3RNN/Lib/site-packages/opt_einsum/backends/theano.py
2024-05-26 19:49:15 +02:00

54 lines
1.6 KiB
Python

"""
Required functions for optimized contractions of numpy arrays using theano.
"""
import numpy as np
from ..sharing import to_backend_cache_wrap
__all__ = ["to_theano", "build_expression", "evaluate_constants"]
@to_backend_cache_wrap(constants=True)
def to_theano(array, constant=False):
"""Convert a numpy array to ``theano.tensor.TensorType`` instance.
"""
import theano
if isinstance(array, np.ndarray):
if constant:
return theano.tensor.constant(array)
return theano.tensor.TensorType(dtype=array.dtype, broadcastable=[False] * len(array.shape))()
return array
def build_expression(arrays, expr):
"""Build a theano function based on ``arrays`` and ``expr``.
"""
import theano
in_vars = [to_theano(array) for array in arrays]
out_var = expr._contract(in_vars, backend='theano')
# don't supply constants to graph
graph_ins = [x for x in in_vars if not isinstance(x, theano.tensor.TensorConstant)]
graph = theano.function(graph_ins, out_var)
def theano_contract(*arrays):
return graph(*[x for x in arrays if not isinstance(x, theano.tensor.TensorConstant)])
return theano_contract
def evaluate_constants(const_arrays, expr):
# compute the partial graph of new inputs
const_arrays = [to_theano(x, constant=True) for x in const_arrays]
new_ops, new_contraction_list = expr(*const_arrays, backend='theano', evaluate_constants=True)
# evaluate the new inputs and convert to theano shared tensors
new_ops = [None if x is None else to_theano(x.eval(), constant=True) for x in new_ops]
return new_ops, new_contraction_list