129 lines
3.4 KiB
Python
129 lines
3.4 KiB
Python
"""
|
|
Required functions for optimized contractions of numpy arrays using pytorch.
|
|
"""
|
|
|
|
import numpy as np
|
|
|
|
from ..parser import convert_to_valid_einsum_chars
|
|
from ..sharing import to_backend_cache_wrap
|
|
|
|
__all__ = ["transpose", "einsum", "tensordot", "to_torch", "build_expression", "evaluate_constants"]
|
|
|
|
_TORCH_DEVICE = None
|
|
_TORCH_HAS_TENSORDOT = None
|
|
|
|
_torch_symbols_base = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
|
|
|
|
|
def _get_torch_and_device():
|
|
global _TORCH_DEVICE
|
|
global _TORCH_HAS_TENSORDOT
|
|
|
|
if _TORCH_DEVICE is None:
|
|
import torch
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
_TORCH_DEVICE = torch, device
|
|
_TORCH_HAS_TENSORDOT = hasattr(torch, 'tensordot')
|
|
|
|
return _TORCH_DEVICE
|
|
|
|
|
|
def transpose(a, axes):
|
|
"""Normal torch transpose is only valid for 2D matrices.
|
|
"""
|
|
return a.permute(*axes)
|
|
|
|
|
|
def einsum(equation, *operands):
|
|
"""Variadic version of torch.einsum to match numpy api.
|
|
"""
|
|
# rename symbols to support PyTorch 0.4.1 and earlier,
|
|
# which allow only symbols a-z.
|
|
equation = convert_to_valid_einsum_chars(equation)
|
|
|
|
torch, _ = _get_torch_and_device()
|
|
return torch.einsum(equation, operands)
|
|
|
|
|
|
def tensordot(x, y, axes=2):
|
|
"""Simple translation of tensordot syntax to einsum.
|
|
"""
|
|
torch, _ = _get_torch_and_device()
|
|
|
|
if _TORCH_HAS_TENSORDOT:
|
|
return torch.tensordot(x, y, dims=axes)
|
|
|
|
xnd = x.ndimension()
|
|
ynd = y.ndimension()
|
|
|
|
# convert int argument to (list[int], list[int])
|
|
if isinstance(axes, int):
|
|
axes = range(xnd - axes, xnd), range(axes)
|
|
|
|
# convert (int, int) to (list[int], list[int])
|
|
if isinstance(axes[0], int):
|
|
axes = (axes[0], ), axes[1]
|
|
if isinstance(axes[1], int):
|
|
axes = axes[0], (axes[1], )
|
|
|
|
# initialize empty indices
|
|
x_ix = [None] * xnd
|
|
y_ix = [None] * ynd
|
|
out_ix = []
|
|
|
|
# fill in repeated indices
|
|
available_ix = iter(_torch_symbols_base)
|
|
for ax1, ax2 in zip(*axes):
|
|
repeat = next(available_ix)
|
|
x_ix[ax1] = repeat
|
|
y_ix[ax2] = repeat
|
|
|
|
# fill in the rest, and maintain output order
|
|
for i in range(xnd):
|
|
if x_ix[i] is None:
|
|
leave = next(available_ix)
|
|
x_ix[i] = leave
|
|
out_ix.append(leave)
|
|
for i in range(ynd):
|
|
if y_ix[i] is None:
|
|
leave = next(available_ix)
|
|
y_ix[i] = leave
|
|
out_ix.append(leave)
|
|
|
|
# form full string and contract!
|
|
einsum_str = "{},{}->{}".format(*map("".join, (x_ix, y_ix, out_ix)))
|
|
return einsum(einsum_str, x, y)
|
|
|
|
|
|
@to_backend_cache_wrap
|
|
def to_torch(array):
|
|
torch, device = _get_torch_and_device()
|
|
|
|
if isinstance(array, np.ndarray):
|
|
return torch.from_numpy(array).to(device)
|
|
|
|
return array
|
|
|
|
|
|
def build_expression(_, expr): # pragma: no cover
|
|
"""Build a torch function based on ``arrays`` and ``expr``.
|
|
"""
|
|
def torch_contract(*arrays):
|
|
torch_arrays = [to_torch(x) for x in arrays]
|
|
torch_out = expr._contract(torch_arrays, backend='torch')
|
|
|
|
if torch_out.device.type == 'cpu':
|
|
return torch_out.numpy()
|
|
|
|
return torch_out.cpu().numpy()
|
|
|
|
return torch_contract
|
|
|
|
|
|
def evaluate_constants(const_arrays, expr):
|
|
"""Convert constant arguments to torch, and perform any possible constant
|
|
contractions.
|
|
"""
|
|
const_arrays = [to_torch(x) for x in const_arrays]
|
|
return expr(*const_arrays, backend='torch', evaluate_constants=True)
|