Intelegentny_Pszczelarz/.venv/Lib/site-packages/keras/models/sharpness_aware_minimization.py

192 lines
7.1 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sharpness Aware Minimization implementation."""
import copy
import tensorflow.compat.v2 as tf
from keras.engine import data_adapter
from keras.layers import deserialize as deserialize_layer
from keras.models import Model
from keras.saving.legacy.serialization import serialize_keras_object
from keras.saving.object_registration import register_keras_serializable
# isort: off
from tensorflow.python.util.tf_export import keras_export
@register_keras_serializable()
@keras_export("keras.models.experimental.SharpnessAwareMinimization", v1=[])
class SharpnessAwareMinimization(Model):
"""Sharpness aware minimization (SAM) training flow.
Sharpness-aware minimization (SAM) is a technique that improves the model
generalization and provides robustness to label noise. Mini-batch splitting
is proven to improve the SAM's performance, so users can control how mini
batches are split via setting the `num_batch_splits` argument.
Args:
model: `tf.keras.Model` instance. The inner model that does the
forward-backward pass.
rho: float, defaults to 0.05. The gradients scaling factor.
num_batch_splits: int, defaults to None. The number of mini batches to
split into from each data batch. If None, batches are not split into
sub-batches.
name: string, defaults to None. The name of the SAM model.
Reference:
[Pierre Foret et al., 2020](https://arxiv.org/abs/2010.01412)
"""
def __init__(self, model, rho=0.05, num_batch_splits=None, name=None):
super().__init__(name=name)
self.model = model
self.rho = rho
self.num_batch_splits = num_batch_splits
def train_step(self, data):
"""The logic of one SAM training step.
Args:
data: A nested structure of `Tensor`s. It should be of structure
(x, y, sample_weight) or (x, y).
Returns:
A dict mapping metric names to running average values.
"""
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
if self.num_batch_splits is not None:
x_split = tf.split(x, self.num_batch_splits)
y_split = tf.split(y, self.num_batch_splits)
else:
x_split = [x]
y_split = [y]
gradients_all_batches = []
pred_all_batches = []
for x_batch, y_batch in zip(x_split, y_split):
epsilon_w_cache = []
with tf.GradientTape() as tape:
pred = self.model(x_batch)
loss = self.compiled_loss(y_batch, pred)
pred_all_batches.append(pred)
trainable_variables = self.model.trainable_variables
gradients = tape.gradient(loss, trainable_variables)
gradients_order2_norm = self._gradients_order2_norm(gradients)
scale = self.rho / (gradients_order2_norm + 1e-12)
for gradient, variable in zip(gradients, trainable_variables):
epsilon_w = gradient * scale
self._distributed_apply_epsilon_w(
variable, epsilon_w, tf.distribute.get_strategy()
)
epsilon_w_cache.append(epsilon_w)
with tf.GradientTape() as tape:
pred = self(x_batch)
loss = self.compiled_loss(y_batch, pred)
gradients = tape.gradient(loss, trainable_variables)
if len(gradients_all_batches) == 0:
for gradient in gradients:
gradients_all_batches.append([gradient])
else:
for gradient, gradient_all_batches in zip(
gradients, gradients_all_batches
):
gradient_all_batches.append(gradient)
for variable, epsilon_w in zip(
trainable_variables, epsilon_w_cache
):
# Restore the variable to its original value before
# `apply_gradients()`.
self._distributed_apply_epsilon_w(
variable, -epsilon_w, tf.distribute.get_strategy()
)
gradients = []
for gradient_all_batches in gradients_all_batches:
gradients.append(tf.reduce_sum(gradient_all_batches, axis=0))
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
pred = tf.concat(pred_all_batches, axis=0)
self.compiled_metrics.update_state(y, pred, sample_weight)
return {m.name: m.result() for m in self.metrics}
def call(self, inputs):
"""Forward pass of SAM.
SAM delegates the forward pass call to the wrapped model.
Args:
inputs: Tensor. The model inputs.
Returns:
A Tensor, the outputs of the wrapped model for given `inputs`.
"""
return self.model(inputs)
def get_config(self):
config = super().get_config()
config.update(
{
"model": serialize_keras_object(self.model),
"rho": self.rho,
}
)
return config
@classmethod
def from_config(cls, config, custom_objects=None):
# Avoid mutating the input dict.
config = copy.deepcopy(config)
model = deserialize_layer(
config.pop("model"), custom_objects=custom_objects
)
config["model"] = model
return super().from_config(config, custom_objects)
def _distributed_apply_epsilon_w(self, var, epsilon_w, strategy):
# Helper function to apply epsilon_w on model variables.
if isinstance(
tf.distribute.get_strategy(),
(
tf.distribute.experimental.ParameterServerStrategy,
tf.distribute.experimental.CentralStorageStrategy,
),
):
# Under PSS and CSS, the AggregatingVariable has to be kept in sync.
def distribute_apply(strategy, var, epsilon_w):
strategy.extended.update(
var,
lambda x, y: x.assign_add(y),
args=(epsilon_w,),
group=False,
)
tf.__internal__.distribute.interim.maybe_merge_call(
distribute_apply, tf.distribute.get_strategy(), var, epsilon_w
)
else:
var.assign_add(epsilon_w)
def _gradients_order2_norm(self, gradients):
norm = tf.norm(
tf.stack([tf.norm(grad) for grad in gradients if grad is not None])
)
return norm