Intelegentny_Pszczelarz/.venv/Lib/site-packages/keras/optimizers/__init__.py
2023-06-19 00:49:18 +02:00

321 lines
12 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Built-in optimizer classes.
For more examples see the base class `tf.keras.optimizers.Optimizer`.
"""
# Imports needed for deserialization.
import platform
import tensorflow.compat.v2 as tf
from absl import logging
from keras import backend
from keras.optimizers import adadelta
from keras.optimizers import adafactor
from keras.optimizers import adagrad
from keras.optimizers import adam
from keras.optimizers import adamax
from keras.optimizers import adamw
from keras.optimizers import ftrl
from keras.optimizers import nadam
from keras.optimizers import optimizer as base_optimizer
from keras.optimizers import rmsprop
from keras.optimizers import sgd
from keras.optimizers.legacy import adadelta as adadelta_legacy
from keras.optimizers.legacy import adagrad as adagrad_legacy
from keras.optimizers.legacy import adam as adam_legacy
from keras.optimizers.legacy import adamax as adamax_legacy
from keras.optimizers.legacy import ftrl as ftrl_legacy
from keras.optimizers.legacy import gradient_descent as gradient_descent_legacy
from keras.optimizers.legacy import nadam as nadam_legacy
from keras.optimizers.legacy import optimizer_v2 as base_optimizer_legacy
from keras.optimizers.legacy import rmsprop as rmsprop_legacy
from keras.optimizers.legacy.adadelta import Adadelta
from keras.optimizers.legacy.adagrad import Adagrad
from keras.optimizers.legacy.adam import Adam
from keras.optimizers.legacy.adamax import Adamax
from keras.optimizers.legacy.ftrl import Ftrl
# Symbols to be accessed under keras.optimizers. To be replaced with
# optimizers v2022 when they graduate out of experimental.
from keras.optimizers.legacy.gradient_descent import SGD
from keras.optimizers.legacy.nadam import Nadam
from keras.optimizers.legacy.rmsprop import RMSprop
from keras.optimizers.optimizer_v1 import Optimizer
from keras.optimizers.optimizer_v1 import TFOptimizer
from keras.optimizers.schedules import learning_rate_schedule
from keras.saving.legacy import serialization as legacy_serialization
from keras.saving.legacy.serialization import deserialize_keras_object
from keras.saving.legacy.serialization import serialize_keras_object
# isort: off
from tensorflow.python.util.tf_export import keras_export
@keras_export("keras.optimizers.serialize")
def serialize(optimizer, use_legacy_format=False):
"""Serialize the optimizer configuration to JSON compatible python dict.
The configuration can be used for persistence and reconstruct the
`Optimizer` instance again.
>>> tf.keras.optimizers.serialize(tf.keras.optimizers.legacy.SGD())
{'class_name': 'SGD', 'config': {'name': 'SGD', 'learning_rate': 0.01,
'decay': 0.0, 'momentum': 0.0,
'nesterov': False}}
Args:
optimizer: An `Optimizer` instance to serialize.
Returns:
Python dict which contains the configuration of the input optimizer.
"""
if use_legacy_format:
return legacy_serialization.serialize_keras_object(optimizer)
return serialize_keras_object(optimizer)
def is_arm_mac():
return platform.system() == "Darwin" and platform.processor() == "arm"
@keras_export("keras.optimizers.deserialize")
def deserialize(config, custom_objects=None, use_legacy_format=False, **kwargs):
"""Inverse of the `serialize` function.
Args:
config: Optimizer configuration dictionary.
custom_objects: Optional dictionary mapping names (strings) to custom
objects (classes and functions) to be considered during
deserialization.
Returns:
A Keras Optimizer instance.
"""
# loss_scale_optimizer has a direct dependency of optimizer, import here
# rather than top to avoid the cyclic dependency.
from keras.mixed_precision import (
loss_scale_optimizer,
)
use_legacy_optimizer = kwargs.pop("use_legacy_optimizer", False)
if kwargs:
raise TypeError(f"Invalid keyword arguments: {kwargs}")
if len(config["config"]) > 0:
# If the optimizer config is not empty, then we use the value of
# `is_legacy_optimizer` to override `use_legacy_optimizer`. If
# `is_legacy_optimizer` does not exist in config, it means we are
# using the legacy optimzier.
use_legacy_optimizer = config["config"].get("is_legacy_optimizer", True)
if (
tf.__internal__.tf2.enabled()
and tf.executing_eagerly()
and not is_arm_mac()
and not use_legacy_optimizer
):
# We observed a slowdown of optimizer on M1 Mac, so we fall back to the
# legacy optimizer for M1 users now, see b/263339144 for more context.
all_classes = {
"adadelta": adadelta.Adadelta,
"adagrad": adagrad.Adagrad,
"adam": adam.Adam,
"adamax": adamax.Adamax,
"experimentaladadelta": adadelta.Adadelta,
"experimentaladagrad": adagrad.Adagrad,
"experimentaladam": adam.Adam,
"experimentalsgd": sgd.SGD,
"nadam": nadam.Nadam,
"rmsprop": rmsprop.RMSprop,
"sgd": sgd.SGD,
"ftrl": ftrl.Ftrl,
"lossscaleoptimizer": loss_scale_optimizer.LossScaleOptimizerV3,
"lossscaleoptimizerv3": loss_scale_optimizer.LossScaleOptimizerV3,
# LossScaleOptimizerV1 was an old version of LSO that was removed.
# Deserializing it turns it into a LossScaleOptimizer
"lossscaleoptimizerv1": loss_scale_optimizer.LossScaleOptimizer,
}
else:
all_classes = {
"adadelta": adadelta_legacy.Adadelta,
"adagrad": adagrad_legacy.Adagrad,
"adam": adam_legacy.Adam,
"adamax": adamax_legacy.Adamax,
"experimentaladadelta": adadelta.Adadelta,
"experimentaladagrad": adagrad.Adagrad,
"experimentaladam": adam.Adam,
"experimentalsgd": sgd.SGD,
"nadam": nadam_legacy.Nadam,
"rmsprop": rmsprop_legacy.RMSprop,
"sgd": gradient_descent_legacy.SGD,
"ftrl": ftrl_legacy.Ftrl,
"lossscaleoptimizer": loss_scale_optimizer.LossScaleOptimizer,
"lossscaleoptimizerv3": loss_scale_optimizer.LossScaleOptimizerV3,
# LossScaleOptimizerV1 was an old version of LSO that was removed.
# Deserializing it turns it into a LossScaleOptimizer
"lossscaleoptimizerv1": loss_scale_optimizer.LossScaleOptimizer,
}
# Make deserialization case-insensitive for built-in optimizers.
if config["class_name"].lower() in all_classes:
config["class_name"] = config["class_name"].lower()
if use_legacy_format:
return legacy_serialization.deserialize_keras_object(
config,
module_objects=all_classes,
custom_objects=custom_objects,
printable_module_name="optimizer",
)
return deserialize_keras_object(
config,
module_objects=all_classes,
custom_objects=custom_objects,
printable_module_name="optimizer",
)
@keras_export(
"keras.__internal__.optimizers.convert_to_legacy_optimizer", v1=[]
)
def convert_to_legacy_optimizer(optimizer):
"""Convert experimental optimizer to legacy optimizer.
This function takes in a `tf.keras.optimizers.experimental.Optimizer`
instance and converts it to the corresponding
`tf.keras.optimizers.legacy.Optimizer` instance.
For example, `tf.keras.optimizers.experimental.Adam(...)` to
`tf.keras.optimizers.legacy.Adam(...)`.
Args:
optimizer: An instance of `tf.keras.optimizers.experimental.Optimizer`.
"""
# loss_scale_optimizer has a direct dependency of optimizer, import here
# rather than top to avoid the cyclic dependency.
from keras.mixed_precision import (
loss_scale_optimizer,
)
if not isinstance(optimizer, base_optimizer.Optimizer):
raise ValueError(
"`convert_to_legacy_optimizer` should only be called "
"on instances of `tf.keras.optimizers.Optimizer`, but "
f"received {optimizer} of type {type(optimizer)}."
)
optimizer_name = optimizer.__class__.__name__.lower()
config = optimizer.get_config()
# Remove fields that only exist in experimental optimizer.
keys_to_remove = [
"weight_decay",
"use_ema",
"ema_momentum",
"ema_overwrite_frequency",
"jit_compile",
"is_legacy_optimizer",
]
for key in keys_to_remove:
config.pop(key, None)
if isinstance(optimizer, loss_scale_optimizer.LossScaleOptimizerV3):
# For LossScaleOptimizers, recursively convert the inner optimizer
config["inner_optimizer"] = convert_to_legacy_optimizer(
optimizer.inner_optimizer
)
if optimizer_name == "lossscaleoptimizerv3":
optimizer_name = "lossscaleoptimizer"
# Learning rate can be a custom LearningRateSchedule, which is stored as
# a dict in config, and cannot be deserialized.
if hasattr(optimizer, "_learning_rate") and isinstance(
optimizer._learning_rate, learning_rate_schedule.LearningRateSchedule
):
config["learning_rate"] = optimizer._learning_rate
legacy_optimizer_config = {
"class_name": optimizer_name,
"config": config,
}
return deserialize(legacy_optimizer_config, use_legacy_optimizer=True)
@keras_export("keras.optimizers.get")
def get(identifier, **kwargs):
"""Retrieves a Keras Optimizer instance.
Args:
identifier: Optimizer identifier, one of - String: name of an optimizer
- Dictionary: configuration dictionary. - Keras Optimizer instance (it
will be returned unchanged). - TensorFlow Optimizer instance (it will
be wrapped as a Keras Optimizer).
Returns:
A Keras Optimizer instance.
Raises:
ValueError: If `identifier` cannot be interpreted.
"""
use_legacy_optimizer = kwargs.pop("use_legacy_optimizer", False)
if kwargs:
raise TypeError(f"Invalid keyword arguments: {kwargs}")
if isinstance(
identifier,
(
Optimizer,
base_optimizer_legacy.OptimizerV2,
),
):
return identifier
elif isinstance(identifier, base_optimizer.Optimizer):
if tf.__internal__.tf2.enabled() and not is_arm_mac():
return identifier
else:
# If TF2 is disabled or on a M1 mac, we convert to the legacy
# optimizer. We observed a slowdown of optimizer on M1 Mac, so we
# fall back to the legacy optimizer for now, see b/263339144
# for more context.
optimizer_name = identifier.__class__.__name__
logging.warning(
"There is a known slowdown when using v2.11+ Keras optimizers "
"on M1/M2 Macs. Falling back to the "
"legacy Keras optimizer, i.e., "
f"`tf.keras.optimizers.legacy.{optimizer_name}`."
)
return convert_to_legacy_optimizer(identifier)
# Wrap legacy TF optimizer instances
elif isinstance(identifier, tf.compat.v1.train.Optimizer):
opt = TFOptimizer(identifier)
backend.track_tf_optimizer(opt)
return opt
elif isinstance(identifier, dict):
use_legacy_format = "module" not in identifier
return deserialize(
identifier,
use_legacy_optimizer=use_legacy_optimizer,
use_legacy_format=use_legacy_format,
)
elif isinstance(identifier, str):
config = {"class_name": str(identifier), "config": {}}
return deserialize(
config,
use_legacy_optimizer=use_legacy_optimizer,
)
else:
raise ValueError(
f"Could not interpret optimizer identifier: {identifier}"
)