from sympy.core import S from sympy.core.function import Lambda from sympy.core.power import Pow from .pycode import PythonCodePrinter, _known_functions_math, _print_known_const, _print_known_func, _unpack_integral_limits, ArrayPrinter from .codeprinter import CodePrinter _not_in_numpy = 'erf erfc factorial gamma loggamma'.split() _in_numpy = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_numpy] _known_functions_numpy = dict(_in_numpy, **{ 'acos': 'arccos', 'acosh': 'arccosh', 'asin': 'arcsin', 'asinh': 'arcsinh', 'atan': 'arctan', 'atan2': 'arctan2', 'atanh': 'arctanh', 'exp2': 'exp2', 'sign': 'sign', 'logaddexp': 'logaddexp', 'logaddexp2': 'logaddexp2', }) _known_constants_numpy = { 'Exp1': 'e', 'Pi': 'pi', 'EulerGamma': 'euler_gamma', 'NaN': 'nan', 'Infinity': 'PINF', 'NegativeInfinity': 'NINF' } _numpy_known_functions = {k: 'numpy.' + v for k, v in _known_functions_numpy.items()} _numpy_known_constants = {k: 'numpy.' + v for k, v in _known_constants_numpy.items()} class NumPyPrinter(ArrayPrinter, PythonCodePrinter): """ Numpy printer which handles vectorized piecewise functions, logical operators, etc. """ _module = 'numpy' _kf = _numpy_known_functions _kc = _numpy_known_constants def __init__(self, settings=None): """ `settings` is passed to CodePrinter.__init__() `module` specifies the array module to use, currently 'NumPy', 'CuPy' or 'JAX'. """ self.language = "Python with {}".format(self._module) self.printmethod = "_{}code".format(self._module) self._kf = {**PythonCodePrinter._kf, **self._kf} super().__init__(settings=settings) def _print_seq(self, seq): "General sequence printer: converts to tuple" # Print tuples here instead of lists because numba supports # tuples in nopython mode. delimiter=', ' return '({},)'.format(delimiter.join(self._print(item) for item in seq)) def _print_MatMul(self, expr): "Matrix multiplication printer" if expr.as_coeff_matrices()[0] is not S.One: expr_list = expr.as_coeff_matrices()[1]+[(expr.as_coeff_matrices()[0])] return '({})'.format(').dot('.join(self._print(i) for i in expr_list)) return '({})'.format(').dot('.join(self._print(i) for i in expr.args)) def _print_MatPow(self, expr): "Matrix power printer" return '{}({}, {})'.format(self._module_format(self._module + '.linalg.matrix_power'), self._print(expr.args[0]), self._print(expr.args[1])) def _print_Inverse(self, expr): "Matrix inverse printer" return '{}({})'.format(self._module_format(self._module + '.linalg.inv'), self._print(expr.args[0])) def _print_DotProduct(self, expr): # DotProduct allows any shape order, but numpy.dot does matrix # multiplication, so we have to make sure it gets 1 x n by n x 1. arg1, arg2 = expr.args if arg1.shape[0] != 1: arg1 = arg1.T if arg2.shape[1] != 1: arg2 = arg2.T return "%s(%s, %s)" % (self._module_format(self._module + '.dot'), self._print(arg1), self._print(arg2)) def _print_MatrixSolve(self, expr): return "%s(%s, %s)" % (self._module_format(self._module + '.linalg.solve'), self._print(expr.matrix), self._print(expr.vector)) def _print_ZeroMatrix(self, expr): return '{}({})'.format(self._module_format(self._module + '.zeros'), self._print(expr.shape)) def _print_OneMatrix(self, expr): return '{}({})'.format(self._module_format(self._module + '.ones'), self._print(expr.shape)) def _print_FunctionMatrix(self, expr): from sympy.abc import i, j lamda = expr.lamda if not isinstance(lamda, Lambda): lamda = Lambda((i, j), lamda(i, j)) return '{}(lambda {}: {}, {})'.format(self._module_format(self._module + '.fromfunction'), ', '.join(self._print(arg) for arg in lamda.args[0]), self._print(lamda.args[1]), self._print(expr.shape)) def _print_HadamardProduct(self, expr): func = self._module_format(self._module + '.multiply') return ''.join('{}({}, '.format(func, self._print(arg)) \ for arg in expr.args[:-1]) + "{}{}".format(self._print(expr.args[-1]), ')' * (len(expr.args) - 1)) def _print_KroneckerProduct(self, expr): func = self._module_format(self._module + '.kron') return ''.join('{}({}, '.format(func, self._print(arg)) \ for arg in expr.args[:-1]) + "{}{}".format(self._print(expr.args[-1]), ')' * (len(expr.args) - 1)) def _print_Adjoint(self, expr): return '{}({}({}))'.format( self._module_format(self._module + '.conjugate'), self._module_format(self._module + '.transpose'), self._print(expr.args[0])) def _print_DiagonalOf(self, expr): vect = '{}({})'.format( self._module_format(self._module + '.diag'), self._print(expr.arg)) return '{}({}, (-1, 1))'.format( self._module_format(self._module + '.reshape'), vect) def _print_DiagMatrix(self, expr): return '{}({})'.format(self._module_format(self._module + '.diagflat'), self._print(expr.args[0])) def _print_DiagonalMatrix(self, expr): return '{}({}, {}({}, {}))'.format(self._module_format(self._module + '.multiply'), self._print(expr.arg), self._module_format(self._module + '.eye'), self._print(expr.shape[0]), self._print(expr.shape[1])) def _print_Piecewise(self, expr): "Piecewise function printer" from sympy.logic.boolalg import ITE, simplify_logic def print_cond(cond): """ Problem having an ITE in the cond. """ if cond.has(ITE): return self._print(simplify_logic(cond)) else: return self._print(cond) exprs = '[{}]'.format(','.join(self._print(arg.expr) for arg in expr.args)) conds = '[{}]'.format(','.join(print_cond(arg.cond) for arg in expr.args)) # If [default_value, True] is a (expr, cond) sequence in a Piecewise object # it will behave the same as passing the 'default' kwarg to select() # *as long as* it is the last element in expr.args. # If this is not the case, it may be triggered prematurely. return '{}({}, {}, default={})'.format( self._module_format(self._module + '.select'), conds, exprs, self._print(S.NaN)) def _print_Relational(self, expr): "Relational printer for Equality and Unequality" op = { '==' :'equal', '!=' :'not_equal', '<' :'less', '<=' :'less_equal', '>' :'greater', '>=' :'greater_equal', } if expr.rel_op in op: lhs = self._print(expr.lhs) rhs = self._print(expr.rhs) return '{op}({lhs}, {rhs})'.format(op=self._module_format(self._module + '.'+op[expr.rel_op]), lhs=lhs, rhs=rhs) return super()._print_Relational(expr) def _print_And(self, expr): "Logical And printer" # We have to override LambdaPrinter because it uses Python 'and' keyword. # If LambdaPrinter didn't define it, we could use StrPrinter's # version of the function and add 'logical_and' to NUMPY_TRANSLATIONS. return '{}.reduce(({}))'.format(self._module_format(self._module + '.logical_and'), ','.join(self._print(i) for i in expr.args)) def _print_Or(self, expr): "Logical Or printer" # We have to override LambdaPrinter because it uses Python 'or' keyword. # If LambdaPrinter didn't define it, we could use StrPrinter's # version of the function and add 'logical_or' to NUMPY_TRANSLATIONS. return '{}.reduce(({}))'.format(self._module_format(self._module + '.logical_or'), ','.join(self._print(i) for i in expr.args)) def _print_Not(self, expr): "Logical Not printer" # We have to override LambdaPrinter because it uses Python 'not' keyword. # If LambdaPrinter didn't define it, we would still have to define our # own because StrPrinter doesn't define it. return '{}({})'.format(self._module_format(self._module + '.logical_not'), ','.join(self._print(i) for i in expr.args)) def _print_Pow(self, expr, rational=False): # XXX Workaround for negative integer power error if expr.exp.is_integer and expr.exp.is_negative: expr = Pow(expr.base, expr.exp.evalf(), evaluate=False) return self._hprint_Pow(expr, rational=rational, sqrt=self._module + '.sqrt') def _print_Min(self, expr): return '{}(({}), axis=0)'.format(self._module_format(self._module + '.amin'), ','.join(self._print(i) for i in expr.args)) def _print_Max(self, expr): return '{}(({}), axis=0)'.format(self._module_format(self._module + '.amax'), ','.join(self._print(i) for i in expr.args)) def _print_arg(self, expr): return "%s(%s)" % (self._module_format(self._module + '.angle'), self._print(expr.args[0])) def _print_im(self, expr): return "%s(%s)" % (self._module_format(self._module + '.imag'), self._print(expr.args[0])) def _print_Mod(self, expr): return "%s(%s)" % (self._module_format(self._module + '.mod'), ', '.join( (self._print(arg) for arg in expr.args))) def _print_re(self, expr): return "%s(%s)" % (self._module_format(self._module + '.real'), self._print(expr.args[0])) def _print_sinc(self, expr): return "%s(%s)" % (self._module_format(self._module + '.sinc'), self._print(expr.args[0]/S.Pi)) def _print_MatrixBase(self, expr): func = self.known_functions.get(expr.__class__.__name__, None) if func is None: func = self._module_format(self._module + '.array') return "%s(%s)" % (func, self._print(expr.tolist())) def _print_Identity(self, expr): shape = expr.shape if all(dim.is_Integer for dim in shape): return "%s(%s)" % (self._module_format(self._module + '.eye'), self._print(expr.shape[0])) else: raise NotImplementedError("Symbolic matrix dimensions are not yet supported for identity matrices") def _print_BlockMatrix(self, expr): return '{}({})'.format(self._module_format(self._module + '.block'), self._print(expr.args[0].tolist())) def _print_NDimArray(self, expr): if len(expr.shape) == 1: return self._module + '.array(' + self._print(expr.args[0]) + ')' if len(expr.shape) == 2: return self._print(expr.tomatrix()) # Should be possible to extend to more dimensions return CodePrinter._print_not_supported(self, expr) _add = "add" _einsum = "einsum" _transpose = "transpose" _ones = "ones" _zeros = "zeros" _print_lowergamma = CodePrinter._print_not_supported _print_uppergamma = CodePrinter._print_not_supported _print_fresnelc = CodePrinter._print_not_supported _print_fresnels = CodePrinter._print_not_supported for func in _numpy_known_functions: setattr(NumPyPrinter, f'_print_{func}', _print_known_func) for const in _numpy_known_constants: setattr(NumPyPrinter, f'_print_{const}', _print_known_const) _known_functions_scipy_special = { 'Ei': 'expi', 'erf': 'erf', 'erfc': 'erfc', 'besselj': 'jv', 'bessely': 'yv', 'besseli': 'iv', 'besselk': 'kv', 'cosm1': 'cosm1', 'powm1': 'powm1', 'factorial': 'factorial', 'gamma': 'gamma', 'loggamma': 'gammaln', 'digamma': 'psi', 'polygamma': 'polygamma', 'RisingFactorial': 'poch', 'jacobi': 'eval_jacobi', 'gegenbauer': 'eval_gegenbauer', 'chebyshevt': 'eval_chebyt', 'chebyshevu': 'eval_chebyu', 'legendre': 'eval_legendre', 'hermite': 'eval_hermite', 'laguerre': 'eval_laguerre', 'assoc_laguerre': 'eval_genlaguerre', 'beta': 'beta', 'LambertW' : 'lambertw', } _known_constants_scipy_constants = { 'GoldenRatio': 'golden_ratio', 'Pi': 'pi', } _scipy_known_functions = {k : "scipy.special." + v for k, v in _known_functions_scipy_special.items()} _scipy_known_constants = {k : "scipy.constants." + v for k, v in _known_constants_scipy_constants.items()} class SciPyPrinter(NumPyPrinter): _kf = {**NumPyPrinter._kf, **_scipy_known_functions} _kc = {**NumPyPrinter._kc, **_scipy_known_constants} def __init__(self, settings=None): super().__init__(settings=settings) self.language = "Python with SciPy and NumPy" def _print_SparseRepMatrix(self, expr): i, j, data = [], [], [] for (r, c), v in expr.todok().items(): i.append(r) j.append(c) data.append(v) return "{name}(({data}, ({i}, {j})), shape={shape})".format( name=self._module_format('scipy.sparse.coo_matrix'), data=data, i=i, j=j, shape=expr.shape ) _print_ImmutableSparseMatrix = _print_SparseRepMatrix # SciPy's lpmv has a different order of arguments from assoc_legendre def _print_assoc_legendre(self, expr): return "{0}({2}, {1}, {3})".format( self._module_format('scipy.special.lpmv'), self._print(expr.args[0]), self._print(expr.args[1]), self._print(expr.args[2])) def _print_lowergamma(self, expr): return "{0}({2})*{1}({2}, {3})".format( self._module_format('scipy.special.gamma'), self._module_format('scipy.special.gammainc'), self._print(expr.args[0]), self._print(expr.args[1])) def _print_uppergamma(self, expr): return "{0}({2})*{1}({2}, {3})".format( self._module_format('scipy.special.gamma'), self._module_format('scipy.special.gammaincc'), self._print(expr.args[0]), self._print(expr.args[1])) def _print_betainc(self, expr): betainc = self._module_format('scipy.special.betainc') beta = self._module_format('scipy.special.beta') args = [self._print(arg) for arg in expr.args] return f"({betainc}({args[0]}, {args[1]}, {args[3]}) - {betainc}({args[0]}, {args[1]}, {args[2]})) \ * {beta}({args[0]}, {args[1]})" def _print_betainc_regularized(self, expr): return "{0}({1}, {2}, {4}) - {0}({1}, {2}, {3})".format( self._module_format('scipy.special.betainc'), self._print(expr.args[0]), self._print(expr.args[1]), self._print(expr.args[2]), self._print(expr.args[3])) def _print_fresnels(self, expr): return "{}({})[0]".format( self._module_format("scipy.special.fresnel"), self._print(expr.args[0])) def _print_fresnelc(self, expr): return "{}({})[1]".format( self._module_format("scipy.special.fresnel"), self._print(expr.args[0])) def _print_airyai(self, expr): return "{}({})[0]".format( self._module_format("scipy.special.airy"), self._print(expr.args[0])) def _print_airyaiprime(self, expr): return "{}({})[1]".format( self._module_format("scipy.special.airy"), self._print(expr.args[0])) def _print_airybi(self, expr): return "{}({})[2]".format( self._module_format("scipy.special.airy"), self._print(expr.args[0])) def _print_airybiprime(self, expr): return "{}({})[3]".format( self._module_format("scipy.special.airy"), self._print(expr.args[0])) def _print_bernoulli(self, expr): # scipy's bernoulli is inconsistent with SymPy's so rewrite return self._print(expr._eval_rewrite_as_zeta(*expr.args)) def _print_harmonic(self, expr): return self._print(expr._eval_rewrite_as_zeta(*expr.args)) def _print_Integral(self, e): integration_vars, limits = _unpack_integral_limits(e) if len(limits) == 1: # nicer (but not necessary) to prefer quad over nquad for 1D case module_str = self._module_format("scipy.integrate.quad") limit_str = "%s, %s" % tuple(map(self._print, limits[0])) else: module_str = self._module_format("scipy.integrate.nquad") limit_str = "({})".format(", ".join( "(%s, %s)" % tuple(map(self._print, l)) for l in limits)) return "{}(lambda {}: {}, {})[0]".format( module_str, ", ".join(map(self._print, integration_vars)), self._print(e.args[0]), limit_str) def _print_Si(self, expr): return "{}({})[0]".format( self._module_format("scipy.special.sici"), self._print(expr.args[0])) def _print_Ci(self, expr): return "{}({})[1]".format( self._module_format("scipy.special.sici"), self._print(expr.args[0])) for func in _scipy_known_functions: setattr(SciPyPrinter, f'_print_{func}', _print_known_func) for const in _scipy_known_constants: setattr(SciPyPrinter, f'_print_{const}', _print_known_const) _cupy_known_functions = {k : "cupy." + v for k, v in _known_functions_numpy.items()} _cupy_known_constants = {k : "cupy." + v for k, v in _known_constants_numpy.items()} class CuPyPrinter(NumPyPrinter): """ CuPy printer which handles vectorized piecewise functions, logical operators, etc. """ _module = 'cupy' _kf = _cupy_known_functions _kc = _cupy_known_constants def __init__(self, settings=None): super().__init__(settings=settings) for func in _cupy_known_functions: setattr(CuPyPrinter, f'_print_{func}', _print_known_func) for const in _cupy_known_constants: setattr(CuPyPrinter, f'_print_{const}', _print_known_const) _jax_known_functions = {k: 'jax.numpy.' + v for k, v in _known_functions_numpy.items()} _jax_known_constants = {k: 'jax.numpy.' + v for k, v in _known_constants_numpy.items()} class JaxPrinter(NumPyPrinter): """ JAX printer which handles vectorized piecewise functions, logical operators, etc. """ _module = "jax.numpy" _kf = _jax_known_functions _kc = _jax_known_constants def __init__(self, settings=None): super().__init__(settings=settings) # These need specific override to allow for the lack of "jax.numpy.reduce" def _print_And(self, expr): "Logical And printer" return "{}({}.asarray([{}]), axis=0)".format( self._module_format(self._module + ".all"), self._module_format(self._module), ",".join(self._print(i) for i in expr.args), ) def _print_Or(self, expr): "Logical Or printer" return "{}({}.asarray([{}]), axis=0)".format( self._module_format(self._module + ".any"), self._module_format(self._module), ",".join(self._print(i) for i in expr.args), ) for func in _jax_known_functions: setattr(JaxPrinter, f'_print_{func}', _print_known_func) for const in _jax_known_constants: setattr(JaxPrinter, f'_print_{const}', _print_known_const)