622 lines
20 KiB
Python
622 lines
20 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Examples of how to write optimizers with JAX.
|
|
|
|
You likely do not mean to import this module! The optimizers in this library
|
|
are intended as examples only. If you are looking for a fully featured optimizer
|
|
library, two good options are JAXopt_ and Optax_.
|
|
|
|
This module contains some convenient optimizer definitions, specifically
|
|
initialization and update functions, which can be used with ndarrays or
|
|
arbitrarily-nested tuple/list/dicts of ndarrays.
|
|
|
|
An optimizer is modeled as an ``(init_fun, update_fun, get_params)`` triple of
|
|
functions, where the component functions have these signatures:
|
|
|
|
::
|
|
|
|
init_fun(params)
|
|
|
|
Args:
|
|
params: pytree representing the initial parameters.
|
|
|
|
Returns:
|
|
A pytree representing the initial optimizer state, which includes the
|
|
initial parameters and may also include auxiliary values like initial
|
|
momentum. The optimizer state pytree structure generally differs from that
|
|
of `params`.
|
|
|
|
::
|
|
|
|
update_fun(step, grads, opt_state)
|
|
|
|
Args:
|
|
step: integer representing the step index.
|
|
grads: a pytree with the same structure as `get_params(opt_state)`
|
|
representing the gradients to be used in updating the optimizer state.
|
|
opt_state: a pytree representing the optimizer state to be updated.
|
|
|
|
Returns:
|
|
A pytree with the same structure as the `opt_state` argument representing
|
|
the updated optimizer state.
|
|
|
|
::
|
|
|
|
get_params(opt_state)
|
|
|
|
Args:
|
|
opt_state: pytree representing an optimizer state.
|
|
|
|
Returns:
|
|
A pytree representing the parameters extracted from `opt_state`, such that
|
|
the invariant `params == get_params(init_fun(params))` holds true.
|
|
|
|
|
|
Notice that an optimizer implementation has a lot of flexibility in the form of
|
|
opt_state: it just has to be a pytree of JaxTypes (so that it can be passed to
|
|
the JAX transforms defined in api.py) and it has to be consumable by update_fun
|
|
and get_params.
|
|
|
|
Example Usage:
|
|
|
|
.. code-block:: python
|
|
|
|
opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
|
|
opt_state = opt_init(params)
|
|
|
|
def step(step, opt_state):
|
|
value, grads = jax.value_and_grad(loss_fn)(get_params(opt_state))
|
|
opt_state = opt_update(step, grads, opt_state)
|
|
return value, opt_state
|
|
|
|
for i in range(num_steps):
|
|
value, opt_state = step(i, opt_state)
|
|
|
|
|
|
.. _JAXopt: https://github.com/google/jaxopt
|
|
.. _Optax: https://github.com/deepmind/optax
|
|
"""
|
|
|
|
from typing import Any, Callable, NamedTuple, Tuple, Union
|
|
|
|
from collections import namedtuple
|
|
import functools
|
|
from functools import partial
|
|
|
|
import jax.numpy as jnp
|
|
from jax._src.util import safe_zip, safe_map, unzip2
|
|
from jax import tree_util
|
|
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
|
|
register_pytree_node)
|
|
|
|
map = safe_map
|
|
zip = safe_zip
|
|
|
|
|
|
# The implementation here basically works by flattening pytrees. There are two
|
|
# levels of pytrees to think about: the pytree of params, which we can think of
|
|
# as defining an "outer pytree", and a pytree produced by applying init_fun to
|
|
# each leaf of the params pytree, which we can think of as the "inner pytrees".
|
|
# Since pytrees can be flattened, that structure is isomorphic to a list of
|
|
# lists (with no further nesting).
|
|
|
|
OptimizerState = namedtuple("OptimizerState",
|
|
["packed_state", "tree_def", "subtree_defs"])
|
|
register_pytree_node(
|
|
OptimizerState,
|
|
lambda xs: ((xs.packed_state,), (xs.tree_def, xs.subtree_defs)),
|
|
lambda data, xs: OptimizerState(xs[0], data[0], data[1])) # type: ignore[index]
|
|
|
|
|
|
Array = Any
|
|
Params = Any # Parameters are arbitrary nests of `jnp.ndarrays`.
|
|
State = Any # internal State
|
|
Updates = Params # Gradient updates are of the same type as parameters.
|
|
|
|
InitFn = Callable[[Params], OptimizerState]
|
|
Step = int
|
|
UpdateFn = Callable[[Step, Updates, OptimizerState], OptimizerState]
|
|
ParamsFn = Callable[[OptimizerState], Params]
|
|
|
|
class Optimizer(NamedTuple):
|
|
init_fn: InitFn
|
|
update_fn: UpdateFn
|
|
params_fn: ParamsFn
|
|
|
|
Schedule = Callable[[Step], float]
|
|
|
|
def optimizer(opt_maker: Callable[...,
|
|
Tuple[Callable[[Params], State],
|
|
Callable[[Step, Updates, Params], Params],
|
|
Callable[[State], Params]]]) -> Callable[..., Optimizer]:
|
|
"""Decorator to make an optimizer defined for arrays generalize to containers.
|
|
|
|
With this decorator, you can write init, update, and get_params functions that
|
|
each operate only on single arrays, and convert them to corresponding
|
|
functions that operate on pytrees of parameters. See the optimizers defined in
|
|
optimizers.py for examples.
|
|
|
|
Args:
|
|
opt_maker: a function that returns an ``(init_fun, update_fun, get_params)``
|
|
triple of functions that might only work with ndarrays, as per
|
|
|
|
.. code-block:: haskell
|
|
|
|
init_fun :: ndarray -> OptStatePytree ndarray
|
|
update_fun :: OptStatePytree ndarray -> OptStatePytree ndarray
|
|
get_params :: OptStatePytree ndarray -> ndarray
|
|
|
|
Returns:
|
|
An ``(init_fun, update_fun, get_params)`` triple of functions that work on
|
|
arbitrary pytrees, as per
|
|
|
|
.. code-block:: haskell
|
|
|
|
init_fun :: ParameterPytree ndarray -> OptimizerState
|
|
update_fun :: OptimizerState -> OptimizerState
|
|
get_params :: OptimizerState -> ParameterPytree ndarray
|
|
|
|
The OptimizerState pytree type used by the returned functions is isomorphic
|
|
to ``ParameterPytree (OptStatePytree ndarray)``, but may store the state
|
|
instead as e.g. a partially-flattened data structure for performance.
|
|
"""
|
|
|
|
@functools.wraps(opt_maker)
|
|
def tree_opt_maker(*args, **kwargs):
|
|
init, update, get_params = opt_maker(*args, **kwargs)
|
|
|
|
@functools.wraps(init)
|
|
def tree_init(x0_tree):
|
|
x0_flat, tree = tree_flatten(x0_tree)
|
|
initial_states = [init(x0) for x0 in x0_flat]
|
|
states_flat, subtrees = unzip2(map(tree_flatten, initial_states))
|
|
return OptimizerState(states_flat, tree, subtrees)
|
|
|
|
@functools.wraps(update)
|
|
def tree_update(i, grad_tree, opt_state):
|
|
states_flat, tree, subtrees = opt_state
|
|
grad_flat, tree2 = tree_flatten(grad_tree)
|
|
if tree2 != tree:
|
|
msg = ("optimizer update function was passed a gradient tree that did "
|
|
"not match the parameter tree structure with which it was "
|
|
"initialized: parameter tree {} and grad tree {}.")
|
|
raise TypeError(msg.format(tree, tree2))
|
|
states = map(tree_unflatten, subtrees, states_flat)
|
|
new_states = map(partial(update, i), grad_flat, states)
|
|
new_states_flat, subtrees2 = unzip2(map(tree_flatten, new_states))
|
|
for subtree, subtree2 in zip(subtrees, subtrees2):
|
|
if subtree2 != subtree:
|
|
msg = ("optimizer update function produced an output structure that "
|
|
"did not match its input structure: input {} and output {}.")
|
|
raise TypeError(msg.format(subtree, subtree2))
|
|
return OptimizerState(new_states_flat, tree, subtrees)
|
|
|
|
@functools.wraps(get_params)
|
|
def tree_get_params(opt_state):
|
|
states_flat, tree, subtrees = opt_state
|
|
states = map(tree_unflatten, subtrees, states_flat)
|
|
params = map(get_params, states)
|
|
return tree_unflatten(tree, params)
|
|
|
|
return Optimizer(tree_init, tree_update, tree_get_params)
|
|
return tree_opt_maker
|
|
|
|
|
|
### optimizers
|
|
|
|
@optimizer
|
|
def sgd(step_size):
|
|
"""Construct optimizer triple for stochastic gradient descent.
|
|
|
|
Args:
|
|
step_size: positive scalar, or a callable representing a step size schedule
|
|
that maps the iteration index to a positive scalar.
|
|
|
|
Returns:
|
|
An (init_fun, update_fun, get_params) triple.
|
|
"""
|
|
step_size = make_schedule(step_size)
|
|
def init(x0):
|
|
return x0
|
|
def update(i, g, x):
|
|
return x - step_size(i) * g
|
|
def get_params(x):
|
|
return x
|
|
return Optimizer(init, update, get_params)
|
|
|
|
@optimizer
|
|
def momentum(step_size: Schedule, mass: float):
|
|
"""Construct optimizer triple for SGD with momentum.
|
|
|
|
Args:
|
|
step_size: positive scalar, or a callable representing a step size schedule
|
|
that maps the iteration index to a positive scalar.
|
|
mass: positive scalar representing the momentum coefficient.
|
|
|
|
Returns:
|
|
An (init_fun, update_fun, get_params) triple.
|
|
"""
|
|
step_size = make_schedule(step_size)
|
|
def init(x0):
|
|
v0 = jnp.zeros_like(x0)
|
|
return x0, v0
|
|
def update(i, g, state):
|
|
x, velocity = state
|
|
velocity = mass * velocity + g
|
|
x = x - step_size(i) * velocity
|
|
return x, velocity
|
|
def get_params(state):
|
|
x, _ = state
|
|
return x
|
|
return init, update, get_params
|
|
|
|
|
|
@optimizer
|
|
def nesterov(step_size: Schedule, mass: float):
|
|
"""Construct optimizer triple for SGD with Nesterov momentum.
|
|
|
|
Args:
|
|
step_size: positive scalar, or a callable representing a step size schedule
|
|
that maps the iteration index to a positive scalar.
|
|
mass: positive scalar representing the momentum coefficient.
|
|
|
|
Returns:
|
|
An (init_fun, update_fun, get_params) triple.
|
|
"""
|
|
step_size = make_schedule(step_size)
|
|
def init(x0):
|
|
v0 = jnp.zeros_like(x0)
|
|
return x0, v0
|
|
def update(i, g, state):
|
|
x, velocity = state
|
|
velocity = mass * velocity + g
|
|
x = x - step_size(i) * (mass * velocity + g)
|
|
return x, velocity
|
|
def get_params(state):
|
|
x, _ = state
|
|
return x
|
|
return init, update, get_params
|
|
|
|
|
|
@optimizer
|
|
def adagrad(step_size, momentum=0.9):
|
|
"""Construct optimizer triple for Adagrad.
|
|
|
|
Adaptive Subgradient Methods for Online Learning and Stochastic Optimization:
|
|
http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf
|
|
|
|
Args:
|
|
step_size: positive scalar, or a callable representing a step size schedule
|
|
that maps the iteration index to a positive scalar.
|
|
momentum: optional, a positive scalar value for momentum
|
|
|
|
Returns:
|
|
An (init_fun, update_fun, get_params) triple.
|
|
"""
|
|
step_size = make_schedule(step_size)
|
|
|
|
def init(x0):
|
|
g_sq = jnp.zeros_like(x0)
|
|
m = jnp.zeros_like(x0)
|
|
return x0, g_sq, m
|
|
|
|
def update(i, g, state):
|
|
x, g_sq, m = state
|
|
g_sq += jnp.square(g)
|
|
g_sq_inv_sqrt = jnp.where(g_sq > 0, 1. / jnp.sqrt(g_sq), 0.0)
|
|
m = (1. - momentum) * (g * g_sq_inv_sqrt) + momentum * m
|
|
x = x - step_size(i) * m
|
|
return x, g_sq, m
|
|
|
|
def get_params(state):
|
|
x, _, _ = state
|
|
return x
|
|
|
|
return init, update, get_params
|
|
|
|
|
|
@optimizer
|
|
def rmsprop(step_size, gamma=0.9, eps=1e-8):
|
|
"""Construct optimizer triple for RMSProp.
|
|
|
|
Args:
|
|
step_size: positive scalar, or a callable representing a step size schedule
|
|
that maps the iteration index to a positive scalar.
|
|
gamma: Decay parameter.
|
|
eps: Epsilon parameter.
|
|
|
|
Returns:
|
|
An (init_fun, update_fun, get_params) triple.
|
|
"""
|
|
step_size = make_schedule(step_size)
|
|
def init(x0):
|
|
avg_sq_grad = jnp.zeros_like(x0)
|
|
return x0, avg_sq_grad
|
|
def update(i, g, state):
|
|
x, avg_sq_grad = state
|
|
avg_sq_grad = avg_sq_grad * gamma + jnp.square(g) * (1. - gamma)
|
|
x = x - step_size(i) * g / jnp.sqrt(avg_sq_grad + eps)
|
|
return x, avg_sq_grad
|
|
def get_params(state):
|
|
x, _ = state
|
|
return x
|
|
return init, update, get_params
|
|
|
|
|
|
@optimizer
|
|
def rmsprop_momentum(step_size, gamma=0.9, eps=1e-8, momentum=0.9):
|
|
"""Construct optimizer triple for RMSProp with momentum.
|
|
|
|
This optimizer is separate from the rmsprop optimizer because it needs to
|
|
keep track of additional parameters.
|
|
|
|
Args:
|
|
step_size: positive scalar, or a callable representing a step size schedule
|
|
that maps the iteration index to a positive scalar.
|
|
gamma: Decay parameter.
|
|
eps: Epsilon parameter.
|
|
momentum: Momentum parameter.
|
|
|
|
Returns:
|
|
An (init_fun, update_fun, get_params) triple.
|
|
"""
|
|
step_size = make_schedule(step_size)
|
|
def init(x0):
|
|
avg_sq_grad = jnp.zeros_like(x0)
|
|
mom = jnp.zeros_like(x0)
|
|
return x0, avg_sq_grad, mom
|
|
def update(i, g, state):
|
|
x, avg_sq_grad, mom = state
|
|
avg_sq_grad = avg_sq_grad * gamma + jnp.square(g) * (1. - gamma)
|
|
mom = momentum * mom + step_size(i) * g / jnp.sqrt(avg_sq_grad + eps)
|
|
x = x - mom
|
|
return x, avg_sq_grad, mom
|
|
def get_params(state):
|
|
x, _, _ = state
|
|
return x
|
|
return init, update, get_params
|
|
|
|
|
|
@optimizer
|
|
def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
|
|
"""Construct optimizer triple for Adam.
|
|
|
|
Args:
|
|
step_size: positive scalar, or a callable representing a step size schedule
|
|
that maps the iteration index to a positive scalar.
|
|
b1: optional, a positive scalar value for beta_1, the exponential decay rate
|
|
for the first moment estimates (default 0.9).
|
|
b2: optional, a positive scalar value for beta_2, the exponential decay rate
|
|
for the second moment estimates (default 0.999).
|
|
eps: optional, a positive scalar value for epsilon, a small constant for
|
|
numerical stability (default 1e-8).
|
|
|
|
Returns:
|
|
An (init_fun, update_fun, get_params) triple.
|
|
"""
|
|
step_size = make_schedule(step_size)
|
|
def init(x0):
|
|
m0 = jnp.zeros_like(x0)
|
|
v0 = jnp.zeros_like(x0)
|
|
return x0, m0, v0
|
|
def update(i, g, state):
|
|
x, m, v = state
|
|
m = (1 - b1) * g + b1 * m # First moment estimate.
|
|
v = (1 - b2) * jnp.square(g) + b2 * v # Second moment estimate.
|
|
mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
|
|
vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
|
|
x = x - step_size(i) * mhat / (jnp.sqrt(vhat) + eps)
|
|
return x, m, v
|
|
def get_params(state):
|
|
x, _, _ = state
|
|
return x
|
|
return init, update, get_params
|
|
|
|
|
|
@optimizer
|
|
def adamax(step_size, b1=0.9, b2=0.999, eps=1e-8):
|
|
"""Construct optimizer triple for AdaMax (a variant of Adam based on infinity norm).
|
|
|
|
Args:
|
|
step_size: positive scalar, or a callable representing a step size schedule
|
|
that maps the iteration index to a positive scalar.
|
|
b1: optional, a positive scalar value for beta_1, the exponential decay rate
|
|
for the first moment estimates (default 0.9).
|
|
b2: optional, a positive scalar value for beta_2, the exponential decay rate
|
|
for the second moment estimates (default 0.999).
|
|
eps: optional, a positive scalar value for epsilon, a small constant for
|
|
numerical stability (default 1e-8).
|
|
|
|
Returns:
|
|
An (init_fun, update_fun, get_params) triple.
|
|
"""
|
|
step_size = make_schedule(step_size)
|
|
def init(x0):
|
|
m0 = jnp.zeros_like(x0)
|
|
u0 = jnp.zeros_like(x0)
|
|
return x0, m0, u0
|
|
def update(i, g, state):
|
|
x, m, u = state
|
|
m = (1 - b1) * g + b1 * m # First moment estimate.
|
|
u = jnp.maximum(b2 * u, jnp.abs(g)) # Update exponentially weighted infinity norm.
|
|
x = (x - (step_size(i) / (1 - jnp.asarray(b1, m.dtype) ** (i + 1))) * m
|
|
/ (u + eps))
|
|
return x, m, u
|
|
def get_params(state):
|
|
x, _, _ = state
|
|
return x
|
|
return init, update, get_params
|
|
|
|
|
|
@optimizer
|
|
def sm3(step_size, momentum=0.9):
|
|
"""Construct optimizer triple for SM3.
|
|
|
|
Memory-Efficient Adaptive Optimization for Large-Scale Learning.
|
|
https://arxiv.org/abs/1901.11150
|
|
|
|
Args:
|
|
step_size: positive scalar, or a callable representing a step size schedule
|
|
that maps the iteration index to a positive scalar.
|
|
momentum: optional, a positive scalar value for momentum
|
|
|
|
Returns:
|
|
An (init_fun, update_fun, get_params) triple.
|
|
"""
|
|
step_size = make_schedule(step_size)
|
|
|
|
def splice(seq, i, x):
|
|
lst = list(seq)
|
|
lst[i:i+1] = x
|
|
return lst
|
|
|
|
def broadcast_into(ndim, x, axis):
|
|
idx = splice([None] * ndim, axis, [slice(None)])
|
|
return x[tuple(idx)]
|
|
|
|
def init(x0):
|
|
x_shape = x0.shape
|
|
x0 = jnp.atleast_1d(x0)
|
|
vs = [jnp.zeros(sz, dtype=x0.dtype) for sz in x0.shape]
|
|
return x0, jnp.zeros_like(x0), vs, x_shape
|
|
|
|
def update(i, g, state):
|
|
x, m, vs, x_shape = state
|
|
vs = [broadcast_into(g.ndim, v, i) for i, v in enumerate(vs)]
|
|
accum = functools.reduce(jnp.minimum, vs) + jnp.square(g)
|
|
accum_inv_sqrt = jnp.where(accum > 0, 1. / jnp.sqrt(accum), 0)
|
|
m = (1. - momentum) * (g * accum_inv_sqrt) + momentum * m
|
|
x = x - step_size(i) * m
|
|
vs = [accum.max(splice(range(x.ndim), j, [])) for j in range(x.ndim)]
|
|
return x, m, vs, x_shape
|
|
|
|
def get_params(state):
|
|
x, _, _, x_shape = state
|
|
return x.reshape(x_shape)
|
|
|
|
return init, update, get_params
|
|
|
|
|
|
### learning rate schedules
|
|
|
|
def constant(step_size) -> Schedule:
|
|
def schedule(i):
|
|
return step_size
|
|
return schedule
|
|
|
|
def exponential_decay(step_size, decay_steps, decay_rate):
|
|
def schedule(i):
|
|
return step_size * decay_rate ** (i / decay_steps)
|
|
return schedule
|
|
|
|
def inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False):
|
|
if staircase:
|
|
def schedule(i):
|
|
return step_size / (1 + decay_rate * jnp.floor(i / decay_steps))
|
|
else:
|
|
def schedule(i):
|
|
return step_size / (1 + decay_rate * i / decay_steps)
|
|
return schedule
|
|
|
|
def polynomial_decay(step_size, decay_steps, final_step_size, power=1.0):
|
|
def schedule(step_num):
|
|
step_num = jnp.minimum(step_num, decay_steps)
|
|
step_mult = (1 - step_num / decay_steps) ** power
|
|
return step_mult * (step_size - final_step_size) + final_step_size
|
|
|
|
return schedule
|
|
|
|
def piecewise_constant(boundaries: Any, values: Any):
|
|
boundaries = jnp.array(boundaries)
|
|
values = jnp.array(values)
|
|
if not boundaries.ndim == values.ndim == 1:
|
|
raise ValueError("boundaries and values must be sequences")
|
|
if not boundaries.shape[0] == values.shape[0] - 1:
|
|
raise ValueError("boundaries length must be one shorter than values length")
|
|
|
|
def schedule(i):
|
|
return values[jnp.sum(i > boundaries)]
|
|
return schedule
|
|
|
|
def make_schedule(scalar_or_schedule: Union[float, Schedule]) -> Schedule:
|
|
if callable(scalar_or_schedule):
|
|
return scalar_or_schedule
|
|
elif jnp.ndim(scalar_or_schedule) == 0:
|
|
return constant(scalar_or_schedule)
|
|
else:
|
|
raise TypeError(type(scalar_or_schedule))
|
|
|
|
|
|
### utilities
|
|
|
|
def l2_norm(tree):
|
|
"""Compute the l2 norm of a pytree of arrays. Useful for weight decay."""
|
|
leaves, _ = tree_flatten(tree)
|
|
return jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves))
|
|
|
|
def clip_grads(grad_tree, max_norm):
|
|
"""Clip gradients stored as a pytree of arrays to maximum norm `max_norm`."""
|
|
norm = l2_norm(grad_tree)
|
|
normalize = lambda g: jnp.where(norm < max_norm, g, g * (max_norm / norm))
|
|
return tree_map(normalize, grad_tree)
|
|
|
|
|
|
### serialization utilities
|
|
|
|
class JoinPoint:
|
|
"""Marks the boundary between two joined (nested) pytrees."""
|
|
def __init__(self, subtree):
|
|
self.subtree = subtree
|
|
|
|
# Since pytrees are containers of numpy arrays, look iterable.
|
|
def __iter__(self):
|
|
yield self.subtree
|
|
|
|
def unpack_optimizer_state(opt_state):
|
|
"""Converts an OptimizerState to a marked pytree.
|
|
|
|
Converts an OptimizerState to a marked pytree with the leaves of the outer
|
|
pytree represented as JoinPoints to avoid losing information. This function is
|
|
intended to be useful when serializing optimizer states.
|
|
|
|
Args:
|
|
opt_state: An OptimizerState
|
|
Returns:
|
|
A pytree with JoinPoint leaves that contain a second level of pytrees.
|
|
"""
|
|
states_flat, tree_def, subtree_defs = opt_state
|
|
subtrees = map(tree_unflatten, subtree_defs, states_flat)
|
|
sentinels = [JoinPoint(subtree) for subtree in subtrees]
|
|
return tree_util.tree_unflatten(tree_def, sentinels)
|
|
|
|
def pack_optimizer_state(marked_pytree):
|
|
"""Converts a marked pytree to an OptimizerState.
|
|
|
|
The inverse of unpack_optimizer_state. Converts a marked pytree with the
|
|
leaves of the outer pytree represented as JoinPoints back into an
|
|
OptimizerState. This function is intended to be useful when deserializing
|
|
optimizer states.
|
|
|
|
Args:
|
|
marked_pytree: A pytree containing JoinPoint leaves that hold more pytrees.
|
|
Returns:
|
|
An equivalent OptimizerState to the input argument.
|
|
"""
|
|
sentinels, tree_def = tree_flatten(marked_pytree)
|
|
assert all(isinstance(s, JoinPoint) for s in sentinels)
|
|
subtrees = [s.subtree for s in sentinels]
|
|
states_flat, subtree_defs = unzip2(map(tree_flatten, subtrees))
|
|
return OptimizerState(states_flat, tree_def, subtree_defs)
|