from sympy.core import (S, pi, oo, symbols, Rational, Integer, GoldenRatio, EulerGamma, Catalan, Lambda, Dummy, Eq, Ne, Le, Lt, Gt, Ge, Mod) from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt, sign, floor) from sympy.logic import ITE from sympy.testing.pytest import raises from sympy.utilities.lambdify import implemented_function from sympy.tensor import IndexedBase, Idx from sympy.matrices import MatrixSymbol, SparseMatrix, Matrix from sympy.printing.rust import rust_code x, y, z = symbols('x,y,z') def test_Integer(): assert rust_code(Integer(42)) == "42" assert rust_code(Integer(-56)) == "-56" def test_Relational(): assert rust_code(Eq(x, y)) == "x == y" assert rust_code(Ne(x, y)) == "x != y" assert rust_code(Le(x, y)) == "x <= y" assert rust_code(Lt(x, y)) == "x < y" assert rust_code(Gt(x, y)) == "x > y" assert rust_code(Ge(x, y)) == "x >= y" def test_Rational(): assert rust_code(Rational(3, 7)) == "3_f64/7.0" assert rust_code(Rational(18, 9)) == "2" assert rust_code(Rational(3, -7)) == "-3_f64/7.0" assert rust_code(Rational(-3, -7)) == "3_f64/7.0" assert rust_code(x + Rational(3, 7)) == "x + 3_f64/7.0" assert rust_code(Rational(3, 7)*x) == "(3_f64/7.0)*x" def test_basic_ops(): assert rust_code(x + y) == "x + y" assert rust_code(x - y) == "x - y" assert rust_code(x * y) == "x*y" assert rust_code(x / y) == "x/y" assert rust_code(-x) == "-x" def test_printmethod(): class fabs(Abs): def _rust_code(self, printer): return "%s.fabs()" % printer._print(self.args[0]) assert rust_code(fabs(x)) == "x.fabs()" a = MatrixSymbol("a", 1, 3) assert rust_code(a[0,0]) == 'a[0]' def test_Functions(): assert rust_code(sin(x) ** cos(x)) == "x.sin().powf(x.cos())" assert rust_code(abs(x)) == "x.abs()" assert rust_code(ceiling(x)) == "x.ceil()" assert rust_code(floor(x)) == "x.floor()" # Automatic rewrite assert rust_code(Mod(x, 3)) == 'x - 3*((1_f64/3.0)*x).floor()' def test_Pow(): assert rust_code(1/x) == "x.recip()" assert rust_code(x**-1) == rust_code(x**-1.0) == "x.recip()" assert rust_code(sqrt(x)) == "x.sqrt()" assert rust_code(x**S.Half) == rust_code(x**0.5) == "x.sqrt()" assert rust_code(1/sqrt(x)) == "x.sqrt().recip()" assert rust_code(x**-S.Half) == rust_code(x**-0.5) == "x.sqrt().recip()" assert rust_code(1/pi) == "PI.recip()" assert rust_code(pi**-1) == rust_code(pi**-1.0) == "PI.recip()" assert rust_code(pi**-0.5) == "PI.sqrt().recip()" assert rust_code(x**Rational(1, 3)) == "x.cbrt()" assert rust_code(2**x) == "x.exp2()" assert rust_code(exp(x)) == "x.exp()" assert rust_code(x**3) == "x.powi(3)" assert rust_code(x**(y**3)) == "x.powf(y.powi(3))" assert rust_code(x**Rational(2, 3)) == "x.powf(2_f64/3.0)" g = implemented_function('g', Lambda(x, 2*x)) assert rust_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ "(3.5*2*x).powf(-x + y.powf(x))/(x.powi(2) + y)" _cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi", 1), (lambda base, exp: not exp.is_integer, "pow", 1)] assert rust_code(x**3, user_functions={'Pow': _cond_cfunc}) == 'x.dpowi(3)' assert rust_code(x**3.2, user_functions={'Pow': _cond_cfunc}) == 'x.pow(3.2)' def test_constants(): assert rust_code(pi) == "PI" assert rust_code(oo) == "INFINITY" assert rust_code(S.Infinity) == "INFINITY" assert rust_code(-oo) == "NEG_INFINITY" assert rust_code(S.NegativeInfinity) == "NEG_INFINITY" assert rust_code(S.NaN) == "NAN" assert rust_code(exp(1)) == "E" assert rust_code(S.Exp1) == "E" def test_constants_other(): assert rust_code(2*GoldenRatio) == "const GoldenRatio: f64 = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17) assert rust_code( 2*Catalan) == "const Catalan: f64 = %s;\n2*Catalan" % Catalan.evalf(17) assert rust_code(2*EulerGamma) == "const EulerGamma: f64 = %s;\n2*EulerGamma" % EulerGamma.evalf(17) def test_boolean(): assert rust_code(True) == "true" assert rust_code(S.true) == "true" assert rust_code(False) == "false" assert rust_code(S.false) == "false" assert rust_code(x & y) == "x && y" assert rust_code(x | y) == "x || y" assert rust_code(~x) == "!x" assert rust_code(x & y & z) == "x && y && z" assert rust_code(x | y | z) == "x || y || z" assert rust_code((x & y) | z) == "z || x && y" assert rust_code((x | y) & z) == "z && (x || y)" def test_Piecewise(): expr = Piecewise((x, x < 1), (x + 2, True)) assert rust_code(expr) == ( "if (x < 1) {\n" " x\n" "} else {\n" " x + 2\n" "}") assert rust_code(expr, assign_to="r") == ( "r = if (x < 1) {\n" " x\n" "} else {\n" " x + 2\n" "};") assert rust_code(expr, assign_to="r", inline=True) == ( "r = if (x < 1) { x } else { x + 2 };") expr = Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) assert rust_code(expr, inline=True) == ( "if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 }") assert rust_code(expr, assign_to="r", inline=True) == ( "r = if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 };") assert rust_code(expr, assign_to="r") == ( "r = if (x < 1) {\n" " x\n" "} else if (x < 5) {\n" " x + 1\n" "} else {\n" " x + 2\n" "};") expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) assert rust_code(expr, inline=True) == ( "2*if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 }") expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) - 42 assert rust_code(expr, inline=True) == ( "2*if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 } - 42") # Check that Piecewise without a True (default) condition error expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) raises(ValueError, lambda: rust_code(expr)) def test_dereference_printing(): expr = x + y + sin(z) + z assert rust_code(expr, dereference=[z]) == "x + y + (*z) + (*z).sin()" def test_sign(): expr = sign(x) * y assert rust_code(expr) == "y*x.signum()" assert rust_code(expr, assign_to='r') == "r = y*x.signum();" expr = sign(x + y) + 42 assert rust_code(expr) == "(x + y).signum() + 42" assert rust_code(expr, assign_to='r') == "r = (x + y).signum() + 42;" expr = sign(cos(x)) assert rust_code(expr) == "x.cos().signum()" def test_reserved_words(): x, y = symbols("x if") expr = sin(y) assert rust_code(expr) == "if_.sin()" assert rust_code(expr, dereference=[y]) == "(*if_).sin()" assert rust_code(expr, reserved_word_suffix='_unreserved') == "if_unreserved.sin()" with raises(ValueError): rust_code(expr, error_on_reserved=True) def test_ITE(): expr = ITE(x < 1, y, z) assert rust_code(expr) == ( "if (x < 1) {\n" " y\n" "} else {\n" " z\n" "}") def test_Indexed(): n, m, o = symbols('n m o', integer=True) i, j, k = Idx('i', n), Idx('j', m), Idx('k', o) x = IndexedBase('x')[j] assert rust_code(x) == "x[j]" A = IndexedBase('A')[i, j] assert rust_code(A) == "A[m*i + j]" B = IndexedBase('B')[i, j, k] assert rust_code(B) == "B[m*o*i + o*j + k]" def test_dummy_loops(): i, m = symbols('i m', integer=True, cls=Dummy) x = IndexedBase('x') y = IndexedBase('y') i = Idx(i, m) assert rust_code(x[i], assign_to=y[i]) == ( "for i in 0..m {\n" " y[i] = x[i];\n" "}") def test_loops(): m, n = symbols('m n', integer=True) A = IndexedBase('A') x = IndexedBase('x') y = IndexedBase('y') z = IndexedBase('z') i = Idx('i', m) j = Idx('j', n) assert rust_code(A[i, j]*x[j], assign_to=y[i]) == ( "for i in 0..m {\n" " y[i] = 0;\n" "}\n" "for i in 0..m {\n" " for j in 0..n {\n" " y[i] = A[n*i + j]*x[j] + y[i];\n" " }\n" "}") assert rust_code(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i]) == ( "for i in 0..m {\n" " y[i] = x[i] + z[i];\n" "}\n" "for i in 0..m {\n" " for j in 0..n {\n" " y[i] = A[n*i + j]*x[j] + y[i];\n" " }\n" "}") def test_loops_multiple_contractions(): n, m, o, p = symbols('n m o p', integer=True) a = IndexedBase('a') b = IndexedBase('b') y = IndexedBase('y') i = Idx('i', m) j = Idx('j', n) k = Idx('k', o) l = Idx('l', p) assert rust_code(b[j, k, l]*a[i, j, k, l], assign_to=y[i]) == ( "for i in 0..m {\n" " y[i] = 0;\n" "}\n" "for i in 0..m {\n" " for j in 0..n {\n" " for k in 0..o {\n" " for l in 0..p {\n" " y[i] = a[%s]*b[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\ " }\n" " }\n" " }\n" "}") def test_loops_addfactor(): m, n, o, p = symbols('m n o p', integer=True) a = IndexedBase('a') b = IndexedBase('b') c = IndexedBase('c') y = IndexedBase('y') i = Idx('i', m) j = Idx('j', n) k = Idx('k', o) l = Idx('l', p) code = rust_code((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i]) assert code == ( "for i in 0..m {\n" " y[i] = 0;\n" "}\n" "for i in 0..m {\n" " for j in 0..n {\n" " for k in 0..o {\n" " for l in 0..p {\n" " y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\ " }\n" " }\n" " }\n" "}") def test_settings(): raises(TypeError, lambda: rust_code(sin(x), method="garbage")) def test_inline_function(): x = symbols('x') g = implemented_function('g', Lambda(x, 2*x)) assert rust_code(g(x)) == "2*x" g = implemented_function('g', Lambda(x, 2*x/Catalan)) assert rust_code(g(x)) == ( "const Catalan: f64 = %s;\n2*x/Catalan" % Catalan.evalf(17)) A = IndexedBase('A') i = Idx('i', symbols('n', integer=True)) g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) assert rust_code(g(A[i]), assign_to=A[i]) == ( "for i in 0..n {\n" " A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n" "}") def test_user_functions(): x = symbols('x', integer=False) n = symbols('n', integer=True) custom_functions = { "ceiling": "ceil", "Abs": [(lambda x: not x.is_integer, "fabs", 4), (lambda x: x.is_integer, "abs", 4)], } assert rust_code(ceiling(x), user_functions=custom_functions) == "x.ceil()" assert rust_code(Abs(x), user_functions=custom_functions) == "fabs(x)" assert rust_code(Abs(n), user_functions=custom_functions) == "abs(n)" def test_matrix(): assert rust_code(Matrix([1, 2, 3])) == '[1, 2, 3]' with raises(ValueError): rust_code(Matrix([[1, 2, 3]])) def test_sparse_matrix(): # gh-15791 assert 'Not supported in Rust' in rust_code(SparseMatrix([[1, 2, 3]]))