from sympy.core.basic import Basic from sympy.core.expr import Expr from sympy.core.symbol import Symbol from sympy.core.numbers import Integer, Rational, Float from sympy.printing.repr import srepr __all__ = ['dotprint'] default_styles = ( (Basic, {'color': 'blue', 'shape': 'ellipse'}), (Expr, {'color': 'black'}) ) slotClasses = (Symbol, Integer, Rational, Float) def purestr(x, with_args=False): """A string that follows ```obj = type(obj)(*obj.args)``` exactly. Parameters ========== with_args : boolean, optional If ``True``, there will be a second argument for the return value, which is a tuple containing ``purestr`` applied to each of the subnodes. If ``False``, there will not be a second argument for the return. Default is ``False`` Examples ======== >>> from sympy import Float, Symbol, MatrixSymbol >>> from sympy import Integer # noqa: F401 >>> from sympy.core.symbol import Str # noqa: F401 >>> from sympy.printing.dot import purestr Applying ``purestr`` for basic symbolic object: >>> code = purestr(Symbol('x')) >>> code "Symbol('x')" >>> eval(code) == Symbol('x') True For basic numeric object: >>> purestr(Float(2)) "Float('2.0', precision=53)" For matrix symbol: >>> code = purestr(MatrixSymbol('x', 2, 2)) >>> code "MatrixSymbol(Str('x'), Integer(2), Integer(2))" >>> eval(code) == MatrixSymbol('x', 2, 2) True With ``with_args=True``: >>> purestr(Float(2), with_args=True) ("Float('2.0', precision=53)", ()) >>> purestr(MatrixSymbol('x', 2, 2), with_args=True) ("MatrixSymbol(Str('x'), Integer(2), Integer(2))", ("Str('x')", 'Integer(2)', 'Integer(2)')) """ sargs = () if not isinstance(x, Basic): rv = str(x) elif not x.args: rv = srepr(x) else: args = x.args sargs = tuple(map(purestr, args)) rv = "%s(%s)"%(type(x).__name__, ', '.join(sargs)) if with_args: rv = rv, sargs return rv def styleof(expr, styles=default_styles): """ Merge style dictionaries in order Examples ======== >>> from sympy import Symbol, Basic, Expr, S >>> from sympy.printing.dot import styleof >>> styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}), ... (Expr, {'color': 'black'})] >>> styleof(Basic(S(1)), styles) {'color': 'blue', 'shape': 'ellipse'} >>> x = Symbol('x') >>> styleof(x + 1, styles) # this is an Expr {'color': 'black', 'shape': 'ellipse'} """ style = {} for typ, sty in styles: if isinstance(expr, typ): style.update(sty) return style def attrprint(d, delimiter=', '): """ Print a dictionary of attributes Examples ======== >>> from sympy.printing.dot import attrprint >>> print(attrprint({'color': 'blue', 'shape': 'ellipse'})) "color"="blue", "shape"="ellipse" """ return delimiter.join('"%s"="%s"'%item for item in sorted(d.items())) def dotnode(expr, styles=default_styles, labelfunc=str, pos=(), repeat=True): """ String defining a node Examples ======== >>> from sympy.printing.dot import dotnode >>> from sympy.abc import x >>> print(dotnode(x)) "Symbol('x')_()" ["color"="black", "label"="x", "shape"="ellipse"]; """ style = styleof(expr, styles) if isinstance(expr, Basic) and not expr.is_Atom: label = str(expr.__class__.__name__) else: label = labelfunc(expr) style['label'] = label expr_str = purestr(expr) if repeat: expr_str += '_%s' % str(pos) return '"%s" [%s];' % (expr_str, attrprint(style)) def dotedges(expr, atom=lambda x: not isinstance(x, Basic), pos=(), repeat=True): """ List of strings for all expr->expr.arg pairs See the docstring of dotprint for explanations of the options. Examples ======== >>> from sympy.printing.dot import dotedges >>> from sympy.abc import x >>> for e in dotedges(x+2): ... print(e) "Add(Integer(2), Symbol('x'))_()" -> "Integer(2)_(0,)"; "Add(Integer(2), Symbol('x'))_()" -> "Symbol('x')_(1,)"; """ if atom(expr): return [] else: expr_str, arg_strs = purestr(expr, with_args=True) if repeat: expr_str += '_%s' % str(pos) arg_strs = ['%s_%s' % (a, str(pos + (i,))) for i, a in enumerate(arg_strs)] return ['"%s" -> "%s";' % (expr_str, a) for a in arg_strs] template = \ """digraph{ # Graph style %(graphstyle)s ######### # Nodes # ######### %(nodes)s ######### # Edges # ######### %(edges)s }""" _graphstyle = {'rankdir': 'TD', 'ordering': 'out'} def dotprint(expr, styles=default_styles, atom=lambda x: not isinstance(x, Basic), maxdepth=None, repeat=True, labelfunc=str, **kwargs): """DOT description of a SymPy expression tree Parameters ========== styles : list of lists composed of (Class, mapping), optional Styles for different classes. The default is .. code-block:: python ( (Basic, {'color': 'blue', 'shape': 'ellipse'}), (Expr, {'color': 'black'}) ) atom : function, optional Function used to determine if an arg is an atom. A good choice is ``lambda x: not x.args``. The default is ``lambda x: not isinstance(x, Basic)``. maxdepth : integer, optional The maximum depth. The default is ``None``, meaning no limit. repeat : boolean, optional Whether to use different nodes for common subexpressions. The default is ``True``. For example, for ``x + x*y`` with ``repeat=True``, it will have two nodes for ``x``; with ``repeat=False``, it will have one node. .. warning:: Even if a node appears twice in the same object like ``x`` in ``Pow(x, x)``, it will still only appear once. Hence, with ``repeat=False``, the number of arrows out of an object might not equal the number of args it has. labelfunc : function, optional A function to create a label for a given leaf node. The default is ``str``. Another good option is ``srepr``. For example with ``str``, the leaf nodes of ``x + 1`` are labeled, ``x`` and ``1``. With ``srepr``, they are labeled ``Symbol('x')`` and ``Integer(1)``. **kwargs : optional Additional keyword arguments are included as styles for the graph. Examples ======== >>> from sympy import dotprint >>> from sympy.abc import x >>> print(dotprint(x+2)) # doctest: +NORMALIZE_WHITESPACE digraph{ # Graph style "ordering"="out" "rankdir"="TD" ######### # Nodes # ######### "Add(Integer(2), Symbol('x'))_()" ["color"="black", "label"="Add", "shape"="ellipse"]; "Integer(2)_(0,)" ["color"="black", "label"="2", "shape"="ellipse"]; "Symbol('x')_(1,)" ["color"="black", "label"="x", "shape"="ellipse"]; ######### # Edges # ######### "Add(Integer(2), Symbol('x'))_()" -> "Integer(2)_(0,)"; "Add(Integer(2), Symbol('x'))_()" -> "Symbol('x')_(1,)"; } """ # repeat works by adding a signature tuple to the end of each node for its # position in the graph. For example, for expr = Add(x, Pow(x, 2)), the x in the # Pow will have the tuple (1, 0), meaning it is expr.args[1].args[0]. graphstyle = _graphstyle.copy() graphstyle.update(kwargs) nodes = [] edges = [] def traverse(e, depth, pos=()): nodes.append(dotnode(e, styles, labelfunc=labelfunc, pos=pos, repeat=repeat)) if maxdepth and depth >= maxdepth: return edges.extend(dotedges(e, atom=atom, pos=pos, repeat=repeat)) [traverse(arg, depth+1, pos + (i,)) for i, arg in enumerate(e.args) if not atom(arg)] traverse(expr, 0) return template%{'graphstyle': attrprint(graphstyle, delimiter='\n'), 'nodes': '\n'.join(nodes), 'edges': '\n'.join(edges)}