""" Required functions for optimized contractions of numpy arrays using theano. """ import numpy as np from ..sharing import to_backend_cache_wrap __all__ = ["to_theano", "build_expression", "evaluate_constants"] @to_backend_cache_wrap(constants=True) def to_theano(array, constant=False): """Convert a numpy array to ``theano.tensor.TensorType`` instance. """ import theano if isinstance(array, np.ndarray): if constant: return theano.tensor.constant(array) return theano.tensor.TensorType(dtype=array.dtype, broadcastable=[False] * len(array.shape))() return array def build_expression(arrays, expr): """Build a theano function based on ``arrays`` and ``expr``. """ import theano in_vars = [to_theano(array) for array in arrays] out_var = expr._contract(in_vars, backend='theano') # don't supply constants to graph graph_ins = [x for x in in_vars if not isinstance(x, theano.tensor.TensorConstant)] graph = theano.function(graph_ins, out_var) def theano_contract(*arrays): return graph(*[x for x in arrays if not isinstance(x, theano.tensor.TensorConstant)]) return theano_contract def evaluate_constants(const_arrays, expr): # compute the partial graph of new inputs const_arrays = [to_theano(x, constant=True) for x in const_arrays] new_ops, new_contraction_list = expr(*const_arrays, backend='theano', evaluate_constants=True) # evaluate the new inputs and convert to theano shared tensors new_ops = [None if x is None else to_theano(x.eval(), constant=True) for x in new_ops] return new_ops, new_contraction_list