284 lines
7.5 KiB
Python
284 lines
7.5 KiB
Python
|
"""
|
||
|
Contains helper functions for opt_einsum testing scripts
|
||
|
"""
|
||
|
|
||
|
from collections import OrderedDict
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from .parser import get_symbol
|
||
|
|
||
|
__all__ = ["build_views", "compute_size_by_dict", "find_contraction", "flop_count"]
|
||
|
|
||
|
_valid_chars = "abcdefghijklmopqABC"
|
||
|
_sizes = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4])
|
||
|
_default_dim_dict = {c: s for c, s in zip(_valid_chars, _sizes)}
|
||
|
|
||
|
|
||
|
def build_views(string, dimension_dict=None):
|
||
|
"""
|
||
|
Builds random numpy arrays for testing.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
string : list of str
|
||
|
List of tensor strings to build
|
||
|
dimension_dict : dictionary
|
||
|
Dictionary of index _sizes
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
ret : list of np.ndarry's
|
||
|
The resulting views.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> view = build_views(['abbc'], {'a': 2, 'b':3, 'c':5})
|
||
|
>>> view[0].shape
|
||
|
(2, 3, 3, 5)
|
||
|
|
||
|
"""
|
||
|
|
||
|
if dimension_dict is None:
|
||
|
dimension_dict = _default_dim_dict
|
||
|
|
||
|
views = []
|
||
|
terms = string.split('->')[0].split(',')
|
||
|
for term in terms:
|
||
|
dims = [dimension_dict[x] for x in term]
|
||
|
views.append(np.random.rand(*dims))
|
||
|
return views
|
||
|
|
||
|
|
||
|
def compute_size_by_dict(indices, idx_dict):
|
||
|
"""
|
||
|
Computes the product of the elements in indices based on the dictionary
|
||
|
idx_dict.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
indices : iterable
|
||
|
Indices to base the product on.
|
||
|
idx_dict : dictionary
|
||
|
Dictionary of index _sizes
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
ret : int
|
||
|
The resulting product.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
|
||
|
90
|
||
|
|
||
|
"""
|
||
|
ret = 1
|
||
|
for i in indices: # lgtm [py/iteration-string-and-sequence]
|
||
|
ret *= idx_dict[i]
|
||
|
return ret
|
||
|
|
||
|
|
||
|
def find_contraction(positions, input_sets, output_set):
|
||
|
"""
|
||
|
Finds the contraction for a given set of input and output sets.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
positions : iterable
|
||
|
Integer positions of terms used in the contraction.
|
||
|
input_sets : list
|
||
|
List of sets that represent the lhs side of the einsum subscript
|
||
|
output_set : set
|
||
|
Set that represents the rhs side of the overall einsum subscript
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
new_result : set
|
||
|
The indices of the resulting contraction
|
||
|
remaining : list
|
||
|
List of sets that have not been contracted, the new set is appended to
|
||
|
the end of this list
|
||
|
idx_removed : set
|
||
|
Indices removed from the entire contraction
|
||
|
idx_contraction : set
|
||
|
The indices used in the current contraction
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
|
||
|
# A simple dot product test case
|
||
|
>>> pos = (0, 1)
|
||
|
>>> isets = [set('ab'), set('bc')]
|
||
|
>>> oset = set('ac')
|
||
|
>>> find_contraction(pos, isets, oset)
|
||
|
({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
|
||
|
|
||
|
# A more complex case with additional terms in the contraction
|
||
|
>>> pos = (0, 2)
|
||
|
>>> isets = [set('abd'), set('ac'), set('bdc')]
|
||
|
>>> oset = set('ac')
|
||
|
>>> find_contraction(pos, isets, oset)
|
||
|
({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
|
||
|
"""
|
||
|
|
||
|
remaining = list(input_sets)
|
||
|
inputs = (remaining.pop(i) for i in sorted(positions, reverse=True))
|
||
|
idx_contract = set.union(*inputs)
|
||
|
idx_remain = output_set.union(*remaining)
|
||
|
|
||
|
new_result = idx_remain & idx_contract
|
||
|
idx_removed = (idx_contract - new_result)
|
||
|
remaining.append(new_result)
|
||
|
|
||
|
return new_result, remaining, idx_removed, idx_contract
|
||
|
|
||
|
|
||
|
def flop_count(idx_contraction, inner, num_terms, size_dictionary):
|
||
|
"""
|
||
|
Computes the number of FLOPS in the contraction.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
idx_contraction : iterable
|
||
|
The indices involved in the contraction
|
||
|
inner : bool
|
||
|
Does this contraction require an inner product?
|
||
|
num_terms : int
|
||
|
The number of terms in a contraction
|
||
|
size_dictionary : dict
|
||
|
The size of each of the indices in idx_contraction
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
flop_count : int
|
||
|
The total number of FLOPS required for the contraction.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
|
||
|
>>> flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
|
||
|
90
|
||
|
|
||
|
>>> flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
|
||
|
270
|
||
|
|
||
|
"""
|
||
|
|
||
|
overall_size = compute_size_by_dict(idx_contraction, size_dictionary)
|
||
|
op_factor = max(1, num_terms - 1)
|
||
|
if inner:
|
||
|
op_factor += 1
|
||
|
|
||
|
return overall_size * op_factor
|
||
|
|
||
|
|
||
|
def rand_equation(n, reg, n_out=0, d_min=2, d_max=9, seed=None, global_dim=False, return_size_dict=False):
|
||
|
"""Generate a random contraction and shapes.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
n : int
|
||
|
Number of array arguments.
|
||
|
reg : int
|
||
|
'Regularity' of the contraction graph. This essentially determines how
|
||
|
many indices each tensor shares with others on average.
|
||
|
n_out : int, optional
|
||
|
Number of output indices (i.e. the number of non-contracted indices).
|
||
|
Defaults to 0, i.e., a contraction resulting in a scalar.
|
||
|
d_min : int, optional
|
||
|
Minimum dimension size.
|
||
|
d_max : int, optional
|
||
|
Maximum dimension size.
|
||
|
seed: int, optional
|
||
|
If not None, seed numpy's random generator with this.
|
||
|
global_dim : bool, optional
|
||
|
Add a global, 'broadcast', dimension to every operand.
|
||
|
return_size_dict : bool, optional
|
||
|
Return the mapping of indices to sizes.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
eq : str
|
||
|
The equation string.
|
||
|
shapes : list[tuple[int]]
|
||
|
The array shapes.
|
||
|
size_dict : dict[str, int]
|
||
|
The dict of index sizes, only returned if ``return_size_dict=True``.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> eq, shapes = rand_equation(n=10, reg=4, n_out=5, seed=42)
|
||
|
>>> eq
|
||
|
'oyeqn,tmaq,skpo,vg,hxui,n,fwxmr,hitplcj,kudlgfv,rywjsb->cebda'
|
||
|
|
||
|
>>> shapes
|
||
|
[(9, 5, 4, 5, 4),
|
||
|
(4, 4, 8, 5),
|
||
|
(9, 4, 6, 9),
|
||
|
(6, 6),
|
||
|
(6, 9, 7, 8),
|
||
|
(4,),
|
||
|
(9, 3, 9, 4, 9),
|
||
|
(6, 8, 4, 6, 8, 6, 3),
|
||
|
(4, 7, 8, 8, 6, 9, 6),
|
||
|
(9, 5, 3, 3, 9, 5)]
|
||
|
"""
|
||
|
|
||
|
if seed is not None:
|
||
|
np.random.seed(seed)
|
||
|
|
||
|
# total number of indices
|
||
|
num_inds = n * reg // 2 + n_out
|
||
|
inputs = ["" for _ in range(n)]
|
||
|
output = []
|
||
|
|
||
|
size_dict = OrderedDict((get_symbol(i), np.random.randint(d_min, d_max + 1)) for i in range(num_inds))
|
||
|
|
||
|
# generate a list of indices to place either once or twice
|
||
|
def gen():
|
||
|
for i, ix in enumerate(size_dict):
|
||
|
# generate an outer index
|
||
|
if i < n_out:
|
||
|
output.append(ix)
|
||
|
yield ix
|
||
|
# generate a bond
|
||
|
else:
|
||
|
yield ix
|
||
|
yield ix
|
||
|
|
||
|
# add the indices randomly to the inputs
|
||
|
for i, ix in enumerate(np.random.permutation(list(gen()))):
|
||
|
# make sure all inputs have at least one index
|
||
|
if i < n:
|
||
|
inputs[i] += ix
|
||
|
else:
|
||
|
# don't add any traces on same op
|
||
|
where = np.random.randint(0, n)
|
||
|
while ix in inputs[where]:
|
||
|
where = np.random.randint(0, n)
|
||
|
|
||
|
inputs[where] += ix
|
||
|
|
||
|
# possibly add the same global dim to every arg
|
||
|
if global_dim:
|
||
|
gdim = get_symbol(num_inds)
|
||
|
size_dict[gdim] = np.random.randint(d_min, d_max + 1)
|
||
|
for i in range(n):
|
||
|
inputs[i] += gdim
|
||
|
output += gdim
|
||
|
|
||
|
# randomly transpose the output indices and form equation
|
||
|
output = "".join(np.random.permutation(output))
|
||
|
eq = "{}->{}".format(",".join(inputs), output)
|
||
|
|
||
|
# make the shapes
|
||
|
shapes = [tuple(size_dict[ix] for ix in op) for op in inputs]
|
||
|
|
||
|
ret = (eq, shapes)
|
||
|
|
||
|
if return_size_dict:
|
||
|
ret += (size_dict, )
|
||
|
|
||
|
return ret
|