783 lines
28 KiB
Python
783 lines
28 KiB
Python
|
"""
|
||
|
Fortran code printer
|
||
|
|
||
|
The FCodePrinter converts single SymPy expressions into single Fortran
|
||
|
expressions, using the functions defined in the Fortran 77 standard where
|
||
|
possible. Some useful pointers to Fortran can be found on wikipedia:
|
||
|
|
||
|
https://en.wikipedia.org/wiki/Fortran
|
||
|
|
||
|
Most of the code below is based on the "Professional Programmer\'s Guide to
|
||
|
Fortran77" by Clive G. Page:
|
||
|
|
||
|
https://www.star.le.ac.uk/~cgp/prof77.html
|
||
|
|
||
|
Fortran is a case-insensitive language. This might cause trouble because
|
||
|
SymPy is case sensitive. So, fcode adds underscores to variable names when
|
||
|
it is necessary to make them different for Fortran.
|
||
|
"""
|
||
|
|
||
|
from __future__ import annotations
|
||
|
from typing import Any
|
||
|
|
||
|
from collections import defaultdict
|
||
|
from itertools import chain
|
||
|
import string
|
||
|
|
||
|
from sympy.codegen.ast import (
|
||
|
Assignment, Declaration, Pointer, value_const,
|
||
|
float32, float64, float80, complex64, complex128, int8, int16, int32,
|
||
|
int64, intc, real, integer, bool_, complex_
|
||
|
)
|
||
|
from sympy.codegen.fnodes import (
|
||
|
allocatable, isign, dsign, cmplx, merge, literal_dp, elemental, pure,
|
||
|
intent_in, intent_out, intent_inout
|
||
|
)
|
||
|
from sympy.core import S, Add, N, Float, Symbol
|
||
|
from sympy.core.function import Function
|
||
|
from sympy.core.numbers import equal_valued
|
||
|
from sympy.core.relational import Eq
|
||
|
from sympy.sets import Range
|
||
|
from sympy.printing.codeprinter import CodePrinter
|
||
|
from sympy.printing.precedence import precedence, PRECEDENCE
|
||
|
from sympy.printing.printer import printer_context
|
||
|
|
||
|
# These are defined in the other file so we can avoid importing sympy.codegen
|
||
|
# from the top-level 'import sympy'. Export them here as well.
|
||
|
from sympy.printing.codeprinter import fcode, print_fcode # noqa:F401
|
||
|
|
||
|
known_functions = {
|
||
|
"sin": "sin",
|
||
|
"cos": "cos",
|
||
|
"tan": "tan",
|
||
|
"asin": "asin",
|
||
|
"acos": "acos",
|
||
|
"atan": "atan",
|
||
|
"atan2": "atan2",
|
||
|
"sinh": "sinh",
|
||
|
"cosh": "cosh",
|
||
|
"tanh": "tanh",
|
||
|
"log": "log",
|
||
|
"exp": "exp",
|
||
|
"erf": "erf",
|
||
|
"Abs": "abs",
|
||
|
"conjugate": "conjg",
|
||
|
"Max": "max",
|
||
|
"Min": "min",
|
||
|
}
|
||
|
|
||
|
|
||
|
class FCodePrinter(CodePrinter):
|
||
|
"""A printer to convert SymPy expressions to strings of Fortran code"""
|
||
|
printmethod = "_fcode"
|
||
|
language = "Fortran"
|
||
|
|
||
|
type_aliases = {
|
||
|
integer: int32,
|
||
|
real: float64,
|
||
|
complex_: complex128,
|
||
|
}
|
||
|
|
||
|
type_mappings = {
|
||
|
intc: 'integer(c_int)',
|
||
|
float32: 'real*4', # real(kind(0.e0))
|
||
|
float64: 'real*8', # real(kind(0.d0))
|
||
|
float80: 'real*10', # real(kind(????))
|
||
|
complex64: 'complex*8',
|
||
|
complex128: 'complex*16',
|
||
|
int8: 'integer*1',
|
||
|
int16: 'integer*2',
|
||
|
int32: 'integer*4',
|
||
|
int64: 'integer*8',
|
||
|
bool_: 'logical'
|
||
|
}
|
||
|
|
||
|
type_modules = {
|
||
|
intc: {'iso_c_binding': 'c_int'}
|
||
|
}
|
||
|
|
||
|
_default_settings: dict[str, Any] = {
|
||
|
'order': None,
|
||
|
'full_prec': 'auto',
|
||
|
'precision': 17,
|
||
|
'user_functions': {},
|
||
|
'human': True,
|
||
|
'allow_unknown_functions': False,
|
||
|
'source_format': 'fixed',
|
||
|
'contract': True,
|
||
|
'standard': 77,
|
||
|
'name_mangling': True,
|
||
|
}
|
||
|
|
||
|
_operators = {
|
||
|
'and': '.and.',
|
||
|
'or': '.or.',
|
||
|
'xor': '.neqv.',
|
||
|
'equivalent': '.eqv.',
|
||
|
'not': '.not. ',
|
||
|
}
|
||
|
|
||
|
_relationals = {
|
||
|
'!=': '/=',
|
||
|
}
|
||
|
|
||
|
def __init__(self, settings=None):
|
||
|
if not settings:
|
||
|
settings = {}
|
||
|
self.mangled_symbols = {} # Dict showing mapping of all words
|
||
|
self.used_name = []
|
||
|
self.type_aliases = dict(chain(self.type_aliases.items(),
|
||
|
settings.pop('type_aliases', {}).items()))
|
||
|
self.type_mappings = dict(chain(self.type_mappings.items(),
|
||
|
settings.pop('type_mappings', {}).items()))
|
||
|
super().__init__(settings)
|
||
|
self.known_functions = dict(known_functions)
|
||
|
userfuncs = settings.get('user_functions', {})
|
||
|
self.known_functions.update(userfuncs)
|
||
|
# leading columns depend on fixed or free format
|
||
|
standards = {66, 77, 90, 95, 2003, 2008}
|
||
|
if self._settings['standard'] not in standards:
|
||
|
raise ValueError("Unknown Fortran standard: %s" % self._settings[
|
||
|
'standard'])
|
||
|
self.module_uses = defaultdict(set) # e.g.: use iso_c_binding, only: c_int
|
||
|
|
||
|
@property
|
||
|
def _lead(self):
|
||
|
if self._settings['source_format'] == 'fixed':
|
||
|
return {'code': " ", 'cont': " @ ", 'comment': "C "}
|
||
|
elif self._settings['source_format'] == 'free':
|
||
|
return {'code': "", 'cont': " ", 'comment': "! "}
|
||
|
else:
|
||
|
raise ValueError("Unknown source format: %s" % self._settings['source_format'])
|
||
|
|
||
|
def _print_Symbol(self, expr):
|
||
|
if self._settings['name_mangling'] == True:
|
||
|
if expr not in self.mangled_symbols:
|
||
|
name = expr.name
|
||
|
while name.lower() in self.used_name:
|
||
|
name += '_'
|
||
|
self.used_name.append(name.lower())
|
||
|
if name == expr.name:
|
||
|
self.mangled_symbols[expr] = expr
|
||
|
else:
|
||
|
self.mangled_symbols[expr] = Symbol(name)
|
||
|
|
||
|
expr = expr.xreplace(self.mangled_symbols)
|
||
|
|
||
|
name = super()._print_Symbol(expr)
|
||
|
return name
|
||
|
|
||
|
def _rate_index_position(self, p):
|
||
|
return -p*5
|
||
|
|
||
|
def _get_statement(self, codestring):
|
||
|
return codestring
|
||
|
|
||
|
def _get_comment(self, text):
|
||
|
return "! {}".format(text)
|
||
|
|
||
|
def _declare_number_const(self, name, value):
|
||
|
return "parameter ({} = {})".format(name, self._print(value))
|
||
|
|
||
|
def _print_NumberSymbol(self, expr):
|
||
|
# A Number symbol that is not implemented here or with _printmethod
|
||
|
# is registered and evaluated
|
||
|
self._number_symbols.add((expr, Float(expr.evalf(self._settings['precision']))))
|
||
|
return str(expr)
|
||
|
|
||
|
def _format_code(self, lines):
|
||
|
return self._wrap_fortran(self.indent_code(lines))
|
||
|
|
||
|
def _traverse_matrix_indices(self, mat):
|
||
|
rows, cols = mat.shape
|
||
|
return ((i, j) for j in range(cols) for i in range(rows))
|
||
|
|
||
|
def _get_loop_opening_ending(self, indices):
|
||
|
open_lines = []
|
||
|
close_lines = []
|
||
|
for i in indices:
|
||
|
# fortran arrays start at 1 and end at dimension
|
||
|
var, start, stop = map(self._print,
|
||
|
[i.label, i.lower + 1, i.upper + 1])
|
||
|
open_lines.append("do %s = %s, %s" % (var, start, stop))
|
||
|
close_lines.append("end do")
|
||
|
return open_lines, close_lines
|
||
|
|
||
|
def _print_sign(self, expr):
|
||
|
from sympy.functions.elementary.complexes import Abs
|
||
|
arg, = expr.args
|
||
|
if arg.is_integer:
|
||
|
new_expr = merge(0, isign(1, arg), Eq(arg, 0))
|
||
|
elif (arg.is_complex or arg.is_infinite):
|
||
|
new_expr = merge(cmplx(literal_dp(0), literal_dp(0)), arg/Abs(arg), Eq(Abs(arg), literal_dp(0)))
|
||
|
else:
|
||
|
new_expr = merge(literal_dp(0), dsign(literal_dp(1), arg), Eq(arg, literal_dp(0)))
|
||
|
return self._print(new_expr)
|
||
|
|
||
|
|
||
|
def _print_Piecewise(self, expr):
|
||
|
if expr.args[-1].cond != True:
|
||
|
# We need the last conditional to be a True, otherwise the resulting
|
||
|
# function may not return a result.
|
||
|
raise ValueError("All Piecewise expressions must contain an "
|
||
|
"(expr, True) statement to be used as a default "
|
||
|
"condition. Without one, the generated "
|
||
|
"expression may not evaluate to anything under "
|
||
|
"some condition.")
|
||
|
lines = []
|
||
|
if expr.has(Assignment):
|
||
|
for i, (e, c) in enumerate(expr.args):
|
||
|
if i == 0:
|
||
|
lines.append("if (%s) then" % self._print(c))
|
||
|
elif i == len(expr.args) - 1 and c == True:
|
||
|
lines.append("else")
|
||
|
else:
|
||
|
lines.append("else if (%s) then" % self._print(c))
|
||
|
lines.append(self._print(e))
|
||
|
lines.append("end if")
|
||
|
return "\n".join(lines)
|
||
|
elif self._settings["standard"] >= 95:
|
||
|
# Only supported in F95 and newer:
|
||
|
# The piecewise was used in an expression, need to do inline
|
||
|
# operators. This has the downside that inline operators will
|
||
|
# not work for statements that span multiple lines (Matrix or
|
||
|
# Indexed expressions).
|
||
|
pattern = "merge({T}, {F}, {COND})"
|
||
|
code = self._print(expr.args[-1].expr)
|
||
|
terms = list(expr.args[:-1])
|
||
|
while terms:
|
||
|
e, c = terms.pop()
|
||
|
expr = self._print(e)
|
||
|
cond = self._print(c)
|
||
|
code = pattern.format(T=expr, F=code, COND=cond)
|
||
|
return code
|
||
|
else:
|
||
|
# `merge` is not supported prior to F95
|
||
|
raise NotImplementedError("Using Piecewise as an expression using "
|
||
|
"inline operators is not supported in "
|
||
|
"standards earlier than Fortran95.")
|
||
|
|
||
|
def _print_MatrixElement(self, expr):
|
||
|
return "{}({}, {})".format(self.parenthesize(expr.parent,
|
||
|
PRECEDENCE["Atom"], strict=True), expr.i + 1, expr.j + 1)
|
||
|
|
||
|
def _print_Add(self, expr):
|
||
|
# purpose: print complex numbers nicely in Fortran.
|
||
|
# collect the purely real and purely imaginary parts:
|
||
|
pure_real = []
|
||
|
pure_imaginary = []
|
||
|
mixed = []
|
||
|
for arg in expr.args:
|
||
|
if arg.is_number and arg.is_real:
|
||
|
pure_real.append(arg)
|
||
|
elif arg.is_number and arg.is_imaginary:
|
||
|
pure_imaginary.append(arg)
|
||
|
else:
|
||
|
mixed.append(arg)
|
||
|
if pure_imaginary:
|
||
|
if mixed:
|
||
|
PREC = precedence(expr)
|
||
|
term = Add(*mixed)
|
||
|
t = self._print(term)
|
||
|
if t.startswith('-'):
|
||
|
sign = "-"
|
||
|
t = t[1:]
|
||
|
else:
|
||
|
sign = "+"
|
||
|
if precedence(term) < PREC:
|
||
|
t = "(%s)" % t
|
||
|
|
||
|
return "cmplx(%s,%s) %s %s" % (
|
||
|
self._print(Add(*pure_real)),
|
||
|
self._print(-S.ImaginaryUnit*Add(*pure_imaginary)),
|
||
|
sign, t,
|
||
|
)
|
||
|
else:
|
||
|
return "cmplx(%s,%s)" % (
|
||
|
self._print(Add(*pure_real)),
|
||
|
self._print(-S.ImaginaryUnit*Add(*pure_imaginary)),
|
||
|
)
|
||
|
else:
|
||
|
return CodePrinter._print_Add(self, expr)
|
||
|
|
||
|
def _print_Function(self, expr):
|
||
|
# All constant function args are evaluated as floats
|
||
|
prec = self._settings['precision']
|
||
|
args = [N(a, prec) for a in expr.args]
|
||
|
eval_expr = expr.func(*args)
|
||
|
if not isinstance(eval_expr, Function):
|
||
|
return self._print(eval_expr)
|
||
|
else:
|
||
|
return CodePrinter._print_Function(self, expr.func(*args))
|
||
|
|
||
|
def _print_Mod(self, expr):
|
||
|
# NOTE : Fortran has the functions mod() and modulo(). modulo() behaves
|
||
|
# the same wrt to the sign of the arguments as Python and SymPy's
|
||
|
# modulus computations (% and Mod()) but is not available in Fortran 66
|
||
|
# or Fortran 77, thus we raise an error.
|
||
|
if self._settings['standard'] in [66, 77]:
|
||
|
msg = ("Python % operator and SymPy's Mod() function are not "
|
||
|
"supported by Fortran 66 or 77 standards.")
|
||
|
raise NotImplementedError(msg)
|
||
|
else:
|
||
|
x, y = expr.args
|
||
|
return " modulo({}, {})".format(self._print(x), self._print(y))
|
||
|
|
||
|
def _print_ImaginaryUnit(self, expr):
|
||
|
# purpose: print complex numbers nicely in Fortran.
|
||
|
return "cmplx(0,1)"
|
||
|
|
||
|
def _print_int(self, expr):
|
||
|
return str(expr)
|
||
|
|
||
|
def _print_Mul(self, expr):
|
||
|
# purpose: print complex numbers nicely in Fortran.
|
||
|
if expr.is_number and expr.is_imaginary:
|
||
|
return "cmplx(0,%s)" % (
|
||
|
self._print(-S.ImaginaryUnit*expr)
|
||
|
)
|
||
|
else:
|
||
|
return CodePrinter._print_Mul(self, expr)
|
||
|
|
||
|
def _print_Pow(self, expr):
|
||
|
PREC = precedence(expr)
|
||
|
if equal_valued(expr.exp, -1):
|
||
|
return '%s/%s' % (
|
||
|
self._print(literal_dp(1)),
|
||
|
self.parenthesize(expr.base, PREC)
|
||
|
)
|
||
|
elif equal_valued(expr.exp, 0.5):
|
||
|
if expr.base.is_integer:
|
||
|
# Fortran intrinsic sqrt() does not accept integer argument
|
||
|
if expr.base.is_Number:
|
||
|
return 'sqrt(%s.0d0)' % self._print(expr.base)
|
||
|
else:
|
||
|
return 'sqrt(dble(%s))' % self._print(expr.base)
|
||
|
else:
|
||
|
return 'sqrt(%s)' % self._print(expr.base)
|
||
|
else:
|
||
|
return CodePrinter._print_Pow(self, expr)
|
||
|
|
||
|
def _print_Rational(self, expr):
|
||
|
p, q = int(expr.p), int(expr.q)
|
||
|
return "%d.0d0/%d.0d0" % (p, q)
|
||
|
|
||
|
def _print_Float(self, expr):
|
||
|
printed = CodePrinter._print_Float(self, expr)
|
||
|
e = printed.find('e')
|
||
|
if e > -1:
|
||
|
return "%sd%s" % (printed[:e], printed[e + 1:])
|
||
|
return "%sd0" % printed
|
||
|
|
||
|
def _print_Relational(self, expr):
|
||
|
lhs_code = self._print(expr.lhs)
|
||
|
rhs_code = self._print(expr.rhs)
|
||
|
op = expr.rel_op
|
||
|
op = op if op not in self._relationals else self._relationals[op]
|
||
|
return "{} {} {}".format(lhs_code, op, rhs_code)
|
||
|
|
||
|
def _print_Indexed(self, expr):
|
||
|
inds = [ self._print(i) for i in expr.indices ]
|
||
|
return "%s(%s)" % (self._print(expr.base.label), ", ".join(inds))
|
||
|
|
||
|
def _print_Idx(self, expr):
|
||
|
return self._print(expr.label)
|
||
|
|
||
|
def _print_AugmentedAssignment(self, expr):
|
||
|
lhs_code = self._print(expr.lhs)
|
||
|
rhs_code = self._print(expr.rhs)
|
||
|
return self._get_statement("{0} = {0} {1} {2}".format(
|
||
|
self._print(lhs_code), self._print(expr.binop), self._print(rhs_code)))
|
||
|
|
||
|
def _print_sum_(self, sm):
|
||
|
params = self._print(sm.array)
|
||
|
if sm.dim != None: # Must use '!= None', cannot use 'is not None'
|
||
|
params += ', ' + self._print(sm.dim)
|
||
|
if sm.mask != None: # Must use '!= None', cannot use 'is not None'
|
||
|
params += ', mask=' + self._print(sm.mask)
|
||
|
return '%s(%s)' % (sm.__class__.__name__.rstrip('_'), params)
|
||
|
|
||
|
def _print_product_(self, prod):
|
||
|
return self._print_sum_(prod)
|
||
|
|
||
|
def _print_Do(self, do):
|
||
|
excl = ['concurrent']
|
||
|
if do.step == 1:
|
||
|
excl.append('step')
|
||
|
step = ''
|
||
|
else:
|
||
|
step = ', {step}'
|
||
|
|
||
|
return (
|
||
|
'do {concurrent}{counter} = {first}, {last}'+step+'\n'
|
||
|
'{body}\n'
|
||
|
'end do\n'
|
||
|
).format(
|
||
|
concurrent='concurrent ' if do.concurrent else '',
|
||
|
**do.kwargs(apply=lambda arg: self._print(arg), exclude=excl)
|
||
|
)
|
||
|
|
||
|
def _print_ImpliedDoLoop(self, idl):
|
||
|
step = '' if idl.step == 1 else ', {step}'
|
||
|
return ('({expr}, {counter} = {first}, {last}'+step+')').format(
|
||
|
**idl.kwargs(apply=lambda arg: self._print(arg))
|
||
|
)
|
||
|
|
||
|
def _print_For(self, expr):
|
||
|
target = self._print(expr.target)
|
||
|
if isinstance(expr.iterable, Range):
|
||
|
start, stop, step = expr.iterable.args
|
||
|
else:
|
||
|
raise NotImplementedError("Only iterable currently supported is Range")
|
||
|
body = self._print(expr.body)
|
||
|
return ('do {target} = {start}, {stop}, {step}\n'
|
||
|
'{body}\n'
|
||
|
'end do').format(target=target, start=start, stop=stop - 1,
|
||
|
step=step, body=body)
|
||
|
|
||
|
def _print_Type(self, type_):
|
||
|
type_ = self.type_aliases.get(type_, type_)
|
||
|
type_str = self.type_mappings.get(type_, type_.name)
|
||
|
module_uses = self.type_modules.get(type_)
|
||
|
if module_uses:
|
||
|
for k, v in module_uses:
|
||
|
self.module_uses[k].add(v)
|
||
|
return type_str
|
||
|
|
||
|
def _print_Element(self, elem):
|
||
|
return '{symbol}({idxs})'.format(
|
||
|
symbol=self._print(elem.symbol),
|
||
|
idxs=', '.join((self._print(arg) for arg in elem.indices))
|
||
|
)
|
||
|
|
||
|
def _print_Extent(self, ext):
|
||
|
return str(ext)
|
||
|
|
||
|
def _print_Declaration(self, expr):
|
||
|
var = expr.variable
|
||
|
val = var.value
|
||
|
dim = var.attr_params('dimension')
|
||
|
intents = [intent in var.attrs for intent in (intent_in, intent_out, intent_inout)]
|
||
|
if intents.count(True) == 0:
|
||
|
intent = ''
|
||
|
elif intents.count(True) == 1:
|
||
|
intent = ', intent(%s)' % ['in', 'out', 'inout'][intents.index(True)]
|
||
|
else:
|
||
|
raise ValueError("Multiple intents specified for %s" % self)
|
||
|
|
||
|
if isinstance(var, Pointer):
|
||
|
raise NotImplementedError("Pointers are not available by default in Fortran.")
|
||
|
if self._settings["standard"] >= 90:
|
||
|
result = '{t}{vc}{dim}{intent}{alloc} :: {s}'.format(
|
||
|
t=self._print(var.type),
|
||
|
vc=', parameter' if value_const in var.attrs else '',
|
||
|
dim=', dimension(%s)' % ', '.join((self._print(arg) for arg in dim)) if dim else '',
|
||
|
intent=intent,
|
||
|
alloc=', allocatable' if allocatable in var.attrs else '',
|
||
|
s=self._print(var.symbol)
|
||
|
)
|
||
|
if val != None: # Must be "!= None", cannot be "is not None"
|
||
|
result += ' = %s' % self._print(val)
|
||
|
else:
|
||
|
if value_const in var.attrs or val:
|
||
|
raise NotImplementedError("F77 init./parameter statem. req. multiple lines.")
|
||
|
result = ' '.join((self._print(arg) for arg in [var.type, var.symbol]))
|
||
|
|
||
|
return result
|
||
|
|
||
|
|
||
|
def _print_Infinity(self, expr):
|
||
|
return '(huge(%s) + 1)' % self._print(literal_dp(0))
|
||
|
|
||
|
def _print_While(self, expr):
|
||
|
return 'do while ({condition})\n{body}\nend do'.format(**expr.kwargs(
|
||
|
apply=lambda arg: self._print(arg)))
|
||
|
|
||
|
def _print_BooleanTrue(self, expr):
|
||
|
return '.true.'
|
||
|
|
||
|
def _print_BooleanFalse(self, expr):
|
||
|
return '.false.'
|
||
|
|
||
|
def _pad_leading_columns(self, lines):
|
||
|
result = []
|
||
|
for line in lines:
|
||
|
if line.startswith('!'):
|
||
|
result.append(self._lead['comment'] + line[1:].lstrip())
|
||
|
else:
|
||
|
result.append(self._lead['code'] + line)
|
||
|
return result
|
||
|
|
||
|
def _wrap_fortran(self, lines):
|
||
|
"""Wrap long Fortran lines
|
||
|
|
||
|
Argument:
|
||
|
lines -- a list of lines (without \\n character)
|
||
|
|
||
|
A comment line is split at white space. Code lines are split with a more
|
||
|
complex rule to give nice results.
|
||
|
"""
|
||
|
# routine to find split point in a code line
|
||
|
my_alnum = set("_+-." + string.digits + string.ascii_letters)
|
||
|
my_white = set(" \t()")
|
||
|
|
||
|
def split_pos_code(line, endpos):
|
||
|
if len(line) <= endpos:
|
||
|
return len(line)
|
||
|
pos = endpos
|
||
|
split = lambda pos: \
|
||
|
(line[pos] in my_alnum and line[pos - 1] not in my_alnum) or \
|
||
|
(line[pos] not in my_alnum and line[pos - 1] in my_alnum) or \
|
||
|
(line[pos] in my_white and line[pos - 1] not in my_white) or \
|
||
|
(line[pos] not in my_white and line[pos - 1] in my_white)
|
||
|
while not split(pos):
|
||
|
pos -= 1
|
||
|
if pos == 0:
|
||
|
return endpos
|
||
|
return pos
|
||
|
# split line by line and add the split lines to result
|
||
|
result = []
|
||
|
if self._settings['source_format'] == 'free':
|
||
|
trailing = ' &'
|
||
|
else:
|
||
|
trailing = ''
|
||
|
for line in lines:
|
||
|
if line.startswith(self._lead['comment']):
|
||
|
# comment line
|
||
|
if len(line) > 72:
|
||
|
pos = line.rfind(" ", 6, 72)
|
||
|
if pos == -1:
|
||
|
pos = 72
|
||
|
hunk = line[:pos]
|
||
|
line = line[pos:].lstrip()
|
||
|
result.append(hunk)
|
||
|
while line:
|
||
|
pos = line.rfind(" ", 0, 66)
|
||
|
if pos == -1 or len(line) < 66:
|
||
|
pos = 66
|
||
|
hunk = line[:pos]
|
||
|
line = line[pos:].lstrip()
|
||
|
result.append("%s%s" % (self._lead['comment'], hunk))
|
||
|
else:
|
||
|
result.append(line)
|
||
|
elif line.startswith(self._lead['code']):
|
||
|
# code line
|
||
|
pos = split_pos_code(line, 72)
|
||
|
hunk = line[:pos].rstrip()
|
||
|
line = line[pos:].lstrip()
|
||
|
if line:
|
||
|
hunk += trailing
|
||
|
result.append(hunk)
|
||
|
while line:
|
||
|
pos = split_pos_code(line, 65)
|
||
|
hunk = line[:pos].rstrip()
|
||
|
line = line[pos:].lstrip()
|
||
|
if line:
|
||
|
hunk += trailing
|
||
|
result.append("%s%s" % (self._lead['cont'], hunk))
|
||
|
else:
|
||
|
result.append(line)
|
||
|
return result
|
||
|
|
||
|
def indent_code(self, code):
|
||
|
"""Accepts a string of code or a list of code lines"""
|
||
|
if isinstance(code, str):
|
||
|
code_lines = self.indent_code(code.splitlines(True))
|
||
|
return ''.join(code_lines)
|
||
|
|
||
|
free = self._settings['source_format'] == 'free'
|
||
|
code = [ line.lstrip(' \t') for line in code ]
|
||
|
|
||
|
inc_keyword = ('do ', 'if(', 'if ', 'do\n', 'else', 'program', 'interface')
|
||
|
dec_keyword = ('end do', 'enddo', 'end if', 'endif', 'else', 'end program', 'end interface')
|
||
|
|
||
|
increase = [ int(any(map(line.startswith, inc_keyword)))
|
||
|
for line in code ]
|
||
|
decrease = [ int(any(map(line.startswith, dec_keyword)))
|
||
|
for line in code ]
|
||
|
continuation = [ int(any(map(line.endswith, ['&', '&\n'])))
|
||
|
for line in code ]
|
||
|
|
||
|
level = 0
|
||
|
cont_padding = 0
|
||
|
tabwidth = 3
|
||
|
new_code = []
|
||
|
for i, line in enumerate(code):
|
||
|
if line in ('', '\n'):
|
||
|
new_code.append(line)
|
||
|
continue
|
||
|
level -= decrease[i]
|
||
|
|
||
|
if free:
|
||
|
padding = " "*(level*tabwidth + cont_padding)
|
||
|
else:
|
||
|
padding = " "*level*tabwidth
|
||
|
|
||
|
line = "%s%s" % (padding, line)
|
||
|
if not free:
|
||
|
line = self._pad_leading_columns([line])[0]
|
||
|
|
||
|
new_code.append(line)
|
||
|
|
||
|
if continuation[i]:
|
||
|
cont_padding = 2*tabwidth
|
||
|
else:
|
||
|
cont_padding = 0
|
||
|
level += increase[i]
|
||
|
|
||
|
if not free:
|
||
|
return self._wrap_fortran(new_code)
|
||
|
return new_code
|
||
|
|
||
|
def _print_GoTo(self, goto):
|
||
|
if goto.expr: # computed goto
|
||
|
return "go to ({labels}), {expr}".format(
|
||
|
labels=', '.join((self._print(arg) for arg in goto.labels)),
|
||
|
expr=self._print(goto.expr)
|
||
|
)
|
||
|
else:
|
||
|
lbl, = goto.labels
|
||
|
return "go to %s" % self._print(lbl)
|
||
|
|
||
|
def _print_Program(self, prog):
|
||
|
return (
|
||
|
"program {name}\n"
|
||
|
"{body}\n"
|
||
|
"end program\n"
|
||
|
).format(**prog.kwargs(apply=lambda arg: self._print(arg)))
|
||
|
|
||
|
def _print_Module(self, mod):
|
||
|
return (
|
||
|
"module {name}\n"
|
||
|
"{declarations}\n"
|
||
|
"\ncontains\n\n"
|
||
|
"{definitions}\n"
|
||
|
"end module\n"
|
||
|
).format(**mod.kwargs(apply=lambda arg: self._print(arg)))
|
||
|
|
||
|
def _print_Stream(self, strm):
|
||
|
if strm.name == 'stdout' and self._settings["standard"] >= 2003:
|
||
|
self.module_uses['iso_c_binding'].add('stdint=>input_unit')
|
||
|
return 'input_unit'
|
||
|
elif strm.name == 'stderr' and self._settings["standard"] >= 2003:
|
||
|
self.module_uses['iso_c_binding'].add('stdint=>error_unit')
|
||
|
return 'error_unit'
|
||
|
else:
|
||
|
if strm.name == 'stdout':
|
||
|
return '*'
|
||
|
else:
|
||
|
return strm.name
|
||
|
|
||
|
def _print_Print(self, ps):
|
||
|
if ps.format_string != None: # Must be '!= None', cannot be 'is not None'
|
||
|
fmt = self._print(ps.format_string)
|
||
|
else:
|
||
|
fmt = "*"
|
||
|
return "print {fmt}, {iolist}".format(fmt=fmt, iolist=', '.join(
|
||
|
(self._print(arg) for arg in ps.print_args)))
|
||
|
|
||
|
def _print_Return(self, rs):
|
||
|
arg, = rs.args
|
||
|
return "{result_name} = {arg}".format(
|
||
|
result_name=self._context.get('result_name', 'sympy_result'),
|
||
|
arg=self._print(arg)
|
||
|
)
|
||
|
|
||
|
def _print_FortranReturn(self, frs):
|
||
|
arg, = frs.args
|
||
|
if arg:
|
||
|
return 'return %s' % self._print(arg)
|
||
|
else:
|
||
|
return 'return'
|
||
|
|
||
|
def _head(self, entity, fp, **kwargs):
|
||
|
bind_C_params = fp.attr_params('bind_C')
|
||
|
if bind_C_params is None:
|
||
|
bind = ''
|
||
|
else:
|
||
|
bind = ' bind(C, name="%s")' % bind_C_params[0] if bind_C_params else ' bind(C)'
|
||
|
result_name = self._settings.get('result_name', None)
|
||
|
return (
|
||
|
"{entity}{name}({arg_names}){result}{bind}\n"
|
||
|
"{arg_declarations}"
|
||
|
).format(
|
||
|
entity=entity,
|
||
|
name=self._print(fp.name),
|
||
|
arg_names=', '.join([self._print(arg.symbol) for arg in fp.parameters]),
|
||
|
result=(' result(%s)' % result_name) if result_name else '',
|
||
|
bind=bind,
|
||
|
arg_declarations='\n'.join((self._print(Declaration(arg)) for arg in fp.parameters))
|
||
|
)
|
||
|
|
||
|
def _print_FunctionPrototype(self, fp):
|
||
|
entity = "{} function ".format(self._print(fp.return_type))
|
||
|
return (
|
||
|
"interface\n"
|
||
|
"{function_head}\n"
|
||
|
"end function\n"
|
||
|
"end interface"
|
||
|
).format(function_head=self._head(entity, fp))
|
||
|
|
||
|
def _print_FunctionDefinition(self, fd):
|
||
|
if elemental in fd.attrs:
|
||
|
prefix = 'elemental '
|
||
|
elif pure in fd.attrs:
|
||
|
prefix = 'pure '
|
||
|
else:
|
||
|
prefix = ''
|
||
|
|
||
|
entity = "{} function ".format(self._print(fd.return_type))
|
||
|
with printer_context(self, result_name=fd.name):
|
||
|
return (
|
||
|
"{prefix}{function_head}\n"
|
||
|
"{body}\n"
|
||
|
"end function\n"
|
||
|
).format(
|
||
|
prefix=prefix,
|
||
|
function_head=self._head(entity, fd),
|
||
|
body=self._print(fd.body)
|
||
|
)
|
||
|
|
||
|
def _print_Subroutine(self, sub):
|
||
|
return (
|
||
|
'{subroutine_head}\n'
|
||
|
'{body}\n'
|
||
|
'end subroutine\n'
|
||
|
).format(
|
||
|
subroutine_head=self._head('subroutine ', sub),
|
||
|
body=self._print(sub.body)
|
||
|
)
|
||
|
|
||
|
def _print_SubroutineCall(self, scall):
|
||
|
return 'call {name}({args})'.format(
|
||
|
name=self._print(scall.name),
|
||
|
args=', '.join((self._print(arg) for arg in scall.subroutine_args))
|
||
|
)
|
||
|
|
||
|
def _print_use_rename(self, rnm):
|
||
|
return "%s => %s" % tuple((self._print(arg) for arg in rnm.args))
|
||
|
|
||
|
def _print_use(self, use):
|
||
|
result = 'use %s' % self._print(use.namespace)
|
||
|
if use.rename != None: # Must be '!= None', cannot be 'is not None'
|
||
|
result += ', ' + ', '.join([self._print(rnm) for rnm in use.rename])
|
||
|
if use.only != None: # Must be '!= None', cannot be 'is not None'
|
||
|
result += ', only: ' + ', '.join([self._print(nly) for nly in use.only])
|
||
|
return result
|
||
|
|
||
|
def _print_BreakToken(self, _):
|
||
|
return 'exit'
|
||
|
|
||
|
def _print_ContinueToken(self, _):
|
||
|
return 'cycle'
|
||
|
|
||
|
def _print_ArrayConstructor(self, ac):
|
||
|
fmtstr = "[%s]" if self._settings["standard"] >= 2003 else '(/%s/)'
|
||
|
return fmtstr % ', '.join((self._print(arg) for arg in ac.elements))
|
||
|
|
||
|
def _print_ArrayElement(self, elem):
|
||
|
return '{symbol}({idxs})'.format(
|
||
|
symbol=self._print(elem.name),
|
||
|
idxs=', '.join((self._print(arg) for arg in elem.indices))
|
||
|
)
|