Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/experimental/jet.py
2023-06-19 00:49:18 +02:00

750 lines
26 KiB
Python

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Jet is an experimental module for higher-order automatic differentiation
that does not rely on repeated first-order automatic differentiation.
How? Through the propagation of truncated Taylor polynomials.
Consider a function :math:`f = g \circ h`, some point :math:`x`
and some offset :math:`v`.
First-order automatic differentiation (such as :func:`jax.jvp`)
computes the pair :math:`(f(x), \partial f(x)[v])` from the pair
:math:`(h(x), \partial h(x)[v])`.
:func:`jet` implements the higher-order analogue:
Given the tuple
.. math::
(h_0, ... h_K) :=
(h(x), \partial h(x)[v], \partial^2 h(x)[v, v], ..., \partial^K h(x)[v,...,v]),
which represents a :math:`K`-th order Taylor approximation
of :math:`h` at :math:`x`, :func:`jet` returns a :math:`K`-th order
Taylor approximation of :math:`f` at :math:`x`,
.. math::
(f_0, ..., f_K) :=
(f(x), \partial f(x)[v], \partial^2 f(x)[v, v], ..., \partial^K f(x)[v,...,v]).
More specifically, :func:`jet` computes
.. math::
f_0, (f_1, . . . , f_K) = \texttt{jet} (f, h_0, (h_1, . . . , h_K))
and can thus be used for high-order
automatic differentiation of :math:`f`.
Details are explained in
`these notes <https://github.com/google/jax/files/6717197/jet.pdf>`__.
Note:
Help improve :func:`jet` by contributing
`outstanding primitive rules <https://github.com/google/jax/issues/2431>`__.
"""
from typing import Any, Callable, Dict, Tuple
from functools import partial
import numpy as np
from jax import lax
import jax.numpy as jnp
from jax.experimental import pjit
from jax.tree_util import (register_pytree_node, tree_structure,
treedef_is_leaf, tree_flatten, tree_unflatten,)
from jax._src import ad_util
from jax._src import core
from jax._src import dispatch
from jax._src import linear_util as lu
from jax._src import sharding_impls
from jax._src.api_util import shaped_abstractify
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
from jax._src.util import unzip2, weakref_lru_cache
def jet(fun, primals, series):
r"""Taylor-mode higher-order automatic differentiation.
Args:
fun: Function to be differentiated. Its arguments should be arrays, scalars,
or standard Python containers of arrays or scalars. It should return an
array, scalar, or standard Python container of arrays or scalars.
primals: The primal values at which the Taylor approximation of ``fun`` should be
evaluated. Should be either a tuple or a list of arguments,
and its length should be equal to the number of positional parameters of
``fun``.
series: Higher order Taylor-series-coefficients.
Together, `primals` and `series` make up a truncated Taylor polynomial.
Should be either a tuple or a list of tuples or lists,
and its length dictates the degree of the truncated Taylor polynomial.
Returns:
A ``(primals_out, series_out)`` pair, where ``primals_out`` is ``fun(*primals)``,
and together, ``primals_out`` and ``series_out`` are a
truncated Taylor polynomial of :math:`f(h(\cdot))`.
The ``primals_out`` value has the same Python tree structure as ``primals``,
and the ``series_out`` value the same Python tree structure as ``series``.
For example:
>>> import jax
>>> import jax.numpy as np
Consider the function :math:`h(z) = z^3`, :math:`x = 0.5`,
and the first few Taylor coefficients
:math:`h_0=x^3`, :math:`h_1=3x^2`, and :math:`h_2=6x`.
Let :math:`f(y) = \sin(y)`.
>>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5
>>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args)
:func:`jet` returns the Taylor coefficients of :math:`f(h(z)) = \sin(z^3)`
according to Faà di Bruno's formula:
>>> f0, (f1, f2) = jet(f, (h0,), ((h1, h2),))
>>> print(f0, f(h0))
0.12467473 0.12467473
>>> print(f1, df(h0) * h1)
0.7441479 0.74414825
>>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2)
2.9064622 2.9064634
"""
try:
order, = set(map(len, series))
except ValueError:
msg = "jet terms have inconsistent lengths for different arguments"
raise ValueError(msg) from None
# TODO(mattjj): consider supporting pytree inputs
for i, (x, terms) in enumerate(zip(primals, series)):
treedef = tree_structure(x)
if not treedef_is_leaf(treedef):
raise ValueError(f"primal value at position {i} is not an array")
for j, t in enumerate(terms):
treedef = tree_structure(t)
if not treedef_is_leaf(treedef):
raise ValueError(f"term {j} for argument {i} is not an array")
@lu.transformation_with_aux
def flatten_fun_output(*args):
ans = yield args, {}
yield tree_flatten(ans)
f, out_tree = flatten_fun_output(lu.wrap_init(fun))
out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series)
return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms)
@lu.transformation
def jet_fun(order, primals, series):
with core.new_main(JetTrace) as main:
main.order = order
out_primals, out_terms = yield (main, primals, series), {}
del main
out_terms = [[jnp.zeros_like(p)] * order if s is zero_series else s
for p, s in zip(out_primals, out_terms)]
yield out_primals, out_terms
@lu.transformation
def jet_subtrace(main, primals, series):
trace = JetTrace(main, core.cur_sublevel())
in_tracers = map(partial(JetTracer, trace), primals, series)
ans = yield in_tracers, {}
out_tracers = map(trace.full_raise, ans)
out_primals, out_terms = unzip2((t.primal, t.terms) for t in out_tracers)
yield out_primals, out_terms
@lu.transformation_with_aux
def traceable(in_tree_def, *primals_and_series):
primals_in, series_in = tree_unflatten(in_tree_def, primals_and_series)
primals_out, series_out = yield (primals_in, series_in), {}
out_flat, out_tree_def = tree_flatten((primals_out, series_out))
yield out_flat, out_tree_def
class JetTracer(core.Tracer):
__slots__ = ["primal", "terms"]
def __init__(self, trace, primal, terms):
assert type(terms) in (ZeroSeries, list, tuple)
self._trace = trace
self.primal = primal
self.terms = terms
@property
def aval(self):
return core.get_aval(self.primal)
def full_lower(self):
if self.terms is zero_series or all(t is zero_term for t in self.terms):
return core.full_lower(self.primal)
else:
return self
class JetTrace(core.Trace):
def pure(self, val):
return JetTracer(self, val, zero_series)
def lift(self, val):
return JetTracer(self, val, zero_series)
def sublift(self, val):
return JetTracer(self, val.primal, val.terms)
def process_primitive(self, primitive, tracers, params):
order = self.main.order # pytype: disable=attribute-error
primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers)
series_in = [[zero_term] * order if s is zero_series else s
for s in series_in]
# TODO(mattjj): avoid always instantiating zeros
series_in = [[jnp.zeros(np.shape(x), dtype=jnp.result_type(x))
if t is zero_term else t for t in series]
for x, series in zip(primals_in, series_in)]
rule = jet_rules[primitive]
primal_out, terms_out = rule(primals_in, series_in, **params)
if not primitive.multiple_results:
return JetTracer(self, primal_out, terms_out)
else:
return [JetTracer(self, p, ts) for p, ts in zip(primal_out, terms_out)]
def process_call(self, call_primitive, f, tracers, params):
primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers)
primals_and_series, in_tree_def = tree_flatten((primals_in, series_in))
f_jet, out_tree_def = traceable(jet_subtrace(f, self.main), in_tree_def)
update_params = call_param_updaters.get(call_primitive)
new_params = (update_params(params, len(primals_and_series))
if update_params else params)
result = call_primitive.bind(f_jet, *primals_and_series, **new_params)
primals_out, series_out = tree_unflatten(out_tree_def(), result)
return [JetTracer(self, p, ts) for p, ts in zip(primals_out, series_out)]
def post_process_call(self, call_primitive, out_tracers, params):
primals, series = unzip2((t.primal, t.terms) for t in out_tracers)
out, treedef = tree_flatten((primals, series))
del primals, series
main = self.main
def todo(x):
primals, series = tree_unflatten(treedef, x)
trace = JetTrace(main, core.cur_sublevel())
return map(partial(JetTracer, trace), primals, series)
return out, todo
def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *,
symbolic_zeros):
# TODO(mattjj): don't just ignore custom jvp rules?
del primitive, jvp # Unused.
return fun.call_wrapped(*tracers)
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
del primitive, fwd, bwd, out_trees # Unused.
return fun.call_wrapped(*tracers)
class ZeroTerm: pass
zero_term = ZeroTerm()
register_pytree_node(ZeroTerm, lambda z: ((), None), lambda _, xs: zero_term)
class ZeroSeries: pass
zero_series = ZeroSeries()
register_pytree_node(ZeroSeries, lambda z: ((), None), lambda _, xs: zero_series)
call_param_updaters: Dict[core.Primitive, Callable[..., Any]] = {}
### rule definitions
jet_rules = {}
def defzero(prim):
jet_rules[prim] = partial(zero_prop, prim)
def zero_prop(prim, primals_in, series_in, **params):
primal_out = prim.bind(*primals_in, **params)
return primal_out, zero_series
defzero(lax.le_p)
defzero(lax.lt_p)
defzero(lax.gt_p)
defzero(lax.ge_p)
defzero(lax.eq_p)
defzero(lax.ne_p)
defzero(lax.not_p)
defzero(lax.and_p)
defzero(lax.or_p)
defzero(lax.xor_p)
defzero(lax.floor_p)
defzero(lax.ceil_p)
defzero(lax.round_p)
defzero(lax.sign_p)
defzero(ad_util.stop_gradient_p)
defzero(lax.is_finite_p)
defzero(lax.shift_left_p)
defzero(lax.shift_right_arithmetic_p)
defzero(lax.shift_right_logical_p)
defzero(lax.bitcast_convert_type_p)
def deflinear(prim):
jet_rules[prim] = partial(linear_prop, prim)
def linear_prop(prim, primals_in, series_in, **params):
primal_out = prim.bind(*primals_in, **params)
series_out = [prim.bind(*terms_in, **params) for terms_in in zip(*series_in)]
return primal_out, series_out
deflinear(lax.neg_p)
deflinear(lax.real_p)
deflinear(lax.complex_p)
deflinear(lax.conj_p)
deflinear(lax.imag_p)
deflinear(lax.add_p)
deflinear(ad_util.add_jaxvals_p)
deflinear(lax.sub_p)
deflinear(lax.convert_element_type_p)
deflinear(lax.broadcast_in_dim_p)
deflinear(lax.concatenate_p)
deflinear(lax.pad_p)
deflinear(lax.reshape_p)
deflinear(lax.squeeze_p)
deflinear(lax.rev_p)
deflinear(lax.transpose_p)
deflinear(lax.slice_p)
deflinear(lax.reduce_sum_p)
deflinear(lax.reduce_window_sum_p)
deflinear(lax.fft_p)
deflinear(dispatch.device_put_p)
def _dynamic_slice_jet_rule(primals_in, series_in, **params):
operand, *start_indices = primals_in
primal_out = lax.dynamic_slice_p.bind(operand, *start_indices, **params)
series_out = [lax.dynamic_slice_p.bind(terms_in[0], *start_indices, **params)
for terms_in in zip(*series_in)]
return primal_out, series_out
jet_rules[lax.dynamic_slice_p] = _dynamic_slice_jet_rule
def _dynamic_update_slice_jet_rule(primals_in, series_in, **params):
operand, update, *start_indices = primals_in
primal_out = lax.dynamic_update_slice_p.bind(operand, update, *start_indices)
series_out = [lax.dynamic_update_slice_p.bind(*terms_in[:2], *start_indices, **params)
for terms_in in zip(*series_in)]
return primal_out, series_out
jet_rules[lax.dynamic_update_slice_p] = _dynamic_update_slice_jet_rule
def _cumulative_jet_rule(primals_in, series_in, *, axis: int, reverse: bool,
combine_fn: Callable):
# Irrespective of backend, we always use the parallel prefix scan
# implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
return jet(partial(lax.associative_scan, combine_fn, axis=axis,
reverse=reverse),
primals_in, series_in)
deflinear(lax.cumsum_p)
jet_rules[lax.cumprod_p] = partial(_cumulative_jet_rule,
combine_fn=lax.mul)
jet_rules[lax.cummax_p] = partial(_cumulative_jet_rule,
combine_fn=lax.max)
jet_rules[lax.cummin_p] = partial(_cumulative_jet_rule,
combine_fn=lax.min)
def def_deriv(prim, deriv):
"""
Define the jet rule for a primitive in terms of its first derivative.
"""
jet_rules[prim] = partial(deriv_prop, prim, deriv)
def deriv_prop(prim, deriv, primals_in, series_in):
x, = primals_in
series, = series_in
primal_out = prim.bind(x)
c0, cs = jet(deriv, primals_in, series_in)
c = [c0] + cs
u = [x] + series
v = [primal_out] + [None] * len(series)
for k in range(1, len(v)):
v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
primal_out, *series_out = v
return primal_out, series_out
def_deriv(lax.erf_p,
lambda x: lax.mul(
lax_internal._const(x, 2. / np.sqrt(np.pi)),
lax.exp(lax.neg(lax.square(x)))))
def def_comp(prim, comp):
"""
Define the jet rule for a primitive in terms of a composition of simpler primitives.
"""
jet_rules[prim] = partial(jet, comp)
def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x ** 0.5)
def_comp(lax.rsqrt_p, lambda x: x ** -0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
def_comp(lax.atanh_p, lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x)))
def_comp(lax.erfc_p, lambda x: 1 - lax.erf(x))
def_comp(lax.rem_p, lambda x, y: x - y * lax.floor(x / y))
def_comp(lax.clamp_p, lambda a, x, b: lax.min(lax.max(a, x), b))
def _erf_inv_rule(primals_in, series_in):
x, = primals_in
series, = series_in
u = [x] + series
primal_out = lax.erf_inv(x)
v = [primal_out] + [None] * len(series)
# derivative on co-domain for caching purposes
deriv_const = np.sqrt(np.pi) / 2.
deriv_y = lambda y: lax.mul(deriv_const, lax.exp(lax.square(y)))
# manually propagate through deriv_y since we don't have lazy evaluation of sensitivities
c = [deriv_y(primal_out)] + [None] * (len(series) - 1)
tmp_sq = [lax.square(v[0])] + [None] * (len(series) - 1)
tmp_exp = [lax.exp(tmp_sq[0])] + [None] * (len(series) - 1)
for k in range(1, len(series)):
# we know c[:k], we compute c[k]
# propagate c to get v
v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
# propagate v to get next c
# square
tmp_sq[k] = fact(k) * sum(_scale2(k, j) * v[k-j] * v[j] for j in range(k + 1))
# exp
tmp_exp[k] = fact(k-1) * sum(_scale(k, j) * tmp_exp[k-j] * tmp_sq[j] for j in range(1, k + 1))
# const
c[k] = deriv_const * tmp_exp[k]
# we can't, and don't need, to compute c[k+1], just need to get the last v[k]
k = len(series)
v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
primal_out, *series_out = v
return primal_out, series_out
jet_rules[lax.erf_inv_p] = _erf_inv_rule
### More complicated rules
def fact(n):
return lax.exp(lax.lgamma(n+1.))
def _scale(k, j):
return 1. / (fact(k - j) * fact(j - 1))
def _scale2(k, j):
return 1. / (fact(k - j) * fact(j))
def _exp_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
u = [x] + series
v = [lax.exp(x)] + [None] * len(series)
for k in range(1,len(v)):
v[k] = fact(k-1) * sum(_scale(k, j) * v[k-j] * u[j] for j in range(1, k+1))
primal_out, *series_out = v
return primal_out, series_out
jet_rules[lax.exp_p] = _exp_taylor
def _pow_taylor(primals_in, series_in):
u_, r_ = primals_in
x, series = jet(lambda x, y: lax.mul(y, lax.log(x)), primals_in, series_in)
u = [x] + series
v = [u_ ** r_] + [None] * len(series)
for k in range(1, len(v)):
v[k] = fact(k-1) * sum(_scale(k, j) * v[k-j] * u[j] for j in range(1, k+1))
primal_out, *series_out = v
return primal_out, series_out
jet_rules[lax.pow_p] = _pow_taylor
def _integer_pow_taylor(primals_in, series_in, *, y):
if y == 0:
return jet(jnp.ones_like, primals_in, series_in)
elif y == 1:
return jet(lambda x: x, primals_in, series_in)
elif y == 2:
return jet(lambda x: x * x, primals_in, series_in)
x, = primals_in
series, = series_in
u = [x] + series
v = [lax.integer_pow(x, y)] + [None] * len(series)
for k in range(1, len(v)):
vu = sum(_scale(k, j) * v[k-j] * u[j] for j in range(1, k + 1))
uv = sum(_scale(k, j) * u[k-j] * v[j] for j in range(1, k))
v[k] = jnp.where(x == 0, 0, fact(k-1) * (y * vu - uv) / x)
primal_out, *series_out = v
return primal_out, series_out
jet_rules[lax.integer_pow_p] = _integer_pow_taylor
def _logistic_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
u = [x] + series
v = [lax.logistic(x)] + [None] * len(series)
e = [v[0] * (1 - v[0])] + [None] * len(series) # terms for sigmoid' = sigmoid * (1 - sigmoid)
for k in range(1, len(v)):
v[k] = fact(k-1) * sum(_scale(k, j) * e[k-j] * u[j] for j in range(1, k+1))
e[k] = (1 - v[0]) * v[k] - fact(k) * sum(_scale2(k, j) * v[j] * v[k-j] for j in range(1, k+1))
primal_out, *series_out = v
return primal_out, series_out
jet_rules[lax.logistic_p] = _logistic_taylor
def _tanh_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
u = [2*x] + [2 * series_ for series_ in series]
primals_in, *series_in = u
primal_out, series_out = _logistic_taylor((primals_in, ), (series_in, ))
series_out = [2 * series_ for series_ in series_out]
return 2 * primal_out - 1, series_out
jet_rules[lax.tanh_p] = _tanh_taylor
def _log_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
u = [x] + series
v = [lax.log(x)] + [None] * len(series)
for k in range(1, len(v)):
conv = sum(_scale(k, j) * v[j] * u[k-j] for j in range(1, k))
v[k] = (u[k] - fact(k - 1) * conv) / u[0]
primal_out, *series_out = v
return primal_out, series_out
jet_rules[lax.log_p] = _log_taylor
def _atan2_taylor(primals_in, series_in):
x, y = primals_in
primal_out = lax.atan2(x, y)
x, series = jet(lax.div, primals_in, series_in)
one = lax_internal._const(x, 1)
c0, cs = jet(lambda x: lax.div(one, 1 + lax.square(x)), (x, ), (series, ))
c = [c0] + cs
u = [x] + series
v = [primal_out] + [None] * len(series)
for k in range(1, len(v)):
v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
primal_out, *series_out = v
return primal_out, series_out
jet_rules[lax.atan2_p] = _atan2_taylor
def _div_taylor_rule(primals_in, series_in):
x, y = primals_in
x_terms, y_terms = series_in
u = [x] + x_terms
w = [y] + y_terms
v = [None] * len(u)
def scale(k, j): return 1. / (fact(k - j) * fact(j))
for k in range(0, len(v)):
conv = sum(scale(k, j) * v[j] * w[k-j] for j in range(0, k))
v[k] = (u[k] - fact(k) * conv) / w[0]
primal_out, *series_out = v
return primal_out, series_out
jet_rules[lax.div_p] = _div_taylor_rule
def _sinusoidal_rule(sign, prims, primals_in, series_in):
x, = primals_in
series, = series_in
u = [x] + series
s, c = prims
s = [s(x)] + [None] * len(series)
c = [c(x)] + [None] * len(series)
for k in range(1, len(s)):
s[k] = fact(k-1) * sum(_scale(k, j) * u[j] * c[k-j] for j in range(1, k + 1))
c[k] = fact(k-1) * sum(_scale(k, j) * u[j] * s[k-j] for j in range(1, k + 1)) * sign
return (s[0], s[1:]), (c[0], c[1:])
def _get_ind(f, ind):
return lambda *args: f(*args)[ind]
jet_rules[lax.sin_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 0)
jet_rules[lax.cos_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 1)
jet_rules[lax.sinh_p] = _get_ind(partial(_sinusoidal_rule, 1, (lax.sinh, lax.cosh)), 0)
jet_rules[lax.cosh_p] = _get_ind(partial(_sinusoidal_rule, 1, (lax.sinh, lax.cosh)), 1)
def _bilinear_taylor_rule(prim, primals_in, series_in, **params):
x, y = primals_in
x_terms, y_terms = series_in
u = [x] + x_terms
w = [y] + y_terms
v = [None] * len(u)
op = partial(prim.bind, **params)
def scale(k, j): return 1. / (fact(k - j) * fact(j))
for k in range(0, len(v)):
v[k] = fact(k) * sum(scale(k, j) * op(u[j], w[k-j]) for j in range(0, k+1))
primal_out, *series_out = v
return primal_out, series_out
jet_rules[lax.dot_general_p] = partial(_bilinear_taylor_rule, lax.dot_general_p)
jet_rules[lax.mul_p] = partial(_bilinear_taylor_rule, lax.mul_p)
jet_rules[lax.conv_general_dilated_p] = partial(_bilinear_taylor_rule, lax.conv_general_dilated_p)
def _gather_taylor_rule(primals_in, series_in, **params):
operand, start_indices = primals_in
gs, _ = series_in
primal_out = lax.gather_p.bind(operand, start_indices, **params)
series_out = [lax.gather_p.bind(g, start_indices, **params) for g in gs]
return primal_out, series_out
jet_rules[lax.gather_p] = _gather_taylor_rule
def _gen_reduce_choose_taylor_rule(chooser_fun):
def chooser_taylor_rule(primals_in, series_in, **params):
operand, = primals_in
gs, = series_in
primal_out = chooser_fun(operand, **params)
axes = params.pop("axes", None)
primal_dtype = gs[0].dtype
shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
location_indicators = lax.convert_element_type(
lax_internal._eq_meet(operand, lax.reshape(primal_out, shape)),
primal_dtype)
counts = lax_internal._reduce_sum(location_indicators, axes)
def _reduce_chooser_taylor_rule(g):
return lax.div(
lax_internal._reduce_sum(lax.mul(g, location_indicators), axes),
counts)
series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
return primal_out, series_out
return chooser_taylor_rule
jet_rules[lax.reduce_max_p] = _gen_reduce_choose_taylor_rule(
lax_internal._reduce_max)
jet_rules[lax.reduce_min_p] = _gen_reduce_choose_taylor_rule(
lax_internal._reduce_min)
def _abs_taylor_rule(x, series_in, **params):
x, = x
zero = lax.full_like(x, 0, shape=())
primal_out = lax.abs_p.bind(x, **params)
negs = lax.select(lax.lt(x, zero), lax.full_like(x, -1), lax.full_like(x, 1.0))
fix_sign = lambda y: negs * y
series_out = [fix_sign(*terms_in, **params) for terms_in in zip(*series_in)]
return primal_out, series_out
jet_rules[lax.abs_p] = _abs_taylor_rule
def _select_n_taylor_rule(primal_in, series_in, **params):
b, *cases = primal_in
primal_out = lax.select_n(b, *cases)
sel = lambda _, *xs: lax.select_n(b, *xs)
series_out = [sel(*terms_in) for terms_in in zip(*series_in)]
return primal_out, series_out
jet_rules[lax.select_n_p] = _select_n_taylor_rule
def _lax_max_taylor_rule(primal_in, series_in):
x, y = jnp.broadcast_arrays(*primal_in)
xgy = x > y # greater than mask
xey = x == y # equal to mask
primal_out = lax.select(xgy, x, y)
def select_max_and_avg_eq(x_i, y_i):
"""Select x where x>y or average when x==y"""
max_i = lax.select(xgy, x_i, y_i)
max_i = lax.select(xey, (x_i + y_i)/2, max_i)
return max_i
series_out = [select_max_and_avg_eq(*jnp.broadcast_arrays(*terms_in)) for terms_in in zip(*series_in)]
return primal_out, series_out
jet_rules[lax.max_p] = _lax_max_taylor_rule
def _lax_min_taylor_rule(primal_in, series_in):
x, y = primal_in
xgy = x < y # less than mask
xey = x == y # equal to mask
primal_out = lax.select(xgy, x, y)
def select_min_and_avg_eq(x_i, y_i):
"""Select x where x>y or average when x==y"""
min_i = lax.select(xgy, x_i, y_i)
min_i = lax.select(xey, (x_i + y_i)/2, min_i)
return min_i
series_out = [select_min_and_avg_eq(*terms_in) for terms_in in zip(*series_in)]
return primal_out, series_out
jet_rules[lax.min_p] = _lax_min_taylor_rule
def _scatter_add_rule(primals_in, series_in, *, update_jaxpr, update_consts,
dimension_numbers, indices_are_sorted, unique_indices,
mode):
bind = partial(lax.scatter_add_p.bind, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)
operand, scatter_indices, updates = primals_in
primal_out = bind(operand, scatter_indices, updates)
series_out = [bind(d1, scatter_indices, d2) for d1, _, d2 in zip(*series_in)]
return primal_out, series_out
jet_rules[lax.scatter_add_p] = _scatter_add_rule
@weakref_lru_cache
def _jet_jaxpr(
jaxpr: core.ClosedJaxpr, order: int, primals_and_series_avals, in_tree_def
) -> Tuple[core.ClosedJaxpr, Any]:
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
f_jet, out_tree_def = traceable(jet_fun(jet_subtrace(f), order), in_tree_def)
jaxpr_jet, _, consts = pe.trace_to_jaxpr_dynamic(
f_jet, primals_and_series_avals)
return core.ClosedJaxpr(jaxpr_jet, consts), out_tree_def
def _pjit_jet_rule(primals_in, series_in, **params):
primals_and_series, in_tree_def = tree_flatten((primals_in, series_in))
order = len(series_in[0])
primals_and_series_avals = tuple(shaped_abstractify(x) for x in primals_and_series)
jaxpr_jet, out_tree_def = _jet_jaxpr(params['jaxpr'], order,
primals_and_series_avals, in_tree_def)
num_series_in = len(primals_in) * order
num_series_out = len(params['out_shardings']) * order
new_params = {
**params,
'jaxpr': jaxpr_jet,
'in_shardings': (
params['in_shardings'] + (sharding_impls.UNSPECIFIED,) * num_series_in
),
'out_shardings': (
params['out_shardings']
+ (sharding_impls.UNSPECIFIED,) * num_series_out
),
'donated_invars': params['donated_invars'] + (False,) * num_series_in,
}
result = pjit.pjit_p.bind(*primals_and_series, **new_params)
return tree_unflatten(out_tree_def(), result)
jet_rules[pjit.pjit_p] = _pjit_jet_rule