328 lines
12 KiB
Python
328 lines
12 KiB
Python
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""DTensor specific Keras optimizers."""
|
|
|
|
|
|
import tensorflow.compat.v2 as tf
|
|
|
|
from keras.dtensor import dtensor_api as dtensor
|
|
from keras.optimizers import adadelta
|
|
from keras.optimizers import adagrad
|
|
from keras.optimizers import adam
|
|
from keras.optimizers import adamw
|
|
from keras.optimizers import optimizer as optimizer_lib
|
|
from keras.optimizers import rmsprop
|
|
from keras.optimizers import sgd
|
|
from keras.optimizers.schedules import learning_rate_schedule
|
|
|
|
# isort: off
|
|
from tensorflow.python.util.tf_export import keras_export
|
|
from tensorflow.tools.docs import doc_controls
|
|
|
|
|
|
class Optimizer(optimizer_lib._BaseOptimizer):
|
|
"""DTensor specific optimizers.
|
|
|
|
The major changes for this class is that all the variable init logic will be
|
|
mesh/layout aware.
|
|
"""
|
|
|
|
# Note that we didn't subclass optimizer_lib.Optimizer since it contains the
|
|
# extra logic of handling distribution strategy, which we don't need for
|
|
# DTensor
|
|
|
|
def __init__(self, name, mesh=None):
|
|
"""Create a new Optimizer.
|
|
|
|
Args:
|
|
name: String. The name of the optimizer, which will appear in all the
|
|
state variables created by this optimizer.
|
|
mesh: dtensor.Mesh. The optional Mesh which will be used to create the
|
|
states. Note that usually the state variable will use the layout
|
|
from the corresponding model variables. This mesh only used for
|
|
global variables like globle steps, learning rate, etc.
|
|
"""
|
|
# TODO(scottzhu): Skip the gradients_clip_option and ema_option for now,
|
|
# and will cover them in future if really needed.
|
|
# TODO(scottzhu): We might want to make mesh to be required in future.
|
|
self._mesh = mesh
|
|
super().__init__(name=name)
|
|
|
|
def _create_iteration_variable(self):
|
|
init_val = tf.constant(0, dtype=tf.int64)
|
|
if self._mesh:
|
|
init_val = dtensor.copy_to_mesh(
|
|
init_val, dtensor.Layout.replicated(self._mesh, rank=0)
|
|
)
|
|
with tf.init_scope():
|
|
# Lift the variable creation to init scope to avoid environment
|
|
# issue.
|
|
self._iterations = dtensor.DVariable(init_val, name="iteration")
|
|
|
|
################## Override methods from keras.Optimizer ################
|
|
def add_variable_from_reference(
|
|
self, model_variable, variable_name, initial_value=None
|
|
):
|
|
"""Create an optimizer variable from model variable.
|
|
|
|
Create an optimizer variable based on the information of model variable.
|
|
For example, in SGD optimizer momemtum, for each model variable, a
|
|
corresponding momemtum variable is created of the same shape and dtype.
|
|
|
|
Args:
|
|
model_variable: The corresponding model variable to the optimizer
|
|
variable to be created.
|
|
variable_name: The name prefix of the optimizer variable to be
|
|
created. The create variables name will follow the pattern
|
|
`{variable_name}/{model_variable.name}`, e.g., `momemtum/dense_1`.
|
|
initial_value: The initial value of the optimizer variable, if None,
|
|
the value will be default to 0.
|
|
|
|
Returns:
|
|
An optimizer variable.
|
|
"""
|
|
if initial_value is None:
|
|
# Use tf.zeros_like which will propagate the layout information from
|
|
# the model weights if any.
|
|
initial_value = tf.zeros_like(model_variable)
|
|
elif isinstance(initial_value, tf.Tensor):
|
|
initial_value = dtensor.copy_to_mesh(
|
|
initial_value,
|
|
dtensor.Layout.replicated(
|
|
self._mesh, rank=initial_value.shape.rank
|
|
),
|
|
)
|
|
variable = dtensor.DVariable(
|
|
initial_value=initial_value,
|
|
name=f"{variable_name}/{model_variable._shared_name}",
|
|
dtype=model_variable.dtype,
|
|
trainable=False,
|
|
)
|
|
self._variables.append(variable)
|
|
return variable
|
|
|
|
@doc_controls.do_not_generate_docs
|
|
def aggregate_gradients(self, grads_and_vars):
|
|
# Hide the aggregate_gradients from Optimizer.aggregate_gradients
|
|
raise NotImplementedError(
|
|
"Dtensor doesn't need to manually aggregate gradients"
|
|
)
|
|
|
|
def _var_key(self, variable):
|
|
"""Get a unique identifier of the given variable."""
|
|
return optimizer_lib._BaseOptimizer._var_key(self, variable)
|
|
|
|
def apply_gradients(self, grads_and_vars):
|
|
"""Apply gradients to variables.
|
|
|
|
Args:
|
|
grads_and_vars: List of (gradient, variable) pairs.
|
|
|
|
Returns:
|
|
None
|
|
|
|
Raises:
|
|
TypeError: If `grads_and_vars` is malformed.
|
|
"""
|
|
# Explicitly call the _BaseOptimizer to avoid any chance of using
|
|
# Optimizers.apply_gradients which contains distribution strategy logic.
|
|
optimizer_lib._BaseOptimizer.apply_gradients(self, grads_and_vars)
|
|
|
|
def _internal_apply_gradients(self, grads_and_vars):
|
|
"""Helper function of apply gradients.
|
|
|
|
This is required for separating out distributed training logic.
|
|
|
|
Args:
|
|
grads_and_vars: List of (gradient, variable) pairs.
|
|
"""
|
|
# Explicitly call the _BaseOptimizer to avoid any chance of using
|
|
# Optimizers.apply_gradients which contains distribution strategy logic.
|
|
optimizer_lib._BaseOptimizer._internal_apply_gradients(
|
|
self, grads_and_vars
|
|
)
|
|
|
|
def _overwrite_model_variables_with_average_value_helper(self, var_list):
|
|
"""Helper function to _overwrite_model_variables_with_average_value."""
|
|
(
|
|
optimizer_lib._BaseOptimizer._overwrite_model_variables_with_average_value_helper( # noqa: E501
|
|
self, var_list
|
|
)
|
|
)
|
|
|
|
def _build_learning_rate(self, learning_rate):
|
|
if isinstance(
|
|
learning_rate, learning_rate_schedule.LearningRateSchedule
|
|
):
|
|
# Create a variable to hold the current learning rate.
|
|
# Note that the init value `learning_rate(self.iterations)` should
|
|
# have the correct layout information from self.iterations.
|
|
self._current_learning_rate = dtensor.DVariable(
|
|
learning_rate(self.iterations),
|
|
name="learning_rate",
|
|
dtype=tf.float32,
|
|
)
|
|
return learning_rate
|
|
init_val = tf.constant(learning_rate, dtype=tf.float32)
|
|
if self._mesh:
|
|
init_val = dtensor.copy_to_mesh(
|
|
init_val, dtensor.Layout.replicated(self._mesh, rank=0)
|
|
)
|
|
return dtensor.DVariable(init_val, name="learning_rate")
|
|
|
|
|
|
@keras_export("keras.dtensor.experimental.optimizers.Adadelta", v1=[])
|
|
class Adadelta(Optimizer, adadelta.Adadelta):
|
|
def __init__(
|
|
self,
|
|
learning_rate=0.001,
|
|
rho=0.95,
|
|
epsilon=1e-7,
|
|
gradients_clip_option=None,
|
|
ema_option=None,
|
|
name="Adadelta",
|
|
mesh=None,
|
|
):
|
|
# Skip the adam.Adadelta.__init__ and only call the Optimizer.__init__
|
|
# this is to skip the keras.Optimizer.__init__, which contains the logic
|
|
# of distribution strategy. Same for all the optimizers subclasses.
|
|
Optimizer.__init__(self, name=name, mesh=mesh)
|
|
self._learning_rate = self._build_learning_rate(learning_rate)
|
|
self.rho = rho
|
|
self.epsilon = epsilon
|
|
|
|
|
|
@keras_export("keras.dtensor.experimental.optimizers.Adagrad", v1=[])
|
|
class Adagrad(Optimizer, adagrad.Adagrad):
|
|
def __init__(
|
|
self,
|
|
learning_rate=0.001,
|
|
initial_accumulator_value=0.1,
|
|
epsilon=1e-7,
|
|
gradients_clip_option=None,
|
|
ema_option=None,
|
|
name="Adagrad",
|
|
mesh=None,
|
|
):
|
|
Optimizer.__init__(self, name=name, mesh=mesh)
|
|
self._learning_rate = self._build_learning_rate(learning_rate)
|
|
self.initial_accumulator_value = initial_accumulator_value
|
|
self.epsilon = epsilon
|
|
|
|
|
|
@keras_export("keras.dtensor.experimental.optimizers.Adam", v1=[])
|
|
class Adam(Optimizer, adam.Adam):
|
|
def __init__(
|
|
self,
|
|
learning_rate=0.001,
|
|
beta_1=0.9,
|
|
beta_2=0.999,
|
|
epsilon=1e-7,
|
|
amsgrad=False,
|
|
gradients_clip_option=None,
|
|
ema_option=None,
|
|
name="Adam",
|
|
mesh=None,
|
|
):
|
|
Optimizer.__init__(self, name=name, mesh=mesh)
|
|
self._learning_rate = self._build_learning_rate(learning_rate)
|
|
self.beta_1 = beta_1
|
|
self.beta_2 = beta_2
|
|
self.epsilon = epsilon
|
|
self.amsgrad = amsgrad
|
|
|
|
|
|
@keras_export("keras.dtensor.experimental.optimizers.AdamW", v1=[])
|
|
class AdamW(Optimizer, adamw.AdamW):
|
|
def __init__(
|
|
self,
|
|
learning_rate=0.001,
|
|
weight_decay=0.004,
|
|
beta_1=0.9,
|
|
beta_2=0.999,
|
|
epsilon=1e-7,
|
|
amsgrad=False,
|
|
name="AdamW",
|
|
mesh=None,
|
|
):
|
|
Optimizer.__init__(self, name=name, mesh=mesh)
|
|
self._learning_rate = self._build_learning_rate(learning_rate)
|
|
self.weight_decay = weight_decay
|
|
self.beta_1 = beta_1
|
|
self.beta_2 = beta_2
|
|
self.epsilon = epsilon
|
|
self.amsgrad = amsgrad
|
|
|
|
if self.weight_decay is None:
|
|
raise ValueError(
|
|
"Missing value of `weight_decay` which is required and"
|
|
" must be a float value."
|
|
)
|
|
|
|
|
|
@keras_export("keras.dtensor.experimental.optimizers.RMSprop", v1=[])
|
|
class RMSprop(Optimizer, rmsprop.RMSprop):
|
|
def __init__(
|
|
self,
|
|
learning_rate=0.001,
|
|
rho=0.9,
|
|
momentum=0.0,
|
|
epsilon=1e-7,
|
|
centered=False,
|
|
gradients_clip_option=None,
|
|
ema_option=None,
|
|
jit_compile=False,
|
|
name="RMSprop",
|
|
mesh=None,
|
|
):
|
|
Optimizer.__init__(self, name=name, mesh=mesh)
|
|
self._learning_rate = self._build_learning_rate(learning_rate)
|
|
self.rho = rho
|
|
self.momentum = momentum
|
|
self.epsilon = epsilon
|
|
self.centered = centered
|
|
|
|
|
|
@keras_export("keras.dtensor.experimental.optimizers.SGD", v1=[])
|
|
class SGD(Optimizer, sgd.SGD):
|
|
def __init__(
|
|
self,
|
|
learning_rate=0.01,
|
|
momentum=0.0,
|
|
nesterov=False,
|
|
amsgrad=False,
|
|
gradients_clip_option=None,
|
|
ema_option=None,
|
|
jit_compile=False,
|
|
name="SGD",
|
|
mesh=None,
|
|
):
|
|
Optimizer.__init__(self, name=name, mesh=mesh)
|
|
self._learning_rate = self._build_learning_rate(learning_rate)
|
|
self.momentum = momentum
|
|
self.nesterov = nesterov
|
|
if isinstance(momentum, (int, float)) and (
|
|
momentum < 0 or momentum > 1
|
|
):
|
|
raise ValueError("`momentum` must be between [0, 1].")
|
|
|
|
|
|
Adadelta.__doc__ = Optimizer.__doc__ + adadelta.Adadelta.__doc__
|
|
Adagrad.__doc__ = Optimizer.__doc__ + adagrad.Adagrad.__doc__
|
|
Adam.__doc__ = Optimizer.__doc__ + adam.Adam.__doc__
|
|
AdamW.__doc__ = Optimizer.__doc__ + adamw.AdamW.__doc__
|
|
RMSprop.__doc__ = Optimizer.__doc__ + rmsprop.RMSprop.__doc__
|
|
SGD.__doc__ = Optimizer.__doc__ + sgd.SGD.__doc__
|