# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """SGD optimizer implementation.""" import tensorflow.compat.v2 as tf from keras.optimizers.legacy import optimizer_v2 # isort: off from tensorflow.python.util.tf_export import keras_export @keras_export( "keras.optimizers.legacy.SGD", v1=["keras.optimizers.SGD", "keras.optimizers.legacy.SGD"], ) class SGD(optimizer_v2.OptimizerV2): r"""Gradient descent (with momentum) optimizer. Update rule for parameter `w` with gradient `g` when `momentum=0`: ```python w = w - learning_rate * g ``` Update rule when `momentum` is larger than 0: ```python velocity = momentum * velocity - learning_rate * g w = w + velocity ``` When `nesterov=True`, this rule becomes: ```python velocity = momentum * velocity - learning_rate * g w = w + momentum * velocity - learning_rate * g ``` Args: learning_rate: A `Tensor`, floating point value, or a schedule that is a `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable that takes no arguments and returns the actual value to use. The learning rate. Defaults to 0.01. momentum: float hyperparameter >= 0 that accelerates gradient descent in the relevant direction and dampens oscillations. Defaults to 0, i.e., vanilla gradient descent. nesterov: boolean. Whether to apply Nesterov momentum. Defaults to `False`. name: Optional name prefix for the operations created when applying gradients. Defaults to `"SGD"`. **kwargs: keyword arguments. Allowed arguments are `clipvalue`, `clipnorm`, `global_clipnorm`. If `clipvalue` (float) is set, the gradient of each weight is clipped to be no higher than this value. If `clipnorm` (float) is set, the gradient of each weight is individually clipped so that its norm is no higher than this value. If `global_clipnorm` (float) is set the gradient of all weights is clipped so that their global norm is no higher than this value. Usage: >>> opt = tf.keras.optimizers.legacy.SGD(learning_rate=0.1) >>> var = tf.Variable(1.0) >>> loss = lambda: (var ** 2)/2.0 # d(loss)/d(var1) = var1 >>> step_count = opt.minimize(loss, [var]).numpy() >>> # Step is `- learning_rate * grad` >>> var.numpy() 0.9 >>> opt = tf.keras.optimizers.legacy.SGD(learning_rate=0.1, momentum=0.9) >>> var = tf.Variable(1.0) >>> val0 = var.value() >>> loss = lambda: (var ** 2)/2.0 # d(loss)/d(var1) = var1 >>> # First step is `- learning_rate * grad` >>> step_count = opt.minimize(loss, [var]).numpy() >>> val1 = var.value() >>> (val0 - val1).numpy() 0.1 >>> # On later steps, step-size increases because of momentum >>> step_count = opt.minimize(loss, [var]).numpy() >>> val2 = var.value() >>> (val1 - val2).numpy() 0.18 Reference: - For `nesterov=True`, See [Sutskever et al., 2013]( https://github.com/mlresearch/v28/blob/gh-pages/sutskever13.pdf). """ _HAS_AGGREGATE_GRAD = True def __init__( self, learning_rate=0.01, momentum=0.0, nesterov=False, name="SGD", **kwargs, ): super().__init__(name, **kwargs) self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) self._set_hyper("decay", self._initial_decay) self._momentum = False if ( isinstance(momentum, tf.Tensor) or callable(momentum) or momentum > 0 ): self._momentum = True if isinstance(momentum, (int, float)) and ( momentum < 0 or momentum > 1 ): raise ValueError( "`momentum` must be between [0, 1]. Received: " f"momentum={momentum} (of type {type(momentum)})." ) self._set_hyper("momentum", momentum) self.nesterov = nesterov def _create_slots(self, var_list): if self._momentum: for var in var_list: self.add_slot(var, "momentum") def _prepare_local(self, var_device, var_dtype, apply_state): super()._prepare_local(var_device, var_dtype, apply_state) apply_state[(var_device, var_dtype)]["momentum"] = tf.identity( self._get_hyper("momentum", var_dtype) ) def _resource_apply_dense(self, grad, var, apply_state=None): var_device, var_dtype = var.device, var.dtype.base_dtype coefficients = (apply_state or {}).get( (var_device, var_dtype) ) or self._fallback_apply_state(var_device, var_dtype) if self._momentum: momentum_var = self.get_slot(var, "momentum") return tf.raw_ops.ResourceApplyKerasMomentum( var=var.handle, accum=momentum_var.handle, lr=coefficients["lr_t"], grad=grad, momentum=coefficients["momentum"], use_locking=self._use_locking, use_nesterov=self.nesterov, ) else: return tf.raw_ops.ResourceApplyGradientDescent( var=var.handle, alpha=coefficients["lr_t"], delta=grad, use_locking=self._use_locking, ) def _resource_apply_sparse_duplicate_indices( self, grad, var, indices, **kwargs ): if self._momentum: return super()._resource_apply_sparse_duplicate_indices( grad, var, indices, **kwargs ) else: var_device, var_dtype = var.device, var.dtype.base_dtype coefficients = kwargs.get("apply_state", {}).get( (var_device, var_dtype) ) or self._fallback_apply_state(var_device, var_dtype) return tf.raw_ops.ResourceScatterAdd( resource=var.handle, indices=indices, updates=-grad * coefficients["lr_t"], ) def _resource_apply_sparse(self, grad, var, indices, apply_state=None): # This method is only needed for momentum optimization. var_device, var_dtype = var.device, var.dtype.base_dtype coefficients = (apply_state or {}).get( (var_device, var_dtype) ) or self._fallback_apply_state(var_device, var_dtype) momentum_var = self.get_slot(var, "momentum") return tf.raw_ops.ResourceSparseApplyKerasMomentum( var=var.handle, accum=momentum_var.handle, lr=coefficients["lr_t"], grad=grad, indices=indices, momentum=coefficients["momentum"], use_locking=self._use_locking, use_nesterov=self.nesterov, ) def get_config(self): config = super().get_config() config.update( { "learning_rate": self._serialize_hyperparameter( "learning_rate" ), "decay": self._initial_decay, "momentum": self._serialize_hyperparameter("momentum"), "nesterov": self.nesterov, } ) return config