3RNN/Lib/site-packages/opt_einsum/backends/torch.py
2024-05-26 19:49:15 +02:00

129 lines
3.4 KiB
Python

"""
Required functions for optimized contractions of numpy arrays using pytorch.
"""
import numpy as np
from ..parser import convert_to_valid_einsum_chars
from ..sharing import to_backend_cache_wrap
__all__ = ["transpose", "einsum", "tensordot", "to_torch", "build_expression", "evaluate_constants"]
_TORCH_DEVICE = None
_TORCH_HAS_TENSORDOT = None
_torch_symbols_base = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
def _get_torch_and_device():
global _TORCH_DEVICE
global _TORCH_HAS_TENSORDOT
if _TORCH_DEVICE is None:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
_TORCH_DEVICE = torch, device
_TORCH_HAS_TENSORDOT = hasattr(torch, 'tensordot')
return _TORCH_DEVICE
def transpose(a, axes):
"""Normal torch transpose is only valid for 2D matrices.
"""
return a.permute(*axes)
def einsum(equation, *operands):
"""Variadic version of torch.einsum to match numpy api.
"""
# rename symbols to support PyTorch 0.4.1 and earlier,
# which allow only symbols a-z.
equation = convert_to_valid_einsum_chars(equation)
torch, _ = _get_torch_and_device()
return torch.einsum(equation, operands)
def tensordot(x, y, axes=2):
"""Simple translation of tensordot syntax to einsum.
"""
torch, _ = _get_torch_and_device()
if _TORCH_HAS_TENSORDOT:
return torch.tensordot(x, y, dims=axes)
xnd = x.ndimension()
ynd = y.ndimension()
# convert int argument to (list[int], list[int])
if isinstance(axes, int):
axes = range(xnd - axes, xnd), range(axes)
# convert (int, int) to (list[int], list[int])
if isinstance(axes[0], int):
axes = (axes[0], ), axes[1]
if isinstance(axes[1], int):
axes = axes[0], (axes[1], )
# initialize empty indices
x_ix = [None] * xnd
y_ix = [None] * ynd
out_ix = []
# fill in repeated indices
available_ix = iter(_torch_symbols_base)
for ax1, ax2 in zip(*axes):
repeat = next(available_ix)
x_ix[ax1] = repeat
y_ix[ax2] = repeat
# fill in the rest, and maintain output order
for i in range(xnd):
if x_ix[i] is None:
leave = next(available_ix)
x_ix[i] = leave
out_ix.append(leave)
for i in range(ynd):
if y_ix[i] is None:
leave = next(available_ix)
y_ix[i] = leave
out_ix.append(leave)
# form full string and contract!
einsum_str = "{},{}->{}".format(*map("".join, (x_ix, y_ix, out_ix)))
return einsum(einsum_str, x, y)
@to_backend_cache_wrap
def to_torch(array):
torch, device = _get_torch_and_device()
if isinstance(array, np.ndarray):
return torch.from_numpy(array).to(device)
return array
def build_expression(_, expr): # pragma: no cover
"""Build a torch function based on ``arrays`` and ``expr``.
"""
def torch_contract(*arrays):
torch_arrays = [to_torch(x) for x in arrays]
torch_out = expr._contract(torch_arrays, backend='torch')
if torch_out.device.type == 'cpu':
return torch_out.numpy()
return torch_out.cpu().numpy()
return torch_contract
def evaluate_constants(const_arrays, expr):
"""Convert constant arguments to torch, and perform any possible constant
contractions.
"""
const_arrays = [to_torch(x) for x in const_arrays]
return expr(*const_arrays, backend='torch', evaluate_constants=True)