2023-06-19 00:49:18 +02:00

284 lines
7.5 KiB

Contains helper functions for opt_einsum testing scripts
from collections import OrderedDict
import numpy as np
from .parser import get_symbol
__all__ = ["build_views", "compute_size_by_dict", "find_contraction", "flop_count"]
_valid_chars = "abcdefghijklmopqABC"
_sizes = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4])
_default_dim_dict = {c: s for c, s in zip(_valid_chars, _sizes)}
def build_views(string, dimension_dict=None):
Builds random numpy arrays for testing.
string : list of str
List of tensor strings to build
dimension_dict : dictionary
Dictionary of index _sizes
ret : list of np.ndarry's
The resulting views.
>>> view = build_views(['abbc'], {'a': 2, 'b':3, 'c':5})
>>> view[0].shape
(2, 3, 3, 5)
if dimension_dict is None:
dimension_dict = _default_dim_dict
views = []
terms = string.split('->')[0].split(',')
for term in terms:
dims = [dimension_dict[x] for x in term]
return views
def compute_size_by_dict(indices, idx_dict):
Computes the product of the elements in indices based on the dictionary
indices : iterable
Indices to base the product on.
idx_dict : dictionary
Dictionary of index _sizes
ret : int
The resulting product.
>>> compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
ret = 1
for i in indices: # lgtm [py/iteration-string-and-sequence]
ret *= idx_dict[i]
return ret
def find_contraction(positions, input_sets, output_set):
Finds the contraction for a given set of input and output sets.
positions : iterable
Integer positions of terms used in the contraction.
input_sets : list
List of sets that represent the lhs side of the einsum subscript
output_set : set
Set that represents the rhs side of the overall einsum subscript
new_result : set
The indices of the resulting contraction
remaining : list
List of sets that have not been contracted, the new set is appended to
the end of this list
idx_removed : set
Indices removed from the entire contraction
idx_contraction : set
The indices used in the current contraction
# A simple dot product test case
>>> pos = (0, 1)
>>> isets = [set('ab'), set('bc')]
>>> oset = set('ac')
>>> find_contraction(pos, isets, oset)
({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
# A more complex case with additional terms in the contraction
>>> pos = (0, 2)
>>> isets = [set('abd'), set('ac'), set('bdc')]
>>> oset = set('ac')
>>> find_contraction(pos, isets, oset)
({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
remaining = list(input_sets)
inputs = (remaining.pop(i) for i in sorted(positions, reverse=True))
idx_contract = set.union(*inputs)
idx_remain = output_set.union(*remaining)
new_result = idx_remain & idx_contract
idx_removed = (idx_contract - new_result)
return new_result, remaining, idx_removed, idx_contract
def flop_count(idx_contraction, inner, num_terms, size_dictionary):
Computes the number of FLOPS in the contraction.
idx_contraction : iterable
The indices involved in the contraction
inner : bool
Does this contraction require an inner product?
num_terms : int
The number of terms in a contraction
size_dictionary : dict
The size of each of the indices in idx_contraction
flop_count : int
The total number of FLOPS required for the contraction.
>>> flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
>>> flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
overall_size = compute_size_by_dict(idx_contraction, size_dictionary)
op_factor = max(1, num_terms - 1)
if inner:
op_factor += 1
return overall_size * op_factor
def rand_equation(n, reg, n_out=0, d_min=2, d_max=9, seed=None, global_dim=False, return_size_dict=False):
"""Generate a random contraction and shapes.
n : int
Number of array arguments.
reg : int
'Regularity' of the contraction graph. This essentially determines how
many indices each tensor shares with others on average.
n_out : int, optional
Number of output indices (i.e. the number of non-contracted indices).
Defaults to 0, i.e., a contraction resulting in a scalar.
d_min : int, optional
Minimum dimension size.
d_max : int, optional
Maximum dimension size.
seed: int, optional
If not None, seed numpy's random generator with this.
global_dim : bool, optional
Add a global, 'broadcast', dimension to every operand.
return_size_dict : bool, optional
Return the mapping of indices to sizes.
eq : str
The equation string.
shapes : list[tuple[int]]
The array shapes.
size_dict : dict[str, int]
The dict of index sizes, only returned if ``return_size_dict=True``.
>>> eq, shapes = rand_equation(n=10, reg=4, n_out=5, seed=42)
>>> eq
>>> shapes
[(9, 5, 4, 5, 4),
(4, 4, 8, 5),
(9, 4, 6, 9),
(6, 6),
(6, 9, 7, 8),
(9, 3, 9, 4, 9),
(6, 8, 4, 6, 8, 6, 3),
(4, 7, 8, 8, 6, 9, 6),
(9, 5, 3, 3, 9, 5)]
if seed is not None:
# total number of indices
num_inds = n * reg // 2 + n_out
inputs = ["" for _ in range(n)]
output = []
size_dict = OrderedDict((get_symbol(i), np.random.randint(d_min, d_max + 1)) for i in range(num_inds))
# generate a list of indices to place either once or twice
def gen():
for i, ix in enumerate(size_dict):
# generate an outer index
if i < n_out:
yield ix
# generate a bond
yield ix
yield ix
# add the indices randomly to the inputs
for i, ix in enumerate(np.random.permutation(list(gen()))):
# make sure all inputs have at least one index
if i < n:
inputs[i] += ix
# don't add any traces on same op
where = np.random.randint(0, n)
while ix in inputs[where]:
where = np.random.randint(0, n)
inputs[where] += ix
# possibly add the same global dim to every arg
if global_dim:
gdim = get_symbol(num_inds)
size_dict[gdim] = np.random.randint(d_min, d_max + 1)
for i in range(n):
inputs[i] += gdim
output += gdim
# randomly transpose the output indices and form equation
output = "".join(np.random.permutation(output))
eq = "{}->{}".format(",".join(inputs), output)
# make the shapes
shapes = [tuple(size_dict[ix] for ix in op) for op in inputs]
ret = (eq, shapes)
if return_size_dict:
ret += (size_dict, )
return ret