# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Built-in optimizer classes. For more examples see the base class `tf.keras.optimizers.Optimizer`. """ # Imports needed for deserialization. import platform import tensorflow.compat.v2 as tf from absl import logging from keras import backend from keras.optimizers import adadelta from keras.optimizers import adafactor from keras.optimizers import adagrad from keras.optimizers import adam from keras.optimizers import adamax from keras.optimizers import adamw from keras.optimizers import ftrl from keras.optimizers import nadam from keras.optimizers import optimizer as base_optimizer from keras.optimizers import rmsprop from keras.optimizers import sgd from keras.optimizers.legacy import adadelta as adadelta_legacy from keras.optimizers.legacy import adagrad as adagrad_legacy from keras.optimizers.legacy import adam as adam_legacy from keras.optimizers.legacy import adamax as adamax_legacy from keras.optimizers.legacy import ftrl as ftrl_legacy from keras.optimizers.legacy import gradient_descent as gradient_descent_legacy from keras.optimizers.legacy import nadam as nadam_legacy from keras.optimizers.legacy import optimizer_v2 as base_optimizer_legacy from keras.optimizers.legacy import rmsprop as rmsprop_legacy from keras.optimizers.legacy.adadelta import Adadelta from keras.optimizers.legacy.adagrad import Adagrad from keras.optimizers.legacy.adam import Adam from keras.optimizers.legacy.adamax import Adamax from keras.optimizers.legacy.ftrl import Ftrl # Symbols to be accessed under keras.optimizers. To be replaced with # optimizers v2022 when they graduate out of experimental. from keras.optimizers.legacy.gradient_descent import SGD from keras.optimizers.legacy.nadam import Nadam from keras.optimizers.legacy.rmsprop import RMSprop from keras.optimizers.optimizer_v1 import Optimizer from keras.optimizers.optimizer_v1 import TFOptimizer from keras.optimizers.schedules import learning_rate_schedule from keras.saving.legacy import serialization as legacy_serialization from keras.saving.legacy.serialization import deserialize_keras_object from keras.saving.legacy.serialization import serialize_keras_object # isort: off from tensorflow.python.util.tf_export import keras_export @keras_export("keras.optimizers.serialize") def serialize(optimizer, use_legacy_format=False): """Serialize the optimizer configuration to JSON compatible python dict. The configuration can be used for persistence and reconstruct the `Optimizer` instance again. >>> tf.keras.optimizers.serialize(tf.keras.optimizers.legacy.SGD()) {'class_name': 'SGD', 'config': {'name': 'SGD', 'learning_rate': 0.01, 'decay': 0.0, 'momentum': 0.0, 'nesterov': False}} Args: optimizer: An `Optimizer` instance to serialize. Returns: Python dict which contains the configuration of the input optimizer. """ if use_legacy_format: return legacy_serialization.serialize_keras_object(optimizer) return serialize_keras_object(optimizer) def is_arm_mac(): return platform.system() == "Darwin" and platform.processor() == "arm" @keras_export("keras.optimizers.deserialize") def deserialize(config, custom_objects=None, use_legacy_format=False, **kwargs): """Inverse of the `serialize` function. Args: config: Optimizer configuration dictionary. custom_objects: Optional dictionary mapping names (strings) to custom objects (classes and functions) to be considered during deserialization. Returns: A Keras Optimizer instance. """ # loss_scale_optimizer has a direct dependency of optimizer, import here # rather than top to avoid the cyclic dependency. from keras.mixed_precision import ( loss_scale_optimizer, ) use_legacy_optimizer = kwargs.pop("use_legacy_optimizer", False) if kwargs: raise TypeError(f"Invalid keyword arguments: {kwargs}") if len(config["config"]) > 0: # If the optimizer config is not empty, then we use the value of # `is_legacy_optimizer` to override `use_legacy_optimizer`. If # `is_legacy_optimizer` does not exist in config, it means we are # using the legacy optimzier. use_legacy_optimizer = config["config"].get("is_legacy_optimizer", True) if ( tf.__internal__.tf2.enabled() and tf.executing_eagerly() and not is_arm_mac() and not use_legacy_optimizer ): # We observed a slowdown of optimizer on M1 Mac, so we fall back to the # legacy optimizer for M1 users now, see b/263339144 for more context. all_classes = { "adadelta": adadelta.Adadelta, "adagrad": adagrad.Adagrad, "adam": adam.Adam, "adamax": adamax.Adamax, "experimentaladadelta": adadelta.Adadelta, "experimentaladagrad": adagrad.Adagrad, "experimentaladam": adam.Adam, "experimentalsgd": sgd.SGD, "nadam": nadam.Nadam, "rmsprop": rmsprop.RMSprop, "sgd": sgd.SGD, "ftrl": ftrl.Ftrl, "lossscaleoptimizer": loss_scale_optimizer.LossScaleOptimizerV3, "lossscaleoptimizerv3": loss_scale_optimizer.LossScaleOptimizerV3, # LossScaleOptimizerV1 was an old version of LSO that was removed. # Deserializing it turns it into a LossScaleOptimizer "lossscaleoptimizerv1": loss_scale_optimizer.LossScaleOptimizer, } else: all_classes = { "adadelta": adadelta_legacy.Adadelta, "adagrad": adagrad_legacy.Adagrad, "adam": adam_legacy.Adam, "adamax": adamax_legacy.Adamax, "experimentaladadelta": adadelta.Adadelta, "experimentaladagrad": adagrad.Adagrad, "experimentaladam": adam.Adam, "experimentalsgd": sgd.SGD, "nadam": nadam_legacy.Nadam, "rmsprop": rmsprop_legacy.RMSprop, "sgd": gradient_descent_legacy.SGD, "ftrl": ftrl_legacy.Ftrl, "lossscaleoptimizer": loss_scale_optimizer.LossScaleOptimizer, "lossscaleoptimizerv3": loss_scale_optimizer.LossScaleOptimizerV3, # LossScaleOptimizerV1 was an old version of LSO that was removed. # Deserializing it turns it into a LossScaleOptimizer "lossscaleoptimizerv1": loss_scale_optimizer.LossScaleOptimizer, } # Make deserialization case-insensitive for built-in optimizers. if config["class_name"].lower() in all_classes: config["class_name"] = config["class_name"].lower() if use_legacy_format: return legacy_serialization.deserialize_keras_object( config, module_objects=all_classes, custom_objects=custom_objects, printable_module_name="optimizer", ) return deserialize_keras_object( config, module_objects=all_classes, custom_objects=custom_objects, printable_module_name="optimizer", ) @keras_export( "keras.__internal__.optimizers.convert_to_legacy_optimizer", v1=[] ) def convert_to_legacy_optimizer(optimizer): """Convert experimental optimizer to legacy optimizer. This function takes in a `tf.keras.optimizers.experimental.Optimizer` instance and converts it to the corresponding `tf.keras.optimizers.legacy.Optimizer` instance. For example, `tf.keras.optimizers.experimental.Adam(...)` to `tf.keras.optimizers.legacy.Adam(...)`. Args: optimizer: An instance of `tf.keras.optimizers.experimental.Optimizer`. """ # loss_scale_optimizer has a direct dependency of optimizer, import here # rather than top to avoid the cyclic dependency. from keras.mixed_precision import ( loss_scale_optimizer, ) if not isinstance(optimizer, base_optimizer.Optimizer): raise ValueError( "`convert_to_legacy_optimizer` should only be called " "on instances of `tf.keras.optimizers.Optimizer`, but " f"received {optimizer} of type {type(optimizer)}." ) optimizer_name = optimizer.__class__.__name__.lower() config = optimizer.get_config() # Remove fields that only exist in experimental optimizer. keys_to_remove = [ "weight_decay", "use_ema", "ema_momentum", "ema_overwrite_frequency", "jit_compile", "is_legacy_optimizer", ] for key in keys_to_remove: config.pop(key, None) if isinstance(optimizer, loss_scale_optimizer.LossScaleOptimizerV3): # For LossScaleOptimizers, recursively convert the inner optimizer config["inner_optimizer"] = convert_to_legacy_optimizer( optimizer.inner_optimizer ) if optimizer_name == "lossscaleoptimizerv3": optimizer_name = "lossscaleoptimizer" # Learning rate can be a custom LearningRateSchedule, which is stored as # a dict in config, and cannot be deserialized. if hasattr(optimizer, "_learning_rate") and isinstance( optimizer._learning_rate, learning_rate_schedule.LearningRateSchedule ): config["learning_rate"] = optimizer._learning_rate legacy_optimizer_config = { "class_name": optimizer_name, "config": config, } return deserialize(legacy_optimizer_config, use_legacy_optimizer=True) @keras_export("keras.optimizers.get") def get(identifier, **kwargs): """Retrieves a Keras Optimizer instance. Args: identifier: Optimizer identifier, one of - String: name of an optimizer - Dictionary: configuration dictionary. - Keras Optimizer instance (it will be returned unchanged). - TensorFlow Optimizer instance (it will be wrapped as a Keras Optimizer). Returns: A Keras Optimizer instance. Raises: ValueError: If `identifier` cannot be interpreted. """ use_legacy_optimizer = kwargs.pop("use_legacy_optimizer", False) if kwargs: raise TypeError(f"Invalid keyword arguments: {kwargs}") if isinstance( identifier, ( Optimizer, base_optimizer_legacy.OptimizerV2, ), ): return identifier elif isinstance(identifier, base_optimizer.Optimizer): if tf.__internal__.tf2.enabled() and not is_arm_mac(): return identifier else: # If TF2 is disabled or on a M1 mac, we convert to the legacy # optimizer. We observed a slowdown of optimizer on M1 Mac, so we # fall back to the legacy optimizer for now, see b/263339144 # for more context. optimizer_name = identifier.__class__.__name__ logging.warning( "There is a known slowdown when using v2.11+ Keras optimizers " "on M1/M2 Macs. Falling back to the " "legacy Keras optimizer, i.e., " f"`tf.keras.optimizers.legacy.{optimizer_name}`." ) return convert_to_legacy_optimizer(identifier) # Wrap legacy TF optimizer instances elif isinstance(identifier, tf.compat.v1.train.Optimizer): opt = TFOptimizer(identifier) backend.track_tf_optimizer(opt) return opt elif isinstance(identifier, dict): use_legacy_format = "module" not in identifier return deserialize( identifier, use_legacy_optimizer=use_legacy_optimizer, use_legacy_format=use_legacy_format, ) elif isinstance(identifier, str): config = {"class_name": str(identifier), "config": {}} return deserialize( config, use_legacy_optimizer=use_legacy_optimizer, ) else: raise ValueError( f"Could not interpret optimizer identifier: {identifier}" )