3RNN/Lib/site-packages/opt_einsum/helpers.py

284 lines
7.5 KiB
Python
Raw Normal View History

2024-05-26 19:49:15 +02:00
"""
Contains helper functions for opt_einsum testing scripts
"""
from collections import OrderedDict
import numpy as np
from .parser import get_symbol
__all__ = ["build_views", "compute_size_by_dict", "find_contraction", "flop_count"]
_valid_chars = "abcdefghijklmopqABC"
_sizes = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4])
_default_dim_dict = {c: s for c, s in zip(_valid_chars, _sizes)}
def build_views(string, dimension_dict=None):
"""
Builds random numpy arrays for testing.
Parameters
----------
string : list of str
List of tensor strings to build
dimension_dict : dictionary
Dictionary of index _sizes
Returns
-------
ret : list of np.ndarry's
The resulting views.
Examples
--------
>>> view = build_views(['abbc'], {'a': 2, 'b':3, 'c':5})
>>> view[0].shape
(2, 3, 3, 5)
"""
if dimension_dict is None:
dimension_dict = _default_dim_dict
views = []
terms = string.split('->')[0].split(',')
for term in terms:
dims = [dimension_dict[x] for x in term]
views.append(np.random.rand(*dims))
return views
def compute_size_by_dict(indices, idx_dict):
"""
Computes the product of the elements in indices based on the dictionary
idx_dict.
Parameters
----------
indices : iterable
Indices to base the product on.
idx_dict : dictionary
Dictionary of index _sizes
Returns
-------
ret : int
The resulting product.
Examples
--------
>>> compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
90
"""
ret = 1
for i in indices: # lgtm [py/iteration-string-and-sequence]
ret *= idx_dict[i]
return ret
def find_contraction(positions, input_sets, output_set):
"""
Finds the contraction for a given set of input and output sets.
Parameters
----------
positions : iterable
Integer positions of terms used in the contraction.
input_sets : list
List of sets that represent the lhs side of the einsum subscript
output_set : set
Set that represents the rhs side of the overall einsum subscript
Returns
-------
new_result : set
The indices of the resulting contraction
remaining : list
List of sets that have not been contracted, the new set is appended to
the end of this list
idx_removed : set
Indices removed from the entire contraction
idx_contraction : set
The indices used in the current contraction
Examples
--------
# A simple dot product test case
>>> pos = (0, 1)
>>> isets = [set('ab'), set('bc')]
>>> oset = set('ac')
>>> find_contraction(pos, isets, oset)
({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
# A more complex case with additional terms in the contraction
>>> pos = (0, 2)
>>> isets = [set('abd'), set('ac'), set('bdc')]
>>> oset = set('ac')
>>> find_contraction(pos, isets, oset)
({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
"""
remaining = list(input_sets)
inputs = (remaining.pop(i) for i in sorted(positions, reverse=True))
idx_contract = set.union(*inputs)
idx_remain = output_set.union(*remaining)
new_result = idx_remain & idx_contract
idx_removed = (idx_contract - new_result)
remaining.append(new_result)
return new_result, remaining, idx_removed, idx_contract
def flop_count(idx_contraction, inner, num_terms, size_dictionary):
"""
Computes the number of FLOPS in the contraction.
Parameters
----------
idx_contraction : iterable
The indices involved in the contraction
inner : bool
Does this contraction require an inner product?
num_terms : int
The number of terms in a contraction
size_dictionary : dict
The size of each of the indices in idx_contraction
Returns
-------
flop_count : int
The total number of FLOPS required for the contraction.
Examples
--------
>>> flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
90
>>> flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
270
"""
overall_size = compute_size_by_dict(idx_contraction, size_dictionary)
op_factor = max(1, num_terms - 1)
if inner:
op_factor += 1
return overall_size * op_factor
def rand_equation(n, reg, n_out=0, d_min=2, d_max=9, seed=None, global_dim=False, return_size_dict=False):
"""Generate a random contraction and shapes.
Parameters
----------
n : int
Number of array arguments.
reg : int
'Regularity' of the contraction graph. This essentially determines how
many indices each tensor shares with others on average.
n_out : int, optional
Number of output indices (i.e. the number of non-contracted indices).
Defaults to 0, i.e., a contraction resulting in a scalar.
d_min : int, optional
Minimum dimension size.
d_max : int, optional
Maximum dimension size.
seed: int, optional
If not None, seed numpy's random generator with this.
global_dim : bool, optional
Add a global, 'broadcast', dimension to every operand.
return_size_dict : bool, optional
Return the mapping of indices to sizes.
Returns
-------
eq : str
The equation string.
shapes : list[tuple[int]]
The array shapes.
size_dict : dict[str, int]
The dict of index sizes, only returned if ``return_size_dict=True``.
Examples
--------
>>> eq, shapes = rand_equation(n=10, reg=4, n_out=5, seed=42)
>>> eq
'oyeqn,tmaq,skpo,vg,hxui,n,fwxmr,hitplcj,kudlgfv,rywjsb->cebda'
>>> shapes
[(9, 5, 4, 5, 4),
(4, 4, 8, 5),
(9, 4, 6, 9),
(6, 6),
(6, 9, 7, 8),
(4,),
(9, 3, 9, 4, 9),
(6, 8, 4, 6, 8, 6, 3),
(4, 7, 8, 8, 6, 9, 6),
(9, 5, 3, 3, 9, 5)]
"""
if seed is not None:
np.random.seed(seed)
# total number of indices
num_inds = n * reg // 2 + n_out
inputs = ["" for _ in range(n)]
output = []
size_dict = OrderedDict((get_symbol(i), np.random.randint(d_min, d_max + 1)) for i in range(num_inds))
# generate a list of indices to place either once or twice
def gen():
for i, ix in enumerate(size_dict):
# generate an outer index
if i < n_out:
output.append(ix)
yield ix
# generate a bond
else:
yield ix
yield ix
# add the indices randomly to the inputs
for i, ix in enumerate(np.random.permutation(list(gen()))):
# make sure all inputs have at least one index
if i < n:
inputs[i] += ix
else:
# don't add any traces on same op
where = np.random.randint(0, n)
while ix in inputs[where]:
where = np.random.randint(0, n)
inputs[where] += ix
# possibly add the same global dim to every arg
if global_dim:
gdim = get_symbol(num_inds)
size_dict[gdim] = np.random.randint(d_min, d_max + 1)
for i in range(n):
inputs[i] += gdim
output += gdim
# randomly transpose the output indices and form equation
output = "".join(np.random.permutation(output))
eq = "{}->{}".format(",".join(inputs), output)
# make the shapes
shapes = [tuple(size_dict[ix] for ix in op) for op in inputs]
ret = (eq, shapes)
if return_size_dict:
ret += (size_dict, )
return ret