876 lines
34 KiB
Python
876 lines
34 KiB
Python
|
from __future__ import annotations
|
||
|
from typing import Any
|
||
|
|
||
|
from functools import wraps
|
||
|
|
||
|
from sympy.core import Add, Mul, Pow, S, sympify, Float
|
||
|
from sympy.core.basic import Basic
|
||
|
from sympy.core.expr import UnevaluatedExpr
|
||
|
from sympy.core.function import Lambda
|
||
|
from sympy.core.mul import _keep_coeff
|
||
|
from sympy.core.sorting import default_sort_key
|
||
|
from sympy.core.symbol import Symbol
|
||
|
from sympy.functions.elementary.complexes import re
|
||
|
from sympy.printing.str import StrPrinter
|
||
|
from sympy.printing.precedence import precedence, PRECEDENCE
|
||
|
|
||
|
|
||
|
class requires:
|
||
|
""" Decorator for registering requirements on print methods. """
|
||
|
def __init__(self, **kwargs):
|
||
|
self._req = kwargs
|
||
|
|
||
|
def __call__(self, method):
|
||
|
def _method_wrapper(self_, *args, **kwargs):
|
||
|
for k, v in self._req.items():
|
||
|
getattr(self_, k).update(v)
|
||
|
return method(self_, *args, **kwargs)
|
||
|
return wraps(method)(_method_wrapper)
|
||
|
|
||
|
|
||
|
class AssignmentError(Exception):
|
||
|
"""
|
||
|
Raised if an assignment variable for a loop is missing.
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
|
||
|
def _convert_python_lists(arg):
|
||
|
if isinstance(arg, list):
|
||
|
from sympy.codegen.abstract_nodes import List
|
||
|
return List(*(_convert_python_lists(e) for e in arg))
|
||
|
elif isinstance(arg, tuple):
|
||
|
return tuple(_convert_python_lists(e) for e in arg)
|
||
|
else:
|
||
|
return arg
|
||
|
|
||
|
|
||
|
class CodePrinter(StrPrinter):
|
||
|
"""
|
||
|
The base class for code-printing subclasses.
|
||
|
"""
|
||
|
|
||
|
_operators = {
|
||
|
'and': '&&',
|
||
|
'or': '||',
|
||
|
'not': '!',
|
||
|
}
|
||
|
|
||
|
_default_settings: dict[str, Any] = {
|
||
|
'order': None,
|
||
|
'full_prec': 'auto',
|
||
|
'error_on_reserved': False,
|
||
|
'reserved_word_suffix': '_',
|
||
|
'human': True,
|
||
|
'inline': False,
|
||
|
'allow_unknown_functions': False,
|
||
|
}
|
||
|
|
||
|
# Functions which are "simple" to rewrite to other functions that
|
||
|
# may be supported
|
||
|
# function_to_rewrite : (function_to_rewrite_to, iterable_with_other_functions_required)
|
||
|
_rewriteable_functions = {
|
||
|
'cot': ('tan', []),
|
||
|
'csc': ('sin', []),
|
||
|
'sec': ('cos', []),
|
||
|
'acot': ('atan', []),
|
||
|
'acsc': ('asin', []),
|
||
|
'asec': ('acos', []),
|
||
|
'coth': ('exp', []),
|
||
|
'csch': ('exp', []),
|
||
|
'sech': ('exp', []),
|
||
|
'acoth': ('log', []),
|
||
|
'acsch': ('log', []),
|
||
|
'asech': ('log', []),
|
||
|
'catalan': ('gamma', []),
|
||
|
'fibonacci': ('sqrt', []),
|
||
|
'lucas': ('sqrt', []),
|
||
|
'beta': ('gamma', []),
|
||
|
'sinc': ('sin', ['Piecewise']),
|
||
|
'Mod': ('floor', []),
|
||
|
'factorial': ('gamma', []),
|
||
|
'factorial2': ('gamma', ['Piecewise']),
|
||
|
'subfactorial': ('uppergamma', []),
|
||
|
'RisingFactorial': ('gamma', ['Piecewise']),
|
||
|
'FallingFactorial': ('gamma', ['Piecewise']),
|
||
|
'binomial': ('gamma', []),
|
||
|
'frac': ('floor', []),
|
||
|
'Max': ('Piecewise', []),
|
||
|
'Min': ('Piecewise', []),
|
||
|
'Heaviside': ('Piecewise', []),
|
||
|
'erf2': ('erf', []),
|
||
|
'erfc': ('erf', []),
|
||
|
'Li': ('li', []),
|
||
|
'Ei': ('li', []),
|
||
|
'dirichlet_eta': ('zeta', []),
|
||
|
'riemann_xi': ('zeta', ['gamma']),
|
||
|
}
|
||
|
|
||
|
def __init__(self, settings=None):
|
||
|
|
||
|
super().__init__(settings=settings)
|
||
|
if not hasattr(self, 'reserved_words'):
|
||
|
self.reserved_words = set()
|
||
|
|
||
|
def _handle_UnevaluatedExpr(self, expr):
|
||
|
return expr.replace(re, lambda arg: arg if isinstance(
|
||
|
arg, UnevaluatedExpr) and arg.args[0].is_real else re(arg))
|
||
|
|
||
|
def doprint(self, expr, assign_to=None):
|
||
|
"""
|
||
|
Print the expression as code.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
expr : Expression
|
||
|
The expression to be printed.
|
||
|
|
||
|
assign_to : Symbol, string, MatrixSymbol, list of strings or Symbols (optional)
|
||
|
If provided, the printed code will set the expression to a variable or multiple variables
|
||
|
with the name or names given in ``assign_to``.
|
||
|
"""
|
||
|
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
||
|
from sympy.codegen.ast import CodeBlock, Assignment
|
||
|
|
||
|
def _handle_assign_to(expr, assign_to):
|
||
|
if assign_to is None:
|
||
|
return sympify(expr)
|
||
|
if isinstance(assign_to, (list, tuple)):
|
||
|
if len(expr) != len(assign_to):
|
||
|
raise ValueError('Failed to assign an expression of length {} to {} variables'.format(len(expr), len(assign_to)))
|
||
|
return CodeBlock(*[_handle_assign_to(lhs, rhs) for lhs, rhs in zip(expr, assign_to)])
|
||
|
if isinstance(assign_to, str):
|
||
|
if expr.is_Matrix:
|
||
|
assign_to = MatrixSymbol(assign_to, *expr.shape)
|
||
|
else:
|
||
|
assign_to = Symbol(assign_to)
|
||
|
elif not isinstance(assign_to, Basic):
|
||
|
raise TypeError("{} cannot assign to object of type {}".format(
|
||
|
type(self).__name__, type(assign_to)))
|
||
|
return Assignment(assign_to, expr)
|
||
|
|
||
|
expr = _convert_python_lists(expr)
|
||
|
expr = _handle_assign_to(expr, assign_to)
|
||
|
|
||
|
# Remove re(...) nodes due to UnevaluatedExpr.is_real always is None:
|
||
|
expr = self._handle_UnevaluatedExpr(expr)
|
||
|
|
||
|
# keep a set of expressions that are not strictly translatable to Code
|
||
|
# and number constants that must be declared and initialized
|
||
|
self._not_supported = set()
|
||
|
self._number_symbols = set()
|
||
|
|
||
|
lines = self._print(expr).splitlines()
|
||
|
|
||
|
# format the output
|
||
|
if self._settings["human"]:
|
||
|
frontlines = []
|
||
|
if self._not_supported:
|
||
|
frontlines.append(self._get_comment(
|
||
|
"Not supported in {}:".format(self.language)))
|
||
|
for expr in sorted(self._not_supported, key=str):
|
||
|
frontlines.append(self._get_comment(type(expr).__name__))
|
||
|
for name, value in sorted(self._number_symbols, key=str):
|
||
|
frontlines.append(self._declare_number_const(name, value))
|
||
|
lines = frontlines + lines
|
||
|
lines = self._format_code(lines)
|
||
|
result = "\n".join(lines)
|
||
|
else:
|
||
|
lines = self._format_code(lines)
|
||
|
num_syms = {(k, self._print(v)) for k, v in self._number_symbols}
|
||
|
result = (num_syms, self._not_supported, "\n".join(lines))
|
||
|
self._not_supported = set()
|
||
|
self._number_symbols = set()
|
||
|
return result
|
||
|
|
||
|
def _doprint_loops(self, expr, assign_to=None):
|
||
|
# Here we print an expression that contains Indexed objects, they
|
||
|
# correspond to arrays in the generated code. The low-level implementation
|
||
|
# involves looping over array elements and possibly storing results in temporary
|
||
|
# variables or accumulate it in the assign_to object.
|
||
|
|
||
|
if self._settings.get('contract', True):
|
||
|
from sympy.tensor import get_contraction_structure
|
||
|
# Setup loops over non-dummy indices -- all terms need these
|
||
|
indices = self._get_expression_indices(expr, assign_to)
|
||
|
# Setup loops over dummy indices -- each term needs separate treatment
|
||
|
dummies = get_contraction_structure(expr)
|
||
|
else:
|
||
|
indices = []
|
||
|
dummies = {None: (expr,)}
|
||
|
openloop, closeloop = self._get_loop_opening_ending(indices)
|
||
|
|
||
|
# terms with no summations first
|
||
|
if None in dummies:
|
||
|
text = StrPrinter.doprint(self, Add(*dummies[None]))
|
||
|
else:
|
||
|
# If all terms have summations we must initialize array to Zero
|
||
|
text = StrPrinter.doprint(self, 0)
|
||
|
|
||
|
# skip redundant assignments (where lhs == rhs)
|
||
|
lhs_printed = self._print(assign_to)
|
||
|
lines = []
|
||
|
if text != lhs_printed:
|
||
|
lines.extend(openloop)
|
||
|
if assign_to is not None:
|
||
|
text = self._get_statement("%s = %s" % (lhs_printed, text))
|
||
|
lines.append(text)
|
||
|
lines.extend(closeloop)
|
||
|
|
||
|
# then terms with summations
|
||
|
for d in dummies:
|
||
|
if isinstance(d, tuple):
|
||
|
indices = self._sort_optimized(d, expr)
|
||
|
openloop_d, closeloop_d = self._get_loop_opening_ending(
|
||
|
indices)
|
||
|
|
||
|
for term in dummies[d]:
|
||
|
if term in dummies and not ([list(f.keys()) for f in dummies[term]]
|
||
|
== [[None] for f in dummies[term]]):
|
||
|
# If one factor in the term has it's own internal
|
||
|
# contractions, those must be computed first.
|
||
|
# (temporary variables?)
|
||
|
raise NotImplementedError(
|
||
|
"FIXME: no support for contractions in factor yet")
|
||
|
else:
|
||
|
|
||
|
# We need the lhs expression as an accumulator for
|
||
|
# the loops, i.e
|
||
|
#
|
||
|
# for (int d=0; d < dim; d++){
|
||
|
# lhs[] = lhs[] + term[][d]
|
||
|
# } ^.................. the accumulator
|
||
|
#
|
||
|
# We check if the expression already contains the
|
||
|
# lhs, and raise an exception if it does, as that
|
||
|
# syntax is currently undefined. FIXME: What would be
|
||
|
# a good interpretation?
|
||
|
if assign_to is None:
|
||
|
raise AssignmentError(
|
||
|
"need assignment variable for loops")
|
||
|
if term.has(assign_to):
|
||
|
raise ValueError("FIXME: lhs present in rhs,\
|
||
|
this is undefined in CodePrinter")
|
||
|
|
||
|
lines.extend(openloop)
|
||
|
lines.extend(openloop_d)
|
||
|
text = "%s = %s" % (lhs_printed, StrPrinter.doprint(
|
||
|
self, assign_to + term))
|
||
|
lines.append(self._get_statement(text))
|
||
|
lines.extend(closeloop_d)
|
||
|
lines.extend(closeloop)
|
||
|
|
||
|
return "\n".join(lines)
|
||
|
|
||
|
def _get_expression_indices(self, expr, assign_to):
|
||
|
from sympy.tensor import get_indices
|
||
|
rinds, junk = get_indices(expr)
|
||
|
linds, junk = get_indices(assign_to)
|
||
|
|
||
|
# support broadcast of scalar
|
||
|
if linds and not rinds:
|
||
|
rinds = linds
|
||
|
if rinds != linds:
|
||
|
raise ValueError("lhs indices must match non-dummy"
|
||
|
" rhs indices in %s" % expr)
|
||
|
|
||
|
return self._sort_optimized(rinds, assign_to)
|
||
|
|
||
|
def _sort_optimized(self, indices, expr):
|
||
|
|
||
|
from sympy.tensor.indexed import Indexed
|
||
|
|
||
|
if not indices:
|
||
|
return []
|
||
|
|
||
|
# determine optimized loop order by giving a score to each index
|
||
|
# the index with the highest score are put in the innermost loop.
|
||
|
score_table = {}
|
||
|
for i in indices:
|
||
|
score_table[i] = 0
|
||
|
|
||
|
arrays = expr.atoms(Indexed)
|
||
|
for arr in arrays:
|
||
|
for p, ind in enumerate(arr.indices):
|
||
|
try:
|
||
|
score_table[ind] += self._rate_index_position(p)
|
||
|
except KeyError:
|
||
|
pass
|
||
|
|
||
|
return sorted(indices, key=lambda x: score_table[x])
|
||
|
|
||
|
def _rate_index_position(self, p):
|
||
|
"""function to calculate score based on position among indices
|
||
|
|
||
|
This method is used to sort loops in an optimized order, see
|
||
|
CodePrinter._sort_optimized()
|
||
|
"""
|
||
|
raise NotImplementedError("This function must be implemented by "
|
||
|
"subclass of CodePrinter.")
|
||
|
|
||
|
def _get_statement(self, codestring):
|
||
|
"""Formats a codestring with the proper line ending."""
|
||
|
raise NotImplementedError("This function must be implemented by "
|
||
|
"subclass of CodePrinter.")
|
||
|
|
||
|
def _get_comment(self, text):
|
||
|
"""Formats a text string as a comment."""
|
||
|
raise NotImplementedError("This function must be implemented by "
|
||
|
"subclass of CodePrinter.")
|
||
|
|
||
|
def _declare_number_const(self, name, value):
|
||
|
"""Declare a numeric constant at the top of a function"""
|
||
|
raise NotImplementedError("This function must be implemented by "
|
||
|
"subclass of CodePrinter.")
|
||
|
|
||
|
def _format_code(self, lines):
|
||
|
"""Take in a list of lines of code, and format them accordingly.
|
||
|
|
||
|
This may include indenting, wrapping long lines, etc..."""
|
||
|
raise NotImplementedError("This function must be implemented by "
|
||
|
"subclass of CodePrinter.")
|
||
|
|
||
|
def _get_loop_opening_ending(self, indices):
|
||
|
"""Returns a tuple (open_lines, close_lines) containing lists
|
||
|
of codelines"""
|
||
|
raise NotImplementedError("This function must be implemented by "
|
||
|
"subclass of CodePrinter.")
|
||
|
|
||
|
def _print_Dummy(self, expr):
|
||
|
if expr.name.startswith('Dummy_'):
|
||
|
return '_' + expr.name
|
||
|
else:
|
||
|
return '%s_%d' % (expr.name, expr.dummy_index)
|
||
|
|
||
|
def _print_CodeBlock(self, expr):
|
||
|
return '\n'.join([self._print(i) for i in expr.args])
|
||
|
|
||
|
def _print_String(self, string):
|
||
|
return str(string)
|
||
|
|
||
|
def _print_QuotedString(self, arg):
|
||
|
return '"%s"' % arg.text
|
||
|
|
||
|
def _print_Comment(self, string):
|
||
|
return self._get_comment(str(string))
|
||
|
|
||
|
def _print_Assignment(self, expr):
|
||
|
from sympy.codegen.ast import Assignment
|
||
|
from sympy.functions.elementary.piecewise import Piecewise
|
||
|
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
||
|
from sympy.tensor.indexed import IndexedBase
|
||
|
lhs = expr.lhs
|
||
|
rhs = expr.rhs
|
||
|
# We special case assignments that take multiple lines
|
||
|
if isinstance(expr.rhs, Piecewise):
|
||
|
# Here we modify Piecewise so each expression is now
|
||
|
# an Assignment, and then continue on the print.
|
||
|
expressions = []
|
||
|
conditions = []
|
||
|
for (e, c) in rhs.args:
|
||
|
expressions.append(Assignment(lhs, e))
|
||
|
conditions.append(c)
|
||
|
temp = Piecewise(*zip(expressions, conditions))
|
||
|
return self._print(temp)
|
||
|
elif isinstance(lhs, MatrixSymbol):
|
||
|
# Here we form an Assignment for each element in the array,
|
||
|
# printing each one.
|
||
|
lines = []
|
||
|
for (i, j) in self._traverse_matrix_indices(lhs):
|
||
|
temp = Assignment(lhs[i, j], rhs[i, j])
|
||
|
code0 = self._print(temp)
|
||
|
lines.append(code0)
|
||
|
return "\n".join(lines)
|
||
|
elif self._settings.get("contract", False) and (lhs.has(IndexedBase) or
|
||
|
rhs.has(IndexedBase)):
|
||
|
# Here we check if there is looping to be done, and if so
|
||
|
# print the required loops.
|
||
|
return self._doprint_loops(rhs, lhs)
|
||
|
else:
|
||
|
lhs_code = self._print(lhs)
|
||
|
rhs_code = self._print(rhs)
|
||
|
return self._get_statement("%s = %s" % (lhs_code, rhs_code))
|
||
|
|
||
|
def _print_AugmentedAssignment(self, expr):
|
||
|
lhs_code = self._print(expr.lhs)
|
||
|
rhs_code = self._print(expr.rhs)
|
||
|
return self._get_statement("{} {} {}".format(
|
||
|
*(self._print(arg) for arg in [lhs_code, expr.op, rhs_code])))
|
||
|
|
||
|
def _print_FunctionCall(self, expr):
|
||
|
return '%s(%s)' % (
|
||
|
expr.name,
|
||
|
', '.join((self._print(arg) for arg in expr.function_args)))
|
||
|
|
||
|
def _print_Variable(self, expr):
|
||
|
return self._print(expr.symbol)
|
||
|
|
||
|
def _print_Symbol(self, expr):
|
||
|
|
||
|
name = super()._print_Symbol(expr)
|
||
|
|
||
|
if name in self.reserved_words:
|
||
|
if self._settings['error_on_reserved']:
|
||
|
msg = ('This expression includes the symbol "{}" which is a '
|
||
|
'reserved keyword in this language.')
|
||
|
raise ValueError(msg.format(name))
|
||
|
return name + self._settings['reserved_word_suffix']
|
||
|
else:
|
||
|
return name
|
||
|
|
||
|
def _can_print(self, name):
|
||
|
""" Check if function ``name`` is either a known function or has its own
|
||
|
printing method. Used to check if rewriting is possible."""
|
||
|
return name in self.known_functions or getattr(self, '_print_{}'.format(name), False)
|
||
|
|
||
|
def _print_Function(self, expr):
|
||
|
if expr.func.__name__ in self.known_functions:
|
||
|
cond_func = self.known_functions[expr.func.__name__]
|
||
|
if isinstance(cond_func, str):
|
||
|
return "%s(%s)" % (cond_func, self.stringify(expr.args, ", "))
|
||
|
else:
|
||
|
for cond, func in cond_func:
|
||
|
if cond(*expr.args):
|
||
|
break
|
||
|
if func is not None:
|
||
|
try:
|
||
|
return func(*[self.parenthesize(item, 0) for item in expr.args])
|
||
|
except TypeError:
|
||
|
return "%s(%s)" % (func, self.stringify(expr.args, ", "))
|
||
|
elif hasattr(expr, '_imp_') and isinstance(expr._imp_, Lambda):
|
||
|
# inlined function
|
||
|
return self._print(expr._imp_(*expr.args))
|
||
|
elif expr.func.__name__ in self._rewriteable_functions:
|
||
|
# Simple rewrite to supported function possible
|
||
|
target_f, required_fs = self._rewriteable_functions[expr.func.__name__]
|
||
|
if self._can_print(target_f) and all(self._can_print(f) for f in required_fs):
|
||
|
return self._print(expr.rewrite(target_f))
|
||
|
if expr.is_Function and self._settings.get('allow_unknown_functions', False):
|
||
|
return '%s(%s)' % (self._print(expr.func), ', '.join(map(self._print, expr.args)))
|
||
|
else:
|
||
|
return self._print_not_supported(expr)
|
||
|
|
||
|
_print_Expr = _print_Function
|
||
|
|
||
|
# Don't inherit the str-printer method for Heaviside to the code printers
|
||
|
_print_Heaviside = None
|
||
|
|
||
|
def _print_NumberSymbol(self, expr):
|
||
|
if self._settings.get("inline", False):
|
||
|
return self._print(Float(expr.evalf(self._settings["precision"])))
|
||
|
else:
|
||
|
# A Number symbol that is not implemented here or with _printmethod
|
||
|
# is registered and evaluated
|
||
|
self._number_symbols.add((expr,
|
||
|
Float(expr.evalf(self._settings["precision"]))))
|
||
|
return str(expr)
|
||
|
|
||
|
def _print_Catalan(self, expr):
|
||
|
return self._print_NumberSymbol(expr)
|
||
|
def _print_EulerGamma(self, expr):
|
||
|
return self._print_NumberSymbol(expr)
|
||
|
def _print_GoldenRatio(self, expr):
|
||
|
return self._print_NumberSymbol(expr)
|
||
|
def _print_TribonacciConstant(self, expr):
|
||
|
return self._print_NumberSymbol(expr)
|
||
|
def _print_Exp1(self, expr):
|
||
|
return self._print_NumberSymbol(expr)
|
||
|
def _print_Pi(self, expr):
|
||
|
return self._print_NumberSymbol(expr)
|
||
|
|
||
|
def _print_And(self, expr):
|
||
|
PREC = precedence(expr)
|
||
|
return (" %s " % self._operators['and']).join(self.parenthesize(a, PREC)
|
||
|
for a in sorted(expr.args, key=default_sort_key))
|
||
|
|
||
|
def _print_Or(self, expr):
|
||
|
PREC = precedence(expr)
|
||
|
return (" %s " % self._operators['or']).join(self.parenthesize(a, PREC)
|
||
|
for a in sorted(expr.args, key=default_sort_key))
|
||
|
|
||
|
def _print_Xor(self, expr):
|
||
|
if self._operators.get('xor') is None:
|
||
|
return self._print(expr.to_nnf())
|
||
|
PREC = precedence(expr)
|
||
|
return (" %s " % self._operators['xor']).join(self.parenthesize(a, PREC)
|
||
|
for a in expr.args)
|
||
|
|
||
|
def _print_Equivalent(self, expr):
|
||
|
if self._operators.get('equivalent') is None:
|
||
|
return self._print(expr.to_nnf())
|
||
|
PREC = precedence(expr)
|
||
|
return (" %s " % self._operators['equivalent']).join(self.parenthesize(a, PREC)
|
||
|
for a in expr.args)
|
||
|
|
||
|
def _print_Not(self, expr):
|
||
|
PREC = precedence(expr)
|
||
|
return self._operators['not'] + self.parenthesize(expr.args[0], PREC)
|
||
|
|
||
|
def _print_BooleanFunction(self, expr):
|
||
|
return self._print(expr.to_nnf())
|
||
|
|
||
|
def _print_Mul(self, expr):
|
||
|
|
||
|
prec = precedence(expr)
|
||
|
|
||
|
c, e = expr.as_coeff_Mul()
|
||
|
if c < 0:
|
||
|
expr = _keep_coeff(-c, e)
|
||
|
sign = "-"
|
||
|
else:
|
||
|
sign = ""
|
||
|
|
||
|
a = [] # items in the numerator
|
||
|
b = [] # items that are in the denominator (if any)
|
||
|
|
||
|
pow_paren = [] # Will collect all pow with more than one base element and exp = -1
|
||
|
|
||
|
if self.order not in ('old', 'none'):
|
||
|
args = expr.as_ordered_factors()
|
||
|
else:
|
||
|
# use make_args in case expr was something like -x -> x
|
||
|
args = Mul.make_args(expr)
|
||
|
|
||
|
# Gather args for numerator/denominator
|
||
|
for item in args:
|
||
|
if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
|
||
|
if item.exp != -1:
|
||
|
b.append(Pow(item.base, -item.exp, evaluate=False))
|
||
|
else:
|
||
|
if len(item.args[0].args) != 1 and isinstance(item.base, Mul): # To avoid situations like #14160
|
||
|
pow_paren.append(item)
|
||
|
b.append(Pow(item.base, -item.exp))
|
||
|
else:
|
||
|
a.append(item)
|
||
|
|
||
|
a = a or [S.One]
|
||
|
|
||
|
if len(a) == 1 and sign == "-":
|
||
|
# Unary minus does not have a SymPy class, and hence there's no
|
||
|
# precedence weight associated with it, Python's unary minus has
|
||
|
# an operator precedence between multiplication and exponentiation,
|
||
|
# so we use this to compute a weight.
|
||
|
a_str = [self.parenthesize(a[0], 0.5*(PRECEDENCE["Pow"]+PRECEDENCE["Mul"]))]
|
||
|
else:
|
||
|
a_str = [self.parenthesize(x, prec) for x in a]
|
||
|
b_str = [self.parenthesize(x, prec) for x in b]
|
||
|
|
||
|
# To parenthesize Pow with exp = -1 and having more than one Symbol
|
||
|
for item in pow_paren:
|
||
|
if item.base in b:
|
||
|
b_str[b.index(item.base)] = "(%s)" % b_str[b.index(item.base)]
|
||
|
|
||
|
if not b:
|
||
|
return sign + '*'.join(a_str)
|
||
|
elif len(b) == 1:
|
||
|
return sign + '*'.join(a_str) + "/" + b_str[0]
|
||
|
else:
|
||
|
return sign + '*'.join(a_str) + "/(%s)" % '*'.join(b_str)
|
||
|
|
||
|
def _print_not_supported(self, expr):
|
||
|
try:
|
||
|
self._not_supported.add(expr)
|
||
|
except TypeError:
|
||
|
# not hashable
|
||
|
pass
|
||
|
return self.emptyPrinter(expr)
|
||
|
|
||
|
# The following can not be simply translated into C or Fortran
|
||
|
_print_Basic = _print_not_supported
|
||
|
_print_ComplexInfinity = _print_not_supported
|
||
|
_print_Derivative = _print_not_supported
|
||
|
_print_ExprCondPair = _print_not_supported
|
||
|
_print_GeometryEntity = _print_not_supported
|
||
|
_print_Infinity = _print_not_supported
|
||
|
_print_Integral = _print_not_supported
|
||
|
_print_Interval = _print_not_supported
|
||
|
_print_AccumulationBounds = _print_not_supported
|
||
|
_print_Limit = _print_not_supported
|
||
|
_print_MatrixBase = _print_not_supported
|
||
|
_print_DeferredVector = _print_not_supported
|
||
|
_print_NaN = _print_not_supported
|
||
|
_print_NegativeInfinity = _print_not_supported
|
||
|
_print_Order = _print_not_supported
|
||
|
_print_RootOf = _print_not_supported
|
||
|
_print_RootsOf = _print_not_supported
|
||
|
_print_RootSum = _print_not_supported
|
||
|
_print_Uniform = _print_not_supported
|
||
|
_print_Unit = _print_not_supported
|
||
|
_print_Wild = _print_not_supported
|
||
|
_print_WildFunction = _print_not_supported
|
||
|
_print_Relational = _print_not_supported
|
||
|
|
||
|
|
||
|
# Code printer functions. These are included in this file so that they can be
|
||
|
# imported in the top-level __init__.py without importing the sympy.codegen
|
||
|
# module.
|
||
|
|
||
|
def ccode(expr, assign_to=None, standard='c99', **settings):
|
||
|
"""Converts an expr to a string of c code
|
||
|
|
||
|
Parameters
|
||
|
==========
|
||
|
|
||
|
expr : Expr
|
||
|
A SymPy expression to be converted.
|
||
|
assign_to : optional
|
||
|
When given, the argument is used as the name of the variable to which
|
||
|
the expression is assigned. Can be a string, ``Symbol``,
|
||
|
``MatrixSymbol``, or ``Indexed`` type. This is helpful in case of
|
||
|
line-wrapping, or for expressions that generate multi-line statements.
|
||
|
standard : str, optional
|
||
|
String specifying the standard. If your compiler supports a more modern
|
||
|
standard you may set this to 'c99' to allow the printer to use more math
|
||
|
functions. [default='c89'].
|
||
|
precision : integer, optional
|
||
|
The precision for numbers such as pi [default=17].
|
||
|
user_functions : dict, optional
|
||
|
A dictionary where the keys are string representations of either
|
||
|
``FunctionClass`` or ``UndefinedFunction`` instances and the values
|
||
|
are their desired C string representations. Alternatively, the
|
||
|
dictionary value can be a list of tuples i.e. [(argument_test,
|
||
|
cfunction_string)] or [(argument_test, cfunction_formater)]. See below
|
||
|
for examples.
|
||
|
dereference : iterable, optional
|
||
|
An iterable of symbols that should be dereferenced in the printed code
|
||
|
expression. These would be values passed by address to the function.
|
||
|
For example, if ``dereference=[a]``, the resulting code would print
|
||
|
``(*a)`` instead of ``a``.
|
||
|
human : bool, optional
|
||
|
If True, the result is a single string that may contain some constant
|
||
|
declarations for the number symbols. If False, the same information is
|
||
|
returned in a tuple of (symbols_to_declare, not_supported_functions,
|
||
|
code_text). [default=True].
|
||
|
contract: bool, optional
|
||
|
If True, ``Indexed`` instances are assumed to obey tensor contraction
|
||
|
rules and the corresponding nested loops over indices are generated.
|
||
|
Setting contract=False will not generate loops, instead the user is
|
||
|
responsible to provide values for the indices in the code.
|
||
|
[default=True].
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import ccode, symbols, Rational, sin, ceiling, Abs, Function
|
||
|
>>> x, tau = symbols("x, tau")
|
||
|
>>> expr = (2*tau)**Rational(7, 2)
|
||
|
>>> ccode(expr)
|
||
|
'8*M_SQRT2*pow(tau, 7.0/2.0)'
|
||
|
>>> ccode(expr, math_macros={})
|
||
|
'8*sqrt(2)*pow(tau, 7.0/2.0)'
|
||
|
>>> ccode(sin(x), assign_to="s")
|
||
|
's = sin(x);'
|
||
|
>>> from sympy.codegen.ast import real, float80
|
||
|
>>> ccode(expr, type_aliases={real: float80})
|
||
|
'8*M_SQRT2l*powl(tau, 7.0L/2.0L)'
|
||
|
|
||
|
Simple custom printing can be defined for certain types by passing a
|
||
|
dictionary of {"type" : "function"} to the ``user_functions`` kwarg.
|
||
|
Alternatively, the dictionary value can be a list of tuples i.e.
|
||
|
[(argument_test, cfunction_string)].
|
||
|
|
||
|
>>> custom_functions = {
|
||
|
... "ceiling": "CEIL",
|
||
|
... "Abs": [(lambda x: not x.is_integer, "fabs"),
|
||
|
... (lambda x: x.is_integer, "ABS")],
|
||
|
... "func": "f"
|
||
|
... }
|
||
|
>>> func = Function('func')
|
||
|
>>> ccode(func(Abs(x) + ceiling(x)), standard='C89', user_functions=custom_functions)
|
||
|
'f(fabs(x) + CEIL(x))'
|
||
|
|
||
|
or if the C-function takes a subset of the original arguments:
|
||
|
|
||
|
>>> ccode(2**x + 3**x, standard='C99', user_functions={'Pow': [
|
||
|
... (lambda b, e: b == 2, lambda b, e: 'exp2(%s)' % e),
|
||
|
... (lambda b, e: b != 2, 'pow')]})
|
||
|
'exp2(x) + pow(3, x)'
|
||
|
|
||
|
``Piecewise`` expressions are converted into conditionals. If an
|
||
|
``assign_to`` variable is provided an if statement is created, otherwise
|
||
|
the ternary operator is used. Note that if the ``Piecewise`` lacks a
|
||
|
default term, represented by ``(expr, True)`` then an error will be thrown.
|
||
|
This is to prevent generating an expression that may not evaluate to
|
||
|
anything.
|
||
|
|
||
|
>>> from sympy import Piecewise
|
||
|
>>> expr = Piecewise((x + 1, x > 0), (x, True))
|
||
|
>>> print(ccode(expr, tau, standard='C89'))
|
||
|
if (x > 0) {
|
||
|
tau = x + 1;
|
||
|
}
|
||
|
else {
|
||
|
tau = x;
|
||
|
}
|
||
|
|
||
|
Support for loops is provided through ``Indexed`` types. With
|
||
|
``contract=True`` these expressions will be turned into loops, whereas
|
||
|
``contract=False`` will just print the assignment expression that should be
|
||
|
looped over:
|
||
|
|
||
|
>>> from sympy import Eq, IndexedBase, Idx
|
||
|
>>> len_y = 5
|
||
|
>>> y = IndexedBase('y', shape=(len_y,))
|
||
|
>>> t = IndexedBase('t', shape=(len_y,))
|
||
|
>>> Dy = IndexedBase('Dy', shape=(len_y-1,))
|
||
|
>>> i = Idx('i', len_y-1)
|
||
|
>>> e=Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i]))
|
||
|
>>> ccode(e.rhs, assign_to=e.lhs, contract=False, standard='C89')
|
||
|
'Dy[i] = (y[i + 1] - y[i])/(t[i + 1] - t[i]);'
|
||
|
|
||
|
Matrices are also supported, but a ``MatrixSymbol`` of the same dimensions
|
||
|
must be provided to ``assign_to``. Note that any expression that can be
|
||
|
generated normally can also exist inside a Matrix:
|
||
|
|
||
|
>>> from sympy import Matrix, MatrixSymbol
|
||
|
>>> mat = Matrix([x**2, Piecewise((x + 1, x > 0), (x, True)), sin(x)])
|
||
|
>>> A = MatrixSymbol('A', 3, 1)
|
||
|
>>> print(ccode(mat, A, standard='C89'))
|
||
|
A[0] = pow(x, 2);
|
||
|
if (x > 0) {
|
||
|
A[1] = x + 1;
|
||
|
}
|
||
|
else {
|
||
|
A[1] = x;
|
||
|
}
|
||
|
A[2] = sin(x);
|
||
|
"""
|
||
|
from sympy.printing.c import c_code_printers
|
||
|
return c_code_printers[standard.lower()](settings).doprint(expr, assign_to)
|
||
|
|
||
|
def print_ccode(expr, **settings):
|
||
|
"""Prints C representation of the given expression."""
|
||
|
print(ccode(expr, **settings))
|
||
|
|
||
|
def fcode(expr, assign_to=None, **settings):
|
||
|
"""Converts an expr to a string of fortran code
|
||
|
|
||
|
Parameters
|
||
|
==========
|
||
|
|
||
|
expr : Expr
|
||
|
A SymPy expression to be converted.
|
||
|
assign_to : optional
|
||
|
When given, the argument is used as the name of the variable to which
|
||
|
the expression is assigned. Can be a string, ``Symbol``,
|
||
|
``MatrixSymbol``, or ``Indexed`` type. This is helpful in case of
|
||
|
line-wrapping, or for expressions that generate multi-line statements.
|
||
|
precision : integer, optional
|
||
|
DEPRECATED. Use type_mappings instead. The precision for numbers such
|
||
|
as pi [default=17].
|
||
|
user_functions : dict, optional
|
||
|
A dictionary where keys are ``FunctionClass`` instances and values are
|
||
|
their string representations. Alternatively, the dictionary value can
|
||
|
be a list of tuples i.e. [(argument_test, cfunction_string)]. See below
|
||
|
for examples.
|
||
|
human : bool, optional
|
||
|
If True, the result is a single string that may contain some constant
|
||
|
declarations for the number symbols. If False, the same information is
|
||
|
returned in a tuple of (symbols_to_declare, not_supported_functions,
|
||
|
code_text). [default=True].
|
||
|
contract: bool, optional
|
||
|
If True, ``Indexed`` instances are assumed to obey tensor contraction
|
||
|
rules and the corresponding nested loops over indices are generated.
|
||
|
Setting contract=False will not generate loops, instead the user is
|
||
|
responsible to provide values for the indices in the code.
|
||
|
[default=True].
|
||
|
source_format : optional
|
||
|
The source format can be either 'fixed' or 'free'. [default='fixed']
|
||
|
standard : integer, optional
|
||
|
The Fortran standard to be followed. This is specified as an integer.
|
||
|
Acceptable standards are 66, 77, 90, 95, 2003, and 2008. Default is 77.
|
||
|
Note that currently the only distinction internally is between
|
||
|
standards before 95, and those 95 and after. This may change later as
|
||
|
more features are added.
|
||
|
name_mangling : bool, optional
|
||
|
If True, then the variables that would become identical in
|
||
|
case-insensitive Fortran are mangled by appending different number
|
||
|
of ``_`` at the end. If False, SymPy Will not interfere with naming of
|
||
|
variables. [default=True]
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import fcode, symbols, Rational, sin, ceiling, floor
|
||
|
>>> x, tau = symbols("x, tau")
|
||
|
>>> fcode((2*tau)**Rational(7, 2))
|
||
|
' 8*sqrt(2.0d0)*tau**(7.0d0/2.0d0)'
|
||
|
>>> fcode(sin(x), assign_to="s")
|
||
|
' s = sin(x)'
|
||
|
|
||
|
Custom printing can be defined for certain types by passing a dictionary of
|
||
|
"type" : "function" to the ``user_functions`` kwarg. Alternatively, the
|
||
|
dictionary value can be a list of tuples i.e. [(argument_test,
|
||
|
cfunction_string)].
|
||
|
|
||
|
>>> custom_functions = {
|
||
|
... "ceiling": "CEIL",
|
||
|
... "floor": [(lambda x: not x.is_integer, "FLOOR1"),
|
||
|
... (lambda x: x.is_integer, "FLOOR2")]
|
||
|
... }
|
||
|
>>> fcode(floor(x) + ceiling(x), user_functions=custom_functions)
|
||
|
' CEIL(x) + FLOOR1(x)'
|
||
|
|
||
|
``Piecewise`` expressions are converted into conditionals. If an
|
||
|
``assign_to`` variable is provided an if statement is created, otherwise
|
||
|
the ternary operator is used. Note that if the ``Piecewise`` lacks a
|
||
|
default term, represented by ``(expr, True)`` then an error will be thrown.
|
||
|
This is to prevent generating an expression that may not evaluate to
|
||
|
anything.
|
||
|
|
||
|
>>> from sympy import Piecewise
|
||
|
>>> expr = Piecewise((x + 1, x > 0), (x, True))
|
||
|
>>> print(fcode(expr, tau))
|
||
|
if (x > 0) then
|
||
|
tau = x + 1
|
||
|
else
|
||
|
tau = x
|
||
|
end if
|
||
|
|
||
|
Support for loops is provided through ``Indexed`` types. With
|
||
|
``contract=True`` these expressions will be turned into loops, whereas
|
||
|
``contract=False`` will just print the assignment expression that should be
|
||
|
looped over:
|
||
|
|
||
|
>>> from sympy import Eq, IndexedBase, Idx
|
||
|
>>> len_y = 5
|
||
|
>>> y = IndexedBase('y', shape=(len_y,))
|
||
|
>>> t = IndexedBase('t', shape=(len_y,))
|
||
|
>>> Dy = IndexedBase('Dy', shape=(len_y-1,))
|
||
|
>>> i = Idx('i', len_y-1)
|
||
|
>>> e=Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i]))
|
||
|
>>> fcode(e.rhs, assign_to=e.lhs, contract=False)
|
||
|
' Dy(i) = (y(i + 1) - y(i))/(t(i + 1) - t(i))'
|
||
|
|
||
|
Matrices are also supported, but a ``MatrixSymbol`` of the same dimensions
|
||
|
must be provided to ``assign_to``. Note that any expression that can be
|
||
|
generated normally can also exist inside a Matrix:
|
||
|
|
||
|
>>> from sympy import Matrix, MatrixSymbol
|
||
|
>>> mat = Matrix([x**2, Piecewise((x + 1, x > 0), (x, True)), sin(x)])
|
||
|
>>> A = MatrixSymbol('A', 3, 1)
|
||
|
>>> print(fcode(mat, A))
|
||
|
A(1, 1) = x**2
|
||
|
if (x > 0) then
|
||
|
A(2, 1) = x + 1
|
||
|
else
|
||
|
A(2, 1) = x
|
||
|
end if
|
||
|
A(3, 1) = sin(x)
|
||
|
"""
|
||
|
from sympy.printing.fortran import FCodePrinter
|
||
|
return FCodePrinter(settings).doprint(expr, assign_to)
|
||
|
|
||
|
|
||
|
def print_fcode(expr, **settings):
|
||
|
"""Prints the Fortran representation of the given expression.
|
||
|
|
||
|
See fcode for the meaning of the optional arguments.
|
||
|
"""
|
||
|
print(fcode(expr, **settings))
|
||
|
|
||
|
def cxxcode(expr, assign_to=None, standard='c++11', **settings):
|
||
|
""" C++ equivalent of :func:`~.ccode`. """
|
||
|
from sympy.printing.cxx import cxx_code_printers
|
||
|
return cxx_code_printers[standard.lower()](settings).doprint(expr, assign_to)
|