3RNN/Lib/site-packages/opt_einsum/backends/dispatch.py

146 lines
4.3 KiB
Python
Raw Permalink Normal View History

2024-05-26 19:49:15 +02:00
"""
Handles dispatching array operations to the correct backend library, as well
as converting arrays to backend formats and then potentially storing them as
constants.
"""
import importlib
import numpy
from . import object_arrays
from . import cupy as _cupy
from . import jax as _jax
from . import tensorflow as _tensorflow
from . import theano as _theano
from . import torch as _torch
__all__ = ["get_func", "has_einsum", "has_tensordot", "build_expression", "evaluate_constants", "has_backend"]
# known non top-level imports
_aliases = {
'dask': 'dask.array',
'theano': 'theano.tensor',
'torch': 'opt_einsum.backends.torch',
'jax': 'jax.numpy',
'autograd': 'autograd.numpy',
'mars': 'mars.tensor',
}
def _import_func(func, backend, default=None):
"""Try and import ``{backend}.{func}``.
If library is installed and func is found, return the func;
otherwise if default is provided, return default;
otherwise raise an error.
"""
try:
lib = importlib.import_module(_aliases.get(backend, backend))
return getattr(lib, func) if default is None else getattr(lib, func, default)
except AttributeError:
error_msg = ("{} doesn't seem to provide the function {} - see "
"https://optimized-einsum.readthedocs.io/en/latest/backends.html "
"for details on which functions are required for which contractions.")
raise AttributeError(error_msg.format(backend, func))
# manually cache functions as python2 doesn't support functools.lru_cache
# other libs will be added to this if needed, but pre-populate with numpy
_cached_funcs = {
('tensordot', 'numpy'): numpy.tensordot,
('transpose', 'numpy'): numpy.transpose,
('einsum', 'numpy'): numpy.einsum,
# also pre-populate with the arbitrary object backend
('tensordot', 'object'): numpy.tensordot,
('transpose', 'object'): numpy.transpose,
('einsum', 'object'): object_arrays.object_einsum,
}
def get_func(func, backend='numpy', default=None):
"""Return ``{backend}.{func}``, e.g. ``numpy.einsum``,
or a default func if provided. Cache result.
"""
try:
return _cached_funcs[func, backend]
except KeyError:
fn = _import_func(func, backend, default)
_cached_funcs[func, backend] = fn
return fn
# mark libs with einsum, else try to use tensordot/tranpose as much as possible
_has_einsum = {}
def has_einsum(backend):
"""Check if ``{backend}.einsum`` exists, cache result for performance.
"""
try:
return _has_einsum[backend]
except KeyError:
try:
get_func('einsum', backend)
_has_einsum[backend] = True
except AttributeError:
_has_einsum[backend] = False
return _has_einsum[backend]
_has_tensordot = {}
def has_tensordot(backend):
"""Check if ``{backend}.tensordot`` exists, cache result for performance.
"""
try:
return _has_tensordot[backend]
except KeyError:
try:
get_func('tensordot', backend)
_has_tensordot[backend] = True
except AttributeError:
_has_tensordot[backend] = False
return _has_tensordot[backend]
# Dispatch to correct expression backend
# these are the backends which support explicit to-and-from numpy conversion
CONVERT_BACKENDS = {
'tensorflow': _tensorflow.build_expression,
'theano': _theano.build_expression,
'cupy': _cupy.build_expression,
'torch': _torch.build_expression,
'jax': _jax.build_expression,
}
EVAL_CONSTS_BACKENDS = {
'tensorflow': _tensorflow.evaluate_constants,
'theano': _theano.evaluate_constants,
'cupy': _cupy.evaluate_constants,
'torch': _torch.evaluate_constants,
'jax': _jax.evaluate_constants,
}
def build_expression(backend, arrays, expr):
"""Build an expression, based on ``expr`` and initial arrays ``arrays``,
that evaluates using backend ``backend``.
"""
return CONVERT_BACKENDS[backend](arrays, expr)
def evaluate_constants(backend, arrays, expr):
"""Convert constant arrays to the correct backend, and perform as much of
the contraction of ``expr`` with these as possible.
"""
return EVAL_CONSTS_BACKENDS[backend](arrays, expr)
def has_backend(backend):
"""Checks if the backend is known.
"""
return backend.lower() in CONVERT_BACKENDS