from sympy.core.symbol import symbols from sympy.printing.codeprinter import ccode from sympy.codegen.ast import Declaration, Variable, float64, int64, String, CodeBlock from sympy.codegen.cnodes import ( alignof, CommaOperator, goto, Label, PreDecrement, PostDecrement, PreIncrement, PostIncrement, sizeof, union, struct ) x, y = symbols('x y') def test_alignof(): ax = alignof(x) assert ccode(ax) == 'alignof(x)' assert ax.func(*ax.args) == ax def test_CommaOperator(): expr = CommaOperator(PreIncrement(x), 2*x) assert ccode(expr) == '(++(x), 2*x)' assert expr.func(*expr.args) == expr def test_goto_Label(): s = 'early_exit' g = goto(s) assert g.func(*g.args) == g assert g != goto('foobar') assert ccode(g) == 'goto early_exit' l1 = Label(s) assert ccode(l1) == 'early_exit:' assert l1 == Label('early_exit') assert l1 != Label('foobar') body = [PreIncrement(x)] l2 = Label(s, body) assert l2.name == String("early_exit") assert l2.body == CodeBlock(PreIncrement(x)) assert ccode(l2) == ("early_exit:\n" "++(x);") body = [PreIncrement(x), PreDecrement(y)] l2 = Label(s, body) assert l2.name == String("early_exit") assert l2.body == CodeBlock(PreIncrement(x), PreDecrement(y)) assert ccode(l2) == ("early_exit:\n" "{\n ++(x);\n --(y);\n}") def test_PreDecrement(): p = PreDecrement(x) assert p.func(*p.args) == p assert ccode(p) == '--(x)' def test_PostDecrement(): p = PostDecrement(x) assert p.func(*p.args) == p assert ccode(p) == '(x)--' def test_PreIncrement(): p = PreIncrement(x) assert p.func(*p.args) == p assert ccode(p) == '++(x)' def test_PostIncrement(): p = PostIncrement(x) assert p.func(*p.args) == p assert ccode(p) == '(x)++' def test_sizeof(): typename = 'unsigned int' sz = sizeof(typename) assert ccode(sz) == 'sizeof(%s)' % typename assert sz.func(*sz.args) == sz assert not sz.is_Atom assert sz.atoms() == {String('unsigned int'), String('sizeof')} def test_struct(): vx, vy = Variable(x, type=float64), Variable(y, type=float64) s = struct('vec2', [vx, vy]) assert s.func(*s.args) == s assert s == struct('vec2', (vx, vy)) assert s != struct('vec2', (vy, vx)) assert str(s.name) == 'vec2' assert len(s.declarations) == 2 assert all(isinstance(arg, Declaration) for arg in s.declarations) assert ccode(s) == ( "struct vec2 {\n" " double x;\n" " double y;\n" "}") def test_union(): vx, vy = Variable(x, type=float64), Variable(y, type=int64) u = union('dualuse', [vx, vy]) assert u.func(*u.args) == u assert u == union('dualuse', (vx, vy)) assert str(u.name) == 'dualuse' assert len(u.declarations) == 2 assert all(isinstance(arg, Declaration) for arg in u.declarations) assert ccode(u) == ( "union dualuse {\n" " double x;\n" " int64_t y;\n" "}")