80 lines
1.8 KiB
Python
80 lines
1.8 KiB
Python
|
from sympy.core import Basic, Expr
|
||
|
from sympy.core.numbers import oo
|
||
|
from sympy.core.symbol import symbols
|
||
|
from sympy.multipledispatch import Dispatcher
|
||
|
from sympy.sets.setexpr import set_mul
|
||
|
from sympy.sets.sets import Interval, Set
|
||
|
|
||
|
|
||
|
_x, _y = symbols("x y")
|
||
|
|
||
|
|
||
|
_set_mul = Dispatcher('_set_mul')
|
||
|
_set_div = Dispatcher('_set_div')
|
||
|
|
||
|
|
||
|
@_set_mul.register(Basic, Basic)
|
||
|
def _(x, y):
|
||
|
return None
|
||
|
|
||
|
@_set_mul.register(Set, Set)
|
||
|
def _(x, y):
|
||
|
return None
|
||
|
|
||
|
@_set_mul.register(Expr, Expr)
|
||
|
def _(x, y):
|
||
|
return x*y
|
||
|
|
||
|
@_set_mul.register(Interval, Interval)
|
||
|
def _(x, y):
|
||
|
"""
|
||
|
Multiplications in interval arithmetic
|
||
|
https://en.wikipedia.org/wiki/Interval_arithmetic
|
||
|
"""
|
||
|
# TODO: some intervals containing 0 and oo will fail as 0*oo returns nan.
|
||
|
comvals = (
|
||
|
(x.start * y.start, bool(x.left_open or y.left_open)),
|
||
|
(x.start * y.end, bool(x.left_open or y.right_open)),
|
||
|
(x.end * y.start, bool(x.right_open or y.left_open)),
|
||
|
(x.end * y.end, bool(x.right_open or y.right_open)),
|
||
|
)
|
||
|
# TODO: handle symbolic intervals
|
||
|
minval, minopen = min(comvals)
|
||
|
maxval, maxopen = max(comvals)
|
||
|
return Interval(
|
||
|
minval,
|
||
|
maxval,
|
||
|
minopen,
|
||
|
maxopen
|
||
|
)
|
||
|
|
||
|
@_set_div.register(Basic, Basic)
|
||
|
def _(x, y):
|
||
|
return None
|
||
|
|
||
|
@_set_div.register(Expr, Expr)
|
||
|
def _(x, y):
|
||
|
return x/y
|
||
|
|
||
|
@_set_div.register(Set, Set)
|
||
|
def _(x, y):
|
||
|
return None
|
||
|
|
||
|
@_set_div.register(Interval, Interval)
|
||
|
def _(x, y):
|
||
|
"""
|
||
|
Divisions in interval arithmetic
|
||
|
https://en.wikipedia.org/wiki/Interval_arithmetic
|
||
|
"""
|
||
|
if (y.start*y.end).is_negative:
|
||
|
return Interval(-oo, oo)
|
||
|
if y.start == 0:
|
||
|
s2 = oo
|
||
|
else:
|
||
|
s2 = 1/y.start
|
||
|
if y.end == 0:
|
||
|
s1 = -oo
|
||
|
else:
|
||
|
s1 = 1/y.end
|
||
|
return set_mul(x, Interval(s1, s2, y.right_open, y.left_open))
|