Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/third_party/numpy/linalg.py
2023-06-19 00:49:18 +02:00

212 lines
6.1 KiB
Python

import numpy as np
import jax.numpy as jnp
import jax.numpy.linalg as la
from jax._src.numpy.util import check_arraylike, _wraps
def _isEmpty2d(arr):
# check size first for efficiency
return arr.size == 0 and np.prod(arr.shape[-2:]) == 0
def _assertNoEmpty2d(*arrays):
for a in arrays:
if _isEmpty2d(a):
raise np.linalg.LinAlgError("Arrays cannot be empty")
def _assertRankAtLeast2(*arrays):
for a in arrays:
if a.ndim < 2:
raise np.linalg.LinAlgError(
'%d-dimensional array given. Array must be '
'at least two-dimensional' % a.ndim)
def _assertNdSquareness(*arrays):
for a in arrays:
m, n = a.shape[-2:]
if m != n:
raise np.linalg.LinAlgError(
'Last 2 dimensions of the array must be square')
def _assert2d(*arrays):
for a in arrays:
if a.ndim != 2:
raise ValueError(f'{a.ndim}-dimensional array given. '
'Array must be two-dimensional')
@_wraps(np.linalg.cond)
def cond(x, p=None):
check_arraylike('jnp.linalg.cond', x)
_assertNoEmpty2d(x)
if p in (None, 2):
s = la.svd(x, compute_uv=False)
return s[..., 0] / s[..., -1]
elif p == -2:
s = la.svd(x, compute_uv=False)
r = s[..., -1] / s[..., 0]
else:
_assertRankAtLeast2(x)
_assertNdSquareness(x)
invx = la.inv(x)
r = la.norm(x, ord=p, axis=(-2, -1)) * la.norm(invx, ord=p, axis=(-2, -1))
# Convert nans to infs unless the original array had nan entries
orig_nan_check = jnp.full_like(r, ~jnp.isnan(r).any())
nan_mask = jnp.logical_and(jnp.isnan(r), ~jnp.isnan(x).any(axis=(-2, -1)))
r = jnp.where(orig_nan_check, jnp.where(nan_mask, jnp.inf, r), r)
return r
@_wraps(np.linalg.tensorinv)
def tensorinv(a, ind=2):
check_arraylike('jnp.linalg.tensorinv', a)
a = jnp.asarray(a)
oldshape = a.shape
prod = 1
if ind > 0:
invshape = oldshape[ind:] + oldshape[:ind]
for k in oldshape[ind:]:
prod *= k
else:
raise ValueError("Invalid ind argument.")
a = a.reshape(prod, -1)
ia = la.inv(a)
return ia.reshape(*invshape)
@_wraps(np.linalg.tensorsolve)
def tensorsolve(a, b, axes=None):
check_arraylike('jnp.linalg.tensorsolve', a, b)
a = jnp.asarray(a)
b = jnp.asarray(b)
an = a.ndim
if axes is not None:
allaxes = list(range(0, an))
for k in axes:
allaxes.remove(k)
allaxes.insert(an, k)
a = a.transpose(allaxes)
Q = a.shape[-(an - b.ndim):]
prod = 1
for k in Q:
prod *= k
a = a.reshape(-1, prod)
b = b.ravel()
res = jnp.asarray(la.solve(a, b))
res = res.reshape(Q)
return res
@_wraps(np.linalg.multi_dot)
def multi_dot(arrays, *, precision=None):
check_arraylike('jnp.linalg.multi_dot', *arrays)
n = len(arrays)
# optimization only makes sense for len(arrays) > 2
if n < 2:
raise ValueError("Expecting at least two arrays.")
elif n == 2:
return jnp.dot(arrays[0], arrays[1], precision=precision)
arrays = [jnp.asarray(a) for a in arrays]
# save original ndim to reshape the result array into the proper form later
ndim_first, ndim_last = arrays[0].ndim, arrays[-1].ndim
# Explicitly convert vectors to 2D arrays to keep the logic of the internal
# _multi_dot_* functions as simple as possible.
if arrays[0].ndim == 1:
arrays[0] = jnp.atleast_2d(arrays[0])
if arrays[-1].ndim == 1:
arrays[-1] = jnp.atleast_2d(arrays[-1]).T
_assert2d(*arrays)
# _multi_dot_three is much faster than _multi_dot_matrix_chain_order
if n == 3:
result = _multi_dot_three(*arrays, precision)
else:
order = _multi_dot_matrix_chain_order(arrays)
result = _multi_dot(arrays, order, 0, n - 1, precision)
# return proper shape
if ndim_first == 1 and ndim_last == 1:
return result[0, 0] # scalar
elif ndim_first == 1 or ndim_last == 1:
return result.ravel() # 1-D
else:
return result
def _multi_dot_three(A, B, C, precision):
"""
Find the best order for three arrays and do the multiplication.
For three arguments `_multi_dot_three` is approximately 15 times faster
than `_multi_dot_matrix_chain_order`
"""
a0, a1b0 = A.shape
b1c0, c1 = C.shape
# cost1 = cost((AB)C) = a0*a1b0*b1c0 + a0*b1c0*c1
cost1 = a0 * b1c0 * (a1b0 + c1)
# cost2 = cost(A(BC)) = a1b0*b1c0*c1 + a0*a1b0*c1
cost2 = a1b0 * c1 * (a0 + b1c0)
if cost1 < cost2:
return jnp.dot(jnp.dot(A, B, precision=precision), C, precision=precision)
else:
return jnp.dot(A, jnp.dot(B, C, precision=precision), precision=precision)
def _multi_dot_matrix_chain_order(arrays, return_costs=False):
"""
Return a jnp.array that encodes the optimal order of mutiplications.
The optimal order array is then used by `_multi_dot()` to do the
multiplication.
Also return the cost matrix if `return_costs` is `True`
The implementation CLOSELY follows Cormen, "Introduction to Algorithms",
Chapter 15.2, p. 370-378. Note that Cormen uses 1-based indices.
cost[i, j] = min([
cost[prefix] + cost[suffix] + cost_mult(prefix, suffix)
for k in range(i, j)])
"""
n = len(arrays)
# p stores the dimensions of the matrices
# Example for p: A_{10x100}, B_{100x5}, C_{5x50} --> p = [10, 100, 5, 50]
p = [a.shape[0] for a in arrays] + [arrays[-1].shape[1]]
# m is a matrix of costs of the subproblems
# m[i,j]: min number of scalar multiplications needed to compute A_{i..j}
m = np.zeros((n, n), dtype=np.double)
# s is the actual ordering
# s[i, j] is the value of k at which we split the product A_i..A_j
s = np.empty((n, n), dtype=np.intp)
for l in range(1, n):
for i in range(n - l):
j = i + l
m[i, j] = jnp.inf
for k in range(i, j):
q = m[i, k] + m[k+1, j] + p[i]*p[k+1]*p[j+1]
if q < m[i, j]:
m[i, j] = q
s[i, j] = k # Note that Cormen uses 1-based index
return (s, m) if return_costs else s
def _multi_dot(arrays, order, i, j, precision):
"""Actually do the multiplication with the given order."""
if i == j:
return arrays[i]
else:
return jnp.dot(_multi_dot(arrays, order, i, order[i, j], precision),
_multi_dot(arrays, order, order[i, j] + 1, j, precision),
precision=precision)