891 lines
35 KiB
Python
891 lines
35 KiB
Python
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
|
||
|
"""Code for model cloning, plus model-related API entries."""
|
||
|
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
|
||
|
from keras import backend
|
||
|
from keras import metrics as metrics_module
|
||
|
from keras.engine import functional
|
||
|
from keras.engine import sequential
|
||
|
from keras.engine import training
|
||
|
from keras.engine import training_v1
|
||
|
from keras.engine.base_layer import AddMetric
|
||
|
from keras.engine.base_layer import Layer
|
||
|
from keras.engine.input_layer import Input
|
||
|
from keras.engine.input_layer import InputLayer
|
||
|
from keras.optimizers import optimizer_v1
|
||
|
from keras.saving.legacy import serialization
|
||
|
from keras.saving.object_registration import CustomObjectScope
|
||
|
from keras.utils import generic_utils
|
||
|
from keras.utils import version_utils
|
||
|
|
||
|
# isort: off
|
||
|
from tensorflow.python.platform import tf_logging as logging
|
||
|
from tensorflow.python.util.tf_export import keras_export
|
||
|
|
||
|
# API entries importable from `keras.models`:
|
||
|
Model = training.Model
|
||
|
Sequential = sequential.Sequential
|
||
|
|
||
|
|
||
|
# Callable used to clone a layer with weights preserved.
|
||
|
def share_weights(layer):
|
||
|
return layer
|
||
|
|
||
|
|
||
|
def _clone_layer(layer):
|
||
|
return layer.__class__.from_config(layer.get_config())
|
||
|
|
||
|
|
||
|
def _insert_ancillary_layers(model, ancillary_layers, metrics_names, new_nodes):
|
||
|
"""Inserts ancillary layers into the model with the proper order."""
|
||
|
# Sort `AddMetric` layers so they agree with metrics_names.
|
||
|
metric_layers = [
|
||
|
layer for layer in ancillary_layers if isinstance(layer, AddMetric)
|
||
|
]
|
||
|
metric_layers.sort(key=lambda layer: metrics_names.index(layer.metric_name))
|
||
|
ancillary_layers = [
|
||
|
layer for layer in ancillary_layers if not isinstance(layer, AddMetric)
|
||
|
] + metric_layers
|
||
|
model._insert_layers(ancillary_layers, relevant_nodes=list(new_nodes))
|
||
|
|
||
|
|
||
|
def _make_new_nodes(nodes_by_depth, layer_fn, layer_map, tensor_map):
|
||
|
"""Make new nodes with the layers in `layer_map` based on `nodes_by_depth`.
|
||
|
|
||
|
Args:
|
||
|
nodes_by_depth: Provides structure information to create new nodes.
|
||
|
layer_fn: Function to clone layers.
|
||
|
layer_map: Map from layers in `model` to new layers.
|
||
|
tensor_map: Map from tensors in `model` to newly compute tensors.
|
||
|
|
||
|
Returns:
|
||
|
A set of new nodes. `layer_map` and `tensor_map` are updated.
|
||
|
"""
|
||
|
# Iterated over every node in the reference model, in depth order.
|
||
|
new_nodes = set()
|
||
|
depth_keys = list(nodes_by_depth.keys())
|
||
|
depth_keys.sort(reverse=True)
|
||
|
for depth in depth_keys:
|
||
|
nodes = nodes_by_depth[depth]
|
||
|
for node in nodes:
|
||
|
# Recover the corresponding layer.
|
||
|
layer = node.outbound_layer
|
||
|
|
||
|
# Get or create layer.
|
||
|
if layer not in layer_map:
|
||
|
new_layer = layer_fn(layer)
|
||
|
layer_map[layer] = new_layer
|
||
|
layer = new_layer
|
||
|
else:
|
||
|
# Reuse previously cloned layer.
|
||
|
layer = layer_map[layer]
|
||
|
# Don't call InputLayer multiple times.
|
||
|
if isinstance(layer, InputLayer):
|
||
|
continue
|
||
|
|
||
|
# If all previous input tensors are available in tensor_map,
|
||
|
# then call node.inbound_layer on them.
|
||
|
if all(
|
||
|
tensor in tensor_map
|
||
|
for tensor in tf.nest.flatten(node.input_tensors)
|
||
|
):
|
||
|
# Call layer.
|
||
|
args = tf.nest.map_structure(
|
||
|
lambda t: tensor_map.get(t, t), node.call_args
|
||
|
)
|
||
|
kwargs = tf.nest.map_structure(
|
||
|
lambda t: tensor_map.get(t, t), node.call_kwargs
|
||
|
)
|
||
|
output_tensors = layer(*args, **kwargs)
|
||
|
|
||
|
# Thread-safe way to keep track of what node was created.
|
||
|
first_output_tensor = tf.nest.flatten(output_tensors)[0]
|
||
|
new_nodes.add(
|
||
|
layer._inbound_nodes[
|
||
|
first_output_tensor._keras_history.node_index
|
||
|
]
|
||
|
)
|
||
|
|
||
|
for x, y in zip(
|
||
|
tf.nest.flatten(node.output_tensors),
|
||
|
tf.nest.flatten(output_tensors),
|
||
|
):
|
||
|
tensor_map[x] = y
|
||
|
return new_nodes
|
||
|
|
||
|
|
||
|
def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer):
|
||
|
"""Clone a functional `Model` instance.
|
||
|
|
||
|
Model cloning is similar to calling a model on new inputs,
|
||
|
except that it creates new layers (and thus new weights) instead
|
||
|
of sharing the weights of the existing layers.
|
||
|
|
||
|
Input layers are always cloned.
|
||
|
|
||
|
Args:
|
||
|
model: Instance of `Model`.
|
||
|
input_tensors: optional list of input tensors
|
||
|
to build the model upon. If not provided,
|
||
|
placeholders will be created.
|
||
|
layer_fn: callable to be applied on non-input layers in the model. By
|
||
|
default it clones the layer. Another example is to preserve the
|
||
|
layer to share the weights. This is required when we create a
|
||
|
per-replica copy of the model with distribution strategy; we want
|
||
|
the weights to be shared but still feed inputs separately so we
|
||
|
create new input layers.
|
||
|
|
||
|
Returns:
|
||
|
An instance of `Model` reproducing the behavior
|
||
|
of the original model, on top of new inputs tensors,
|
||
|
using newly instantiated weights.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: in case of invalid `model` argument value or `layer_fn`
|
||
|
argument value.
|
||
|
"""
|
||
|
if layer_fn is None:
|
||
|
layer_fn = _clone_layer
|
||
|
|
||
|
if not isinstance(model, Model):
|
||
|
raise ValueError(
|
||
|
"Expected `model` argument "
|
||
|
f"to be a `Model` instance. Received: model={model}"
|
||
|
)
|
||
|
if isinstance(model, Sequential):
|
||
|
raise ValueError(
|
||
|
"Expected `model` argument "
|
||
|
"to be a functional `Model` instance, "
|
||
|
f"got a `Sequential` instance instead: {model}"
|
||
|
)
|
||
|
if not model._is_graph_network:
|
||
|
raise ValueError(
|
||
|
"Expected `model` argument "
|
||
|
"to be a functional `Model` instance, "
|
||
|
f"but got a subclassed model instead: {model}"
|
||
|
)
|
||
|
|
||
|
new_input_layers = {} # Cache for created layers.
|
||
|
if input_tensors is not None:
|
||
|
# Make sure that all input tensors come from a Keras layer.
|
||
|
input_tensors = tf.nest.flatten(input_tensors)
|
||
|
for i, input_tensor in enumerate(input_tensors):
|
||
|
original_input_layer = model._input_layers[i]
|
||
|
|
||
|
# Cache input layer. Create a new layer if the tensor is originally
|
||
|
# not from a Keras layer.
|
||
|
if not backend.is_keras_tensor(input_tensor):
|
||
|
name = original_input_layer.name
|
||
|
input_tensor = Input(
|
||
|
tensor=input_tensor, name="input_wrapper_for_" + name
|
||
|
)
|
||
|
newly_created_input_layer = input_tensor._keras_history.layer
|
||
|
new_input_layers[
|
||
|
original_input_layer
|
||
|
] = newly_created_input_layer
|
||
|
else:
|
||
|
new_input_layers[
|
||
|
original_input_layer
|
||
|
] = input_tensor._keras_history.layer
|
||
|
|
||
|
if not callable(layer_fn):
|
||
|
raise ValueError(
|
||
|
"Expected `layer_fn` argument to be a callable. "
|
||
|
f"Received: layer_fn={layer_fn}"
|
||
|
)
|
||
|
|
||
|
model_configs, created_layers = _clone_layers_and_model_config(
|
||
|
model, new_input_layers, layer_fn
|
||
|
)
|
||
|
# Reconstruct model from the config, using the cloned layers.
|
||
|
(
|
||
|
input_tensors,
|
||
|
output_tensors,
|
||
|
created_layers,
|
||
|
) = functional.reconstruct_from_config(
|
||
|
model_configs, created_layers=created_layers
|
||
|
)
|
||
|
metrics_names = model.metrics_names
|
||
|
if functional.has_functional_like_constructor(model.__class__):
|
||
|
new_model = model.__class__(
|
||
|
input_tensors, output_tensors, name=model.name
|
||
|
)
|
||
|
else:
|
||
|
# This may be incorrect: the new model will end up having a different
|
||
|
# class than the original. However various existing models rely
|
||
|
# on this behavior, so we keep it.
|
||
|
new_model = Model(input_tensors, output_tensors, name=model.name)
|
||
|
|
||
|
# Layers not directly tied to outputs of the Model, such as loss layers
|
||
|
# created in `add_loss` and `add_metric`.
|
||
|
ancillary_layers = [
|
||
|
layer
|
||
|
for layer in created_layers.values()
|
||
|
if layer not in new_model.layers
|
||
|
]
|
||
|
# TODO(b/162887610): This may need to adjust the inbound node index if the
|
||
|
# created layers had already been used to define other models.
|
||
|
if ancillary_layers:
|
||
|
new_nodes = tf.nest.flatten(
|
||
|
[
|
||
|
layer.inbound_nodes[1:]
|
||
|
if functional._should_skip_first_node(layer)
|
||
|
else layer.inbound_nodes
|
||
|
for layer in created_layers.values()
|
||
|
]
|
||
|
)
|
||
|
_insert_ancillary_layers(
|
||
|
new_model, ancillary_layers, metrics_names, new_nodes
|
||
|
)
|
||
|
return new_model
|
||
|
|
||
|
|
||
|
def _clone_layers_and_model_config(model, input_layers, layer_fn):
|
||
|
"""Clones all layers; returns the model config without serializing layers.
|
||
|
|
||
|
This function ensures that only the node graph is retrieved when getting the
|
||
|
model config. The `layer_fn` used to clone layers might not rely on
|
||
|
`layer.get_config()`, so some custom layers do not define `get_config`.
|
||
|
Trying to retrieve the config results in errors.
|
||
|
|
||
|
Args:
|
||
|
model: A Functional model.
|
||
|
input_layers: Dictionary mapping input layers in `model` to new input
|
||
|
layers.
|
||
|
layer_fn: Function used to clone all non-input layers.
|
||
|
|
||
|
Returns:
|
||
|
Model config object, and a dictionary of newly created layers.
|
||
|
"""
|
||
|
created_layers = {}
|
||
|
|
||
|
def _copy_layer(layer):
|
||
|
# Whenever the network config attempts to get the layer serialization,
|
||
|
# return a dummy dictionary.
|
||
|
if layer in input_layers:
|
||
|
created_layers[layer.name] = input_layers[layer]
|
||
|
elif layer in model._input_layers:
|
||
|
created_layers[layer.name] = InputLayer(**layer.get_config())
|
||
|
else:
|
||
|
created_layers[layer.name] = layer_fn(layer)
|
||
|
return {}
|
||
|
|
||
|
config = functional.get_network_config(
|
||
|
model, serialize_layer_fn=_copy_layer
|
||
|
)
|
||
|
return config, created_layers
|
||
|
|
||
|
|
||
|
def _remove_ancillary_layers(model, layer_map, layers):
|
||
|
"""Removes and returns any ancillary layers from `layers` based on `model`.
|
||
|
|
||
|
Ancillary layers are part of the model topology but not used to compute the
|
||
|
model outputs, e.g., layers from `add_loss` and `add_metric`.
|
||
|
|
||
|
Args:
|
||
|
model: A Keras Model.
|
||
|
layer_map: A map to from layers in the `model` to those in `layers`.
|
||
|
layers: A list of all layers.
|
||
|
|
||
|
Returns:
|
||
|
Two lists of layers: (1) `layers` with the ancillary layers removed, and
|
||
|
(2) the ancillary layers.
|
||
|
"""
|
||
|
ancillary_layers = [] # Additional layers for computing losses and metrics.
|
||
|
if not model._is_graph_network:
|
||
|
return layers, ancillary_layers
|
||
|
|
||
|
# Ancillary layers are those with depth < 0.
|
||
|
depths = [depth for depth in model._nodes_by_depth.keys() if depth < 0]
|
||
|
depths.sort(reverse=True) # Order topologically from inputs to outputs.
|
||
|
for depth in depths:
|
||
|
for node in model._nodes_by_depth[depth]:
|
||
|
ancillary_layers.append(layer_map[node.outbound_layer])
|
||
|
|
||
|
return [l for l in layers if l not in ancillary_layers], ancillary_layers
|
||
|
|
||
|
|
||
|
def _clone_sequential_model(model, input_tensors=None, layer_fn=_clone_layer):
|
||
|
"""Clone a `Sequential` model instance.
|
||
|
|
||
|
Model cloning is similar to calling a model on new inputs,
|
||
|
except that it creates new layers (and thus new weights) instead
|
||
|
of sharing the weights of the existing layers.
|
||
|
|
||
|
Args:
|
||
|
model: Instance of `Sequential`.
|
||
|
input_tensors: optional list of input tensors
|
||
|
to build the model upon. If not provided,
|
||
|
placeholders will be created.
|
||
|
layer_fn: callable to be applied on non-input layers in the model. By
|
||
|
default it clones the layer. Another example is to preserve the
|
||
|
layer to share the weights. This is required when we create a
|
||
|
per-replica copy of the model with distribution strategy; we want
|
||
|
the weights to be shared but still feed inputs separately so we
|
||
|
create new input layers.
|
||
|
|
||
|
Returns:
|
||
|
An instance of `Sequential` reproducing the behavior
|
||
|
of the original model, on top of new inputs tensors,
|
||
|
using newly instantiated weights.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: in case of invalid `model` argument value or `layer_fn`
|
||
|
argument value.
|
||
|
"""
|
||
|
if layer_fn is None:
|
||
|
layer_fn = _clone_layer
|
||
|
|
||
|
if not isinstance(model, Sequential):
|
||
|
raise ValueError(
|
||
|
"Expected `model` argument "
|
||
|
"to be a `Sequential` model instance. "
|
||
|
f"Received: model={model}"
|
||
|
)
|
||
|
|
||
|
if not callable(layer_fn):
|
||
|
raise ValueError(
|
||
|
"Expected `layer_fn` argument to be a callable. "
|
||
|
f"Received: layer_fn={layer_fn}"
|
||
|
)
|
||
|
|
||
|
layers = [] # Layers needed to compute the model's outputs.
|
||
|
layer_map = {}
|
||
|
# Ensure that all layers are cloned. The model's layers
|
||
|
# property will exclude the initial InputLayer (if it exists) in the model,
|
||
|
# resulting in a different Sequential model structure.
|
||
|
for layer in model._flatten_layers(include_self=False, recursive=False):
|
||
|
if isinstance(layer, InputLayer) and input_tensors is not None:
|
||
|
# If input tensors are provided, the original model's InputLayer is
|
||
|
# overwritten with a different InputLayer.
|
||
|
continue
|
||
|
cloned_layer = (
|
||
|
_clone_layer(layer)
|
||
|
if isinstance(layer, InputLayer)
|
||
|
else layer_fn(layer)
|
||
|
)
|
||
|
layers.append(cloned_layer)
|
||
|
layer_map[layer] = cloned_layer
|
||
|
layers, ancillary_layers = _remove_ancillary_layers(
|
||
|
model, layer_map, layers
|
||
|
)
|
||
|
|
||
|
if input_tensors is None:
|
||
|
cloned_model = Sequential(layers=layers, name=model.name)
|
||
|
elif len(generic_utils.to_list(input_tensors)) != 1:
|
||
|
raise ValueError(
|
||
|
"To clone a `Sequential` model, we expect at most one tensor as "
|
||
|
f"part of `input_tensors`. Received: input_tensors={input_tensors}"
|
||
|
)
|
||
|
else:
|
||
|
# Overwrite the original model's input layer.
|
||
|
if isinstance(input_tensors, tuple):
|
||
|
input_tensors = list(input_tensors)
|
||
|
x = generic_utils.to_list(input_tensors)[0]
|
||
|
if backend.is_keras_tensor(x):
|
||
|
origin_layer = x._keras_history.layer
|
||
|
if isinstance(origin_layer, InputLayer):
|
||
|
cloned_model = Sequential(
|
||
|
layers=[origin_layer] + layers, name=model.name
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
"Cannot clone a `Sequential` model on top "
|
||
|
"of a tensor that comes from a Keras layer "
|
||
|
"other than an `InputLayer`. "
|
||
|
"Use the Functional API instead. "
|
||
|
f"Received: input_tensors={input_tensors}"
|
||
|
)
|
||
|
else:
|
||
|
input_tensor = Input(
|
||
|
tensor=x, name="input_wrapper_for_" + str(x.name)
|
||
|
)
|
||
|
input_layer = input_tensor._keras_history.layer
|
||
|
cloned_model = Sequential(
|
||
|
layers=[input_layer] + layers, name=model.name
|
||
|
)
|
||
|
|
||
|
if not ancillary_layers:
|
||
|
return cloned_model
|
||
|
|
||
|
tensor_map = {} # Maps tensors from `model` to those in `cloned_model`.
|
||
|
for depth, cloned_nodes in cloned_model._nodes_by_depth.items():
|
||
|
nodes = model._nodes_by_depth[depth]
|
||
|
# This should be safe in a Sequential model. In an arbitrary network,
|
||
|
# you need to sort using the outbound layer of the node as a key.
|
||
|
for cloned_node, node in zip(cloned_nodes, nodes):
|
||
|
if isinstance(cloned_node.output_tensors, list):
|
||
|
for j, output_tensor in enumerate(cloned_node.output_tensors):
|
||
|
tensor_map[node.output_tensors[j]] = output_tensor
|
||
|
else:
|
||
|
tensor_map[node.output_tensors] = cloned_node.output_tensors
|
||
|
# Ancillary nodes have negative depth.
|
||
|
new_nodes = _make_new_nodes(
|
||
|
{
|
||
|
depth: nodes
|
||
|
for depth, nodes in model._nodes_by_depth.items()
|
||
|
if depth < 0
|
||
|
},
|
||
|
layer_fn,
|
||
|
layer_map,
|
||
|
tensor_map,
|
||
|
)
|
||
|
_insert_ancillary_layers(
|
||
|
cloned_model, ancillary_layers, model.metrics_names, new_nodes
|
||
|
)
|
||
|
return cloned_model
|
||
|
|
||
|
|
||
|
@keras_export("keras.models.clone_model")
|
||
|
def clone_model(model, input_tensors=None, clone_function=None):
|
||
|
"""Clone a Functional or Sequential `Model` instance.
|
||
|
|
||
|
Model cloning is similar to calling a model on new inputs,
|
||
|
except that it creates new layers (and thus new weights) instead
|
||
|
of sharing the weights of the existing layers.
|
||
|
|
||
|
Note that
|
||
|
`clone_model` will not preserve the uniqueness of shared objects within the
|
||
|
model (e.g. a single variable attached to two distinct layers will be
|
||
|
restored as two separate variables).
|
||
|
|
||
|
Args:
|
||
|
model: Instance of `Model`
|
||
|
(could be a Functional model or a Sequential model).
|
||
|
input_tensors: optional list of input tensors or InputLayer objects
|
||
|
to build the model upon. If not provided,
|
||
|
new `Input` objects will be created.
|
||
|
clone_function: Callable to be used to clone each layer in the target
|
||
|
model (except `InputLayer` instances). It takes as argument the
|
||
|
layer instance to be cloned, and returns the corresponding layer
|
||
|
instance to be used in the model copy. If unspecified, this callable
|
||
|
defaults to the following serialization/deserialization function:
|
||
|
`lambda layer: layer.__class__.from_config(layer.get_config())`.
|
||
|
By passing a custom callable, you can customize your copy of the
|
||
|
model, e.g. by wrapping certain layers of interest (you might want
|
||
|
to replace all `LSTM` instances with equivalent
|
||
|
`Bidirectional(LSTM(...))` instances, for example).
|
||
|
|
||
|
Returns:
|
||
|
An instance of `Model` reproducing the behavior
|
||
|
of the original model, on top of new inputs tensors,
|
||
|
using newly instantiated weights. The cloned model may behave
|
||
|
differently from the original model if a custom `clone_function`
|
||
|
modifies the layer.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
# Create a test Sequential model.
|
||
|
model = keras.Sequential([
|
||
|
keras.Input(shape=(728,)),
|
||
|
keras.layers.Dense(32, activation='relu'),
|
||
|
keras.layers.Dense(1, activation='sigmoid'),
|
||
|
])
|
||
|
# Create a copy of the test model (with freshly initialized weights).
|
||
|
new_model = clone_model(model)
|
||
|
```
|
||
|
|
||
|
Note that subclassed models cannot be cloned, since their internal
|
||
|
layer structure is not known. To achieve equivalent functionality
|
||
|
as `clone_model` in the case of a subclassed model, simply make sure
|
||
|
that the model class implements `get_config()`
|
||
|
(and optionally `from_config()`), and call:
|
||
|
|
||
|
```python
|
||
|
new_model = model.__class__.from_config(model.get_config())
|
||
|
```
|
||
|
"""
|
||
|
with serialization.DisableSharedObjectScope():
|
||
|
if isinstance(model, Sequential):
|
||
|
return _clone_sequential_model(
|
||
|
model, input_tensors=input_tensors, layer_fn=clone_function
|
||
|
)
|
||
|
if isinstance(model, functional.Functional):
|
||
|
# If the get_config() method is the same as a regular Functional
|
||
|
# model, we're safe to use _clone_functional_model (which relies
|
||
|
# on a Functional constructor). In the case where the get_config
|
||
|
# is custom, this may not necessarily work, but if clone_function
|
||
|
# or input_tensors are passed, we attempt it anyway
|
||
|
# in order to preserve backwards compatibility.
|
||
|
if generic_utils.is_default(model.get_config) or (
|
||
|
clone_function or input_tensors
|
||
|
):
|
||
|
return _clone_functional_model(
|
||
|
model, input_tensors=input_tensors, layer_fn=clone_function
|
||
|
)
|
||
|
|
||
|
# Case of a custom model class
|
||
|
if clone_function or input_tensors:
|
||
|
raise ValueError(
|
||
|
"Arguments clone_function and input_tensors "
|
||
|
"are only supported for Sequential models "
|
||
|
"or Functional models. Received model of "
|
||
|
f"type '{model.__class__.__name__}', with "
|
||
|
f"clone_function={clone_function} and "
|
||
|
f"input_tensors={input_tensors}"
|
||
|
)
|
||
|
# Note that a custom object scope may be required in this case.
|
||
|
return model.__class__.from_config(model.get_config())
|
||
|
|
||
|
|
||
|
# "Clone" a subclassed model by resetting all of the attributes.
|
||
|
def _in_place_subclassed_model_reset(model):
|
||
|
"""Substitute for model cloning that works for subclassed models.
|
||
|
|
||
|
Subclassed models cannot be cloned because their topology is not
|
||
|
serializable. To "instantiate" an identical model in a new TF graph, we
|
||
|
reuse the original model object, but we clear its state.
|
||
|
|
||
|
After calling this function on a model instance, you can use the model
|
||
|
instance as if it were a model clone (in particular you can use it in a new
|
||
|
graph).
|
||
|
|
||
|
This method clears the state of the input model. It is thus destructive.
|
||
|
However the original state can be restored fully by calling
|
||
|
`_in_place_subclassed_model_state_restoration`.
|
||
|
|
||
|
Args:
|
||
|
model: Instance of a Keras model created via subclassing.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: In case the model uses a subclassed model as inner layer.
|
||
|
"""
|
||
|
assert (
|
||
|
not model._is_graph_network
|
||
|
) # Only makes sense for subclassed networks
|
||
|
# Select correct base class for new Model.
|
||
|
version_utils.swap_class(
|
||
|
model.__class__,
|
||
|
training.Model,
|
||
|
training_v1.Model,
|
||
|
tf.compat.v1.executing_eagerly_outside_functions(),
|
||
|
)
|
||
|
# Retrieve all layers tracked by the model as well as their attribute names
|
||
|
attributes_cache = {}
|
||
|
for name in dir(model):
|
||
|
# Skip attrs that track other trackables.
|
||
|
if name == "submodules" or name == "_self_tracked_trackables":
|
||
|
continue
|
||
|
|
||
|
try:
|
||
|
value = getattr(model, name)
|
||
|
except (AttributeError, ValueError, TypeError):
|
||
|
continue
|
||
|
if isinstance(value, Layer):
|
||
|
attributes_cache[name] = value
|
||
|
assert value in model.layers
|
||
|
if hasattr(value, "layers") and value.layers:
|
||
|
raise ValueError(
|
||
|
"We do not support the use of nested layers "
|
||
|
"in `model_to_estimator` at this time. Found nested "
|
||
|
f"layer: {value}"
|
||
|
)
|
||
|
elif isinstance(value, (list, tuple)) and name not in (
|
||
|
"layers",
|
||
|
"_layers",
|
||
|
"metrics",
|
||
|
"_compile_metric_functions",
|
||
|
"_output_loss_metrics",
|
||
|
):
|
||
|
# Handle case: list/tuple of layers (also tracked by the Network
|
||
|
# API).
|
||
|
if value and all(isinstance(val, Layer) for val in value):
|
||
|
raise ValueError(
|
||
|
"We do not support the use of list-of-layers "
|
||
|
"attributes in subclassed models used with "
|
||
|
"`model_to_estimator` at this time. Found list "
|
||
|
f"model: {name}"
|
||
|
)
|
||
|
|
||
|
# Replace layers on the model with fresh layers
|
||
|
layers_to_names = {value: key for key, value in attributes_cache.items()}
|
||
|
original_layers = list(
|
||
|
model._flatten_layers(include_self=False, recursive=False)
|
||
|
)
|
||
|
setattr_tracking = model._setattr_tracking
|
||
|
model._setattr_tracking = False
|
||
|
model._self_tracked_trackables = []
|
||
|
for layer in original_layers: # We preserve layer order.
|
||
|
config = layer.get_config()
|
||
|
# This will not work for nested subclassed models used as layers.
|
||
|
# This would be theoretically possible to support, but would add
|
||
|
# complexity. Only do it if users complain.
|
||
|
if isinstance(layer, training.Model) and not layer._is_graph_network:
|
||
|
raise ValueError(
|
||
|
"We do not support the use of nested subclassed models "
|
||
|
"in `model_to_estimator` at this time. Found nested "
|
||
|
f"model: {layer}"
|
||
|
)
|
||
|
fresh_layer = layer.__class__.from_config(config)
|
||
|
name = layers_to_names[layer]
|
||
|
setattr(model, name, fresh_layer)
|
||
|
model._self_tracked_trackables.append(fresh_layer)
|
||
|
|
||
|
# Cache original model build attributes (in addition to layers)
|
||
|
if (
|
||
|
not hasattr(model, "_original_attributes_cache")
|
||
|
or model._original_attributes_cache is None
|
||
|
):
|
||
|
if model.built:
|
||
|
attributes_to_cache = [
|
||
|
"inputs",
|
||
|
"outputs",
|
||
|
"total_loss",
|
||
|
"optimizer",
|
||
|
"train_function",
|
||
|
"test_function",
|
||
|
"predict_function",
|
||
|
"_training_endpoints",
|
||
|
"_collected_trainable_weights",
|
||
|
"_feed_inputs",
|
||
|
"_feed_input_names",
|
||
|
"_feed_input_shapes",
|
||
|
]
|
||
|
for name in attributes_to_cache:
|
||
|
attributes_cache[name] = getattr(model, name)
|
||
|
model._original_attributes_cache = attributes_cache
|
||
|
_reset_build_compile_trackers(model)
|
||
|
model._setattr_tracking = setattr_tracking
|
||
|
|
||
|
|
||
|
def _reset_build_compile_trackers(model):
|
||
|
"""Reset state trackers for model.
|
||
|
|
||
|
Note that we do not actually zero out attributes such as optimizer,
|
||
|
but instead rely on the expectation that all of the attrs will be
|
||
|
over-written on calling build/compile/etc. This is somewhat fragile,
|
||
|
insofar as we check elsewhere for the presence of these attributes as
|
||
|
evidence of having been built/compiled/etc. Pending a better way to do this,
|
||
|
we reset key attributes here to allow building and compiling.
|
||
|
|
||
|
Args:
|
||
|
model: the model that is being reset
|
||
|
"""
|
||
|
# Reset build state
|
||
|
model.built = False
|
||
|
model.inputs = None
|
||
|
model.outputs = None
|
||
|
# Reset compile state
|
||
|
model._is_compiled = False
|
||
|
if not tf.compat.v1.executing_eagerly_outside_functions():
|
||
|
model._v1_compile_was_called = False
|
||
|
model.optimizer = None
|
||
|
|
||
|
|
||
|
@keras_export(
|
||
|
"keras.__internal__.models.in_place_subclassed_model_state_restoration",
|
||
|
v1=[],
|
||
|
)
|
||
|
def in_place_subclassed_model_state_restoration(model):
|
||
|
"""Restores the original state of a model after it was "reset".
|
||
|
|
||
|
This undoes this action of `_in_place_subclassed_model_reset`, which is
|
||
|
called in `clone_and_build_model` if `in_place_reset` is set to True.
|
||
|
|
||
|
Args:
|
||
|
model: Instance of a Keras model created via subclassing, on which
|
||
|
`_in_place_subclassed_model_reset` was previously called.
|
||
|
"""
|
||
|
assert not model._is_graph_network
|
||
|
# Restore layers and build attributes
|
||
|
if (
|
||
|
hasattr(model, "_original_attributes_cache")
|
||
|
and model._original_attributes_cache is not None
|
||
|
):
|
||
|
# Models have sticky attribute assignment, so we want to be careful to
|
||
|
# add back the previous attributes and track Layers by their original
|
||
|
# names without adding dependencies on "utility" attributes which Models
|
||
|
# exempt when they're constructed.
|
||
|
setattr_tracking = model._setattr_tracking
|
||
|
model._setattr_tracking = False
|
||
|
model._self_tracked_trackables = []
|
||
|
for name, value in model._original_attributes_cache.items():
|
||
|
setattr(model, name, value)
|
||
|
if isinstance(value, Layer):
|
||
|
model._self_tracked_trackables.append(value)
|
||
|
model._original_attributes_cache = None
|
||
|
model._setattr_tracking = setattr_tracking
|
||
|
else:
|
||
|
# Restore to the state of a never-called model.
|
||
|
_reset_build_compile_trackers(model)
|
||
|
|
||
|
|
||
|
@keras_export("keras.__internal__.models.clone_and_build_model", v1=[])
|
||
|
def clone_and_build_model(
|
||
|
model,
|
||
|
input_tensors=None,
|
||
|
target_tensors=None,
|
||
|
custom_objects=None,
|
||
|
compile_clone=True,
|
||
|
in_place_reset=False,
|
||
|
optimizer_iterations=None,
|
||
|
optimizer_config=None,
|
||
|
):
|
||
|
"""Clone a `Model` and build/compile it with the same settings used before.
|
||
|
|
||
|
This function can be run in the same graph or in a separate graph from the
|
||
|
model. When using a separate graph, `in_place_reset` must be `False`.
|
||
|
|
||
|
Note that, currently, the clone produced from this function may not work
|
||
|
with TPU DistributionStrategy. Try at your own risk.
|
||
|
|
||
|
Args:
|
||
|
model: `tf.keras.Model` object. Can be Functional, Sequential, or
|
||
|
sub-classed.
|
||
|
input_tensors: Optional list or dictionary of input tensors to build the
|
||
|
model upon. If not provided, placeholders will be created.
|
||
|
target_tensors: Optional list of target tensors for compiling the model.
|
||
|
If not provided, placeholders will be created.
|
||
|
custom_objects: Optional dictionary mapping string names to custom classes
|
||
|
or functions.
|
||
|
compile_clone: Boolean, whether to compile model clone (default `True`).
|
||
|
in_place_reset: Boolean, whether to reset the model in place. Only used if
|
||
|
the model is a subclassed model. In the case of a subclassed model,
|
||
|
this argument must be set to `True` (default `False`). To restore the
|
||
|
original model, use the function
|
||
|
`in_place_subclassed_model_state_restoration(model)`.
|
||
|
optimizer_iterations: An iterations variable that will be incremented by
|
||
|
the optimizer if the clone is compiled. This argument is used when a
|
||
|
Keras model is cloned into an Estimator model function, because
|
||
|
Estimators create their own global step variable.
|
||
|
optimizer_config: Optimizer config dictionary or list of dictionary
|
||
|
returned from `get_config()`. This argument should be defined if
|
||
|
`clone_and_build_model` is called in a different graph or session from
|
||
|
the original model, and the optimizer is an instance of `OptimizerV2`.
|
||
|
|
||
|
Returns:
|
||
|
Clone of the model.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: Cloning fails in the following cases
|
||
|
- cloning a subclassed model with `in_place_reset` set to False.
|
||
|
- compiling the clone when the original model has not been compiled.
|
||
|
"""
|
||
|
# Grab optimizer now, as we reset-in-place for subclassed models, but
|
||
|
# want to maintain access to the original optimizer.
|
||
|
orig_optimizer = model.optimizer
|
||
|
if compile_clone and not orig_optimizer:
|
||
|
raise ValueError(
|
||
|
"Error when cloning model: `compile_clone` was set to True, but "
|
||
|
f"the original model has not been compiled. Received: model={model}"
|
||
|
)
|
||
|
|
||
|
if compile_clone:
|
||
|
compile_args = model._get_compile_args()
|
||
|
# Allows this method to be robust to switching graph and eager classes.
|
||
|
model._get_compile_args = lambda: compile_args
|
||
|
|
||
|
with CustomObjectScope(custom_objects or {}):
|
||
|
if model._is_graph_network:
|
||
|
clone = clone_model(model, input_tensors=input_tensors)
|
||
|
elif isinstance(model, Sequential):
|
||
|
clone = clone_model(model, input_tensors=input_tensors)
|
||
|
if (
|
||
|
not clone._is_graph_network
|
||
|
and model._build_input_shape is not None
|
||
|
):
|
||
|
if tf.compat.v1.executing_eagerly_outside_functions():
|
||
|
clone.build(model._build_input_shape)
|
||
|
else:
|
||
|
clone._set_inputs(
|
||
|
backend.placeholder(
|
||
|
model._build_input_shape,
|
||
|
dtype=model.inputs[0].dtype,
|
||
|
)
|
||
|
)
|
||
|
else:
|
||
|
try:
|
||
|
# Prefer cloning the model if serial/deserial logic is
|
||
|
# implemented for subclassed model.
|
||
|
clone = model.__class__.from_config(model.get_config())
|
||
|
except NotImplementedError:
|
||
|
logging.warning(
|
||
|
"This model is a subclassed model. Please implement "
|
||
|
"`get_config` and `from_config` to better support "
|
||
|
"cloning the model."
|
||
|
)
|
||
|
if not in_place_reset:
|
||
|
raise ValueError(
|
||
|
f"This model ({model}) is a subclassed model. "
|
||
|
"Such a model cannot be cloned, but there is a "
|
||
|
"workaround where the model is reset in-place. "
|
||
|
"To use this, please set the "
|
||
|
"argument `in_place_reset` to `True`. This will reset "
|
||
|
"the attributes in the original model. "
|
||
|
"To restore the attributes, call "
|
||
|
"`in_place_subclassed_model_state_restoration(model)`."
|
||
|
)
|
||
|
clone = model
|
||
|
_in_place_subclassed_model_reset(clone)
|
||
|
if input_tensors is not None:
|
||
|
if (
|
||
|
isinstance(input_tensors, (list, tuple))
|
||
|
and len(input_tensors) == 1
|
||
|
):
|
||
|
input_tensors = input_tensors[0]
|
||
|
clone._set_inputs(input_tensors)
|
||
|
|
||
|
if compile_clone:
|
||
|
if isinstance(orig_optimizer, optimizer_v1.TFOptimizer):
|
||
|
optimizer = optimizer_v1.TFOptimizer(
|
||
|
orig_optimizer.optimizer, optimizer_iterations
|
||
|
)
|
||
|
backend.track_tf_optimizer(optimizer)
|
||
|
else:
|
||
|
if not isinstance(orig_optimizer, (tuple, list)):
|
||
|
orig_optimizer = [orig_optimizer]
|
||
|
if optimizer_config is None:
|
||
|
optimizer = [
|
||
|
opt.__class__.from_config(opt.get_config())
|
||
|
for opt in orig_optimizer
|
||
|
]
|
||
|
elif isinstance(optimizer_config, dict):
|
||
|
optimizer = [
|
||
|
orig_optimizer[0].__class__.from_config(optimizer_config)
|
||
|
]
|
||
|
else:
|
||
|
# optimizer config is list of dict, same order as
|
||
|
# orig_optimizer.
|
||
|
optimizer = [
|
||
|
opt.__class__.from_config(opt_config)
|
||
|
for (opt, opt_config) in zip(
|
||
|
orig_optimizer, optimizer_config
|
||
|
)
|
||
|
]
|
||
|
if optimizer_iterations is not None:
|
||
|
for opt in optimizer:
|
||
|
opt.iterations = optimizer_iterations
|
||
|
|
||
|
if len(optimizer) == 1:
|
||
|
optimizer = optimizer[0]
|
||
|
|
||
|
compile_args["optimizer"] = optimizer
|
||
|
if target_tensors is not None:
|
||
|
compile_args["target_tensors"] = target_tensors
|
||
|
# Ensure Metric objects in new model are separate from existing model.
|
||
|
compile_args["metrics"] = metrics_module.clone_metrics(
|
||
|
compile_args["metrics"]
|
||
|
)
|
||
|
compile_args["weighted_metrics"] = metrics_module.clone_metrics(
|
||
|
compile_args["weighted_metrics"]
|
||
|
)
|
||
|
clone.compile(**compile_args)
|
||
|
|
||
|
return clone
|