Intelegentny_Pszczelarz/.venv/Lib/site-packages/opt_einsum/tests/test_blas.py

106 lines
4.0 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
"""
Tests thte BLAS capability for the opt_einsum module.
"""
import numpy as np
import pytest
from opt_einsum import blas, contract, helpers
blas_tests = [
# DOT
((['k', 'k'], '', set('k')), 'DOT'), # DDOT
((['ijk', 'ijk'], '', set('ijk')), 'DOT'), # DDOT
# GEMV?
# GEMM
((['ij', 'jk'], 'ik', set('j')), 'GEMM'), # GEMM N N
((['ijl', 'jlk'], 'ik', set('jl')), 'GEMM'), # GEMM N N Tensor
((['ij', 'kj'], 'ik', set('j')), 'GEMM'), # GEMM N T
((['ijl', 'kjl'], 'ik', set('jl')), 'GEMM'), # GEMM N T Tensor
((['ji', 'jk'], 'ik', set('j')), 'GEMM'), # GEMM T N
((['jli', 'jlk'], 'ik', set('jl')), 'GEMM'), # GEMM T N Tensor
((['ji', 'kj'], 'ik', set('j')), 'GEMM'), # GEMM T T
((['jli', 'kjl'], 'ik', set('jl')), 'GEMM'), # GEMM T T Tensor
# GEMM with final transpose
((['ij', 'jk'], 'ki', set('j')), 'GEMM'), # GEMM N N
((['ijl', 'jlk'], 'ki', set('jl')), 'GEMM'), # GEMM N N Tensor
((['ij', 'kj'], 'ki', set('j')), 'GEMM'), # GEMM N T
((['ijl', 'kjl'], 'ki', set('jl')), 'GEMM'), # GEMM N T Tensor
((['ji', 'jk'], 'ki', set('j')), 'GEMM'), # GEMM T N
((['jli', 'jlk'], 'ki', set('jl')), 'GEMM'), # GEMM T N Tensor
((['ji', 'kj'], 'ki', set('j')), 'GEMM'), # GEMM T T
((['jli', 'kjl'], 'ki', set('jl')), 'GEMM'), # GEMM T T Tensor
# Tensor Dot (requires copy), lets not deal with this for now
((['ilj', 'jlk'], 'ik', set('jl')), 'TDOT'), # FT GEMM N N Tensor
((['ijl', 'ljk'], 'ik', set('jl')), 'TDOT'), # ST GEMM N N Tensor
((['ilj', 'kjl'], 'ik', set('jl')), 'TDOT'), # FT GEMM N T Tensor
((['ijl', 'klj'], 'ik', set('jl')), 'TDOT'), # ST GEMM N T Tensor
((['lji', 'jlk'], 'ik', set('jl')), 'TDOT'), # FT GEMM T N Tensor
((['jli', 'ljk'], 'ik', set('jl')), 'TDOT'), # ST GEMM T N Tensor
((['lji', 'jlk'], 'ik', set('jl')), 'TDOT'), # FT GEMM T N Tensor
((['jli', 'ljk'], 'ik', set('jl')), 'TDOT'), # ST GEMM T N Tensor
# Tensor Dot (requires copy), lets not deal with this for now with transpose
((['ilj', 'jlk'], 'ik', set('lj')), 'TDOT'), # FT GEMM N N Tensor
((['ijl', 'ljk'], 'ik', set('lj')), 'TDOT'), # ST GEMM N N Tensor
((['ilj', 'kjl'], 'ik', set('lj')), 'TDOT'), # FT GEMM N T Tensor
((['ijl', 'klj'], 'ik', set('lj')), 'TDOT'), # ST GEMM N T Tensor
((['lji', 'jlk'], 'ik', set('lj')), 'TDOT'), # FT GEMM T N Tensor
((['jli', 'ljk'], 'ik', set('lj')), 'TDOT'), # ST GEMM T N Tensor
((['lji', 'jlk'], 'ik', set('lj')), 'TDOT'), # FT GEMM T N Tensor
((['jli', 'ljk'], 'ik', set('lj')), 'TDOT'), # ST GEMM T N Tensor
# Other
((['ijk', 'ikj'], '', set('ijk')), 'DOT/EINSUM'), # Transpose DOT
((['i', 'j'], 'ij', set()), 'OUTER/EINSUM'), # Outer
((['ijk', 'ik'], 'j', set('ik')), 'GEMV/EINSUM'), # Matrix-vector
((['ijj', 'jk'], 'ik', set('j')), False), # Double index
((['ijk', 'j'], 'ij', set()), False), # Index sum 1
((['ij', 'ij'], 'ij', set()), False), # Index sum 2
]
@pytest.mark.parametrize("inp,benchmark", blas_tests)
def test_can_blas(inp, benchmark):
result = blas.can_blas(*inp)
assert result == benchmark
@pytest.mark.parametrize("inp,benchmark", blas_tests)
def test_tensor_blas(inp, benchmark):
# Weed out non-blas cases
if benchmark is False:
return
tensor_strs, output, reduced_idx = inp
einsum_str = ','.join(tensor_strs) + '->' + output
# Only binary operations should be here
if len(tensor_strs) != 2:
assert False
view_left, view_right = helpers.build_views(einsum_str)
einsum_result = np.einsum(einsum_str, view_left, view_right)
blas_result = blas.tensor_blas(view_left, tensor_strs[0], view_right, tensor_strs[1], output, reduced_idx)
assert np.allclose(einsum_result, blas_result)
def test_blas_out():
a = np.random.rand(4, 4)
b = np.random.rand(4, 4)
c = np.random.rand(4, 4)
d = np.empty((4, 4))
contract('ij,jk->ik', a, b, out=d)
assert np.allclose(d, np.dot(a, b))
contract('ij,jk,kl->il', a, b, c, out=d)
assert np.allclose(d, np.dot(a, b).dot(c))