"""
Required functions for optimized contractions of numpy arrays using pytorch.
"""

import numpy as np

from ..parser import convert_to_valid_einsum_chars
from ..sharing import to_backend_cache_wrap

__all__ = ["transpose", "einsum", "tensordot", "to_torch", "build_expression", "evaluate_constants"]

_TORCH_DEVICE = None
_TORCH_HAS_TENSORDOT = None

_torch_symbols_base = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'


def _get_torch_and_device():
    global _TORCH_DEVICE
    global _TORCH_HAS_TENSORDOT

    if _TORCH_DEVICE is None:
        import torch
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        _TORCH_DEVICE = torch, device
        _TORCH_HAS_TENSORDOT = hasattr(torch, 'tensordot')

    return _TORCH_DEVICE


def transpose(a, axes):
    """Normal torch transpose is only valid for 2D matrices.
    """
    return a.permute(*axes)


def einsum(equation, *operands):
    """Variadic version of torch.einsum to match numpy api.
    """
    # rename symbols to support PyTorch 0.4.1 and earlier,
    # which allow only symbols a-z.
    equation = convert_to_valid_einsum_chars(equation)

    torch, _ = _get_torch_and_device()
    return torch.einsum(equation, operands)


def tensordot(x, y, axes=2):
    """Simple translation of tensordot syntax to einsum.
    """
    torch, _ = _get_torch_and_device()

    if _TORCH_HAS_TENSORDOT:
        return torch.tensordot(x, y, dims=axes)

    xnd = x.ndimension()
    ynd = y.ndimension()

    # convert int argument to (list[int], list[int])
    if isinstance(axes, int):
        axes = range(xnd - axes, xnd), range(axes)

    # convert (int, int) to (list[int], list[int])
    if isinstance(axes[0], int):
        axes = (axes[0], ), axes[1]
    if isinstance(axes[1], int):
        axes = axes[0], (axes[1], )

    # initialize empty indices
    x_ix = [None] * xnd
    y_ix = [None] * ynd
    out_ix = []

    # fill in repeated indices
    available_ix = iter(_torch_symbols_base)
    for ax1, ax2 in zip(*axes):
        repeat = next(available_ix)
        x_ix[ax1] = repeat
        y_ix[ax2] = repeat

    # fill in the rest, and maintain output order
    for i in range(xnd):
        if x_ix[i] is None:
            leave = next(available_ix)
            x_ix[i] = leave
            out_ix.append(leave)
    for i in range(ynd):
        if y_ix[i] is None:
            leave = next(available_ix)
            y_ix[i] = leave
            out_ix.append(leave)

    # form full string and contract!
    einsum_str = "{},{}->{}".format(*map("".join, (x_ix, y_ix, out_ix)))
    return einsum(einsum_str, x, y)


@to_backend_cache_wrap
def to_torch(array):
    torch, device = _get_torch_and_device()

    if isinstance(array, np.ndarray):
        return torch.from_numpy(array).to(device)

    return array


def build_expression(_, expr):  # pragma: no cover
    """Build a torch function based on ``arrays`` and ``expr``.
    """
    def torch_contract(*arrays):
        torch_arrays = [to_torch(x) for x in arrays]
        torch_out = expr._contract(torch_arrays, backend='torch')

        if torch_out.device.type == 'cpu':
            return torch_out.numpy()

        return torch_out.cpu().numpy()

    return torch_contract


def evaluate_constants(const_arrays, expr):
    """Convert constant arguments to torch, and perform any possible constant
    contractions.
    """
    const_arrays = [to_torch(x) for x in const_arrays]
    return expr(*const_arrays, backend='torch', evaluate_constants=True)