Intelegentny_Pszczelarz/.venv/Lib/site-packages/keras/saving/legacy/saving_utils.py

372 lines
13 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils related to keras model saving."""
import copy
import os
import tensorflow.compat.v2 as tf
import keras
from keras import backend
from keras import losses
from keras import optimizers
from keras.engine import base_layer_utils
from keras.optimizers import optimizer_v1
from keras.saving.legacy import serialization
from keras.utils import version_utils
from keras.utils.io_utils import ask_to_proceed_with_overwrite
# isort: off
from tensorflow.python.platform import tf_logging as logging
def extract_model_metrics(model):
"""Convert metrics from a Keras model `compile` API to dictionary.
This is used for converting Keras models to Estimators and SavedModels.
Args:
model: A `tf.keras.Model` object.
Returns:
Dictionary mapping metric names to metric instances. May return `None` if
the model does not contain any metrics.
"""
if getattr(model, "_compile_metrics", None):
# TODO(psv/kathywu): use this implementation in model to estimator flow.
# We are not using model.metrics here because we want to exclude the
# metrics added using `add_metric` API.
return {m.name: m for m in model._compile_metric_functions}
return None
def model_call_inputs(model, keep_original_batch_size=False):
"""Inspect model to get its input signature.
The model's input signature is a list with a single (possibly-nested)
object. This is due to the Keras-enforced restriction that tensor inputs
must be passed in as the first argument.
For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>}
will have input signature:
[{'feature1': TensorSpec, 'feature2': TensorSpec}]
Args:
model: Keras Model object.
keep_original_batch_size: A boolean indicating whether we want to keep
using the original batch size or set it to None. Default is `False`,
which means that the batch dim of the returned input signature will
always be set to `None`.
Returns:
A tuple containing `(args, kwargs)` TensorSpecs of the model call function
inputs.
`kwargs` does not contain the `training` argument.
"""
input_specs = model.save_spec(dynamic_batch=not keep_original_batch_size)
if input_specs is None:
return None, None
input_specs = _enforce_names_consistency(input_specs)
return input_specs
def raise_model_input_error(model):
if isinstance(model, keras.models.Sequential):
raise ValueError(
f"Model {model} cannot be saved because the input shape is not "
"available. Please specify an input shape either by calling "
"`build(input_shape)` directly, or by calling the model on actual "
"data using `Model()`, `Model.fit()`, or `Model.predict()`."
)
# If the model is not a `Sequential`, it is intended to be a subclassed
# model.
raise ValueError(
f"Model {model} cannot be saved either because the input shape is not "
"available or because the forward pass of the model is not defined."
"To define a forward pass, please override `Model.call()`. To specify "
"an input shape, either call `build(input_shape)` directly, or call "
"the model on actual data using `Model()`, `Model.fit()`, or "
"`Model.predict()`. If you have a custom training step, please make "
"sure to invoke the forward pass in train step through "
"`Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`."
)
def trace_model_call(model, input_signature=None):
"""Trace the model call to create a tf.function for exporting a Keras model.
Args:
model: A Keras model.
input_signature: optional, a list of tf.TensorSpec objects specifying the
inputs to the model.
Returns:
A tf.function wrapping the model's call function with input signatures
set.
Raises:
ValueError: if input signature cannot be inferred from the model.
"""
if input_signature is None:
if isinstance(model.call, tf.__internal__.function.Function):
input_signature = model.call.input_signature
if input_signature:
model_args = input_signature
model_kwargs = {}
else:
model_args, model_kwargs = model_call_inputs(model)
if model_args is None:
raise_model_input_error(model)
@tf.function
def _wrapped_model(*args, **kwargs):
"""A concrete tf.function that wraps the model's call function."""
(args, kwargs,) = model._call_spec.set_arg_value(
"training", False, args, kwargs, inputs_in_args=True
)
with base_layer_utils.call_context().enter(
model, inputs=None, build_graph=False, training=False, saving=True
):
outputs = model(*args, **kwargs)
# Outputs always has to be a flat dict.
output_names = model.output_names # Functional Model.
if output_names is None: # Subclassed Model.
from keras.engine import compile_utils
output_names = compile_utils.create_pseudo_output_names(outputs)
outputs = tf.nest.flatten(outputs)
return {name: output for name, output in zip(output_names, outputs)}
return _wrapped_model.get_concrete_function(*model_args, **model_kwargs)
def model_metadata(model, include_optimizer=True, require_config=True):
"""Returns a dictionary containing the model metadata."""
from keras import __version__ as keras_version
from keras.optimizers.legacy import optimizer_v2
model_config = {"class_name": model.__class__.__name__}
try:
model_config["config"] = model.get_config()
except NotImplementedError as e:
if require_config:
raise e
metadata = dict(
keras_version=str(keras_version),
backend=backend.backend(),
model_config=model_config,
)
if model.optimizer and include_optimizer:
if isinstance(model.optimizer, optimizer_v1.TFOptimizer):
logging.warning(
"TensorFlow optimizers do not "
"make it possible to access "
"optimizer attributes or optimizer state "
"after instantiation. "
"As a result, we cannot save the optimizer "
"as part of the model save file. "
"You will have to compile your model again after loading it. "
"Prefer using a Keras optimizer instead "
"(see keras.io/optimizers)."
)
elif model._compile_was_called:
training_config = model._get_compile_args(user_metrics=False)
training_config.pop("optimizer", None) # Handled separately.
metadata["training_config"] = _serialize_nested_config(
training_config
)
if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer):
raise NotImplementedError(
"Optimizers loaded from a SavedModel cannot be saved. "
"If you are calling `model.save` or "
"`tf.keras.models.save_model`, "
"please set the `include_optimizer` option to `False`. For "
"`tf.saved_model.save`, "
"delete the optimizer from the model."
)
else:
optimizer_config = {
"class_name": keras.utils.get_registered_name(
model.optimizer.__class__
),
"config": model.optimizer.get_config(),
}
metadata["training_config"]["optimizer_config"] = optimizer_config
return metadata
def should_overwrite(filepath, overwrite):
"""Returns whether the filepath should be overwritten."""
# If file exists and should not be overwritten.
if not overwrite and os.path.isfile(filepath):
return ask_to_proceed_with_overwrite(filepath)
return True
def compile_args_from_training_config(training_config, custom_objects=None):
"""Return model.compile arguments from training config."""
if custom_objects is None:
custom_objects = {}
with keras.utils.CustomObjectScope(custom_objects):
optimizer_config = training_config["optimizer_config"]
optimizer = optimizers.deserialize(optimizer_config)
# Recover losses.
loss = None
loss_config = training_config.get("loss", None)
if loss_config is not None:
loss = _deserialize_nested_config(losses.deserialize, loss_config)
# Recover metrics.
metrics = None
metrics_config = training_config.get("metrics", None)
if metrics_config is not None:
metrics = _deserialize_nested_config(
_deserialize_metric, metrics_config
)
# Recover weighted metrics.
weighted_metrics = None
weighted_metrics_config = training_config.get("weighted_metrics", None)
if weighted_metrics_config is not None:
weighted_metrics = _deserialize_nested_config(
_deserialize_metric, weighted_metrics_config
)
sample_weight_mode = (
training_config["sample_weight_mode"]
if hasattr(training_config, "sample_weight_mode")
else None
)
loss_weights = training_config["loss_weights"]
return dict(
optimizer=optimizer,
loss=loss,
metrics=metrics,
weighted_metrics=weighted_metrics,
loss_weights=loss_weights,
sample_weight_mode=sample_weight_mode,
)
def _deserialize_nested_config(deserialize_fn, config):
"""Deserializes arbitrary Keras `config` using `deserialize_fn`."""
def _is_single_object(obj):
if isinstance(obj, dict) and "class_name" in obj:
return True # Serialized Keras object.
if isinstance(obj, str):
return True # Serialized function or string.
return False
if config is None:
return None
if _is_single_object(config):
return deserialize_fn(config)
elif isinstance(config, dict):
return {
k: _deserialize_nested_config(deserialize_fn, v)
for k, v in config.items()
}
elif isinstance(config, (tuple, list)):
return [
_deserialize_nested_config(deserialize_fn, obj) for obj in config
]
raise ValueError(
"Saved configuration not understood. Configuration should be a "
f"dictionary, string, tuple or list. Received: config={config}."
)
def _serialize_nested_config(config):
"""Serialized a nested structure of Keras objects."""
def _serialize_fn(obj):
if callable(obj):
return serialization.serialize_keras_object(obj)
return obj
return tf.nest.map_structure(_serialize_fn, config)
def _deserialize_metric(metric_config):
"""Deserialize metrics, leaving special strings untouched."""
from keras import metrics as metrics_module
if metric_config in ["accuracy", "acc", "crossentropy", "ce"]:
# Do not deserialize accuracy and cross-entropy strings as we have
# special case handling for these in compile, based on model output
# shape.
return metric_config
return metrics_module.deserialize(metric_config)
def _enforce_names_consistency(specs):
"""Enforces that either all specs have names or none do."""
def _has_name(spec):
return spec is None or (hasattr(spec, "name") and spec.name is not None)
def _clear_name(spec):
spec = copy.deepcopy(spec)
if hasattr(spec, "name"):
spec._name = None
return spec
flat_specs = tf.nest.flatten(specs)
name_inconsistency = any(_has_name(s) for s in flat_specs) and not all(
_has_name(s) for s in flat_specs
)
if name_inconsistency:
specs = tf.nest.map_structure(_clear_name, specs)
return specs
def try_build_compiled_arguments(model):
if (
not version_utils.is_v1_layer_or_model(model)
and model.outputs is not None
):
try:
if not model.compiled_loss.built:
model.compiled_loss.build(model.outputs)
if not model.compiled_metrics.built:
model.compiled_metrics.build(model.outputs, model.outputs)
except: # noqa: E722
logging.warning(
"Compiled the loaded model, but the compiled metrics have "
"yet to be built. `model.compile_metrics` will be empty "
"until you train or evaluate the model."
)
def is_hdf5_filepath(filepath):
return (
filepath.endswith(".h5")
or filepath.endswith(".keras")
or filepath.endswith(".hdf5")
)