54 lines
1.6 KiB
Python
54 lines
1.6 KiB
Python
"""
|
|
Required functions for optimized contractions of numpy arrays using theano.
|
|
"""
|
|
|
|
import numpy as np
|
|
|
|
from ..sharing import to_backend_cache_wrap
|
|
|
|
__all__ = ["to_theano", "build_expression", "evaluate_constants"]
|
|
|
|
|
|
@to_backend_cache_wrap(constants=True)
|
|
def to_theano(array, constant=False):
|
|
"""Convert a numpy array to ``theano.tensor.TensorType`` instance.
|
|
"""
|
|
import theano
|
|
|
|
if isinstance(array, np.ndarray):
|
|
if constant:
|
|
return theano.tensor.constant(array)
|
|
|
|
return theano.tensor.TensorType(dtype=array.dtype, broadcastable=[False] * len(array.shape))()
|
|
|
|
return array
|
|
|
|
|
|
def build_expression(arrays, expr):
|
|
"""Build a theano function based on ``arrays`` and ``expr``.
|
|
"""
|
|
import theano
|
|
|
|
in_vars = [to_theano(array) for array in arrays]
|
|
out_var = expr._contract(in_vars, backend='theano')
|
|
|
|
# don't supply constants to graph
|
|
graph_ins = [x for x in in_vars if not isinstance(x, theano.tensor.TensorConstant)]
|
|
graph = theano.function(graph_ins, out_var)
|
|
|
|
def theano_contract(*arrays):
|
|
return graph(*[x for x in arrays if not isinstance(x, theano.tensor.TensorConstant)])
|
|
|
|
return theano_contract
|
|
|
|
|
|
def evaluate_constants(const_arrays, expr):
|
|
# compute the partial graph of new inputs
|
|
const_arrays = [to_theano(x, constant=True) for x in const_arrays]
|
|
new_ops, new_contraction_list = expr(*const_arrays, backend='theano', evaluate_constants=True)
|
|
|
|
# evaluate the new inputs and convert to theano shared tensors
|
|
new_ops = [None if x is None else to_theano(x.eval(), constant=True) for x in new_ops]
|
|
|
|
return new_ops, new_contraction_list
|