# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utils related to keras model saving.""" import copy import os import tensorflow.compat.v2 as tf import keras from keras import backend from keras import losses from keras import optimizers from keras.engine import base_layer_utils from keras.optimizers import optimizer_v1 from keras.saving.legacy import serialization from keras.utils import version_utils from keras.utils.io_utils import ask_to_proceed_with_overwrite # isort: off from tensorflow.python.platform import tf_logging as logging def extract_model_metrics(model): """Convert metrics from a Keras model `compile` API to dictionary. This is used for converting Keras models to Estimators and SavedModels. Args: model: A `tf.keras.Model` object. Returns: Dictionary mapping metric names to metric instances. May return `None` if the model does not contain any metrics. """ if getattr(model, "_compile_metrics", None): # TODO(psv/kathywu): use this implementation in model to estimator flow. # We are not using model.metrics here because we want to exclude the # metrics added using `add_metric` API. return {m.name: m for m in model._compile_metric_functions} return None def model_call_inputs(model, keep_original_batch_size=False): """Inspect model to get its input signature. The model's input signature is a list with a single (possibly-nested) object. This is due to the Keras-enforced restriction that tensor inputs must be passed in as the first argument. For example, a model with input {'feature1': , 'feature2': } will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}] Args: model: Keras Model object. keep_original_batch_size: A boolean indicating whether we want to keep using the original batch size or set it to None. Default is `False`, which means that the batch dim of the returned input signature will always be set to `None`. Returns: A tuple containing `(args, kwargs)` TensorSpecs of the model call function inputs. `kwargs` does not contain the `training` argument. """ input_specs = model.save_spec(dynamic_batch=not keep_original_batch_size) if input_specs is None: return None, None input_specs = _enforce_names_consistency(input_specs) return input_specs def raise_model_input_error(model): if isinstance(model, keras.models.Sequential): raise ValueError( f"Model {model} cannot be saved because the input shape is not " "available. Please specify an input shape either by calling " "`build(input_shape)` directly, or by calling the model on actual " "data using `Model()`, `Model.fit()`, or `Model.predict()`." ) # If the model is not a `Sequential`, it is intended to be a subclassed # model. raise ValueError( f"Model {model} cannot be saved either because the input shape is not " "available or because the forward pass of the model is not defined." "To define a forward pass, please override `Model.call()`. To specify " "an input shape, either call `build(input_shape)` directly, or call " "the model on actual data using `Model()`, `Model.fit()`, or " "`Model.predict()`. If you have a custom training step, please make " "sure to invoke the forward pass in train step through " "`Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`." ) def trace_model_call(model, input_signature=None): """Trace the model call to create a tf.function for exporting a Keras model. Args: model: A Keras model. input_signature: optional, a list of tf.TensorSpec objects specifying the inputs to the model. Returns: A tf.function wrapping the model's call function with input signatures set. Raises: ValueError: if input signature cannot be inferred from the model. """ if input_signature is None: if isinstance(model.call, tf.__internal__.function.Function): input_signature = model.call.input_signature if input_signature: model_args = input_signature model_kwargs = {} else: model_args, model_kwargs = model_call_inputs(model) if model_args is None: raise_model_input_error(model) @tf.function def _wrapped_model(*args, **kwargs): """A concrete tf.function that wraps the model's call function.""" (args, kwargs,) = model._call_spec.set_arg_value( "training", False, args, kwargs, inputs_in_args=True ) with base_layer_utils.call_context().enter( model, inputs=None, build_graph=False, training=False, saving=True ): outputs = model(*args, **kwargs) # Outputs always has to be a flat dict. output_names = model.output_names # Functional Model. if output_names is None: # Subclassed Model. from keras.engine import compile_utils output_names = compile_utils.create_pseudo_output_names(outputs) outputs = tf.nest.flatten(outputs) return {name: output for name, output in zip(output_names, outputs)} return _wrapped_model.get_concrete_function(*model_args, **model_kwargs) def model_metadata(model, include_optimizer=True, require_config=True): """Returns a dictionary containing the model metadata.""" from keras import __version__ as keras_version from keras.optimizers.legacy import optimizer_v2 model_config = {"class_name": model.__class__.__name__} try: model_config["config"] = model.get_config() except NotImplementedError as e: if require_config: raise e metadata = dict( keras_version=str(keras_version), backend=backend.backend(), model_config=model_config, ) if model.optimizer and include_optimizer: if isinstance(model.optimizer, optimizer_v1.TFOptimizer): logging.warning( "TensorFlow optimizers do not " "make it possible to access " "optimizer attributes or optimizer state " "after instantiation. " "As a result, we cannot save the optimizer " "as part of the model save file. " "You will have to compile your model again after loading it. " "Prefer using a Keras optimizer instead " "(see keras.io/optimizers)." ) elif model._compile_was_called: training_config = model._get_compile_args(user_metrics=False) training_config.pop("optimizer", None) # Handled separately. metadata["training_config"] = _serialize_nested_config( training_config ) if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer): raise NotImplementedError( "Optimizers loaded from a SavedModel cannot be saved. " "If you are calling `model.save` or " "`tf.keras.models.save_model`, " "please set the `include_optimizer` option to `False`. For " "`tf.saved_model.save`, " "delete the optimizer from the model." ) else: optimizer_config = { "class_name": keras.utils.get_registered_name( model.optimizer.__class__ ), "config": model.optimizer.get_config(), } metadata["training_config"]["optimizer_config"] = optimizer_config return metadata def should_overwrite(filepath, overwrite): """Returns whether the filepath should be overwritten.""" # If file exists and should not be overwritten. if not overwrite and os.path.isfile(filepath): return ask_to_proceed_with_overwrite(filepath) return True def compile_args_from_training_config(training_config, custom_objects=None): """Return model.compile arguments from training config.""" if custom_objects is None: custom_objects = {} with keras.utils.CustomObjectScope(custom_objects): optimizer_config = training_config["optimizer_config"] optimizer = optimizers.deserialize(optimizer_config) # Recover losses. loss = None loss_config = training_config.get("loss", None) if loss_config is not None: loss = _deserialize_nested_config(losses.deserialize, loss_config) # Recover metrics. metrics = None metrics_config = training_config.get("metrics", None) if metrics_config is not None: metrics = _deserialize_nested_config( _deserialize_metric, metrics_config ) # Recover weighted metrics. weighted_metrics = None weighted_metrics_config = training_config.get("weighted_metrics", None) if weighted_metrics_config is not None: weighted_metrics = _deserialize_nested_config( _deserialize_metric, weighted_metrics_config ) sample_weight_mode = ( training_config["sample_weight_mode"] if hasattr(training_config, "sample_weight_mode") else None ) loss_weights = training_config["loss_weights"] return dict( optimizer=optimizer, loss=loss, metrics=metrics, weighted_metrics=weighted_metrics, loss_weights=loss_weights, sample_weight_mode=sample_weight_mode, ) def _deserialize_nested_config(deserialize_fn, config): """Deserializes arbitrary Keras `config` using `deserialize_fn`.""" def _is_single_object(obj): if isinstance(obj, dict) and "class_name" in obj: return True # Serialized Keras object. if isinstance(obj, str): return True # Serialized function or string. return False if config is None: return None if _is_single_object(config): return deserialize_fn(config) elif isinstance(config, dict): return { k: _deserialize_nested_config(deserialize_fn, v) for k, v in config.items() } elif isinstance(config, (tuple, list)): return [ _deserialize_nested_config(deserialize_fn, obj) for obj in config ] raise ValueError( "Saved configuration not understood. Configuration should be a " f"dictionary, string, tuple or list. Received: config={config}." ) def _serialize_nested_config(config): """Serialized a nested structure of Keras objects.""" def _serialize_fn(obj): if callable(obj): return serialization.serialize_keras_object(obj) return obj return tf.nest.map_structure(_serialize_fn, config) def _deserialize_metric(metric_config): """Deserialize metrics, leaving special strings untouched.""" from keras import metrics as metrics_module if metric_config in ["accuracy", "acc", "crossentropy", "ce"]: # Do not deserialize accuracy and cross-entropy strings as we have # special case handling for these in compile, based on model output # shape. return metric_config return metrics_module.deserialize(metric_config) def _enforce_names_consistency(specs): """Enforces that either all specs have names or none do.""" def _has_name(spec): return spec is None or (hasattr(spec, "name") and spec.name is not None) def _clear_name(spec): spec = copy.deepcopy(spec) if hasattr(spec, "name"): spec._name = None return spec flat_specs = tf.nest.flatten(specs) name_inconsistency = any(_has_name(s) for s in flat_specs) and not all( _has_name(s) for s in flat_specs ) if name_inconsistency: specs = tf.nest.map_structure(_clear_name, specs) return specs def try_build_compiled_arguments(model): if ( not version_utils.is_v1_layer_or_model(model) and model.outputs is not None ): try: if not model.compiled_loss.built: model.compiled_loss.build(model.outputs) if not model.compiled_metrics.built: model.compiled_metrics.build(model.outputs, model.outputs) except: # noqa: E722 logging.warning( "Compiled the loaded model, but the compiled metrics have " "yet to be built. `model.compile_metrics` will be empty " "until you train or evaluate the model." ) def is_hdf5_filepath(filepath): return ( filepath.endswith(".h5") or filepath.endswith(".keras") or filepath.endswith(".hdf5") )