285 lines
6.1 KiB
Python
285 lines
6.1 KiB
Python
from sympy.multipledispatch.dispatcher import (Dispatcher, MDNotImplementedError,
|
|
MethodDispatcher, halt_ordering,
|
|
restart_ordering,
|
|
ambiguity_register_error_ignore_dup)
|
|
from sympy.testing.pytest import raises, warns
|
|
|
|
|
|
def identity(x):
|
|
return x
|
|
|
|
|
|
def inc(x):
|
|
return x + 1
|
|
|
|
|
|
def dec(x):
|
|
return x - 1
|
|
|
|
|
|
def test_dispatcher():
|
|
f = Dispatcher('f')
|
|
f.add((int,), inc)
|
|
f.add((float,), dec)
|
|
|
|
with warns(DeprecationWarning, test_stacklevel=False):
|
|
assert f.resolve((int,)) == inc
|
|
assert f.dispatch(int) is inc
|
|
|
|
assert f(1) == 2
|
|
assert f(1.0) == 0.0
|
|
|
|
|
|
def test_union_types():
|
|
f = Dispatcher('f')
|
|
f.register((int, float))(inc)
|
|
|
|
assert f(1) == 2
|
|
assert f(1.0) == 2.0
|
|
|
|
|
|
def test_dispatcher_as_decorator():
|
|
f = Dispatcher('f')
|
|
|
|
@f.register(int)
|
|
def inc(x): # noqa:F811
|
|
return x + 1
|
|
|
|
@f.register(float) # noqa:F811
|
|
def inc(x): # noqa:F811
|
|
return x - 1
|
|
|
|
assert f(1) == 2
|
|
assert f(1.0) == 0.0
|
|
|
|
|
|
def test_register_instance_method():
|
|
|
|
class Test:
|
|
__init__ = MethodDispatcher('f')
|
|
|
|
@__init__.register(list)
|
|
def _init_list(self, data):
|
|
self.data = data
|
|
|
|
@__init__.register(object)
|
|
def _init_obj(self, datum):
|
|
self.data = [datum]
|
|
|
|
a = Test(3)
|
|
b = Test([3])
|
|
assert a.data == b.data
|
|
|
|
|
|
def test_on_ambiguity():
|
|
f = Dispatcher('f')
|
|
|
|
def identity(x): return x
|
|
|
|
ambiguities = [False]
|
|
|
|
def on_ambiguity(dispatcher, amb):
|
|
ambiguities[0] = True
|
|
|
|
f.add((object, object), identity, on_ambiguity=on_ambiguity)
|
|
assert not ambiguities[0]
|
|
f.add((object, float), identity, on_ambiguity=on_ambiguity)
|
|
assert not ambiguities[0]
|
|
f.add((float, object), identity, on_ambiguity=on_ambiguity)
|
|
assert ambiguities[0]
|
|
|
|
|
|
def test_raise_error_on_non_class():
|
|
f = Dispatcher('f')
|
|
assert raises(TypeError, lambda: f.add((1,), inc))
|
|
|
|
|
|
def test_docstring():
|
|
|
|
def one(x, y):
|
|
""" Docstring number one """
|
|
return x + y
|
|
|
|
def two(x, y):
|
|
""" Docstring number two """
|
|
return x + y
|
|
|
|
def three(x, y):
|
|
return x + y
|
|
|
|
master_doc = 'Doc of the multimethod itself'
|
|
|
|
f = Dispatcher('f', doc=master_doc)
|
|
f.add((object, object), one)
|
|
f.add((int, int), two)
|
|
f.add((float, float), three)
|
|
|
|
assert one.__doc__.strip() in f.__doc__
|
|
assert two.__doc__.strip() in f.__doc__
|
|
assert f.__doc__.find(one.__doc__.strip()) < \
|
|
f.__doc__.find(two.__doc__.strip())
|
|
assert 'object, object' in f.__doc__
|
|
assert master_doc in f.__doc__
|
|
|
|
|
|
def test_help():
|
|
def one(x, y):
|
|
""" Docstring number one """
|
|
return x + y
|
|
|
|
def two(x, y):
|
|
""" Docstring number two """
|
|
return x + y
|
|
|
|
def three(x, y):
|
|
""" Docstring number three """
|
|
return x + y
|
|
|
|
master_doc = 'Doc of the multimethod itself'
|
|
|
|
f = Dispatcher('f', doc=master_doc)
|
|
f.add((object, object), one)
|
|
f.add((int, int), two)
|
|
f.add((float, float), three)
|
|
|
|
assert f._help(1, 1) == two.__doc__
|
|
assert f._help(1.0, 2.0) == three.__doc__
|
|
|
|
|
|
def test_source():
|
|
def one(x, y):
|
|
""" Docstring number one """
|
|
return x + y
|
|
|
|
def two(x, y):
|
|
""" Docstring number two """
|
|
return x - y
|
|
|
|
master_doc = 'Doc of the multimethod itself'
|
|
|
|
f = Dispatcher('f', doc=master_doc)
|
|
f.add((int, int), one)
|
|
f.add((float, float), two)
|
|
|
|
assert 'x + y' in f._source(1, 1)
|
|
assert 'x - y' in f._source(1.0, 1.0)
|
|
|
|
|
|
def test_source_raises_on_missing_function():
|
|
f = Dispatcher('f')
|
|
|
|
assert raises(TypeError, lambda: f.source(1))
|
|
|
|
|
|
def test_halt_method_resolution():
|
|
g = [0]
|
|
|
|
def on_ambiguity(a, b):
|
|
g[0] += 1
|
|
|
|
f = Dispatcher('f')
|
|
|
|
halt_ordering()
|
|
|
|
def func(*args):
|
|
pass
|
|
|
|
f.add((int, object), func)
|
|
f.add((object, int), func)
|
|
|
|
assert g == [0]
|
|
|
|
restart_ordering(on_ambiguity=on_ambiguity)
|
|
|
|
assert g == [1]
|
|
|
|
assert set(f.ordering) == {(int, object), (object, int)}
|
|
|
|
|
|
def test_no_implementations():
|
|
f = Dispatcher('f')
|
|
assert raises(NotImplementedError, lambda: f('hello'))
|
|
|
|
|
|
def test_register_stacking():
|
|
f = Dispatcher('f')
|
|
|
|
@f.register(list)
|
|
@f.register(tuple)
|
|
def rev(x):
|
|
return x[::-1]
|
|
|
|
assert f((1, 2, 3)) == (3, 2, 1)
|
|
assert f([1, 2, 3]) == [3, 2, 1]
|
|
|
|
assert raises(NotImplementedError, lambda: f('hello'))
|
|
assert rev('hello') == 'olleh'
|
|
|
|
|
|
def test_dispatch_method():
|
|
f = Dispatcher('f')
|
|
|
|
@f.register(list)
|
|
def rev(x):
|
|
return x[::-1]
|
|
|
|
@f.register(int, int)
|
|
def add(x, y):
|
|
return x + y
|
|
|
|
class MyList(list):
|
|
pass
|
|
|
|
assert f.dispatch(list) is rev
|
|
assert f.dispatch(MyList) is rev
|
|
assert f.dispatch(int, int) is add
|
|
|
|
|
|
def test_not_implemented():
|
|
f = Dispatcher('f')
|
|
|
|
@f.register(object)
|
|
def _(x):
|
|
return 'default'
|
|
|
|
@f.register(int)
|
|
def _(x):
|
|
if x % 2 == 0:
|
|
return 'even'
|
|
else:
|
|
raise MDNotImplementedError()
|
|
|
|
assert f('hello') == 'default' # default behavior
|
|
assert f(2) == 'even' # specialized behavior
|
|
assert f(3) == 'default' # fall bac to default behavior
|
|
assert raises(NotImplementedError, lambda: f(1, 2))
|
|
|
|
|
|
def test_not_implemented_error():
|
|
f = Dispatcher('f')
|
|
|
|
@f.register(float)
|
|
def _(a):
|
|
raise MDNotImplementedError()
|
|
|
|
assert raises(NotImplementedError, lambda: f(1.0))
|
|
|
|
def test_ambiguity_register_error_ignore_dup():
|
|
f = Dispatcher('f')
|
|
|
|
class A:
|
|
pass
|
|
class B(A):
|
|
pass
|
|
class C(A):
|
|
pass
|
|
|
|
# suppress warning for registering ambiguous signal
|
|
f.add((A, B), lambda x,y: None, ambiguity_register_error_ignore_dup)
|
|
f.add((B, A), lambda x,y: None, ambiguity_register_error_ignore_dup)
|
|
f.add((A, C), lambda x,y: None, ambiguity_register_error_ignore_dup)
|
|
f.add((C, A), lambda x,y: None, ambiguity_register_error_ignore_dup)
|
|
|
|
# raises error if ambiguous signal is passed
|
|
assert raises(NotImplementedError, lambda: f(B(), C()))
|