from functools import partial
import numpy as np
import textwrap
import operator
from typing import Literal, Optional, Tuple, Union, cast, overload

import jax
from jax import jit, custom_jvp
from jax import lax
from jax._src.lax import lax as lax_internal
from jax._src.lax import linalg as lax_linalg
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import reductions, ufuncs
from jax._src.numpy.util import _wraps, promote_dtypes_inexact, check_arraylike
from jax._src.util import canonicalize_axis
from jax._src.typing import ArrayLike, Array


def _H(x: ArrayLike) -> Array:
  return ufuncs.conjugate(jnp.matrix_transpose(x))


def _symmetrize(x: Array) -> Array:
  return (x + _H(x)) / 2


@_wraps(np.linalg.cholesky)
@jit
def cholesky(a: ArrayLike) -> Array:
  check_arraylike("jnp.linalg.cholesky", a)
  a, = promote_dtypes_inexact(jnp.asarray(a))
  return lax_linalg.cholesky(a)


@overload
def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[True],
        hermitian: bool = False) -> Tuple[Array, Array, Array]: ...

@overload
def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[True],
        hermitian: bool = False) -> Tuple[Array, Array, Array]: ...

@overload
def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[False],
        hermitian: bool = False) -> Array: ...

@overload
def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False],
        hermitian: bool = False) -> Array: ...

@overload
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
        hermitian: bool = False) -> Union[Array, Tuple[Array, Array, Array]]: ...

@_wraps(np.linalg.svd)
@partial(jit, static_argnames=('full_matrices', 'compute_uv', 'hermitian'))
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
        hermitian: bool = False) -> Union[Array, Tuple[Array, Array, Array]]:
  check_arraylike("jnp.linalg.svd", a)
  a, = promote_dtypes_inexact(jnp.asarray(a))
  if hermitian:
    w, v = lax_linalg.eigh(a)
    s = lax.abs(v)
    if compute_uv:
      sign = lax.sign(v)
      idxs = lax.broadcasted_iota(np.int64, s.shape, dimension=s.ndim - 1)
      s, idxs, sign = lax.sort((s, idxs, sign), dimension=-1, num_keys=1)
      s = lax.rev(s, dimensions=[s.ndim - 1])
      idxs = lax.rev(idxs, dimensions=[s.ndim - 1])
      sign = lax.rev(sign, dimensions=[s.ndim - 1])
      u = jnp.take_along_axis(w, idxs[..., None, :], axis=-1)
      vh = _H(u * sign[..., None, :].astype(u.dtype))
      return u, s, vh
    else:
      return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim-1])
  return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)


