3RNN/Lib/site-packages/opt_einsum/backends/dispatch.py
2024-05-26 19:49:15 +02:00

146 lines
4.3 KiB
Python

"""
Handles dispatching array operations to the correct backend library, as well
as converting arrays to backend formats and then potentially storing them as
constants.
"""
import importlib
import numpy
from . import object_arrays
from . import cupy as _cupy
from . import jax as _jax
from . import tensorflow as _tensorflow
from . import theano as _theano
from . import torch as _torch
__all__ = ["get_func", "has_einsum", "has_tensordot", "build_expression", "evaluate_constants", "has_backend"]
# known non top-level imports
_aliases = {
'dask': 'dask.array',
'theano': 'theano.tensor',
'torch': 'opt_einsum.backends.torch',
'jax': 'jax.numpy',
'autograd': 'autograd.numpy',
'mars': 'mars.tensor',
}
def _import_func(func, backend, default=None):
"""Try and import ``{backend}.{func}``.
If library is installed and func is found, return the func;
otherwise if default is provided, return default;
otherwise raise an error.
"""
try:
lib = importlib.import_module(_aliases.get(backend, backend))
return getattr(lib, func) if default is None else getattr(lib, func, default)
except AttributeError:
error_msg = ("{} doesn't seem to provide the function {} - see "
"https://optimized-einsum.readthedocs.io/en/latest/backends.html "
"for details on which functions are required for which contractions.")
raise AttributeError(error_msg.format(backend, func))
# manually cache functions as python2 doesn't support functools.lru_cache
# other libs will be added to this if needed, but pre-populate with numpy
_cached_funcs = {
('tensordot', 'numpy'): numpy.tensordot,
('transpose', 'numpy'): numpy.transpose,
('einsum', 'numpy'): numpy.einsum,
# also pre-populate with the arbitrary object backend
('tensordot', 'object'): numpy.tensordot,
('transpose', 'object'): numpy.transpose,
('einsum', 'object'): object_arrays.object_einsum,
}
def get_func(func, backend='numpy', default=None):
"""Return ``{backend}.{func}``, e.g. ``numpy.einsum``,
or a default func if provided. Cache result.
"""
try:
return _cached_funcs[func, backend]
except KeyError:
fn = _import_func(func, backend, default)
_cached_funcs[func, backend] = fn
return fn
# mark libs with einsum, else try to use tensordot/tranpose as much as possible
_has_einsum = {}
def has_einsum(backend):
"""Check if ``{backend}.einsum`` exists, cache result for performance.
"""
try:
return _has_einsum[backend]
except KeyError:
try:
get_func('einsum', backend)
_has_einsum[backend] = True
except AttributeError:
_has_einsum[backend] = False
return _has_einsum[backend]
_has_tensordot = {}
def has_tensordot(backend):
"""Check if ``{backend}.tensordot`` exists, cache result for performance.
"""
try:
return _has_tensordot[backend]
except KeyError:
try:
get_func('tensordot', backend)
_has_tensordot[backend] = True
except AttributeError:
_has_tensordot[backend] = False
return _has_tensordot[backend]
# Dispatch to correct expression backend
# these are the backends which support explicit to-and-from numpy conversion
CONVERT_BACKENDS = {
'tensorflow': _tensorflow.build_expression,
'theano': _theano.build_expression,
'cupy': _cupy.build_expression,
'torch': _torch.build_expression,
'jax': _jax.build_expression,
}
EVAL_CONSTS_BACKENDS = {
'tensorflow': _tensorflow.evaluate_constants,
'theano': _theano.evaluate_constants,
'cupy': _cupy.evaluate_constants,
'torch': _torch.evaluate_constants,
'jax': _jax.evaluate_constants,
}
def build_expression(backend, arrays, expr):
"""Build an expression, based on ``expr`` and initial arrays ``arrays``,
that evaluates using backend ``backend``.
"""
return CONVERT_BACKENDS[backend](arrays, expr)
def evaluate_constants(backend, arrays, expr):
"""Convert constant arrays to the correct backend, and perform as much of
the contraction of ``expr`` with these as possible.
"""
return EVAL_CONSTS_BACKENDS[backend](arrays, expr)
def has_backend(backend):
"""Checks if the backend is known.
"""
return backend.lower() in CONVERT_BACKENDS