from sympy.core import Basic, Expr from sympy.core.numbers import oo from sympy.core.symbol import symbols from sympy.multipledispatch import Dispatcher from sympy.sets.setexpr import set_mul from sympy.sets.sets import Interval, Set _x, _y = symbols("x y") _set_mul = Dispatcher('_set_mul') _set_div = Dispatcher('_set_div') @_set_mul.register(Basic, Basic) def _(x, y): return None @_set_mul.register(Set, Set) def _(x, y): return None @_set_mul.register(Expr, Expr) def _(x, y): return x*y @_set_mul.register(Interval, Interval) def _(x, y): """ Multiplications in interval arithmetic https://en.wikipedia.org/wiki/Interval_arithmetic """ # TODO: some intervals containing 0 and oo will fail as 0*oo returns nan. comvals = ( (x.start * y.start, bool(x.left_open or y.left_open)), (x.start * y.end, bool(x.left_open or y.right_open)), (x.end * y.start, bool(x.right_open or y.left_open)), (x.end * y.end, bool(x.right_open or y.right_open)), ) # TODO: handle symbolic intervals minval, minopen = min(comvals) maxval, maxopen = max(comvals) return Interval( minval, maxval, minopen, maxopen ) @_set_div.register(Basic, Basic) def _(x, y): return None @_set_div.register(Expr, Expr) def _(x, y): return x/y @_set_div.register(Set, Set) def _(x, y): return None @_set_div.register(Interval, Interval) def _(x, y): """ Divisions in interval arithmetic https://en.wikipedia.org/wiki/Interval_arithmetic """ if (y.start*y.end).is_negative: return Interval(-oo, oo) if y.start == 0: s2 = oo else: s2 = 1/y.start if y.end == 0: s1 = -oo else: s1 = 1/y.end return set_mul(x, Interval(s1, s2, y.right_open, y.left_open))