360 lines
11 KiB
Python
360 lines
11 KiB
Python
|
from sympy.core import (S, pi, oo, symbols, Rational, Integer,
|
||
|
GoldenRatio, EulerGamma, Catalan, Lambda, Dummy,
|
||
|
Eq, Ne, Le, Lt, Gt, Ge, Mod)
|
||
|
from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt,
|
||
|
sign, floor)
|
||
|
from sympy.logic import ITE
|
||
|
from sympy.testing.pytest import raises
|
||
|
from sympy.utilities.lambdify import implemented_function
|
||
|
from sympy.tensor import IndexedBase, Idx
|
||
|
from sympy.matrices import MatrixSymbol, SparseMatrix, Matrix
|
||
|
|
||
|
from sympy.printing.rust import rust_code
|
||
|
|
||
|
x, y, z = symbols('x,y,z')
|
||
|
|
||
|
|
||
|
def test_Integer():
|
||
|
assert rust_code(Integer(42)) == "42"
|
||
|
assert rust_code(Integer(-56)) == "-56"
|
||
|
|
||
|
|
||
|
def test_Relational():
|
||
|
assert rust_code(Eq(x, y)) == "x == y"
|
||
|
assert rust_code(Ne(x, y)) == "x != y"
|
||
|
assert rust_code(Le(x, y)) == "x <= y"
|
||
|
assert rust_code(Lt(x, y)) == "x < y"
|
||
|
assert rust_code(Gt(x, y)) == "x > y"
|
||
|
assert rust_code(Ge(x, y)) == "x >= y"
|
||
|
|
||
|
|
||
|
def test_Rational():
|
||
|
assert rust_code(Rational(3, 7)) == "3_f64/7.0"
|
||
|
assert rust_code(Rational(18, 9)) == "2"
|
||
|
assert rust_code(Rational(3, -7)) == "-3_f64/7.0"
|
||
|
assert rust_code(Rational(-3, -7)) == "3_f64/7.0"
|
||
|
assert rust_code(x + Rational(3, 7)) == "x + 3_f64/7.0"
|
||
|
assert rust_code(Rational(3, 7)*x) == "(3_f64/7.0)*x"
|
||
|
|
||
|
|
||
|
def test_basic_ops():
|
||
|
assert rust_code(x + y) == "x + y"
|
||
|
assert rust_code(x - y) == "x - y"
|
||
|
assert rust_code(x * y) == "x*y"
|
||
|
assert rust_code(x / y) == "x/y"
|
||
|
assert rust_code(-x) == "-x"
|
||
|
|
||
|
|
||
|
def test_printmethod():
|
||
|
class fabs(Abs):
|
||
|
def _rust_code(self, printer):
|
||
|
return "%s.fabs()" % printer._print(self.args[0])
|
||
|
assert rust_code(fabs(x)) == "x.fabs()"
|
||
|
a = MatrixSymbol("a", 1, 3)
|
||
|
assert rust_code(a[0,0]) == 'a[0]'
|
||
|
|
||
|
|
||
|
def test_Functions():
|
||
|
assert rust_code(sin(x) ** cos(x)) == "x.sin().powf(x.cos())"
|
||
|
assert rust_code(abs(x)) == "x.abs()"
|
||
|
assert rust_code(ceiling(x)) == "x.ceil()"
|
||
|
assert rust_code(floor(x)) == "x.floor()"
|
||
|
|
||
|
# Automatic rewrite
|
||
|
assert rust_code(Mod(x, 3)) == 'x - 3*((1_f64/3.0)*x).floor()'
|
||
|
|
||
|
|
||
|
def test_Pow():
|
||
|
assert rust_code(1/x) == "x.recip()"
|
||
|
assert rust_code(x**-1) == rust_code(x**-1.0) == "x.recip()"
|
||
|
assert rust_code(sqrt(x)) == "x.sqrt()"
|
||
|
assert rust_code(x**S.Half) == rust_code(x**0.5) == "x.sqrt()"
|
||
|
|
||
|
assert rust_code(1/sqrt(x)) == "x.sqrt().recip()"
|
||
|
assert rust_code(x**-S.Half) == rust_code(x**-0.5) == "x.sqrt().recip()"
|
||
|
|
||
|
assert rust_code(1/pi) == "PI.recip()"
|
||
|
assert rust_code(pi**-1) == rust_code(pi**-1.0) == "PI.recip()"
|
||
|
assert rust_code(pi**-0.5) == "PI.sqrt().recip()"
|
||
|
|
||
|
assert rust_code(x**Rational(1, 3)) == "x.cbrt()"
|
||
|
assert rust_code(2**x) == "x.exp2()"
|
||
|
assert rust_code(exp(x)) == "x.exp()"
|
||
|
assert rust_code(x**3) == "x.powi(3)"
|
||
|
assert rust_code(x**(y**3)) == "x.powf(y.powi(3))"
|
||
|
assert rust_code(x**Rational(2, 3)) == "x.powf(2_f64/3.0)"
|
||
|
|
||
|
g = implemented_function('g', Lambda(x, 2*x))
|
||
|
assert rust_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
|
||
|
"(3.5*2*x).powf(-x + y.powf(x))/(x.powi(2) + y)"
|
||
|
_cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi", 1),
|
||
|
(lambda base, exp: not exp.is_integer, "pow", 1)]
|
||
|
assert rust_code(x**3, user_functions={'Pow': _cond_cfunc}) == 'x.dpowi(3)'
|
||
|
assert rust_code(x**3.2, user_functions={'Pow': _cond_cfunc}) == 'x.pow(3.2)'
|
||
|
|
||
|
|
||
|
def test_constants():
|
||
|
assert rust_code(pi) == "PI"
|
||
|
assert rust_code(oo) == "INFINITY"
|
||
|
assert rust_code(S.Infinity) == "INFINITY"
|
||
|
assert rust_code(-oo) == "NEG_INFINITY"
|
||
|
assert rust_code(S.NegativeInfinity) == "NEG_INFINITY"
|
||
|
assert rust_code(S.NaN) == "NAN"
|
||
|
assert rust_code(exp(1)) == "E"
|
||
|
assert rust_code(S.Exp1) == "E"
|
||
|
|
||
|
|
||
|
def test_constants_other():
|
||
|
assert rust_code(2*GoldenRatio) == "const GoldenRatio: f64 = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17)
|
||
|
assert rust_code(
|
||
|
2*Catalan) == "const Catalan: f64 = %s;\n2*Catalan" % Catalan.evalf(17)
|
||
|
assert rust_code(2*EulerGamma) == "const EulerGamma: f64 = %s;\n2*EulerGamma" % EulerGamma.evalf(17)
|
||
|
|
||
|
|
||
|
def test_boolean():
|
||
|
assert rust_code(True) == "true"
|
||
|
assert rust_code(S.true) == "true"
|
||
|
assert rust_code(False) == "false"
|
||
|
assert rust_code(S.false) == "false"
|
||
|
assert rust_code(x & y) == "x && y"
|
||
|
assert rust_code(x | y) == "x || y"
|
||
|
assert rust_code(~x) == "!x"
|
||
|
assert rust_code(x & y & z) == "x && y && z"
|
||
|
assert rust_code(x | y | z) == "x || y || z"
|
||
|
assert rust_code((x & y) | z) == "z || x && y"
|
||
|
assert rust_code((x | y) & z) == "z && (x || y)"
|
||
|
|
||
|
|
||
|
def test_Piecewise():
|
||
|
expr = Piecewise((x, x < 1), (x + 2, True))
|
||
|
assert rust_code(expr) == (
|
||
|
"if (x < 1) {\n"
|
||
|
" x\n"
|
||
|
"} else {\n"
|
||
|
" x + 2\n"
|
||
|
"}")
|
||
|
assert rust_code(expr, assign_to="r") == (
|
||
|
"r = if (x < 1) {\n"
|
||
|
" x\n"
|
||
|
"} else {\n"
|
||
|
" x + 2\n"
|
||
|
"};")
|
||
|
assert rust_code(expr, assign_to="r", inline=True) == (
|
||
|
"r = if (x < 1) { x } else { x + 2 };")
|
||
|
expr = Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True))
|
||
|
assert rust_code(expr, inline=True) == (
|
||
|
"if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 }")
|
||
|
assert rust_code(expr, assign_to="r", inline=True) == (
|
||
|
"r = if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 };")
|
||
|
assert rust_code(expr, assign_to="r") == (
|
||
|
"r = if (x < 1) {\n"
|
||
|
" x\n"
|
||
|
"} else if (x < 5) {\n"
|
||
|
" x + 1\n"
|
||
|
"} else {\n"
|
||
|
" x + 2\n"
|
||
|
"};")
|
||
|
expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True))
|
||
|
assert rust_code(expr, inline=True) == (
|
||
|
"2*if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 }")
|
||
|
expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) - 42
|
||
|
assert rust_code(expr, inline=True) == (
|
||
|
"2*if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 } - 42")
|
||
|
# Check that Piecewise without a True (default) condition error
|
||
|
expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
|
||
|
raises(ValueError, lambda: rust_code(expr))
|
||
|
|
||
|
|
||
|
def test_dereference_printing():
|
||
|
expr = x + y + sin(z) + z
|
||
|
assert rust_code(expr, dereference=[z]) == "x + y + (*z) + (*z).sin()"
|
||
|
|
||
|
|
||
|
def test_sign():
|
||
|
expr = sign(x) * y
|
||
|
assert rust_code(expr) == "y*x.signum()"
|
||
|
assert rust_code(expr, assign_to='r') == "r = y*x.signum();"
|
||
|
|
||
|
expr = sign(x + y) + 42
|
||
|
assert rust_code(expr) == "(x + y).signum() + 42"
|
||
|
assert rust_code(expr, assign_to='r') == "r = (x + y).signum() + 42;"
|
||
|
|
||
|
expr = sign(cos(x))
|
||
|
assert rust_code(expr) == "x.cos().signum()"
|
||
|
|
||
|
|
||
|
def test_reserved_words():
|
||
|
|
||
|
x, y = symbols("x if")
|
||
|
|
||
|
expr = sin(y)
|
||
|
assert rust_code(expr) == "if_.sin()"
|
||
|
assert rust_code(expr, dereference=[y]) == "(*if_).sin()"
|
||
|
assert rust_code(expr, reserved_word_suffix='_unreserved') == "if_unreserved.sin()"
|
||
|
|
||
|
with raises(ValueError):
|
||
|
rust_code(expr, error_on_reserved=True)
|
||
|
|
||
|
|
||
|
def test_ITE():
|
||
|
expr = ITE(x < 1, y, z)
|
||
|
assert rust_code(expr) == (
|
||
|
"if (x < 1) {\n"
|
||
|
" y\n"
|
||
|
"} else {\n"
|
||
|
" z\n"
|
||
|
"}")
|
||
|
|
||
|
|
||
|
def test_Indexed():
|
||
|
n, m, o = symbols('n m o', integer=True)
|
||
|
i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
|
||
|
|
||
|
x = IndexedBase('x')[j]
|
||
|
assert rust_code(x) == "x[j]"
|
||
|
|
||
|
A = IndexedBase('A')[i, j]
|
||
|
assert rust_code(A) == "A[m*i + j]"
|
||
|
|
||
|
B = IndexedBase('B')[i, j, k]
|
||
|
assert rust_code(B) == "B[m*o*i + o*j + k]"
|
||
|
|
||
|
|
||
|
def test_dummy_loops():
|
||
|
i, m = symbols('i m', integer=True, cls=Dummy)
|
||
|
x = IndexedBase('x')
|
||
|
y = IndexedBase('y')
|
||
|
i = Idx(i, m)
|
||
|
|
||
|
assert rust_code(x[i], assign_to=y[i]) == (
|
||
|
"for i in 0..m {\n"
|
||
|
" y[i] = x[i];\n"
|
||
|
"}")
|
||
|
|
||
|
|
||
|
def test_loops():
|
||
|
m, n = symbols('m n', integer=True)
|
||
|
A = IndexedBase('A')
|
||
|
x = IndexedBase('x')
|
||
|
y = IndexedBase('y')
|
||
|
z = IndexedBase('z')
|
||
|
i = Idx('i', m)
|
||
|
j = Idx('j', n)
|
||
|
|
||
|
assert rust_code(A[i, j]*x[j], assign_to=y[i]) == (
|
||
|
"for i in 0..m {\n"
|
||
|
" y[i] = 0;\n"
|
||
|
"}\n"
|
||
|
"for i in 0..m {\n"
|
||
|
" for j in 0..n {\n"
|
||
|
" y[i] = A[n*i + j]*x[j] + y[i];\n"
|
||
|
" }\n"
|
||
|
"}")
|
||
|
|
||
|
assert rust_code(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i]) == (
|
||
|
"for i in 0..m {\n"
|
||
|
" y[i] = x[i] + z[i];\n"
|
||
|
"}\n"
|
||
|
"for i in 0..m {\n"
|
||
|
" for j in 0..n {\n"
|
||
|
" y[i] = A[n*i + j]*x[j] + y[i];\n"
|
||
|
" }\n"
|
||
|
"}")
|
||
|
|
||
|
|
||
|
def test_loops_multiple_contractions():
|
||
|
n, m, o, p = symbols('n m o p', integer=True)
|
||
|
a = IndexedBase('a')
|
||
|
b = IndexedBase('b')
|
||
|
y = IndexedBase('y')
|
||
|
i = Idx('i', m)
|
||
|
j = Idx('j', n)
|
||
|
k = Idx('k', o)
|
||
|
l = Idx('l', p)
|
||
|
|
||
|
assert rust_code(b[j, k, l]*a[i, j, k, l], assign_to=y[i]) == (
|
||
|
"for i in 0..m {\n"
|
||
|
" y[i] = 0;\n"
|
||
|
"}\n"
|
||
|
"for i in 0..m {\n"
|
||
|
" for j in 0..n {\n"
|
||
|
" for k in 0..o {\n"
|
||
|
" for l in 0..p {\n"
|
||
|
" y[i] = a[%s]*b[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
|
||
|
" }\n"
|
||
|
" }\n"
|
||
|
" }\n"
|
||
|
"}")
|
||
|
|
||
|
|
||
|
def test_loops_addfactor():
|
||
|
m, n, o, p = symbols('m n o p', integer=True)
|
||
|
a = IndexedBase('a')
|
||
|
b = IndexedBase('b')
|
||
|
c = IndexedBase('c')
|
||
|
y = IndexedBase('y')
|
||
|
i = Idx('i', m)
|
||
|
j = Idx('j', n)
|
||
|
k = Idx('k', o)
|
||
|
l = Idx('l', p)
|
||
|
|
||
|
code = rust_code((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i])
|
||
|
assert code == (
|
||
|
"for i in 0..m {\n"
|
||
|
" y[i] = 0;\n"
|
||
|
"}\n"
|
||
|
"for i in 0..m {\n"
|
||
|
" for j in 0..n {\n"
|
||
|
" for k in 0..o {\n"
|
||
|
" for l in 0..p {\n"
|
||
|
" y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
|
||
|
" }\n"
|
||
|
" }\n"
|
||
|
" }\n"
|
||
|
"}")
|
||
|
|
||
|
|
||
|
def test_settings():
|
||
|
raises(TypeError, lambda: rust_code(sin(x), method="garbage"))
|
||
|
|
||
|
|
||
|
def test_inline_function():
|
||
|
x = symbols('x')
|
||
|
g = implemented_function('g', Lambda(x, 2*x))
|
||
|
assert rust_code(g(x)) == "2*x"
|
||
|
|
||
|
g = implemented_function('g', Lambda(x, 2*x/Catalan))
|
||
|
assert rust_code(g(x)) == (
|
||
|
"const Catalan: f64 = %s;\n2*x/Catalan" % Catalan.evalf(17))
|
||
|
|
||
|
A = IndexedBase('A')
|
||
|
i = Idx('i', symbols('n', integer=True))
|
||
|
g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
|
||
|
assert rust_code(g(A[i]), assign_to=A[i]) == (
|
||
|
"for i in 0..n {\n"
|
||
|
" A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
|
||
|
"}")
|
||
|
|
||
|
|
||
|
def test_user_functions():
|
||
|
x = symbols('x', integer=False)
|
||
|
n = symbols('n', integer=True)
|
||
|
custom_functions = {
|
||
|
"ceiling": "ceil",
|
||
|
"Abs": [(lambda x: not x.is_integer, "fabs", 4), (lambda x: x.is_integer, "abs", 4)],
|
||
|
}
|
||
|
assert rust_code(ceiling(x), user_functions=custom_functions) == "x.ceil()"
|
||
|
assert rust_code(Abs(x), user_functions=custom_functions) == "fabs(x)"
|
||
|
assert rust_code(Abs(n), user_functions=custom_functions) == "abs(n)"
|
||
|
|
||
|
|
||
|
def test_matrix():
|
||
|
assert rust_code(Matrix([1, 2, 3])) == '[1, 2, 3]'
|
||
|
with raises(ValueError):
|
||
|
rust_code(Matrix([[1, 2, 3]]))
|
||
|
|
||
|
|
||
|
def test_sparse_matrix():
|
||
|
# gh-15791
|
||
|
assert 'Not supported in Rust' in rust_code(SparseMatrix([[1, 2, 3]]))
|