Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/lax/control_flow/solves.py

473 lines
18 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for the custom linear solve and utilities."""
import collections
from functools import partial
import operator
import jax
from jax.tree_util import (tree_flatten, treedef_children, tree_leaves,
tree_unflatten, treedef_tuple)
from jax._src import ad_util
from jax._src import core
from jax._src import linear_util as lu
from jax._src.core import raise_to_shaped
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.lax import lax
from jax._src.traceback_util import api_boundary
from jax._src.util import split_list, safe_map
import numpy as np
from jax._src.lax.control_flow.common import (
_abstractify,
_check_tree,
_initial_style_jaxpr,
)
_map = safe_map
_RootTuple = collections.namedtuple('_RootTuple', 'f, solve, l_and_s')
def _split_root_args(args, const_lengths):
params_list = split_list(args, list(const_lengths))
return _RootTuple(*params_list[:-1]), params_list[-1]
@api_boundary
def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False):
"""Differentiably solve for a roots of a function.
This is a low-level routine, mostly intended for internal use in JAX.
Gradients of custom_root() are defined with respect to closed-over variables
from the provided function ``f`` via the implicit function theorem:
https://en.wikipedia.org/wiki/Implicit_function_theorem
Args:
f: function for which to find a root. Should accept a single argument,
return a tree of arrays with the same structure as its input.
initial_guess: initial guess for a zero of f.
solve: function to solve for the roots of f. Should take two positional
arguments, f and initial_guess, and return a solution with the same
structure as initial_guess such that func(solution) = 0. In other words,
the following is assumed to be true (but not checked)::
solution = solve(f, initial_guess)
error = f(solution)
assert all(error == 0)
tangent_solve: function to solve the tangent system. Should take two
positional arguments, a linear function ``g`` (the function ``f``
linearized at its root) and a tree of array(s) ``y`` with the same
structure as initial_guess, and return a solution ``x`` such that
``g(x)=y``:
- For scalar ``y``, use ``lambda g, y: y / g(1.0)``.
- For vector ``y``, you could use a linear solve with the Jacobian, if
dimensionality of ``y`` is not too large:
``lambda g, y: np.linalg.solve(jacobian(g)(y), y)``.
has_aux: bool indicating whether the ``solve`` function returns
auxiliary data like solver diagnostics as a second argument.
Returns:
The result of calling solve(f, initial_guess) with gradients defined via
implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
"""
guess_flat, in_args_tree = tree_flatten((initial_guess,))
guess_avals = tuple(_map(_abstractify, guess_flat))
f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
f, in_args_tree, guess_avals)
in_tree, = treedef_children(in_args_tree)
_check_tree("f", "initial_guess", out_tree, in_tree, False)
solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr(
partial(solve, f), in_args_tree, guess_avals)
_check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux)
def linearize_and_solve(x, b):
unchecked_zeros, f_jvp = jax.linearize(f, x)
return tangent_solve(f_jvp, b)
l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr(
linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2)
_check_tree("tangent_solve", "x", out_tree, in_tree, False)
all_consts = [f_consts, solve_consts, l_and_s_consts]
const_lengths = _RootTuple(*_map(len, all_consts))
jaxprs = _RootTuple(f_jaxpr, solve_jaxpr, l_and_s_jaxpr)
solution_flat = _custom_root(
const_lengths, jaxprs, *(_flatten(all_consts) + guess_flat))
return tree_unflatten(solution_tree, solution_flat)
@partial(jax.custom_jvp, nondiff_argnums=(0, 1))
def _custom_root(const_lengths, jaxprs, *args):
params, initial_guess = _split_root_args(args, const_lengths)
solution = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + initial_guess))
return solution
@_custom_root.defjvp
def _root_jvp(const_lengths, jaxprs, primals, tangents):
params, _ = _split_root_args(primals, const_lengths)
sol = _custom_root(const_lengths, jaxprs, *primals)
f_out_vals = len(jaxprs.f.out_avals)
solution, aux = split_list(sol, [f_out_vals])
params_dot, _ = _split_root_args(tangents, const_lengths)
# F(m, u) = 0 # system of equations in u, parameterized by m
# # solution is u*(m) defined in a neighborhood
# F(m, u*(m)) = 0 # satisfied in a neighborhood
#
# ∂_0 F(m, u*(m)) + ∂_1 F(m, u*(m)) ∂ u*(m) = 0 # implied by line above
# ∂ u*(m) = - (∂_1 F(m, u*(m)))^{-1} ∂_0 F(m, u*(m)) # rearrange
#
# ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]] # jvp
f = core.jaxpr_as_fun(jaxprs.f)
linearize_and_solve = partial(
core.jaxpr_as_fun(jaxprs.l_and_s), *params.l_and_s)
f_at_solution = lambda *params: f(*params, *solution)
_, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped(
params.f, params_dot.f)
solution_dot = _map(
operator.neg, linearize_and_solve(*solution, *rhs))
# append aux, create symbolic zero tangents for the aux values
solution += aux
solution_dot += _map(lax.zeros_like_array, aux)
return solution, solution_dot
class _LinearSolveTuple(collections.namedtuple(
'_LinearSolveTuple', 'matvec, vecmat, solve, transpose_solve')):
def transpose(self):
return type(self)(self.vecmat, self.matvec, self.transpose_solve, self.solve)
def _split_linear_solve_args(args, const_lengths):
params_list = split_list(args, list(const_lengths))
return _LinearSolveTuple(*params_list[:-1]), params_list[-1]
def _transpose_one_output(linear_fun, primals):
transpose_fun = jax.linear_transpose(linear_fun, primals)
def transposed_fun(x):
(y,) = transpose_fun(x)
return y
return transposed_fun
def _flatten(args):
return [x for arg in args for x in arg]
def _check_shapes(func_name, expected_name, actual, expected):
actual_shapes = _map(np.shape, tree_leaves(actual))
expected_shapes = _map(np.shape, tree_leaves(expected))
if actual_shapes != expected_shapes:
raise ValueError(
f"{func_name}() output shapes must match {expected_name}, "
f"got {actual_shapes} and {expected_shapes}")
@api_boundary
def custom_linear_solve(
matvec, b, solve, transpose_solve=None, symmetric=False, has_aux=False):
"""Perform a matrix-free linear solve with implicitly defined gradients.
This function allows for overriding or defining gradients for a linear
solve directly via implicit differentiation at the solution, rather than by
differentiating *through* the solve operation. This can sometimes be much faster
or more numerically stable, or differentiating through the solve operation
may not even be implemented (e.g., if ``solve`` uses ``lax.while_loop``).
Required invariant::
x = solve(matvec, b) # solve the linear equation
assert matvec(x) == b # not checked
Args:
matvec: linear function to invert. Must be differentiable.
b: constant right handle side of the equation. May be any nested structure
of arrays.
solve: higher level function that solves for solution to the linear
equation, i.e., ``solve(matvec, x) == x`` for all ``x`` of the same form
as ``b``. This function need not be differentiable.
transpose_solve: higher level function for solving the transpose linear
equation, i.e., ``transpose_solve(vecmat, x) == x``, where ``vecmat`` is
the transpose of the linear map ``matvec`` (computed automatically with
autodiff). Required for backwards mode automatic differentiation, unless
``symmetric=True``, in which case ``solve`` provides the default value.
symmetric: bool indicating if it is safe to assume the linear map
corresponds to a symmetric matrix, i.e., ``matvec == vecmat``.
has_aux: bool indicating whether the ``solve`` and ``transpose_solve`` functions
return auxiliary data like solver diagnostics as a second argument.
Returns:
Result of ``solve(matvec, b)``, with gradients defined assuming that the
solution ``x`` satisfies the linear equation ``matvec(x) == b``.
"""
if transpose_solve is None and symmetric:
transpose_solve = solve
b_flat, in_args_tree = tree_flatten((b,))
b_avals = tuple(_map(_abstractify, b_flat))
tree, = treedef_children(in_args_tree)
def _shape_checked(fun, name, has_aux):
def f(x):
y = fun(x)
_check_shapes(name, "b", y, b_flat)
return y
def f_aux(x):
y, aux = fun(x)
_check_shapes(name, "b", y, b_flat)
return y, aux
return f_aux if has_aux else f
# no auxiliary data assumed for matvec
matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr(
_shape_checked(matvec, "matvec", False), in_args_tree, b_avals,
'custom_linear_solve')
_check_tree("matvec", "b", out_tree, tree, False)
solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr(
_shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals,
'custom_linear_solve')
_check_tree("solve", "b", out_tree, tree, has_aux)
if transpose_solve is None:
vecmat_jaxpr = tr_solve_jaxpr = None
vecmat_consts = tr_solve_consts = []
else:
if symmetric:
vecmat = matvec
vecmat_jaxpr = matvec_jaxpr
vecmat_consts = matvec_consts
else:
vecmat = _transpose_one_output(matvec, b)
vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr(
vecmat, in_args_tree, b_avals, 'custom_linear_solve')
assert out_tree == tree
tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr(
_shape_checked(partial(transpose_solve, vecmat), "transpose_solve", has_aux),
in_args_tree, b_avals, 'custom_linear_solve')
_check_tree("transpose_solve", "b", out_tree, tree, has_aux)
all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts]
const_lengths = _LinearSolveTuple(*_map(len, all_consts))
jaxprs = _LinearSolveTuple(
matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr)
out_flat = linear_solve_p.bind(
*(_flatten(all_consts) + b_flat),
const_lengths=const_lengths, jaxprs=jaxprs)
return tree_unflatten(out_tree, out_flat)
def _linear_solve_abstract_eval(*args, const_lengths, jaxprs):
args_to_raise = args[sum(const_lengths):]
# raise aux_args to shaped arrays as well if present
# number of aux args is the difference in out_avals
# of solve and matvec (since they map to the same vector space)
num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals)
if num_aux > 0:
args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:])
return _map(raise_to_shaped, args_to_raise)
def _custom_linear_solve_impl(*args, const_lengths, jaxprs):
params, b = _split_linear_solve_args(args, const_lengths)
x = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + b))
return x
def _tangent_linear_map(func, params, params_dot, *x):
"""Compute the tangent of a linear map.
Assuming ``func(*params, *x)`` is linear in ``x`` and computes ``A @ x``,
this function computes ``A @ x``.
"""
assert any(type(p) is not ad_util.Zero for p in params_dot)
zeros = _map(ad_util.Zero.from_value, x)
_, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped(
params + list(x), params_dot + zeros)
return out_tangent
def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs):
# A x - b = 0
# ∂A x + A ∂x - ∂b = 0
# ∂x = A^{-1} (∂b - ∂A x)
kwargs = dict(const_lengths=const_lengths, jaxprs=jaxprs)
x = linear_solve_p.bind(*primals, **kwargs)
params, _ = _split_linear_solve_args(primals, const_lengths)
params_dot, b_dot = _split_linear_solve_args(tangents, const_lengths)
num_x_leaves = len(b_dot)
# x is a flat tree with possible aux values appended
# since x_tree == b_tree == b_dot_tree, we can cut off
# aux values with len info provided by b_dot tree here
x_leaves, _ = split_list(x, [num_x_leaves])
if all(type(p) is ad_util.Zero for p in params_dot.matvec):
# no need to evaluate matvec_tangents
rhs = b_dot
else:
matvec_tangents = _tangent_linear_map(
core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec, *x_leaves)
rhs = _map(ad.add_tangents, b_dot, _map(operator.neg, matvec_tangents))
x_dot = linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs)
# split into x tangents and aux tangents (these become zero)
dx_leaves, daux_leaves = split_list(x_dot, [num_x_leaves])
daux_leaves = _map(ad_util.Zero.from_value, daux_leaves)
x_dot = dx_leaves + daux_leaves
return x, x_dot
def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs):
if jaxprs.transpose_solve is None:
raise TypeError('transpose_solve required for backwards mode automatic '
'differentiation of custom_linear_solve')
params, b = _split_linear_solve_args(primals, const_lengths)
# split off symbolic zeros in the cotangent if present
x_cotangent, _ = split_list(cotangent, [len(b)])
assert all(ad.is_undefined_primal(x) for x in b)
cotangent_b_full = linear_solve_p.bind(
*(_flatten(params.transpose()) + x_cotangent),
const_lengths=const_lengths.transpose(), jaxprs=jaxprs.transpose())
# drop aux values in cotangent computation
cotangent_b, _ = split_list(cotangent_b_full, [len(b)])
return [None] * sum(const_lengths) + cotangent_b
def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type,
args, dims, const_lengths, jaxprs):
orig_bat = [d is not batching.not_mapped for d in dims]
params, b = _split_linear_solve_args(args, const_lengths)
params_dims, b_dims = _split_linear_solve_args(dims, const_lengths)
params_bat, orig_b_bat = _split_linear_solve_args(orig_bat, const_lengths)
(matvec, vecmat, solve, solve_t) = jaxprs
(matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat
# number of operator out avals is assumed to be the same for matvec/vecmat
num_operator_out_avals = len(matvec.out_avals)
num_aux = len(solve.out_avals) - num_operator_out_avals
# Fixpoint computation of which parts of x and b are batched; we need to
# ensure this is consistent between all four jaxprs
b_bat = orig_b_bat
x_bat = [False] * len(solve.out_avals)
for i in range(1 + len(orig_b_bat) + len(solve.out_avals)):
# Apply vecmat and solve -> new batched parts of x
solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr(
solve, axis_size, solve_bat + b_bat, instantiate=x_bat,
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
if vecmat is None:
vecmat_jaxpr_batched = None
x_bat_out = solve_x_bat
else:
vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr(
vecmat, axis_size, vecmat_bat + b_bat, instantiate=b_bat,
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
# batch all aux data by default
x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux, solve_x_bat)
# keep a slice of only the linear operator part of solve's avals
x_bat_noaux = x_bat_out[:num_operator_out_avals]
# Apply matvec and solve_t -> new batched parts of b
matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr(
matvec, axis_size, matvec_bat + x_bat_noaux, instantiate=b_bat,
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
if solve_t is None:
solve_t_jaxpr_batched = None
b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat)
else:
solve_t_jaxpr_batched, solve_t_b_aux_bat = batching.batch_jaxpr(
solve_t, axis_size, solve_t_bat + x_bat_noaux, instantiate=x_bat_out,
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
assert len(solve_t_b_aux_bat) == len(orig_b_bat) + num_aux
solve_t_b_bat, _ = split_list(solve_t_b_aux_bat, [len(orig_b_bat)])
b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat,
orig_b_bat)
if x_bat_out == x_bat and b_bat_out == b_bat:
break
else:
x_bat = x_bat_out
b_bat = b_bat_out
else:
assert False, "Fixedpoint not reached"
batched_jaxprs = _LinearSolveTuple(matvec_jaxpr_batched, vecmat_jaxpr_batched,
solve_jaxpr_batched, solve_t_jaxpr_batched)
# Move batched axes to the front
new_params = [
batching.moveaxis(x, d, 0)
if d is not batching.not_mapped and d != 0 else x
for x, d in zip(_flatten(params), _flatten(params_dims))
]
# Broadcast out b if necessary
new_b = [
batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else
batching.moveaxis(x, d, 0) if now_bat and d != 0 else x
for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat)
]
outs = linear_solve_p.bind(
*(new_params + new_b),
const_lengths=const_lengths,
jaxprs=batched_jaxprs)
out_dims = [0 if batched else batching.not_mapped for batched in solve_x_bat]
return outs, out_dims
linear_solve_p = core.AxisPrimitive('custom_linear_solve')
linear_solve_p.multiple_results = True
linear_solve_p.def_impl(_custom_linear_solve_impl)
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
xla.register_initial_style_primitive(linear_solve_p)
mlir.register_lowering(
linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl,
multiple_results=True))
ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
batching.axis_primitive_batchers[linear_solve_p] = partial(_linear_solve_batching_rule, None)
batching.spmd_axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule