Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/numpy/linalg.py
2023-06-19 00:49:18 +02:00

677 lines
26 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import numpy as np
import textwrap
import operator
from typing import Literal, Optional, Tuple, Union, cast, overload
import jax
from jax import jit, custom_jvp
from jax import lax
from jax._src.lax import lax as lax_internal
from jax._src.lax import linalg as lax_linalg
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import reductions, ufuncs
from jax._src.numpy.util import _wraps, promote_dtypes_inexact, check_arraylike
from jax._src.util import canonicalize_axis
from jax._src.typing import ArrayLike, Array
def _H(x: ArrayLike) -> Array:
return ufuncs.conjugate(jnp.matrix_transpose(x))
def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
@_wraps(np.linalg.cholesky)
@jit
def cholesky(a: ArrayLike) -> Array:
check_arraylike("jnp.linalg.cholesky", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
return lax_linalg.cholesky(a)
@overload
def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[True],
hermitian: bool = False) -> Tuple[Array, Array, Array]: ...
@overload
def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[True],
hermitian: bool = False) -> Tuple[Array, Array, Array]: ...
@overload
def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[False],
hermitian: bool = False) -> Array: ...
@overload
def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False],
hermitian: bool = False) -> Array: ...
@overload
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
hermitian: bool = False) -> Union[Array, Tuple[Array, Array, Array]]: ...
@_wraps(np.linalg.svd)
@partial(jit, static_argnames=('full_matrices', 'compute_uv', 'hermitian'))
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
hermitian: bool = False) -> Union[Array, Tuple[Array, Array, Array]]:
check_arraylike("jnp.linalg.svd", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
if hermitian:
w, v = lax_linalg.eigh(a)
s = lax.abs(v)
if compute_uv:
sign = lax.sign(v)
idxs = lax.broadcasted_iota(np.int64, s.shape, dimension=s.ndim - 1)
s, idxs, sign = lax.sort((s, idxs, sign), dimension=-1, num_keys=1)
s = lax.rev(s, dimensions=[s.ndim - 1])
idxs = lax.rev(idxs, dimensions=[s.ndim - 1])
sign = lax.rev(sign, dimensions=[s.ndim - 1])
u = jnp.take_along_axis(w, idxs[..., None, :], axis=-1)
vh = _H(u * sign[..., None, :].astype(u.dtype))
return u, s, vh
else:
return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim-1])
return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
@_wraps(np.linalg.matrix_power)
@partial(jit, static_argnames=('n',))
def matrix_power(a: ArrayLike, n: int) -> Array:
check_arraylike("jnp.linalg.matrix_power", a)
arr, = promote_dtypes_inexact(jnp.asarray(a))
if arr.ndim < 2:
raise TypeError("{}-dimensional array given. Array must be at least "
"two-dimensional".format(arr.ndim))
if arr.shape[-2] != arr.shape[-1]:
raise TypeError("Last 2 dimensions of the array must be square")
try:
n = operator.index(n)
except TypeError as err:
raise TypeError(f"exponent must be an integer, got {n}") from err
if n == 0:
return jnp.broadcast_to(jnp.eye(arr.shape[-2], dtype=arr.dtype), arr.shape)
elif n < 0:
arr = inv(arr)
n = abs(n)
if n == 1:
return arr
elif n == 2:
return arr @ arr
elif n == 3:
return (arr @ arr) @ arr
z = result = None
while n > 0:
z = arr if z is None else (z @ z) # type: ignore[operator]
n, bit = divmod(n, 2)
if bit:
result = z if result is None else (result @ z)
assert result is not None
return result
@_wraps(np.linalg.matrix_rank)
@jit
def matrix_rank(M: ArrayLike, tol: Optional[ArrayLike] = None) -> Array:
check_arraylike("jnp.linalg.matrix_rank", M)
M, = promote_dtypes_inexact(jnp.asarray(M))
if M.ndim < 2:
return (M != 0).any().astype(jnp.int32)
S = svd(M, full_matrices=False, compute_uv=False)
if tol is None:
tol = S.max(-1) * np.max(M.shape[-2:]).astype(S.dtype) * jnp.finfo(S.dtype).eps
tol = jnp.expand_dims(tol, np.ndim(tol))
return reductions.sum(S > tol, axis=-1)
@custom_jvp
def _slogdet_lu(a: Array) -> Tuple[Array, Array]:
dtype = lax.dtype(a)
lu, pivot, _ = lax_linalg.lu(a)
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
is_zero = reductions.any(diag == jnp.array(0, dtype=dtype), axis=-1)
iota = lax.expand_dims(jnp.arange(a.shape[-1], dtype=pivot.dtype),
range(pivot.ndim - 1))
parity = reductions.count_nonzero(pivot != iota, axis=-1)
if jnp.iscomplexobj(a):
sign = reductions.prod(diag / ufuncs.abs(diag).astype(diag.dtype), axis=-1)
else:
sign = jnp.array(1, dtype=dtype)
parity = parity + reductions.count_nonzero(diag < 0, axis=-1)
sign = jnp.where(is_zero,
jnp.array(0, dtype=dtype),
sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
logdet = jnp.where(
is_zero, jnp.array(-jnp.inf, dtype=dtype),
reductions.sum(ufuncs.log(ufuncs.abs(diag)).astype(dtype), axis=-1))
return sign, ufuncs.real(logdet)
@custom_jvp
def _slogdet_qr(a: Array) -> Tuple[Array, Array]:
# Implementation of slogdet using QR decomposition. One reason we might prefer
# QR decomposition is that it is more amenable to a fast batched
# implementation on TPU because of the lack of row pivoting.
if jnp.issubdtype(lax.dtype(a), jnp.complexfloating):
raise NotImplementedError("slogdet method='qr' not implemented for complex "
"inputs")
n = a.shape[-1]
a, taus = lax_linalg.geqrf(a)
# The determinant of a triangular matrix is the product of its diagonal
# elements. We are working in log space, so we compute the magnitude as the
# the trace of the log-absolute values, and we compute the sign separately.
log_abs_det = jnp.trace(ufuncs.log(ufuncs.abs(a)), axis1=-2, axis2=-1)
sign_diag = reductions.prod(ufuncs.sign(jnp.diagonal(a, axis1=-2, axis2=-1)), axis=-1)
# The determinant of a Householder reflector is -1. So whenever we actually
# made a reflection (tau != 0), multiply the result by -1.
sign_taus = reductions.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype)
return sign_diag * sign_taus, log_abs_det
@_wraps(
np.linalg.slogdet,
extra_params=textwrap.dedent("""
method: string, optional
One of ``lu`` or ``qr``, specifying whether the determinant should be
computed using an LU decomposition or a QR decomposition. Defaults to
LU decomposition if ``None``.
"""))
@partial(jit, static_argnames=('method',))
def slogdet(a: ArrayLike, *, method: Optional[str] = None) -> Tuple[Array, Array]:
check_arraylike("jnp.linalg.slogdet", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
a_shape = jnp.shape(a)
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
msg = "Argument to slogdet() must have shape [..., n, n], got {}"
raise ValueError(msg.format(a_shape))
if method is None or method == "lu":
return _slogdet_lu(a)
elif method == "qr":
return _slogdet_qr(a)
else:
raise ValueError(f"Unknown slogdet method '{method}'. Supported methods "
"are 'lu' (`None`), and 'qr'.")
def _slogdet_jvp(primals, tangents):
x, = primals
g, = tangents
sign, ans = slogdet(x)
ans_dot = jnp.trace(solve(x, g), axis1=-1, axis2=-2)
if jnp.issubdtype(jnp._dtype(x), jnp.complexfloating):
sign_dot = (ans_dot - ufuncs.real(ans_dot).astype(ans_dot.dtype)) * sign
ans_dot = ufuncs.real(ans_dot)
else:
sign_dot = jnp.zeros_like(sign)
return (sign, ans), (sign_dot, ans_dot)
_slogdet_lu.defjvp(_slogdet_jvp)
_slogdet_qr.defjvp(_slogdet_jvp)
def _cofactor_solve(a: ArrayLike, b: ArrayLike) -> Tuple[Array, Array]:
"""Equivalent to det(a)*solve(a, b) for nonsingular mat.
Intermediate function used for jvp and vjp of det.
This function borrows heavily from jax.numpy.linalg.solve and
jax.numpy.linalg.slogdet to compute the gradient of the determinant
in a way that is well defined even for low rank matrices.
This function handles two different cases:
* rank(a) == n or n-1
* rank(a) < n-1
For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix.
Rather than computing det(a)*solve(a, b), which would return NaN, we work
directly with the LU decomposition. If a = p @ l @ u, then
det(a)*solve(a, b) =
prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b =
prod(diag(u)) * triangular_solve(u, solve(p @ l, b))
If a is rank n-1, then the lower right corner of u will be zero and the
triangular_solve will fail.
Let x = solve(p @ l, b) and y = det(a)*solve(a, b).
Then y_{n}
x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) =
x_{n} * prod_{i=1...n-1}(u_{ii})
So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1
we can avoid the triangular_solve failing.
To correctly compute the rest of y_{i} for i != n, we simply multiply
x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1.
For the second case, a check is done on the matrix to see if `solve`
returns NaN or Inf, and gives a matrix of zeros as a result, as the
gradient of the determinant of a matrix with rank less than n-1 is 0.
This will still return the correct value for rank n-1 matrices, as the check
is applied *after* the lower right corner of u has been updated.
Args:
a: A square matrix or batch of matrices, possibly singular.
b: A matrix, or batch of matrices of the same dimension as a.
Returns:
det(a) and cofactor(a)^T*b, aka adjugate(a)*b
"""
a, = promote_dtypes_inexact(jnp.asarray(a))
b, = promote_dtypes_inexact(jnp.asarray(b))
a_shape = jnp.shape(a)
b_shape = jnp.shape(b)
a_ndims = len(a_shape)
if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2]
and b_shape[-2:] == a_shape[-2:]):
msg = ("The arguments to _cofactor_solve must have shapes "
"a=[..., m, m] and b=[..., m, m]; got a={} and b={}")
raise ValueError(msg.format(a_shape, b_shape))
if a_shape[-1] == 1:
return a[..., 0, 0], b
# lu contains u in the upper triangular matrix and l in the strict lower
# triangular matrix.
# The diagonal of l is set to ones without loss of generality.
lu, pivots, permutation = lax_linalg.lu(a)
dtype = lax.dtype(a)
batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2])
x = jnp.broadcast_to(b, batch_dims + b.shape[-2:])
lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:])
# Compute (partial) determinant, ignoring last diagonal of LU
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
iota = lax.expand_dims(jnp.arange(a_shape[-1], dtype=pivots.dtype),
range(pivots.ndim - 1))
parity = reductions.count_nonzero(pivots != iota, axis=-1)
sign = jnp.asarray(-2 * (parity % 2) + 1, dtype=dtype)
# partial_det[:, -1] contains the full determinant and
# partial_det[:, -2] contains det(u) / u_{nn}.
partial_det = reductions.cumprod(diag, axis=-1) * sign[..., None]
lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2])
permutation = jnp.broadcast_to(permutation, (*batch_dims, a_shape[-1]))
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in (*batch_dims, 1)))
# filter out any matrices that are not full rank
d = jnp.ones(x.shape[:-1], x.dtype)
d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False)
d = reductions.any(ufuncs.logical_or(ufuncs.isnan(d), ufuncs.isinf(d)), axis=-1)
d = jnp.tile(d[..., None, None], d.ndim*(1,) + x.shape[-2:])
x = jnp.where(d, jnp.zeros_like(x), x) # first filter
x = x[iotas[:-1] + (permutation, slice(None))]
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True,
unit_diagonal=True)
x = jnp.concatenate((x[..., :-1, :] * partial_det[..., -1, None, None],
x[..., -1:, :]), axis=-2)
x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)
x = jnp.where(d, jnp.zeros_like(x), x) # second filter
return partial_det[..., -1], x
def _det_2x2(a: Array) -> Array:
return (a[..., 0, 0] * a[..., 1, 1] -
a[..., 0, 1] * a[..., 1, 0])
def _det_3x3(a: Array) -> Array:
return (a[..., 0, 0] * a[..., 1, 1] * a[..., 2, 2] +
a[..., 0, 1] * a[..., 1, 2] * a[..., 2, 0] +
a[..., 0, 2] * a[..., 1, 0] * a[..., 2, 1] -
a[..., 0, 2] * a[..., 1, 1] * a[..., 2, 0] -
a[..., 0, 0] * a[..., 1, 2] * a[..., 2, 1] -
a[..., 0, 1] * a[..., 1, 0] * a[..., 2, 2])
@custom_jvp
@_wraps(np.linalg.det)
@jit
def det(a: ArrayLike) -> Array:
check_arraylike("jnp.linalg.det", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
a_shape = jnp.shape(a)
if len(a_shape) >= 2 and a_shape[-1] == 2 and a_shape[-2] == 2:
return _det_2x2(a)
elif len(a_shape) >= 2 and a_shape[-1] == 3 and a_shape[-2] == 3:
return _det_3x3(a)
elif len(a_shape) >= 2 and a_shape[-1] == a_shape[-2]:
sign, logdet = slogdet(a)
return sign * ufuncs.exp(logdet).astype(sign.dtype)
else:
msg = "Argument to _det() must have shape [..., n, n], got {}"
raise ValueError(msg.format(a_shape))
@det.defjvp
def _det_jvp(primals, tangents):
x, = primals
g, = tangents
y, z = _cofactor_solve(x, g)
return y, jnp.trace(z, axis1=-1, axis2=-2)
@_wraps(np.linalg.eig, lax_description="""
This differs from :func:`numpy.linalg.eig` in that the return type of
:func:`jax.numpy.linalg.eig` is always ``complex64`` for 32-bit input,
and ``complex128`` for 64-bit input.
At present, non-symmetric eigendecomposition is only implemented on the CPU
backend. However eigendecomposition for symmetric/Hermitian matrices is
implemented more widely (see :func:`jax.numpy.linalg.eigh`).
""")
def eig(a: ArrayLike) -> Tuple[Array, Array]:
check_arraylike("jnp.linalg.eig", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
w, v = lax_linalg.eig(a, compute_left_eigenvectors=False)
return w, v
@_wraps(np.linalg.eigvals)
@jit
def eigvals(a: ArrayLike) -> Array:
check_arraylike("jnp.linalg.eigvals", a)
return lax_linalg.eig(a, compute_left_eigenvectors=False,
compute_right_eigenvectors=False)[0]
@_wraps(np.linalg.eigh)
@partial(jit, static_argnames=('UPLO', 'symmetrize_input'))
def eigh(a: ArrayLike, UPLO: Optional[str] = None,
symmetrize_input: bool = True) -> Tuple[Array, Array]:
check_arraylike("jnp.linalg.eigh", a)
if UPLO is None or UPLO == "L":
lower = True
elif UPLO == "U":
lower = False
else:
msg = f"UPLO must be one of None, 'L', or 'U', got {UPLO}"
raise ValueError(msg)
a, = promote_dtypes_inexact(jnp.asarray(a))
v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input)
return w, v
@_wraps(np.linalg.eigvalsh)
@partial(jit, static_argnames=('UPLO',))
def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array:
check_arraylike("jnp.linalg.eigvalsh", a)
w, _ = eigh(a, UPLO)
return w
@partial(custom_jvp, nondiff_argnums=(1, 2))
@_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\
It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
default `rcond` is `1e-15`. Here the default is
`10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`.
"""))
@partial(jit, static_argnames=('hermitian',))
def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None,
hermitian: bool = False) -> Array:
# Uses same algorithm as
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
check_arraylike("jnp.linalg.pinv", a)
arr = jnp.asarray(a)
m, n = arr.shape[-2:]
if m == 0 or n == 0:
return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype)
arr = ufuncs.conj(arr)
if rcond is None:
max_rows_cols = max(arr.shape[-2:])
rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps)
rcond = jnp.asarray(rcond)
u, s, vh = svd(arr, full_matrices=False, hermitian=hermitian)
# Singular values less than or equal to ``rcond * largest_singular_value``
# are set to zero.
rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1))
cutoff = rcond * s[..., 0:1]
s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype)
res = jnp.matmul(vh.mT, ufuncs.divide(u.mT, s[..., jnp.newaxis]),
precision=lax.Precision.HIGHEST)
return lax.convert_element_type(res, arr.dtype)
@pinv.defjvp
@jax.default_matmul_precision("float32")
def _pinv_jvp(rcond, hermitian, primals, tangents):
# The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems
# Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM
# Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432.
# (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative)
a, = primals # m x n
a_dot, = tangents
p = pinv(a, rcond=rcond, hermitian=hermitian) # n x m
if hermitian:
# svd(..., hermitian=True) symmetrizes its input, and the JVP must match.
a = _symmetrize(a)
a_dot = _symmetrize(a_dot)
# TODO(phawkins): this could be simplified in the Hermitian case if we
# supported triangular matrix multiplication.
m, n = a.shape[-2:]
if m >= n:
s = (p @ _H(p)) @ _H(a_dot) # nxm
t = (_H(a_dot) @ _H(p)) @ p # nxm
p_dot = -(p @ a_dot) @ p + s - (s @ a) @ p + t - (p @ a) @ t
else: # m < n
s = p @ (_H(p) @ _H(a_dot))
t = _H(a_dot) @ (_H(p) @ p)
p_dot = -p @ (a_dot @ p) + s - s @ (a @ p) + t - p @ (a @ t)
return p, p_dot
@_wraps(np.linalg.inv)
@jit
def inv(a: ArrayLike) -> Array:
check_arraylike("jnp.linalg.inv", a)
arr = jnp.asarray(a)
if arr.ndim < 2 or arr.shape[-1] != arr.shape[-2]:
raise ValueError(
f"Argument to inv must have shape [..., n, n], got {arr.shape}.")
return solve(
arr, lax.broadcast(jnp.eye(arr.shape[-1], dtype=arr.dtype), arr.shape[:-2]))
@_wraps(np.linalg.norm)
@partial(jit, static_argnames=('ord', 'axis', 'keepdims'))
def norm(x: ArrayLike, ord: Union[int, str, None] = None,
axis: Union[None, Tuple[int, ...], int] = None,
keepdims: bool = False) -> Array:
check_arraylike("jnp.linalg.norm", x)
x, = promote_dtypes_inexact(jnp.asarray(x))
x_shape = jnp.shape(x)
ndim = len(x_shape)
if axis is None:
# NumPy has an undocumented behavior that admits arbitrary rank inputs if
# `ord` is None: https://github.com/numpy/numpy/issues/14215
if ord is None:
return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), keepdims=keepdims))
axis = tuple(range(ndim))
elif isinstance(axis, tuple):
axis = tuple(canonicalize_axis(x, ndim) for x in axis)
else:
axis = (canonicalize_axis(axis, ndim),)
num_axes = len(axis)
if num_axes == 1:
if ord is None or ord == 2:
return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis,
keepdims=keepdims))
elif ord == jnp.inf:
return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims)
elif ord == -jnp.inf:
return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims)
elif ord == 0:
return reductions.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype,
axis=axis, keepdims=keepdims)
elif ord == 1:
# Numpy has a special case for ord == 1 as an optimization. We don't
# really need the optimization (XLA could do it for us), but the Numpy
# code has slightly different type promotion semantics, so we need a
# special case too.
return reductions.sum(ufuncs.abs(x), axis=axis, keepdims=keepdims)
elif isinstance(ord, str):
msg = f"Invalid order '{ord}' for vector norm."
if ord == "inf":
msg += "Use 'jax.numpy.inf' instead."
if ord == "-inf":
msg += "Use '-jax.numpy.inf' instead."
raise ValueError(msg)
else:
abs_x = ufuncs.abs(x)
ord_arr = lax_internal._const(abs_x, ord)
ord_inv = lax_internal._const(abs_x, 1. / ord_arr)
out = reductions.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims)
return ufuncs.power(out, ord_inv)
elif num_axes == 2:
row_axis, col_axis = cast(Tuple[int, ...], axis)
if ord is None or ord in ('f', 'fro'):
return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis,
keepdims=keepdims))
elif ord == 1:
if not keepdims and col_axis > row_axis:
col_axis -= 1
return reductions.amax(reductions.sum(ufuncs.abs(x), axis=row_axis, keepdims=keepdims),
axis=col_axis, keepdims=keepdims)
elif ord == -1:
if not keepdims and col_axis > row_axis:
col_axis -= 1
return reductions.amin(reductions.sum(ufuncs.abs(x), axis=row_axis, keepdims=keepdims),
axis=col_axis, keepdims=keepdims)
elif ord == jnp.inf:
if not keepdims and row_axis > col_axis:
row_axis -= 1
return reductions.amax(reductions.sum(ufuncs.abs(x), axis=col_axis, keepdims=keepdims),
axis=row_axis, keepdims=keepdims)
elif ord == -jnp.inf:
if not keepdims and row_axis > col_axis:
row_axis -= 1
return reductions.amin(reductions.sum(ufuncs.abs(x), axis=col_axis, keepdims=keepdims),
axis=row_axis, keepdims=keepdims)
elif ord in ('nuc', 2, -2):
x = jnp.moveaxis(x, axis, (-2, -1))
if ord == 2:
reducer = reductions.amax
elif ord == -2:
reducer = reductions.amin
else:
# `sum` takes an extra dtype= argument, unlike `amax` and `amin`.
reducer = reductions.sum # type: ignore[assignment]
y = reducer(svd(x, compute_uv=False), axis=-1)
if keepdims:
y = jnp.expand_dims(y, axis)
return y
else:
raise ValueError(f"Invalid order '{ord}' for matrix norm.")
else:
raise ValueError(
f"Invalid axis values ({axis}) for jnp.linalg.norm.")
@overload
def qr(a: ArrayLike, mode: Literal["r"]) -> Array: ...
@overload
def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, Tuple[Array, Array]]: ...
@_wraps(np.linalg.qr)
@partial(jit, static_argnames=('mode',))
def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, Tuple[Array, Array]]:
check_arraylike("jnp.linalg.qr", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
if mode == "raw":
a, taus = lax_linalg.geqrf(a)
return a.mT, taus
if mode in ("reduced", "r", "full"):
full_matrices = False
elif mode == "complete":
full_matrices = True
else:
raise ValueError(f"Unsupported QR decomposition mode '{mode}'")
q, r = lax_linalg.qr(a, full_matrices=full_matrices)
if mode == "r":
return r
return q, r
@_wraps(np.linalg.solve)
@jit
def solve(a: ArrayLike, b: ArrayLike) -> Array:
check_arraylike("jnp.linalg.solve", a, b)
a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
return lax_linalg._solve(a, b)
def _lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float], *,
numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]:
# TODO: add lstsq to lax_linalg and implement this function via those wrappers.
# TODO: add custom jvp rule for more robust lstsq differentiation
a, b = promote_dtypes_inexact(a, b)
if a.shape[0] != b.shape[0]:
raise ValueError("Leading dimensions of input arrays must match")
b_orig_ndim = b.ndim
if b_orig_ndim == 1:
b = b[:, None]
if a.ndim != 2:
raise TypeError(
f"{a.ndim}-dimensional array given. Array must be two-dimensional")
if b.ndim != 2:
raise TypeError(
f"{b.ndim}-dimensional array given. Array must be one or two-dimensional")
m, n = a.shape
dtype = a.dtype
if a.size == 0:
s = jnp.empty(0, dtype=a.dtype)
rank = jnp.array(0, dtype=int)
x = jnp.empty((n, *b.shape[1:]), dtype=a.dtype)
else:
if rcond is None:
rcond = jnp.finfo(dtype).eps * max(n, m)
else:
rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)
u, s, vt = svd(a, full_matrices=False)
mask = s >= jnp.array(rcond, dtype=s.dtype) * s[0]
rank = mask.sum()
safe_s = jnp.where(mask, s, 1).astype(a.dtype)
s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis]
uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST)
x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST)
# Numpy returns empty residuals in some cases. To allow compilation, we
# default to returning full residuals in all cases.
if numpy_resid and (rank < n or m <= n):
resid = jnp.asarray([])
else:
b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST)
resid = norm(b - b_estimate, axis=0) ** 2
if b_orig_ndim == 1:
x = x.ravel()
return x, resid, rank, s
_jit_lstsq = jit(partial(_lstsq, numpy_resid=False))
@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\
It has two important differences:
1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future
the default will be `None`. Here, the default rcond is `None`.
2. In `np.linalg.lstsq` the returned residuals are empty for low-rank or over-determined
solutions. Here, the residuals are returned in all cases, to make the function
compatible with jit. The non-jit compatible numpy behavior can be recovered by
passing numpy_resid=True.
The lstsq function does not currently have a custom JVP rule, so the gradient is
poorly behaved for some inputs, particularly for low-rank `a`.
"""))
def lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float] = None, *,
numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]:
check_arraylike("jnp.linalg.lstsq", a, b)
if numpy_resid:
return _lstsq(a, b, rcond, numpy_resid=True)
return _jit_lstsq(a, b, rcond)