# Copyright 2022 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Sharpness Aware Minimization implementation.""" import copy import tensorflow.compat.v2 as tf from keras.engine import data_adapter from keras.layers import deserialize as deserialize_layer from keras.models import Model from keras.saving.legacy.serialization import serialize_keras_object from keras.saving.object_registration import register_keras_serializable # isort: off from tensorflow.python.util.tf_export import keras_export @register_keras_serializable() @keras_export("keras.models.experimental.SharpnessAwareMinimization", v1=[]) class SharpnessAwareMinimization(Model): """Sharpness aware minimization (SAM) training flow. Sharpness-aware minimization (SAM) is a technique that improves the model generalization and provides robustness to label noise. Mini-batch splitting is proven to improve the SAM's performance, so users can control how mini batches are split via setting the `num_batch_splits` argument. Args: model: `tf.keras.Model` instance. The inner model that does the forward-backward pass. rho: float, defaults to 0.05. The gradients scaling factor. num_batch_splits: int, defaults to None. The number of mini batches to split into from each data batch. If None, batches are not split into sub-batches. name: string, defaults to None. The name of the SAM model. Reference: [Pierre Foret et al., 2020](https://arxiv.org/abs/2010.01412) """ def __init__(self, model, rho=0.05, num_batch_splits=None, name=None): super().__init__(name=name) self.model = model self.rho = rho self.num_batch_splits = num_batch_splits def train_step(self, data): """The logic of one SAM training step. Args: data: A nested structure of `Tensor`s. It should be of structure (x, y, sample_weight) or (x, y). Returns: A dict mapping metric names to running average values. """ x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) if self.num_batch_splits is not None: x_split = tf.split(x, self.num_batch_splits) y_split = tf.split(y, self.num_batch_splits) else: x_split = [x] y_split = [y] gradients_all_batches = [] pred_all_batches = [] for x_batch, y_batch in zip(x_split, y_split): epsilon_w_cache = [] with tf.GradientTape() as tape: pred = self.model(x_batch) loss = self.compiled_loss(y_batch, pred) pred_all_batches.append(pred) trainable_variables = self.model.trainable_variables gradients = tape.gradient(loss, trainable_variables) gradients_order2_norm = self._gradients_order2_norm(gradients) scale = self.rho / (gradients_order2_norm + 1e-12) for gradient, variable in zip(gradients, trainable_variables): epsilon_w = gradient * scale self._distributed_apply_epsilon_w( variable, epsilon_w, tf.distribute.get_strategy() ) epsilon_w_cache.append(epsilon_w) with tf.GradientTape() as tape: pred = self(x_batch) loss = self.compiled_loss(y_batch, pred) gradients = tape.gradient(loss, trainable_variables) if len(gradients_all_batches) == 0: for gradient in gradients: gradients_all_batches.append([gradient]) else: for gradient, gradient_all_batches in zip( gradients, gradients_all_batches ): gradient_all_batches.append(gradient) for variable, epsilon_w in zip( trainable_variables, epsilon_w_cache ): # Restore the variable to its original value before # `apply_gradients()`. self._distributed_apply_epsilon_w( variable, -epsilon_w, tf.distribute.get_strategy() ) gradients = [] for gradient_all_batches in gradients_all_batches: gradients.append(tf.reduce_sum(gradient_all_batches, axis=0)) self.optimizer.apply_gradients(zip(gradients, trainable_variables)) pred = tf.concat(pred_all_batches, axis=0) self.compiled_metrics.update_state(y, pred, sample_weight) return {m.name: m.result() for m in self.metrics} def call(self, inputs): """Forward pass of SAM. SAM delegates the forward pass call to the wrapped model. Args: inputs: Tensor. The model inputs. Returns: A Tensor, the outputs of the wrapped model for given `inputs`. """ return self.model(inputs) def get_config(self): config = super().get_config() config.update( { "model": serialize_keras_object(self.model), "rho": self.rho, } ) return config @classmethod def from_config(cls, config, custom_objects=None): # Avoid mutating the input dict. config = copy.deepcopy(config) model = deserialize_layer( config.pop("model"), custom_objects=custom_objects ) config["model"] = model return super().from_config(config, custom_objects) def _distributed_apply_epsilon_w(self, var, epsilon_w, strategy): # Helper function to apply epsilon_w on model variables. if isinstance( tf.distribute.get_strategy(), ( tf.distribute.experimental.ParameterServerStrategy, tf.distribute.experimental.CentralStorageStrategy, ), ): # Under PSS and CSS, the AggregatingVariable has to be kept in sync. def distribute_apply(strategy, var, epsilon_w): strategy.extended.update( var, lambda x, y: x.assign_add(y), args=(epsilon_w,), group=False, ) tf.__internal__.distribute.interim.maybe_merge_call( distribute_apply, tf.distribute.get_strategy(), var, epsilon_w ) else: var.assign_add(epsilon_w) def _gradients_order2_norm(self, gradients): norm = tf.norm( tf.stack([tf.norm(grad) for grad in gradients if grad is not None]) ) return norm