372 lines
13 KiB
Python
372 lines
13 KiB
Python
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Utils related to keras model saving."""
|
||
|
|
||
|
import copy
|
||
|
import os
|
||
|
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
|
||
|
import keras
|
||
|
from keras import backend
|
||
|
from keras import losses
|
||
|
from keras import optimizers
|
||
|
from keras.engine import base_layer_utils
|
||
|
from keras.optimizers import optimizer_v1
|
||
|
from keras.saving.legacy import serialization
|
||
|
from keras.utils import version_utils
|
||
|
from keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||
|
|
||
|
# isort: off
|
||
|
from tensorflow.python.platform import tf_logging as logging
|
||
|
|
||
|
|
||
|
def extract_model_metrics(model):
|
||
|
"""Convert metrics from a Keras model `compile` API to dictionary.
|
||
|
|
||
|
This is used for converting Keras models to Estimators and SavedModels.
|
||
|
|
||
|
Args:
|
||
|
model: A `tf.keras.Model` object.
|
||
|
|
||
|
Returns:
|
||
|
Dictionary mapping metric names to metric instances. May return `None` if
|
||
|
the model does not contain any metrics.
|
||
|
"""
|
||
|
if getattr(model, "_compile_metrics", None):
|
||
|
# TODO(psv/kathywu): use this implementation in model to estimator flow.
|
||
|
# We are not using model.metrics here because we want to exclude the
|
||
|
# metrics added using `add_metric` API.
|
||
|
return {m.name: m for m in model._compile_metric_functions}
|
||
|
return None
|
||
|
|
||
|
|
||
|
def model_call_inputs(model, keep_original_batch_size=False):
|
||
|
"""Inspect model to get its input signature.
|
||
|
|
||
|
The model's input signature is a list with a single (possibly-nested)
|
||
|
object. This is due to the Keras-enforced restriction that tensor inputs
|
||
|
must be passed in as the first argument.
|
||
|
|
||
|
For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>}
|
||
|
will have input signature:
|
||
|
[{'feature1': TensorSpec, 'feature2': TensorSpec}]
|
||
|
|
||
|
Args:
|
||
|
model: Keras Model object.
|
||
|
keep_original_batch_size: A boolean indicating whether we want to keep
|
||
|
using the original batch size or set it to None. Default is `False`,
|
||
|
which means that the batch dim of the returned input signature will
|
||
|
always be set to `None`.
|
||
|
|
||
|
Returns:
|
||
|
A tuple containing `(args, kwargs)` TensorSpecs of the model call function
|
||
|
inputs.
|
||
|
`kwargs` does not contain the `training` argument.
|
||
|
"""
|
||
|
input_specs = model.save_spec(dynamic_batch=not keep_original_batch_size)
|
||
|
if input_specs is None:
|
||
|
return None, None
|
||
|
input_specs = _enforce_names_consistency(input_specs)
|
||
|
return input_specs
|
||
|
|
||
|
|
||
|
def raise_model_input_error(model):
|
||
|
if isinstance(model, keras.models.Sequential):
|
||
|
raise ValueError(
|
||
|
f"Model {model} cannot be saved because the input shape is not "
|
||
|
"available. Please specify an input shape either by calling "
|
||
|
"`build(input_shape)` directly, or by calling the model on actual "
|
||
|
"data using `Model()`, `Model.fit()`, or `Model.predict()`."
|
||
|
)
|
||
|
|
||
|
# If the model is not a `Sequential`, it is intended to be a subclassed
|
||
|
# model.
|
||
|
raise ValueError(
|
||
|
f"Model {model} cannot be saved either because the input shape is not "
|
||
|
"available or because the forward pass of the model is not defined."
|
||
|
"To define a forward pass, please override `Model.call()`. To specify "
|
||
|
"an input shape, either call `build(input_shape)` directly, or call "
|
||
|
"the model on actual data using `Model()`, `Model.fit()`, or "
|
||
|
"`Model.predict()`. If you have a custom training step, please make "
|
||
|
"sure to invoke the forward pass in train step through "
|
||
|
"`Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`."
|
||
|
)
|
||
|
|
||
|
|
||
|
def trace_model_call(model, input_signature=None):
|
||
|
"""Trace the model call to create a tf.function for exporting a Keras model.
|
||
|
|
||
|
Args:
|
||
|
model: A Keras model.
|
||
|
input_signature: optional, a list of tf.TensorSpec objects specifying the
|
||
|
inputs to the model.
|
||
|
|
||
|
Returns:
|
||
|
A tf.function wrapping the model's call function with input signatures
|
||
|
set.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if input signature cannot be inferred from the model.
|
||
|
"""
|
||
|
if input_signature is None:
|
||
|
if isinstance(model.call, tf.__internal__.function.Function):
|
||
|
input_signature = model.call.input_signature
|
||
|
|
||
|
if input_signature:
|
||
|
model_args = input_signature
|
||
|
model_kwargs = {}
|
||
|
else:
|
||
|
model_args, model_kwargs = model_call_inputs(model)
|
||
|
|
||
|
if model_args is None:
|
||
|
raise_model_input_error(model)
|
||
|
|
||
|
@tf.function
|
||
|
def _wrapped_model(*args, **kwargs):
|
||
|
"""A concrete tf.function that wraps the model's call function."""
|
||
|
(args, kwargs,) = model._call_spec.set_arg_value(
|
||
|
"training", False, args, kwargs, inputs_in_args=True
|
||
|
)
|
||
|
|
||
|
with base_layer_utils.call_context().enter(
|
||
|
model, inputs=None, build_graph=False, training=False, saving=True
|
||
|
):
|
||
|
outputs = model(*args, **kwargs)
|
||
|
|
||
|
# Outputs always has to be a flat dict.
|
||
|
output_names = model.output_names # Functional Model.
|
||
|
if output_names is None: # Subclassed Model.
|
||
|
from keras.engine import compile_utils
|
||
|
|
||
|
output_names = compile_utils.create_pseudo_output_names(outputs)
|
||
|
outputs = tf.nest.flatten(outputs)
|
||
|
return {name: output for name, output in zip(output_names, outputs)}
|
||
|
|
||
|
return _wrapped_model.get_concrete_function(*model_args, **model_kwargs)
|
||
|
|
||
|
|
||
|
def model_metadata(model, include_optimizer=True, require_config=True):
|
||
|
"""Returns a dictionary containing the model metadata."""
|
||
|
from keras import __version__ as keras_version
|
||
|
from keras.optimizers.legacy import optimizer_v2
|
||
|
|
||
|
model_config = {"class_name": model.__class__.__name__}
|
||
|
try:
|
||
|
model_config["config"] = model.get_config()
|
||
|
except NotImplementedError as e:
|
||
|
if require_config:
|
||
|
raise e
|
||
|
|
||
|
metadata = dict(
|
||
|
keras_version=str(keras_version),
|
||
|
backend=backend.backend(),
|
||
|
model_config=model_config,
|
||
|
)
|
||
|
if model.optimizer and include_optimizer:
|
||
|
if isinstance(model.optimizer, optimizer_v1.TFOptimizer):
|
||
|
logging.warning(
|
||
|
"TensorFlow optimizers do not "
|
||
|
"make it possible to access "
|
||
|
"optimizer attributes or optimizer state "
|
||
|
"after instantiation. "
|
||
|
"As a result, we cannot save the optimizer "
|
||
|
"as part of the model save file. "
|
||
|
"You will have to compile your model again after loading it. "
|
||
|
"Prefer using a Keras optimizer instead "
|
||
|
"(see keras.io/optimizers)."
|
||
|
)
|
||
|
elif model._compile_was_called:
|
||
|
training_config = model._get_compile_args(user_metrics=False)
|
||
|
training_config.pop("optimizer", None) # Handled separately.
|
||
|
metadata["training_config"] = _serialize_nested_config(
|
||
|
training_config
|
||
|
)
|
||
|
if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer):
|
||
|
raise NotImplementedError(
|
||
|
"Optimizers loaded from a SavedModel cannot be saved. "
|
||
|
"If you are calling `model.save` or "
|
||
|
"`tf.keras.models.save_model`, "
|
||
|
"please set the `include_optimizer` option to `False`. For "
|
||
|
"`tf.saved_model.save`, "
|
||
|
"delete the optimizer from the model."
|
||
|
)
|
||
|
else:
|
||
|
optimizer_config = {
|
||
|
"class_name": keras.utils.get_registered_name(
|
||
|
model.optimizer.__class__
|
||
|
),
|
||
|
"config": model.optimizer.get_config(),
|
||
|
}
|
||
|
metadata["training_config"]["optimizer_config"] = optimizer_config
|
||
|
return metadata
|
||
|
|
||
|
|
||
|
def should_overwrite(filepath, overwrite):
|
||
|
"""Returns whether the filepath should be overwritten."""
|
||
|
# If file exists and should not be overwritten.
|
||
|
if not overwrite and os.path.isfile(filepath):
|
||
|
return ask_to_proceed_with_overwrite(filepath)
|
||
|
return True
|
||
|
|
||
|
|
||
|
def compile_args_from_training_config(training_config, custom_objects=None):
|
||
|
"""Return model.compile arguments from training config."""
|
||
|
if custom_objects is None:
|
||
|
custom_objects = {}
|
||
|
|
||
|
with keras.utils.CustomObjectScope(custom_objects):
|
||
|
optimizer_config = training_config["optimizer_config"]
|
||
|
optimizer = optimizers.deserialize(optimizer_config)
|
||
|
|
||
|
# Recover losses.
|
||
|
loss = None
|
||
|
loss_config = training_config.get("loss", None)
|
||
|
if loss_config is not None:
|
||
|
loss = _deserialize_nested_config(losses.deserialize, loss_config)
|
||
|
|
||
|
# Recover metrics.
|
||
|
metrics = None
|
||
|
metrics_config = training_config.get("metrics", None)
|
||
|
if metrics_config is not None:
|
||
|
metrics = _deserialize_nested_config(
|
||
|
_deserialize_metric, metrics_config
|
||
|
)
|
||
|
|
||
|
# Recover weighted metrics.
|
||
|
weighted_metrics = None
|
||
|
weighted_metrics_config = training_config.get("weighted_metrics", None)
|
||
|
if weighted_metrics_config is not None:
|
||
|
weighted_metrics = _deserialize_nested_config(
|
||
|
_deserialize_metric, weighted_metrics_config
|
||
|
)
|
||
|
|
||
|
sample_weight_mode = (
|
||
|
training_config["sample_weight_mode"]
|
||
|
if hasattr(training_config, "sample_weight_mode")
|
||
|
else None
|
||
|
)
|
||
|
loss_weights = training_config["loss_weights"]
|
||
|
|
||
|
return dict(
|
||
|
optimizer=optimizer,
|
||
|
loss=loss,
|
||
|
metrics=metrics,
|
||
|
weighted_metrics=weighted_metrics,
|
||
|
loss_weights=loss_weights,
|
||
|
sample_weight_mode=sample_weight_mode,
|
||
|
)
|
||
|
|
||
|
|
||
|
def _deserialize_nested_config(deserialize_fn, config):
|
||
|
"""Deserializes arbitrary Keras `config` using `deserialize_fn`."""
|
||
|
|
||
|
def _is_single_object(obj):
|
||
|
if isinstance(obj, dict) and "class_name" in obj:
|
||
|
return True # Serialized Keras object.
|
||
|
if isinstance(obj, str):
|
||
|
return True # Serialized function or string.
|
||
|
return False
|
||
|
|
||
|
if config is None:
|
||
|
return None
|
||
|
if _is_single_object(config):
|
||
|
return deserialize_fn(config)
|
||
|
elif isinstance(config, dict):
|
||
|
return {
|
||
|
k: _deserialize_nested_config(deserialize_fn, v)
|
||
|
for k, v in config.items()
|
||
|
}
|
||
|
elif isinstance(config, (tuple, list)):
|
||
|
return [
|
||
|
_deserialize_nested_config(deserialize_fn, obj) for obj in config
|
||
|
]
|
||
|
|
||
|
raise ValueError(
|
||
|
"Saved configuration not understood. Configuration should be a "
|
||
|
f"dictionary, string, tuple or list. Received: config={config}."
|
||
|
)
|
||
|
|
||
|
|
||
|
def _serialize_nested_config(config):
|
||
|
"""Serialized a nested structure of Keras objects."""
|
||
|
|
||
|
def _serialize_fn(obj):
|
||
|
if callable(obj):
|
||
|
return serialization.serialize_keras_object(obj)
|
||
|
return obj
|
||
|
|
||
|
return tf.nest.map_structure(_serialize_fn, config)
|
||
|
|
||
|
|
||
|
def _deserialize_metric(metric_config):
|
||
|
"""Deserialize metrics, leaving special strings untouched."""
|
||
|
from keras import metrics as metrics_module
|
||
|
|
||
|
if metric_config in ["accuracy", "acc", "crossentropy", "ce"]:
|
||
|
# Do not deserialize accuracy and cross-entropy strings as we have
|
||
|
# special case handling for these in compile, based on model output
|
||
|
# shape.
|
||
|
return metric_config
|
||
|
return metrics_module.deserialize(metric_config)
|
||
|
|
||
|
|
||
|
def _enforce_names_consistency(specs):
|
||
|
"""Enforces that either all specs have names or none do."""
|
||
|
|
||
|
def _has_name(spec):
|
||
|
return spec is None or (hasattr(spec, "name") and spec.name is not None)
|
||
|
|
||
|
def _clear_name(spec):
|
||
|
spec = copy.deepcopy(spec)
|
||
|
if hasattr(spec, "name"):
|
||
|
spec._name = None
|
||
|
return spec
|
||
|
|
||
|
flat_specs = tf.nest.flatten(specs)
|
||
|
name_inconsistency = any(_has_name(s) for s in flat_specs) and not all(
|
||
|
_has_name(s) for s in flat_specs
|
||
|
)
|
||
|
|
||
|
if name_inconsistency:
|
||
|
specs = tf.nest.map_structure(_clear_name, specs)
|
||
|
return specs
|
||
|
|
||
|
|
||
|
def try_build_compiled_arguments(model):
|
||
|
if (
|
||
|
not version_utils.is_v1_layer_or_model(model)
|
||
|
and model.outputs is not None
|
||
|
):
|
||
|
try:
|
||
|
if not model.compiled_loss.built:
|
||
|
model.compiled_loss.build(model.outputs)
|
||
|
if not model.compiled_metrics.built:
|
||
|
model.compiled_metrics.build(model.outputs, model.outputs)
|
||
|
except: # noqa: E722
|
||
|
logging.warning(
|
||
|
"Compiled the loaded model, but the compiled metrics have "
|
||
|
"yet to be built. `model.compile_metrics` will be empty "
|
||
|
"until you train or evaluate the model."
|
||
|
)
|
||
|
|
||
|
|
||
|
def is_hdf5_filepath(filepath):
|
||
|
return (
|
||
|
filepath.endswith(".h5")
|
||
|
or filepath.endswith(".keras")
|
||
|
or filepath.endswith(".hdf5")
|
||
|
)
|