# Copyright 2022 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License """A JIT-compatible library for QDWH-based singular value decomposition. QDWH is short for QR-based dynamically weighted Halley iteration. The Halley iteration implemented through QR decmopositions is numerically stable and does not require solving a linear system involving the iteration matrix or computing its inversion. This is desirable for multicore and heterogeneous computing systems. References: Nakatsukasa, Yuji, and Nicholas J. Higham. "Stable and efficient spectral divide and conquer algorithms for the symmetric eigenvalue decomposition and the SVD." SIAM Journal on Scientific Computing 35, no. 3 (2013): A1325-A1349. https://epubs.siam.org/doi/abs/10.1137/120876605 Nakatsukasa, Yuji, Zhaojun Bai, and François Gygi. "Optimizing Halley's iteration for computing the matrix polar decomposition." SIAM Journal on Matrix Analysis and Applications 31, no. 5 (2010): 2700-2720. https://epubs.siam.org/doi/abs/10.1137/090774999 """ import functools from typing import Any, Sequence, Union import jax import jax.numpy as jnp from jax import lax from jax._src import core @functools.partial(jax.jit, static_argnums=(1, 2)) def _zero_svd(a: Any, full_matrices: bool, compute_uv: bool = True) -> Union[Any, Sequence[Any]]: """SVD on matrix of all zeros.""" m, n = a.shape k = min(m, n) s = jnp.zeros(shape=(k,), dtype=a.real.dtype) if compute_uv: if full_matrices: u = jnp.eye(m, m, dtype=a.dtype) vh = jnp.eye(n, n, dtype=a.dtype) else: u = jnp.eye(m, k, dtype=a.dtype) vh = jnp.eye(k, n, dtype=a.dtype) return (u, s, vh) else: return s @functools.partial(jax.jit, static_argnums=(1, 2, 3)) def _svd_tall_and_square_input( a: Any, hermitian: bool, compute_uv: bool, max_iterations: int ) -> Union[Any, Sequence[Any]]: """Singular value decomposition for m x n matrix and m >= n. Args: a: A matrix of shape `m x n` with `m >= n`. hermitian: True if `a` is Hermitian. compute_uv: Whether to compute also `u` and `v` in addition to `s`. max_iterations: The predefined maximum number of iterations of QDWH. Returns: A 3-tuple (`u`, `s`, `v`), where `u` is a unitary matrix of shape `m x n`, `s` is vector of length `n` containing the singular values in the descending order, `v` is a unitary matrix of shape `n x n`, and `a = (u * s) @ v.T.conj()`. For `compute_uv=False`, only `s` is returned. """ u, h, _, _ = lax.linalg.qdwh(a, is_hermitian=hermitian, max_iterations=max_iterations) # TODO: Uses `eigvals_only=True` if `compute_uv=False`. v, s = lax.linalg.eigh(h) # Flips the singular values in descending order. s_out = jnp.flip(s) if not compute_uv: return s_out # Reorders eigenvectors. v_out = jnp.fliplr(v) u_out = u @ v_out # Makes correction if computed `u` from qdwh is not unitary. # Section 5.5 of Nakatsukasa, Yuji, and Nicholas J. Higham. "Stable and # efficient spectral divide and conquer algorithms for the symmetric # eigenvalue decomposition and the SVD." SIAM Journal on Scientific Computing # 35, no. 3 (2013): A1325-A1349. def correct_rank_deficiency(u_out): u_out, r = lax.linalg.qr(u_out, full_matrices=False) u_out = u_out @ jnp.diag(lax.sign(jnp.diag(r))) return u_out eps = float(jnp.finfo(a.dtype).eps) u_out = lax.cond(s[0] < a.shape[1] * eps * s_out[0], correct_rank_deficiency, lambda u_out: u_out, operand=(u_out)) return (u_out, s_out, v_out) @functools.partial(jax.jit, static_argnums=(1, 2, 3, 4)) def _qdwh_svd(a: Any, full_matrices: bool, compute_uv: bool = True, hermitian: bool = False, max_iterations: int = 10) -> Union[Any, Sequence[Any]]: """Singular value decomposition. Args: a: A matrix of shape `m x n`. full_matrices: If True, `u` and `vh` have the shapes `m x m` and `n x n`, respectively. If False, the shapes are `m x k` and `k x n`, respectively, where `k = min(m, n)`. compute_uv: Whether to compute also `u` and `v` in addition to `s`. hermitian: True if `a` is Hermitian. max_iterations: The predefined maximum number of iterations of QDWH. Returns: A 3-tuple (`u`, `s`, `vh`), where `u` and `vh` are unitary matrices, `s` is vector of length `k` containing the singular values in the non-increasing order, and `k = min(m, n)`. The shapes of `u` and `vh` depend on the value of `full_matrices`. For `compute_uv=False`, only `s` is returned. """ m, n = a.shape is_flip = False if m < n: a = a.T.conj() m, n = a.shape is_flip = True reduce_to_square = False if full_matrices: q_full, a_full = lax.linalg.qr(a, full_matrices=True) q = q_full[:, :n] u_out_null = q_full[:, n:] a = a_full[:n, :] reduce_to_square = True else: # The constant `1.15` comes from Yuji Nakatsukasa's implementation # https://www.mathworks.com/matlabcentral/fileexchange/36830-symmetric-eigenvalue-decomposition-and-the-svd?s_tid=FX_rc3_behav if m > 1.15 * n: q, a = lax.linalg.qr(a, full_matrices=False) reduce_to_square = True if not compute_uv: with jax.default_matmul_precision('float32'): return _svd_tall_and_square_input(a, hermitian, compute_uv, max_iterations) with jax.default_matmul_precision('float32'): u_out, s_out, v_out = _svd_tall_and_square_input( a, hermitian, compute_uv, max_iterations) if reduce_to_square: u_out = q @ u_out if full_matrices: u_out = jnp.hstack((u_out, u_out_null)) if is_flip: return(v_out, s_out, u_out.T.conj()) return (u_out, s_out, v_out.T.conj()) @functools.partial(jax.jit, static_argnums=(1, 2, 3, 4)) def svd(a: Any, full_matrices: bool, compute_uv: bool = True, hermitian: bool = False, max_iterations: int = 10) -> Union[Any, Sequence[Any]]: """Singular value decomposition. Args: a: A matrix of shape `m x n`. full_matrices: If True, `u` and `vh` have the shapes `m x m` and `n x n`, respectively. If False, the shapes are `m x k` and `k x n`, respectively, where `k = min(m, n)`. compute_uv: Whether to compute also `u` and `v` in addition to `s`. hermitian: True if `a` is Hermitian. max_iterations: The predefined maximum number of iterations of QDWH. Returns: A 3-tuple (`u`, `s`, `vh`), where `u` and `vh` are unitary matrices, `s` is vector of length `k` containing the singular values in the non-increasing order, and `k = min(m, n)`. The shapes of `u` and `vh` depend on the value of `full_matrices`. For `compute_uv=False`, only `s` is returned. """ full_matrices = core.concrete_or_error( bool, full_matrices, 'The `full_matrices` argument must be statically ' 'specified to use `svd` within JAX transformations.') compute_uv = core.concrete_or_error( bool, compute_uv, 'The `compute_uv` argument must be statically ' 'specified to use `svd` within JAX transformations.') hermitian = core.concrete_or_error( bool, hermitian, 'The `hermitian` argument must be statically ' 'specified to use `qdwh` within JAX transformations.') max_iterations = core.concrete_or_error( int, max_iterations, 'The `max_iterations` argument must be statically ' 'specified to use `qdwh` within JAX transformations.') # QDWH algorithm fails at zero-matrix `A` and produces all NaNs, which can # be seen from a dynamically weighted Halley (DWH) iteration: # X_{k+1} = X_k(a_k I + b_k {X_k}^H X_k)(I + c_k {X_k}^H X_k)^{−1} and # X_0 = A/alpha, where alpha = ||A||_2, the triplet (a_k, b_k, c_k) are # weighting parameters, and X_k denotes the k^{th} iterate. return jax.lax.cond(jnp.all(a == 0), functools.partial(_zero_svd, full_matrices=full_matrices, compute_uv=compute_uv), functools.partial(_qdwh_svd, full_matrices=full_matrices, compute_uv=compute_uv, hermitian=hermitian, max_iterations=max_iterations), operand=(a))