750 lines
26 KiB
Python
750 lines
26 KiB
Python
# Copyright 2020 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
r"""Jet is an experimental module for higher-order automatic differentiation
|
|
that does not rely on repeated first-order automatic differentiation.
|
|
|
|
How? Through the propagation of truncated Taylor polynomials.
|
|
Consider a function :math:`f = g \circ h`, some point :math:`x`
|
|
and some offset :math:`v`.
|
|
First-order automatic differentiation (such as :func:`jax.jvp`)
|
|
computes the pair :math:`(f(x), \partial f(x)[v])` from the pair
|
|
:math:`(h(x), \partial h(x)[v])`.
|
|
|
|
:func:`jet` implements the higher-order analogue:
|
|
Given the tuple
|
|
|
|
.. math::
|
|
(h_0, ... h_K) :=
|
|
(h(x), \partial h(x)[v], \partial^2 h(x)[v, v], ..., \partial^K h(x)[v,...,v]),
|
|
|
|
which represents a :math:`K`-th order Taylor approximation
|
|
of :math:`h` at :math:`x`, :func:`jet` returns a :math:`K`-th order
|
|
Taylor approximation of :math:`f` at :math:`x`,
|
|
|
|
.. math::
|
|
(f_0, ..., f_K) :=
|
|
(f(x), \partial f(x)[v], \partial^2 f(x)[v, v], ..., \partial^K f(x)[v,...,v]).
|
|
|
|
More specifically, :func:`jet` computes
|
|
|
|
.. math::
|
|
f_0, (f_1, . . . , f_K) = \texttt{jet} (f, h_0, (h_1, . . . , h_K))
|
|
|
|
and can thus be used for high-order
|
|
automatic differentiation of :math:`f`.
|
|
Details are explained in
|
|
`these notes <https://github.com/google/jax/files/6717197/jet.pdf>`__.
|
|
|
|
Note:
|
|
Help improve :func:`jet` by contributing
|
|
`outstanding primitive rules <https://github.com/google/jax/issues/2431>`__.
|
|
"""
|
|
|
|
from typing import Any, Callable, Dict, Tuple
|
|
|
|
from functools import partial
|
|
|
|
import numpy as np
|
|
|
|
from jax import lax
|
|
import jax.numpy as jnp
|
|
from jax.experimental import pjit
|
|
from jax.tree_util import (register_pytree_node, tree_structure,
|
|
treedef_is_leaf, tree_flatten, tree_unflatten,)
|
|
|
|
from jax._src import ad_util
|
|
from jax._src import core
|
|
from jax._src import dispatch
|
|
from jax._src import linear_util as lu
|
|
from jax._src import sharding_impls
|
|
from jax._src.api_util import shaped_abstractify
|
|
from jax._src.interpreters import partial_eval as pe
|
|
from jax._src.lax import lax as lax_internal
|
|
from jax._src.util import unzip2, weakref_lru_cache
|
|
|
|
|
|
def jet(fun, primals, series):
|
|
r"""Taylor-mode higher-order automatic differentiation.
|
|
|
|
Args:
|
|
fun: Function to be differentiated. Its arguments should be arrays, scalars,
|
|
or standard Python containers of arrays or scalars. It should return an
|
|
array, scalar, or standard Python container of arrays or scalars.
|
|
primals: The primal values at which the Taylor approximation of ``fun`` should be
|
|
evaluated. Should be either a tuple or a list of arguments,
|
|
and its length should be equal to the number of positional parameters of
|
|
``fun``.
|
|
series: Higher order Taylor-series-coefficients.
|
|
Together, `primals` and `series` make up a truncated Taylor polynomial.
|
|
Should be either a tuple or a list of tuples or lists,
|
|
and its length dictates the degree of the truncated Taylor polynomial.
|
|
|
|
Returns:
|
|
A ``(primals_out, series_out)`` pair, where ``primals_out`` is ``fun(*primals)``,
|
|
and together, ``primals_out`` and ``series_out`` are a
|
|
truncated Taylor polynomial of :math:`f(h(\cdot))`.
|
|
The ``primals_out`` value has the same Python tree structure as ``primals``,
|
|
and the ``series_out`` value the same Python tree structure as ``series``.
|
|
|
|
For example:
|
|
|
|
>>> import jax
|
|
>>> import jax.numpy as np
|
|
|
|
Consider the function :math:`h(z) = z^3`, :math:`x = 0.5`,
|
|
and the first few Taylor coefficients
|
|
:math:`h_0=x^3`, :math:`h_1=3x^2`, and :math:`h_2=6x`.
|
|
Let :math:`f(y) = \sin(y)`.
|
|
|
|
>>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5
|
|
>>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args)
|
|
|
|
:func:`jet` returns the Taylor coefficients of :math:`f(h(z)) = \sin(z^3)`
|
|
according to Faà di Bruno's formula:
|
|
|
|
>>> f0, (f1, f2) = jet(f, (h0,), ((h1, h2),))
|
|
>>> print(f0, f(h0))
|
|
0.12467473 0.12467473
|
|
|
|
>>> print(f1, df(h0) * h1)
|
|
0.7441479 0.74414825
|
|
|
|
>>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2)
|
|
2.9064622 2.9064634
|
|
"""
|
|
try:
|
|
order, = set(map(len, series))
|
|
except ValueError:
|
|
msg = "jet terms have inconsistent lengths for different arguments"
|
|
raise ValueError(msg) from None
|
|
|
|
# TODO(mattjj): consider supporting pytree inputs
|
|
for i, (x, terms) in enumerate(zip(primals, series)):
|
|
treedef = tree_structure(x)
|
|
if not treedef_is_leaf(treedef):
|
|
raise ValueError(f"primal value at position {i} is not an array")
|
|
for j, t in enumerate(terms):
|
|
treedef = tree_structure(t)
|
|
if not treedef_is_leaf(treedef):
|
|
raise ValueError(f"term {j} for argument {i} is not an array")
|
|
|
|
@lu.transformation_with_aux
|
|
def flatten_fun_output(*args):
|
|
ans = yield args, {}
|
|
yield tree_flatten(ans)
|
|
|
|
f, out_tree = flatten_fun_output(lu.wrap_init(fun))
|
|
out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series)
|
|
return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms)
|
|
|
|
@lu.transformation
|
|
def jet_fun(order, primals, series):
|
|
with core.new_main(JetTrace) as main:
|
|
main.order = order
|
|
out_primals, out_terms = yield (main, primals, series), {}
|
|
del main
|
|
out_terms = [[jnp.zeros_like(p)] * order if s is zero_series else s
|
|
for p, s in zip(out_primals, out_terms)]
|
|
yield out_primals, out_terms
|
|
|
|
@lu.transformation
|
|
def jet_subtrace(main, primals, series):
|
|
trace = JetTrace(main, core.cur_sublevel())
|
|
in_tracers = map(partial(JetTracer, trace), primals, series)
|
|
ans = yield in_tracers, {}
|
|
out_tracers = map(trace.full_raise, ans)
|
|
out_primals, out_terms = unzip2((t.primal, t.terms) for t in out_tracers)
|
|
yield out_primals, out_terms
|
|
|
|
@lu.transformation_with_aux
|
|
def traceable(in_tree_def, *primals_and_series):
|
|
primals_in, series_in = tree_unflatten(in_tree_def, primals_and_series)
|
|
primals_out, series_out = yield (primals_in, series_in), {}
|
|
out_flat, out_tree_def = tree_flatten((primals_out, series_out))
|
|
yield out_flat, out_tree_def
|
|
|
|
|
|
class JetTracer(core.Tracer):
|
|
__slots__ = ["primal", "terms"]
|
|
|
|
def __init__(self, trace, primal, terms):
|
|
assert type(terms) in (ZeroSeries, list, tuple)
|
|
self._trace = trace
|
|
self.primal = primal
|
|
self.terms = terms
|
|
|
|
@property
|
|
def aval(self):
|
|
return core.get_aval(self.primal)
|
|
|
|
def full_lower(self):
|
|
if self.terms is zero_series or all(t is zero_term for t in self.terms):
|
|
return core.full_lower(self.primal)
|
|
else:
|
|
return self
|
|
|
|
class JetTrace(core.Trace):
|
|
|
|
def pure(self, val):
|
|
return JetTracer(self, val, zero_series)
|
|
|
|
def lift(self, val):
|
|
return JetTracer(self, val, zero_series)
|
|
|
|
def sublift(self, val):
|
|
return JetTracer(self, val.primal, val.terms)
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
|
order = self.main.order # pytype: disable=attribute-error
|
|
primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers)
|
|
series_in = [[zero_term] * order if s is zero_series else s
|
|
for s in series_in]
|
|
# TODO(mattjj): avoid always instantiating zeros
|
|
series_in = [[jnp.zeros(np.shape(x), dtype=jnp.result_type(x))
|
|
if t is zero_term else t for t in series]
|
|
for x, series in zip(primals_in, series_in)]
|
|
rule = jet_rules[primitive]
|
|
primal_out, terms_out = rule(primals_in, series_in, **params)
|
|
if not primitive.multiple_results:
|
|
return JetTracer(self, primal_out, terms_out)
|
|
else:
|
|
return [JetTracer(self, p, ts) for p, ts in zip(primal_out, terms_out)]
|
|
|
|
def process_call(self, call_primitive, f, tracers, params):
|
|
primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers)
|
|
primals_and_series, in_tree_def = tree_flatten((primals_in, series_in))
|
|
f_jet, out_tree_def = traceable(jet_subtrace(f, self.main), in_tree_def)
|
|
update_params = call_param_updaters.get(call_primitive)
|
|
new_params = (update_params(params, len(primals_and_series))
|
|
if update_params else params)
|
|
result = call_primitive.bind(f_jet, *primals_and_series, **new_params)
|
|
primals_out, series_out = tree_unflatten(out_tree_def(), result)
|
|
return [JetTracer(self, p, ts) for p, ts in zip(primals_out, series_out)]
|
|
|
|
def post_process_call(self, call_primitive, out_tracers, params):
|
|
primals, series = unzip2((t.primal, t.terms) for t in out_tracers)
|
|
out, treedef = tree_flatten((primals, series))
|
|
del primals, series
|
|
main = self.main
|
|
def todo(x):
|
|
primals, series = tree_unflatten(treedef, x)
|
|
trace = JetTrace(main, core.cur_sublevel())
|
|
return map(partial(JetTracer, trace), primals, series)
|
|
return out, todo
|
|
|
|
def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *,
|
|
symbolic_zeros):
|
|
# TODO(mattjj): don't just ignore custom jvp rules?
|
|
del primitive, jvp # Unused.
|
|
return fun.call_wrapped(*tracers)
|
|
|
|
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
|
|
del primitive, fwd, bwd, out_trees # Unused.
|
|
return fun.call_wrapped(*tracers)
|
|
|
|
|
|
class ZeroTerm: pass
|
|
zero_term = ZeroTerm()
|
|
register_pytree_node(ZeroTerm, lambda z: ((), None), lambda _, xs: zero_term)
|
|
|
|
class ZeroSeries: pass
|
|
zero_series = ZeroSeries()
|
|
register_pytree_node(ZeroSeries, lambda z: ((), None), lambda _, xs: zero_series)
|
|
|
|
|
|
call_param_updaters: Dict[core.Primitive, Callable[..., Any]] = {}
|
|
|
|
|
|
### rule definitions
|
|
|
|
jet_rules = {}
|
|
|
|
def defzero(prim):
|
|
jet_rules[prim] = partial(zero_prop, prim)
|
|
|
|
def zero_prop(prim, primals_in, series_in, **params):
|
|
primal_out = prim.bind(*primals_in, **params)
|
|
return primal_out, zero_series
|
|
|
|
defzero(lax.le_p)
|
|
defzero(lax.lt_p)
|
|
defzero(lax.gt_p)
|
|
defzero(lax.ge_p)
|
|
defzero(lax.eq_p)
|
|
defzero(lax.ne_p)
|
|
defzero(lax.not_p)
|
|
defzero(lax.and_p)
|
|
defzero(lax.or_p)
|
|
defzero(lax.xor_p)
|
|
defzero(lax.floor_p)
|
|
defzero(lax.ceil_p)
|
|
defzero(lax.round_p)
|
|
defzero(lax.sign_p)
|
|
defzero(ad_util.stop_gradient_p)
|
|
defzero(lax.is_finite_p)
|
|
defzero(lax.shift_left_p)
|
|
defzero(lax.shift_right_arithmetic_p)
|
|
defzero(lax.shift_right_logical_p)
|
|
defzero(lax.bitcast_convert_type_p)
|
|
|
|
|
|
def deflinear(prim):
|
|
jet_rules[prim] = partial(linear_prop, prim)
|
|
|
|
def linear_prop(prim, primals_in, series_in, **params):
|
|
primal_out = prim.bind(*primals_in, **params)
|
|
series_out = [prim.bind(*terms_in, **params) for terms_in in zip(*series_in)]
|
|
return primal_out, series_out
|
|
|
|
deflinear(lax.neg_p)
|
|
deflinear(lax.real_p)
|
|
deflinear(lax.complex_p)
|
|
deflinear(lax.conj_p)
|
|
deflinear(lax.imag_p)
|
|
deflinear(lax.add_p)
|
|
deflinear(ad_util.add_jaxvals_p)
|
|
deflinear(lax.sub_p)
|
|
deflinear(lax.convert_element_type_p)
|
|
deflinear(lax.broadcast_in_dim_p)
|
|
deflinear(lax.concatenate_p)
|
|
deflinear(lax.pad_p)
|
|
deflinear(lax.reshape_p)
|
|
deflinear(lax.squeeze_p)
|
|
deflinear(lax.rev_p)
|
|
deflinear(lax.transpose_p)
|
|
deflinear(lax.slice_p)
|
|
deflinear(lax.reduce_sum_p)
|
|
deflinear(lax.reduce_window_sum_p)
|
|
deflinear(lax.fft_p)
|
|
deflinear(dispatch.device_put_p)
|
|
|
|
def _dynamic_slice_jet_rule(primals_in, series_in, **params):
|
|
operand, *start_indices = primals_in
|
|
primal_out = lax.dynamic_slice_p.bind(operand, *start_indices, **params)
|
|
series_out = [lax.dynamic_slice_p.bind(terms_in[0], *start_indices, **params)
|
|
for terms_in in zip(*series_in)]
|
|
return primal_out, series_out
|
|
|
|
jet_rules[lax.dynamic_slice_p] = _dynamic_slice_jet_rule
|
|
|
|
def _dynamic_update_slice_jet_rule(primals_in, series_in, **params):
|
|
operand, update, *start_indices = primals_in
|
|
primal_out = lax.dynamic_update_slice_p.bind(operand, update, *start_indices)
|
|
series_out = [lax.dynamic_update_slice_p.bind(*terms_in[:2], *start_indices, **params)
|
|
for terms_in in zip(*series_in)]
|
|
return primal_out, series_out
|
|
|
|
jet_rules[lax.dynamic_update_slice_p] = _dynamic_update_slice_jet_rule
|
|
|
|
def _cumulative_jet_rule(primals_in, series_in, *, axis: int, reverse: bool,
|
|
combine_fn: Callable):
|
|
# Irrespective of backend, we always use the parallel prefix scan
|
|
# implementation when differentiating because reduce_window is not
|
|
# arbitrarily differentiable.
|
|
return jet(partial(lax.associative_scan, combine_fn, axis=axis,
|
|
reverse=reverse),
|
|
primals_in, series_in)
|
|
|
|
deflinear(lax.cumsum_p)
|
|
jet_rules[lax.cumprod_p] = partial(_cumulative_jet_rule,
|
|
combine_fn=lax.mul)
|
|
jet_rules[lax.cummax_p] = partial(_cumulative_jet_rule,
|
|
combine_fn=lax.max)
|
|
jet_rules[lax.cummin_p] = partial(_cumulative_jet_rule,
|
|
combine_fn=lax.min)
|
|
|
|
|
|
def def_deriv(prim, deriv):
|
|
"""
|
|
Define the jet rule for a primitive in terms of its first derivative.
|
|
"""
|
|
jet_rules[prim] = partial(deriv_prop, prim, deriv)
|
|
|
|
|
|
def deriv_prop(prim, deriv, primals_in, series_in):
|
|
x, = primals_in
|
|
series, = series_in
|
|
primal_out = prim.bind(x)
|
|
c0, cs = jet(deriv, primals_in, series_in)
|
|
c = [c0] + cs
|
|
u = [x] + series
|
|
v = [primal_out] + [None] * len(series)
|
|
for k in range(1, len(v)):
|
|
v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
|
|
primal_out, *series_out = v
|
|
return primal_out, series_out
|
|
|
|
|
|
def_deriv(lax.erf_p,
|
|
lambda x: lax.mul(
|
|
lax_internal._const(x, 2. / np.sqrt(np.pi)),
|
|
lax.exp(lax.neg(lax.square(x)))))
|
|
|
|
|
|
def def_comp(prim, comp):
|
|
"""
|
|
Define the jet rule for a primitive in terms of a composition of simpler primitives.
|
|
"""
|
|
jet_rules[prim] = partial(jet, comp)
|
|
|
|
|
|
def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
|
|
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
|
|
def_comp(lax.sqrt_p, lambda x: x ** 0.5)
|
|
def_comp(lax.rsqrt_p, lambda x: x ** -0.5)
|
|
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
|
|
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
|
|
def_comp(lax.atanh_p, lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x)))
|
|
def_comp(lax.erfc_p, lambda x: 1 - lax.erf(x))
|
|
def_comp(lax.rem_p, lambda x, y: x - y * lax.floor(x / y))
|
|
def_comp(lax.clamp_p, lambda a, x, b: lax.min(lax.max(a, x), b))
|
|
|
|
|
|
def _erf_inv_rule(primals_in, series_in):
|
|
x, = primals_in
|
|
series, = series_in
|
|
|
|
u = [x] + series
|
|
primal_out = lax.erf_inv(x)
|
|
v = [primal_out] + [None] * len(series)
|
|
|
|
# derivative on co-domain for caching purposes
|
|
deriv_const = np.sqrt(np.pi) / 2.
|
|
deriv_y = lambda y: lax.mul(deriv_const, lax.exp(lax.square(y)))
|
|
|
|
# manually propagate through deriv_y since we don't have lazy evaluation of sensitivities
|
|
|
|
c = [deriv_y(primal_out)] + [None] * (len(series) - 1)
|
|
tmp_sq = [lax.square(v[0])] + [None] * (len(series) - 1)
|
|
tmp_exp = [lax.exp(tmp_sq[0])] + [None] * (len(series) - 1)
|
|
for k in range(1, len(series)):
|
|
# we know c[:k], we compute c[k]
|
|
|
|
# propagate c to get v
|
|
v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
|
|
|
|
# propagate v to get next c
|
|
|
|
# square
|
|
tmp_sq[k] = fact(k) * sum(_scale2(k, j) * v[k-j] * v[j] for j in range(k + 1))
|
|
|
|
# exp
|
|
tmp_exp[k] = fact(k-1) * sum(_scale(k, j) * tmp_exp[k-j] * tmp_sq[j] for j in range(1, k + 1))
|
|
|
|
# const
|
|
c[k] = deriv_const * tmp_exp[k]
|
|
|
|
# we can't, and don't need, to compute c[k+1], just need to get the last v[k]
|
|
k = len(series)
|
|
v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
|
|
|
|
primal_out, *series_out = v
|
|
return primal_out, series_out
|
|
jet_rules[lax.erf_inv_p] = _erf_inv_rule
|
|
|
|
### More complicated rules
|
|
|
|
def fact(n):
|
|
return lax.exp(lax.lgamma(n+1.))
|
|
|
|
def _scale(k, j):
|
|
return 1. / (fact(k - j) * fact(j - 1))
|
|
|
|
def _scale2(k, j):
|
|
return 1. / (fact(k - j) * fact(j))
|
|
|
|
def _exp_taylor(primals_in, series_in):
|
|
x, = primals_in
|
|
series, = series_in
|
|
u = [x] + series
|
|
v = [lax.exp(x)] + [None] * len(series)
|
|
for k in range(1,len(v)):
|
|
v[k] = fact(k-1) * sum(_scale(k, j) * v[k-j] * u[j] for j in range(1, k+1))
|
|
primal_out, *series_out = v
|
|
return primal_out, series_out
|
|
jet_rules[lax.exp_p] = _exp_taylor
|
|
|
|
def _pow_taylor(primals_in, series_in):
|
|
u_, r_ = primals_in
|
|
|
|
x, series = jet(lambda x, y: lax.mul(y, lax.log(x)), primals_in, series_in)
|
|
|
|
u = [x] + series
|
|
v = [u_ ** r_] + [None] * len(series)
|
|
for k in range(1, len(v)):
|
|
v[k] = fact(k-1) * sum(_scale(k, j) * v[k-j] * u[j] for j in range(1, k+1))
|
|
primal_out, *series_out = v
|
|
|
|
return primal_out, series_out
|
|
jet_rules[lax.pow_p] = _pow_taylor
|
|
|
|
def _integer_pow_taylor(primals_in, series_in, *, y):
|
|
if y == 0:
|
|
return jet(jnp.ones_like, primals_in, series_in)
|
|
elif y == 1:
|
|
return jet(lambda x: x, primals_in, series_in)
|
|
elif y == 2:
|
|
return jet(lambda x: x * x, primals_in, series_in)
|
|
x, = primals_in
|
|
series, = series_in
|
|
u = [x] + series
|
|
v = [lax.integer_pow(x, y)] + [None] * len(series)
|
|
for k in range(1, len(v)):
|
|
vu = sum(_scale(k, j) * v[k-j] * u[j] for j in range(1, k + 1))
|
|
uv = sum(_scale(k, j) * u[k-j] * v[j] for j in range(1, k))
|
|
v[k] = jnp.where(x == 0, 0, fact(k-1) * (y * vu - uv) / x)
|
|
primal_out, *series_out = v
|
|
|
|
return primal_out, series_out
|
|
jet_rules[lax.integer_pow_p] = _integer_pow_taylor
|
|
|
|
|
|
def _logistic_taylor(primals_in, series_in):
|
|
x, = primals_in
|
|
series, = series_in
|
|
u = [x] + series
|
|
v = [lax.logistic(x)] + [None] * len(series)
|
|
e = [v[0] * (1 - v[0])] + [None] * len(series) # terms for sigmoid' = sigmoid * (1 - sigmoid)
|
|
for k in range(1, len(v)):
|
|
v[k] = fact(k-1) * sum(_scale(k, j) * e[k-j] * u[j] for j in range(1, k+1))
|
|
e[k] = (1 - v[0]) * v[k] - fact(k) * sum(_scale2(k, j) * v[j] * v[k-j] for j in range(1, k+1))
|
|
|
|
primal_out, *series_out = v
|
|
return primal_out, series_out
|
|
|
|
jet_rules[lax.logistic_p] = _logistic_taylor
|
|
|
|
|
|
def _tanh_taylor(primals_in, series_in):
|
|
x, = primals_in
|
|
series, = series_in
|
|
u = [2*x] + [2 * series_ for series_ in series]
|
|
primals_in, *series_in = u
|
|
primal_out, series_out = _logistic_taylor((primals_in, ), (series_in, ))
|
|
series_out = [2 * series_ for series_ in series_out]
|
|
return 2 * primal_out - 1, series_out
|
|
jet_rules[lax.tanh_p] = _tanh_taylor
|
|
|
|
def _log_taylor(primals_in, series_in):
|
|
x, = primals_in
|
|
series, = series_in
|
|
u = [x] + series
|
|
v = [lax.log(x)] + [None] * len(series)
|
|
for k in range(1, len(v)):
|
|
conv = sum(_scale(k, j) * v[j] * u[k-j] for j in range(1, k))
|
|
v[k] = (u[k] - fact(k - 1) * conv) / u[0]
|
|
primal_out, *series_out = v
|
|
return primal_out, series_out
|
|
jet_rules[lax.log_p] = _log_taylor
|
|
|
|
def _atan2_taylor(primals_in, series_in):
|
|
x, y = primals_in
|
|
primal_out = lax.atan2(x, y)
|
|
|
|
x, series = jet(lax.div, primals_in, series_in)
|
|
one = lax_internal._const(x, 1)
|
|
c0, cs = jet(lambda x: lax.div(one, 1 + lax.square(x)), (x, ), (series, ))
|
|
c = [c0] + cs
|
|
u = [x] + series
|
|
v = [primal_out] + [None] * len(series)
|
|
for k in range(1, len(v)):
|
|
v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
|
|
primal_out, *series_out = v
|
|
return primal_out, series_out
|
|
jet_rules[lax.atan2_p] = _atan2_taylor
|
|
|
|
def _div_taylor_rule(primals_in, series_in):
|
|
x, y = primals_in
|
|
x_terms, y_terms = series_in
|
|
u = [x] + x_terms
|
|
w = [y] + y_terms
|
|
v = [None] * len(u)
|
|
def scale(k, j): return 1. / (fact(k - j) * fact(j))
|
|
for k in range(0, len(v)):
|
|
conv = sum(scale(k, j) * v[j] * w[k-j] for j in range(0, k))
|
|
v[k] = (u[k] - fact(k) * conv) / w[0]
|
|
primal_out, *series_out = v
|
|
return primal_out, series_out
|
|
jet_rules[lax.div_p] = _div_taylor_rule
|
|
|
|
def _sinusoidal_rule(sign, prims, primals_in, series_in):
|
|
x, = primals_in
|
|
series, = series_in
|
|
u = [x] + series
|
|
s, c = prims
|
|
s = [s(x)] + [None] * len(series)
|
|
c = [c(x)] + [None] * len(series)
|
|
for k in range(1, len(s)):
|
|
s[k] = fact(k-1) * sum(_scale(k, j) * u[j] * c[k-j] for j in range(1, k + 1))
|
|
c[k] = fact(k-1) * sum(_scale(k, j) * u[j] * s[k-j] for j in range(1, k + 1)) * sign
|
|
return (s[0], s[1:]), (c[0], c[1:])
|
|
|
|
def _get_ind(f, ind):
|
|
return lambda *args: f(*args)[ind]
|
|
|
|
jet_rules[lax.sin_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 0)
|
|
jet_rules[lax.cos_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 1)
|
|
jet_rules[lax.sinh_p] = _get_ind(partial(_sinusoidal_rule, 1, (lax.sinh, lax.cosh)), 0)
|
|
jet_rules[lax.cosh_p] = _get_ind(partial(_sinusoidal_rule, 1, (lax.sinh, lax.cosh)), 1)
|
|
|
|
def _bilinear_taylor_rule(prim, primals_in, series_in, **params):
|
|
x, y = primals_in
|
|
x_terms, y_terms = series_in
|
|
u = [x] + x_terms
|
|
w = [y] + y_terms
|
|
v = [None] * len(u)
|
|
op = partial(prim.bind, **params)
|
|
def scale(k, j): return 1. / (fact(k - j) * fact(j))
|
|
for k in range(0, len(v)):
|
|
v[k] = fact(k) * sum(scale(k, j) * op(u[j], w[k-j]) for j in range(0, k+1))
|
|
primal_out, *series_out = v
|
|
return primal_out, series_out
|
|
jet_rules[lax.dot_general_p] = partial(_bilinear_taylor_rule, lax.dot_general_p)
|
|
jet_rules[lax.mul_p] = partial(_bilinear_taylor_rule, lax.mul_p)
|
|
jet_rules[lax.conv_general_dilated_p] = partial(_bilinear_taylor_rule, lax.conv_general_dilated_p)
|
|
|
|
def _gather_taylor_rule(primals_in, series_in, **params):
|
|
operand, start_indices = primals_in
|
|
gs, _ = series_in
|
|
primal_out = lax.gather_p.bind(operand, start_indices, **params)
|
|
series_out = [lax.gather_p.bind(g, start_indices, **params) for g in gs]
|
|
return primal_out, series_out
|
|
jet_rules[lax.gather_p] = _gather_taylor_rule
|
|
|
|
def _gen_reduce_choose_taylor_rule(chooser_fun):
|
|
def chooser_taylor_rule(primals_in, series_in, **params):
|
|
operand, = primals_in
|
|
gs, = series_in
|
|
primal_out = chooser_fun(operand, **params)
|
|
axes = params.pop("axes", None)
|
|
primal_dtype = gs[0].dtype
|
|
shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
|
|
location_indicators = lax.convert_element_type(
|
|
lax_internal._eq_meet(operand, lax.reshape(primal_out, shape)),
|
|
primal_dtype)
|
|
counts = lax_internal._reduce_sum(location_indicators, axes)
|
|
def _reduce_chooser_taylor_rule(g):
|
|
return lax.div(
|
|
lax_internal._reduce_sum(lax.mul(g, location_indicators), axes),
|
|
counts)
|
|
series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
|
|
return primal_out, series_out
|
|
return chooser_taylor_rule
|
|
jet_rules[lax.reduce_max_p] = _gen_reduce_choose_taylor_rule(
|
|
lax_internal._reduce_max)
|
|
jet_rules[lax.reduce_min_p] = _gen_reduce_choose_taylor_rule(
|
|
lax_internal._reduce_min)
|
|
|
|
def _abs_taylor_rule(x, series_in, **params):
|
|
x, = x
|
|
zero = lax.full_like(x, 0, shape=())
|
|
primal_out = lax.abs_p.bind(x, **params)
|
|
negs = lax.select(lax.lt(x, zero), lax.full_like(x, -1), lax.full_like(x, 1.0))
|
|
fix_sign = lambda y: negs * y
|
|
series_out = [fix_sign(*terms_in, **params) for terms_in in zip(*series_in)]
|
|
return primal_out, series_out
|
|
jet_rules[lax.abs_p] = _abs_taylor_rule
|
|
|
|
def _select_n_taylor_rule(primal_in, series_in, **params):
|
|
b, *cases = primal_in
|
|
primal_out = lax.select_n(b, *cases)
|
|
sel = lambda _, *xs: lax.select_n(b, *xs)
|
|
series_out = [sel(*terms_in) for terms_in in zip(*series_in)]
|
|
return primal_out, series_out
|
|
jet_rules[lax.select_n_p] = _select_n_taylor_rule
|
|
|
|
def _lax_max_taylor_rule(primal_in, series_in):
|
|
x, y = jnp.broadcast_arrays(*primal_in)
|
|
|
|
xgy = x > y # greater than mask
|
|
xey = x == y # equal to mask
|
|
primal_out = lax.select(xgy, x, y)
|
|
|
|
def select_max_and_avg_eq(x_i, y_i):
|
|
"""Select x where x>y or average when x==y"""
|
|
max_i = lax.select(xgy, x_i, y_i)
|
|
max_i = lax.select(xey, (x_i + y_i)/2, max_i)
|
|
return max_i
|
|
|
|
series_out = [select_max_and_avg_eq(*jnp.broadcast_arrays(*terms_in)) for terms_in in zip(*series_in)]
|
|
return primal_out, series_out
|
|
jet_rules[lax.max_p] = _lax_max_taylor_rule
|
|
|
|
def _lax_min_taylor_rule(primal_in, series_in):
|
|
x, y = primal_in
|
|
xgy = x < y # less than mask
|
|
xey = x == y # equal to mask
|
|
primal_out = lax.select(xgy, x, y)
|
|
|
|
def select_min_and_avg_eq(x_i, y_i):
|
|
"""Select x where x>y or average when x==y"""
|
|
min_i = lax.select(xgy, x_i, y_i)
|
|
min_i = lax.select(xey, (x_i + y_i)/2, min_i)
|
|
return min_i
|
|
|
|
series_out = [select_min_and_avg_eq(*terms_in) for terms_in in zip(*series_in)]
|
|
return primal_out, series_out
|
|
jet_rules[lax.min_p] = _lax_min_taylor_rule
|
|
|
|
def _scatter_add_rule(primals_in, series_in, *, update_jaxpr, update_consts,
|
|
dimension_numbers, indices_are_sorted, unique_indices,
|
|
mode):
|
|
bind = partial(lax.scatter_add_p.bind, update_jaxpr=update_jaxpr,
|
|
update_consts=update_consts, dimension_numbers=dimension_numbers,
|
|
indices_are_sorted=indices_are_sorted,
|
|
unique_indices=unique_indices, mode=mode)
|
|
operand, scatter_indices, updates = primals_in
|
|
primal_out = bind(operand, scatter_indices, updates)
|
|
series_out = [bind(d1, scatter_indices, d2) for d1, _, d2 in zip(*series_in)]
|
|
return primal_out, series_out
|
|
jet_rules[lax.scatter_add_p] = _scatter_add_rule
|
|
|
|
|
|
@weakref_lru_cache
|
|
def _jet_jaxpr(
|
|
jaxpr: core.ClosedJaxpr, order: int, primals_and_series_avals, in_tree_def
|
|
) -> Tuple[core.ClosedJaxpr, Any]:
|
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
|
f_jet, out_tree_def = traceable(jet_fun(jet_subtrace(f), order), in_tree_def)
|
|
jaxpr_jet, _, consts = pe.trace_to_jaxpr_dynamic(
|
|
f_jet, primals_and_series_avals)
|
|
return core.ClosedJaxpr(jaxpr_jet, consts), out_tree_def
|
|
|
|
|
|
def _pjit_jet_rule(primals_in, series_in, **params):
|
|
primals_and_series, in_tree_def = tree_flatten((primals_in, series_in))
|
|
order = len(series_in[0])
|
|
primals_and_series_avals = tuple(shaped_abstractify(x) for x in primals_and_series)
|
|
jaxpr_jet, out_tree_def = _jet_jaxpr(params['jaxpr'], order,
|
|
primals_and_series_avals, in_tree_def)
|
|
num_series_in = len(primals_in) * order
|
|
num_series_out = len(params['out_shardings']) * order
|
|
new_params = {
|
|
**params,
|
|
'jaxpr': jaxpr_jet,
|
|
'in_shardings': (
|
|
params['in_shardings'] + (sharding_impls.UNSPECIFIED,) * num_series_in
|
|
),
|
|
'out_shardings': (
|
|
params['out_shardings']
|
|
+ (sharding_impls.UNSPECIFIED,) * num_series_out
|
|
),
|
|
'donated_invars': params['donated_invars'] + (False,) * num_series_in,
|
|
}
|
|
result = pjit.pjit_p.bind(*primals_and_series, **new_params)
|
|
return tree_unflatten(out_tree_def(), result)
|
|
|
|
jet_rules[pjit.pjit_p] = _pjit_jet_rule
|