1688 lines
68 KiB
Python
1688 lines
68 KiB
Python
![]() |
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
|
||
|
"""A `Network` is way to compose layers: the topological form of a `Model`."""
|
||
|
|
||
|
|
||
|
import collections
|
||
|
import copy
|
||
|
import itertools
|
||
|
import warnings
|
||
|
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
|
||
|
from keras import backend
|
||
|
from keras.dtensor import layout_map as layout_map_lib
|
||
|
from keras.engine import base_layer
|
||
|
from keras.engine import base_layer_utils
|
||
|
from keras.engine import functional_utils
|
||
|
from keras.engine import input_layer as input_layer_module
|
||
|
from keras.engine import input_spec
|
||
|
from keras.engine import node as node_module
|
||
|
from keras.engine import training as training_lib
|
||
|
from keras.engine import training_utils
|
||
|
from keras.saving.legacy import serialization
|
||
|
from keras.saving.legacy.saved_model import json_utils
|
||
|
from keras.saving.legacy.saved_model import network_serialization
|
||
|
from keras.saving.legacy.saved_model import utils as saved_model_utils
|
||
|
from keras.utils import generic_utils
|
||
|
from keras.utils import tf_inspect
|
||
|
from keras.utils import tf_utils
|
||
|
|
||
|
# isort: off
|
||
|
from tensorflow.python.platform import tf_logging as logging
|
||
|
from tensorflow.tools.docs import doc_controls
|
||
|
|
||
|
|
||
|
class Functional(training_lib.Model):
|
||
|
"""A `Functional` model is a `Model` defined as a directed graph of layers.
|
||
|
|
||
|
Three types of `Model` exist: subclassed `Model`, `Functional` model,
|
||
|
and `Sequential` (a special case of `Functional`).
|
||
|
In general, more Keras features are supported with `Functional`
|
||
|
than with subclassed `Model`s, specifically:
|
||
|
|
||
|
- Model cloning (`keras.models.clone`)
|
||
|
- Serialization (`model.get_config()/from_config`, `model.to_json()`
|
||
|
- Whole-model saving (`model.save()`)
|
||
|
|
||
|
A `Functional` model can be instantiated by passing two arguments to
|
||
|
`__init__`. The first argument is the `keras.Input` Tensors that represent
|
||
|
the inputs to the model. The second argument specifies the output
|
||
|
tensors that represent the outputs of this model. Both arguments can be a
|
||
|
nested structure of tensors.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```
|
||
|
inputs = {'x1': keras.Input(shape=(10,)), 'x2': keras.Input(shape=(1,))}
|
||
|
t = keras.layers.Dense(1, activation='relu')(inputs['x1'])
|
||
|
outputs = keras.layers.Add()([t, inputs['x2'])
|
||
|
model = keras.Model(inputs, outputs)
|
||
|
```
|
||
|
|
||
|
A `Functional` model constructed using the Functional API can also include
|
||
|
raw TensorFlow functions, with the exception of functions that create
|
||
|
Variables or assign ops.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
inputs = keras.Input(shape=(10,))
|
||
|
x = keras.layers.Dense(1)(inputs)
|
||
|
outputs = tf.nn.relu(x)
|
||
|
model = keras.Model(inputs, outputs)
|
||
|
```
|
||
|
|
||
|
A new `Functional` model can also be created by using the
|
||
|
intermediate tensors. This enables you to quickly extract sub-components
|
||
|
of the model.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
inputs = keras.Input(shape=(None, None, 3))
|
||
|
processed = keras.layers.RandomCrop(width=32, height=32)(inputs)
|
||
|
conv = keras.layers.Conv2D(filters=2, kernel_size=3)(processed)
|
||
|
pooling = keras.layers.GlobalAveragePooling2D()(conv)
|
||
|
feature = keras.layers.Dense(10)(pooling)
|
||
|
|
||
|
full_model = keras.Model(inputs, feature)
|
||
|
backbone = keras.Model(processed, conv)
|
||
|
activations = keras.Model(conv, feature)
|
||
|
```
|
||
|
|
||
|
Note that the `backbone` and `activations` models are not
|
||
|
created with `keras.Input` objects, but with the tensors that are originated
|
||
|
from `keras.Input` objects. Under the hood, the layers and weights will
|
||
|
be shared across these models, so that user can train the `full_model`, and
|
||
|
use `backbone` or `activations` to do feature extraction.
|
||
|
The inputs and outputs of the model can be nested structures of tensors as
|
||
|
well, and the created models are standard `Functional` model that support
|
||
|
all the existing API.
|
||
|
|
||
|
Args:
|
||
|
inputs: List of input tensors (must be created via `tf.keras.Input()` or
|
||
|
originated from `tf.keras.Input()`).
|
||
|
outputs: List of output tensors.
|
||
|
name: String, optional. Name of the model.
|
||
|
trainable: Boolean, optional. If the model's variables should be
|
||
|
trainable.
|
||
|
"""
|
||
|
|
||
|
# See tf.Module for the usage of this property.
|
||
|
# The key of _layer_call_argspecs is a layer. tf.Module._flatten will fail
|
||
|
# to flatten the key since it is trying to convert Trackable/Layer to a
|
||
|
# string.
|
||
|
_TF_MODULE_IGNORED_PROPERTIES = frozenset(
|
||
|
itertools.chain(
|
||
|
(
|
||
|
"_layer_call_argspecs",
|
||
|
"_compiled_trainable_state",
|
||
|
"_output_mask_cache",
|
||
|
"_output_tensor_cache",
|
||
|
"_output_shape_cache",
|
||
|
),
|
||
|
training_lib.Model._TF_MODULE_IGNORED_PROPERTIES,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
@tf.__internal__.tracking.no_automatic_dependency_tracking
|
||
|
def __init__(self, inputs, outputs, name=None, trainable=True, **kwargs):
|
||
|
# This is used by the Model class, since we have some logic to swap the
|
||
|
# class in the __new__ method, which will lead to __init__ get invoked
|
||
|
# twice. Using the skip_init to skip one of the invocation of __init__
|
||
|
# to avoid any side effects
|
||
|
skip_init = kwargs.pop("skip_init", False)
|
||
|
if skip_init:
|
||
|
return
|
||
|
generic_utils.validate_kwargs(kwargs, {})
|
||
|
super().__init__(name=name, trainable=trainable)
|
||
|
# Check if the inputs contain any intermediate `KerasTensor` (not
|
||
|
# created by tf.keras.Input()). In this case we need to clone the `Node`
|
||
|
# and `KerasTensor` objects to mimic rebuilding a new model from new
|
||
|
# inputs. This feature is only enabled in TF2 not in v1 graph mode.
|
||
|
if tf.compat.v1.executing_eagerly_outside_functions():
|
||
|
if not all(
|
||
|
[
|
||
|
functional_utils.is_input_keras_tensor(t)
|
||
|
for t in tf.nest.flatten(inputs)
|
||
|
]
|
||
|
):
|
||
|
inputs, outputs = functional_utils.clone_graph_nodes(
|
||
|
inputs, outputs
|
||
|
)
|
||
|
self._init_graph_network(inputs, outputs)
|
||
|
|
||
|
@tf.__internal__.tracking.no_automatic_dependency_tracking
|
||
|
def _init_graph_network(self, inputs, outputs):
|
||
|
# This method is needed for Sequential to reinitialize graph network
|
||
|
# when layer is added or removed.
|
||
|
|
||
|
base_layer.keras_api_gauge.get_cell("Functional").set(True)
|
||
|
self._is_graph_network = True
|
||
|
|
||
|
# Normalize and set self.inputs, self.outputs.
|
||
|
if isinstance(inputs, list) and len(tf.nest.flatten(inputs)) == 1:
|
||
|
inputs = inputs[0]
|
||
|
if isinstance(outputs, list) and len(tf.nest.flatten(outputs)) == 1:
|
||
|
outputs = outputs[0]
|
||
|
self._nested_inputs = inputs
|
||
|
self._nested_outputs = outputs
|
||
|
self.inputs = tf.nest.flatten(inputs)
|
||
|
self.outputs = tf.nest.flatten(outputs)
|
||
|
|
||
|
# Models constructed with a single Tensor or list of Tensors can
|
||
|
# be called with a dict, where the keys of the dict are the names
|
||
|
# of the `Input` objects. Extra keys are ignored with warning.
|
||
|
if not tf.nest.is_nested(self._nested_inputs):
|
||
|
self._enable_dict_to_input_mapping = True
|
||
|
elif isinstance(self._nested_inputs, (list, tuple)) and not any(
|
||
|
tf.nest.is_nested(t) for t in self._nested_inputs
|
||
|
):
|
||
|
self._enable_dict_to_input_mapping = True
|
||
|
elif isinstance(self._nested_inputs, dict) and not any(
|
||
|
tf.nest.is_nested(t) for t in self._nested_inputs.values()
|
||
|
):
|
||
|
self._enable_dict_to_input_mapping = True
|
||
|
else:
|
||
|
self._enable_dict_to_input_mapping = False
|
||
|
|
||
|
if not tf.compat.v1.executing_eagerly_outside_functions():
|
||
|
if any(
|
||
|
not hasattr(tensor, "_keras_history") for tensor in self.outputs
|
||
|
):
|
||
|
base_layer_utils.create_keras_history(self._nested_outputs)
|
||
|
|
||
|
self._validate_graph_inputs_and_outputs()
|
||
|
|
||
|
# A Network does not create weights of its own, thus it is already
|
||
|
# built.
|
||
|
self.built = True
|
||
|
self._build_input_shape = tf.nest.map_structure(
|
||
|
lambda x: x.shape, inputs
|
||
|
)
|
||
|
self._compute_output_and_mask_jointly = True
|
||
|
# `_expects_training_arg` is True since the `training` argument is
|
||
|
# always present in the signature of the `call` method of a graph
|
||
|
# network.
|
||
|
self._call_spec.expects_training_arg = True
|
||
|
self._call_spec.expects_mask_arg = True
|
||
|
# A graph network does not autocast inputs, as its layers will cast them
|
||
|
# instead.
|
||
|
self._autocast = False
|
||
|
|
||
|
self._input_layers = []
|
||
|
self._output_layers = []
|
||
|
self._input_coordinates = []
|
||
|
self._output_coordinates = []
|
||
|
|
||
|
# This is for performance optimization when calling the Network on new
|
||
|
# inputs. Every time the Network is called on a set on input tensors, we
|
||
|
# compute the output tensors, output masks and output shapes in one
|
||
|
# pass, then cache them here. When any of these outputs is queried
|
||
|
# later, we retrieve it from there instead of recomputing it.
|
||
|
self._output_mask_cache = {}
|
||
|
self._output_tensor_cache = {}
|
||
|
self._output_shape_cache = {}
|
||
|
|
||
|
# Build self._output_layers:
|
||
|
for x in self.outputs:
|
||
|
(
|
||
|
layer,
|
||
|
node_index,
|
||
|
tensor_index,
|
||
|
) = x._keras_history
|
||
|
self._output_layers.append(layer)
|
||
|
self._output_coordinates.append((layer, node_index, tensor_index))
|
||
|
|
||
|
# Build self._input_layers:
|
||
|
for x in self.inputs:
|
||
|
(
|
||
|
layer,
|
||
|
node_index,
|
||
|
tensor_index,
|
||
|
) = x._keras_history
|
||
|
# It's supposed to be an input layer, so only one node
|
||
|
# and one tensor output.
|
||
|
assert node_index == 0
|
||
|
assert tensor_index == 0
|
||
|
self._input_layers.append(layer)
|
||
|
self._input_coordinates.append((layer, node_index, tensor_index))
|
||
|
|
||
|
# Keep track of the network's nodes and layers.
|
||
|
nodes, nodes_by_depth, layers, _ = _map_graph_network(
|
||
|
self.inputs, self.outputs
|
||
|
)
|
||
|
self._network_nodes = nodes
|
||
|
self._nodes_by_depth = nodes_by_depth
|
||
|
self._self_tracked_trackables = layers
|
||
|
self._layer_call_argspecs = {}
|
||
|
for layer in self._self_tracked_trackables:
|
||
|
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(
|
||
|
layer.call
|
||
|
)
|
||
|
|
||
|
# Build self.input_names and self.output_names.
|
||
|
self._set_output_names()
|
||
|
self.input_names = []
|
||
|
self._feed_input_names = []
|
||
|
self._feed_inputs = []
|
||
|
self._feed_input_shapes = []
|
||
|
for layer in self._input_layers:
|
||
|
self.input_names.append(layer.name)
|
||
|
if layer.is_placeholder:
|
||
|
self._feed_input_names.append(layer.name)
|
||
|
# Use batch_input_shape here because non-eager composite tensors
|
||
|
# may not have a shape attribute that's meaningful (sparse, for
|
||
|
# instance, has a tensor that's non-constant and needs to be
|
||
|
# fed). This means that input layers that create placeholders
|
||
|
# will need to have the batch_input_shape attr to allow for
|
||
|
# input shape validation.
|
||
|
self._feed_input_shapes.append(layer._batch_input_shape)
|
||
|
self._feed_inputs.append(layer.input)
|
||
|
|
||
|
self._compute_tensor_usage_count()
|
||
|
self._set_save_spec(self._nested_inputs)
|
||
|
tf_utils.assert_no_legacy_layers(self.layers)
|
||
|
|
||
|
# Note that this method is used by both functional and sequential
|
||
|
# models, so we can't just have this method in functional.__init__,
|
||
|
# which will miss the coverage of sequential model.
|
||
|
if self._layout_map is not None:
|
||
|
layout_map_lib._map_functional_model_variable(
|
||
|
self, self._layout_map
|
||
|
)
|
||
|
|
||
|
@property
|
||
|
def input(self):
|
||
|
"""Retrieves the input tensor(s) of a layer.
|
||
|
|
||
|
Only applicable if the layer has exactly one input,
|
||
|
i.e. if it is connected to one incoming layer.
|
||
|
|
||
|
Returns:
|
||
|
Input tensor or list of input tensors.
|
||
|
|
||
|
Raises:
|
||
|
RuntimeError: If called in Eager mode.
|
||
|
AttributeError: If no inbound nodes are found.
|
||
|
"""
|
||
|
return self._nested_inputs
|
||
|
|
||
|
@property
|
||
|
def input_shape(self):
|
||
|
"""Retrieves the input shape(s) of a layer.
|
||
|
|
||
|
Only applicable if the layer has exactly one input,
|
||
|
i.e. if it is connected to one incoming layer, or if all inputs
|
||
|
have the same shape.
|
||
|
|
||
|
Returns:
|
||
|
Input shape, as an integer shape tuple
|
||
|
(or list of shape tuples, one tuple per input tensor).
|
||
|
|
||
|
Raises:
|
||
|
AttributeError: if the layer has no defined input_shape.
|
||
|
RuntimeError: if called in Eager mode.
|
||
|
"""
|
||
|
return tf.nest.map_structure(backend.int_shape, self.input)
|
||
|
|
||
|
@property
|
||
|
def input_spec(self):
|
||
|
if hasattr(self, "_manual_input_spec"):
|
||
|
return self._manual_input_spec
|
||
|
if isinstance(self._nested_inputs, (dict, list, tuple)) and len(
|
||
|
self._nested_inputs
|
||
|
) != len(self.inputs):
|
||
|
# Case where we have a nested structure.
|
||
|
# In such a case we can't safely run any checks.
|
||
|
return None
|
||
|
if isinstance(self._nested_inputs, dict):
|
||
|
# Case where `_nested_inputs` is a plain dict of Inputs.
|
||
|
names = sorted(self._nested_inputs.keys())
|
||
|
return [
|
||
|
input_spec.InputSpec(
|
||
|
shape=shape_with_no_batch_size(self._nested_inputs[name]),
|
||
|
allow_last_axis_squeeze=True,
|
||
|
name=name,
|
||
|
)
|
||
|
for name in names
|
||
|
]
|
||
|
else:
|
||
|
# Single input, or list / tuple of inputs.
|
||
|
# The data may be passed as a dict keyed by input name.
|
||
|
return [
|
||
|
input_spec.InputSpec(
|
||
|
shape=shape_with_no_batch_size(x),
|
||
|
allow_last_axis_squeeze=True,
|
||
|
name=x._keras_history.layer.name,
|
||
|
)
|
||
|
for x in self.inputs
|
||
|
]
|
||
|
|
||
|
@input_spec.setter
|
||
|
def input_spec(self, value):
|
||
|
self._manual_input_spec = value
|
||
|
|
||
|
@property
|
||
|
def output(self):
|
||
|
"""Retrieves the output tensor(s) of a layer.
|
||
|
|
||
|
Only applicable if the layer has exactly one output,
|
||
|
i.e. if it is connected to one incoming layer.
|
||
|
|
||
|
Returns:
|
||
|
Output tensor or list of output tensors.
|
||
|
|
||
|
Raises:
|
||
|
AttributeError: if the layer is connected to more than one incoming
|
||
|
layers.
|
||
|
RuntimeError: if called in Eager mode.
|
||
|
"""
|
||
|
return self._nested_outputs
|
||
|
|
||
|
@property
|
||
|
def output_shape(self):
|
||
|
"""Retrieves the output shape(s) of a layer.
|
||
|
|
||
|
Only applicable if the layer has one output,
|
||
|
or if all outputs have the same shape.
|
||
|
|
||
|
Returns:
|
||
|
Output shape, as an integer shape tuple
|
||
|
(or list of shape tuples, one tuple per output tensor).
|
||
|
|
||
|
Raises:
|
||
|
AttributeError: if the layer has no defined output shape.
|
||
|
RuntimeError: if called in Eager mode.
|
||
|
"""
|
||
|
return tf.nest.map_structure(backend.int_shape, self.output)
|
||
|
|
||
|
def _set_output_names(self):
|
||
|
"""Assigns unique names to the Network's outputs.
|
||
|
|
||
|
Output layers with multiple output tensors would otherwise lead to
|
||
|
duplicate names in self.output_names.
|
||
|
"""
|
||
|
uniquified = []
|
||
|
output_names = set()
|
||
|
prefix_count = {}
|
||
|
for layer in self._output_layers:
|
||
|
proposal = layer.name
|
||
|
while proposal in output_names:
|
||
|
existing_count = prefix_count.get(layer.name, 1)
|
||
|
proposal = f"{layer.name}_{existing_count}"
|
||
|
prefix_count[layer.name] = existing_count + 1
|
||
|
output_names.add(proposal)
|
||
|
uniquified.append(proposal)
|
||
|
self.output_names = uniquified
|
||
|
|
||
|
@property
|
||
|
def _layer_checkpoint_dependencies(self):
|
||
|
"""Dictionary of layer dependencies to be included in the checkpoint."""
|
||
|
weight_layer_index = 0
|
||
|
|
||
|
dependencies = collections.OrderedDict()
|
||
|
for layer_index, layer in enumerate(self.layers):
|
||
|
try:
|
||
|
if layer.weights:
|
||
|
# Keep a separate index for layers which have weights. This
|
||
|
# allows users to insert Layers without weights anywhere in
|
||
|
# the network without breaking checkpoints.
|
||
|
dependencies[
|
||
|
"layer_with_weights-%d" % weight_layer_index
|
||
|
] = layer
|
||
|
weight_layer_index += 1
|
||
|
except ValueError:
|
||
|
# The layer might have weights, but may not be built yet. We
|
||
|
# just treat it as layer without weight.
|
||
|
pass
|
||
|
|
||
|
# Even if it doesn't have weights, we should still track everything
|
||
|
# in case it has/will have Trackable dependencies.
|
||
|
dependencies["layer-%d" % layer_index] = layer
|
||
|
return dependencies
|
||
|
|
||
|
def _trackable_children(self, save_type="checkpoint", **kwargs):
|
||
|
dependencies = self._layer_checkpoint_dependencies
|
||
|
dependencies.update(super()._trackable_children(save_type, **kwargs))
|
||
|
return dependencies
|
||
|
|
||
|
def _lookup_dependency(self, name):
|
||
|
layer_dependencies = self._layer_checkpoint_dependencies
|
||
|
if name in layer_dependencies:
|
||
|
return layer_dependencies[name]
|
||
|
return super()._lookup_dependency(name)
|
||
|
|
||
|
def _handle_deferred_layer_dependencies(self, layers):
|
||
|
"""Handles layer checkpoint dependencies that are added after init."""
|
||
|
layer_checkpoint_dependencies = self._layer_checkpoint_dependencies
|
||
|
layer_to_name = {v: k for k, v in layer_checkpoint_dependencies.items()}
|
||
|
for layer in layers:
|
||
|
if layer in layer_to_name:
|
||
|
self._handle_deferred_dependencies(
|
||
|
name=layer_to_name[layer], trackable=layer
|
||
|
)
|
||
|
|
||
|
@property
|
||
|
def _should_compute_mask(self):
|
||
|
return True
|
||
|
|
||
|
def compute_mask(self, inputs, mask):
|
||
|
# TODO(omalleyt): b/123540974 This function is not really safe to call
|
||
|
# by itself because it will duplicate any updates and losses in graph
|
||
|
# mode by `call`ing the Layers again.
|
||
|
output_tensors = self._run_internal_graph(inputs, mask=mask)
|
||
|
return tf.nest.map_structure(
|
||
|
lambda t: getattr(t, "_keras_mask", None), output_tensors
|
||
|
)
|
||
|
|
||
|
@doc_controls.do_not_doc_inheritable
|
||
|
def call(self, inputs, training=None, mask=None):
|
||
|
"""Calls the model on new inputs.
|
||
|
|
||
|
In this case `call` just reapplies
|
||
|
all ops in the graph to the new inputs
|
||
|
(e.g. build a new computational graph from the provided inputs).
|
||
|
|
||
|
Args:
|
||
|
inputs: A tensor or list of tensors.
|
||
|
training: Boolean or boolean scalar tensor, indicating whether to
|
||
|
run the `Network` in training mode or inference mode.
|
||
|
mask: A mask or list of masks. A mask can be
|
||
|
either a tensor or None (no mask).
|
||
|
|
||
|
Returns:
|
||
|
A tensor if there is a single output, or
|
||
|
a list of tensors if there are more than one outputs.
|
||
|
"""
|
||
|
return self._run_internal_graph(inputs, training=training, mask=mask)
|
||
|
|
||
|
def compute_output_shape(self, input_shape):
|
||
|
# Convert any shapes in tuple format to TensorShapes.
|
||
|
input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
|
||
|
|
||
|
if len(tf.nest.flatten(input_shape)) != len(
|
||
|
tf.nest.flatten(self._input_layers)
|
||
|
):
|
||
|
raise ValueError(
|
||
|
f"Invalid `input_shape` argument {input_shape}: "
|
||
|
f"the model expects {len(self._input_layers)} "
|
||
|
"input tensors."
|
||
|
)
|
||
|
|
||
|
# Use the tuple of TensorShape as the cache key, since tuple is hashable
|
||
|
# and can be used as hash key.
|
||
|
try:
|
||
|
cache_key = tuple(
|
||
|
tf_utils.convert_shapes(input_shape, to_tuples=True)
|
||
|
)
|
||
|
if cache_key in self._output_shape_cache:
|
||
|
# Cache hit. Return shapes as TensorShapes.
|
||
|
return self._output_shape_cache[cache_key]
|
||
|
except ValueError:
|
||
|
# In case there are unknown TensorShape, eg for sparse tensor input,
|
||
|
# We skip the caching since the shape is unknown.
|
||
|
pass
|
||
|
|
||
|
layers_to_output_shapes = {}
|
||
|
for layer, shape in zip(
|
||
|
self._input_layers, tf.nest.flatten(input_shape)
|
||
|
):
|
||
|
# It's an input layer: then `compute_output_shape` is identity,
|
||
|
# and there is only one node and one tensor..
|
||
|
shape_key = layer.name + "_0_0"
|
||
|
layers_to_output_shapes[shape_key] = shape
|
||
|
|
||
|
depth_keys = list(self._nodes_by_depth.keys())
|
||
|
depth_keys.sort(reverse=True)
|
||
|
# Iterate over nodes, by depth level.
|
||
|
if len(depth_keys) > 1:
|
||
|
for depth in depth_keys:
|
||
|
nodes = self._nodes_by_depth[depth]
|
||
|
for node in nodes:
|
||
|
layer = node.layer
|
||
|
if layer in self._input_layers:
|
||
|
# We've already covered the input layers
|
||
|
# a few lines above.
|
||
|
continue
|
||
|
# Get the input shapes for the first argument of the node
|
||
|
layer_input_shapes = []
|
||
|
layer_inputs = node.call_args[0]
|
||
|
for layer_input in tf.nest.flatten(layer_inputs):
|
||
|
kh = layer_input._keras_history
|
||
|
input_layer_key = kh.layer.name + "_%s_%s" % (
|
||
|
kh.node_index,
|
||
|
kh.tensor_index,
|
||
|
)
|
||
|
layer_input_shapes.append(
|
||
|
layers_to_output_shapes[input_layer_key]
|
||
|
)
|
||
|
layer_input_shapes = tf.nest.pack_sequence_as(
|
||
|
layer_inputs, layer_input_shapes
|
||
|
)
|
||
|
# Layers expect shapes to be tuples for
|
||
|
# `compute_output_shape`.
|
||
|
layer_input_shapes = tf_utils.convert_shapes(
|
||
|
layer_input_shapes, to_tuples=True
|
||
|
)
|
||
|
layer_output_shapes = layer.compute_output_shape(
|
||
|
layer_input_shapes
|
||
|
)
|
||
|
# Convert back to TensorShapes.
|
||
|
layer_output_shapes = tf_utils.convert_shapes(
|
||
|
layer_output_shapes, to_tuples=False
|
||
|
)
|
||
|
|
||
|
node_index = layer._inbound_nodes.index(node)
|
||
|
for j, shape in enumerate(
|
||
|
tf.nest.flatten(layer_output_shapes)
|
||
|
):
|
||
|
shape_key = layer.name + f"_{node_index}_{j}"
|
||
|
layers_to_output_shapes[shape_key] = shape
|
||
|
|
||
|
# Read final output shapes from layers_to_output_shapes.
|
||
|
output_shapes = []
|
||
|
for i in range(len(self._output_layers)):
|
||
|
layer, node_index, tensor_index = self._output_coordinates[i]
|
||
|
shape_key = layer.name + f"_{node_index}_{tensor_index}"
|
||
|
output_shapes.append(layers_to_output_shapes[shape_key])
|
||
|
output_shapes = tf.nest.pack_sequence_as(
|
||
|
self._nested_outputs, output_shapes
|
||
|
)
|
||
|
# Store in cache.
|
||
|
self._output_shape_cache[cache_key] = output_shapes
|
||
|
|
||
|
# Return shapes as TensorShapes.
|
||
|
return output_shapes
|
||
|
|
||
|
def _init_set_name(self, name, zero_based=True):
|
||
|
if not name:
|
||
|
cls_name = self.__class__.__name__
|
||
|
if self.__class__ == Functional:
|
||
|
# Hide the functional class name from user, since its not a
|
||
|
# public visible class. Use "Model" instead,
|
||
|
cls_name = "Model"
|
||
|
self._name = backend.unique_object_name(
|
||
|
generic_utils.to_snake_case(cls_name), zero_based=zero_based
|
||
|
)
|
||
|
else:
|
||
|
self._name = name
|
||
|
|
||
|
def _run_internal_graph(self, inputs, training=None, mask=None):
|
||
|
"""Computes output tensors for new inputs.
|
||
|
|
||
|
# Note:
|
||
|
- Can be run on non-Keras tensors.
|
||
|
|
||
|
Args:
|
||
|
inputs: Tensor or nested structure of Tensors.
|
||
|
training: Boolean learning phase.
|
||
|
mask: (Optional) Tensor or nested structure of Tensors.
|
||
|
|
||
|
Returns:
|
||
|
output_tensors
|
||
|
"""
|
||
|
inputs = self._flatten_to_reference_inputs(inputs)
|
||
|
if mask is None:
|
||
|
masks = [None] * len(inputs)
|
||
|
else:
|
||
|
masks = self._flatten_to_reference_inputs(mask)
|
||
|
for input_t, mask in zip(inputs, masks):
|
||
|
input_t._keras_mask = mask
|
||
|
|
||
|
# Dictionary mapping reference tensors to computed tensors.
|
||
|
tensor_dict = {}
|
||
|
tensor_usage_count = self._tensor_usage_count
|
||
|
for x, y in zip(self.inputs, inputs):
|
||
|
y = self._conform_to_reference_input(y, ref_input=x)
|
||
|
x_id = str(id(x))
|
||
|
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
|
||
|
|
||
|
nodes_by_depth = self._nodes_by_depth
|
||
|
depth_keys = list(nodes_by_depth.keys())
|
||
|
depth_keys.sort(reverse=True)
|
||
|
|
||
|
for depth in depth_keys:
|
||
|
nodes = nodes_by_depth[depth]
|
||
|
for node in nodes:
|
||
|
if node.is_input:
|
||
|
continue # Input tensors already exist.
|
||
|
|
||
|
if any(t_id not in tensor_dict for t_id in node.flat_input_ids):
|
||
|
continue # Node is not computable, try skipping.
|
||
|
|
||
|
args, kwargs = node.map_arguments(tensor_dict)
|
||
|
outputs = node.layer(*args, **kwargs)
|
||
|
|
||
|
# Update tensor_dict.
|
||
|
for x_id, y in zip(
|
||
|
node.flat_output_ids, tf.nest.flatten(outputs)
|
||
|
):
|
||
|
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
|
||
|
|
||
|
output_tensors = []
|
||
|
for x in self.outputs:
|
||
|
x_id = str(id(x))
|
||
|
assert x_id in tensor_dict, "Could not compute output " + str(x)
|
||
|
output_tensors.append(tensor_dict[x_id].pop())
|
||
|
|
||
|
return tf.nest.pack_sequence_as(self._nested_outputs, output_tensors)
|
||
|
|
||
|
def _flatten_to_reference_inputs(self, tensors):
|
||
|
"""Maps `tensors` to their respective `keras.Input`."""
|
||
|
if self._enable_dict_to_input_mapping and isinstance(tensors, dict):
|
||
|
ref_inputs = self._nested_inputs
|
||
|
if not tf.nest.is_nested(ref_inputs):
|
||
|
ref_inputs = [self._nested_inputs]
|
||
|
if isinstance(ref_inputs, dict):
|
||
|
# In the case that the graph is constructed with dict input
|
||
|
# tensors, We will use the original dict key to map with the
|
||
|
# keys in the input data. Note that the model.inputs is using
|
||
|
# nest.flatten to process the input tensors, which means the
|
||
|
# dict input tensors are ordered by their keys.
|
||
|
ref_input_names = sorted(ref_inputs.keys())
|
||
|
else:
|
||
|
ref_input_names = [
|
||
|
inp._keras_history.layer.name for inp in ref_inputs
|
||
|
]
|
||
|
|
||
|
# Raise an warning if there are more input data comparing to input
|
||
|
# tensor
|
||
|
if len(tensors) > len(ref_input_names):
|
||
|
warnings.warn(
|
||
|
"Input dict contained keys {} which did not match any "
|
||
|
"model input. They will be ignored by the model.".format(
|
||
|
[n for n in tensors.keys() if n not in ref_input_names]
|
||
|
),
|
||
|
stacklevel=2,
|
||
|
)
|
||
|
|
||
|
try:
|
||
|
# Flatten in the order `Input`s were passed during Model
|
||
|
# construction.
|
||
|
return [tensors[n] for n in ref_input_names]
|
||
|
except KeyError:
|
||
|
# TODO(b/151582614)
|
||
|
return tf.nest.flatten(tensors)
|
||
|
|
||
|
# Otherwise both self.inputs and tensors will already be in same order.
|
||
|
return tf.nest.flatten(tensors)
|
||
|
|
||
|
def _conform_to_reference_input(self, tensor, ref_input):
|
||
|
"""Set shape and dtype based on `keras.Input`s."""
|
||
|
if isinstance(tensor, tf.Tensor):
|
||
|
# Allow (None,) and (None, 1) Tensors to be passed interchangeably.
|
||
|
# Use the shape specified by the `keras.Input`.
|
||
|
t_shape = tensor.shape
|
||
|
t_rank = t_shape.rank
|
||
|
ref_shape = ref_input.shape
|
||
|
ref_rank = ref_shape.rank
|
||
|
keras_history = getattr(tensor, "_keras_history", None)
|
||
|
if t_rank is not None and ref_rank is not None:
|
||
|
# Should squeeze last dimension. True if tensor is (BATCH, ...,
|
||
|
# 1) and reference is (BATCH, ...).
|
||
|
if t_rank == ref_rank + 1 and t_shape[-1] == 1:
|
||
|
tensor = tf.squeeze(tensor, axis=-1)
|
||
|
# Should expand last_dimension. True if tensor is (BATCH, ...)
|
||
|
# and reference is (BATCH, ..., 1).
|
||
|
elif t_rank == ref_rank - 1 and ref_shape[-1] == 1:
|
||
|
tensor = tf.expand_dims(tensor, axis=-1)
|
||
|
if keras_history is not None: # Restore keras history.
|
||
|
tensor._keras_history = keras_history
|
||
|
|
||
|
# Dtype casting.
|
||
|
tensor = tf.cast(tensor, dtype=ref_input.dtype)
|
||
|
elif tf_utils.is_extension_type(tensor):
|
||
|
# Dtype casting (If the extension type has a non-variant dtype and
|
||
|
# supports being cast). Only cast if necessary (since some
|
||
|
# extension types may not implement tf.cast).
|
||
|
tensor_dtype = getattr(tensor, "dtype", None)
|
||
|
ref_input_dtype = getattr(ref_input, "dtype", None)
|
||
|
if (
|
||
|
ref_input_dtype is not None
|
||
|
and tensor_dtype is not None
|
||
|
and tensor_dtype != ref_input_dtype
|
||
|
and ref_input_dtype != tf.variant
|
||
|
):
|
||
|
tensor = tf.cast(tensor, dtype=ref_input_dtype)
|
||
|
|
||
|
return tensor
|
||
|
|
||
|
@generic_utils.default
|
||
|
def get_config(self):
|
||
|
# Prepare base arguments
|
||
|
config = {
|
||
|
"name": self.name,
|
||
|
"trainable": self.trainable,
|
||
|
}
|
||
|
|
||
|
if saved_model_utils.in_tf_saved_model_scope():
|
||
|
# SavedModel special case: need to preserve legacy (potentially
|
||
|
# incorrect) behavior.
|
||
|
return copy.deepcopy(get_network_config(self, config=config))
|
||
|
|
||
|
# Check whether the class has a constructor compatible with a Functional
|
||
|
# model or if it has a custom constructor.
|
||
|
if has_functional_like_constructor(self.__class__):
|
||
|
# Only return a Functional config if the constructor is the same
|
||
|
# as that of a Functional model. This excludes subclassed Functional
|
||
|
# models with a custom __init__.
|
||
|
config = copy.deepcopy(get_network_config(self, config=config))
|
||
|
else:
|
||
|
# Try to autogenerate config
|
||
|
xtra_args = set(config.keys())
|
||
|
if getattr(self, "_auto_get_config", False):
|
||
|
config.update(self._auto_config.config)
|
||
|
# Remove args non explicitly supported
|
||
|
argspec = tf_inspect.getfullargspec(self.__init__)
|
||
|
if argspec.varkw != "kwargs":
|
||
|
for key in xtra_args - xtra_args.intersection(argspec.args[1:]):
|
||
|
config.pop(key, None)
|
||
|
return config
|
||
|
|
||
|
def get_weight_paths(self):
|
||
|
result = {}
|
||
|
for layer in self.layers:
|
||
|
(
|
||
|
descendants,
|
||
|
object_paths_dict,
|
||
|
) = tf.__internal__.tracking.ObjectGraphView(
|
||
|
layer
|
||
|
).breadth_first_traversal()
|
||
|
for descendant in descendants:
|
||
|
if isinstance(descendant, tf.Variable):
|
||
|
trackable_references = object_paths_dict[descendant]
|
||
|
object_path = ".".join(
|
||
|
[t.name for t in trackable_references]
|
||
|
)
|
||
|
result[layer.name + "." + object_path] = descendant
|
||
|
return result
|
||
|
|
||
|
def _validate_graph_inputs_and_outputs(self):
|
||
|
"""Validates the inputs and outputs of a Graph Network."""
|
||
|
# Check for redundancy in inputs.
|
||
|
if len({id(i) for i in self.inputs}) != len(self.inputs):
|
||
|
raise ValueError(
|
||
|
"The list of inputs passed to the model "
|
||
|
"contains the same input multiple times. "
|
||
|
"All inputs should only appear once."
|
||
|
f"Received inputs={self.inputs}"
|
||
|
)
|
||
|
|
||
|
for x in self.inputs:
|
||
|
# Check that x has appropriate `_keras_history` metadata.
|
||
|
if not hasattr(x, "_keras_history"):
|
||
|
cls_name = self.__class__.__name__
|
||
|
raise ValueError(
|
||
|
f"Input tensors to a {cls_name} model "
|
||
|
"must come from `tf.keras.Input`. "
|
||
|
f"Received inputs={x} (missing previous layer metadata)."
|
||
|
)
|
||
|
# Check that x is an input tensor.
|
||
|
|
||
|
layer = x._keras_history.layer
|
||
|
if len(layer._inbound_nodes) > 1 or (
|
||
|
layer._inbound_nodes and not layer._inbound_nodes[0].is_input
|
||
|
):
|
||
|
cls_name = self.__class__.__name__
|
||
|
logging.warning(
|
||
|
f"{cls_name} model inputs must come from "
|
||
|
"`tf.keras.Input` (thus holding past layer metadata). "
|
||
|
"They cannot be the output of "
|
||
|
"a previous non-Input layer. "
|
||
|
"Here, a tensor specified as "
|
||
|
f'input to "{self.name}" was not an Input tensor, '
|
||
|
f'it was generated by layer "{layer.name}".\n'
|
||
|
"Note that input tensors are "
|
||
|
"instantiated via `tensor = tf.keras.Input(shape)`.\n"
|
||
|
f"The tensor that caused the issue was: {x}"
|
||
|
)
|
||
|
|
||
|
# Check compatibility of batch sizes of Input Layers.
|
||
|
input_batch_sizes = set(
|
||
|
[
|
||
|
training_utils.get_static_batch_size(x._keras_history.layer)
|
||
|
for x in self.inputs
|
||
|
]
|
||
|
)
|
||
|
input_batch_sizes.discard(None)
|
||
|
if len(input_batch_sizes) > 1:
|
||
|
logging.warning(
|
||
|
"Found incompatible static batch sizes among the "
|
||
|
f"inputs. Batch sizes: {sorted(input_batch_sizes)}"
|
||
|
)
|
||
|
|
||
|
for x in self.outputs:
|
||
|
if not hasattr(x, "_keras_history"):
|
||
|
cls_name = self.__class__.__name__
|
||
|
raise ValueError(
|
||
|
f"Output tensors of a {cls_name} model must be "
|
||
|
"the output of a TensorFlow `Layer` "
|
||
|
f"(thus holding past layer metadata). Found: {x}"
|
||
|
)
|
||
|
|
||
|
def _insert_layers(self, layers, relevant_nodes=None):
|
||
|
"""Inserts Layers into the Network after Network creation.
|
||
|
|
||
|
This is only valid for Keras Graph Networks. Layers added via this
|
||
|
function will be included in the `call` computation and `get_config` of
|
||
|
this Network. They will not be added to the Network's outputs.
|
||
|
|
||
|
Args:
|
||
|
layers: Arbitrary nested structure of Layers. Layers must be reachable
|
||
|
from one or more of the `keras.Input` Tensors that correspond to
|
||
|
this Network's inputs.
|
||
|
relevant_nodes: Nodes from the Layers that should be considered part
|
||
|
of this Network. If `None`, all Nodes will be considered part of
|
||
|
this Network.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If the layers depend on `Input`s not found in this Model.
|
||
|
"""
|
||
|
layers = tf.nest.flatten(layers)
|
||
|
tf_utils.assert_no_legacy_layers(layers)
|
||
|
node_to_depth = {}
|
||
|
for depth, nodes in self._nodes_by_depth.items():
|
||
|
node_to_depth.update({node: depth for node in nodes})
|
||
|
# The nodes of these Layers that are relevant to this Network. If not
|
||
|
# provided, assume all Nodes are relevant
|
||
|
if not relevant_nodes:
|
||
|
relevant_nodes = tf.nest.flatten(
|
||
|
[layer._inbound_nodes for layer in layers]
|
||
|
)
|
||
|
network_nodes = set(relevant_nodes + list(node_to_depth.keys()))
|
||
|
|
||
|
def _get_min_depth(node):
|
||
|
"""Gets the minimum depth at which node can be computed."""
|
||
|
min_depth = 0
|
||
|
for layer, node_id, _, _ in node.iterate_inbound():
|
||
|
inbound_node = layer._inbound_nodes[node_id]
|
||
|
if inbound_node in node_to_depth:
|
||
|
min_depth = min(min_depth, node_to_depth[inbound_node])
|
||
|
elif inbound_node not in network_nodes:
|
||
|
continue
|
||
|
else:
|
||
|
# Previous relevant nodes haven't been processed yet.
|
||
|
return None
|
||
|
# New node is one shallower than its shallowest input.
|
||
|
return min_depth - 1
|
||
|
|
||
|
# Insert nodes into `_nodes_by_depth` and other node attrs.
|
||
|
unprocessed_nodes = copy.copy(relevant_nodes)
|
||
|
i = 0
|
||
|
while unprocessed_nodes:
|
||
|
i += 1
|
||
|
# Do a sanity check. This can occur if `Input`s from outside this
|
||
|
# Model are being relied on.
|
||
|
if i > 10000:
|
||
|
raise ValueError(
|
||
|
"Layers could not be added due to missing dependencies."
|
||
|
)
|
||
|
|
||
|
node = unprocessed_nodes.pop(0)
|
||
|
depth = _get_min_depth(node)
|
||
|
if depth is None: # Defer until inbound nodes are processed.
|
||
|
unprocessed_nodes.append(node)
|
||
|
continue
|
||
|
node_key = _make_node_key(
|
||
|
node.layer.name, node.layer._inbound_nodes.index(node)
|
||
|
)
|
||
|
if node_key not in self._network_nodes:
|
||
|
node_to_depth[node] = depth
|
||
|
self._network_nodes.add(node_key)
|
||
|
self._nodes_by_depth[depth].append(node)
|
||
|
|
||
|
# Insert layers and update other layer attrs.
|
||
|
layer_set = set(self._self_tracked_trackables)
|
||
|
deferred_layers = []
|
||
|
for layer in layers:
|
||
|
if layer not in layer_set:
|
||
|
self._self_tracked_trackables.append(layer)
|
||
|
deferred_layers.append(layer)
|
||
|
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(
|
||
|
layer.call
|
||
|
)
|
||
|
layer_set.add(layer)
|
||
|
self._handle_deferred_layer_dependencies(deferred_layers)
|
||
|
|
||
|
self._compute_tensor_usage_count()
|
||
|
|
||
|
def _compute_tensor_usage_count(self):
|
||
|
"""Compute the #. of tensor usages for all the output tensors of layers.
|
||
|
|
||
|
The computed tensor usage count is saved as `self._tensor_usage_count`.
|
||
|
This is later used for saving memory in eager computation by releasing
|
||
|
no-longer-needed tensors as early as possible.
|
||
|
"""
|
||
|
tensor_usage_count = collections.Counter()
|
||
|
available_tensors = set(str(id(tensor)) for tensor in self.inputs)
|
||
|
|
||
|
depth_keys = list(self._nodes_by_depth.keys())
|
||
|
depth_keys.sort(reverse=True)
|
||
|
depth_keys = depth_keys[1:]
|
||
|
|
||
|
for depth in depth_keys:
|
||
|
for node in self._nodes_by_depth[depth]:
|
||
|
input_tensors = {
|
||
|
str(id(tensor))
|
||
|
for tensor in tf.nest.flatten(node.keras_inputs)
|
||
|
}
|
||
|
if input_tensors.issubset(available_tensors):
|
||
|
for tensor in tf.nest.flatten(node.keras_inputs):
|
||
|
tensor_usage_count[str(id(tensor))] += 1
|
||
|
|
||
|
for output_tensor in tf.nest.flatten(node.outputs):
|
||
|
available_tensors.add(str(id(output_tensor)))
|
||
|
|
||
|
for tensor in self.outputs:
|
||
|
tensor_usage_count[str(id(tensor))] += 1
|
||
|
|
||
|
self._tensor_usage_count = tensor_usage_count
|
||
|
|
||
|
def _assert_weights_created(self):
|
||
|
# Override the implementation in Model.
|
||
|
# The Functional model should always have weight created already.
|
||
|
return
|
||
|
|
||
|
def _graph_network_add_loss(self, symbolic_loss):
|
||
|
new_nodes, new_layers = _map_subgraph_network(
|
||
|
self.inputs, [symbolic_loss]
|
||
|
)
|
||
|
# Losses must be keyed on inputs no matter what in order to be supported
|
||
|
# in DistributionStrategy.
|
||
|
add_loss_layer = base_layer.AddLoss(
|
||
|
unconditional=False, dtype=symbolic_loss.dtype
|
||
|
)
|
||
|
add_loss_layer(symbolic_loss)
|
||
|
new_nodes.extend(add_loss_layer.inbound_nodes)
|
||
|
new_layers.append(add_loss_layer)
|
||
|
self._insert_layers(new_layers, new_nodes)
|
||
|
|
||
|
def _graph_network_add_metric(self, value, aggregation, name):
|
||
|
new_nodes, new_layers = _map_subgraph_network(self.inputs, [value])
|
||
|
add_metric_layer = base_layer.AddMetric(
|
||
|
aggregation, name, dtype=value.dtype
|
||
|
)
|
||
|
add_metric_layer(value)
|
||
|
new_nodes.extend(add_metric_layer.inbound_nodes)
|
||
|
new_layers.append(add_metric_layer)
|
||
|
self._insert_layers(new_layers, new_nodes)
|
||
|
|
||
|
@property
|
||
|
def _trackable_saved_model_saver(self):
|
||
|
return network_serialization.NetworkSavedModelSaver(self)
|
||
|
|
||
|
def _get_save_spec(self, dynamic_batch=True, inputs_only=True):
|
||
|
if getattr(self, "_has_explicit_input_shape", True):
|
||
|
# Functional models and Sequential models that have an explicit
|
||
|
# input shape should use the batch size set by the input layer.
|
||
|
dynamic_batch = False
|
||
|
return super()._get_save_spec(dynamic_batch, inputs_only)
|
||
|
|
||
|
|
||
|
def _make_node_key(layer_name, node_index):
|
||
|
return layer_name + "_ib-" + str(node_index)
|
||
|
|
||
|
|
||
|
def _map_graph_network(inputs, outputs):
|
||
|
"""Validates a network's topology and gather its layers and nodes.
|
||
|
|
||
|
Args:
|
||
|
inputs: List of input tensors.
|
||
|
outputs: List of outputs tensors.
|
||
|
|
||
|
Returns:
|
||
|
A tuple `(nodes, nodes_by_depth, layers, layers_by_depth)`.
|
||
|
- nodes: list of Node instances.
|
||
|
- nodes_by_depth: dict mapping ints (depth) to lists of node instances.
|
||
|
- layers: list of Layer instances.
|
||
|
- layers_by_depth: dict mapping ints (depth) to lists of layer instances.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: In case the network is not valid (e.g. disconnected graph).
|
||
|
"""
|
||
|
# "depth" is number of layers between output Node and the Node.
|
||
|
# Nodes are ordered from inputs -> outputs.
|
||
|
nodes_in_decreasing_depth, layer_indices = _build_map(outputs)
|
||
|
network_nodes = {
|
||
|
_make_node_key(node.layer.name, node.layer._inbound_nodes.index(node))
|
||
|
for node in nodes_in_decreasing_depth
|
||
|
}
|
||
|
|
||
|
nodes_depths = {} # dict {node: depth value}
|
||
|
layers_depths = {} # dict {layer: depth value}
|
||
|
|
||
|
for node in reversed(nodes_in_decreasing_depth):
|
||
|
# If the depth is not set, the node has no outbound nodes (depth 0).
|
||
|
depth = nodes_depths.setdefault(node, 0)
|
||
|
|
||
|
# Update the depth of the corresponding layer
|
||
|
previous_depth = layers_depths.get(node.layer, 0)
|
||
|
# If we've seen this layer before at a higher depth,
|
||
|
# we should use that depth instead of the node depth.
|
||
|
# This is necessary for shared layers that have inputs at different
|
||
|
# depth levels in the graph.
|
||
|
depth = max(depth, previous_depth)
|
||
|
layers_depths[node.layer] = depth
|
||
|
nodes_depths[node] = depth
|
||
|
|
||
|
# Update the depth of inbound nodes.
|
||
|
# The "depth" of a node is the max of the depths
|
||
|
# of all nodes it is connected to + 1.
|
||
|
for node_dep in node.parent_nodes:
|
||
|
previous_depth = nodes_depths.get(node_dep, 0)
|
||
|
nodes_depths[node_dep] = max(depth + 1, previous_depth)
|
||
|
|
||
|
# Handle inputs that are not connected to outputs.
|
||
|
# We do not error out here because the inputs may be used to compute losses
|
||
|
# and metrics.
|
||
|
for input_t in inputs:
|
||
|
input_layer = input_t._keras_history[0]
|
||
|
if input_layer not in layers_depths:
|
||
|
layers_depths[input_layer] = 0
|
||
|
layer_indices[input_layer] = -1
|
||
|
nodes_depths[input_layer._inbound_nodes[0]] = 0
|
||
|
network_nodes.add(_make_node_key(input_layer.name, 0))
|
||
|
|
||
|
# Build a dict {depth: list of nodes with this depth}
|
||
|
nodes_by_depth = collections.defaultdict(list)
|
||
|
for node, depth in nodes_depths.items():
|
||
|
nodes_by_depth[depth].append(node)
|
||
|
|
||
|
# Build a dict {depth: list of layers with this depth}
|
||
|
layers_by_depth = collections.defaultdict(list)
|
||
|
for layer, depth in layers_depths.items():
|
||
|
layers_by_depth[depth].append(layer)
|
||
|
|
||
|
# Get sorted list of layer depths.
|
||
|
depth_keys = list(layers_by_depth.keys())
|
||
|
depth_keys.sort(reverse=True)
|
||
|
|
||
|
# Set self.layers ordered by depth.
|
||
|
layers = []
|
||
|
for depth in depth_keys:
|
||
|
layers_for_depth = layers_by_depth[depth]
|
||
|
# Network.layers needs to have a deterministic order:
|
||
|
# here we order them by traversal order.
|
||
|
layers_for_depth.sort(key=lambda x: layer_indices[x])
|
||
|
layers.extend(layers_for_depth)
|
||
|
|
||
|
# Get sorted list of node depths.
|
||
|
depth_keys = list(nodes_by_depth.keys())
|
||
|
depth_keys.sort(reverse=True)
|
||
|
|
||
|
# Check that all tensors required are computable.
|
||
|
# computable_tensors: all tensors in the graph
|
||
|
# that can be computed from the inputs provided.
|
||
|
computable_tensors = set()
|
||
|
for x in inputs:
|
||
|
computable_tensors.add(id(x))
|
||
|
|
||
|
layers_with_complete_input = [] # To provide a better error msg.
|
||
|
for depth in depth_keys:
|
||
|
for node in nodes_by_depth[depth]:
|
||
|
layer = node.layer
|
||
|
if layer and not node.is_input:
|
||
|
for x in tf.nest.flatten(node.keras_inputs):
|
||
|
if id(x) not in computable_tensors:
|
||
|
raise ValueError(
|
||
|
"Graph disconnected: cannot obtain value for "
|
||
|
f'tensor {x} at layer "{layer.name}". '
|
||
|
"The following previous layers were accessed "
|
||
|
f"without issue: {layers_with_complete_input}"
|
||
|
)
|
||
|
for x in tf.nest.flatten(node.outputs):
|
||
|
computable_tensors.add(id(x))
|
||
|
layers_with_complete_input.append(layer.name)
|
||
|
|
||
|
# Ensure name unicity, which will be crucial for serialization
|
||
|
# (since serialized nodes refer to layers by their name).
|
||
|
all_names = [layer.name for layer in layers]
|
||
|
for name in all_names:
|
||
|
if all_names.count(name) != 1:
|
||
|
raise ValueError(
|
||
|
f'The name "{name}" is used {all_names.count(name)} '
|
||
|
"times in the model. All layer names should be unique."
|
||
|
)
|
||
|
return network_nodes, nodes_by_depth, layers, layers_by_depth
|
||
|
|
||
|
|
||
|
def _build_map(outputs):
|
||
|
"""This method topologically sorts nodes in order from inputs to outputs.
|
||
|
|
||
|
It uses a depth-first search to topologically sort nodes that appear in the
|
||
|
_keras_history connectivity metadata of `outputs`.
|
||
|
|
||
|
Args:
|
||
|
outputs: the output tensors whose _keras_history metadata should be
|
||
|
walked. This may be an arbitrary nested structure.
|
||
|
|
||
|
Returns:
|
||
|
A tuple like (ordered_nodes, layer_to_first_traversal_index)
|
||
|
ordered_nodes: list of nodes appearing in the keras history, topologically
|
||
|
sorted from original inputs to the `outputs`.
|
||
|
(If outputs have different sets of ancestors, the inputs to one output
|
||
|
may appear after a different output).
|
||
|
layer_to_first_traversal_index:
|
||
|
A dict mapping layer to the traversal index in the DFS where it is
|
||
|
seen. Note: if a layer is shared by several nodes, the dict will only
|
||
|
store the index corresponding to the *first* time the layer seen.
|
||
|
"""
|
||
|
finished_nodes = set()
|
||
|
nodes_in_progress = set()
|
||
|
nodes_in_decreasing_depth = [] # nodes from inputs -> outputs.
|
||
|
layer_indices = {} # layer -> in traversal order.
|
||
|
for output in tf.nest.flatten(outputs):
|
||
|
_build_map_helper(
|
||
|
output,
|
||
|
finished_nodes,
|
||
|
nodes_in_progress,
|
||
|
nodes_in_decreasing_depth,
|
||
|
layer_indices,
|
||
|
)
|
||
|
return nodes_in_decreasing_depth, layer_indices
|
||
|
|
||
|
|
||
|
def _build_map_helper(
|
||
|
tensor,
|
||
|
finished_nodes,
|
||
|
nodes_in_progress,
|
||
|
nodes_in_decreasing_depth,
|
||
|
layer_indices,
|
||
|
):
|
||
|
"""Recursive helper for `_build_map`."""
|
||
|
(
|
||
|
layer,
|
||
|
node_index,
|
||
|
_,
|
||
|
) = tensor._keras_history
|
||
|
node = layer._inbound_nodes[node_index]
|
||
|
|
||
|
# Don't repeat work for shared subgraphs
|
||
|
if node in finished_nodes:
|
||
|
return
|
||
|
|
||
|
# Prevent cycles.
|
||
|
if node in nodes_in_progress:
|
||
|
raise ValueError(
|
||
|
f'Tensor {tensor} from layer "{layer.name}" is part of a cycle.'
|
||
|
)
|
||
|
|
||
|
# Store the traversal order for layer sorting.
|
||
|
if layer not in layer_indices:
|
||
|
layer_indices[layer] = len(layer_indices)
|
||
|
|
||
|
# Propagate to all previous tensors connected to this node.
|
||
|
nodes_in_progress.add(node)
|
||
|
if not node.is_input:
|
||
|
for tensor in node.keras_inputs:
|
||
|
_build_map_helper(
|
||
|
tensor,
|
||
|
finished_nodes,
|
||
|
nodes_in_progress,
|
||
|
nodes_in_decreasing_depth,
|
||
|
layer_indices,
|
||
|
)
|
||
|
|
||
|
finished_nodes.add(node)
|
||
|
nodes_in_progress.remove(node)
|
||
|
nodes_in_decreasing_depth.append(node)
|
||
|
|
||
|
|
||
|
def _map_subgraph_network(inputs, outputs):
|
||
|
"""Returns the nodes and layers in the topology from `inputs` to `outputs`.
|
||
|
|
||
|
Args:
|
||
|
inputs: List of input tensors.
|
||
|
outputs: List of output tensors.
|
||
|
|
||
|
Returns:
|
||
|
A tuple of List{Node] and List[Layer].
|
||
|
"""
|
||
|
if not tf.compat.v1.executing_eagerly_outside_functions():
|
||
|
base_layer_utils.create_keras_history(outputs)
|
||
|
# Keep only nodes and layers in the topology between inputs and outputs.
|
||
|
_, nodes_by_depth, layers, _ = _map_graph_network(inputs, outputs)
|
||
|
return tf.nest.flatten([nodes for nodes in nodes_by_depth.values()]), layers
|
||
|
|
||
|
|
||
|
def _should_skip_first_node(layer):
|
||
|
"""Returns True if the first layer node should not be saved or loaded."""
|
||
|
# Networks that are constructed with an Input layer/shape start with a
|
||
|
# pre-existing node linking their input to output. This node is excluded
|
||
|
# from the network config.
|
||
|
if layer._self_tracked_trackables:
|
||
|
return (
|
||
|
isinstance(layer, Functional)
|
||
|
# Filter out Sequential models without an input shape.
|
||
|
and isinstance(
|
||
|
layer._self_tracked_trackables[0], input_layer_module.InputLayer
|
||
|
)
|
||
|
)
|
||
|
else:
|
||
|
return isinstance(layer, Functional)
|
||
|
|
||
|
|
||
|
def connect_ancillary_layers(model, created_layers):
|
||
|
"""Adds layers that are not connected to the outputs to the model."""
|
||
|
# Layers not connected to outputs, such as those added in `add_loss`.
|
||
|
ancillary_layers = [
|
||
|
layer for layer in created_layers.values() if layer not in model.layers
|
||
|
]
|
||
|
if ancillary_layers:
|
||
|
relevant_nodes = tf.nest.flatten(
|
||
|
[
|
||
|
layer.inbound_nodes[1:]
|
||
|
if _should_skip_first_node(layer)
|
||
|
else layer.inbound_nodes
|
||
|
for layer in created_layers.values()
|
||
|
]
|
||
|
)
|
||
|
model._insert_layers(ancillary_layers, relevant_nodes)
|
||
|
return model
|
||
|
|
||
|
|
||
|
def reconstruct_from_config(config, custom_objects=None, created_layers=None):
|
||
|
"""Reconstructs graph from config object.
|
||
|
|
||
|
Args:
|
||
|
config: Dictionary returned from Network.get_config()
|
||
|
custom_objects: Optional dictionary mapping names (strings) to custom
|
||
|
classes or functions to be considered during deserialization.
|
||
|
created_layers: Optional dictionary mapping names to Layer objects. Any
|
||
|
layer not in this dictionary will be created and added to the dict.
|
||
|
This function will add new nodes to all layers (excluding InputLayers),
|
||
|
instead of re-using pre-existing nodes in the layers.
|
||
|
|
||
|
Returns:
|
||
|
Tuple of (input tensors, output tensors, dictionary of created layers)
|
||
|
"""
|
||
|
# Layer instances created during the graph reconstruction process.
|
||
|
created_layers = created_layers or collections.OrderedDict()
|
||
|
|
||
|
# Maps input data (tuple of inbound layer name, node index) from the config
|
||
|
# to node indices in the newly generated model. The node indices may be
|
||
|
# different if the layers have already been called previously.
|
||
|
node_index_map = {}
|
||
|
node_count_by_layer = {}
|
||
|
|
||
|
# Dictionary mapping layer instances to
|
||
|
# node data that specifies a layer call.
|
||
|
# It acts as a queue that maintains any unprocessed
|
||
|
# layer call until it becomes possible to process it
|
||
|
# (i.e. until the input tensors to the call all exist).
|
||
|
unprocessed_nodes = collections.defaultdict(list)
|
||
|
|
||
|
def get_node_index(layer, config_node_index):
|
||
|
"""Returns node index in layer (might differ from config_node_index)."""
|
||
|
if isinstance(layer, input_layer_module.InputLayer):
|
||
|
return 0
|
||
|
return node_index_map.get((layer.name, config_node_index), None)
|
||
|
|
||
|
def _deserialize_keras_tensors(kwargs, layer_map):
|
||
|
"""Deserializes Keras Tensors passed to `call`.."""
|
||
|
|
||
|
def _deserialize_keras_tensor(t):
|
||
|
"""Deserializes a single Keras Tensor passed to `call`."""
|
||
|
if isinstance(t, tf_utils.ListWrapper):
|
||
|
t = t.as_list()
|
||
|
layer_name = t[0]
|
||
|
node_index = t[1]
|
||
|
tensor_index = t[2]
|
||
|
|
||
|
layer = layer_map[layer_name]
|
||
|
new_node_index = get_node_index(layer, node_index)
|
||
|
if new_node_index is None:
|
||
|
# The inbound node may not have been processed yet,
|
||
|
# (This can happen e.g. if it depends on a different set
|
||
|
# of inputs than those that have been processed already).
|
||
|
# raise an IndexError so that the current node puts itself
|
||
|
# back on the unprocessed queue.
|
||
|
# Caution: This may lead to infinite loops for malformed
|
||
|
# network configurations! (or when there is a bug in
|
||
|
# the network config loading code).
|
||
|
raise IndexError
|
||
|
node = layer._inbound_nodes[new_node_index]
|
||
|
return tf.nest.flatten(node.outputs)[tensor_index]
|
||
|
return t
|
||
|
|
||
|
kwargs = tf_utils.convert_inner_node_data(kwargs, wrap=True)
|
||
|
return tf.nest.map_structure(_deserialize_keras_tensor, kwargs)
|
||
|
|
||
|
def process_node(layer, node_data):
|
||
|
"""Deserialize a node.
|
||
|
|
||
|
Args:
|
||
|
layer: layer instance.
|
||
|
node_data: Nested structure of `ListWrapper`.
|
||
|
|
||
|
Returns:
|
||
|
Whether the node was processed (i.e. the layer was called on the
|
||
|
inputs specified by the node data)
|
||
|
|
||
|
Raises:
|
||
|
ValueError: In case of improperly formatted `node_data`.
|
||
|
"""
|
||
|
input_tensors = []
|
||
|
for input_data in tf.nest.flatten(node_data):
|
||
|
input_data = input_data.as_list()
|
||
|
if len(input_data) == 3:
|
||
|
kwargs = {}
|
||
|
elif len(input_data) == 4:
|
||
|
kwargs = input_data[3]
|
||
|
try:
|
||
|
kwargs = _deserialize_keras_tensors(kwargs, created_layers)
|
||
|
except IndexError:
|
||
|
# Happens if keras tensors in kwargs are still unprocessed
|
||
|
return False
|
||
|
else:
|
||
|
raise ValueError("Improperly formatted model config.")
|
||
|
|
||
|
if input_data[0] != node_module._CONSTANT_VALUE:
|
||
|
inbound_layer_name = input_data[0]
|
||
|
inbound_node_index = input_data[1]
|
||
|
inbound_tensor_index = input_data[2]
|
||
|
inbound_layer = created_layers[inbound_layer_name]
|
||
|
inbound_node_index = get_node_index(
|
||
|
inbound_layer, inbound_node_index
|
||
|
)
|
||
|
|
||
|
if inbound_node_index is None:
|
||
|
return False
|
||
|
inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
|
||
|
input_tensors.append(
|
||
|
tf.nest.flatten(inbound_node.outputs)[inbound_tensor_index]
|
||
|
)
|
||
|
else:
|
||
|
# We received a constant w/ no Keras history attached,
|
||
|
# which means it is a constant tensor input.
|
||
|
# Input is a constant value.
|
||
|
# Format = [_CONSTANT_VALUE, -1, const_val, kwargs]
|
||
|
assert input_data[1] == -1
|
||
|
assert len(input_data) >= 3
|
||
|
const_val = input_data[2]
|
||
|
if (
|
||
|
isinstance(const_val, tuple)
|
||
|
and len(const_val) == 2
|
||
|
and const_val[0] == node_module._COMPOSITE_TYPE
|
||
|
):
|
||
|
# It is a composite tensor.
|
||
|
input_tensors.append(json_utils.decode(const_val[1]))
|
||
|
else:
|
||
|
input_tensors.append(const_val)
|
||
|
input_tensors = tf.nest.pack_sequence_as(node_data, input_tensors)
|
||
|
# Call layer on its inputs, thus creating the node
|
||
|
# and building the layer if needed.
|
||
|
if input_tensors is not None:
|
||
|
if not layer._preserve_input_structure_in_config:
|
||
|
input_tensors = base_layer_utils.unnest_if_single_tensor(
|
||
|
input_tensors
|
||
|
)
|
||
|
output_tensors = layer(input_tensors, **kwargs)
|
||
|
|
||
|
# Update node index map.
|
||
|
output_index = tf.nest.flatten(output_tensors)[
|
||
|
0
|
||
|
]._keras_history.node_index
|
||
|
node_index_map[
|
||
|
(layer.name, node_count_by_layer[layer])
|
||
|
] = output_index
|
||
|
node_count_by_layer[layer] += 1
|
||
|
return True
|
||
|
|
||
|
def process_layer(layer_data):
|
||
|
"""Deserializes a layer, then call it on appropriate inputs.
|
||
|
|
||
|
Args:
|
||
|
layer_data: layer config dict.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: In case of improperly formatted `layer_data` dict.
|
||
|
"""
|
||
|
layer_name = layer_data["name"]
|
||
|
|
||
|
if layer_name in created_layers:
|
||
|
layer = created_layers[layer_name]
|
||
|
else:
|
||
|
# Instantiate layer.
|
||
|
from keras.layers import deserialize as deserialize_layer
|
||
|
|
||
|
layer = deserialize_layer(layer_data, custom_objects=custom_objects)
|
||
|
created_layers[layer_name] = layer
|
||
|
|
||
|
node_count_by_layer[layer] = int(_should_skip_first_node(layer))
|
||
|
|
||
|
# Gather layer inputs and convert to `ListWrapper` objects.
|
||
|
inbound_nodes_data = layer_data["inbound_nodes"]
|
||
|
inbound_nodes_data = tf_utils.convert_inner_node_data(
|
||
|
inbound_nodes_data, wrap=True
|
||
|
)
|
||
|
for node_data in inbound_nodes_data:
|
||
|
# We don't process nodes (i.e. make layer calls)
|
||
|
# on the fly because the inbound node may not yet exist,
|
||
|
# in case of layer shared at different topological depths
|
||
|
# (e.g. a model such as A(B(A(B(x)))))
|
||
|
unprocessed_nodes[layer].append(node_data)
|
||
|
|
||
|
# First, we create all layers and enqueue nodes to be processed
|
||
|
for layer_data in config["layers"]:
|
||
|
process_layer(layer_data)
|
||
|
# Then we process nodes in order of layer depth.
|
||
|
# Nodes that cannot yet be processed (if the inbound node
|
||
|
# does not yet exist) are re-enqueued, and the process
|
||
|
# is repeated until all nodes are processed.
|
||
|
while unprocessed_nodes:
|
||
|
for layer_data in config["layers"]:
|
||
|
layer = created_layers[layer_data["name"]]
|
||
|
if layer in unprocessed_nodes:
|
||
|
layer_nodes = unprocessed_nodes.pop(layer)
|
||
|
while layer_nodes:
|
||
|
node_data = layer_nodes[0]
|
||
|
if process_node(layer, node_data):
|
||
|
layer_nodes.pop(0)
|
||
|
else:
|
||
|
# If a node can't be processed, stop processing the
|
||
|
# nodes of the current layer to maintain node ordering.
|
||
|
unprocessed_nodes[layer] = layer_nodes
|
||
|
break
|
||
|
|
||
|
input_tensors = []
|
||
|
output_tensors = []
|
||
|
|
||
|
input_layers = tf_utils.convert_inner_node_data(
|
||
|
config["input_layers"], wrap=True
|
||
|
)
|
||
|
for layer_data in tf.nest.flatten(input_layers):
|
||
|
layer_name, node_index, tensor_index = layer_data.as_list()
|
||
|
assert layer_name in created_layers
|
||
|
layer = created_layers[layer_name]
|
||
|
node_index = get_node_index(layer, node_index)
|
||
|
layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
|
||
|
input_tensors.append(
|
||
|
tf.nest.flatten(layer_output_tensors)[tensor_index]
|
||
|
)
|
||
|
|
||
|
output_layers = tf_utils.convert_inner_node_data(
|
||
|
config["output_layers"], wrap=True
|
||
|
)
|
||
|
for layer_data in tf.nest.flatten(output_layers):
|
||
|
layer_name, node_index, tensor_index = layer_data.as_list()
|
||
|
assert layer_name in created_layers
|
||
|
layer = created_layers[layer_name]
|
||
|
node_index = get_node_index(layer, node_index)
|
||
|
layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
|
||
|
output_tensors.append(
|
||
|
tf.nest.flatten(layer_output_tensors)[tensor_index]
|
||
|
)
|
||
|
|
||
|
input_tensors = tf.nest.pack_sequence_as(input_layers, input_tensors)
|
||
|
output_tensors = tf.nest.pack_sequence_as(output_layers, output_tensors)
|
||
|
return input_tensors, output_tensors, created_layers
|
||
|
|
||
|
|
||
|
def get_network_config(network, serialize_layer_fn=None, config=None):
|
||
|
"""Build the config, which consists of the node graph and serialized layers.
|
||
|
|
||
|
Args:
|
||
|
network: A Network object.
|
||
|
serialize_layer_fn: Function used to serialize layers.
|
||
|
config: A dict to append more config entries into. If None, start with a
|
||
|
new dict for the config.
|
||
|
|
||
|
Returns:
|
||
|
Config dictionary.
|
||
|
"""
|
||
|
serialize_layer_fn = (
|
||
|
serialize_layer_fn or serialization.serialize_keras_object
|
||
|
)
|
||
|
config = config or {}
|
||
|
config["name"] = network.name
|
||
|
node_conversion_map = {}
|
||
|
for layer in network.layers:
|
||
|
kept_nodes = 1 if _should_skip_first_node(layer) else 0
|
||
|
for original_node_index, node in enumerate(layer._inbound_nodes):
|
||
|
node_key = _make_node_key(layer.name, original_node_index)
|
||
|
if node_key in network._network_nodes:
|
||
|
node_conversion_map[node_key] = kept_nodes
|
||
|
kept_nodes += 1
|
||
|
layer_configs = []
|
||
|
|
||
|
with serialization.SharedObjectSavingScope():
|
||
|
for layer in network.layers: # From the earliest layers on.
|
||
|
filtered_inbound_nodes = []
|
||
|
for original_node_index, node in enumerate(layer._inbound_nodes):
|
||
|
node_key = _make_node_key(layer.name, original_node_index)
|
||
|
if node_key in network._network_nodes and not node.is_input:
|
||
|
# The node is relevant to the model:
|
||
|
# add to filtered_inbound_nodes.
|
||
|
node_data = node.serialize(
|
||
|
_make_node_key, node_conversion_map
|
||
|
)
|
||
|
filtered_inbound_nodes.append(node_data)
|
||
|
|
||
|
layer_config = serialize_layer_fn(layer)
|
||
|
layer_config["name"] = layer.name
|
||
|
layer_config["inbound_nodes"] = filtered_inbound_nodes
|
||
|
layer_configs.append(layer_config)
|
||
|
config["layers"] = layer_configs
|
||
|
|
||
|
# Gather info about inputs and outputs.
|
||
|
model_inputs = []
|
||
|
for i in range(len(network._input_layers)):
|
||
|
layer, node_index, tensor_index = network._input_coordinates[i]
|
||
|
node_key = _make_node_key(layer.name, node_index)
|
||
|
if node_key not in network._network_nodes:
|
||
|
continue
|
||
|
new_node_index = node_conversion_map[node_key]
|
||
|
model_inputs.append(
|
||
|
tf_utils.ListWrapper([layer.name, new_node_index, tensor_index])
|
||
|
)
|
||
|
model_inputs = tf.nest.pack_sequence_as(
|
||
|
network._nested_inputs, model_inputs
|
||
|
)
|
||
|
# Preserve external Keras compat for Models with single input.
|
||
|
if not tf.nest.is_nested(model_inputs):
|
||
|
model_inputs = [model_inputs]
|
||
|
model_inputs = tf_utils.convert_inner_node_data(model_inputs)
|
||
|
config["input_layers"] = model_inputs
|
||
|
|
||
|
model_outputs = []
|
||
|
for i in range(len(network._output_layers)):
|
||
|
layer, node_index, tensor_index = network._output_coordinates[i]
|
||
|
node_key = _make_node_key(layer.name, node_index)
|
||
|
if node_key not in network._network_nodes:
|
||
|
continue
|
||
|
new_node_index = node_conversion_map[node_key]
|
||
|
model_outputs.append(
|
||
|
tf_utils.ListWrapper([layer.name, new_node_index, tensor_index])
|
||
|
)
|
||
|
model_outputs = tf.nest.pack_sequence_as(
|
||
|
network._nested_outputs, model_outputs
|
||
|
)
|
||
|
# Preserve external Keras compat for Models with single output.
|
||
|
if not tf.nest.is_nested(model_outputs):
|
||
|
model_outputs = [model_outputs]
|
||
|
model_outputs = tf_utils.convert_inner_node_data(model_outputs)
|
||
|
config["output_layers"] = model_outputs
|
||
|
return config
|
||
|
|
||
|
|
||
|
def shape_with_no_batch_size(x):
|
||
|
if x.shape.rank is None:
|
||
|
return None
|
||
|
shape = x.shape.as_list()
|
||
|
if shape:
|
||
|
shape[0] = None
|
||
|
return shape
|
||
|
|
||
|
|
||
|
class ModuleWrapper(base_layer.Layer):
|
||
|
"""Wrapper for `tf.Module`s to support the Functional and Sequential API."""
|
||
|
|
||
|
def __init__(self, module, method_name=None, **kwargs):
|
||
|
"""Initializes the wrapper Layer for this module.
|
||
|
|
||
|
Args:
|
||
|
module: The `tf.Module` instance to be wrapped.
|
||
|
method_name: (Optional) str. The name of the method to use as the
|
||
|
forward pass of the module. If not set, defaults to '__call__' if
|
||
|
defined, or 'call'.
|
||
|
**kwargs: Additional keywrod arguments. See `tf.keras.layers.Layer`.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If `method` is not defined on `module`.
|
||
|
"""
|
||
|
super().__init__(**kwargs)
|
||
|
if method_name is None:
|
||
|
if hasattr(module, "__call__"):
|
||
|
method_name = "__call__"
|
||
|
elif hasattr(module, "call"):
|
||
|
method_name = "call"
|
||
|
if method_name is None or not hasattr(module, method_name):
|
||
|
raise ValueError(f"{method_name} is not defined on object {module}")
|
||
|
|
||
|
self._module = module
|
||
|
self._method_name = method_name
|
||
|
|
||
|
# Check if module.__call__ has a `training` arg or accepts `**kwargs`.
|
||
|
method = getattr(module, method_name)
|
||
|
method_arg_spec = tf_inspect.getfullargspec(method)
|
||
|
self._call_spec.expects_training_arg = (
|
||
|
"training" in method_arg_spec.args
|
||
|
or method_arg_spec.varkw is not None
|
||
|
)
|
||
|
self._call_spec.expects_mask_arg = (
|
||
|
"mask" in method_arg_spec.args or method_arg_spec.varkw is not None
|
||
|
)
|
||
|
|
||
|
def call(self, *args, **kwargs):
|
||
|
if "training" in kwargs and not self._expects_training_arg:
|
||
|
kwargs.pop("training")
|
||
|
if "mask" in kwargs and not self._expects_mask_arg:
|
||
|
kwargs.pop("mask")
|
||
|
return getattr(self._module, self._method_name)(*args, **kwargs)
|
||
|
|
||
|
|
||
|
def has_functional_like_constructor(cls):
|
||
|
init_args = tf_inspect.getfullargspec(cls.__init__).args[1:]
|
||
|
functional_init_args = tf_inspect.getfullargspec(Functional.__init__).args[
|
||
|
1:
|
||
|
]
|
||
|
if init_args == functional_init_args:
|
||
|
return True
|
||
|
return False
|