@_wraps(np.linalg.matrix_power)
@partial(jit, static_argnames=('n',))
def matrix_power(a: ArrayLike, n: int) -> Array:
  check_arraylike("jnp.linalg.matrix_power", a)
  arr, = promote_dtypes_inexact(jnp.asarray(a))

  if arr.ndim < 2:
    raise TypeError("{}-dimensional array given. Array must be at least " "two-dimensional".format(arr.ndim)) if arr.shape[-2] != arr.shape[-1]: raise TypeError("Last 2 dimensions of the array must be square") try: n = operator.index(n) except TypeError as err: raise TypeError(f"exponent must be an integer, got {n}") from err if n == 0: return jnp.broadcast_to(jnp.eye(arr.shape[-2], dtype=arr.dtype), arr.shape) elif n < 0: arr = inv(arr) n = abs(n) if n == 1: return arr elif n == 2: return arr @ arr elif n == 3: return (arr @ arr) @ arr z = result = None while n > 0: z = arr if z is None else (z @ z) # type: ignore[operator] n, bit = divmod(n, 2) if bit: result = z if result is None else (result @ z) assert result is not None return result @_wraps(np.linalg.matrix_rank) @jit def matrix_rank(M: ArrayLike, tol: Optional[ArrayLike] = None) -> Array: check_arraylike("jnp.linalg.matrix_rank", M) M, = promote_dtypes_inexact(jnp.asarray(M)) if M.ndim < 2: return (M != 0).any().astype(jnp.int32) S = svd(M, full_matrices=False, compute_uv=False) if tol is None: tol = S.max(-1) * np.max(M.shape[-2:]).astype(S.dtype) * jnp.finfo(S.dtype).eps tol = jnp.expand_dims(tol, np.ndim(tol)) return reductions.sum(S > tol, axis=-1) @custom_jvp def _slogdet_lu(a: Array) -> Tuple[Array, Array]: dtype = lax.dtype(a) lu, pivot, _ = lax_linalg.lu(a) diag = jnp.diagonal(lu, axis1=-2, axis2=-1) is_zero = reductions.any(diag == jnp.array(0, dtype=dtype), axis=-1) iota = lax.expand_dims(jnp.arange(a.shape[-1], dtype=pivot.dtype), range(pivot.ndim - 1)) parity = reductions.count_nonzero(pivot != iota, axis=-1) if jnp.iscomplexobj(a): sign = reductions.prod(diag / ufuncs.abs(diag).astype(diag.dtype), axis=-1) else: sign = jnp.array(1, dtype=dtype) parity = parity + reductions.count_nonzero(diag < 0, axis=-1) sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) logdet = jnp.where( is_zero, jnp.array(-jnp.inf, dtype=dtype), reductions.sum(ufuncs.log(ufuncs.abs(diag)).astype(dtype), axis=-1)) return sign, ufuncs.real(logdet) @custom_jvp def _slogdet_qr(a: Array) -> Tuple[Array, Array]: # Implementation of slogdet using QR decomposition. One reason we might prefer # QR decomposition is that it is more amenable to a fast batched # implementation on TPU because of the lack of row pivoting. if jnp.issubdtype(lax.dtype(a), jnp.complexfloating): raise NotImplementedError("slogdet method='qr' not implemented for complex " "inputs") n = a.shape[-1] a, taus = lax_linalg.geqrf(a) # The determinant of a triangular matrix is the product of its diagonal # elements. We are working in log space, so we compute the magnitude as the # the trace of the log-absolute values, and we compute the sign separately. log_abs_det = jnp.trace(ufuncs.log(ufuncs.abs(a)), axis1=-2, axis2=-1) sign_diag = reductions.prod(ufuncs.sign(jnp.diagonal(a, axis1=-2, axis2=-1)), axis=-1) # The determinant of a Householder reflector is -1. So whenever we actually # made a reflection (tau != 0), multiply the result by -1. sign_taus = reductions.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype) return sign_diag * sign_taus, log_abs_det @_wraps( np.linalg.slogdet, extra_params=textwrap.dedent(""" method: string, optional One of ``lu`` or ``qr``, specifying whether the determinant should be computed using an LU decomposition or a QR decomposition. Defaults to LU decomposition if ``None``. """)) @partial(jit, static_argnames=('method',)) def slogdet(a: ArrayLike, *, method: Optional[str] = None) -> Tuple[Array, Array]: check_arraylike("jnp.linalg.slogdet", a) a, = promote_dtypes_inexact(jnp.asarray(a)) a_shape = jnp.shape(a) if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: msg = "Argument to slogdet() must have shape [..., n, n], got {}" raise ValueError(msg.format(a_shape)) if method is None or method == "lu": return _slogdet_lu(a) elif method == "qr": return _slogdet_qr(a) else: raise ValueError(f"Unknown slogdet method '{method}'. Supported methods " "are 'lu' (`None`), and 'qr'.") def _slogdet_jvp(primals, tangents): x, = primals g, = tangents sign, ans = slogdet(x) ans_dot = jnp.trace(solve(x, g), axis1=-1, axis2=-2) if jnp.issubdtype(jnp._dtype(x), jnp.complexfloating): sign_dot = (ans_dot - ufuncs.real(ans_dot).astype(ans_dot.dtype)) * sign ans_dot = ufuncs.real(ans_dot) else: sign_dot = jnp.zeros_like(sign) return (sign, ans), (sign_dot, ans_dot) _slogdet_lu.defjvp(_slogdet_jvp) _slogdet_qr.defjvp(_slogdet_jvp) def _cofactor_solve(a: ArrayLike, b: ArrayLike) -> Tuple[Array, Array]: """Equivalent to det(a)*solve(a, b) for nonsingular mat. Intermediate function used for jvp and vjp of det. This function borrows heavily from jax.numpy.linalg.solve and jax.numpy.linalg.slogdet to compute the gradient of the determinant in a way that is well defined even for low rank matrices. This function handles two different cases: * rank(a) == n or n-1 * rank(a) < n-1 For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix. Rather than computing det(a)*solve(a, b), which would return NaN, we work directly with the LU decomposition. If a = p @ l @ u, then det(a)*solve(a, b) = prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b = prod(diag(u)) * triangular_solve(u, solve(p @ l, b)) If a is rank n-1, then the lower right corner of u will be zero and the triangular_solve will fail. Let x = solve(p @ l, b) and y = det(a)*solve(a, b). Then y_{n} x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) = x_{n} * prod_{i=1...n-1}(u_{ii}) So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1 we can avoid the triangular_solve failing. To correctly compute the rest of y_{i} for i != n, we simply multiply x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1. For the second case, a check is done on the matrix to see if `solve` returns NaN or Inf, and gives a matrix of zeros as a result, as the gradient of the determinant of a matrix with rank less than n-1 is 0. This will still return the correct value for rank n-1 matrices, as the check is applied *after* the lower right corner of u has been updated. Args: a: A square matrix or batch of matrices, possibly singular. b: A matrix, or batch of matrices of the same dimension as a. Returns: det(a) and cofactor(a)^T*b, aka adjugate(a)*b """ a, = promote_dtypes_inexact(jnp.asarray(a)) b, = promote_dtypes_inexact(jnp.asarray(b)) a_shape = jnp.shape(a) b_shape = jnp.shape(b) a_ndims = len(a_shape) if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2] and b_shape[-2:] == a_shape[-2:]): msg = ("The arguments to _cofactor_solve must have shapes " "a=[..., m, m] and b=[..., m, m]; got a={} and b={}") raise ValueError(msg.format(a_shape, b_shape)) if a_shape[-1] == 1: return a[..., 0, 0], b # lu contains u in the upper triangular matrix and l in the strict lower # triangular matrix. # The diagonal of l is set to ones without loss of generality. lu, pivots, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2]) x = jnp.broadcast_to(b, batch_dims + b.shape[-2:]) lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:]) # Compute (partial) determinant, ignoring last diagonal of LU diag = jnp.diagonal(lu, axis1=-2, axis2=-1) iota = lax.expand_dims(jnp.arange(a_shape[-1], dtype=pivots.dtype), range(pivots.ndim - 1)) parity = reductions.count_nonzero(pivots != iota, axis=-1) sign = jnp.asarray(-2 * (parity % 2) + 1, dtype=dtype) # partial_det[:, -1] contains the full determinant and # partial_det[:, -2] contains det(u) / u_{nn}. partial_det = reductions.cumprod(diag, axis=-1) * sign[..., None] lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2]) permutation = jnp.broadcast_to(permutation, (*batch_dims, a_shape[-1])) iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in (*batch_dims, 1))) # filter out any matrices that are not full rank d = jnp.ones(x.shape[:-1], x.dtype) d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False) d = reductions.any(ufuncs.logical_or(ufuncs.isnan(d), ufuncs.isinf(d)), axis=-1) d = jnp.tile(d[..., None, None], d.ndim*(1,) + x.shape[-2:]) x = jnp.where(d, jnp.zeros_like(x), x) # first filter x = x[iotas[:-1] + (permutation, slice(None))] x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True) x = jnp.concatenate((x[..., :-1, :] * partial_det[..., -1, None, None], x[..., -1:, :]), axis=-2) x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False) x = jnp.where(d, jnp.zeros_like(x), x) # second filter return partial_det[..., -1], x def _det_2x2(a: Array) -> Array: return (a[..., 0, 0] * a[..., 1, 1] - a[..., 0, 1] * a[..., 1, 0]) def _det_3x3(a: Array) -> Array: return (a[..., 0, 0] * a[..., 1, 1] * a[..., 2, 2] + a[..., 0, 1] * a[..., 1, 2] * a[..., 2, 0] + a[..., 0, 2] * a[..., 1, 0] * a[..., 2, 1] - a[..., 0, 2] * a[..., 1, 1] * a[..., 2, 0] - a[..., 0, 0] * a[..., 1, 2] * a[..., 2, 1] - a[..., 0, 1] * a[..., 1, 0] * a[..., 2, 2]) @custom_jvp @_wraps(np.linalg.det) @jit def det(a: ArrayLike) -> Array: check_arraylike("jnp.linalg.det", a) a, = promote_dtypes_inexact(jnp.asarray(a)) a_shape = jnp.shape(a) if len(a_shape) >= 2 and a_shape[-1] == 2 and a_shape[-2] == 2: return _det_2x2(a) elif len(a_shape) >= 2 and a_shape[-1] == 3 and a_shape[-2] == 3: return _det_3x3(a) elif len(a_shape) >= 2 and a_shape[-1] == a_shape[-2]: sign, logdet = slogdet(a) return sign * ufuncs.exp(logdet).astype(sign.dtype) else: msg = "Argument to _det() must have shape [..., n, n], got {}" raise ValueError(msg.format(a_shape)) @det.defjvp def _det_jvp(primals, tangents): x, = primals g, = tangents y, z = _cofactor_solve(x, g) return y, jnp.trace(z, axis1=-1, axis2=-2) @_wraps(np.linalg.eig, lax_description=""" This differs from :func:`numpy.linalg.eig` in that the return type of :func:`jax.numpy.linalg.eig` is always ``complex64`` for 32-bit input, and ``complex128`` for 64-bit input. At present, non-symmetric eigendecomposition is only implemented on the CPU backend. However eigendecomposition for symmetric/Hermitian matrices is implemented more widely (see :func:`jax.numpy.linalg.eigh`). """) def eig(a: ArrayLike) -> Tuple[Array, Array]: check_arraylike("jnp.linalg.eig", a) a, = promote_dtypes_inexact(jnp.asarray(a)) w, v = lax_linalg.eig(a, compute_left_eigenvectors=False) return w, v @_wraps(np.linalg.eigvals) @jit def eigvals(a: ArrayLike) -> Array: check_arraylike("jnp.linalg.eigvals", a) return lax_linalg.eig(a, compute_left_eigenvectors=False, compute_right_eigenvectors=False)[0] @_wraps(np.linalg.eigh) @partial(jit, static_argnames=('UPLO', 'symmetrize_input')) def eigh(a: ArrayLike, UPLO: Optional[str] = None, symmetrize_input: bool = True) -> Tuple[Array, Array]: check_arraylike("jnp.linalg.eigh", a) if UPLO is None or UPLO == "L": lower = True elif UPLO == "U": lower = False else: msg = f"UPLO must be one of None, 'L', or 'U', got {UPLO}" raise ValueError(msg) a, = promote_dtypes_inexact(jnp.asarray(a)) v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input) return w, v @_wraps(np.linalg.eigvalsh) @partial(jit, static_argnames=('UPLO',)) def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array: check_arraylike("jnp.linalg.eigvalsh", a) w, _ = eigh(a, UPLO) return w @partial(custom_jvp, nondiff_argnums=(1, 2)) @_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\ It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the default `rcond` is `1e-15`. Here the default is `10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`. """)) @partial(jit, static_argnames=('hermitian',)) def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None, hermitian: bool = False) -> Array: # Uses same algorithm as # https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979 check_arraylike("jnp.linalg.pinv", a) arr = jnp.asarray(a) m, n = arr.shape[-2:] if m == 0 or n == 0: return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype) arr = ufuncs.conj(arr) if rcond is None: max_rows_cols = max(arr.shape[-2:]) rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps) rcond = jnp.asarray(rcond) u, s, vh = svd(arr, full_matrices=False, hermitian=hermitian) # Singular values less than or equal to ``rcond * largest_singular_value`` # are set to zero. rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1)) cutoff = rcond * s[..., 0:1] s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype) res = jnp.matmul(vh.mT, ufuncs.divide(u.mT, s[..., jnp.newaxis]), precision=lax.Precision.HIGHEST) return lax.convert_element_type(res, arr.dtype) @pinv.defjvp @jax.default_matmul_precision("float32") def _pinv_jvp(rcond, hermitian, primals, tangents): # The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems # Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM # Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432. # (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative) a, = primals # m x n a_dot, = tangents p = pinv(a, rcond=rcond, hermitian=hermitian) # n x m if hermitian: # svd(..., hermitian=True) symmetrizes its input, and the JVP must match. a = _symmetrize(a) a_dot = _symmetrize(a_dot) # TODO(phawkins): this could be simplified in the Hermitian case if we # supported triangular matrix multiplication. m, n = a.shape[-2:] if m >= n: s = (p @ _H(p)) @ _H(a_dot) # nxm t = (_H(a_dot) @ _H(p)) @ p # nxm p_dot = -(p @ a_dot) @ p + s - (s @ a) @ p + t - (p @ a) @ t else: # m < n s = p @ (_H(p) @ _H(a_dot)) t = _H(a_dot) @ (_H(p) @ p) p_dot = -p @ (a_dot @ p) + s - s @ (a @ p) + t - p @ (a @ t) return p, p_dot @_wraps(np.linalg.inv) @jit def inv(a: ArrayLike) -> Array: check_arraylike("jnp.linalg.inv", a) arr = jnp.asarray(a) if arr.ndim < 2 or arr.shape[-1] != arr.shape[-2]: raise ValueError( f"Argument to inv must have shape [..., n, n], got {arr.shape}.") return solve( arr, lax.broadcast(jnp.eye(arr.shape[-1], dtype=arr.dtype), arr.shape[:-2])) @_wraps(np.linalg.norm) @partial(jit, static_argnames=('ord', 'axis', 'keepdims')) def norm(x: ArrayLike, ord: Union[int, str, None] = None, axis: Union[None, Tuple[int, ...], int] = None, keepdims: bool = False) -> Array: check_arraylike("jnp.linalg.norm", x) x, = promote_dtypes_inexact(jnp.asarray(x)) x_shape = jnp.shape(x) ndim = len(x_shape) if axis is None: # NumPy has an undocumented behavior that admits arbitrary rank inputs if # `ord` is None: https://github.com/numpy/numpy/issues/14215 if ord is None: return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), keepdims=keepdims)) axis = tuple(range(ndim)) elif isinstance(axis, tuple): axis = tuple(canonicalize_axis(x, ndim) for x in axis) else: axis = (canonicalize_axis(axis, ndim),) num_axes = len(axis) if num_axes == 1: if ord is None or ord == 2: return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis, keepdims=keepdims)) elif ord == jnp.inf: return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims) elif ord == -jnp.inf: return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims) elif ord == 0: return reductions.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype, axis=axis, keepdims=keepdims) elif ord == 1: # Numpy has a special case for ord == 1 as an optimization. We don't # really need the optimization (XLA could do it for us), but the Numpy # code has slightly different type promotion semantics, so we need a # special case too. return reductions.sum(ufuncs.abs(x), axis=axis, keepdims=keepdims) elif isinstance(ord, str): msg = f"Invalid order '{ord}' for vector norm." if ord == "inf": msg += "Use 'jax.numpy.inf' instead." if ord == "-inf": msg += "Use '-jax.numpy.inf' instead." raise ValueError(msg) else: abs_x = ufuncs.abs(x) ord_arr = lax_internal._const(abs_x, ord) ord_inv = lax_internal._const(abs_x, 1. / ord_arr) out = reductions.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims) return ufuncs.power(out, ord_inv) elif num_axes == 2: row_axis, col_axis = cast(Tuple[int, ...], axis) if ord is None or ord in ('f', 'fro'): return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis, keepdims=keepdims)) elif ord == 1: if not keepdims and col_axis > row_axis: col_axis -= 1 return reductions.amax(reductions.sum(ufuncs.abs(x), axis=row_axis, keepdims=keepdims), axis=col_axis, keepdims=keepdims) elif ord == -1: if not keepdims and col_axis > row_axis: col_axis -= 1 return reductions.amin(reductions.sum(ufuncs.abs(x), axis=row_axis, keepdims=keepdims), axis=col_axis, keepdims=keepdims) elif ord == jnp.inf: if not keepdims and row_axis > col_axis: row_axis -= 1 return reductions.amax(reductions.sum(ufuncs.abs(x), axis=col_axis, keepdims=keepdims), axis=row_axis, keepdims=keepdims) elif ord == -jnp.inf: if not keepdims and row_axis > col_axis: row_axis -= 1 return reductions.amin(reductions.sum(ufuncs.abs(x), axis=col_axis, keepdims=keepdims), axis=row_axis, keepdims=keepdims) elif ord in ('nuc', 2, -2): x = jnp.moveaxis(x, axis, (-2, -1)) if ord == 2: reducer = reductions.amax elif ord == -2: reducer = reductions.amin else: # `sum` takes an extra dtype= argument, unlike `amax` and `amin`. reducer = reductions.sum # type: ignore[assignment] y = reducer(svd(x, compute_uv=False), axis=-1) if keepdims: y = jnp.expand_dims(y, axis) return y else: raise ValueError(f"Invalid order '{ord}' for matrix norm.") else: raise ValueError( f"Invalid axis values ({axis}) for jnp.linalg.norm.") @overload def qr(a: ArrayLike, mode: Literal["r"]) -> Array: ... @overload def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, Tuple[Array, Array]]: ... @_wraps(np.linalg.qr) @partial(jit, static_argnames=('mode',)) def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, Tuple[Array, Array]]: check_arraylike("jnp.linalg.qr", a) a, = promote_dtypes_inexact(jnp.asarray(a)) if mode == "raw": a, taus = lax_linalg.geqrf(a) return a.mT, taus if mode in ("reduced", "r", "full"): full_matrices = False elif mode == "complete": full_matrices = True else: raise ValueError(f"Unsupported QR decomposition mode '{mode}'") q, r = lax_linalg.qr(a, full_matrices=full_matrices) if mode == "r": return r return q, r @_wraps(np.linalg.solve) @jit def solve(a: ArrayLike, b: ArrayLike) -> Array: check_arraylike("jnp.linalg.solve", a, b) a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b)) return lax_linalg._solve(a, b) def _lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float], *, numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]: # TODO: add lstsq to lax_linalg and implement this function via those wrappers. # TODO: add custom jvp rule for more robust lstsq differentiation a, b = promote_dtypes_inexact(a, b) if a.shape[0] != b.shape[0]: raise ValueError("Leading dimensions of input arrays must match") b_orig_ndim = b.ndim if b_orig_ndim == 1: b = b[:, None] if a.ndim != 2: raise TypeError( f"{a.ndim}-dimensional array given. Array must be two-dimensional") if b.ndim != 2: raise TypeError( f"{b.ndim}-dimensional array given. Array must be one or two-dimensional") m, n = a.shape dtype = a.dtype if a.size == 0: s = jnp.empty(0, dtype=a.dtype) rank = jnp.array(0, dtype=int) x = jnp.empty((n, *b.shape[1:]), dtype=a.dtype) else: if rcond is None: rcond = jnp.finfo(dtype).eps * max(n, m) else: rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond) u, s, vt = svd(a, full_matrices=False) mask = s >= jnp.array(rcond, dtype=s.dtype) * s[0] rank = mask.sum() safe_s = jnp.where(mask, s, 1).astype(a.dtype) s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis] uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST) x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST) # Numpy returns empty residuals in some cases. To allow compilation, we # default to returning full residuals in all cases. if numpy_resid and (rank < n or m <= n): resid = jnp.asarray([]) else: b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST) resid = norm(b - b_estimate, axis=0) ** 2 if b_orig_ndim == 1: x = x.ravel() return x, resid, rank, s _jit_lstsq = jit(partial(_lstsq, numpy_resid=False)) @_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\ It has two important differences: 1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future the default will be `None`. Here, the default rcond is `None`. 2. In `np.linalg.lstsq` the returned residuals are empty for low-rank or over-determined solutions. Here, the residuals are returned in all cases, to make the function compatible with jit. The non-jit compatible numpy behavior can be recovered by passing numpy_resid=True. The lstsq function does not currently have a custom JVP rule, so the gradient is poorly behaved for some inputs, particularly for low-rank `a`. """)) def lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float] = None, *, numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]: check_arraylike("jnp.linalg.lstsq", a, b) if numpy_resid: return _lstsq(a, b, rcond, numpy_resid=True) return _jit_lstsq(a, b, rcond)