from __future__ import annotations from warnings import warn import inspect from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning from .utils import expand_tuples import itertools as itl class MDNotImplementedError(NotImplementedError): """ A NotImplementedError for multiple dispatch """ ### Functions for on_ambiguity def ambiguity_warn(dispatcher, ambiguities): """ Raise warning when ambiguity is detected Parameters ---------- dispatcher : Dispatcher The dispatcher on which the ambiguity was detected ambiguities : set Set of type signature pairs that are ambiguous within this dispatcher See Also: Dispatcher.add warning_text """ warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning) class RaiseNotImplementedError: """Raise ``NotImplementedError`` when called.""" def __init__(self, dispatcher): self.dispatcher = dispatcher def __call__(self, *args, **kwargs): types = tuple(type(a) for a in args) raise NotImplementedError( "Ambiguous signature for %s: <%s>" % ( self.dispatcher.name, str_signature(types) )) def ambiguity_register_error_ignore_dup(dispatcher, ambiguities): """ If super signature for ambiguous types is duplicate types, ignore it. Else, register instance of ``RaiseNotImplementedError`` for ambiguous types. Parameters ---------- dispatcher : Dispatcher The dispatcher on which the ambiguity was detected ambiguities : set Set of type signature pairs that are ambiguous within this dispatcher See Also: Dispatcher.add ambiguity_warn """ for amb in ambiguities: signature = tuple(super_signature(amb)) if len(set(signature)) == 1: continue dispatcher.add( signature, RaiseNotImplementedError(dispatcher), on_ambiguity=ambiguity_register_error_ignore_dup ) ### _unresolved_dispatchers: set[Dispatcher] = set() _resolve = [True] def halt_ordering(): _resolve[0] = False def restart_ordering(on_ambiguity=ambiguity_warn): _resolve[0] = True while _unresolved_dispatchers: dispatcher = _unresolved_dispatchers.pop() dispatcher.reorder(on_ambiguity=on_ambiguity) class Dispatcher: """ Dispatch methods based on type signature Use ``dispatch`` to add implementations Examples -------- >>> from sympy.multipledispatch import dispatch >>> @dispatch(int) ... def f(x): ... return x + 1 >>> @dispatch(float) ... def f(x): # noqa: F811 ... return x - 1 >>> f(3) 4 >>> f(3.0) 2.0 """ __slots__ = '__name__', 'name', 'funcs', 'ordering', '_cache', 'doc' def __init__(self, name, doc=None): self.name = self.__name__ = name self.funcs = {} self._cache = {} self.ordering = [] self.doc = doc def register(self, *types, **kwargs): """ Register dispatcher with new implementation >>> from sympy.multipledispatch.dispatcher import Dispatcher >>> f = Dispatcher('f') >>> @f.register(int) ... def inc(x): ... return x + 1 >>> @f.register(float) ... def dec(x): ... return x - 1 >>> @f.register(list) ... @f.register(tuple) ... def reverse(x): ... return x[::-1] >>> f(1) 2 >>> f(1.0) 0.0 >>> f([1, 2, 3]) [3, 2, 1] """ def _(func): self.add(types, func, **kwargs) return func return _ @classmethod def get_func_params(cls, func): if hasattr(inspect, "signature"): sig = inspect.signature(func) return sig.parameters.values() @classmethod def get_func_annotations(cls, func): """ Get annotations of function positional parameters """ params = cls.get_func_params(func) if params: Parameter = inspect.Parameter params = (param for param in params if param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)) annotations = tuple( param.annotation for param in params) if not any(ann is Parameter.empty for ann in annotations): return annotations def add(self, signature, func, on_ambiguity=ambiguity_warn): """ Add new types/method pair to dispatcher >>> from sympy.multipledispatch import Dispatcher >>> D = Dispatcher('add') >>> D.add((int, int), lambda x, y: x + y) >>> D.add((float, float), lambda x, y: x + y) >>> D(1, 2) 3 >>> D(1, 2.0) Traceback (most recent call last): ... NotImplementedError: Could not find signature for add: When ``add`` detects a warning it calls the ``on_ambiguity`` callback with a dispatcher/itself, and a set of ambiguous type signature pairs as inputs. See ``ambiguity_warn`` for an example. """ # Handle annotations if not signature: annotations = self.get_func_annotations(func) if annotations: signature = annotations # Handle union types if any(isinstance(typ, tuple) for typ in signature): for typs in expand_tuples(signature): self.add(typs, func, on_ambiguity) return for typ in signature: if not isinstance(typ, type): str_sig = ', '.join(c.__name__ if isinstance(c, type) else str(c) for c in signature) raise TypeError("Tried to dispatch on non-type: %s\n" "In signature: <%s>\n" "In function: %s" % (typ, str_sig, self.name)) self.funcs[signature] = func self.reorder(on_ambiguity=on_ambiguity) self._cache.clear() def reorder(self, on_ambiguity=ambiguity_warn): if _resolve[0]: self.ordering = ordering(self.funcs) amb = ambiguities(self.funcs) if amb: on_ambiguity(self, amb) else: _unresolved_dispatchers.add(self) def __call__(self, *args, **kwargs): types = tuple([type(arg) for arg in args]) try: func = self._cache[types] except KeyError: func = self.dispatch(*types) if not func: raise NotImplementedError( 'Could not find signature for %s: <%s>' % (self.name, str_signature(types))) self._cache[types] = func try: return func(*args, **kwargs) except MDNotImplementedError: funcs = self.dispatch_iter(*types) next(funcs) # burn first for func in funcs: try: return func(*args, **kwargs) except MDNotImplementedError: pass raise NotImplementedError("Matching functions for " "%s: <%s> found, but none completed successfully" % (self.name, str_signature(types))) def __str__(self): return "" % self.name __repr__ = __str__ def dispatch(self, *types): """ Deterimine appropriate implementation for this type signature This method is internal. Users should call this object as a function. Implementation resolution occurs within the ``__call__`` method. >>> from sympy.multipledispatch import dispatch >>> @dispatch(int) ... def inc(x): ... return x + 1 >>> implementation = inc.dispatch(int) >>> implementation(3) 4 >>> print(inc.dispatch(float)) None See Also: ``sympy.multipledispatch.conflict`` - module to determine resolution order """ if types in self.funcs: return self.funcs[types] try: return next(self.dispatch_iter(*types)) except StopIteration: return None def dispatch_iter(self, *types): n = len(types) for signature in self.ordering: if len(signature) == n and all(map(issubclass, types, signature)): result = self.funcs[signature] yield result def resolve(self, types): """ Deterimine appropriate implementation for this type signature .. deprecated:: 0.4.4 Use ``dispatch(*types)`` instead """ warn("resolve() is deprecated, use dispatch(*types)", DeprecationWarning) return self.dispatch(*types) def __getstate__(self): return {'name': self.name, 'funcs': self.funcs} def __setstate__(self, d): self.name = d['name'] self.funcs = d['funcs'] self.ordering = ordering(self.funcs) self._cache = {} @property def __doc__(self): docs = ["Multiply dispatched method: %s" % self.name] if self.doc: docs.append(self.doc) other = [] for sig in self.ordering[::-1]: func = self.funcs[sig] if func.__doc__: s = 'Inputs: <%s>\n' % str_signature(sig) s += '-' * len(s) + '\n' s += func.__doc__.strip() docs.append(s) else: other.append(str_signature(sig)) if other: docs.append('Other signatures:\n ' + '\n '.join(other)) return '\n\n'.join(docs) def _help(self, *args): return self.dispatch(*map(type, args)).__doc__ def help(self, *args, **kwargs): """ Print docstring for the function corresponding to inputs """ print(self._help(*args)) def _source(self, *args): func = self.dispatch(*map(type, args)) if not func: raise TypeError("No function found") return source(func) def source(self, *args, **kwargs): """ Print source code for the function corresponding to inputs """ print(self._source(*args)) def source(func): s = 'File: %s\n\n' % inspect.getsourcefile(func) s = s + inspect.getsource(func) return s class MethodDispatcher(Dispatcher): """ Dispatch methods based on type signature See Also: Dispatcher """ @classmethod def get_func_params(cls, func): if hasattr(inspect, "signature"): sig = inspect.signature(func) return itl.islice(sig.parameters.values(), 1, None) def __get__(self, instance, owner): self.obj = instance self.cls = owner return self def __call__(self, *args, **kwargs): types = tuple([type(arg) for arg in args]) func = self.dispatch(*types) if not func: raise NotImplementedError('Could not find signature for %s: <%s>' % (self.name, str_signature(types))) return func(self.obj, *args, **kwargs) def str_signature(sig): """ String representation of type signature >>> from sympy.multipledispatch.dispatcher import str_signature >>> str_signature((int, float)) 'int, float' """ return ', '.join(cls.__name__ for cls in sig) def warning_text(name, amb): """ The text for ambiguity warnings """ text = "\nAmbiguities exist in dispatched function %s\n\n" % (name) text += "The following signatures may result in ambiguous behavior:\n" for pair in amb: text += "\t" + \ ', '.join('[' + str_signature(s) + ']' for s in pair) + "\n" text += "\n\nConsider making the following additions:\n\n" text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s)) + ')\ndef %s(...)' % name for s in amb]) return text