566 lines
19 KiB
Python
566 lines
19 KiB
Python
|
"""
|
||
|
.. deprecated:: 1.8
|
||
|
|
||
|
``sympy.printing.theanocode`` is deprecated. Theano has been renamed to
|
||
|
Aesara. Use ``sympy.printing.aesaracode`` instead. See
|
||
|
:ref:`theanocode-deprecated` for more information.
|
||
|
|
||
|
"""
|
||
|
from __future__ import annotations
|
||
|
from typing import Any
|
||
|
|
||
|
from sympy.external import import_module
|
||
|
from sympy.printing.printer import Printer
|
||
|
from sympy.utilities.iterables import is_sequence
|
||
|
import sympy
|
||
|
from functools import partial
|
||
|
|
||
|
from sympy.utilities.decorator import doctest_depends_on
|
||
|
from sympy.utilities.exceptions import sympy_deprecation_warning
|
||
|
|
||
|
theano = import_module('theano')
|
||
|
|
||
|
if theano:
|
||
|
ts = theano.scalar
|
||
|
tt = theano.tensor
|
||
|
from theano.sandbox import linalg as tlinalg
|
||
|
|
||
|
mapping = {
|
||
|
sympy.Add: tt.add,
|
||
|
sympy.Mul: tt.mul,
|
||
|
sympy.Abs: tt.abs_,
|
||
|
sympy.sign: tt.sgn,
|
||
|
sympy.ceiling: tt.ceil,
|
||
|
sympy.floor: tt.floor,
|
||
|
sympy.log: tt.log,
|
||
|
sympy.exp: tt.exp,
|
||
|
sympy.sqrt: tt.sqrt,
|
||
|
sympy.cos: tt.cos,
|
||
|
sympy.acos: tt.arccos,
|
||
|
sympy.sin: tt.sin,
|
||
|
sympy.asin: tt.arcsin,
|
||
|
sympy.tan: tt.tan,
|
||
|
sympy.atan: tt.arctan,
|
||
|
sympy.atan2: tt.arctan2,
|
||
|
sympy.cosh: tt.cosh,
|
||
|
sympy.acosh: tt.arccosh,
|
||
|
sympy.sinh: tt.sinh,
|
||
|
sympy.asinh: tt.arcsinh,
|
||
|
sympy.tanh: tt.tanh,
|
||
|
sympy.atanh: tt.arctanh,
|
||
|
sympy.re: tt.real,
|
||
|
sympy.im: tt.imag,
|
||
|
sympy.arg: tt.angle,
|
||
|
sympy.erf: tt.erf,
|
||
|
sympy.gamma: tt.gamma,
|
||
|
sympy.loggamma: tt.gammaln,
|
||
|
sympy.Pow: tt.pow,
|
||
|
sympy.Eq: tt.eq,
|
||
|
sympy.StrictGreaterThan: tt.gt,
|
||
|
sympy.StrictLessThan: tt.lt,
|
||
|
sympy.LessThan: tt.le,
|
||
|
sympy.GreaterThan: tt.ge,
|
||
|
sympy.And: tt.and_,
|
||
|
sympy.Or: tt.or_,
|
||
|
sympy.Max: tt.maximum, # SymPy accept >2 inputs, Theano only 2
|
||
|
sympy.Min: tt.minimum, # SymPy accept >2 inputs, Theano only 2
|
||
|
sympy.conjugate: tt.conj,
|
||
|
sympy.core.numbers.ImaginaryUnit: lambda:tt.complex(0,1),
|
||
|
# Matrices
|
||
|
sympy.MatAdd: tt.Elemwise(ts.add),
|
||
|
sympy.HadamardProduct: tt.Elemwise(ts.mul),
|
||
|
sympy.Trace: tlinalg.trace,
|
||
|
sympy.Determinant : tlinalg.det,
|
||
|
sympy.Inverse: tlinalg.matrix_inverse,
|
||
|
sympy.Transpose: tt.DimShuffle((False, False), [1, 0]),
|
||
|
}
|
||
|
|
||
|
|
||
|
class TheanoPrinter(Printer):
|
||
|
""" Code printer which creates Theano symbolic expression graphs.
|
||
|
|
||
|
Parameters
|
||
|
==========
|
||
|
|
||
|
cache : dict
|
||
|
Cache dictionary to use. If None (default) will use
|
||
|
the global cache. To create a printer which does not depend on or alter
|
||
|
global state pass an empty dictionary. Note: the dictionary is not
|
||
|
copied on initialization of the printer and will be updated in-place,
|
||
|
so using the same dict object when creating multiple printers or making
|
||
|
multiple calls to :func:`.theano_code` or :func:`.theano_function` means
|
||
|
the cache is shared between all these applications.
|
||
|
|
||
|
Attributes
|
||
|
==========
|
||
|
|
||
|
cache : dict
|
||
|
A cache of Theano variables which have been created for SymPy
|
||
|
symbol-like objects (e.g. :class:`sympy.core.symbol.Symbol` or
|
||
|
:class:`sympy.matrices.expressions.MatrixSymbol`). This is used to
|
||
|
ensure that all references to a given symbol in an expression (or
|
||
|
multiple expressions) are printed as the same Theano variable, which is
|
||
|
created only once. Symbols are differentiated only by name and type. The
|
||
|
format of the cache's contents should be considered opaque to the user.
|
||
|
"""
|
||
|
printmethod = "_theano"
|
||
|
|
||
|
def __init__(self, *args, **kwargs):
|
||
|
self.cache = kwargs.pop('cache', {})
|
||
|
super().__init__(*args, **kwargs)
|
||
|
|
||
|
def _get_key(self, s, name=None, dtype=None, broadcastable=None):
|
||
|
""" Get the cache key for a SymPy object.
|
||
|
|
||
|
Parameters
|
||
|
==========
|
||
|
|
||
|
s : sympy.core.basic.Basic
|
||
|
SymPy object to get key for.
|
||
|
|
||
|
name : str
|
||
|
Name of object, if it does not have a ``name`` attribute.
|
||
|
"""
|
||
|
|
||
|
if name is None:
|
||
|
name = s.name
|
||
|
|
||
|
return (name, type(s), s.args, dtype, broadcastable)
|
||
|
|
||
|
def _get_or_create(self, s, name=None, dtype=None, broadcastable=None):
|
||
|
"""
|
||
|
Get the Theano variable for a SymPy symbol from the cache, or create it
|
||
|
if it does not exist.
|
||
|
"""
|
||
|
|
||
|
# Defaults
|
||
|
if name is None:
|
||
|
name = s.name
|
||
|
if dtype is None:
|
||
|
dtype = 'floatX'
|
||
|
if broadcastable is None:
|
||
|
broadcastable = ()
|
||
|
|
||
|
key = self._get_key(s, name, dtype=dtype, broadcastable=broadcastable)
|
||
|
|
||
|
if key in self.cache:
|
||
|
return self.cache[key]
|
||
|
|
||
|
value = tt.tensor(name=name, dtype=dtype, broadcastable=broadcastable)
|
||
|
self.cache[key] = value
|
||
|
return value
|
||
|
|
||
|
def _print_Symbol(self, s, **kwargs):
|
||
|
dtype = kwargs.get('dtypes', {}).get(s)
|
||
|
bc = kwargs.get('broadcastables', {}).get(s)
|
||
|
return self._get_or_create(s, dtype=dtype, broadcastable=bc)
|
||
|
|
||
|
def _print_AppliedUndef(self, s, **kwargs):
|
||
|
name = str(type(s)) + '_' + str(s.args[0])
|
||
|
dtype = kwargs.get('dtypes', {}).get(s)
|
||
|
bc = kwargs.get('broadcastables', {}).get(s)
|
||
|
return self._get_or_create(s, name=name, dtype=dtype, broadcastable=bc)
|
||
|
|
||
|
def _print_Basic(self, expr, **kwargs):
|
||
|
op = mapping[type(expr)]
|
||
|
children = [self._print(arg, **kwargs) for arg in expr.args]
|
||
|
return op(*children)
|
||
|
|
||
|
def _print_Number(self, n, **kwargs):
|
||
|
# Integers already taken care of below, interpret as float
|
||
|
return float(n.evalf())
|
||
|
|
||
|
def _print_MatrixSymbol(self, X, **kwargs):
|
||
|
dtype = kwargs.get('dtypes', {}).get(X)
|
||
|
return self._get_or_create(X, dtype=dtype, broadcastable=(None, None))
|
||
|
|
||
|
def _print_DenseMatrix(self, X, **kwargs):
|
||
|
if not hasattr(tt, 'stacklists'):
|
||
|
raise NotImplementedError(
|
||
|
"Matrix translation not yet supported in this version of Theano")
|
||
|
|
||
|
return tt.stacklists([
|
||
|
[self._print(arg, **kwargs) for arg in L]
|
||
|
for L in X.tolist()
|
||
|
])
|
||
|
|
||
|
_print_ImmutableMatrix = _print_ImmutableDenseMatrix = _print_DenseMatrix
|
||
|
|
||
|
def _print_MatMul(self, expr, **kwargs):
|
||
|
children = [self._print(arg, **kwargs) for arg in expr.args]
|
||
|
result = children[0]
|
||
|
for child in children[1:]:
|
||
|
result = tt.dot(result, child)
|
||
|
return result
|
||
|
|
||
|
def _print_MatPow(self, expr, **kwargs):
|
||
|
children = [self._print(arg, **kwargs) for arg in expr.args]
|
||
|
result = 1
|
||
|
if isinstance(children[1], int) and children[1] > 0:
|
||
|
for i in range(children[1]):
|
||
|
result = tt.dot(result, children[0])
|
||
|
else:
|
||
|
raise NotImplementedError('''Only non-negative integer
|
||
|
powers of matrices can be handled by Theano at the moment''')
|
||
|
return result
|
||
|
|
||
|
def _print_MatrixSlice(self, expr, **kwargs):
|
||
|
parent = self._print(expr.parent, **kwargs)
|
||
|
rowslice = self._print(slice(*expr.rowslice), **kwargs)
|
||
|
colslice = self._print(slice(*expr.colslice), **kwargs)
|
||
|
return parent[rowslice, colslice]
|
||
|
|
||
|
def _print_BlockMatrix(self, expr, **kwargs):
|
||
|
nrows, ncols = expr.blocks.shape
|
||
|
blocks = [[self._print(expr.blocks[r, c], **kwargs)
|
||
|
for c in range(ncols)]
|
||
|
for r in range(nrows)]
|
||
|
return tt.join(0, *[tt.join(1, *row) for row in blocks])
|
||
|
|
||
|
|
||
|
def _print_slice(self, expr, **kwargs):
|
||
|
return slice(*[self._print(i, **kwargs)
|
||
|
if isinstance(i, sympy.Basic) else i
|
||
|
for i in (expr.start, expr.stop, expr.step)])
|
||
|
|
||
|
def _print_Pi(self, expr, **kwargs):
|
||
|
return 3.141592653589793
|
||
|
|
||
|
def _print_Exp1(self, expr, **kwargs):
|
||
|
return ts.exp(1)
|
||
|
|
||
|
def _print_Piecewise(self, expr, **kwargs):
|
||
|
import numpy as np
|
||
|
e, cond = expr.args[0].args # First condition and corresponding value
|
||
|
|
||
|
# Print conditional expression and value for first condition
|
||
|
p_cond = self._print(cond, **kwargs)
|
||
|
p_e = self._print(e, **kwargs)
|
||
|
|
||
|
# One condition only
|
||
|
if len(expr.args) == 1:
|
||
|
# Return value if condition else NaN
|
||
|
return tt.switch(p_cond, p_e, np.nan)
|
||
|
|
||
|
# Return value_1 if condition_1 else evaluate remaining conditions
|
||
|
p_remaining = self._print(sympy.Piecewise(*expr.args[1:]), **kwargs)
|
||
|
return tt.switch(p_cond, p_e, p_remaining)
|
||
|
|
||
|
def _print_Rational(self, expr, **kwargs):
|
||
|
return tt.true_div(self._print(expr.p, **kwargs),
|
||
|
self._print(expr.q, **kwargs))
|
||
|
|
||
|
def _print_Integer(self, expr, **kwargs):
|
||
|
return expr.p
|
||
|
|
||
|
def _print_factorial(self, expr, **kwargs):
|
||
|
return self._print(sympy.gamma(expr.args[0] + 1), **kwargs)
|
||
|
|
||
|
def _print_Derivative(self, deriv, **kwargs):
|
||
|
rv = self._print(deriv.expr, **kwargs)
|
||
|
for var in deriv.variables:
|
||
|
var = self._print(var, **kwargs)
|
||
|
rv = tt.Rop(rv, var, tt.ones_like(var))
|
||
|
return rv
|
||
|
|
||
|
def emptyPrinter(self, expr):
|
||
|
return expr
|
||
|
|
||
|
def doprint(self, expr, dtypes=None, broadcastables=None):
|
||
|
""" Convert a SymPy expression to a Theano graph variable.
|
||
|
|
||
|
The ``dtypes`` and ``broadcastables`` arguments are used to specify the
|
||
|
data type, dimension, and broadcasting behavior of the Theano variables
|
||
|
corresponding to the free symbols in ``expr``. Each is a mapping from
|
||
|
SymPy symbols to the value of the corresponding argument to
|
||
|
``theano.tensor.Tensor``.
|
||
|
|
||
|
See the corresponding `documentation page`__ for more information on
|
||
|
broadcasting in Theano.
|
||
|
|
||
|
.. __: http://deeplearning.net/software/theano/tutorial/broadcasting.html
|
||
|
|
||
|
Parameters
|
||
|
==========
|
||
|
|
||
|
expr : sympy.core.expr.Expr
|
||
|
SymPy expression to print.
|
||
|
|
||
|
dtypes : dict
|
||
|
Mapping from SymPy symbols to Theano datatypes to use when creating
|
||
|
new Theano variables for those symbols. Corresponds to the ``dtype``
|
||
|
argument to ``theano.tensor.Tensor``. Defaults to ``'floatX'``
|
||
|
for symbols not included in the mapping.
|
||
|
|
||
|
broadcastables : dict
|
||
|
Mapping from SymPy symbols to the value of the ``broadcastable``
|
||
|
argument to ``theano.tensor.Tensor`` to use when creating Theano
|
||
|
variables for those symbols. Defaults to the empty tuple for symbols
|
||
|
not included in the mapping (resulting in a scalar).
|
||
|
|
||
|
Returns
|
||
|
=======
|
||
|
|
||
|
theano.gof.graph.Variable
|
||
|
A variable corresponding to the expression's value in a Theano
|
||
|
symbolic expression graph.
|
||
|
|
||
|
"""
|
||
|
if dtypes is None:
|
||
|
dtypes = {}
|
||
|
if broadcastables is None:
|
||
|
broadcastables = {}
|
||
|
|
||
|
return self._print(expr, dtypes=dtypes, broadcastables=broadcastables)
|
||
|
|
||
|
|
||
|
global_cache: dict[Any, Any] = {}
|
||
|
|
||
|
|
||
|
def theano_code(expr, cache=None, **kwargs):
|
||
|
"""
|
||
|
Convert a SymPy expression into a Theano graph variable.
|
||
|
|
||
|
.. deprecated:: 1.8
|
||
|
|
||
|
``sympy.printing.theanocode`` is deprecated. Theano has been renamed to
|
||
|
Aesara. Use ``sympy.printing.aesaracode`` instead. See
|
||
|
:ref:`theanocode-deprecated` for more information.
|
||
|
|
||
|
Parameters
|
||
|
==========
|
||
|
|
||
|
expr : sympy.core.expr.Expr
|
||
|
SymPy expression object to convert.
|
||
|
|
||
|
cache : dict
|
||
|
Cached Theano variables (see :class:`TheanoPrinter.cache
|
||
|
<TheanoPrinter>`). Defaults to the module-level global cache.
|
||
|
|
||
|
dtypes : dict
|
||
|
Passed to :meth:`.TheanoPrinter.doprint`.
|
||
|
|
||
|
broadcastables : dict
|
||
|
Passed to :meth:`.TheanoPrinter.doprint`.
|
||
|
|
||
|
Returns
|
||
|
=======
|
||
|
|
||
|
theano.gof.graph.Variable
|
||
|
A variable corresponding to the expression's value in a Theano symbolic
|
||
|
expression graph.
|
||
|
|
||
|
"""
|
||
|
sympy_deprecation_warning(
|
||
|
"""
|
||
|
sympy.printing.theanocode is deprecated. Theano has been renamed to
|
||
|
Aesara. Use sympy.printing.aesaracode instead.""",
|
||
|
deprecated_since_version="1.8",
|
||
|
active_deprecations_target='theanocode-deprecated')
|
||
|
|
||
|
if not theano:
|
||
|
raise ImportError("theano is required for theano_code")
|
||
|
|
||
|
if cache is None:
|
||
|
cache = global_cache
|
||
|
|
||
|
return TheanoPrinter(cache=cache, settings={}).doprint(expr, **kwargs)
|
||
|
|
||
|
|
||
|
def dim_handling(inputs, dim=None, dims=None, broadcastables=None):
|
||
|
r"""
|
||
|
Get value of ``broadcastables`` argument to :func:`.theano_code` from
|
||
|
keyword arguments to :func:`.theano_function`.
|
||
|
|
||
|
Included for backwards compatibility.
|
||
|
|
||
|
Parameters
|
||
|
==========
|
||
|
|
||
|
inputs
|
||
|
Sequence of input symbols.
|
||
|
|
||
|
dim : int
|
||
|
Common number of dimensions for all inputs. Overrides other arguments
|
||
|
if given.
|
||
|
|
||
|
dims : dict
|
||
|
Mapping from input symbols to number of dimensions. Overrides
|
||
|
``broadcastables`` argument if given.
|
||
|
|
||
|
broadcastables : dict
|
||
|
Explicit value of ``broadcastables`` argument to
|
||
|
:meth:`.TheanoPrinter.doprint`. If not None function will return this value unchanged.
|
||
|
|
||
|
Returns
|
||
|
=======
|
||
|
dict
|
||
|
Dictionary mapping elements of ``inputs`` to their "broadcastable"
|
||
|
values (tuple of ``bool``\ s).
|
||
|
"""
|
||
|
if dim is not None:
|
||
|
return {s: (False,) * dim for s in inputs}
|
||
|
|
||
|
if dims is not None:
|
||
|
maxdim = max(dims.values())
|
||
|
return {
|
||
|
s: (False,) * d + (True,) * (maxdim - d)
|
||
|
for s, d in dims.items()
|
||
|
}
|
||
|
|
||
|
if broadcastables is not None:
|
||
|
return broadcastables
|
||
|
|
||
|
return {}
|
||
|
|
||
|
|
||
|
@doctest_depends_on(modules=('theano',))
|
||
|
def theano_function(inputs, outputs, scalar=False, *,
|
||
|
dim=None, dims=None, broadcastables=None, **kwargs):
|
||
|
"""
|
||
|
Create a Theano function from SymPy expressions.
|
||
|
|
||
|
.. deprecated:: 1.8
|
||
|
|
||
|
``sympy.printing.theanocode`` is deprecated. Theano has been renamed to
|
||
|
Aesara. Use ``sympy.printing.aesaracode`` instead. See
|
||
|
:ref:`theanocode-deprecated` for more information.
|
||
|
|
||
|
The inputs and outputs are converted to Theano variables using
|
||
|
:func:`.theano_code` and then passed to ``theano.function``.
|
||
|
|
||
|
Parameters
|
||
|
==========
|
||
|
|
||
|
inputs
|
||
|
Sequence of symbols which constitute the inputs of the function.
|
||
|
|
||
|
outputs
|
||
|
Sequence of expressions which constitute the outputs(s) of the
|
||
|
function. The free symbols of each expression must be a subset of
|
||
|
``inputs``.
|
||
|
|
||
|
scalar : bool
|
||
|
Convert 0-dimensional arrays in output to scalars. This will return a
|
||
|
Python wrapper function around the Theano function object.
|
||
|
|
||
|
cache : dict
|
||
|
Cached Theano variables (see :class:`TheanoPrinter.cache
|
||
|
<TheanoPrinter>`). Defaults to the module-level global cache.
|
||
|
|
||
|
dtypes : dict
|
||
|
Passed to :meth:`.TheanoPrinter.doprint`.
|
||
|
|
||
|
broadcastables : dict
|
||
|
Passed to :meth:`.TheanoPrinter.doprint`.
|
||
|
|
||
|
dims : dict
|
||
|
Alternative to ``broadcastables`` argument. Mapping from elements of
|
||
|
``inputs`` to integers indicating the dimension of their associated
|
||
|
arrays/tensors. Overrides ``broadcastables`` argument if given.
|
||
|
|
||
|
dim : int
|
||
|
Another alternative to the ``broadcastables`` argument. Common number of
|
||
|
dimensions to use for all arrays/tensors.
|
||
|
``theano_function([x, y], [...], dim=2)`` is equivalent to using
|
||
|
``broadcastables={x: (False, False), y: (False, False)}``.
|
||
|
|
||
|
Returns
|
||
|
=======
|
||
|
callable
|
||
|
A callable object which takes values of ``inputs`` as positional
|
||
|
arguments and returns an output array for each of the expressions
|
||
|
in ``outputs``. If ``outputs`` is a single expression the function will
|
||
|
return a Numpy array, if it is a list of multiple expressions the
|
||
|
function will return a list of arrays. See description of the ``squeeze``
|
||
|
argument above for the behavior when a single output is passed in a list.
|
||
|
The returned object will either be an instance of
|
||
|
``theano.compile.function_module.Function`` or a Python wrapper
|
||
|
function around one. In both cases, the returned value will have a
|
||
|
``theano_function`` attribute which points to the return value of
|
||
|
``theano.function``.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy.abc import x, y, z
|
||
|
>>> from sympy.printing.theanocode import theano_function
|
||
|
|
||
|
A simple function with one input and one output:
|
||
|
|
||
|
>>> f1 = theano_function([x], [x**2 - 1], scalar=True)
|
||
|
>>> f1(3)
|
||
|
8.0
|
||
|
|
||
|
A function with multiple inputs and one output:
|
||
|
|
||
|
>>> f2 = theano_function([x, y, z], [(x**z + y**z)**(1/z)], scalar=True)
|
||
|
>>> f2(3, 4, 2)
|
||
|
5.0
|
||
|
|
||
|
A function with multiple inputs and multiple outputs:
|
||
|
|
||
|
>>> f3 = theano_function([x, y], [x**2 + y**2, x**2 - y**2], scalar=True)
|
||
|
>>> f3(2, 3)
|
||
|
[13.0, -5.0]
|
||
|
|
||
|
See also
|
||
|
========
|
||
|
|
||
|
dim_handling
|
||
|
|
||
|
"""
|
||
|
sympy_deprecation_warning(
|
||
|
"""
|
||
|
sympy.printing.theanocode is deprecated. Theano has been renamed to Aesara. Use sympy.printing.aesaracode instead""",
|
||
|
deprecated_since_version="1.8",
|
||
|
active_deprecations_target='theanocode-deprecated')
|
||
|
|
||
|
if not theano:
|
||
|
raise ImportError("theano is required for theano_function")
|
||
|
|
||
|
# Pop off non-theano keyword args
|
||
|
cache = kwargs.pop('cache', {})
|
||
|
dtypes = kwargs.pop('dtypes', {})
|
||
|
|
||
|
broadcastables = dim_handling(
|
||
|
inputs, dim=dim, dims=dims, broadcastables=broadcastables,
|
||
|
)
|
||
|
|
||
|
# Print inputs/outputs
|
||
|
code = partial(theano_code, cache=cache, dtypes=dtypes,
|
||
|
broadcastables=broadcastables)
|
||
|
tinputs = list(map(code, inputs))
|
||
|
toutputs = list(map(code, outputs))
|
||
|
|
||
|
#fix constant expressions as variables
|
||
|
toutputs = [output if isinstance(output, theano.Variable) else tt.as_tensor_variable(output) for output in toutputs]
|
||
|
|
||
|
if len(toutputs) == 1:
|
||
|
toutputs = toutputs[0]
|
||
|
|
||
|
# Compile theano func
|
||
|
func = theano.function(tinputs, toutputs, **kwargs)
|
||
|
|
||
|
is_0d = [len(o.variable.broadcastable) == 0 for o in func.outputs]
|
||
|
|
||
|
# No wrapper required
|
||
|
if not scalar or not any(is_0d):
|
||
|
func.theano_function = func
|
||
|
return func
|
||
|
|
||
|
# Create wrapper to convert 0-dimensional outputs to scalars
|
||
|
def wrapper(*args):
|
||
|
out = func(*args)
|
||
|
# out can be array(1.0) or [array(1.0), array(2.0)]
|
||
|
|
||
|
if is_sequence(out):
|
||
|
return [o[()] if is_0d[i] else o for i, o in enumerate(out)]
|
||
|
else:
|
||
|
return out[()]
|
||
|
|
||
|
wrapper.__wrapped__ = func
|
||
|
wrapper.__doc__ = func.__doc__
|
||
|
wrapper.theano_function = func
|
||
|
return wrapper
|