662 lines
21 KiB
Python
662 lines
21 KiB
Python
import math
|
|
from sympy.core.containers import Tuple
|
|
from sympy.core.numbers import nan, oo, Float, Integer
|
|
from sympy.core.relational import Lt
|
|
from sympy.core.symbol import symbols, Symbol
|
|
from sympy.functions.elementary.trigonometric import sin
|
|
from sympy.matrices.dense import Matrix
|
|
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
|
from sympy.sets.fancysets import Range
|
|
from sympy.tensor.indexed import Idx, IndexedBase
|
|
from sympy.testing.pytest import raises
|
|
|
|
|
|
from sympy.codegen.ast import (
|
|
Assignment, Attribute, aug_assign, CodeBlock, For, Type, Variable, Pointer, Declaration,
|
|
AddAugmentedAssignment, SubAugmentedAssignment, MulAugmentedAssignment,
|
|
DivAugmentedAssignment, ModAugmentedAssignment, value_const, pointer_const,
|
|
integer, real, complex_, int8, uint8, float16 as f16, float32 as f32,
|
|
float64 as f64, float80 as f80, float128 as f128, complex64 as c64, complex128 as c128,
|
|
While, Scope, String, Print, QuotedString, FunctionPrototype, FunctionDefinition, Return,
|
|
FunctionCall, untyped, IntBaseType, intc, Node, none, NoneToken, Token, Comment
|
|
)
|
|
|
|
x, y, z, t, x0, x1, x2, a, b = symbols("x, y, z, t, x0, x1, x2, a, b")
|
|
n = symbols("n", integer=True)
|
|
A = MatrixSymbol('A', 3, 1)
|
|
mat = Matrix([1, 2, 3])
|
|
B = IndexedBase('B')
|
|
i = Idx("i", n)
|
|
A22 = MatrixSymbol('A22',2,2)
|
|
B22 = MatrixSymbol('B22',2,2)
|
|
|
|
|
|
def test_Assignment():
|
|
# Here we just do things to show they don't error
|
|
Assignment(x, y)
|
|
Assignment(x, 0)
|
|
Assignment(A, mat)
|
|
Assignment(A[1,0], 0)
|
|
Assignment(A[1,0], x)
|
|
Assignment(B[i], x)
|
|
Assignment(B[i], 0)
|
|
a = Assignment(x, y)
|
|
assert a.func(*a.args) == a
|
|
assert a.op == ':='
|
|
# Here we test things to show that they error
|
|
# Matrix to scalar
|
|
raises(ValueError, lambda: Assignment(B[i], A))
|
|
raises(ValueError, lambda: Assignment(B[i], mat))
|
|
raises(ValueError, lambda: Assignment(x, mat))
|
|
raises(ValueError, lambda: Assignment(x, A))
|
|
raises(ValueError, lambda: Assignment(A[1,0], mat))
|
|
# Scalar to matrix
|
|
raises(ValueError, lambda: Assignment(A, x))
|
|
raises(ValueError, lambda: Assignment(A, 0))
|
|
# Non-atomic lhs
|
|
raises(TypeError, lambda: Assignment(mat, A))
|
|
raises(TypeError, lambda: Assignment(0, x))
|
|
raises(TypeError, lambda: Assignment(x*x, 1))
|
|
raises(TypeError, lambda: Assignment(A + A, mat))
|
|
raises(TypeError, lambda: Assignment(B, 0))
|
|
|
|
|
|
def test_AugAssign():
|
|
# Here we just do things to show they don't error
|
|
aug_assign(x, '+', y)
|
|
aug_assign(x, '+', 0)
|
|
aug_assign(A, '+', mat)
|
|
aug_assign(A[1, 0], '+', 0)
|
|
aug_assign(A[1, 0], '+', x)
|
|
aug_assign(B[i], '+', x)
|
|
aug_assign(B[i], '+', 0)
|
|
|
|
# Check creation via aug_assign vs constructor
|
|
for binop, cls in [
|
|
('+', AddAugmentedAssignment),
|
|
('-', SubAugmentedAssignment),
|
|
('*', MulAugmentedAssignment),
|
|
('/', DivAugmentedAssignment),
|
|
('%', ModAugmentedAssignment),
|
|
]:
|
|
a = aug_assign(x, binop, y)
|
|
b = cls(x, y)
|
|
assert a.func(*a.args) == a == b
|
|
assert a.binop == binop
|
|
assert a.op == binop + '='
|
|
|
|
# Here we test things to show that they error
|
|
# Matrix to scalar
|
|
raises(ValueError, lambda: aug_assign(B[i], '+', A))
|
|
raises(ValueError, lambda: aug_assign(B[i], '+', mat))
|
|
raises(ValueError, lambda: aug_assign(x, '+', mat))
|
|
raises(ValueError, lambda: aug_assign(x, '+', A))
|
|
raises(ValueError, lambda: aug_assign(A[1, 0], '+', mat))
|
|
# Scalar to matrix
|
|
raises(ValueError, lambda: aug_assign(A, '+', x))
|
|
raises(ValueError, lambda: aug_assign(A, '+', 0))
|
|
# Non-atomic lhs
|
|
raises(TypeError, lambda: aug_assign(mat, '+', A))
|
|
raises(TypeError, lambda: aug_assign(0, '+', x))
|
|
raises(TypeError, lambda: aug_assign(x * x, '+', 1))
|
|
raises(TypeError, lambda: aug_assign(A + A, '+', mat))
|
|
raises(TypeError, lambda: aug_assign(B, '+', 0))
|
|
|
|
|
|
def test_Assignment_printing():
|
|
assignment_classes = [
|
|
Assignment,
|
|
AddAugmentedAssignment,
|
|
SubAugmentedAssignment,
|
|
MulAugmentedAssignment,
|
|
DivAugmentedAssignment,
|
|
ModAugmentedAssignment,
|
|
]
|
|
pairs = [
|
|
(x, 2 * y + 2),
|
|
(B[i], x),
|
|
(A22, B22),
|
|
(A[0, 0], x),
|
|
]
|
|
|
|
for cls in assignment_classes:
|
|
for lhs, rhs in pairs:
|
|
a = cls(lhs, rhs)
|
|
assert repr(a) == '%s(%s, %s)' % (cls.__name__, repr(lhs), repr(rhs))
|
|
|
|
|
|
def test_CodeBlock():
|
|
c = CodeBlock(Assignment(x, 1), Assignment(y, x + 1))
|
|
assert c.func(*c.args) == c
|
|
|
|
assert c.left_hand_sides == Tuple(x, y)
|
|
assert c.right_hand_sides == Tuple(1, x + 1)
|
|
|
|
def test_CodeBlock_topological_sort():
|
|
assignments = [
|
|
Assignment(x, y + z),
|
|
Assignment(z, 1),
|
|
Assignment(t, x),
|
|
Assignment(y, 2),
|
|
]
|
|
|
|
ordered_assignments = [
|
|
# Note that the unrelated z=1 and y=2 are kept in that order
|
|
Assignment(z, 1),
|
|
Assignment(y, 2),
|
|
Assignment(x, y + z),
|
|
Assignment(t, x),
|
|
]
|
|
c1 = CodeBlock.topological_sort(assignments)
|
|
assert c1 == CodeBlock(*ordered_assignments)
|
|
|
|
# Cycle
|
|
invalid_assignments = [
|
|
Assignment(x, y + z),
|
|
Assignment(z, 1),
|
|
Assignment(y, x),
|
|
Assignment(y, 2),
|
|
]
|
|
|
|
raises(ValueError, lambda: CodeBlock.topological_sort(invalid_assignments))
|
|
|
|
# Free symbols
|
|
free_assignments = [
|
|
Assignment(x, y + z),
|
|
Assignment(z, a * b),
|
|
Assignment(t, x),
|
|
Assignment(y, b + 3),
|
|
]
|
|
|
|
free_assignments_ordered = [
|
|
Assignment(z, a * b),
|
|
Assignment(y, b + 3),
|
|
Assignment(x, y + z),
|
|
Assignment(t, x),
|
|
]
|
|
|
|
c2 = CodeBlock.topological_sort(free_assignments)
|
|
assert c2 == CodeBlock(*free_assignments_ordered)
|
|
|
|
def test_CodeBlock_free_symbols():
|
|
c1 = CodeBlock(
|
|
Assignment(x, y + z),
|
|
Assignment(z, 1),
|
|
Assignment(t, x),
|
|
Assignment(y, 2),
|
|
)
|
|
assert c1.free_symbols == set()
|
|
|
|
c2 = CodeBlock(
|
|
Assignment(x, y + z),
|
|
Assignment(z, a * b),
|
|
Assignment(t, x),
|
|
Assignment(y, b + 3),
|
|
)
|
|
assert c2.free_symbols == {a, b}
|
|
|
|
def test_CodeBlock_cse():
|
|
c1 = CodeBlock(
|
|
Assignment(y, 1),
|
|
Assignment(x, sin(y)),
|
|
Assignment(z, sin(y)),
|
|
Assignment(t, x*z),
|
|
)
|
|
assert c1.cse() == CodeBlock(
|
|
Assignment(y, 1),
|
|
Assignment(x0, sin(y)),
|
|
Assignment(x, x0),
|
|
Assignment(z, x0),
|
|
Assignment(t, x*z),
|
|
)
|
|
|
|
# Multiple assignments to same symbol not supported
|
|
raises(NotImplementedError, lambda: CodeBlock(
|
|
Assignment(x, 1),
|
|
Assignment(y, 1), Assignment(y, 2)
|
|
).cse())
|
|
|
|
# Check auto-generated symbols do not collide with existing ones
|
|
c2 = CodeBlock(
|
|
Assignment(x0, sin(y) + 1),
|
|
Assignment(x1, 2 * sin(y)),
|
|
Assignment(z, x * y),
|
|
)
|
|
assert c2.cse() == CodeBlock(
|
|
Assignment(x2, sin(y)),
|
|
Assignment(x0, x2 + 1),
|
|
Assignment(x1, 2 * x2),
|
|
Assignment(z, x * y),
|
|
)
|
|
|
|
|
|
def test_CodeBlock_cse__issue_14118():
|
|
# see https://github.com/sympy/sympy/issues/14118
|
|
c = CodeBlock(
|
|
Assignment(A22, Matrix([[x, sin(y)],[3, 4]])),
|
|
Assignment(B22, Matrix([[sin(y), 2*sin(y)], [sin(y)**2, 7]]))
|
|
)
|
|
assert c.cse() == CodeBlock(
|
|
Assignment(x0, sin(y)),
|
|
Assignment(A22, Matrix([[x, x0],[3, 4]])),
|
|
Assignment(B22, Matrix([[x0, 2*x0], [x0**2, 7]]))
|
|
)
|
|
|
|
def test_For():
|
|
f = For(n, Range(0, 3), (Assignment(A[n, 0], x + n), aug_assign(x, '+', y)))
|
|
f = For(n, (1, 2, 3, 4, 5), (Assignment(A[n, 0], x + n),))
|
|
assert f.func(*f.args) == f
|
|
raises(TypeError, lambda: For(n, x, (x + y,)))
|
|
|
|
|
|
def test_none():
|
|
assert none.is_Atom
|
|
assert none == none
|
|
class Foo(Token):
|
|
pass
|
|
foo = Foo()
|
|
assert foo != none
|
|
assert none == None
|
|
assert none == NoneToken()
|
|
assert none.func(*none.args) == none
|
|
|
|
|
|
def test_String():
|
|
st = String('foobar')
|
|
assert st.is_Atom
|
|
assert st == String('foobar')
|
|
assert st.text == 'foobar'
|
|
assert st.func(**st.kwargs()) == st
|
|
assert st.func(*st.args) == st
|
|
|
|
|
|
class Signifier(String):
|
|
pass
|
|
|
|
si = Signifier('foobar')
|
|
assert si != st
|
|
assert si.text == st.text
|
|
s = String('foo')
|
|
assert str(s) == 'foo'
|
|
assert repr(s) == "String('foo')"
|
|
|
|
def test_Comment():
|
|
c = Comment('foobar')
|
|
assert c.text == 'foobar'
|
|
assert str(c) == 'foobar'
|
|
|
|
def test_Node():
|
|
n = Node()
|
|
assert n == Node()
|
|
assert n.func(*n.args) == n
|
|
|
|
|
|
def test_Type():
|
|
t = Type('MyType')
|
|
assert len(t.args) == 1
|
|
assert t.name == String('MyType')
|
|
assert str(t) == 'MyType'
|
|
assert repr(t) == "Type(String('MyType'))"
|
|
assert Type(t) == t
|
|
assert t.func(*t.args) == t
|
|
t1 = Type('t1')
|
|
t2 = Type('t2')
|
|
assert t1 != t2
|
|
assert t1 == t1 and t2 == t2
|
|
t1b = Type('t1')
|
|
assert t1 == t1b
|
|
assert t2 != t1b
|
|
|
|
|
|
def test_Type__from_expr():
|
|
assert Type.from_expr(i) == integer
|
|
u = symbols('u', real=True)
|
|
assert Type.from_expr(u) == real
|
|
assert Type.from_expr(n) == integer
|
|
assert Type.from_expr(3) == integer
|
|
assert Type.from_expr(3.0) == real
|
|
assert Type.from_expr(3+1j) == complex_
|
|
raises(ValueError, lambda: Type.from_expr(sum))
|
|
|
|
|
|
def test_Type__cast_check__integers():
|
|
# Rounding
|
|
raises(ValueError, lambda: integer.cast_check(3.5))
|
|
assert integer.cast_check('3') == 3
|
|
assert integer.cast_check(Float('3.0000000000000000000')) == 3
|
|
assert integer.cast_check(Float('3.0000000000000000001')) == 3 # unintuitive maybe?
|
|
|
|
# Range
|
|
assert int8.cast_check(127.0) == 127
|
|
raises(ValueError, lambda: int8.cast_check(128))
|
|
assert int8.cast_check(-128) == -128
|
|
raises(ValueError, lambda: int8.cast_check(-129))
|
|
|
|
assert uint8.cast_check(0) == 0
|
|
assert uint8.cast_check(128) == 128
|
|
raises(ValueError, lambda: uint8.cast_check(256.0))
|
|
raises(ValueError, lambda: uint8.cast_check(-1))
|
|
|
|
def test_Attribute():
|
|
noexcept = Attribute('noexcept')
|
|
assert noexcept == Attribute('noexcept')
|
|
alignas16 = Attribute('alignas', [16])
|
|
alignas32 = Attribute('alignas', [32])
|
|
assert alignas16 != alignas32
|
|
assert alignas16.func(*alignas16.args) == alignas16
|
|
|
|
|
|
def test_Variable():
|
|
v = Variable(x, type=real)
|
|
assert v == Variable(v)
|
|
assert v == Variable('x', type=real)
|
|
assert v.symbol == x
|
|
assert v.type == real
|
|
assert value_const not in v.attrs
|
|
assert v.func(*v.args) == v
|
|
assert str(v) == 'Variable(x, type=real)'
|
|
|
|
w = Variable(y, f32, attrs={value_const})
|
|
assert w.symbol == y
|
|
assert w.type == f32
|
|
assert value_const in w.attrs
|
|
assert w.func(*w.args) == w
|
|
|
|
v_n = Variable(n, type=Type.from_expr(n))
|
|
assert v_n.type == integer
|
|
assert v_n.func(*v_n.args) == v_n
|
|
v_i = Variable(i, type=Type.from_expr(n))
|
|
assert v_i.type == integer
|
|
assert v_i != v_n
|
|
|
|
a_i = Variable.deduced(i)
|
|
assert a_i.type == integer
|
|
assert Variable.deduced(Symbol('x', real=True)).type == real
|
|
assert a_i.func(*a_i.args) == a_i
|
|
|
|
v_n2 = Variable.deduced(n, value=3.5, cast_check=False)
|
|
assert v_n2.func(*v_n2.args) == v_n2
|
|
assert abs(v_n2.value - 3.5) < 1e-15
|
|
raises(ValueError, lambda: Variable.deduced(n, value=3.5, cast_check=True))
|
|
|
|
v_n3 = Variable.deduced(n)
|
|
assert v_n3.type == integer
|
|
assert str(v_n3) == 'Variable(n, type=integer)'
|
|
assert Variable.deduced(z, value=3).type == integer
|
|
assert Variable.deduced(z, value=3.0).type == real
|
|
assert Variable.deduced(z, value=3.0+1j).type == complex_
|
|
|
|
|
|
def test_Pointer():
|
|
p = Pointer(x)
|
|
assert p.symbol == x
|
|
assert p.type == untyped
|
|
assert value_const not in p.attrs
|
|
assert pointer_const not in p.attrs
|
|
assert p.func(*p.args) == p
|
|
|
|
u = symbols('u', real=True)
|
|
pu = Pointer(u, type=Type.from_expr(u), attrs={value_const, pointer_const})
|
|
assert pu.symbol is u
|
|
assert pu.type == real
|
|
assert value_const in pu.attrs
|
|
assert pointer_const in pu.attrs
|
|
assert pu.func(*pu.args) == pu
|
|
|
|
i = symbols('i', integer=True)
|
|
deref = pu[i]
|
|
assert deref.indices == (i,)
|
|
|
|
|
|
def test_Declaration():
|
|
u = symbols('u', real=True)
|
|
vu = Variable(u, type=Type.from_expr(u))
|
|
assert Declaration(vu).variable.type == real
|
|
vn = Variable(n, type=Type.from_expr(n))
|
|
assert Declaration(vn).variable.type == integer
|
|
|
|
# PR 19107, does not allow comparison between expressions and Basic
|
|
# lt = StrictLessThan(vu, vn)
|
|
# assert isinstance(lt, StrictLessThan)
|
|
|
|
vuc = Variable(u, Type.from_expr(u), value=3.0, attrs={value_const})
|
|
assert value_const in vuc.attrs
|
|
assert pointer_const not in vuc.attrs
|
|
decl = Declaration(vuc)
|
|
assert decl.variable == vuc
|
|
assert isinstance(decl.variable.value, Float)
|
|
assert decl.variable.value == 3.0
|
|
assert decl.func(*decl.args) == decl
|
|
assert vuc.as_Declaration() == decl
|
|
assert vuc.as_Declaration(value=None, attrs=None) == Declaration(vu)
|
|
|
|
vy = Variable(y, type=integer, value=3)
|
|
decl2 = Declaration(vy)
|
|
assert decl2.variable == vy
|
|
assert decl2.variable.value == Integer(3)
|
|
|
|
vi = Variable(i, type=Type.from_expr(i), value=3.0)
|
|
decl3 = Declaration(vi)
|
|
assert decl3.variable.type == integer
|
|
assert decl3.variable.value == 3.0
|
|
|
|
raises(ValueError, lambda: Declaration(vi, 42))
|
|
|
|
|
|
def test_IntBaseType():
|
|
assert intc.name == String('intc')
|
|
assert intc.args == (intc.name,)
|
|
assert str(IntBaseType('a').name) == 'a'
|
|
|
|
|
|
def test_FloatType():
|
|
assert f16.dig == 3
|
|
assert f32.dig == 6
|
|
assert f64.dig == 15
|
|
assert f80.dig == 18
|
|
assert f128.dig == 33
|
|
|
|
assert f16.decimal_dig == 5
|
|
assert f32.decimal_dig == 9
|
|
assert f64.decimal_dig == 17
|
|
assert f80.decimal_dig == 21
|
|
assert f128.decimal_dig == 36
|
|
|
|
assert f16.max_exponent == 16
|
|
assert f32.max_exponent == 128
|
|
assert f64.max_exponent == 1024
|
|
assert f80.max_exponent == 16384
|
|
assert f128.max_exponent == 16384
|
|
|
|
assert f16.min_exponent == -13
|
|
assert f32.min_exponent == -125
|
|
assert f64.min_exponent == -1021
|
|
assert f80.min_exponent == -16381
|
|
assert f128.min_exponent == -16381
|
|
|
|
assert abs(f16.eps / Float('0.00097656', precision=16) - 1) < 0.1*10**-f16.dig
|
|
assert abs(f32.eps / Float('1.1920929e-07', precision=32) - 1) < 0.1*10**-f32.dig
|
|
assert abs(f64.eps / Float('2.2204460492503131e-16', precision=64) - 1) < 0.1*10**-f64.dig
|
|
assert abs(f80.eps / Float('1.08420217248550443401e-19', precision=80) - 1) < 0.1*10**-f80.dig
|
|
assert abs(f128.eps / Float(' 1.92592994438723585305597794258492732e-34', precision=128) - 1) < 0.1*10**-f128.dig
|
|
|
|
assert abs(f16.max / Float('65504', precision=16) - 1) < .1*10**-f16.dig
|
|
assert abs(f32.max / Float('3.40282347e+38', precision=32) - 1) < 0.1*10**-f32.dig
|
|
assert abs(f64.max / Float('1.79769313486231571e+308', precision=64) - 1) < 0.1*10**-f64.dig # cf. np.finfo(np.float64).max
|
|
assert abs(f80.max / Float('1.18973149535723176502e+4932', precision=80) - 1) < 0.1*10**-f80.dig
|
|
assert abs(f128.max / Float('1.18973149535723176508575932662800702e+4932', precision=128) - 1) < 0.1*10**-f128.dig
|
|
|
|
# cf. np.finfo(np.float32).tiny
|
|
assert abs(f16.tiny / Float('6.1035e-05', precision=16) - 1) < 0.1*10**-f16.dig
|
|
assert abs(f32.tiny / Float('1.17549435e-38', precision=32) - 1) < 0.1*10**-f32.dig
|
|
assert abs(f64.tiny / Float('2.22507385850720138e-308', precision=64) - 1) < 0.1*10**-f64.dig
|
|
assert abs(f80.tiny / Float('3.36210314311209350626e-4932', precision=80) - 1) < 0.1*10**-f80.dig
|
|
assert abs(f128.tiny / Float('3.3621031431120935062626778173217526e-4932', precision=128) - 1) < 0.1*10**-f128.dig
|
|
|
|
assert f64.cast_check(0.5) == Float(0.5, 17)
|
|
assert abs(f64.cast_check(3.7) - 3.7) < 3e-17
|
|
assert isinstance(f64.cast_check(3), (Float, float))
|
|
|
|
assert f64.cast_nocheck(oo) == float('inf')
|
|
assert f64.cast_nocheck(-oo) == float('-inf')
|
|
assert f64.cast_nocheck(float(oo)) == float('inf')
|
|
assert f64.cast_nocheck(float(-oo)) == float('-inf')
|
|
assert math.isnan(f64.cast_nocheck(nan))
|
|
|
|
assert f32 != f64
|
|
assert f64 == f64.func(*f64.args)
|
|
|
|
|
|
def test_Type__cast_check__floating_point():
|
|
raises(ValueError, lambda: f32.cast_check(123.45678949))
|
|
raises(ValueError, lambda: f32.cast_check(12.345678949))
|
|
raises(ValueError, lambda: f32.cast_check(1.2345678949))
|
|
raises(ValueError, lambda: f32.cast_check(.12345678949))
|
|
assert abs(123.456789049 - f32.cast_check(123.456789049) - 4.9e-8) < 1e-8
|
|
assert abs(0.12345678904 - f32.cast_check(0.12345678904) - 4e-11) < 1e-11
|
|
|
|
dcm21 = Float('0.123456789012345670499') # 21 decimals
|
|
assert abs(dcm21 - f64.cast_check(dcm21) - 4.99e-19) < 1e-19
|
|
|
|
f80.cast_check(Float('0.12345678901234567890103', precision=88))
|
|
raises(ValueError, lambda: f80.cast_check(Float('0.12345678901234567890149', precision=88)))
|
|
|
|
v10 = 12345.67894
|
|
raises(ValueError, lambda: f32.cast_check(v10))
|
|
assert abs(Float(str(v10), precision=64+8) - f64.cast_check(v10)) < v10*1e-16
|
|
|
|
assert abs(f32.cast_check(2147483647) - 2147483650) < 1
|
|
|
|
|
|
def test_Type__cast_check__complex_floating_point():
|
|
val9_11 = 123.456789049 + 0.123456789049j
|
|
raises(ValueError, lambda: c64.cast_check(.12345678949 + .12345678949j))
|
|
assert abs(val9_11 - c64.cast_check(val9_11) - 4.9e-8) < 1e-8
|
|
|
|
dcm21 = Float('0.123456789012345670499') + 1e-20j # 21 decimals
|
|
assert abs(dcm21 - c128.cast_check(dcm21) - 4.99e-19) < 1e-19
|
|
v19 = Float('0.1234567890123456749') + 1j*Float('0.1234567890123456749')
|
|
raises(ValueError, lambda: c128.cast_check(v19))
|
|
|
|
|
|
def test_While():
|
|
xpp = AddAugmentedAssignment(x, 1)
|
|
whl1 = While(x < 2, [xpp])
|
|
assert whl1.condition.args[0] == x
|
|
assert whl1.condition.args[1] == 2
|
|
assert whl1.condition == Lt(x, 2, evaluate=False)
|
|
assert whl1.body.args == (xpp,)
|
|
assert whl1.func(*whl1.args) == whl1
|
|
|
|
cblk = CodeBlock(AddAugmentedAssignment(x, 1))
|
|
whl2 = While(x < 2, cblk)
|
|
assert whl1 == whl2
|
|
assert whl1 != While(x < 3, [xpp])
|
|
|
|
|
|
def test_Scope():
|
|
assign = Assignment(x, y)
|
|
incr = AddAugmentedAssignment(x, 1)
|
|
scp = Scope([assign, incr])
|
|
cblk = CodeBlock(assign, incr)
|
|
assert scp.body == cblk
|
|
assert scp == Scope(cblk)
|
|
assert scp != Scope([incr, assign])
|
|
assert scp.func(*scp.args) == scp
|
|
|
|
|
|
def test_Print():
|
|
fmt = "%d %.3f"
|
|
ps = Print([n, x], fmt)
|
|
assert str(ps.format_string) == fmt
|
|
assert ps.print_args == Tuple(n, x)
|
|
assert ps.args == (Tuple(n, x), QuotedString(fmt), none)
|
|
assert ps == Print((n, x), fmt)
|
|
assert ps != Print([x, n], fmt)
|
|
assert ps.func(*ps.args) == ps
|
|
|
|
ps2 = Print([n, x])
|
|
assert ps2 == Print([n, x])
|
|
assert ps2 != ps
|
|
assert ps2.format_string == None
|
|
|
|
|
|
def test_FunctionPrototype_and_FunctionDefinition():
|
|
vx = Variable(x, type=real)
|
|
vn = Variable(n, type=integer)
|
|
fp1 = FunctionPrototype(real, 'power', [vx, vn])
|
|
assert fp1.return_type == real
|
|
assert fp1.name == String('power')
|
|
assert fp1.parameters == Tuple(vx, vn)
|
|
assert fp1 == FunctionPrototype(real, 'power', [vx, vn])
|
|
assert fp1 != FunctionPrototype(real, 'power', [vn, vx])
|
|
assert fp1.func(*fp1.args) == fp1
|
|
|
|
|
|
body = [Assignment(x, x**n), Return(x)]
|
|
fd1 = FunctionDefinition(real, 'power', [vx, vn], body)
|
|
assert fd1.return_type == real
|
|
assert str(fd1.name) == 'power'
|
|
assert fd1.parameters == Tuple(vx, vn)
|
|
assert fd1.body == CodeBlock(*body)
|
|
assert fd1 == FunctionDefinition(real, 'power', [vx, vn], body)
|
|
assert fd1 != FunctionDefinition(real, 'power', [vx, vn], body[::-1])
|
|
assert fd1.func(*fd1.args) == fd1
|
|
|
|
fp2 = FunctionPrototype.from_FunctionDefinition(fd1)
|
|
assert fp2 == fp1
|
|
|
|
fd2 = FunctionDefinition.from_FunctionPrototype(fp1, body)
|
|
assert fd2 == fd1
|
|
|
|
|
|
def test_Return():
|
|
rs = Return(x)
|
|
assert rs.args == (x,)
|
|
assert rs == Return(x)
|
|
assert rs != Return(y)
|
|
assert rs.func(*rs.args) == rs
|
|
|
|
|
|
def test_FunctionCall():
|
|
fc = FunctionCall('power', (x, 3))
|
|
assert fc.function_args[0] == x
|
|
assert fc.function_args[1] == 3
|
|
assert len(fc.function_args) == 2
|
|
assert isinstance(fc.function_args[1], Integer)
|
|
assert fc == FunctionCall('power', (x, 3))
|
|
assert fc != FunctionCall('power', (3, x))
|
|
assert fc != FunctionCall('Power', (x, 3))
|
|
assert fc.func(*fc.args) == fc
|
|
|
|
fc2 = FunctionCall('fma', [2, 3, 4])
|
|
assert len(fc2.function_args) == 3
|
|
assert fc2.function_args[0] == 2
|
|
assert fc2.function_args[1] == 3
|
|
assert fc2.function_args[2] == 4
|
|
assert str(fc2) in ( # not sure if QuotedString is a better default...
|
|
'FunctionCall(fma, function_args=(2, 3, 4))',
|
|
'FunctionCall("fma", function_args=(2, 3, 4))',
|
|
)
|
|
|
|
def test_ast_replace():
|
|
x = Variable('x', real)
|
|
y = Variable('y', real)
|
|
n = Variable('n', integer)
|
|
|
|
pwer = FunctionDefinition(real, 'pwer', [x, n], [pow(x.symbol, n.symbol)])
|
|
pname = pwer.name
|
|
pcall = FunctionCall('pwer', [y, 3])
|
|
|
|
tree1 = CodeBlock(pwer, pcall)
|
|
assert str(tree1.args[0].name) == 'pwer'
|
|
assert str(tree1.args[1].name) == 'pwer'
|
|
for a, b in zip(tree1, [pwer, pcall]):
|
|
assert a == b
|
|
|
|
tree2 = tree1.replace(pname, String('power'))
|
|
assert str(tree1.args[0].name) == 'pwer'
|
|
assert str(tree1.args[1].name) == 'pwer'
|
|
assert str(tree2.args[0].name) == 'power'
|
|
assert str(tree2.args[1].name) == 'power'
|