2195 lines
82 KiB
Python
2195 lines
82 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import inspect
|
|
import functools
|
|
from functools import partial
|
|
import math
|
|
from typing import cast, Any, Callable, List, Literal, Optional, Tuple, TypeVar, Union, overload
|
|
import warnings
|
|
|
|
import numpy as np
|
|
|
|
import jax
|
|
from jax import lax
|
|
|
|
from jax._src import ad_util
|
|
from jax._src import api
|
|
from jax._src import dispatch
|
|
from jax._src import dtypes
|
|
from jax._src.core import (
|
|
Primitive, ShapedArray, raise_to_shaped, is_constant_dim, is_constant_shape)
|
|
from jax._src.interpreters import ad
|
|
from jax._src.interpreters import batching
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.lax import control_flow
|
|
from jax._src.lax import eigh as lax_eigh
|
|
from jax._src.lax import lax as lax_internal
|
|
from jax._src.lax import svd as lax_svd
|
|
from jax._src.lax.lax import (
|
|
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
|
|
_input_dtype)
|
|
from jax._src.lib import gpu_linalg
|
|
from jax._src.lib import gpu_solver
|
|
from jax._src.lib import gpu_sparse
|
|
from jax._src.lib import lapack
|
|
from jax._src.lib import version as jaxlib_version
|
|
from jax._src.lib import xla_client
|
|
from jax._src.lib.mlir import ir
|
|
from jax._src.lib.mlir.dialects import chlo
|
|
from jax._src.lib.mlir.dialects import hlo
|
|
from jax._src.numpy import lax_numpy as jnp
|
|
from jax._src.numpy import reductions
|
|
from jax._src.numpy import ufuncs
|
|
from jax._src.numpy.vectorize import vectorize
|
|
from jax._src.typing import Array, ArrayLike
|
|
|
|
xops = xla_client.ops
|
|
|
|
TFun = TypeVar('TFun', bound=Callable[..., Any])
|
|
|
|
# traceables
|
|
|
|
# TODO(phawkins): remove backward compatibility shim after 2022/08/11.
|
|
def _warn_on_positional_kwargs(f: TFun) -> TFun:
|
|
"""Decorator used for backward compatibility of keyword-only arguments.
|
|
|
|
Some functions were changed to mark their keyword arguments as keyword-only.
|
|
This decorator allows existing code to keep working temporarily, while issuing
|
|
a warning if a now keyword-only parameter is passed positionally."""
|
|
sig = inspect.signature(f)
|
|
pos_names = [name for name, p in sig.parameters.items()
|
|
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD]
|
|
kwarg_names = [name for name, p in sig.parameters.items()
|
|
if p.kind == inspect.Parameter.KEYWORD_ONLY]
|
|
|
|
# This decorator assumes that all arguments to `f` are either
|
|
# positional-or-keyword or keyword-only.
|
|
assert len(pos_names) + len(kwarg_names) == len(sig.parameters)
|
|
|
|
@functools.wraps(f)
|
|
def wrapped(*args, **kwargs):
|
|
if len(args) < len(pos_names):
|
|
a = pos_names[len(args)]
|
|
raise TypeError(f"{f.__name__} missing required positional argument: {a}")
|
|
|
|
pos_args = args[:len(pos_names)]
|
|
extra_kwargs = args[len(pos_names):]
|
|
|
|
if len(extra_kwargs) > len(kwarg_names):
|
|
raise TypeError(f"{f.__name__} takes at most {len(sig.parameters)} "
|
|
f" arguments but {len(args)} were given.")
|
|
|
|
for name, value in zip(kwarg_names, extra_kwargs):
|
|
if name in kwargs:
|
|
raise TypeError(f"{f.__name__} got multiple values for argument: "
|
|
f"{name}")
|
|
|
|
warnings.warn(f"Argument {name} to {f.__name__} is now a keyword-only "
|
|
"argument. Support for passing it positionally will be "
|
|
"removed in an upcoming JAX release.",
|
|
DeprecationWarning)
|
|
kwargs[name] = value
|
|
return f(*pos_args, **kwargs)
|
|
|
|
return cast(TFun, wrapped)
|
|
|
|
@_warn_on_positional_kwargs
|
|
def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array:
|
|
"""Cholesky decomposition.
|
|
|
|
Computes the Cholesky decomposition
|
|
|
|
.. math::
|
|
A = L . L^H
|
|
|
|
of square matrices, :math:`A`, such that :math:`L`
|
|
is lower triangular. The matrices of :math:`A` must be positive-definite and
|
|
either Hermitian, if complex, or symmetric, if real.
|
|
|
|
Args:
|
|
x: A batch of square Hermitian (symmetric if real) positive-definite
|
|
matrices with shape ``[..., n, n]``.
|
|
symmetrize_input: If ``True``, the matrix is symmetrized before Cholesky
|
|
decomposition by computing :math:`\\frac{1}{2}(x + x^H)`. If ``False``,
|
|
only the lower triangle of ``x`` is used; the upper triangle is ignored
|
|
and not accessed.
|
|
|
|
Returns:
|
|
The Cholesky decomposition as a matrix with the same dtype as ``x`` and
|
|
shape ``[..., n, n]``. If Cholesky decomposition fails, returns a matrix
|
|
full of NaNs. The behavior on failure may change in the future.
|
|
"""
|
|
if symmetrize_input:
|
|
x = symmetrize(x)
|
|
return jnp.tril(cholesky_p.bind(x))
|
|
|
|
@_warn_on_positional_kwargs
|
|
def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
|
|
compute_right_eigenvectors: bool = True) -> List[Array]:
|
|
"""Eigendecomposition of a general matrix.
|
|
|
|
Nonsymmetric eigendecomposition is at present only implemented on CPU.
|
|
"""
|
|
return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
|
|
compute_right_eigenvectors=compute_right_eigenvectors)
|
|
|
|
@_warn_on_positional_kwargs
|
|
def eigh(x: Array, *, lower: bool = True, symmetrize_input: bool = True,
|
|
sort_eigenvalues: bool = True) -> Tuple[Array, Array]:
|
|
r"""Eigendecomposition of a Hermitian matrix.
|
|
|
|
Computes the eigenvectors and eigenvalues of a complex Hermitian or real
|
|
symmetric square matrix.
|
|
|
|
Args:
|
|
x: A batch of square complex Hermitian or real symmetric matrices with shape
|
|
``[..., n, n]``.
|
|
lower: If ``symmetrize_input`` is ``False``, describes which triangle of the
|
|
input matrix to use. If ``symmetrize_input`` is ``False``, only the
|
|
triangle given by ``lower`` is accessed; the other triangle is ignored and
|
|
not accessed.
|
|
symmetrize_input: If ``True``, the matrix is symmetrized before the
|
|
eigendecomposition by computing :math:`\frac{1}{2}(x + x^H)`.
|
|
sort_eigenvalues: If ``True``, the eigenvalues will be sorted in ascending
|
|
order. If ``False`` the eigenvalues are returned in an
|
|
implementation-defined order.
|
|
|
|
Returns:
|
|
A tuple ``(v, w)``.
|
|
|
|
``v`` is an array with the same dtype as ``x`` such that ``v[..., :, i]`` is
|
|
the normalized eigenvector corresponding to eigenvalue ``w[..., i]``.
|
|
|
|
``w`` is an array with the same dtype as ``x`` (or its real counterpart if
|
|
complex) with shape ``[..., n]`` containing the eigenvalues of ``x`` in
|
|
ascending order(each repeated according to its multiplicity).
|
|
"""
|
|
if symmetrize_input:
|
|
x = symmetrize(x)
|
|
v, w = eigh_p.bind(x, lower=lower, sort_eigenvalues=sort_eigenvalues)
|
|
return v, w
|
|
|
|
|
|
def lu_pivots_to_permutation(pivots: ArrayLike, permutation_size: int) -> Array:
|
|
"""Converts the pivots (row swaps) returned by LU to a permutation.
|
|
|
|
We build a permutation rather than applying `pivots` directly to the rows
|
|
of a matrix because lax loops aren't differentiable.
|
|
|
|
Args:
|
|
pivots: an int32 array of shape (..., k) of row swaps to perform
|
|
permutation_size: the size of the output permutation. Has to be >= k.
|
|
|
|
Returns:
|
|
An int32 array of shape (..., permutation_size).
|
|
"""
|
|
permutation = lu_pivots_to_permutation_p.bind(
|
|
pivots, permutation_size=int(permutation_size))
|
|
return permutation
|
|
|
|
|
|
def lu(x: ArrayLike) -> Tuple[Array, Array, Array]:
|
|
"""LU decomposition with partial pivoting.
|
|
|
|
Computes the matrix decomposition:
|
|
|
|
.. math::
|
|
P.A = L.U
|
|
|
|
where :math:`P` is a permutation of the rows of :math:`A`, :math:`L` is a
|
|
lower-triangular matrix with unit-diagonal elements, and :math:`U` is an
|
|
upper-triangular matrix.
|
|
|
|
Args:
|
|
x: A batch of matrices with shape ``[..., m, n]``.
|
|
|
|
Returns:
|
|
A tuple ``(lu, pivots, permutation)``.
|
|
|
|
``lu`` is a batch of matrices with the same shape and dtype as ``x``
|
|
containing the :math:`L` matrix in its lower triangle and the :math:`U`
|
|
matrix in its upper triangle. The (unit) diagonal elements of :math:`L` are
|
|
not represented explicitly.
|
|
|
|
``pivots`` is an int32 array with shape ``[..., min(m, n)]`` representing a
|
|
sequence of row swaps that should be performed on :math:`A`.
|
|
|
|
``permutation`` is an alternative representation of the sequence of row
|
|
swaps as a permutation, represented as an int32 array with shape
|
|
``[..., m]``.
|
|
"""
|
|
lu, pivots, permutation = lu_p.bind(x)
|
|
return lu, pivots, permutation
|
|
|
|
@_warn_on_positional_kwargs
|
|
def qr(x: ArrayLike, *, full_matrices: bool = True) -> Tuple[Array, Array]:
|
|
"""QR decomposition.
|
|
|
|
Computes the QR decomposition
|
|
|
|
.. math::
|
|
A = Q . R
|
|
|
|
of matrices :math:`A`, such that :math:`Q` is a unitary (orthogonal) matrix,
|
|
and :math:`R` is an upper-triangular matrix.
|
|
|
|
Args:
|
|
x: A batch of matrices with shape ``[..., m, n]``.
|
|
full_matrices: Determines if full or reduced matrices are returned; see
|
|
below.
|
|
|
|
Returns:
|
|
A pair of arrays ``(q, r)``.
|
|
|
|
Array ``q`` is a unitary (orthogonal) matrix,
|
|
with shape ``[..., m, m]`` if ``full_matrices=True``, or
|
|
``[..., m, min(m, n)]`` if ``full_matrices=False``.
|
|
|
|
Array ``r`` is an upper-triangular matrix with shape ``[..., m, n]`` if
|
|
``full_matrices=True``, or ``[..., min(m, n), n]`` if
|
|
``full_matrices=False``.
|
|
"""
|
|
q, r = qr_p.bind(x, full_matrices=full_matrices)
|
|
return q, r
|
|
|
|
@overload
|
|
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: Literal[True]) -> Tuple[Array, Array, Array]: ...
|
|
|
|
@overload
|
|
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: Literal[False]) -> Array: ...
|
|
|
|
@overload
|
|
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Union[Array, Tuple[Array, Array, Array]]: ...
|
|
|
|
# TODO: Add `max_qdwh_iterations` to the function signature for TPU SVD.
|
|
@_warn_on_positional_kwargs
|
|
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Union[Array, Tuple[Array, Array, Array]]:
|
|
"""Singular value decomposition.
|
|
|
|
Returns the singular values if compute_uv is False, otherwise returns a triple
|
|
containing the left singular vectors, the singular values and the adjoint of
|
|
the right singular vectors.
|
|
"""
|
|
result = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv)
|
|
if compute_uv:
|
|
s, u, v = result
|
|
return u, s, v
|
|
else:
|
|
s, = result
|
|
return s
|
|
|
|
@_warn_on_positional_kwargs
|
|
def triangular_solve(a: ArrayLike, b: ArrayLike, *,
|
|
left_side: bool = False, lower: bool = False,
|
|
transpose_a: bool = False, conjugate_a: bool = False,
|
|
unit_diagonal: bool = False) -> Array:
|
|
r"""Triangular solve.
|
|
|
|
Solves either the matrix equation
|
|
|
|
.. math::
|
|
\mathit{op}(A) . X = B
|
|
|
|
if ``left_side`` is ``True`` or
|
|
|
|
.. math::
|
|
X . \mathit{op}(A) = B
|
|
|
|
if ``left_side`` is ``False``.
|
|
|
|
``A`` must be a lower or upper triangular square matrix, and where
|
|
:math:`\mathit{op}(A)` may either transpose :math:`A` if ``transpose_a``
|
|
is ``True`` and/or take its complex conjugate if ``conjugate_a`` is ``True``.
|
|
|
|
Args:
|
|
a: A batch of matrices with shape ``[..., m, m]``.
|
|
b: A batch of matrices with shape ``[..., m, n]`` if ``left_side`` is
|
|
``True`` or shape ``[..., n, m]`` otherwise.
|
|
left_side: describes which of the two matrix equations to solve; see above.
|
|
lower: describes which triangle of ``a`` should be used. The other triangle
|
|
is ignored.
|
|
transpose_a: if ``True``, the value of ``a`` is transposed.
|
|
conjugate_a: if ``True``, the complex conjugate of ``a`` is used in the
|
|
solve. Has no effect if ``a`` is real.
|
|
unit_diagonal: if ``True``, the diagonal of ``a`` is assumed to be unit
|
|
(all 1s) and not accessed.
|
|
|
|
Returns:
|
|
A batch of matrices the same shape and dtype as ``b``.
|
|
"""
|
|
conjugate_a = conjugate_a and jnp.issubdtype(lax.dtype(a), jnp.complexfloating)
|
|
singleton = jnp.ndim(b) == jnp.ndim(a) - 1
|
|
if singleton:
|
|
b = jnp.expand_dims(b, -1 if left_side else -2)
|
|
out = triangular_solve_p.bind(
|
|
a, b, left_side=left_side, lower=lower, transpose_a=transpose_a,
|
|
conjugate_a=conjugate_a, unit_diagonal=unit_diagonal)
|
|
if singleton:
|
|
out = out[..., 0] if left_side else out[..., 0, :]
|
|
return out
|
|
|
|
|
|
# utilities
|
|
@partial(vectorize, signature='(n,m),(m)->(n)')
|
|
def _matvec_multiply(a: Array, b: Array) -> Array:
|
|
return lax.dot(a, b, precision=lax.Precision.HIGHEST)
|
|
|
|
def _check_solve_shapes(a: Array, b: Array):
|
|
if not (a.ndim >= 2 and b.ndim in [a.ndim, a.ndim - 1] and
|
|
a.shape[-1] == a.shape[-2] == b.shape[a.ndim - 2]):
|
|
raise ValueError(
|
|
"The arguments to solve must have shapes a=[..., m, m] and "
|
|
f"b=[..., m, k] or b=[..., m]; got a={a.shape} and b={b.shape}")
|
|
|
|
def _solve(a: Array, b: Array) -> Array:
|
|
_check_solve_shapes(a, b)
|
|
|
|
# Broadcast leading dimensions of b to the shape of a, as is required by
|
|
# custom_linear_solve.
|
|
out_shape = tuple(d_a if d_b == 1 else d_b
|
|
for d_a, d_b in zip(a.shape[:-1] + (1,), b.shape))
|
|
b = jnp.broadcast_to(b, out_shape)
|
|
|
|
# With custom_linear_solve, we can reuse the same factorization when
|
|
# computing sensitivities. This is considerably faster.
|
|
lu_, _, permutation = lu(lax.stop_gradient(a))
|
|
custom_solve = partial(
|
|
lax.custom_linear_solve,
|
|
lambda x: _matvec_multiply(a, x),
|
|
solve=lambda _, x: lu_solve(lu_, permutation, x, trans=0),
|
|
transpose_solve=lambda _, x: lu_solve(lu_, permutation, x, trans=1))
|
|
if a.ndim == b.ndim + 1:
|
|
# b.shape == [..., m]
|
|
return custom_solve(b)
|
|
else:
|
|
# b.shape == [..., m, k]
|
|
return api.vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
|
|
|
|
def _T(x: Array) -> Array: return jnp.swapaxes(x, -1, -2)
|
|
def _H(x: Array) -> Array: return ufuncs.conj(_T(x))
|
|
def symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
|
|
|
|
# primitives
|
|
|
|
_cpu_lapack_types = {np.dtype(np.float32), np.dtype(np.float64),
|
|
np.dtype(np.complex64), np.dtype(np.complex128)}
|
|
|
|
# Cholesky decomposition
|
|
|
|
def _cholesky_jvp_rule(primals, tangents):
|
|
x, = primals
|
|
sigma_dot, = tangents
|
|
L = jnp.tril(cholesky_p.bind(x))
|
|
|
|
# Forward-mode rule from https://arxiv.org/pdf/1602.07527.pdf
|
|
def phi(X):
|
|
l = jnp.tril(X)
|
|
return l / lax.expand_dims(
|
|
lax_internal._const(X, 1) + jnp.eye(X.shape[-1], dtype=X.dtype),
|
|
range(l.ndim - 2))
|
|
|
|
tmp = triangular_solve(L, sigma_dot, left_side=False, transpose_a=True,
|
|
conjugate_a=True, lower=True)
|
|
L_dot = lax.batch_matmul(L, phi(triangular_solve(
|
|
L, tmp, left_side=True, transpose_a=False, lower=True)),
|
|
precision=lax.Precision.HIGHEST)
|
|
return L, L_dot
|
|
|
|
def _cholesky_batching_rule(batched_args, batch_dims):
|
|
x, = batched_args
|
|
bd, = batch_dims
|
|
x = batching.moveaxis(x, bd, 0)
|
|
return cholesky(x), 0
|
|
|
|
cholesky_p = standard_unop(_float | _complex, 'cholesky')
|
|
ad.primitive_jvps[cholesky_p] = _cholesky_jvp_rule
|
|
batching.primitive_batchers[cholesky_p] = _cholesky_batching_rule
|
|
|
|
def _cholesky_lowering(ctx, x):
|
|
return hlo.CholeskyOp(x, lower=ir.BoolAttr.get(True)).results
|
|
|
|
mlir.register_lowering(cholesky_p, _cholesky_lowering)
|
|
|
|
def _cholesky_cpu_gpu_lowering(potrf_impl, ctx, operand):
|
|
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
|
raise NotImplementedError("Shape polymorphism for custom call is not implemented (cholesky); b/261671778")
|
|
operand_aval, = ctx.avals_in
|
|
out_aval, = ctx.avals_out
|
|
batch_dims = operand_aval.shape[:-2]
|
|
result, info = potrf_impl(operand_aval.dtype, operand, lower=True)
|
|
ok = mlir.compare_hlo(
|
|
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
|
|
"EQ", "SIGNED")
|
|
select_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
|
return [_broadcasting_select_hlo(
|
|
ctx,
|
|
mlir.broadcast_in_dim(ctx, ok,
|
|
select_aval,
|
|
broadcast_dimensions=range(len(batch_dims))),
|
|
select_aval,
|
|
result, out_aval, _nan_like_hlo(ctx, out_aval), out_aval)]
|
|
|
|
mlir.register_lowering(
|
|
cholesky_p,
|
|
partial(_cholesky_cpu_gpu_lowering, lapack.potrf_hlo),
|
|
platform='cpu')
|
|
|
|
# Asymmetric eigendecomposition
|
|
|
|
def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors):
|
|
return dispatch.apply_primitive(
|
|
eig_p,
|
|
operand,
|
|
compute_left_eigenvectors=compute_left_eigenvectors,
|
|
compute_right_eigenvectors=compute_right_eigenvectors,
|
|
)
|
|
|
|
def eig_lower(*args, **kw):
|
|
raise NotImplementedError(
|
|
"Nonsymmetric eigendecomposition is only implemented on the CPU backend. "
|
|
"If your matrix is symmetric or Hermitian, you should use eigh instead.")
|
|
|
|
def eig_abstract_eval(operand, *, compute_left_eigenvectors,
|
|
compute_right_eigenvectors):
|
|
if isinstance(operand, ShapedArray):
|
|
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
|
|
raise ValueError("Argument to nonsymmetric eigendecomposition must have "
|
|
"shape [..., n, n], got shape {}".format(operand.shape))
|
|
|
|
batch_dims = operand.shape[:-2]
|
|
n = operand.shape[-1]
|
|
dtype = np.complex64 if dtypes.finfo(operand.dtype).bits == 32 else np.complex128
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
vl = vr = operand.update(shape=batch_dims + (n, n), dtype=dtype)
|
|
w = operand.update(shape=batch_dims + (n,), dtype=dtype)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
output = [w]
|
|
if compute_left_eigenvectors:
|
|
output.append(vl)
|
|
if compute_right_eigenvectors:
|
|
output.append(vr)
|
|
|
|
return tuple(output)
|
|
|
|
def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
|
|
compute_right_eigenvectors):
|
|
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
|
raise NotImplementedError("Shape polymorphism for custom call is not implemented (eig); b/261671778")
|
|
operand_aval, = ctx.avals_in
|
|
out_aval = ctx.avals_out[0]
|
|
batch_dims = operand_aval.shape[:-2]
|
|
|
|
w, vl, vr, info = lapack.geev_hlo(operand_aval.dtype, operand,
|
|
jobvl=compute_left_eigenvectors,
|
|
jobvr=compute_right_eigenvectors)
|
|
|
|
ok = mlir.compare_hlo(
|
|
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
|
|
"EQ", "SIGNED")
|
|
select_w_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_))
|
|
w = _broadcasting_select_hlo(
|
|
ctx,
|
|
mlir.broadcast_in_dim(ctx, ok, select_w_aval,
|
|
broadcast_dimensions=range(len(batch_dims))),
|
|
select_w_aval,
|
|
w, out_aval, _nan_like_hlo(ctx, out_aval), out_aval)
|
|
output = [w]
|
|
|
|
if compute_left_eigenvectors:
|
|
aval = ctx.avals_out[len(output)]
|
|
select_vl_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
|
vl = _broadcasting_select_hlo(
|
|
ctx,
|
|
mlir.broadcast_in_dim(ctx, ok, select_vl_aval,
|
|
broadcast_dimensions=range(len(batch_dims))),
|
|
select_vl_aval,
|
|
vl, aval, _nan_like_hlo(ctx, aval), aval)
|
|
output.append(vl)
|
|
|
|
if compute_right_eigenvectors:
|
|
aval = ctx.avals_out[len(output)]
|
|
select_vr_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
|
vr = _broadcasting_select_hlo(
|
|
ctx,
|
|
mlir.broadcast_in_dim(ctx, ok, select_vr_aval,
|
|
broadcast_dimensions=range(len(batch_dims))),
|
|
select_vr_aval,
|
|
vr, aval, _nan_like_hlo(ctx, aval), aval)
|
|
output.append(vr)
|
|
|
|
return output
|
|
|
|
|
|
def eig_batching_rule(batched_args, batch_dims, *, compute_left_eigenvectors,
|
|
compute_right_eigenvectors):
|
|
x, = batched_args
|
|
bd, = batch_dims
|
|
x = batching.moveaxis(x, bd, 0)
|
|
|
|
return (eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
|
|
compute_right_eigenvectors=compute_right_eigenvectors),
|
|
(0,) * (1 + compute_left_eigenvectors + compute_right_eigenvectors))
|
|
|
|
def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors,
|
|
compute_right_eigenvectors):
|
|
if compute_left_eigenvectors or compute_right_eigenvectors:
|
|
raise NotImplementedError(
|
|
'The derivatives of eigenvectors are not implemented, only '
|
|
'eigenvalues. See '
|
|
'https://github.com/google/jax/issues/2748 for discussion.')
|
|
# Formula for derivative of eigenvalues w.r.t. a is eqn 4.60 in
|
|
# https://arxiv.org/abs/1701.00392
|
|
a, = primals
|
|
da, = tangents
|
|
l, v = eig(a, compute_left_eigenvectors=False)
|
|
return [l], [reductions.sum(_solve(v, da.astype(v.dtype)) * _T(v), -1)]
|
|
|
|
eig_p = Primitive('eig')
|
|
eig_p.multiple_results = True
|
|
eig_p.def_impl(eig_impl)
|
|
eig_p.def_abstract_eval(eig_abstract_eval)
|
|
mlir.register_lowering(eig_p, eig_lower)
|
|
mlir.register_lowering(eig_p, _eig_cpu_lowering, platform='cpu')
|
|
batching.primitive_batchers[eig_p] = eig_batching_rule
|
|
ad.primitive_jvps[eig_p] = eig_jvp_rule
|
|
|
|
|
|
# Symmetric/Hermitian eigendecomposition
|
|
|
|
|
|
def eigh_jacobi(x: ArrayLike, *, lower: bool = True,
|
|
sort_eigenvalues: bool = True) -> Tuple[Array, Array]:
|
|
"""Helper Jacobi eigendecomposition implemented by XLA.
|
|
|
|
Used as a subroutine of QDWH-eig on TPU."""
|
|
w, v = eigh_jacobi_p.bind(x, lower=lower, sort_eigenvalues=sort_eigenvalues)
|
|
return w, v
|
|
|
|
def _eigh_jacobi_impl(operand, *, lower, sort_eigenvalues):
|
|
w, v = dispatch.apply_primitive(eigh_jacobi_p, operand, lower=lower,
|
|
sort_eigenvalues=sort_eigenvalues)
|
|
return w, v
|
|
|
|
def _eigh_jacobi_abstract_eval(operand, *, lower, sort_eigenvalues):
|
|
if isinstance(operand, ShapedArray):
|
|
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
|
|
raise ValueError(
|
|
"Argument to symmetric eigendecomposition must have shape [..., n, n],"
|
|
"got shape {}".format(operand.shape))
|
|
|
|
batch_dims = operand.shape[:-2]
|
|
n = operand.shape[-1]
|
|
w = operand.update(shape=batch_dims + (n,),
|
|
dtype=lax_internal._complex_basetype(operand.dtype))
|
|
v = operand.update(shape=batch_dims + (n, n))
|
|
else:
|
|
w, v = operand, operand
|
|
return w, v
|
|
|
|
|
|
def _eigh_jacobi_lowering_rule(ctx, operand, lower, sort_eigenvalues):
|
|
operand_aval, = ctx.avals_in
|
|
if operand_aval.shape[-1] == 0:
|
|
reshape_aval = operand_aval.update(shape=operand_aval.shape[:-1])
|
|
return [
|
|
hlo.RealOp(mlir.reshape(ctx, operand, reshape_aval)).result,
|
|
operand,
|
|
]
|
|
|
|
eigvals_type = mlir.aval_to_ir_type(ctx.avals_out[0])
|
|
eigvecs_type = mlir.aval_to_ir_type(ctx.avals_out[1])
|
|
result_types = [eigvecs_type, eigvals_type]
|
|
|
|
backend_config = f"{int(lower)},{int(sort_eigenvalues)},100,1e-6"
|
|
|
|
if any(not is_constant_shape(aval_out.shape)
|
|
for aval_out in ctx.avals_out):
|
|
if jaxlib_version < (0, 4, 9):
|
|
raise ValueError("shape polymorphism with native lowering for eigh on "
|
|
"TPU requires jaxlib version 0.4.9.")
|
|
result_shapes = [
|
|
mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, aval_out.shape))
|
|
# The custom call returns the results swapped
|
|
for aval_out in list(reversed(ctx.avals_out))
|
|
]
|
|
else:
|
|
result_shapes = None
|
|
op = mlir.custom_call(
|
|
"Eigh",
|
|
result_types,
|
|
[operand],
|
|
backend_config=backend_config,
|
|
api_version=1,
|
|
result_shapes=result_shapes,
|
|
)
|
|
return op.results[1], op.results[0]
|
|
|
|
eigh_jacobi_p = Primitive('eigh_jacobi')
|
|
eigh_jacobi_p.multiple_results = True
|
|
eigh_jacobi_p.def_impl(_eigh_jacobi_impl)
|
|
eigh_jacobi_p.def_abstract_eval(_eigh_jacobi_abstract_eval)
|
|
mlir.register_lowering(eigh_jacobi_p, _eigh_jacobi_lowering_rule)
|
|
|
|
|
|
def _eigh_impl(operand, *, lower, sort_eigenvalues):
|
|
v, w = dispatch.apply_primitive(eigh_p, operand, lower=lower,
|
|
sort_eigenvalues=sort_eigenvalues)
|
|
return v, w
|
|
|
|
def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues):
|
|
if isinstance(operand, ShapedArray):
|
|
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
|
|
raise ValueError(
|
|
"Argument to symmetric eigendecomposition must have shape [..., n, n],"
|
|
"got shape {}".format(operand.shape))
|
|
|
|
batch_dims = operand.shape[:-2]
|
|
n = operand.shape[-1]
|
|
v = operand.update(shape=batch_dims + (n, n))
|
|
w = operand.update(shape=batch_dims + (n,),
|
|
dtype=lax_internal._complex_basetype(operand.dtype))
|
|
else:
|
|
v, w = operand, operand
|
|
return v, w
|
|
|
|
def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower,
|
|
sort_eigenvalues):
|
|
del sort_eigenvalues # The CPU/GPU implementations always sort.
|
|
operand_aval, = ctx.avals_in
|
|
v_aval, w_aval = ctx.avals_out
|
|
|
|
batch_dims = operand_aval.shape[:-2]
|
|
|
|
if jaxlib_version < (0, 4, 9):
|
|
if not is_constant_shape(operand_aval.shape):
|
|
raise NotImplementedError("Shape polymorphism for native lowering for "
|
|
"eigh requires "
|
|
"jaxlib version 0.4.9; b/261671778")
|
|
v, w, info = syevd_impl(operand_aval.dtype, operand, lower=lower)
|
|
else:
|
|
# The eigh implementation on CPU and GPU uses lapack helper routines to
|
|
# find the size of the workspace based on the non-batch dimensions.
|
|
# Therefore, we cannot yet support dynamic non-batch dimensions.
|
|
if not is_constant_shape(operand_aval.shape[-2:]):
|
|
raise NotImplementedError(
|
|
"Shape polymorphism for for native lowering for eigh is implemented "
|
|
f"only for the batch dimensions: {operand_aval.shape}")
|
|
|
|
batch_size_num = math.prod(batch_dims) if batch_dims else 1
|
|
batch_size = mlir.eval_dynamic_shape(ctx, (batch_size_num,))[0]
|
|
if isinstance(batch_size, int):
|
|
batch_size = mlir.ir_constant(np.int32(batch_size))
|
|
v_shape: ir.Value = mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, v_aval.shape))
|
|
w_shape: ir.Value = mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, w_aval.shape))
|
|
info_shape: ir.Value = mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, batch_dims))
|
|
v, w, info = syevd_impl(operand_aval.dtype, operand, batch_size,
|
|
v_shape, w_shape, info_shape,
|
|
lower=lower)
|
|
|
|
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
|
|
ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED")
|
|
select_v_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
|
v = _broadcasting_select_hlo(
|
|
ctx,
|
|
mlir.broadcast_in_dim(ctx, ok, select_v_aval,
|
|
broadcast_dimensions=range(len(batch_dims))),
|
|
select_v_aval,
|
|
v, v_aval, _nan_like_hlo(ctx, v_aval), v_aval)
|
|
select_w_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_))
|
|
w = _broadcasting_select_hlo(
|
|
ctx,
|
|
mlir.broadcast_in_dim(ctx, ok, select_w_aval,
|
|
broadcast_dimensions=range(len(batch_dims))),
|
|
select_w_aval,
|
|
w, w_aval, _nan_like_hlo(ctx, w_aval), w_aval)
|
|
return [v, w]
|
|
|
|
def _eigh_tpu_impl(x, *, lower, sort_eigenvalues):
|
|
*_, m, n = x.shape
|
|
assert m == n, (m, n)
|
|
|
|
termination_size = 256
|
|
if not is_constant_dim(m):
|
|
# TODO: maybe we can relax the check below for shape polymorphism?
|
|
raise NotImplementedError(
|
|
"Shape polymorphism for for native lowering for eigh is implemented "
|
|
f"only for the batch dimensions: {x.shape}")
|
|
if m <= termination_size:
|
|
eig_vals, eig_vecs = eigh_jacobi(x, lower=lower,
|
|
sort_eigenvalues=sort_eigenvalues)
|
|
return eig_vecs, eig_vals
|
|
|
|
def eigh_qdwh(x):
|
|
if len(x.shape) > 2:
|
|
return control_flow.map(eigh_qdwh, x)
|
|
|
|
# We should only look at elements from the lower/upper triangle. Reflects
|
|
# that triangle into the other triangle to form a Hermitian matrix.
|
|
if lower:
|
|
mask = jnp.tri(n, k=0, dtype=bool)
|
|
else:
|
|
mask = ufuncs.logical_not(jnp.tri(n, k=-1, dtype=bool))
|
|
if dtypes.issubdtype(x.dtype, jnp.complexfloating):
|
|
re = lax.select(mask, lax.real(x), _T(lax.real(x)))
|
|
if lower:
|
|
im_mask = jnp.tri(n, k=-1, dtype=bool)
|
|
else:
|
|
im_mask = ufuncs.logical_not(jnp.tri(n, k=0, dtype=bool))
|
|
im = lax.select(im_mask, lax.imag(x), jnp.zeros_like(lax.imag(x)))
|
|
im = lax.select(mask, im, -_T(im))
|
|
x = lax.complex(re, im)
|
|
else:
|
|
x = lax.select(mask, x, _T(x))
|
|
|
|
return lax_eigh.eigh(x, sort_eigenvalues=sort_eigenvalues,
|
|
termination_size=termination_size)
|
|
|
|
eig_vals, eig_vecs = eigh_qdwh(x)
|
|
return eig_vecs, eig_vals
|
|
|
|
def _eigh_jvp_rule(primals, tangents, *, lower, sort_eigenvalues):
|
|
# Derivative for eigh in the simplest case of distinct eigenvalues.
|
|
# This is classic nondegenerate perurbation theory, but also see
|
|
# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
|
|
# The general solution treating the case of degenerate eigenvalues is
|
|
# considerably more complicated. Ambitious readers may refer to the general
|
|
# methods below or refer to degenerate perturbation theory in physics.
|
|
# https://www.win.tue.nl/analysis/reports/rana06-33.pdf and
|
|
# https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf
|
|
a, = primals
|
|
a_dot, = tangents
|
|
|
|
v, w_real = eigh_p.bind(symmetrize(a), lower=lower,
|
|
sort_eigenvalues=sort_eigenvalues)
|
|
|
|
# for complex numbers we need eigenvalues to be full dtype of v, a:
|
|
w = w_real.astype(a.dtype)
|
|
eye_n = jnp.eye(a.shape[-1], dtype=a.dtype)
|
|
# carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
|
|
Fmat = ufuncs.reciprocal(eye_n + w[..., jnp.newaxis, :] - w[..., jnp.newaxis]) - eye_n
|
|
# eigh impl doesn't support batch dims, but future-proof the grad.
|
|
dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul,
|
|
precision=lax.Precision.HIGHEST)
|
|
vdag_adot_v = dot(dot(_H(v), a_dot), v)
|
|
dv = dot(v, ufuncs.multiply(Fmat, vdag_adot_v))
|
|
dw = ufuncs.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1))
|
|
return (v, w_real), (dv, dw)
|
|
|
|
def _eigh_batching_rule(batched_args, batch_dims, *, lower, sort_eigenvalues):
|
|
x, = batched_args
|
|
bd, = batch_dims
|
|
x = batching.moveaxis(x, bd, 0)
|
|
return eigh_p.bind(x, lower=lower, sort_eigenvalues=sort_eigenvalues), (0, 0)
|
|
|
|
eigh_p = Primitive('eigh')
|
|
eigh_p.multiple_results = True
|
|
eigh_p.def_impl(_eigh_impl)
|
|
eigh_p.def_abstract_eval(_eigh_abstract_eval)
|
|
ad.primitive_jvps[eigh_p] = _eigh_jvp_rule
|
|
batching.primitive_batchers[eigh_p] = _eigh_batching_rule
|
|
|
|
mlir.register_lowering(
|
|
eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_hlo),
|
|
platform='cpu')
|
|
|
|
if gpu_solver is not None:
|
|
mlir.register_lowering(
|
|
eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.cuda_syevd),
|
|
platform='cuda')
|
|
mlir.register_lowering(
|
|
eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.rocm_syevd),
|
|
platform='rocm')
|
|
|
|
mlir.register_lowering(
|
|
eigh_p, mlir.lower_fun(_eigh_tpu_impl, multiple_results=True),
|
|
platform='tpu')
|
|
|
|
|
|
_triangular_solve_dtype_rule = partial(
|
|
naryop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex),
|
|
'triangular_solve')
|
|
|
|
def _triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs):
|
|
if a.ndim < 2:
|
|
msg = "triangular_solve requires a.ndim to be at least 2, got {}."
|
|
raise TypeError(msg.format(a.ndim))
|
|
if b.ndim < 2:
|
|
msg = "triangular_solve requires b.ndim to be at least 2, got {}."
|
|
raise TypeError(msg.format(b.ndim))
|
|
if a.shape[-1] != a.shape[-2]:
|
|
msg = ("triangular_solve requires the last two dimensions of a to be equal "
|
|
"in size, got a.shape of {}.")
|
|
raise TypeError(msg.format(a.shape))
|
|
if a.shape[:-2] != b.shape[:-2]:
|
|
msg = ("triangular_solve requires both arguments to have the same number "
|
|
"of dimensions and equal batch dimensions, got {} and {}.")
|
|
raise TypeError(msg.format(a.shape, b.shape))
|
|
common_dim = -2 if left_side else -1
|
|
if a.shape[-1] != b.shape[common_dim]:
|
|
msg = "Incompatible shapes for arguments to triangular_solve: {} and {}."
|
|
raise TypeError(msg.format(a.shape, b.shape))
|
|
return b.shape
|
|
|
|
def _triangular_solve_jvp_rule_a(
|
|
g_a, ans, a, b, *, left_side, lower, transpose_a, conjugate_a,
|
|
unit_diagonal):
|
|
m, n = b.shape[-2:]
|
|
k = 1 if unit_diagonal else 0
|
|
g_a = jnp.tril(g_a, k=-k) if lower else jnp.triu(g_a, k=k)
|
|
g_a = lax.neg(g_a)
|
|
g_a = jnp.swapaxes(g_a, -1, -2) if transpose_a else g_a
|
|
g_a = ufuncs.conj(g_a) if conjugate_a else g_a
|
|
dot = partial(lax.dot if g_a.ndim == 2 else lax.batch_matmul,
|
|
precision=lax.Precision.HIGHEST)
|
|
|
|
def a_inverse(rhs):
|
|
return triangular_solve(a, rhs, left_side=left_side, lower=lower,
|
|
transpose_a=transpose_a, conjugate_a=conjugate_a,
|
|
unit_diagonal=unit_diagonal)
|
|
|
|
# triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs
|
|
# for matrix/vector inputs). Order these operations in whichever order is
|
|
# cheaper.
|
|
if left_side:
|
|
assert g_a.shape[-2:] == a.shape[-2:] == (m, m) and ans.shape[-2:] == (m, n)
|
|
if m > n:
|
|
return a_inverse(dot(g_a, ans)) # A^{-1} (∂A X)
|
|
else:
|
|
return dot(a_inverse(g_a), ans) # (A^{-1} ∂A) X
|
|
else:
|
|
assert g_a.shape[-2:] == a.shape[-2:] == (n, n) and ans.shape[-2:] == (m, n)
|
|
if m < n:
|
|
return a_inverse(dot(ans, g_a)) # (X ∂A) A^{-1}
|
|
else:
|
|
return dot(ans, a_inverse(g_a)) # X (∂A A^{-1})
|
|
|
|
def _triangular_solve_transpose_rule(
|
|
cotangent, a, b, *, left_side, lower, transpose_a, conjugate_a,
|
|
unit_diagonal):
|
|
# Triangular solve is nonlinear in its first argument and linear in its second
|
|
# argument, analogous to `div` but swapped.
|
|
assert not ad.is_undefined_primal(a) and ad.is_undefined_primal(b)
|
|
if type(cotangent) is ad_util.Zero:
|
|
cotangent_b = ad_util.Zero(b.aval)
|
|
else:
|
|
cotangent_b = triangular_solve(a, cotangent, left_side=left_side,
|
|
lower=lower, transpose_a=not transpose_a,
|
|
conjugate_a=conjugate_a,
|
|
unit_diagonal=unit_diagonal)
|
|
return [None, cotangent_b]
|
|
|
|
|
|
def _triangular_solve_batching_rule(batched_args, batch_dims, *, left_side,
|
|
lower, transpose_a, conjugate_a,
|
|
unit_diagonal):
|
|
x, y = batched_args
|
|
bx, by = batch_dims
|
|
if bx is batching.not_mapped:
|
|
if left_side:
|
|
y = batching.moveaxis(y, by, -1)
|
|
y_flat = y.reshape(y.shape[:-2] + (y.shape[-2] * y.shape[-1],))
|
|
bdim_out = y.ndim - 1
|
|
else:
|
|
y = batching.moveaxis(y, by, -2)
|
|
y_flat = y.reshape(y.shape[:-3] + (y.shape[-3] * y.shape[-2], y.shape[-1]))
|
|
bdim_out = y.ndim - 2
|
|
out_flat = triangular_solve(
|
|
x, y_flat, left_side=left_side, lower=lower,
|
|
transpose_a=transpose_a, conjugate_a=conjugate_a,
|
|
unit_diagonal=unit_diagonal)
|
|
return out_flat.reshape(y.shape), bdim_out
|
|
else:
|
|
size = next(t.shape[i] for t, i in zip(batched_args, batch_dims)
|
|
if i is not None)
|
|
x = batching.bdim_at_front(x, bx, size)
|
|
y = batching.bdim_at_front(y, by, size)
|
|
return triangular_solve(x, y, left_side=left_side, lower=lower,
|
|
transpose_a=transpose_a, conjugate_a=conjugate_a,
|
|
unit_diagonal=unit_diagonal), 0
|
|
|
|
triangular_solve_p = standard_primitive(
|
|
_triangular_solve_shape_rule, _triangular_solve_dtype_rule,
|
|
'triangular_solve')
|
|
ad.defjvp2(triangular_solve_p,
|
|
_triangular_solve_jvp_rule_a,
|
|
lambda g_b, _, a, b, **kws: triangular_solve(a, g_b, **kws))
|
|
ad.primitive_transposes[triangular_solve_p] = _triangular_solve_transpose_rule
|
|
batching.primitive_batchers[triangular_solve_p] = _triangular_solve_batching_rule
|
|
|
|
|
|
def _triangular_solve_lowering(
|
|
ctx, a, b, *, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
|
|
out_aval, = ctx.avals_out
|
|
if conjugate_a and not transpose_a:
|
|
a = chlo.ConjOp(a)
|
|
conjugate_a = False
|
|
if not transpose_a:
|
|
transpose = "NO_TRANSPOSE"
|
|
else:
|
|
transpose = "ADJOINT" if conjugate_a else "TRANSPOSE"
|
|
return hlo.TriangularSolveOp(
|
|
a, b, ir.BoolAttr.get(left_side),
|
|
ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal),
|
|
hlo.TransposeAttr.get(transpose)).results
|
|
|
|
mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering)
|
|
|
|
def _triangular_solve_cpu_lower(
|
|
ctx, a, b, *, left_side, lower, transpose_a,
|
|
conjugate_a, unit_diagonal):
|
|
a_aval, _ = ctx.avals_in
|
|
|
|
if conjugate_a and not transpose_a:
|
|
a = chlo.ConjOp(a).result
|
|
conjugate_a = False
|
|
if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types:
|
|
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype))
|
|
return [lapack.trsm_hlo(
|
|
a_aval.dtype, alpha,
|
|
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal)]
|
|
else:
|
|
# Fall back to the HLO implementation for unsupported types or batching.
|
|
# TODO: Consider swapping XLA for LAPACK in batched case
|
|
if transpose_a:
|
|
transpose = "ADJOINT" if conjugate_a else "TRANSPOSE"
|
|
else:
|
|
transpose = "NO_TRANSPOSE"
|
|
return hlo.TriangularSolveOp(a, b, ir.BoolAttr.get(left_side),
|
|
ir.BoolAttr.get(lower),
|
|
ir.BoolAttr.get(unit_diagonal),
|
|
hlo.TransposeAttr.get(transpose)).results
|
|
|
|
mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower,
|
|
platform='cpu')
|
|
|
|
|
|
# Support operation for LU decomposition: Transformation of the pivots returned
|
|
# by LU decomposition into permutations.
|
|
|
|
|
|
# Define this outside lu_pivots_to_permutation to ensure fori_loop cache hits
|
|
def _lu_pivots_body_fn(i, permutation_and_swaps):
|
|
permutation, swaps = permutation_and_swaps
|
|
batch_dims = swaps.shape[:-1]
|
|
j = swaps[..., i]
|
|
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims))
|
|
x = permutation[..., i]
|
|
y = permutation[iotas + (j,)]
|
|
permutation = permutation.at[..., i].set(y)
|
|
return permutation.at[iotas + (j,)].set(x), swaps
|
|
|
|
|
|
def _generic_lu_pivots_to_permutation(swaps, permutation_size):
|
|
"""Converts the pivots (row swaps) returned by LU to a permutation.
|
|
|
|
We build a permutation rather than applying `swaps` directly to the rows
|
|
of a matrix because lax loops aren't differentiable.
|
|
|
|
Args:
|
|
swaps: an array of shape (..., k) of row swaps to perform
|
|
permutation_size: the size of the output permutation. Should be >= k.
|
|
Returns:
|
|
An int32 array of shape (..., m).
|
|
"""
|
|
assert len(swaps.shape) >= 1
|
|
batch_dims = swaps.shape[:-1]
|
|
k = swaps.shape[-1]
|
|
m = permutation_size
|
|
|
|
permutation = lax.broadcasted_iota(jnp.int32, batch_dims + (m,),
|
|
len(batch_dims))
|
|
if m == 0:
|
|
return permutation
|
|
result, _ = lax.fori_loop(np.array(0, np.int32), np.array(k, np.int32),
|
|
_lu_pivots_body_fn, (permutation, swaps))
|
|
return result
|
|
|
|
|
|
def _lu_pivots_to_permutation_abstract_eval(pivots, *, permutation_size):
|
|
pivots = raise_to_shaped(pivots)
|
|
if isinstance(pivots, ShapedArray):
|
|
if pivots.ndim < 1 or pivots.dtype != np.dtype(np.int32):
|
|
raise ValueError(
|
|
'Argument to lu_pivots_to_permutation must have rank >= 1 and dtype '
|
|
'int32. Got shape={} and dtype={}'.format(pivots.shape, pivots.dtype))
|
|
|
|
if permutation_size < pivots.shape[-1]:
|
|
raise ValueError(
|
|
'Output permutation size {} has to exceed the trailing dimension of '
|
|
'the pivots. Got shape {}'.format(permutation_size, pivots.shape))
|
|
|
|
batch_dims = pivots.shape[:-1]
|
|
permutations = pivots.update(shape=batch_dims + (permutation_size,))
|
|
else:
|
|
permutations = pivots
|
|
|
|
return permutations
|
|
|
|
|
|
def _lu_pivots_to_permutation_batching_rule(batched_args, batch_dims, *,
|
|
permutation_size):
|
|
x, = batched_args
|
|
bd, = batch_dims
|
|
x = batching.moveaxis(x, bd, 0)
|
|
return lu_pivots_to_permutation_p.bind(
|
|
x, permutation_size=permutation_size), 0
|
|
|
|
def _lu_pivots_to_permutation_gpu_lowering(lowering, ctx, pivots, *,
|
|
permutation_size):
|
|
return [lowering(pivots, permutation_size=permutation_size)]
|
|
|
|
|
|
lu_pivots_to_permutation_p = Primitive('lu_pivots_to_permutation')
|
|
lu_pivots_to_permutation_p.multiple_results = False
|
|
lu_pivots_to_permutation_p.def_impl(
|
|
partial(dispatch.apply_primitive, lu_pivots_to_permutation_p))
|
|
lu_pivots_to_permutation_p.def_abstract_eval(
|
|
_lu_pivots_to_permutation_abstract_eval)
|
|
batching.primitive_batchers[lu_pivots_to_permutation_p] = (
|
|
_lu_pivots_to_permutation_batching_rule)
|
|
mlir.register_lowering(
|
|
lu_pivots_to_permutation_p,
|
|
mlir.lower_fun(_generic_lu_pivots_to_permutation, multiple_results=False))
|
|
mlir.register_lowering(
|
|
lu_pivots_to_permutation_p,
|
|
partial(_lu_pivots_to_permutation_gpu_lowering,
|
|
gpu_linalg.cuda_lu_pivots_to_permutation),
|
|
platform='cuda')
|
|
mlir.register_lowering(
|
|
lu_pivots_to_permutation_p,
|
|
partial(_lu_pivots_to_permutation_gpu_lowering,
|
|
gpu_linalg.hip_lu_pivots_to_permutation),
|
|
platform='rocm')
|
|
|
|
|
|
# LU decomposition
|
|
|
|
# Computes a pivoted LU decomposition such that
|
|
# PA = LU
|
|
# In the style of LAPACK, LU are stored in the same matrix.
|
|
|
|
def _lu_unblocked(a):
|
|
"""Unblocked LU decomposition, as a rolled loop."""
|
|
m, n = a.shape
|
|
def body(k, state):
|
|
pivot, perm, a = state
|
|
m_idx = jnp.arange(m)
|
|
n_idx = jnp.arange(n)
|
|
|
|
if jnp.issubdtype(a.dtype, jnp.complexfloating):
|
|
t = a[:, k]
|
|
magnitude = ufuncs.abs(ufuncs.real(t)) + ufuncs.abs(ufuncs.imag(t))
|
|
else:
|
|
magnitude = ufuncs.abs(a[:, k])
|
|
i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf))
|
|
pivot = pivot.at[k].set(i)
|
|
a = a.at[[k, i],].set(a[[i, k],])
|
|
perm = perm.at[[i, k],].set(perm[[k, i],])
|
|
|
|
# a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes
|
|
x = a[k, k]
|
|
a = a.at[:, k].set(jnp.where(m_idx > k, a[:, k] / x, a[:, k]))
|
|
|
|
# a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:])
|
|
a = a - jnp.where((m_idx[:, None] > k) & (n_idx[None, :] > k),
|
|
jnp.outer(a[:, k], a[k, :]), jnp.array(0, dtype=a.dtype))
|
|
return pivot, perm, a
|
|
|
|
pivot = jnp.zeros((min(m, n),), dtype=jnp.int32)
|
|
perm = jnp.arange(m, dtype=jnp.int32)
|
|
if m == 0 and n == 0:
|
|
# If the array is empty, the loop body never executes but tracing it to a
|
|
# jaxpr fails because the indexing cannot succeed.
|
|
return (pivot, perm, a)
|
|
return lax.fori_loop(0, min(m, n), body, (pivot, perm, a))
|
|
|
|
|
|
def _lu_blocked(a, block_size=128):
|
|
"""Blocked LU decomposition, as an unrolled loop."""
|
|
m, n = a.shape
|
|
r = min(m, n)
|
|
pivot = jnp.zeros((r,), dtype=jnp.int32)
|
|
perm = jnp.arange(m, dtype=jnp.int32)
|
|
for k in range(0, r, block_size):
|
|
b = min(r - k, block_size)
|
|
block_pivot, block_perm, lu_block = _lu_unblocked(a[k:, k:k+b])
|
|
|
|
pivot = pivot.at[k:k+b].set(block_pivot + k)
|
|
perm = perm.at[k:].set(perm[block_perm + k])
|
|
a = a.at[k:, :].set(a[block_perm + k, :])
|
|
a = a.at[k:, k:k+b].set(lu_block)
|
|
|
|
if k + b < n:
|
|
a = a.at[k:k+b, k+b:].set(
|
|
triangular_solve(a[k:k+b, k:k+b], a[k:k+b, k+b:], left_side=True,
|
|
lower=True, unit_diagonal=True))
|
|
a = a.at[k+b:, k+b:].add(-lax.dot(a[k+b:, k:k+b], a[k:k+b, k+b:],
|
|
precision=lax.Precision.HIGHEST))
|
|
return a, pivot, perm
|
|
|
|
def _lu_python(x):
|
|
"""Default LU decomposition in Python, where no better version exists."""
|
|
batch_dims = x.shape[:-2]
|
|
fn = _lu_blocked
|
|
for _ in range(len(batch_dims)):
|
|
fn = api.vmap(fn)
|
|
|
|
return fn(x)
|
|
|
|
def _lu_impl(operand):
|
|
lu, pivot, perm = dispatch.apply_primitive(lu_p, operand)
|
|
return lu, pivot, perm
|
|
|
|
def _lu_abstract_eval(operand):
|
|
operand = raise_to_shaped(operand)
|
|
if isinstance(operand, ShapedArray):
|
|
if operand.ndim < 2:
|
|
raise ValueError("Argument to LU decomposition must have ndims >= 2")
|
|
|
|
batch_dims = operand.shape[:-2]
|
|
m = operand.shape[-2]
|
|
n = operand.shape[-1]
|
|
pivot = operand.update(shape=batch_dims + (min(m, n),), dtype=jnp.int32)
|
|
perm = operand.update(shape=batch_dims + (m,), dtype=jnp.int32)
|
|
else:
|
|
pivot = operand
|
|
perm = operand
|
|
return operand, pivot, perm
|
|
|
|
def _lu_jvp_rule(primals, tangents):
|
|
a, = primals
|
|
a_dot, = tangents
|
|
lu, pivots, permutation = lu_p.bind(a)
|
|
|
|
a_shape = jnp.shape(a)
|
|
m, n = a_shape[-2:]
|
|
dtype = lax.dtype(a)
|
|
k = min(m, n)
|
|
|
|
batch_dims = a_shape[:-2]
|
|
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1,)))
|
|
x = a_dot[iotas[:-1] + (permutation, slice(None))]
|
|
|
|
# Differentiation of Matrix Functionals Using Triangular Factorization
|
|
# F. R. De Hoog, R. S. Anderssen, and M. A. Lukas
|
|
#
|
|
# LU = A
|
|
# ==> L'U + LU' = A'
|
|
# ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U)
|
|
# ==> L' = L . tril(inv(L) . A' . inv(U), -1)
|
|
# U' = triu(inv(L) . A' . inv(U)) . U
|
|
|
|
ndims = len(a_shape)
|
|
l_padding = [(0, 0, 0)] * ndims
|
|
l_padding[-1] = (0, m - k, 0)
|
|
zero = lax_internal._const(lu, 0)
|
|
l = lax.pad(jnp.tril(lu[..., :, :k], -1), zero, l_padding)
|
|
l = l + lax.expand_dims(jnp.eye(m, m, dtype=dtype), range(l.ndim - 2))
|
|
|
|
u_eye = lax.pad(jnp.eye(n - k, n - k, dtype=dtype), zero,
|
|
((k, 0, 0), (k, 0, 0)))
|
|
u_padding = [(0, 0, 0)] * ndims
|
|
u_padding[-2] = (0, n - k, 0)
|
|
u = (lax.pad(jnp.triu(lu[..., :k, :]), zero, u_padding) +
|
|
lax.expand_dims(u_eye, range(lu.ndim - 2)))
|
|
|
|
la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True,
|
|
unit_diagonal=True)
|
|
lau = triangular_solve(u, la, left_side=False, transpose_a=False,
|
|
lower=False)
|
|
|
|
l_dot = jnp.matmul(l, jnp.tril(lau, -1), precision=lax.Precision.HIGHEST)
|
|
u_dot = jnp.matmul(jnp.triu(lau), u, precision=lax.Precision.HIGHEST)
|
|
lu_dot = l_dot + u_dot
|
|
return (lu, pivots, permutation), (lu_dot, ad_util.Zero.from_value(pivots),
|
|
ad_util.Zero.from_value(permutation))
|
|
|
|
|
|
def _lu_batching_rule(batched_args, batch_dims):
|
|
x, = batched_args
|
|
bd, = batch_dims
|
|
x = batching.moveaxis(x, bd, 0)
|
|
return lu_p.bind(x), (0, 0, 0)
|
|
|
|
def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand):
|
|
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
|
raise NotImplementedError("Shape polymorphism for custom call is not implemented (lu); b/261671778")
|
|
operand_aval, = ctx.avals_in
|
|
out_aval, pivot_aval, perm_aval = ctx.avals_out
|
|
batch_dims = operand_aval.shape[:-2]
|
|
m = operand_aval.shape[-2]
|
|
lu, pivot, info = getrf_impl(operand_aval.dtype, operand)
|
|
# Subtract 1 from the pivot to get 0-based indices.
|
|
pivot = hlo.SubtractOp(pivot, mlir.full_like_aval(ctx, 1, pivot_aval)).result
|
|
ok = mlir.compare_hlo(
|
|
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
|
|
"GE", "SIGNED")
|
|
select_lu_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
|
lu = _broadcasting_select_hlo(
|
|
ctx,
|
|
mlir.broadcast_in_dim(ctx, ok, select_lu_aval,
|
|
broadcast_dimensions=range(len(batch_dims))),
|
|
select_lu_aval,
|
|
lu, out_aval, _nan_like_hlo(ctx, out_aval), out_aval)
|
|
sub_ctx = ctx.replace(primitive=None, avals_in=[pivot_aval], avals_out=[perm_aval])
|
|
perm_fn = mlir.lower_fun(lambda x: lu_pivots_to_permutation(x, m),
|
|
multiple_results=False)
|
|
perm, = perm_fn(sub_ctx, pivot)
|
|
return [lu, pivot, perm]
|
|
|
|
|
|
def _lu_tpu_lowering_rule(ctx, operand):
|
|
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
|
raise NotImplementedError(f"Shape polymorphism for custom call is not implemented (lu); b/261671778; {ctx.avals_in + ctx.avals_out}")
|
|
result_types = [
|
|
mlir.aval_to_ir_type(ctx.avals_out[0]),
|
|
mlir.aval_to_ir_type(ctx.avals_out[1]),
|
|
mlir.aval_to_ir_type(ctx.avals_out[2])
|
|
]
|
|
op = hlo.CustomCallOp(
|
|
result_types,
|
|
[operand],
|
|
call_target_name=ir.StringAttr.get("LuDecomposition"),
|
|
has_side_effect=ir.BoolAttr.get(False),
|
|
)
|
|
return op.results
|
|
|
|
|
|
lu_p = Primitive('lu')
|
|
lu_p.multiple_results = True
|
|
lu_p.def_impl(_lu_impl)
|
|
lu_p.def_abstract_eval(_lu_abstract_eval)
|
|
mlir.register_lowering(lu_p, mlir.lower_fun(_lu_python, multiple_results=True))
|
|
ad.primitive_jvps[lu_p] = _lu_jvp_rule
|
|
batching.primitive_batchers[lu_p] = _lu_batching_rule
|
|
|
|
mlir.register_lowering(lu_p,
|
|
partial(_lu_cpu_gpu_lowering, lapack.getrf_hlo),
|
|
platform='cpu')
|
|
|
|
mlir.register_lowering(
|
|
lu_p, partial(_lu_cpu_gpu_lowering, gpu_solver.cuda_getrf),
|
|
platform='cuda')
|
|
mlir.register_lowering(
|
|
lu_p, partial(_lu_cpu_gpu_lowering, gpu_solver.rocm_getrf),
|
|
platform='rocm')
|
|
|
|
mlir.register_lowering(lu_p, _lu_tpu_lowering_rule, platform='tpu')
|
|
|
|
|
|
@partial(vectorize, excluded={3}, signature='(n,n),(n),(n,k)->(n,k)')
|
|
def _lu_solve_core(lu: Array, permutation: Array, b: Array, trans: int) -> Array:
|
|
m = lu.shape[0]
|
|
x = jnp.reshape(b, (m, math.prod(b.shape[1:])))
|
|
if trans == 0:
|
|
x = x[permutation, :]
|
|
x = triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True)
|
|
x = triangular_solve(lu, x, left_side=True, lower=False)
|
|
elif trans == 1 or trans == 2:
|
|
conj = trans == 2
|
|
x = triangular_solve(lu, x, left_side=True, lower=False, transpose_a=True,
|
|
conjugate_a=conj)
|
|
x = triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True,
|
|
transpose_a=True, conjugate_a=conj)
|
|
x = x[jnp.argsort(permutation), :]
|
|
else:
|
|
raise ValueError(f"'trans' value must be 0, 1, or 2, got {trans}")
|
|
return lax.reshape(x, b.shape)
|
|
|
|
|
|
@partial(api.jit, static_argnums=(3,))
|
|
def _lu_solve(lu: Array, permutation: Array, b: Array, trans: int) -> Array:
|
|
if len(lu.shape) < 2 or lu.shape[-1] != lu.shape[-2]:
|
|
raise ValueError("last two dimensions of LU decomposition must be equal, "
|
|
"got shape {}".format(lu.shape))
|
|
if len(b.shape) < 1:
|
|
raise ValueError("b matrix must have rank >= 1, got shape {}"
|
|
.format(b.shape))
|
|
# Broadcasting follows NumPy's convention for linalg.solve: the RHS is
|
|
# treated as a (batched) vector if the number of dimensions differ by 1.
|
|
# Otherwise, broadcasting rules apply.
|
|
rhs_vector = lu.ndim == b.ndim + 1
|
|
if rhs_vector:
|
|
if b.shape[-1] != lu.shape[-1]:
|
|
raise ValueError("When LU decomposition matrix and b have the same "
|
|
"number of dimensions, last axis of LU decomposition "
|
|
"matrix (shape {}) and b array (shape {}) must match"
|
|
.format(lu.shape, b.shape))
|
|
b = b[..., jnp.newaxis]
|
|
else:
|
|
if b.shape[-2] != lu.shape[-1]:
|
|
raise ValueError("When LU decomposition matrix and b different "
|
|
"numbers of dimensions, last axis of LU decomposition "
|
|
"matrix (shape {}) and second to last axis of b array "
|
|
"(shape {}) must match"
|
|
.format(lu.shape, b.shape))
|
|
x = _lu_solve_core(lu, permutation, b, trans)
|
|
return x[..., 0] if rhs_vector else x
|
|
|
|
|
|
def lu_solve(lu: ArrayLike, permutation: ArrayLike, b: ArrayLike,
|
|
trans: int = 0) -> Array:
|
|
"""LU solve with broadcasting."""
|
|
return _lu_solve(lu, permutation, b, trans)
|
|
|
|
|
|
# QR decomposition
|
|
|
|
# QR decomposition is implemented as a composition of two lower-level primitives
|
|
# geqrf and orgqr. The names, while cryptic Fortran alphabet soup, are LAPACK's
|
|
# names for the primitives, and we stick with them for consistency.
|
|
|
|
def geqrf(a: ArrayLike) -> Tuple[Array, Array]:
|
|
"""Computes the QR decomposition of a matrix.
|
|
|
|
Args:
|
|
a: an ``[..., m, n]`` batch of matrices, with floating-point or complex type.
|
|
Returns:
|
|
An ``(a, taus)`` pair where ``r`` is in the upper triangle of ``a``,
|
|
``q`` is represented in the lower triangle of ``a`` and in ``taus`` as
|
|
elementary Householder reflectors.
|
|
"""
|
|
a_out, taus = geqrf_p.bind(a)
|
|
return a_out, taus
|
|
|
|
def _geqrf_abstract_eval(operand):
|
|
if not isinstance(operand, ShapedArray):
|
|
raise NotImplementedError("Unsupported aval in geqrf_abstract_eval: "
|
|
f"{operand.aval}")
|
|
if operand.ndim < 2:
|
|
raise ValueError("Argument to QR decomposition must have ndims >= 2")
|
|
*batch_dims, m, n = operand.shape
|
|
taus = operand.update(shape=(*batch_dims, min(m, n)))
|
|
return operand, taus
|
|
|
|
def _geqrf_batching_rule(batched_args, batch_dims):
|
|
x, = batched_args
|
|
bd, = batch_dims
|
|
return geqrf(batching.moveaxis(x, bd, 0)), (0, 0)
|
|
|
|
def _geqrf_lowering_rule(ctx, operand):
|
|
ts_type = mlir.aval_to_ir_type(ctx.avals_out[0])
|
|
r_type = mlir.aval_to_ir_type(ctx.avals_out[1])
|
|
result_types = [ts_type, r_type]
|
|
if any(not is_constant_shape(aval_out.shape)
|
|
for aval_out in ctx.avals_out):
|
|
result_shapes = [
|
|
mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, aval_out.shape))
|
|
for aval_out in ctx.avals_out
|
|
]
|
|
else:
|
|
result_shapes = None
|
|
op = mlir.custom_call(
|
|
"Qr",
|
|
result_types,
|
|
[operand],
|
|
api_version=1,
|
|
result_shapes=result_shapes
|
|
)
|
|
return op.results
|
|
|
|
def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a):
|
|
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
|
raise NotImplementedError("Shape polymorphism for custom call is not implemented (geqrf); b/261671778")
|
|
a_aval, taus_aval = ctx.avals_out
|
|
*batch_dims, m, n = a_aval.shape
|
|
batch = math.prod(batch_dims)
|
|
|
|
if batch == 0 or m == 0 or n == 0:
|
|
return mlir.full_like_aval(ctx, 0, a_aval), mlir.full_like_aval(ctx, 0, taus_aval)
|
|
|
|
if (batched_geqrf_impl is not None and batch > 1 and m // batch <= 128 and
|
|
n // batch <= 128):
|
|
a_out, taus = batched_geqrf_impl(a_aval.dtype, a)
|
|
else:
|
|
a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a)
|
|
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
|
|
ok = mlir.compare_hlo(info_geqrf, zeros, "EQ", "SIGNED")
|
|
select_ok_a_aval = ShapedArray(batch_dims + [1, 1], np.dtype(np.bool_))
|
|
ok_a = mlir.broadcast_in_dim(ctx, ok, select_ok_a_aval,
|
|
broadcast_dimensions=range(len(batch_dims)))
|
|
a_out = _broadcasting_select_hlo(ctx, ok_a, select_ok_a_aval, a_out, a_aval, _nan_like_hlo(ctx, a_aval), a_aval)
|
|
select_ok_taus_aval = ShapedArray(batch_dims + [1], np.dtype(np.bool_))
|
|
ok_taus = mlir.broadcast_in_dim(ctx, ok, select_ok_taus_aval,
|
|
broadcast_dimensions=range(len(batch_dims)))
|
|
taus = _broadcasting_select_hlo(ctx, ok_taus, select_ok_taus_aval, taus, taus_aval, _nan_like_hlo(ctx, taus_aval), taus_aval)
|
|
return a_out, taus
|
|
|
|
geqrf_p = Primitive('geqrf')
|
|
geqrf_p.multiple_results = True
|
|
geqrf_p.def_impl(partial(dispatch.apply_primitive, geqrf_p))
|
|
geqrf_p.def_abstract_eval(_geqrf_abstract_eval)
|
|
batching.primitive_batchers[geqrf_p] = _geqrf_batching_rule
|
|
mlir.register_lowering(geqrf_p, _geqrf_lowering_rule)
|
|
|
|
mlir.register_lowering(
|
|
geqrf_p, partial(_geqrf_cpu_gpu_lowering, lapack.geqrf_hlo, None),
|
|
platform='cpu')
|
|
mlir.register_lowering(
|
|
geqrf_p,
|
|
partial(_geqrf_cpu_gpu_lowering, gpu_solver.cuda_geqrf,
|
|
gpu_solver.cuda_geqrf_batched),
|
|
platform='cuda')
|
|
mlir.register_lowering(
|
|
geqrf_p,
|
|
partial(_geqrf_cpu_gpu_lowering, gpu_solver.rocm_geqrf,
|
|
gpu_solver.rocm_geqrf_batched),
|
|
platform='rocm')
|
|
|
|
|
|
# householder_product: product of elementary Householder reflectors
|
|
|
|
def householder_product(a: ArrayLike, taus: ArrayLike) -> Array:
|
|
"""Product of elementary Householder reflectors.
|
|
|
|
Args:
|
|
a: A matrix with shape ``[..., m, n]``, whose lower triangle contains
|
|
elementary Householder reflectors.
|
|
taus: A vector with shape ``[..., k]``, where ``k < min(m, n)``, containing
|
|
the scalar factors of the elementary Householder reflectors.
|
|
|
|
Returns:
|
|
A batch of orthogonal (unitary) matrices with the same shape as ``a``,
|
|
containing the products of the elementary Householder reflectors.
|
|
"""
|
|
return householder_product_p.bind(a, taus)
|
|
|
|
|
|
def _householder_product_abstract_eval(a, taus):
|
|
if not isinstance(a, ShapedArray) or not isinstance(taus, ShapedArray):
|
|
raise NotImplementedError("Unsupported aval in householder_product_abstract_eval: "
|
|
f"{a.aval} {taus.aval}")
|
|
if a.ndim < 2:
|
|
raise ValueError("Argument to Householder product must have ndims >= 2")
|
|
*batch_dims, m, n = a.shape
|
|
*taus_batch_dims, k = taus.shape
|
|
if a.dtype != taus.dtype or batch_dims != taus_batch_dims or k > min(m, n):
|
|
raise ValueError(f"Type mismatch for Householder product: {a=} {taus=}")
|
|
if m < n:
|
|
raise ValueError("Householder product inputs must have at least as many "
|
|
f"rows as columns, got shape {a.shape}")
|
|
return a
|
|
|
|
def _householder_product_batching_rule(batched_args, batch_dims):
|
|
a, taus = batched_args
|
|
b_a, b_taus, = batch_dims
|
|
return householder_product(batching.moveaxis(a, b_a, 0),
|
|
batching.moveaxis(taus, b_taus, 0)), (0,)
|
|
|
|
def _householder_product_lowering_rule(ctx, a, taus):
|
|
aval_out, = ctx.avals_out
|
|
if not is_constant_shape(aval_out.shape):
|
|
result_shapes = [
|
|
mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, aval_out.shape))]
|
|
else:
|
|
result_shapes = None
|
|
op = mlir.custom_call(
|
|
"ProductOfElementaryHouseholderReflectors",
|
|
[mlir.aval_to_ir_type(aval_out)],
|
|
[a, taus],
|
|
api_version=1,
|
|
result_shapes=result_shapes)
|
|
return [op.result]
|
|
|
|
def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
|
|
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
|
raise NotImplementedError("Shape polymorphism for custom call is not implemented (householder product); b/261671778")
|
|
a_aval, _ = ctx.avals_in
|
|
*batch_dims, m, n = a_aval.shape
|
|
|
|
if m == 0 or n == 0:
|
|
return [mlir.full_like_aval(ctx, 0, a_aval)]
|
|
|
|
a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus)
|
|
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
|
|
ok = mlir.compare_hlo(info_orgqr, zeros, "EQ", "SIGNED")
|
|
select_a_aval = ShapedArray(batch_dims + [1, 1], np.dtype(np.bool_))
|
|
ok = mlir.broadcast_in_dim(ctx, ok, select_a_aval,
|
|
broadcast_dimensions=range(len(batch_dims)))
|
|
a = _broadcasting_select_hlo(ctx, ok, select_a_aval, a, a_aval, _nan_like_hlo(ctx, a_aval), a_aval)
|
|
return [a]
|
|
|
|
|
|
householder_product_p = Primitive('householder_product')
|
|
householder_product_p.def_impl(partial(dispatch.apply_primitive, householder_product_p))
|
|
householder_product_p.def_abstract_eval(_householder_product_abstract_eval)
|
|
batching.primitive_batchers[householder_product_p] = _householder_product_batching_rule
|
|
mlir.register_lowering(householder_product_p, _householder_product_lowering_rule)
|
|
|
|
mlir.register_lowering(
|
|
householder_product_p,
|
|
partial(_householder_product_cpu_gpu_lowering, lapack.orgqr_hlo),
|
|
platform='cpu')
|
|
mlir.register_lowering(
|
|
householder_product_p,
|
|
partial(_householder_product_cpu_gpu_lowering, gpu_solver.cuda_orgqr),
|
|
platform='cuda')
|
|
mlir.register_lowering(
|
|
householder_product_p,
|
|
partial(_householder_product_cpu_gpu_lowering, gpu_solver.rocm_orgqr),
|
|
platform='rocm')
|
|
|
|
|
|
def _qr_impl(operand, *, full_matrices):
|
|
q, r = dispatch.apply_primitive(qr_p, operand, full_matrices=full_matrices)
|
|
return q, r
|
|
|
|
def _qr_abstract_eval(operand, *, full_matrices):
|
|
if isinstance(operand, ShapedArray):
|
|
if operand.ndim < 2:
|
|
raise ValueError("Argument to QR decomposition must have ndims >= 2")
|
|
*batch_dims, m, n = operand.shape
|
|
k = m if full_matrices else min(m, n)
|
|
q = operand.update(shape=(*batch_dims, m, k))
|
|
r = operand.update(shape=(*batch_dims, k, n))
|
|
else:
|
|
q = operand
|
|
r = operand
|
|
return q, r
|
|
|
|
def qr_jvp_rule(primals, tangents, *, full_matrices):
|
|
# See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
|
|
x, = primals
|
|
dx, = tangents
|
|
q, r = qr_p.bind(x, full_matrices=False)
|
|
*_, m, n = x.shape
|
|
if m < n or (full_matrices and m != n):
|
|
raise NotImplementedError(
|
|
"Unimplemented case of QR decomposition derivative")
|
|
dx_rinv = triangular_solve(r, dx) # Right side solve by default
|
|
qt_dx_rinv = jnp.matmul(_H(q), dx_rinv)
|
|
qt_dx_rinv_lower = jnp.tril(qt_dx_rinv, -1)
|
|
do = qt_dx_rinv_lower - _H(qt_dx_rinv_lower) # This is skew-symmetric
|
|
# The following correction is necessary for complex inputs
|
|
I = lax.expand_dims(jnp.eye(n, dtype=do.dtype), range(qt_dx_rinv.ndim - 2))
|
|
do = do + I * (qt_dx_rinv - qt_dx_rinv.real.astype(qt_dx_rinv.dtype))
|
|
dq = jnp.matmul(q, do - qt_dx_rinv) + dx_rinv
|
|
dr = jnp.matmul(qt_dx_rinv - do, r)
|
|
return (q, r), (dq, dr)
|
|
|
|
def _qr_batching_rule(batched_args, batch_dims, *, full_matrices):
|
|
x, = batched_args
|
|
bd, = batch_dims
|
|
x = batching.moveaxis(x, bd, 0)
|
|
return qr_p.bind(x, full_matrices=full_matrices), (0, 0)
|
|
|
|
def _qr_lowering(a, *, full_matrices):
|
|
*batch_dims, m, n = a.shape
|
|
if m == 0 or n == 0:
|
|
k = m if full_matrices else min(m, n)
|
|
q = jnp.broadcast_to(jnp.eye(m, k, dtype=a.dtype), (*batch_dims, m, k))
|
|
r = jnp.empty((*batch_dims, k, n), dtype=a.dtype)
|
|
return q, r
|
|
|
|
r, taus = geqrf(a)
|
|
if m < n:
|
|
q = householder_product(r[..., :m, :m], taus)
|
|
elif full_matrices:
|
|
pads = [(0, 0, 0)] * (len(batch_dims) + 1) + [(0, m - n, 0)]
|
|
q = lax.pad(r, lax_internal._zero(r), pads)
|
|
q = householder_product(q, taus)
|
|
else:
|
|
q = householder_product(r, taus)
|
|
r = r[..., :n, :n]
|
|
r = jnp.triu(r)
|
|
return q, r
|
|
|
|
|
|
qr_p = Primitive('qr')
|
|
qr_p.multiple_results = True
|
|
qr_p.def_impl(_qr_impl)
|
|
qr_p.def_abstract_eval(_qr_abstract_eval)
|
|
|
|
ad.primitive_jvps[qr_p] = qr_jvp_rule
|
|
batching.primitive_batchers[qr_p] = _qr_batching_rule
|
|
|
|
mlir.register_lowering(qr_p, mlir.lower_fun(_qr_lowering));
|
|
|
|
|
|
# Singular value decomposition
|
|
|
|
def _svd_impl(operand, *, full_matrices, compute_uv):
|
|
return dispatch.apply_primitive(svd_p, operand, full_matrices=full_matrices,
|
|
compute_uv=compute_uv)
|
|
|
|
def _svd_abstract_eval(operand, *, full_matrices, compute_uv):
|
|
if isinstance(operand, ShapedArray):
|
|
if operand.ndim < 2:
|
|
raise ValueError("Argument to singular value decomposition must have ndims >= 2")
|
|
|
|
batch_dims = operand.shape[:-2]
|
|
m = operand.shape[-2]
|
|
n = operand.shape[-1]
|
|
s = operand.update(shape=batch_dims + (min(m, n),),
|
|
dtype=lax_internal._complex_basetype(operand.dtype))
|
|
if compute_uv:
|
|
u = operand.update(shape=batch_dims + (m, m if full_matrices else min(m, n)))
|
|
vt = operand.update(shape=batch_dims + (n if full_matrices else min(m, n), n))
|
|
return s, u, vt
|
|
else:
|
|
return s,
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
@jax.default_matmul_precision("float32")
|
|
def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):
|
|
A, = primals
|
|
dA, = tangents
|
|
s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)
|
|
|
|
if compute_uv and full_matrices:
|
|
# TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
|
|
raise NotImplementedError(
|
|
"Singular value decomposition JVP not implemented for full matrices")
|
|
|
|
Ut, V = _H(U), _H(Vt)
|
|
s_dim = s[..., None, :]
|
|
dS = Ut @ dA @ V
|
|
ds = ufuncs.real(jnp.diagonal(dS, 0, -2, -1))
|
|
|
|
if not compute_uv:
|
|
return (s,), (ds,)
|
|
|
|
s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim))
|
|
s_diffs_zeros = jnp.eye(s.shape[-1], dtype=s.dtype) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else
|
|
s_diffs_zeros = lax.expand_dims(s_diffs_zeros, range(s_diffs.ndim - 2))
|
|
F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros
|
|
dSS = s_dim.astype(A.dtype) * dS # dS.dot(jnp.diag(s))
|
|
SdS = _T(s_dim.astype(A.dtype)) * dS # jnp.diag(s).dot(dS)
|
|
|
|
s_zeros = (s == 0).astype(s.dtype)
|
|
s_inv = 1 / (s + s_zeros) - s_zeros
|
|
s_inv_mat = jnp.vectorize(jnp.diag, signature='(k)->(k,k)')(s_inv)
|
|
dUdV_diag = .5 * (dS - _H(dS)) * s_inv_mat.astype(A.dtype)
|
|
dU = U @ (F.astype(A.dtype) * (dSS + _H(dSS)) + dUdV_diag)
|
|
dV = V @ (F.astype(A.dtype) * (SdS + _H(SdS)))
|
|
|
|
m, n = A.shape[-2:]
|
|
if m > n:
|
|
dAV = dA @ V
|
|
dU = dU + (dAV - U @ (Ut @ dAV)) / s_dim.astype(A.dtype)
|
|
if n > m:
|
|
dAHU = _H(dA) @ U
|
|
dV = dV + (dAHU - V @ (Vt @ dAHU)) / s_dim.astype(A.dtype)
|
|
|
|
return (s, U, Vt), (ds, dU, _H(dV))
|
|
|
|
def _empty_svd(a, *, full_matrices, compute_uv):
|
|
batch_shape = a.shape[:-2]
|
|
m, n = a.shape[-2:]
|
|
s = jnp.empty(batch_shape + (0,), dtype=lax_internal._complex_basetype(a.dtype))
|
|
if not compute_uv:
|
|
return (s,)
|
|
if full_matrices:
|
|
size = max(m, n)
|
|
u = jnp.broadcast_to(jnp.eye(size, dtype=a.dtype), batch_shape + (size, size))
|
|
else:
|
|
u = jnp.empty(batch_shape + (m, n), dtype=a.dtype)
|
|
v = jnp.empty(batch_shape + (0, 0), dtype=a.dtype)
|
|
if m < n:
|
|
u, v = v, u
|
|
return s, u, v
|
|
|
|
def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices,
|
|
compute_uv):
|
|
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
|
raise NotImplementedError("Shape polymorphism for custom call is not implemented (svd); b/261671778")
|
|
operand_aval, = ctx.avals_in
|
|
s_aval = ctx.avals_out[0]
|
|
m, n = operand_aval.shape[-2:]
|
|
batch_dims = operand_aval.shape[:-2]
|
|
|
|
if m == 0 or n == 0:
|
|
return mlir.lower_fun(_empty_svd, multiple_results=True)(
|
|
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)
|
|
|
|
s, u, vt, info = gesvd_impl(operand_aval.dtype, operand,
|
|
full_matrices=full_matrices,
|
|
compute_uv=compute_uv)
|
|
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
|
|
ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED")
|
|
select_s_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_))
|
|
s = _broadcasting_select_hlo(
|
|
ctx,
|
|
mlir.broadcast_in_dim(ctx, ok, select_s_aval,
|
|
broadcast_dimensions=range(len(batch_dims))),
|
|
select_s_aval,
|
|
s, s_aval, _nan_like_hlo(ctx, s_aval), s_aval)
|
|
result = [s]
|
|
|
|
if compute_uv:
|
|
u_aval, vt_aval = ctx.avals_out[1:]
|
|
select_u_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
|
u = _broadcasting_select_hlo(
|
|
ctx,
|
|
mlir.broadcast_in_dim(ctx, ok, select_u_aval,
|
|
broadcast_dimensions=range(len(batch_dims))),
|
|
select_u_aval,
|
|
u, u_aval, _nan_like_hlo(ctx, u_aval), u_aval)
|
|
select_v_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
|
vt = _broadcasting_select_hlo(
|
|
ctx,
|
|
mlir.broadcast_in_dim(ctx, ok, select_v_aval,
|
|
broadcast_dimensions=range(len(batch_dims))),
|
|
select_v_aval,
|
|
vt, vt_aval, _nan_like_hlo(ctx, vt_aval), vt_aval)
|
|
result += [u, vt]
|
|
|
|
return result
|
|
|
|
def _svd_tpu(a, *, full_matrices, compute_uv):
|
|
batch_dims = a.shape[:-2]
|
|
|
|
fn = partial(lax_svd.svd, full_matrices=full_matrices, compute_uv=compute_uv)
|
|
for _ in range(len(batch_dims)):
|
|
fn = api.vmap(fn)
|
|
|
|
if compute_uv:
|
|
u, s, vh = fn(a)
|
|
return [s, u, vh]
|
|
else:
|
|
s = fn(a)
|
|
return [s]
|
|
|
|
def _svd_tpu_lowering_rule(ctx, operand, *, full_matrices, compute_uv):
|
|
operand_aval, = ctx.avals_in
|
|
m, n = operand_aval.shape[-2:]
|
|
|
|
if m == 0 or n == 0:
|
|
return mlir.lower_fun(_empty_svd, multiple_results=True)(
|
|
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)
|
|
|
|
return mlir.lower_fun(_svd_tpu, multiple_results=True)(
|
|
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)
|
|
|
|
def _svd_batching_rule(batched_args, batch_dims, *, full_matrices, compute_uv):
|
|
x, = batched_args
|
|
bd, = batch_dims
|
|
x = batching.moveaxis(x, bd, 0)
|
|
outs = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv)
|
|
|
|
if compute_uv:
|
|
return outs, (0, 0, 0)
|
|
else:
|
|
return outs, (0,)
|
|
|
|
svd_p = Primitive('svd')
|
|
svd_p.multiple_results = True
|
|
svd_p.def_impl(_svd_impl)
|
|
svd_p.def_abstract_eval(_svd_abstract_eval)
|
|
ad.primitive_jvps[svd_p] = _svd_jvp_rule
|
|
batching.primitive_batchers[svd_p] = _svd_batching_rule
|
|
|
|
mlir.register_lowering(
|
|
svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_hlo),
|
|
platform='cpu')
|
|
mlir.register_lowering(
|
|
svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.cuda_gesvd),
|
|
platform='cuda')
|
|
mlir.register_lowering(
|
|
svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.rocm_gesvd),
|
|
platform='rocm')
|
|
|
|
mlir.register_lowering(svd_p, _svd_tpu_lowering_rule)
|
|
|
|
def _tridiagonal_solve_gpu_lowering(lowering, ctx, dl, d, du, b, *, m, n, ldb, t):
|
|
return [lowering(dl, d, du, b, m=m, n=n, ldb=ldb,
|
|
t=dtypes.canonicalize_dtype(t))]
|
|
|
|
tridiagonal_solve_p = Primitive('tridiagonal_solve')
|
|
tridiagonal_solve_p.multiple_results = False
|
|
tridiagonal_solve_p.def_impl(
|
|
functools.partial(dispatch.apply_primitive, tridiagonal_solve_p))
|
|
tridiagonal_solve_p.def_abstract_eval(lambda dl, d, du, b, *, m, n, ldb, t: b)
|
|
# TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve?
|
|
|
|
mlir.register_lowering(
|
|
tridiagonal_solve_p,
|
|
partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.cuda_gtsv2),
|
|
platform='cuda')
|
|
mlir.register_lowering(
|
|
tridiagonal_solve_p,
|
|
partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.rocm_gtsv2),
|
|
platform='rocm')
|
|
|
|
|
|
def _tridiagonal_solve_jax(dl, d, du, b, **kw):
|
|
"""Pure JAX implementation of `tridiagonal_solve`."""
|
|
prepend_zero = lambda x: jnp.append(jnp.zeros([1], dtype=x.dtype), x[:-1])
|
|
fwd1 = lambda tu_, x: x[1] / (x[0] - x[2] * tu_)
|
|
fwd2 = lambda b_, x: (x[0] - x[3] * b_) / (x[1] - x[3] * x[2])
|
|
bwd1 = lambda x_, x: x[0] - x[1] * x_
|
|
double = lambda f, args: (f(*args), f(*args))
|
|
|
|
# Forward pass.
|
|
_, tu_ = lax.scan(lambda tu_, x: double(fwd1, (tu_, x)),
|
|
du[0] / d[0],
|
|
(d, du, dl),
|
|
unroll=32)
|
|
|
|
_, b_ = lax.scan(lambda b_, x: double(fwd2, (b_, x)),
|
|
b[0] / d[0],
|
|
(b, d, prepend_zero(tu_), dl),
|
|
unroll=32)
|
|
|
|
# Backsubstitution.
|
|
_, x_ = lax.scan(lambda x_, x: double(bwd1, (x_, x)),
|
|
b_[-1],
|
|
(b_[::-1], tu_[::-1]),
|
|
unroll=32)
|
|
|
|
return x_[::-1]
|
|
|
|
|
|
mlir.register_lowering(tridiagonal_solve_p, mlir.lower_fun(
|
|
_tridiagonal_solve_jax, multiple_results=False))
|
|
|
|
|
|
def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array:
|
|
r"""Computes the solution of a tridiagonal linear system.
|
|
|
|
This function computes the solution of a tridiagonal linear system:
|
|
|
|
.. math::
|
|
A . X = B
|
|
|
|
Args:
|
|
dl: The lower diagonal of A: ``dl[i] := A[i, i-1]`` for i in ``[0,m)``.
|
|
Note that ``dl[0] = 0``.
|
|
d: The middle diagnoal of A: ``d[i] := A[i, i]`` for i in ``[0,m)``.
|
|
du: The upper diagonal of A: ``du[i] := A[i, i+1]`` for i in ``[0,m)``.
|
|
Note that ``dl[m - 1] = 0``.
|
|
b: Right hand side matrix.
|
|
|
|
Returns:
|
|
Solution ``X`` of tridiagonal system.
|
|
"""
|
|
if dl.ndim != 1 or d.ndim != 1 or du.ndim != 1:
|
|
raise ValueError('dl, d and du must be vectors')
|
|
|
|
if dl.shape != d.shape or d.shape != du.shape:
|
|
raise ValueError(
|
|
f'dl={dl.shape}, d={d.shape} and du={du.shape} must all be `[m]`')
|
|
|
|
if b.ndim != 2:
|
|
raise ValueError(f'b={b.shape} must be a matrix')
|
|
|
|
m, = dl.shape
|
|
if m < 3:
|
|
raise ValueError(f'm ({m}) must be >= 3')
|
|
|
|
ldb, n = b.shape
|
|
if ldb < max(1, m):
|
|
raise ValueError(f'Leading dimension of b={ldb} must be ≥ max(1, {m})')
|
|
|
|
if dl.dtype != d.dtype or d.dtype != du.dtype or du.dtype != b.dtype:
|
|
raise ValueError(f'dl={dl.dtype}, d={d.dtype}, du={du.dtype} and '
|
|
f'b={b.dtype} must be the same dtype,')
|
|
|
|
t = dl.dtype
|
|
if t not in (np.float32, np.float64):
|
|
raise ValueError(f'Only f32/f64 are supported, got {t}')
|
|
|
|
return tridiagonal_solve_p.bind(dl, d, du, b, m=m, n=n, ldb=ldb, t=t)
|
|
|
|
|
|
# Schur Decomposition
|
|
|
|
|
|
@_warn_on_positional_kwargs
|
|
def schur(x: ArrayLike, *,
|
|
compute_schur_vectors: bool = True,
|
|
sort_eig_vals: bool = False,
|
|
select_callable: Optional[Callable[..., Any]] = None) -> Tuple[Array, Array]:
|
|
return schur_p.bind(
|
|
x,
|
|
compute_schur_vectors=compute_schur_vectors,
|
|
sort_eig_vals=sort_eig_vals,
|
|
select_callable=select_callable)
|
|
|
|
|
|
def _schur_impl(operand, *, compute_schur_vectors, sort_eig_vals,
|
|
select_callable):
|
|
return dispatch.apply_primitive(
|
|
schur_p,
|
|
operand,
|
|
compute_schur_vectors=compute_schur_vectors,
|
|
sort_eig_vals=sort_eig_vals,
|
|
select_callable=select_callable)
|
|
|
|
def _schur_lowering(ctx, *args, **kwargs):
|
|
raise NotImplementedError(
|
|
"Schur decomposition is only implemented on the CPU backend.")
|
|
|
|
def _schur_abstract_eval(operand, *, compute_schur_vectors, sort_eig_vals,
|
|
select_callable):
|
|
|
|
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
|
|
raise ValueError("Argument to Schur decomposition must have "
|
|
"shape [..., n, n], got shape {}".format(operand.shape))
|
|
|
|
batch_dims = operand.shape[:-2]
|
|
n = operand.shape[-1]
|
|
dtype = operand.dtype
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
T = operand.update(shape=batch_dims + (n, n), dtype=dtype)
|
|
vs = operand.update(shape=batch_dims + (n, n), dtype=dtype)
|
|
|
|
return (T, vs) if compute_schur_vectors else (T,)
|
|
|
|
def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals,
|
|
select_callable):
|
|
operand_aval, = ctx.avals_in
|
|
batch_dims = operand_aval.shape[:-2]
|
|
|
|
gees_result = lapack.gees_hlo(operand_aval.dtype, operand,
|
|
jobvs=compute_schur_vectors,
|
|
sort=sort_eig_vals,
|
|
select=select_callable)
|
|
|
|
# Number of return values depends on value of sort_eig_vals.
|
|
T, vs, *_, info = gees_result
|
|
|
|
ok = mlir.compare_hlo(
|
|
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
|
|
"EQ", "SIGNED")
|
|
|
|
select_T_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
|
T = _broadcasting_select_hlo(
|
|
ctx,
|
|
mlir.broadcast_in_dim(ctx, ok, select_T_aval,
|
|
broadcast_dimensions=range(len(batch_dims))),
|
|
select_T_aval,
|
|
T, ctx.avals_out[0],_nan_like_hlo(ctx, ctx.avals_out[0]), ctx.avals_out[0])
|
|
output = [T]
|
|
if compute_schur_vectors:
|
|
select_vs_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
|
vs = _broadcasting_select_hlo(
|
|
ctx,
|
|
mlir.broadcast_in_dim(ctx, ok, select_vs_aval,
|
|
broadcast_dimensions=range(len(batch_dims))),
|
|
select_vs_aval,
|
|
vs, ctx.avals_out[1], _nan_like_hlo(ctx, ctx.avals_out[1]), ctx.avals_out[1])
|
|
|
|
output.append(vs)
|
|
|
|
return output
|
|
|
|
|
|
def _schur_batching_rule(batched_args, batch_dims, *, compute_schur_vectors,
|
|
sort_eig_vals, select_callable):
|
|
x, = batched_args
|
|
bd, = batch_dims
|
|
x = batching.moveaxis(x, bd, 0)
|
|
|
|
return schur_p.bind(
|
|
x,
|
|
compute_schur_vectors=compute_schur_vectors,
|
|
sort_eig_vals=sort_eig_vals,
|
|
select_callable=select_callable), (0,) * (1 + compute_schur_vectors)
|
|
|
|
|
|
def _schur_jvp_rule(primals, tangents, *, compute_schur_vectors, sort_eig_vals):
|
|
raise NotImplementedError(
|
|
'The differentiation rules for the Schur factorization have not been implemented.'
|
|
)
|
|
|
|
|
|
schur_p = Primitive('schur')
|
|
schur_p.multiple_results = True
|
|
schur_p.def_impl(_schur_impl)
|
|
schur_p.def_abstract_eval(_schur_abstract_eval)
|
|
mlir.register_lowering(schur_p, _schur_lowering)
|
|
mlir.register_lowering(schur_p, _schur_cpu_lowering, platform='cpu')
|
|
batching.primitive_batchers[schur_p] = _schur_batching_rule
|
|
ad.primitive_jvps[schur_p] = _schur_jvp_rule
|
|
|
|
|
|
# hessenberg: Upper Hessenberg reduction
|
|
|
|
def hessenberg(a: ArrayLike) -> Tuple[Array, Array]:
|
|
"""Reduces a square matrix to upper Hessenberg form.
|
|
|
|
Currently implemented on CPU only.
|
|
|
|
Args:
|
|
a: A floating point or complex square matrix or batch of matrices.
|
|
|
|
Returns:
|
|
A ``(a, taus)`` pair, where the upper triangle and first subdiagonal of ``a``
|
|
contain the upper Hessenberg matrix, and the elements below the first
|
|
subdiagonal contain the Householder reflectors. For each Householder
|
|
reflector ``taus`` contains the scalar factors of the elementary Householder
|
|
reflectors.
|
|
"""
|
|
return hessenberg_p.bind(a)
|
|
|
|
def _hessenberg_abstract_eval(a):
|
|
if a.dtype not in (jnp.float32, jnp.float64, jnp.complex64, jnp.complex128):
|
|
raise TypeError("hessenberg requires a.dtype to be float32, float64, "
|
|
f"complex64, or complex128, got {a.dtype}.")
|
|
if a.ndim < 2:
|
|
raise TypeError("hessenberg requires a.ndim to be at least 2, got "
|
|
f"{a.ndim}.")
|
|
if a.shape[-1] != a.shape[-2]:
|
|
raise TypeError("hessenberg requires the last two dimensions of a to be "
|
|
f"equal in size, got a.shape of {a.shape}.")
|
|
return [a, ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype)]
|
|
|
|
hessenberg_p = Primitive("hessenberg")
|
|
hessenberg_p.def_impl(partial(dispatch.apply_primitive, hessenberg_p))
|
|
hessenberg_p.def_abstract_eval(_hessenberg_abstract_eval)
|
|
hessenberg_p.multiple_results = True
|
|
|
|
def _hessenberg_batching_rule(batched_args, batch_dims):
|
|
x, = batched_args
|
|
bd, = batch_dims
|
|
x = batching.moveaxis(x, bd, 0)
|
|
return hessenberg(x), 0
|
|
|
|
batching.primitive_batchers[hessenberg_p] = _hessenberg_batching_rule
|
|
|
|
def _hessenberg_cpu_hlo(ctx, a):
|
|
a_aval, = ctx.avals_in
|
|
batch_dims = a_aval.shape[:-2]
|
|
a, taus, info = lapack.gehrd_hlo(a_aval.dtype, a)
|
|
ok = mlir.compare_hlo(
|
|
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
|
|
"EQ", "SIGNED")
|
|
select_a_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
|
select_taus_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_))
|
|
return [
|
|
_broadcasting_select_hlo(
|
|
ctx,
|
|
mlir.broadcast_in_dim(ctx, ok, select_a_aval,
|
|
broadcast_dimensions=range(len(batch_dims))),
|
|
select_a_aval,
|
|
a, ctx.avals_out[0], _nan_like_hlo(ctx, ctx.avals_out[0]), ctx.avals_out[0]),
|
|
_broadcasting_select_hlo(
|
|
ctx,
|
|
mlir.broadcast_in_dim(ctx, ok, select_taus_aval,
|
|
broadcast_dimensions=range(len(batch_dims))),
|
|
select_taus_aval,
|
|
taus, ctx.avals_out[1], _nan_like_hlo(ctx, ctx.avals_out[1]), ctx.avals_out[1]),
|
|
]
|
|
|
|
mlir.register_lowering(hessenberg_p, _hessenberg_cpu_hlo, platform='cpu')
|
|
|
|
|
|
# tridiagonal: Upper Hessenberg reduction
|
|
|
|
def tridiagonal(a: ArrayLike, *, lower=True
|
|
) -> Tuple[Array, Array, Array, Array]:
|
|
"""Reduces a symmetric/Hermitian matrix to tridiagonal form.
|
|
|
|
Currently implemented on CPU and GPU only.
|
|
|
|
Args:
|
|
a: A floating point or complex matrix or batch of matrices.
|
|
lower: Describes which triangle of the input matrices to use.
|
|
The other triangle is ignored and not accessed.
|
|
|
|
Returns:
|
|
A ``(a, d, e, taus)`` pair. If ``lower=True``, the diagonal and first subdiagonal of
|
|
matrix (or batch of matrices) ``a`` contain the tridiagonal representation,
|
|
and elements below the first subdiagonal contain the elementary Householder
|
|
reflectors, where additionally ``d`` contains the diagonal of the matrix and ``e`` contains
|
|
the first subdiagonal.If ``lower=False`` the diagonal and first superdiagonal of the
|
|
matrix contains the tridiagonal representation, and elements above the first
|
|
superdiagonal contain the elementary Householder reflectors, where
|
|
additionally ``d`` contains the diagonal of the matrix and ``e`` contains the
|
|
first superdiagonal. ``taus`` contains the scalar factors of the elementary
|
|
Householder reflectors.
|
|
"""
|
|
arr, d, e, taus, info = tridiagonal_p.bind(jnp.asarray(a), lower=lower)
|
|
nan = arr.dtype.type(jnp.nan)
|
|
if jnp.issubdtype(arr.dtype, np.complexfloating):
|
|
nan = nan + arr.dtype.type(jnp.nan * 1j)
|
|
arr = jnp.where((info == 0)[..., None, None], arr, nan)
|
|
real_type = jnp.finfo(arr.dtype).dtype.type
|
|
d = jnp.where((info == 0)[..., None], d, real_type(jnp.nan))
|
|
e = jnp.where((info == 0)[..., None], e, real_type(jnp.nan))
|
|
taus = jnp.where((info == 0)[..., None], taus, nan)
|
|
return arr, d, e, taus
|
|
|
|
def _tridiagonal_abstract_eval(a, *, lower):
|
|
if a.dtype not in (jnp.float32, jnp.float64, jnp.complex64, jnp.complex128):
|
|
raise TypeError("tridiagonal requires a.dtype to be float32, float64, "
|
|
f"complex64, or complex128, got {a.dtype}.")
|
|
if a.ndim < 2:
|
|
raise TypeError("tridiagonal requires a.ndim to be at least 2, got "
|
|
f"{a.ndim}.")
|
|
if a.shape[-1] != a.shape[-2]:
|
|
raise TypeError("tridiagonal requires the last two dimensions of a to be "
|
|
f"equal in size, got a.shape of {a.shape}.")
|
|
if a.shape[-1] == 0:
|
|
raise TypeError("tridiagonal requires the last two dimensions of a to be "
|
|
f"non-zero, got a.shape of {a.shape}.")
|
|
real_dtype = jnp.finfo(a.dtype).dtype
|
|
return [
|
|
a,
|
|
ShapedArray(a.shape[:-2] + (a.shape[-1],), real_dtype),
|
|
ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), real_dtype),
|
|
ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype),
|
|
ShapedArray(a.shape[:-2], np.int32)
|
|
]
|
|
|
|
tridiagonal_p = Primitive("tridiagonal")
|
|
tridiagonal_p.def_impl(partial(dispatch.apply_primitive, tridiagonal_p))
|
|
tridiagonal_p.def_abstract_eval(_tridiagonal_abstract_eval)
|
|
tridiagonal_p.multiple_results = True
|
|
|
|
def _tridiagonal_batching_rule(batched_args, batch_dims, *, lower):
|
|
x, = batched_args
|
|
bd, = batch_dims
|
|
x = batching.moveaxis(x, bd, 0)
|
|
return tridiagonal(x), 0
|
|
|
|
batching.primitive_batchers[tridiagonal_p] = _tridiagonal_batching_rule
|
|
|
|
def _tridiagonal_cpu_gpu_hlo(sytrd_impl, ctx, a, *, lower):
|
|
a_aval, = ctx.avals_in
|
|
a, d, e, taus, info = sytrd_impl(a_aval.dtype, a, lower=lower)
|
|
return a, d, e, taus, info
|
|
|
|
mlir.register_lowering(
|
|
tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, lapack.sytrd_hlo),
|
|
platform='cpu')
|
|
mlir.register_lowering(
|
|
tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.cuda_sytrd),
|
|
platform='cuda')
|
|
mlir.register_lowering(
|
|
tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.rocm_sytrd),
|
|
platform='rocm')
|
|
|
|
# Utilities
|
|
|
|
def _nan_like_hlo(ctx: mlir.LoweringRuleContext, aval) -> ir.Value:
|
|
if jnp.issubdtype(aval.dtype, np.complexfloating):
|
|
return mlir.full_like_aval(ctx, np.nan + np.nan * 1j, aval)
|
|
else:
|
|
return mlir.full_like_aval(ctx, np.nan, aval)
|
|
|
|
def _broadcasting_select_hlo(ctx, which, which_aval, x, x_aval, y, y_aval) -> ir.Value:
|
|
"""Wrapper around XLA `Select` that broadcasts its arguments."""
|
|
out_shapes = list(lax_internal.broadcast_shapes(
|
|
tuple(which_aval.shape), tuple(x_aval.shape), tuple(y_aval.shape)))
|
|
which, x, y = mlir.multi_broadcast_in_dim(ctx, (which, x, y),
|
|
(which_aval, x_aval, y_aval),
|
|
out_shapes)
|
|
return hlo.SelectOp(which, x, y).result
|