3RNN/Lib/site-packages/opt_einsum/backends/tensorflow.py

129 lines
3.8 KiB
Python
Raw Permalink Normal View History

2024-05-26 19:49:15 +02:00
"""
Required functions for optimized contractions of numpy arrays using tensorflow.
"""
import numpy as np
from ..sharing import to_backend_cache_wrap
__all__ = ["to_tensorflow", "build_expression", "evaluate_constants"]
_CACHED_TF_DEVICE = None
def _get_tensorflow_and_device():
global _CACHED_TF_DEVICE
if _CACHED_TF_DEVICE is None:
import tensorflow as tf
try:
eager = tf.executing_eagerly()
except AttributeError:
try:
eager = tf.contrib.eager.in_eager_mode()
except AttributeError:
eager = False
device = tf.test.gpu_device_name()
if not device:
device = 'cpu'
_CACHED_TF_DEVICE = tf, device, eager
return _CACHED_TF_DEVICE
@to_backend_cache_wrap(constants=True)
def to_tensorflow(array, constant=False):
"""Convert a numpy array to a ``tensorflow.placeholder`` instance.
"""
tf, device, eager = _get_tensorflow_and_device()
if eager:
if isinstance(array, np.ndarray):
with tf.device(device):
return tf.convert_to_tensor(array)
return array
if isinstance(array, np.ndarray):
if constant:
return tf.convert_to_tensor(array)
return tf.placeholder(array.dtype, array.shape)
return array
# Standard graph mode
def build_expression_graph(arrays, expr):
"""Build a tensorflow function based on ``arrays`` and ``expr``.
"""
tf, _, _ = _get_tensorflow_and_device()
placeholders = [to_tensorflow(array) for array in arrays]
graph = expr._contract(placeholders, backend='tensorflow')
def tensorflow_contract(*arrays):
session = tf.get_default_session()
# only want to feed placeholders - constant tensors already have values
feed_dict = {p: a for p, a in zip(placeholders, arrays) if p.op.type == 'Placeholder'}
return session.run(graph, feed_dict=feed_dict)
return tensorflow_contract
def evaluate_constants_graph(const_arrays, expr):
"""Convert constant arguments to tensorflow constants, and perform any
possible constant contractions. Requires evaluating a tensorflow graph.
"""
tf, _, _ = _get_tensorflow_and_device()
# compute the partial graph of new inputs
const_arrays = [to_tensorflow(x, constant=True) for x in const_arrays]
new_ops, new_contraction_list = expr(*const_arrays, backend='tensorflow', evaluate_constants=True)
# evaluate the new inputs and convert back to tensorflow, maintaining None as non-consts
session = tf.get_default_session()
new_consts = iter(session.run([x for x in new_ops if x is not None]))
new_ops = [None if x is None else to_tensorflow(next(new_consts), constant=True) for x in new_ops]
return new_ops, new_contraction_list
# Eager execution mode
def build_expression_eager(_, expr):
"""Build a eager tensorflow function based on ``arrays`` and ``expr``.
"""
def tensorflow_eager_contract(*arrays):
return expr._contract([to_tensorflow(x) for x in arrays], backend='tensorflow').numpy()
return tensorflow_eager_contract
def evaluate_constants_eager(const_arrays, expr):
"""Convert constant arguments to tensorflow_eager arrays, and perform any
possible constant contractions.
"""
return expr(*[to_tensorflow(x) for x in const_arrays], backend='tensorflow', evaluate_constants=True)
# Dispatch to eager or graph mode
def build_expression(arrays, expr):
_, _, eager = _get_tensorflow_and_device()
fn = build_expression_eager if eager else build_expression_graph
return fn(arrays, expr)
def evaluate_constants(const_arrays, expr):
_, _, eager = _get_tensorflow_and_device()
fn = evaluate_constants_eager if eager else evaluate_constants_graph
return fn(const_arrays, expr)