1715 lines
67 KiB
Python
1715 lines
67 KiB
Python
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Version 2 of class Optimizer."""
|
||
|
|
||
|
|
||
|
import abc
|
||
|
import contextlib
|
||
|
import functools
|
||
|
import warnings
|
||
|
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
|
||
|
from keras import backend
|
||
|
from keras import initializers
|
||
|
from keras.engine import base_layer_utils
|
||
|
from keras.optimizers import utils as optimizer_utils
|
||
|
from keras.optimizers.schedules import learning_rate_schedule
|
||
|
from keras.utils import generic_utils
|
||
|
from keras.utils import layer_utils
|
||
|
from keras.utils import tf_inspect
|
||
|
from keras.utils import tf_utils
|
||
|
|
||
|
# isort: off
|
||
|
from tensorflow.python.util.tf_export import keras_export
|
||
|
|
||
|
keras_optimizers_gauge = tf.__internal__.monitoring.BoolGauge(
|
||
|
"/tensorflow/api/keras/optimizers", "keras optimizer usage", "method"
|
||
|
)
|
||
|
|
||
|
_DEFAULT_VALID_DTYPES = frozenset(
|
||
|
[
|
||
|
tf.float16,
|
||
|
tf.bfloat16,
|
||
|
tf.float32,
|
||
|
tf.float64,
|
||
|
tf.complex64,
|
||
|
tf.complex128,
|
||
|
]
|
||
|
)
|
||
|
|
||
|
|
||
|
def _deduplicate_indexed_slices(values, indices):
|
||
|
"""Sums `values` associated with any non-unique `indices`.
|
||
|
|
||
|
Args:
|
||
|
values: A `Tensor` with rank >= 1.
|
||
|
indices: A one-dimensional integer `Tensor`, indexing into the first
|
||
|
dimension of `values` (as in an IndexedSlices object).
|
||
|
|
||
|
Returns:
|
||
|
A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a
|
||
|
de-duplicated version of `indices` and `summed_values` contains the sum of
|
||
|
`values` slices associated with each unique index.
|
||
|
"""
|
||
|
unique_indices, new_index_positions = tf.unique(indices)
|
||
|
summed_values = tf.math.unsorted_segment_sum(
|
||
|
values, new_index_positions, tf.shape(unique_indices)[0]
|
||
|
)
|
||
|
return (summed_values, unique_indices)
|
||
|
|
||
|
|
||
|
class NullContextmanager:
|
||
|
def __init__(self, *args, **kwargs):
|
||
|
pass
|
||
|
|
||
|
def __enter__(self):
|
||
|
pass
|
||
|
|
||
|
def __exit__(self, type_arg, value_arg, traceback_arg):
|
||
|
return False # False values do not suppress exceptions
|
||
|
|
||
|
|
||
|
def name_scope_only_in_function_or_graph(name):
|
||
|
"""Internal-only entry point for `name_scope*`.
|
||
|
|
||
|
Enters a compat.v1.name_scope only when in a function or graph,
|
||
|
not when running fully eagerly.
|
||
|
|
||
|
Args:
|
||
|
name: The name argument that is passed to the op function.
|
||
|
|
||
|
Returns:
|
||
|
`name_scope*` context manager.
|
||
|
"""
|
||
|
if not tf.executing_eagerly():
|
||
|
return tf.name_scope(name)
|
||
|
else:
|
||
|
return NullContextmanager()
|
||
|
|
||
|
|
||
|
@keras_export(
|
||
|
"keras.optimizers.legacy.Optimizer",
|
||
|
v1=["keras.optimizers.Optimizer", "keras.optimizers.legacy.Optimizer"],
|
||
|
)
|
||
|
class OptimizerV2(tf.__internal__.tracking.Trackable):
|
||
|
"""Base class for legacy Keras optimizers.
|
||
|
|
||
|
You should not use this class directly, but instead instantiate one of its
|
||
|
subclasses such as `tf.keras.optimizers.legacy.SGD`,
|
||
|
`tf.keras.optimizers.legacy.Adam`, etc.
|
||
|
|
||
|
This is the default Keras optimizer base class until v2.10 (included).
|
||
|
In v2.11 and later, `tf.keras.optimizers.Optimizer`
|
||
|
points to a new base class implementation. The legacy class won't be
|
||
|
deleted in the future and will continue to be available at
|
||
|
`tf.keras.optimizers.legacy.Optimizer`.
|
||
|
|
||
|
### Usage
|
||
|
|
||
|
```python
|
||
|
# Create an optimizer with the desired parameters.
|
||
|
opt = tf.keras.optimizers.legacy.SGD(learning_rate=0.1)
|
||
|
# `loss` is a callable that takes no argument and returns the value
|
||
|
# to minimize.
|
||
|
var1 = tf.Variable(2.0)
|
||
|
var2 = tf.Variable(5.0)
|
||
|
loss = lambda: 3 * var1 * var1 + 2 * var2 * var2
|
||
|
# In graph mode, returns op that minimizes the loss by updating the listed
|
||
|
# variables.
|
||
|
opt_op = opt.minimize(loss, var_list=[var1, var2])
|
||
|
opt_op.run()
|
||
|
# In eager mode, simply call minimize to update the list of variables.
|
||
|
opt.minimize(loss, var_list=[var1, var2])
|
||
|
```
|
||
|
|
||
|
### Usage in custom training loops
|
||
|
|
||
|
In Keras models, sometimes variables are created when the model is first
|
||
|
called, instead of construction time. Examples include 1) sequential models
|
||
|
without input shape pre-defined, or 2) subclassed models. Pass var_list as
|
||
|
callable in these cases.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
opt = tf.keras.optimizers.legacy.SGD(learning_rate=0.1)
|
||
|
model = tf.keras.Sequential()
|
||
|
model.add(tf.keras.layers.Dense(num_hidden, activation='relu'))
|
||
|
model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid'))
|
||
|
loss_fn = lambda: tf.keras.losses.mse(model(input), output)
|
||
|
var_list_fn = lambda: model.trainable_weights
|
||
|
for input, output in data:
|
||
|
opt.minimize(loss_fn, var_list_fn)
|
||
|
```
|
||
|
|
||
|
### Processing gradients before applying them
|
||
|
|
||
|
Calling `minimize()` takes care of both computing the gradients and
|
||
|
applying them to the variables. If you want to process the gradients
|
||
|
before applying them you can instead use the optimizer in three steps:
|
||
|
|
||
|
1. Compute the gradients with `tf.GradientTape`.
|
||
|
2. Process the gradients as you wish.
|
||
|
3. Apply the processed gradients with `apply_gradients()`.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
# Create an optimizer.
|
||
|
opt = tf.keras.optimizers.legacy.SGD(learning_rate=0.1)
|
||
|
|
||
|
# Compute the gradients for a list of variables.
|
||
|
with tf.GradientTape() as tape:
|
||
|
loss = <call_loss_function>
|
||
|
vars = <list_of_variables>
|
||
|
grads = tape.gradient(loss, vars)
|
||
|
|
||
|
# Process the gradients, for example cap them, etc.
|
||
|
# capped_grads = [MyCapper(g) for g in grads]
|
||
|
processed_grads = [process_gradient(g) for g in grads]
|
||
|
|
||
|
# Ask the optimizer to apply the processed gradients.
|
||
|
opt.apply_gradients(zip(processed_grads, var_list))
|
||
|
```
|
||
|
|
||
|
### Use with `tf.distribute.Strategy`
|
||
|
|
||
|
This optimizer class is `tf.distribute.Strategy` aware, which means it
|
||
|
automatically sums gradients across all replicas. To average gradients,
|
||
|
you divide your loss by the global batch size, which is done
|
||
|
automatically if you use `tf.keras` built-in training or evaluation loops.
|
||
|
See the `reduction` argument of your loss which should be set to
|
||
|
`tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` for averaging or
|
||
|
`tf.keras.losses.Reduction.SUM` for not.
|
||
|
|
||
|
To aggregate gradients yourself, call `apply_gradients` with
|
||
|
`experimental_aggregate_gradients` set to False. This is useful if you need
|
||
|
to process aggregated gradients.
|
||
|
|
||
|
If you are not using these and you want to average gradients, you should use
|
||
|
`tf.math.reduce_sum` to add up your per-example losses and then divide by
|
||
|
the global batch size. Note that when using `tf.distribute.Strategy`, the
|
||
|
first component of a tensor's shape is the *replica-local* batch size, which
|
||
|
is off by a factor equal to the number of replicas being used to compute a
|
||
|
single step. As a result, using `tf.math.reduce_mean` will give the wrong
|
||
|
answer, resulting in gradients that can be many times too big.
|
||
|
|
||
|
### Variable Constraints
|
||
|
|
||
|
All Keras optimizers respect variable constraints. If constraint function is
|
||
|
passed to any variable, the constraint will be applied to the variable after
|
||
|
the gradient has been applied to the variable.
|
||
|
Important: If gradient is sparse tensor, variable constraint is not
|
||
|
supported.
|
||
|
|
||
|
### Thread Compatibility
|
||
|
|
||
|
The entire optimizer is currently thread compatible, not thread-safe. The
|
||
|
user needs to perform synchronization if necessary.
|
||
|
|
||
|
### Slots
|
||
|
|
||
|
Many optimizer subclasses, such as `Adam` and `Adagrad` allocate and manage
|
||
|
additional variables associated with the variables to train. These are
|
||
|
called <i>Slots</i>. Slots have names and you can ask the optimizer for the
|
||
|
names of the slots that it uses. Once you have a slot name you can ask the
|
||
|
optimizer for the variable it created to hold the slot value.
|
||
|
|
||
|
This can be useful if you want to log debug a training algorithm, report
|
||
|
stats about the slots, etc.
|
||
|
|
||
|
### Hyperparameters
|
||
|
|
||
|
These are arguments passed to the optimizer subclass constructor
|
||
|
(the `__init__` method), and then passed to `self._set_hyper()`.
|
||
|
They can be either regular Python values (like 1.0), tensors, or
|
||
|
callables. If they are callable, the callable will be called during
|
||
|
`apply_gradients()` to get the value for the hyper parameter.
|
||
|
|
||
|
Hyperparameters can be overwritten through user code:
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
# Create an optimizer with the desired parameters.
|
||
|
opt = tf.keras.optimizers.legacy.SGD(learning_rate=0.1)
|
||
|
# `loss` is a callable that takes no argument and returns the value
|
||
|
# to minimize.
|
||
|
loss = lambda: 3 * var1 + 2 * var2
|
||
|
# In eager mode, simply call minimize to update the list of variables.
|
||
|
opt.minimize(loss, var_list=[var1, var2])
|
||
|
# update learning rate
|
||
|
opt.learning_rate = 0.05
|
||
|
opt.minimize(loss, var_list=[var1, var2])
|
||
|
```
|
||
|
|
||
|
### Callable learning rate
|
||
|
|
||
|
Optimizer accepts a callable learning rate in two ways. The first way is
|
||
|
through built-in or customized
|
||
|
`tf.keras.optimizers.schedules.LearningRateSchedule`. The schedule will be
|
||
|
called on each iteration with `schedule(iteration)`, a `tf.Variable`
|
||
|
owned by the optimizer.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
>>> var = tf.Variable(np.random.random(size=(1,)))
|
||
|
>>> learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
|
||
|
... initial_learning_rate=.01, decay_steps=20, decay_rate=.1)
|
||
|
>>> opt = tf.keras.optimizers.legacy.SGD(learning_rate=learning_rate)
|
||
|
>>> loss = lambda: 3 * var
|
||
|
>>> opt.minimize(loss, var_list=[var])
|
||
|
<tf.Variable...
|
||
|
|
||
|
The second way is through a callable function that
|
||
|
does not accept any arguments.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
>>> var = tf.Variable(np.random.random(size=(1,)))
|
||
|
>>> def lr_callable():
|
||
|
... return .1
|
||
|
>>> opt = tf.keras.optimizers.legacy.SGD(learning_rate=lr_callable)
|
||
|
>>> loss = lambda: 3 * var
|
||
|
>>> opt.minimize(loss, var_list=[var])
|
||
|
<tf.Variable...
|
||
|
|
||
|
### Creating a custom optimizer
|
||
|
|
||
|
If you intend to create your own optimization algorithm, simply inherit from
|
||
|
this class and override the following methods:
|
||
|
|
||
|
- `_resource_apply_dense` (update variable given gradient tensor is a
|
||
|
dense `tf.Tensor`)
|
||
|
- `_resource_apply_sparse` (update variable given gradient tensor is a
|
||
|
sparse `tf.IndexedSlices`. The most common way for this to happen
|
||
|
is if you are taking the gradient through a `tf.gather`.)
|
||
|
- `_create_slots`
|
||
|
(if your optimizer algorithm requires additional variables)
|
||
|
- `get_config`
|
||
|
(serialization of the optimizer, include all hyper parameters)
|
||
|
"""
|
||
|
|
||
|
# Subclasses should set this to True unless they override `apply_gradients`
|
||
|
# with a version that does not have the `experimental_aggregate_gradients`
|
||
|
# argument. Older versions of Keras did not have this argument so custom
|
||
|
# optimizers may have overridden `apply_gradients` without the
|
||
|
# `experimental_aggregate_gradients` argument. Keras only passes
|
||
|
# `experimental_aggregate_gradients` if this attribute is True.
|
||
|
# Note: This attribute will likely be removed in an upcoming release.
|
||
|
_HAS_AGGREGATE_GRAD = False
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
name,
|
||
|
gradient_aggregator=None,
|
||
|
gradient_transformers=None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
"""Create a new Optimizer.
|
||
|
|
||
|
This must be called by the constructors of subclasses.
|
||
|
Note that Optimizer instances should not bind to a single graph,
|
||
|
and so shouldn't keep Tensors as member variables. Generally
|
||
|
you should be able to use the _set_hyper()/state.get_hyper()
|
||
|
facility instead.
|
||
|
|
||
|
This class is stateful and thread-compatible.
|
||
|
|
||
|
Example of custom gradient transformations:
|
||
|
|
||
|
```python
|
||
|
def my_gradient_transformer(grads_and_vars):
|
||
|
# Simple example, double the gradients.
|
||
|
return [(2. * g, v) for g, v in grads_and_vars]
|
||
|
|
||
|
optimizer = tf.keras.optimizers.legacy.SGD(
|
||
|
1e-3, gradient_transformers=[my_gradient_transformer])
|
||
|
```
|
||
|
|
||
|
Args:
|
||
|
name: String. The name to use for momentum accumulator weights created
|
||
|
by the optimizer.
|
||
|
gradient_aggregator: The function to use to aggregate gradients across
|
||
|
devices (when using `tf.distribute.Strategy`). If `None`, defaults
|
||
|
to summing the gradients across devices. The function should accept
|
||
|
and return a list of `(gradient, variable)` tuples.
|
||
|
gradient_transformers: Optional. List of functions to use to transform
|
||
|
gradients before applying updates to Variables. The functions are
|
||
|
applied after `gradient_aggregator`. The functions should accept and
|
||
|
return a list of `(gradient, variable)` tuples.
|
||
|
**kwargs: keyword arguments. Allowed arguments are `clipvalue`,
|
||
|
`clipnorm`, `global_clipnorm`.
|
||
|
If `clipvalue` (float) is set, the gradient of each weight
|
||
|
is clipped to be no higher than this value.
|
||
|
If `clipnorm` (float) is set, the gradient of each weight
|
||
|
is individually clipped so that its norm is no higher than this
|
||
|
value. If `global_clipnorm` (float) is set the gradient of all
|
||
|
weights is clipped so that their global norm is no higher than this
|
||
|
value.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: in case of any invalid argument.
|
||
|
"""
|
||
|
# Instrument optimizer usages
|
||
|
keras_optimizers_gauge.get_cell(self.__class__.__name__).set(True)
|
||
|
|
||
|
allowed_kwargs = {
|
||
|
"clipnorm",
|
||
|
"clipvalue",
|
||
|
"lr",
|
||
|
"decay",
|
||
|
"global_clipnorm",
|
||
|
}
|
||
|
for k in kwargs:
|
||
|
if k not in allowed_kwargs:
|
||
|
raise TypeError(
|
||
|
"Unexpected keyword argument "
|
||
|
f"passed to optimizer: {str(k)}. Allowed kwargs are "
|
||
|
f"{allowed_kwargs}."
|
||
|
)
|
||
|
# checks that all keyword arguments are non-negative.
|
||
|
if kwargs[k] is not None and kwargs[k] < 0:
|
||
|
raise ValueError(f"Expected {k} >= 0, received: {kwargs[k]}")
|
||
|
if k == "lr":
|
||
|
warnings.warn(
|
||
|
"The `lr` argument is deprecated, "
|
||
|
"use `learning_rate` instead.",
|
||
|
stacklevel=2,
|
||
|
)
|
||
|
|
||
|
self._use_locking = True
|
||
|
self._init_set_name(name)
|
||
|
self._hyper = {}
|
||
|
# dict: {variable name : {slot name : variable}}
|
||
|
self._slots = {}
|
||
|
self._slot_names = []
|
||
|
self._weights = []
|
||
|
self._iterations = None
|
||
|
|
||
|
# For implementing Trackable. Stores information about how to restore
|
||
|
# slot variables which have not yet been created
|
||
|
# (trackable._CheckpointPosition objects).
|
||
|
# {slot_name :
|
||
|
# {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
|
||
|
# ... }
|
||
|
self._deferred_slot_restorations = {}
|
||
|
|
||
|
decay = kwargs.pop("decay", 0.0)
|
||
|
if decay < 0.0:
|
||
|
raise ValueError(
|
||
|
f"decay cannot be less than 0. Received: decay={decay}."
|
||
|
)
|
||
|
self._initial_decay = decay
|
||
|
|
||
|
self._hypers_created = False
|
||
|
# Store the distribution strategy object if the optimizer is created
|
||
|
# inside strategy scope, so it could be used to create variables later.
|
||
|
if tf.distribute.has_strategy():
|
||
|
self._distribution_strategy = tf.distribute.get_strategy()
|
||
|
else:
|
||
|
self._distribution_strategy = None
|
||
|
|
||
|
# Configure gradient transformations.
|
||
|
if gradient_aggregator is None:
|
||
|
gradient_aggregator = optimizer_utils.all_reduce_sum_gradients
|
||
|
self.gradient_aggregator = gradient_aggregator
|
||
|
if gradient_transformers is None:
|
||
|
gradient_transformers = []
|
||
|
self.gradient_transformers = gradient_transformers
|
||
|
self.clipnorm = kwargs.pop("clipnorm", None)
|
||
|
self.global_clipnorm = kwargs.pop("global_clipnorm", None)
|
||
|
if self.clipnorm is not None and self.global_clipnorm is not None:
|
||
|
raise ValueError(
|
||
|
"Cannot accept both `clipnorm` and `global_clipnorm`. "
|
||
|
"Received: `clipnorm`={}, `global_clipnorm`={}.".format(
|
||
|
self.clipnorm, self.global_clipnorm
|
||
|
)
|
||
|
)
|
||
|
self.clipvalue = kwargs.pop("clipvalue", None)
|
||
|
|
||
|
@property
|
||
|
def clipnorm(self):
|
||
|
"""`float` or `None`. If set, clips gradients to a maximum norm."""
|
||
|
return self._clipnorm
|
||
|
|
||
|
@property
|
||
|
def global_clipnorm(self):
|
||
|
"""`float` or `None`.
|
||
|
|
||
|
If set, clips gradients to a maximum norm.
|
||
|
|
||
|
Check `tf.clip_by_global_norm` for more details.
|
||
|
"""
|
||
|
return self._global_clipnorm
|
||
|
|
||
|
@clipnorm.setter
|
||
|
def clipnorm(self, val):
|
||
|
if val is not None and self.gradient_transformers:
|
||
|
raise ValueError(
|
||
|
"`clipnorm` cannot be set when `gradient_transformers` "
|
||
|
"is set. Instead, use the `gradient_transformers` to "
|
||
|
"specify clipping and other transformations. Received: "
|
||
|
f"val={val}, "
|
||
|
f"gradient_transformers={self.gradient_transformers}."
|
||
|
)
|
||
|
self._clipnorm = val
|
||
|
self._clipnorm_fn = optimizer_utils.make_gradient_clipnorm_fn(
|
||
|
self._clipnorm
|
||
|
)
|
||
|
|
||
|
@global_clipnorm.setter
|
||
|
def global_clipnorm(self, val):
|
||
|
if val is not None and self.gradient_transformers:
|
||
|
raise ValueError(
|
||
|
"`global_clipnorm` cannot be set when "
|
||
|
"`gradient_transformers` "
|
||
|
"is set. Instead, use the `gradient_transformers` to "
|
||
|
"specify clipping and other transformations. Received: "
|
||
|
f"val={val}, "
|
||
|
f"gradient_transformers={self.gradient_transformers}."
|
||
|
)
|
||
|
self._global_clipnorm = val
|
||
|
self._global_clipnorm_fn = (
|
||
|
optimizer_utils.make_global_gradient_clipnorm_fn(
|
||
|
self._global_clipnorm
|
||
|
)
|
||
|
)
|
||
|
|
||
|
@property
|
||
|
def clipvalue(self):
|
||
|
"""`float` or `None`. If set, clips gradients to a maximum value."""
|
||
|
return self._clipvalue
|
||
|
|
||
|
@clipvalue.setter
|
||
|
def clipvalue(self, val):
|
||
|
if val is not None and self.gradient_transformers:
|
||
|
raise ValueError(
|
||
|
"`clipvalue` cannot be set when `gradient_transformers` "
|
||
|
"is set. Instead, use the `gradient_transformers` to "
|
||
|
"specify clipping and other transformations. Received: "
|
||
|
f"val={val}, "
|
||
|
f"gradient_transformers={self.gradient_transformers}."
|
||
|
)
|
||
|
self._clipvalue = val
|
||
|
self._clipvalue_fn = optimizer_utils.make_gradient_clipvalue_fn(
|
||
|
self._clipvalue
|
||
|
)
|
||
|
|
||
|
def _transform_loss(self, loss):
|
||
|
"""Called in `.minimize` to transform loss before computing
|
||
|
gradients."""
|
||
|
return loss
|
||
|
|
||
|
def _get_gradients(self, tape, loss, var_list, grad_loss=None):
|
||
|
"""Called in `minimize` to compute gradients from loss."""
|
||
|
grads = tape.gradient(loss, var_list, grad_loss)
|
||
|
return list(zip(grads, var_list))
|
||
|
|
||
|
def _transform_unaggregated_gradients(self, grads_and_vars):
|
||
|
"""Called in `apply_gradients` before gradient aggregation."""
|
||
|
return grads_and_vars
|
||
|
|
||
|
def _aggregate_gradients(self, grads_and_vars):
|
||
|
"""Called in `apply_gradients` to aggregate gradients across devices.
|
||
|
|
||
|
Note that user subclasses may override this, so the interface should not
|
||
|
be changed.
|
||
|
|
||
|
Args:
|
||
|
grads_and_vars: List of (gradient, variable) pairs.
|
||
|
|
||
|
Returns:
|
||
|
A list of (aggregrated_gradient, variable) pairs. By default, this
|
||
|
calls `self.gradient_aggregator`.
|
||
|
"""
|
||
|
return self.gradient_aggregator(grads_and_vars)
|
||
|
|
||
|
def _transform_gradients(self, grads_and_vars):
|
||
|
"""Called in `apply_gradients` after aggregation."""
|
||
|
if self._clipvalue is not None:
|
||
|
grads_and_vars = self._clipvalue_fn(grads_and_vars)
|
||
|
if self._clipnorm is not None:
|
||
|
grads_and_vars = self._clipnorm_fn(grads_and_vars)
|
||
|
if self._global_clipnorm is not None:
|
||
|
grads_and_vars = self._global_clipnorm_fn(grads_and_vars)
|
||
|
|
||
|
for fn in self.gradient_transformers:
|
||
|
grads_and_vars = fn(grads_and_vars)
|
||
|
return grads_and_vars
|
||
|
|
||
|
def minimize(self, loss, var_list, grad_loss=None, name=None, tape=None):
|
||
|
"""Minimize `loss` by updating `var_list`.
|
||
|
|
||
|
This method simply computes gradient using `tf.GradientTape` and calls
|
||
|
`apply_gradients()`. If you want to process the gradient before applying
|
||
|
then call `tf.GradientTape` and `apply_gradients()` explicitly instead
|
||
|
of using this function.
|
||
|
|
||
|
Args:
|
||
|
loss: `Tensor` or callable. If a callable, `loss` should take no
|
||
|
arguments and return the value to minimize. If a `Tensor`, the
|
||
|
`tape` argument must be passed.
|
||
|
var_list: list or tuple of `Variable` objects to update to minimize
|
||
|
`loss`, or a callable returning the list or tuple of `Variable`
|
||
|
objects. Use callable when the variable list would otherwise be
|
||
|
incomplete before `minimize` since the variables are created at the
|
||
|
first time `loss` is called.
|
||
|
grad_loss: (Optional). A `Tensor` holding the gradient computed for
|
||
|
`loss`.
|
||
|
name: (Optional) str. Name for the returned operation.
|
||
|
tape: (Optional) `tf.GradientTape`. If `loss` is provided as a
|
||
|
`Tensor`, the tape that computed the `loss` must be provided.
|
||
|
|
||
|
Returns:
|
||
|
An `Operation` that updates the variables in `var_list`. The
|
||
|
`iterations` will be automatically increased by 1.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If some of the variables are not `Variable` objects.
|
||
|
|
||
|
"""
|
||
|
grads_and_vars = self._compute_gradients(
|
||
|
loss, var_list=var_list, grad_loss=grad_loss, tape=tape
|
||
|
)
|
||
|
return self.apply_gradients(grads_and_vars, name=name)
|
||
|
|
||
|
def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
|
||
|
"""Compute gradients of `loss` for the variables in `var_list`.
|
||
|
|
||
|
This is the first part of `minimize()`. It returns a list
|
||
|
of (gradient, variable) pairs where "gradient" is the gradient
|
||
|
for "variable". Note that "gradient" can be a `Tensor`, an
|
||
|
`IndexedSlices`, or `None` if there is no gradient for the
|
||
|
given variable.
|
||
|
|
||
|
Args:
|
||
|
loss: `Tensor` or callable. If a callable, `loss` should take no
|
||
|
arguments and return the value to minimize. If a `Tensor`, the
|
||
|
`tape` argument must be passed.
|
||
|
var_list: list or tuple of `Variable` objects to update to minimize
|
||
|
`loss`, or a callable returning the list or tuple of `Variable`
|
||
|
objects. Use callable when the variable list would otherwise be
|
||
|
incomplete before `minimize` and the variables are created at the
|
||
|
first time when `loss` is called.
|
||
|
grad_loss: Optional. A `Tensor` holding the gradient computed for
|
||
|
`loss`.
|
||
|
tape: (Optional) `tf.GradientTape`. If `loss` is provided as a
|
||
|
`Tensor`, the tape that computed the `loss` must be provided.
|
||
|
|
||
|
Returns:
|
||
|
A list of (gradient, variable) pairs. Variable is always present, but
|
||
|
gradient can be `None`.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: If `var_list` contains anything else than `Variable`
|
||
|
objects.
|
||
|
ValueError: If some arguments are invalid, or var_list is None.
|
||
|
"""
|
||
|
# TODO(joshl): Test that we handle weight decay in a reasonable way.
|
||
|
if not callable(loss) and tape is None:
|
||
|
raise ValueError(
|
||
|
"`tape` is required when a `Tensor` loss is passed. "
|
||
|
f"Received: loss={loss}, tape={tape}."
|
||
|
)
|
||
|
tape = tape if tape is not None else tf.GradientTape()
|
||
|
|
||
|
if callable(loss):
|
||
|
with tape:
|
||
|
if not callable(var_list):
|
||
|
tape.watch(var_list)
|
||
|
loss = loss()
|
||
|
if callable(var_list):
|
||
|
var_list = var_list()
|
||
|
|
||
|
with tape:
|
||
|
loss = self._transform_loss(loss)
|
||
|
|
||
|
var_list = tf.nest.flatten(var_list)
|
||
|
with tf.name_scope(self._name + "/gradients"):
|
||
|
grads_and_vars = self._get_gradients(
|
||
|
tape, loss, var_list, grad_loss
|
||
|
)
|
||
|
|
||
|
self._assert_valid_dtypes(
|
||
|
[
|
||
|
v
|
||
|
for g, v in grads_and_vars
|
||
|
if g is not None and v.dtype != tf.resource
|
||
|
]
|
||
|
)
|
||
|
|
||
|
return grads_and_vars
|
||
|
|
||
|
def apply_gradients(
|
||
|
self, grads_and_vars, name=None, experimental_aggregate_gradients=True
|
||
|
):
|
||
|
"""Apply gradients to variables.
|
||
|
|
||
|
This is the second part of `minimize()`. It returns an `Operation` that
|
||
|
applies gradients.
|
||
|
|
||
|
The method sums gradients from all replicas in the presence of
|
||
|
`tf.distribute.Strategy` by default. You can aggregate gradients
|
||
|
yourself by passing `experimental_aggregate_gradients=False`.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
grads = tape.gradient(loss, vars)
|
||
|
grads = tf.distribute.get_replica_context().all_reduce('sum', grads)
|
||
|
# Processing aggregated gradients.
|
||
|
optimizer.apply_gradients(zip(grads, vars),
|
||
|
experimental_aggregate_gradients=False)
|
||
|
|
||
|
```
|
||
|
|
||
|
Args:
|
||
|
grads_and_vars: List of (gradient, variable) pairs.
|
||
|
name: Optional name for the returned operation. Default to the name
|
||
|
passed to the `Optimizer` constructor.
|
||
|
experimental_aggregate_gradients: Whether to sum gradients from
|
||
|
different replicas in the presence of `tf.distribute.Strategy`. If
|
||
|
False, it's user responsibility to aggregate the gradients. Default
|
||
|
to True.
|
||
|
|
||
|
Returns:
|
||
|
An `Operation` that applies the specified gradients. The `iterations`
|
||
|
will be automatically increased by 1.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: If `grads_and_vars` is malformed.
|
||
|
ValueError: If none of the variables have gradients.
|
||
|
RuntimeError: If called in a cross-replica context.
|
||
|
"""
|
||
|
grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars)
|
||
|
var_list = [v for (_, v) in grads_and_vars]
|
||
|
|
||
|
with tf.name_scope(self._name):
|
||
|
# Create iteration if necessary.
|
||
|
with tf.init_scope():
|
||
|
self._create_all_weights(var_list)
|
||
|
|
||
|
if not grads_and_vars:
|
||
|
# Distribution strategy does not support reducing an empty list
|
||
|
# of gradients
|
||
|
return tf.no_op()
|
||
|
|
||
|
if tf.distribute.in_cross_replica_context():
|
||
|
raise RuntimeError(
|
||
|
"`apply_gradients() cannot be called in cross-replica "
|
||
|
"context. Use `tf.distribute.Strategy.run` to enter "
|
||
|
"replica context. For more information, please see the "
|
||
|
"docstring of `tf.distribute.get_replica_context`."
|
||
|
)
|
||
|
|
||
|
strategy = tf.distribute.get_strategy()
|
||
|
if (
|
||
|
not experimental_aggregate_gradients
|
||
|
and strategy
|
||
|
and isinstance(
|
||
|
strategy,
|
||
|
(
|
||
|
tf.compat.v1.distribute.experimental.ParameterServerStrategy, # noqa: E501
|
||
|
tf.distribute.experimental.ParameterServerStrategy,
|
||
|
tf.distribute.experimental.CentralStorageStrategy,
|
||
|
tf.compat.v1.distribute.experimental.CentralStorageStrategy, # noqa: E501
|
||
|
),
|
||
|
)
|
||
|
):
|
||
|
raise NotImplementedError(
|
||
|
"`experimental_aggregate_gradients=False is not supported "
|
||
|
"for ParameterServerStrategy and CentralStorageStrategy. "
|
||
|
f"Used: strategy={strategy}."
|
||
|
)
|
||
|
|
||
|
apply_state = self._prepare(var_list)
|
||
|
if experimental_aggregate_gradients:
|
||
|
grads_and_vars = self._transform_unaggregated_gradients(
|
||
|
grads_and_vars
|
||
|
)
|
||
|
grads_and_vars = self._aggregate_gradients(grads_and_vars)
|
||
|
grads_and_vars = self._transform_gradients(grads_and_vars)
|
||
|
|
||
|
return tf.__internal__.distribute.interim.maybe_merge_call(
|
||
|
functools.partial(
|
||
|
self._distributed_apply, apply_state=apply_state
|
||
|
),
|
||
|
strategy,
|
||
|
grads_and_vars,
|
||
|
name=name,
|
||
|
)
|
||
|
|
||
|
def _distributed_apply(
|
||
|
self, distribution, grads_and_vars, apply_state, name
|
||
|
):
|
||
|
"""`apply_gradients` using a `DistributionStrategy`."""
|
||
|
|
||
|
def apply_grad_to_update_var(var, grad):
|
||
|
"""Apply gradient to variable."""
|
||
|
if isinstance(var, tf.Tensor):
|
||
|
raise NotImplementedError(
|
||
|
"Updating a `Tensor` is not implemented. "
|
||
|
f"Received: var={var}."
|
||
|
)
|
||
|
|
||
|
apply_kwargs = {}
|
||
|
if isinstance(grad, tf.IndexedSlices):
|
||
|
if var.constraint is not None:
|
||
|
raise RuntimeError(
|
||
|
"Cannot use a constraint function on a sparse "
|
||
|
f"variable. Received: grad={grad}, "
|
||
|
f"var.constraint={var.constraint}."
|
||
|
)
|
||
|
if "apply_state" in self._sparse_apply_args:
|
||
|
apply_kwargs["apply_state"] = apply_state
|
||
|
return self._resource_apply_sparse_duplicate_indices(
|
||
|
grad.values, var, grad.indices, **apply_kwargs
|
||
|
)
|
||
|
|
||
|
if "apply_state" in self._dense_apply_args:
|
||
|
apply_kwargs["apply_state"] = apply_state
|
||
|
update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
|
||
|
if var.constraint is not None:
|
||
|
with tf.control_dependencies([update_op]):
|
||
|
return var.assign(var.constraint(var))
|
||
|
else:
|
||
|
return update_op
|
||
|
|
||
|
eagerly_outside_functions = (
|
||
|
tf.compat.v1.executing_eagerly_outside_functions()
|
||
|
)
|
||
|
update_ops = []
|
||
|
with name_scope_only_in_function_or_graph(name or self._name):
|
||
|
for grad, var in grads_and_vars:
|
||
|
# Colocate the update with variables to avoid unnecessary
|
||
|
# communication delays. See b/136304694.
|
||
|
with distribution.extended.colocate_vars_with(var):
|
||
|
with name_scope_only_in_function_or_graph(
|
||
|
"update"
|
||
|
if eagerly_outside_functions
|
||
|
else "update_" + var.op.name
|
||
|
):
|
||
|
update_op = distribution.extended.update(
|
||
|
var,
|
||
|
apply_grad_to_update_var,
|
||
|
args=(grad,),
|
||
|
group=False,
|
||
|
)
|
||
|
if tf.distribute.in_cross_replica_context():
|
||
|
# In cross-replica context, extended.update returns
|
||
|
# a list of update ops from all replicas
|
||
|
# (group=False).
|
||
|
update_ops.extend(update_op)
|
||
|
else:
|
||
|
# In replica context, extended.update return the
|
||
|
# single update op of current replica.
|
||
|
update_ops.append(update_op)
|
||
|
|
||
|
any_symbolic = any(
|
||
|
isinstance(i, tf.Operation) or tf_utils.is_symbolic_tensor(i)
|
||
|
for i in update_ops
|
||
|
)
|
||
|
if not tf.executing_eagerly() or any_symbolic:
|
||
|
# If the current context is graph mode or any of the update ops
|
||
|
# are symbolic then the step update should be carried out under
|
||
|
# a graph context. (eager updates execute immediately)
|
||
|
with backend._current_graph(update_ops).as_default():
|
||
|
with tf.control_dependencies([tf.group(update_ops)]):
|
||
|
return self.iterations.assign_add(1, read_value=False)
|
||
|
|
||
|
return self.iterations.assign_add(1)
|
||
|
|
||
|
def get_gradients(self, loss, params):
|
||
|
"""Returns gradients of `loss` with respect to `params`.
|
||
|
|
||
|
Should be used only in legacy v1 graph mode.
|
||
|
|
||
|
Args:
|
||
|
loss: Loss tensor.
|
||
|
params: List of variables.
|
||
|
|
||
|
Returns:
|
||
|
List of gradient tensors.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: In case any gradient cannot be computed (e.g. if gradient
|
||
|
function not implemented).
|
||
|
"""
|
||
|
params = tf.nest.flatten(params)
|
||
|
with backend.get_graph().as_default(), backend.name_scope(
|
||
|
self._name + "/gradients"
|
||
|
):
|
||
|
grads = tf.compat.v1.gradients(loss, params)
|
||
|
for grad, param in zip(grads, params):
|
||
|
if grad is None:
|
||
|
raise ValueError(
|
||
|
"Variable {} has `None` for gradient. "
|
||
|
"Please make sure that all of your ops have a "
|
||
|
"gradient defined (i.e. are differentiable). "
|
||
|
"Common ops without gradient: "
|
||
|
"K.argmax, K.round, K.eval.".format(param)
|
||
|
)
|
||
|
return grads
|
||
|
|
||
|
def get_updates(self, loss, params):
|
||
|
grads = self.get_gradients(loss, params)
|
||
|
grads_and_vars = list(zip(grads, params))
|
||
|
self._assert_valid_dtypes(
|
||
|
[
|
||
|
v
|
||
|
for g, v in grads_and_vars
|
||
|
if g is not None and v.dtype != tf.resource
|
||
|
]
|
||
|
)
|
||
|
return [self.apply_gradients(grads_and_vars)]
|
||
|
|
||
|
def _set_hyper(self, name, value):
|
||
|
"""set hyper `name` to value. value can be callable, tensor, numeric."""
|
||
|
if isinstance(value, tf.__internal__.tracking.Trackable):
|
||
|
self._track_trackable(value, name, overwrite=True)
|
||
|
if name not in self._hyper:
|
||
|
self._hyper[name] = value
|
||
|
else:
|
||
|
prev_value = self._hyper[name]
|
||
|
if (
|
||
|
callable(prev_value)
|
||
|
or isinstance(
|
||
|
prev_value,
|
||
|
(
|
||
|
tf.Tensor,
|
||
|
int,
|
||
|
float,
|
||
|
learning_rate_schedule.LearningRateSchedule,
|
||
|
),
|
||
|
)
|
||
|
or isinstance(
|
||
|
value, learning_rate_schedule.LearningRateSchedule
|
||
|
)
|
||
|
):
|
||
|
self._hyper[name] = value
|
||
|
else:
|
||
|
backend.set_value(self._hyper[name], value)
|
||
|
|
||
|
def _get_hyper(self, name, dtype=None):
|
||
|
if not self._hypers_created:
|
||
|
self._create_hypers()
|
||
|
value = self._hyper[name]
|
||
|
if isinstance(value, learning_rate_schedule.LearningRateSchedule):
|
||
|
return value
|
||
|
if callable(value):
|
||
|
value = value()
|
||
|
if dtype:
|
||
|
return tf.cast(value, dtype)
|
||
|
else:
|
||
|
return value
|
||
|
|
||
|
def _create_slots(self, var_list):
|
||
|
pass
|
||
|
|
||
|
def _create_slots_for_sharded_variables(self, var_list):
|
||
|
"""Add ShardedVariables to slots to later reconstruct for checkpointing.
|
||
|
|
||
|
ShardedVariables don't have slot variables created for them; their
|
||
|
shards do. This function allows users to call get_slot with a
|
||
|
ShardedVariable input and receive a ShardedVariable output containing
|
||
|
the appropriate slot vars.
|
||
|
|
||
|
Iterate over the variables to find shards, and aggregate the sharded
|
||
|
containers in a set. Add these ShardedVariables to _slots so that
|
||
|
get_slot can retrieve the proper slot variables for their component
|
||
|
shards, and reconstruct those into a ShardedVariable.
|
||
|
|
||
|
Args:
|
||
|
var_list: list or tuple of `Variable` objects that will be minimized
|
||
|
using this optimizer.
|
||
|
"""
|
||
|
sharded_vars = set()
|
||
|
for var in var_list:
|
||
|
if getattr(var, "_sharded_container", False):
|
||
|
sharded_vars.add(var._sharded_container())
|
||
|
|
||
|
for sharded_var in sharded_vars:
|
||
|
sharded_key = _var_key(sharded_var)
|
||
|
slot_dict = {}
|
||
|
for slot in self.get_slot_names():
|
||
|
slot_dict[slot] = sharded_var
|
||
|
self._slots[sharded_key] = slot_dict
|
||
|
|
||
|
def _create_all_weights(self, var_list):
|
||
|
"""Creates all weights, including iterations, hyperparameters and slot
|
||
|
vars.
|
||
|
|
||
|
This will add newly created variables to `optimizer.weights`.
|
||
|
|
||
|
New variables are only created when this method is called the first
|
||
|
time, or when called with different variables in the var_list.
|
||
|
|
||
|
Args:
|
||
|
var_list: list or tuple of `Variable` objects that will be minimized
|
||
|
using this optimizer.
|
||
|
"""
|
||
|
|
||
|
_ = self.iterations
|
||
|
self._create_hypers()
|
||
|
self._create_slots(var_list)
|
||
|
self._create_slots_for_sharded_variables(var_list)
|
||
|
|
||
|
def __getattribute__(self, name):
|
||
|
"""Overridden to support hyperparameter access."""
|
||
|
try:
|
||
|
return super().__getattribute__(name)
|
||
|
except AttributeError as e:
|
||
|
# Needed to avoid infinite recursion with __setattr__.
|
||
|
if name == "_hyper":
|
||
|
raise e
|
||
|
# Backwards compatibility with Keras optimizers.
|
||
|
if name == "lr":
|
||
|
name = "learning_rate"
|
||
|
if name in self._hyper:
|
||
|
return self._get_hyper(name)
|
||
|
raise e
|
||
|
|
||
|
def __dir__(self):
|
||
|
result = set(super().__dir__())
|
||
|
if "_hyper" in result:
|
||
|
result |= self._hyper.keys()
|
||
|
if "learning_rate" in self._hyper.keys():
|
||
|
result.add("lr")
|
||
|
return list(result)
|
||
|
|
||
|
def __setattr__(self, name, value):
|
||
|
"""Override setattr to support dynamic hyperparameter setting."""
|
||
|
# Backwards compatibility with Keras optimizers.
|
||
|
if name == "lr":
|
||
|
name = "learning_rate"
|
||
|
if hasattr(self, "_hyper") and name in self._hyper:
|
||
|
self._set_hyper(name, value)
|
||
|
else:
|
||
|
super().__setattr__(name, value)
|
||
|
|
||
|
def get_slot_names(self):
|
||
|
"""A list of names for this optimizer's slots."""
|
||
|
return self._slot_names
|
||
|
|
||
|
def add_slot(self, var, slot_name, initializer="zeros", shape=None):
|
||
|
"""Add a new slot variable for `var`.
|
||
|
|
||
|
A slot variable is an additional variable associated with `var` to
|
||
|
train. It is allocated and managed by optimizers, e.g. `Adam`.
|
||
|
|
||
|
Args:
|
||
|
var: a `Variable` object.
|
||
|
slot_name: name of the slot variable.
|
||
|
initializer: initializer of the slot variable
|
||
|
shape: (Optional) shape of the slot variable. If not set, it will
|
||
|
default to the shape of `var`.
|
||
|
|
||
|
Returns:
|
||
|
A slot variable.
|
||
|
"""
|
||
|
if slot_name not in self._slot_names:
|
||
|
self._slot_names.append(slot_name)
|
||
|
var_key = _var_key(var)
|
||
|
slot_dict = self._slots.setdefault(var_key, {})
|
||
|
weight = slot_dict.get(slot_name, None)
|
||
|
if weight is None:
|
||
|
if isinstance(initializer, str) or callable(initializer):
|
||
|
initializer = initializers.get(initializer)
|
||
|
if isinstance(
|
||
|
initializer,
|
||
|
tf.__internal__.tracking.CheckpointInitialValueCallable,
|
||
|
) or (shape is not None):
|
||
|
slot_shape = shape
|
||
|
else:
|
||
|
slot_shape = var.shape
|
||
|
initial_value = functools.partial(
|
||
|
initializer, shape=slot_shape, dtype=var.dtype
|
||
|
)
|
||
|
else:
|
||
|
initial_value = initializer
|
||
|
|
||
|
with self._distribution_strategy_scope():
|
||
|
strategy = tf.distribute.get_strategy()
|
||
|
if not strategy.extended.variable_created_in_scope(var):
|
||
|
raise ValueError(
|
||
|
"Trying to create optimizer slot variable under the "
|
||
|
"scope for tf.distribute.Strategy ({}), which is "
|
||
|
"different from the scope used for the original "
|
||
|
"variable ({}). Make sure the slot variables are "
|
||
|
"created under the same strategy scope. This may "
|
||
|
"happen if you're restoring from a checkpoint "
|
||
|
"outside the scope.".format(strategy, var)
|
||
|
)
|
||
|
|
||
|
with strategy.extended.colocate_vars_with(var):
|
||
|
weight = tf.Variable(
|
||
|
name=f"{var._shared_name}/{slot_name}",
|
||
|
dtype=var.dtype,
|
||
|
trainable=False,
|
||
|
initial_value=initial_value,
|
||
|
)
|
||
|
backend.track_variable(weight)
|
||
|
slot_dict[slot_name] = weight
|
||
|
self._restore_slot_variable(
|
||
|
slot_name=slot_name, variable=var, slot_variable=weight
|
||
|
)
|
||
|
self._weights.append(weight)
|
||
|
return weight
|
||
|
|
||
|
def get_slot(self, var, slot_name):
|
||
|
var_key = _var_key(var)
|
||
|
slot_dict = self._slots[var_key]
|
||
|
slot_variable = slot_dict[slot_name]
|
||
|
if isinstance(
|
||
|
slot_variable, tf.__internal__.distribute.ShardedVariable
|
||
|
):
|
||
|
# Construct a ShardedVariable that points to the input
|
||
|
# ShardedVariable's component shard's slot variables.
|
||
|
shard_vars = []
|
||
|
for shard in slot_variable.variables:
|
||
|
slot_shard = self.get_slot(shard, slot_name)
|
||
|
shard_vars.append(slot_shard)
|
||
|
slot_variable = tf.__internal__.distribute.ShardedVariable(
|
||
|
shard_vars, name=slot_variable.name
|
||
|
)
|
||
|
return slot_variable
|
||
|
|
||
|
def _prepare(self, var_list):
|
||
|
keys = set()
|
||
|
for var in var_list:
|
||
|
if isinstance(var, tf.distribute.DistributedValues):
|
||
|
var_devices = var._devices
|
||
|
else:
|
||
|
var_devices = [var.device]
|
||
|
var_dtype = var.dtype.base_dtype
|
||
|
for var_device in var_devices:
|
||
|
keys.add((var_device, var_dtype))
|
||
|
|
||
|
apply_state = {}
|
||
|
for var_device, var_dtype in keys:
|
||
|
apply_state[(var_device, var_dtype)] = {}
|
||
|
with tf.device(var_device):
|
||
|
self._prepare_local(var_device, var_dtype, apply_state)
|
||
|
|
||
|
return apply_state
|
||
|
|
||
|
def _prepare_local(self, var_device, var_dtype, apply_state):
|
||
|
if "learning_rate" in self._hyper:
|
||
|
lr_t = tf.identity(self._decayed_lr(var_dtype))
|
||
|
apply_state[(var_device, var_dtype)]["lr_t"] = lr_t
|
||
|
|
||
|
def _fallback_apply_state(self, var_device, var_dtype):
|
||
|
"""Compatibility for subclasses that don't pass apply_state through."""
|
||
|
apply_state = {(var_device, var_dtype): {}}
|
||
|
self._prepare_local(var_device, var_dtype, apply_state)
|
||
|
return apply_state[(var_device, var_dtype)]
|
||
|
|
||
|
def _create_hypers(self):
|
||
|
if self._hypers_created:
|
||
|
return
|
||
|
with self._distribution_strategy_scope():
|
||
|
# Iterate hyper values deterministically.
|
||
|
for name, value in sorted(self._hyper.items()):
|
||
|
if isinstance(value, (tf.Tensor, tf.Variable)) or callable(
|
||
|
value
|
||
|
):
|
||
|
# The check for `callable` covers the usage when `value` is
|
||
|
# a `LearningRateSchedule`, in which case it does not need
|
||
|
# to create a variable.
|
||
|
continue
|
||
|
else:
|
||
|
self._hyper[name] = self.add_weight(
|
||
|
name,
|
||
|
shape=[],
|
||
|
trainable=False,
|
||
|
initializer=value,
|
||
|
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
|
||
|
)
|
||
|
self._hypers_created = True
|
||
|
|
||
|
@property
|
||
|
def iterations(self):
|
||
|
"""Variable. The number of training steps this Optimizer has run."""
|
||
|
if self._iterations is None:
|
||
|
with self._distribution_strategy_scope():
|
||
|
self._iterations = self.add_weight(
|
||
|
"iter",
|
||
|
shape=[],
|
||
|
dtype=tf.int64,
|
||
|
trainable=False,
|
||
|
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
|
||
|
)
|
||
|
self._weights.append(self._iterations)
|
||
|
return self._iterations
|
||
|
|
||
|
@iterations.setter
|
||
|
def iterations(self, variable):
|
||
|
if self._iterations is not None:
|
||
|
raise RuntimeError(
|
||
|
"Cannot set `iterations` to a new Variable after "
|
||
|
"the Optimizer weights have been created. Here it is "
|
||
|
f"attempting to set `iterations` to {variable}."
|
||
|
)
|
||
|
self._iterations = variable
|
||
|
self._weights.append(self._iterations)
|
||
|
|
||
|
def _decayed_lr(self, var_dtype):
|
||
|
"""Get decayed learning rate as a Tensor with dtype=var_dtype."""
|
||
|
lr_t = self._get_hyper("learning_rate", var_dtype)
|
||
|
if isinstance(lr_t, learning_rate_schedule.LearningRateSchedule):
|
||
|
local_step = tf.cast(self.iterations, var_dtype)
|
||
|
lr_t = tf.cast(lr_t(local_step), var_dtype)
|
||
|
if self._initial_decay > 0.0:
|
||
|
local_step = tf.cast(self.iterations, var_dtype)
|
||
|
decay_t = tf.cast(self._initial_decay, var_dtype)
|
||
|
lr_t = lr_t / (1.0 + decay_t * local_step)
|
||
|
return lr_t
|
||
|
|
||
|
@abc.abstractmethod
|
||
|
def get_config(self):
|
||
|
"""Returns the config of the optimizer.
|
||
|
|
||
|
An optimizer config is a Python dictionary (serializable)
|
||
|
containing the configuration of an optimizer.
|
||
|
The same optimizer can be reinstantiated later
|
||
|
(without any saved state) from this configuration.
|
||
|
|
||
|
Returns:
|
||
|
Python dictionary.
|
||
|
"""
|
||
|
config = {"name": self._name}
|
||
|
if self.clipnorm is not None:
|
||
|
config["clipnorm"] = self.clipnorm
|
||
|
if self.clipvalue is not None:
|
||
|
config["clipvalue"] = self.clipvalue
|
||
|
if self.global_clipnorm is not None:
|
||
|
config["global_clipnorm"] = self.global_clipnorm
|
||
|
return config
|
||
|
|
||
|
@classmethod
|
||
|
def from_config(cls, config, custom_objects=None):
|
||
|
"""Creates an optimizer from its config.
|
||
|
|
||
|
This method is the reverse of `get_config`,
|
||
|
capable of instantiating the same optimizer from the config
|
||
|
dictionary.
|
||
|
|
||
|
Args:
|
||
|
config: A Python dictionary, typically the output of get_config.
|
||
|
custom_objects: A Python dictionary mapping names to additional
|
||
|
Python objects used to create this optimizer, such as a function
|
||
|
used for a hyperparameter.
|
||
|
|
||
|
Returns:
|
||
|
An optimizer instance.
|
||
|
"""
|
||
|
if "lr" in config:
|
||
|
config["learning_rate"] = config.pop("lr")
|
||
|
if "learning_rate" in config:
|
||
|
if isinstance(config["learning_rate"], dict):
|
||
|
config["learning_rate"] = learning_rate_schedule.deserialize(
|
||
|
config["learning_rate"], custom_objects=custom_objects
|
||
|
)
|
||
|
return cls(**config)
|
||
|
|
||
|
def _serialize_hyperparameter(self, hyperparameter_name):
|
||
|
"""Serialize a hyperparameter that can be a float, callable, or
|
||
|
Tensor."""
|
||
|
value = self._hyper[hyperparameter_name]
|
||
|
if isinstance(value, learning_rate_schedule.LearningRateSchedule):
|
||
|
return learning_rate_schedule.serialize(value)
|
||
|
if callable(value):
|
||
|
return value()
|
||
|
if tf.is_tensor(value):
|
||
|
return backend.get_value(value)
|
||
|
return value
|
||
|
|
||
|
def variables(self):
|
||
|
"""Returns variables of this Optimizer based on the order created."""
|
||
|
return self._weights
|
||
|
|
||
|
@property
|
||
|
def weights(self):
|
||
|
"""Returns variables of this Optimizer based on the order created."""
|
||
|
return self._weights
|
||
|
|
||
|
def get_weights(self):
|
||
|
"""Returns the current weights of the optimizer.
|
||
|
|
||
|
The weights of an optimizer are its state (ie, variables).
|
||
|
This function returns the weight values associated with this
|
||
|
optimizer as a list of Numpy arrays. The first value is always the
|
||
|
iterations count of the optimizer, followed by the optimizer's state
|
||
|
variables in the order they were created. The returned list can in turn
|
||
|
be used to load state into similarly parameterized optimizers.
|
||
|
|
||
|
For example, the RMSprop optimizer for this simple model returns a list
|
||
|
of three values-- the iteration count, followed by the root-mean-square
|
||
|
value of the kernel and bias of the single Dense layer:
|
||
|
|
||
|
>>> opt = tf.keras.optimizers.legacy.RMSprop()
|
||
|
>>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
|
||
|
>>> m.compile(opt, loss='mse')
|
||
|
>>> data = np.arange(100).reshape(5, 20)
|
||
|
>>> labels = np.zeros(5)
|
||
|
>>> results = m.fit(data, labels) # Training.
|
||
|
>>> len(opt.get_weights())
|
||
|
3
|
||
|
|
||
|
Returns:
|
||
|
Weights values as a list of numpy arrays.
|
||
|
"""
|
||
|
params = self.weights
|
||
|
return backend.batch_get_value(params)
|
||
|
|
||
|
# TODO(tanzheny): Maybe share this logic with base_layer.
|
||
|
def set_weights(self, weights):
|
||
|
"""Set the weights of the optimizer.
|
||
|
|
||
|
The weights of an optimizer are its state (ie, variables).
|
||
|
This function takes the weight values associated with this
|
||
|
optimizer as a list of Numpy arrays. The first value is always the
|
||
|
iterations count of the optimizer, followed by the optimizer's state
|
||
|
variables in the order they are created. The passed values are used to
|
||
|
set the new state of the optimizer.
|
||
|
|
||
|
For example, the RMSprop optimizer for this simple model takes a list of
|
||
|
three values-- the iteration count, followed by the root-mean-square
|
||
|
value of the kernel and bias of the single Dense layer:
|
||
|
|
||
|
>>> opt = tf.keras.optimizers.legacy.RMSprop()
|
||
|
>>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
|
||
|
>>> m.compile(opt, loss='mse')
|
||
|
>>> data = np.arange(100).reshape(5, 20)
|
||
|
>>> labels = np.zeros(5)
|
||
|
>>> results = m.fit(data, labels) # Training.
|
||
|
>>> new_weights = [np.array(10), np.ones([20, 10]), np.zeros([10])]
|
||
|
>>> opt.set_weights(new_weights)
|
||
|
>>> opt.iterations
|
||
|
<tf.Variable 'RMSprop/iter:0' shape=() dtype=int64, numpy=10>
|
||
|
|
||
|
Args:
|
||
|
weights: weight values as a list of numpy arrays.
|
||
|
"""
|
||
|
params = self.weights
|
||
|
if len(params) != len(weights):
|
||
|
raise ValueError(
|
||
|
f"You called `set_weights(weights)` on optimizer {self._name} "
|
||
|
f"with a weight list of length {str(len(weights))}, "
|
||
|
f"but the optimizer was expecting {str(len(params))} "
|
||
|
f"weights. Provided weights: {str(weights)[:50]}..."
|
||
|
)
|
||
|
if not params:
|
||
|
return
|
||
|
weight_value_tuples = []
|
||
|
param_values = backend.batch_get_value(params)
|
||
|
for pv, p, w in zip(param_values, params, weights):
|
||
|
if pv.shape != w.shape:
|
||
|
raise ValueError(
|
||
|
f"Optimizer weight shape {str(pv.shape)} "
|
||
|
"not compatible with "
|
||
|
f"provided weight shape {str(w.shape)}."
|
||
|
)
|
||
|
weight_value_tuples.append((p, w))
|
||
|
backend.batch_set_value(weight_value_tuples)
|
||
|
|
||
|
def add_weight(
|
||
|
self,
|
||
|
name,
|
||
|
shape,
|
||
|
dtype=None,
|
||
|
initializer="zeros",
|
||
|
trainable=None,
|
||
|
synchronization=tf.VariableSynchronization.AUTO,
|
||
|
aggregation=tf.VariableAggregation.NONE,
|
||
|
):
|
||
|
|
||
|
if dtype is None:
|
||
|
dtype = tf.float32
|
||
|
if isinstance(initializer, str) or callable(initializer):
|
||
|
initializer = initializers.get(initializer)
|
||
|
|
||
|
if synchronization == tf.VariableSynchronization.ON_READ:
|
||
|
if trainable:
|
||
|
raise ValueError(
|
||
|
"Synchronization value can be set to "
|
||
|
"VariableSynchronization.ON_READ only for non-trainable "
|
||
|
"variables. You have specified trainable=True and "
|
||
|
"synchronization=VariableSynchronization.ON_READ."
|
||
|
)
|
||
|
else:
|
||
|
# Set trainable to be false when variable is to be synced on
|
||
|
# read.
|
||
|
trainable = False
|
||
|
elif trainable is None:
|
||
|
trainable = True
|
||
|
|
||
|
variable = self._add_variable_with_custom_getter(
|
||
|
name=name,
|
||
|
shape=shape,
|
||
|
getter=base_layer_utils.make_variable,
|
||
|
overwrite=True,
|
||
|
initializer=initializer,
|
||
|
dtype=dtype,
|
||
|
trainable=trainable,
|
||
|
use_resource=True,
|
||
|
synchronization=synchronization,
|
||
|
aggregation=aggregation,
|
||
|
)
|
||
|
backend.track_variable(variable)
|
||
|
|
||
|
return variable
|
||
|
|
||
|
def _init_set_name(self, name, zero_based=True):
|
||
|
if not name:
|
||
|
self._name = backend.unique_object_name(
|
||
|
generic_utils.to_snake_case(self.__class__.__name__),
|
||
|
zero_based=zero_based,
|
||
|
)
|
||
|
else:
|
||
|
self._name = name
|
||
|
|
||
|
def _assert_valid_dtypes(self, tensors):
|
||
|
"""Asserts tensors are all valid types (see `_valid_dtypes`).
|
||
|
|
||
|
Args:
|
||
|
tensors: Tensors to check.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If any tensor is not a valid type.
|
||
|
"""
|
||
|
valid_dtypes = self._valid_dtypes()
|
||
|
for t in tensors:
|
||
|
dtype = t.dtype.base_dtype
|
||
|
if dtype not in valid_dtypes:
|
||
|
raise ValueError(
|
||
|
"Invalid type {} for {}, expected: {}.".format(
|
||
|
dtype, t.name, [v for v in valid_dtypes]
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def _valid_dtypes(self):
|
||
|
"""Valid types for loss, variables and gradients.
|
||
|
|
||
|
Subclasses should override to allow other float types.
|
||
|
|
||
|
Returns:
|
||
|
Valid types for loss, variables and gradients.
|
||
|
"""
|
||
|
return _DEFAULT_VALID_DTYPES
|
||
|
|
||
|
def _call_if_callable(self, param):
|
||
|
"""Call the function if param is callable."""
|
||
|
return param() if callable(param) else param
|
||
|
|
||
|
def _resource_apply_dense(self, grad, handle, apply_state):
|
||
|
"""Add ops to apply dense gradients to the variable `handle`.
|
||
|
|
||
|
Args:
|
||
|
grad: a `Tensor` representing the gradient.
|
||
|
handle: a `Tensor` of dtype `resource` which points to the variable to
|
||
|
be updated.
|
||
|
apply_state: A dict which is used across multiple apply calls.
|
||
|
|
||
|
Returns:
|
||
|
An `Operation` which updates the value of the variable.
|
||
|
"""
|
||
|
raise NotImplementedError(
|
||
|
"`_resource_apply_dense` must be implemented in subclasses."
|
||
|
)
|
||
|
|
||
|
def _resource_apply_sparse_duplicate_indices(
|
||
|
self, grad, handle, indices, **kwargs
|
||
|
):
|
||
|
"""Add ops to apply sparse gradients to `handle`, with repeated indices.
|
||
|
|
||
|
Optimizers which override this method must deal with repeated indices.
|
||
|
See the docstring of `_apply_sparse_duplicate_indices` for details. By
|
||
|
default the correct behavior, to sum non-unique indices and their
|
||
|
associated gradients, is enforced by first pre-processing `grad` and
|
||
|
`indices` and passing them on to `_resource_apply_sparse`. Optimizers
|
||
|
which deal correctly with duplicate indices may instead override this
|
||
|
method to avoid the overhead of summing.
|
||
|
|
||
|
Args:
|
||
|
grad: a `Tensor` representing the gradient for the affected indices.
|
||
|
handle: a `Tensor` of dtype `resource` which points to the variable to
|
||
|
be updated.
|
||
|
indices: a `Tensor` of integral type representing the indices for
|
||
|
which the gradient is nonzero. Indices may be repeated.
|
||
|
**kwargs: May optionally contain `apply_state`
|
||
|
|
||
|
Returns:
|
||
|
An `Operation` which updates the value of the variable.
|
||
|
"""
|
||
|
summed_grad, unique_indices = _deduplicate_indexed_slices(
|
||
|
values=grad, indices=indices
|
||
|
)
|
||
|
return self._resource_apply_sparse(
|
||
|
summed_grad, handle, unique_indices, **kwargs
|
||
|
)
|
||
|
|
||
|
def _resource_apply_sparse(self, grad, handle, indices, apply_state):
|
||
|
"""Add ops to apply sparse gradients to the variable `handle`.
|
||
|
|
||
|
Similar to `_apply_sparse`, the `indices` argument to this method has
|
||
|
been de-duplicated. Optimizers which deal correctly with non-unique
|
||
|
indices may instead override `_resource_apply_sparse_duplicate_indices`
|
||
|
to avoid this overhead.
|
||
|
|
||
|
Args:
|
||
|
grad: a `Tensor` representing the gradient for the affected indices.
|
||
|
handle: a `Tensor` of dtype `resource` which points to the variable to
|
||
|
be updated.
|
||
|
indices: a `Tensor` of integral type representing the indices for
|
||
|
which the gradient is nonzero. Indices are unique.
|
||
|
apply_state: A dict which is used across multiple apply calls.
|
||
|
|
||
|
Returns:
|
||
|
An `Operation` which updates the value of the variable.
|
||
|
"""
|
||
|
raise NotImplementedError(
|
||
|
"`_resource_apply_sparse` Must be implemented in subclasses."
|
||
|
)
|
||
|
|
||
|
def _resource_scatter_add(self, x, i, v):
|
||
|
with tf.control_dependencies(
|
||
|
[
|
||
|
tf.raw_ops.ResourceScatterAdd(
|
||
|
resource=x.handle, indices=i, updates=v
|
||
|
)
|
||
|
]
|
||
|
):
|
||
|
return x.value()
|
||
|
|
||
|
def _resource_scatter_update(self, x, i, v):
|
||
|
with tf.control_dependencies(
|
||
|
[
|
||
|
tf.raw_ops.ResourceScatterUpdate(
|
||
|
resource=x.handle, indices=i, updates=v
|
||
|
)
|
||
|
]
|
||
|
):
|
||
|
return x.value()
|
||
|
|
||
|
@property
|
||
|
@layer_utils.cached_per_instance
|
||
|
def _dense_apply_args(self):
|
||
|
return tf_inspect.getfullargspec(self._resource_apply_dense).args
|
||
|
|
||
|
@property
|
||
|
@layer_utils.cached_per_instance
|
||
|
def _sparse_apply_args(self):
|
||
|
return tf_inspect.getfullargspec(self._resource_apply_sparse).args
|
||
|
|
||
|
# ---------------
|
||
|
# For implementing the trackable interface
|
||
|
# ---------------
|
||
|
|
||
|
def _restore_slot_variable(self, slot_name, variable, slot_variable):
|
||
|
"""Restore a newly created slot variable's value."""
|
||
|
variable_key = _var_key(variable)
|
||
|
deferred_restorations = self._deferred_slot_restorations.get(
|
||
|
slot_name, {}
|
||
|
).pop(variable_key, [])
|
||
|
# Iterate over restores, highest restore UID first to minimize the
|
||
|
# number of assignments.
|
||
|
deferred_restorations.sort(
|
||
|
key=lambda position: position.restore_uid, reverse=True
|
||
|
)
|
||
|
for checkpoint_position in deferred_restorations:
|
||
|
checkpoint_position.restore(slot_variable)
|
||
|
|
||
|
def _create_or_restore_slot_variable(
|
||
|
self, slot_variable_position, slot_name, variable
|
||
|
):
|
||
|
"""Returns the slot variable that should have a value restored into it.
|
||
|
|
||
|
It is up to the caller to restore the value into the slot variable if a
|
||
|
valid slot variable is returned.
|
||
|
|
||
|
Called when a variable which has an associated slot variable is created
|
||
|
or restored. When executing eagerly, we create the slot variable with a
|
||
|
restoring initializer.
|
||
|
|
||
|
No new variables are created when graph building. Instead,
|
||
|
_restore_slot_variable catches these after normal creation and adds
|
||
|
restore ops to the graph. This method is nonetheless important when
|
||
|
graph building for the case when a slot variable has already been
|
||
|
created but `variable` has just been added to a dependency graph
|
||
|
(causing us to realize that the slot variable needs to be restored).
|
||
|
|
||
|
Args:
|
||
|
slot_variable_position: A `trackable._CheckpointPosition` object
|
||
|
indicating the slot variable `Trackable` object to be restored.
|
||
|
slot_name: The name of this `Optimizer`'s slot to restore into.
|
||
|
variable: The variable object this slot is being created for.
|
||
|
|
||
|
Returns:
|
||
|
A slot variable that should have a value restored into it, or None if
|
||
|
a slot variable should not be restored at this time.
|
||
|
"""
|
||
|
variable_key = _var_key(variable)
|
||
|
slot_dict = self._slots.get(variable_key, {})
|
||
|
slot_variable = slot_dict.get(slot_name, None)
|
||
|
if (
|
||
|
slot_variable is None
|
||
|
and tf.executing_eagerly()
|
||
|
and slot_variable_position.is_simple_variable()
|
||
|
# Defer slot variable creation if there is an active variable
|
||
|
# creator scope. Generally we'd like to eagerly create/restore slot
|
||
|
# variables when possible, but this may mean that scopes intended to
|
||
|
# catch `variable` also catch its eagerly created slot variable
|
||
|
# unintentionally (specifically make_template would add a dependency
|
||
|
# on a slot variable if not for this case). Deferring is mostly
|
||
|
# harmless (aside from double initialization), and makes variable
|
||
|
# creator scopes behave the same way they do when graph building.
|
||
|
#
|
||
|
# One notable case is with distribution strategy, which uses
|
||
|
# variable creator scope but always desires the `variable` and the
|
||
|
# slot to use the same scope, thus we can safely eagerly
|
||
|
# create/restore slot variables.
|
||
|
and (
|
||
|
not tf.compat.v1.get_default_graph()._variable_creator_stack
|
||
|
or self._distribution_strategy
|
||
|
)
|
||
|
):
|
||
|
initializer = (
|
||
|
tf.__internal__.tracking.CheckpointInitialValueCallable(
|
||
|
checkpoint_position=slot_variable_position
|
||
|
)
|
||
|
)
|
||
|
slot_variable = self.add_slot(
|
||
|
var=variable,
|
||
|
initializer=initializer,
|
||
|
slot_name=slot_name,
|
||
|
shape=slot_variable_position.value_shape(),
|
||
|
)
|
||
|
# Slot variables are not owned by any one object (because we don't
|
||
|
# want to save the slot variable if the optimizer is saved without
|
||
|
# the non-slot variable, or if the non-slot variable is saved
|
||
|
# without the optimizer; it's a dependency hypergraph with edges of
|
||
|
# the form (optimizer, non-slot variable, variable)). So we don't
|
||
|
# _track_ slot variables anywhere, and instead special-case this
|
||
|
# dependency and otherwise pretend it's a normal graph.
|
||
|
if slot_variable is not None:
|
||
|
# For sharded variables, we need the logic in get_slot to combine
|
||
|
# slot variables for its shards
|
||
|
if (slot_variable is variable) and (
|
||
|
isinstance(variable, tf.__internal__.distribute.ShardedVariable)
|
||
|
):
|
||
|
return self.get_slot(variable, slot_name)
|
||
|
# If we've either made this slot variable, or if we've pulled out an
|
||
|
# existing slot variable, we should restore it.
|
||
|
return slot_variable
|
||
|
else:
|
||
|
# We didn't make the slot variable. Defer restoring until it gets
|
||
|
# created normally. We keep a list rather than the one with the
|
||
|
# highest restore UID in case slot variables have their own
|
||
|
# dependencies, in which case those could differ between restores.
|
||
|
self._deferred_slot_restorations.setdefault(
|
||
|
slot_name, {}
|
||
|
).setdefault(variable_key, []).append(slot_variable_position)
|
||
|
return None
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def _distribution_strategy_scope(self):
|
||
|
"""Returns the `tf.distribute.Strategy` this optimizer was created
|
||
|
under."""
|
||
|
if self._distribution_strategy and not tf.distribute.has_strategy():
|
||
|
with self._distribution_strategy.scope():
|
||
|
yield self._distribution_strategy.scope()
|
||
|
else:
|
||
|
yield
|
||
|
|
||
|
|
||
|
def _var_key(var):
|
||
|
"""Key for representing a primary variable, for looking up slots.
|
||
|
|
||
|
In graph mode the name is derived from the var shared name.
|
||
|
In eager mode the name is derived from the var unique id.
|
||
|
If distribution strategy exists, get the primary variable first.
|
||
|
|
||
|
Args:
|
||
|
var: the variable.
|
||
|
|
||
|
Returns:
|
||
|
the unique name of the variable.
|
||
|
"""
|
||
|
|
||
|
# Get the distributed variable if it exists.
|
||
|
if hasattr(var, "_distributed_container"):
|
||
|
var = var._distributed_container()
|
||
|
elif (
|
||
|
tf_utils.is_extension_type(var)
|
||
|
and hasattr(var, "handle")
|
||
|
and hasattr(var.handle, "_distributed_container")
|
||
|
):
|
||
|
# For ResourceVariables, the _distributed_container attribute
|
||
|
# is added to their handle tensors.
|
||
|
var = var.handle._distributed_container()
|
||
|
if getattr(var, "_in_graph_mode", False):
|
||
|
return var._shared_name
|
||
|
return var._unique_id
|
||
|
|
||
|
|
||
|
def _get_slot_key_from_var(var, slot_name):
|
||
|
"""Get the slot key for the variable: var_name/slot_name."""
|
||
|
|
||
|
name = _var_key(var)
|
||
|
return name + "/" + slot_name
|
||
|
|
||
|
|
||
|
class RestoredOptimizer(OptimizerV2):
|
||
|
"""A non-functional Optimizer implementation for checkpoint compatibility.
|
||
|
|
||
|
Holds slot variables and hyperparameters when an optimizer is restored from
|
||
|
a SavedModel. These variables may be referenced in functions along with ops
|
||
|
created by the original optimizer, but currently we do not support using the
|
||
|
optimizer object itself (e.g. through `apply_gradients`).
|
||
|
"""
|
||
|
|
||
|
# TODO(allenl): Make the restored optimizer functional by tracing its apply
|
||
|
# methods.
|
||
|
|
||
|
def __init__(self):
|
||
|
super().__init__("RestoredOptimizer")
|
||
|
self._hypers_created = True
|
||
|
|
||
|
def get_config(self):
|
||
|
# TODO(allenl): Save and restore the Optimizer's config
|
||
|
raise NotImplementedError(
|
||
|
"Restoring functional Optimizers from SavedModels is not currently "
|
||
|
"supported. Please file a feature request if this limitation "
|
||
|
"bothers you."
|
||
|
)
|
||
|
|
||
|
|
||
|
tf.__internal__.saved_model.load.register_revived_type(
|
||
|
"optimizer",
|
||
|
lambda obj: isinstance(obj, OptimizerV2),
|
||
|
versions=[
|
||
|
tf.__internal__.saved_model.load.VersionedTypeRegistration(
|
||
|
object_factory=lambda proto: RestoredOptimizer(),
|
||
|
version=2,
|
||
|
min_producer_version=1,
|
||
|
min_consumer_version=1,
|
||
|
setter=RestoredOptimizer._set_hyper,
|
||
|
)
|
||
|
],
|
||
|
)
|