Intelegentny_Pszczelarz/.venv/Lib/site-packages/keras/engine/functional.py

1688 lines
68 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A `Network` is way to compose layers: the topological form of a `Model`."""
import collections
import copy
import itertools
import warnings
import tensorflow.compat.v2 as tf
from keras import backend
from keras.dtensor import layout_map as layout_map_lib
from keras.engine import base_layer
from keras.engine import base_layer_utils
from keras.engine import functional_utils
from keras.engine import input_layer as input_layer_module
from keras.engine import input_spec
from keras.engine import node as node_module
from keras.engine import training as training_lib
from keras.engine import training_utils
from keras.saving.legacy import serialization
from keras.saving.legacy.saved_model import json_utils
from keras.saving.legacy.saved_model import network_serialization
from keras.saving.legacy.saved_model import utils as saved_model_utils
from keras.utils import generic_utils
from keras.utils import tf_inspect
from keras.utils import tf_utils
# isort: off
from tensorflow.python.platform import tf_logging as logging
from tensorflow.tools.docs import doc_controls
class Functional(training_lib.Model):
"""A `Functional` model is a `Model` defined as a directed graph of layers.
Three types of `Model` exist: subclassed `Model`, `Functional` model,
and `Sequential` (a special case of `Functional`).
In general, more Keras features are supported with `Functional`
than with subclassed `Model`s, specifically:
- Model cloning (`keras.models.clone`)
- Serialization (`model.get_config()/from_config`, `model.to_json()`
- Whole-model saving (`model.save()`)
A `Functional` model can be instantiated by passing two arguments to
`__init__`. The first argument is the `keras.Input` Tensors that represent
the inputs to the model. The second argument specifies the output
tensors that represent the outputs of this model. Both arguments can be a
nested structure of tensors.
Example:
```
inputs = {'x1': keras.Input(shape=(10,)), 'x2': keras.Input(shape=(1,))}
t = keras.layers.Dense(1, activation='relu')(inputs['x1'])
outputs = keras.layers.Add()([t, inputs['x2'])
model = keras.Model(inputs, outputs)
```
A `Functional` model constructed using the Functional API can also include
raw TensorFlow functions, with the exception of functions that create
Variables or assign ops.
Example:
```python
inputs = keras.Input(shape=(10,))
x = keras.layers.Dense(1)(inputs)
outputs = tf.nn.relu(x)
model = keras.Model(inputs, outputs)
```
A new `Functional` model can also be created by using the
intermediate tensors. This enables you to quickly extract sub-components
of the model.
Example:
```python
inputs = keras.Input(shape=(None, None, 3))
processed = keras.layers.RandomCrop(width=32, height=32)(inputs)
conv = keras.layers.Conv2D(filters=2, kernel_size=3)(processed)
pooling = keras.layers.GlobalAveragePooling2D()(conv)
feature = keras.layers.Dense(10)(pooling)
full_model = keras.Model(inputs, feature)
backbone = keras.Model(processed, conv)
activations = keras.Model(conv, feature)
```
Note that the `backbone` and `activations` models are not
created with `keras.Input` objects, but with the tensors that are originated
from `keras.Input` objects. Under the hood, the layers and weights will
be shared across these models, so that user can train the `full_model`, and
use `backbone` or `activations` to do feature extraction.
The inputs and outputs of the model can be nested structures of tensors as
well, and the created models are standard `Functional` model that support
all the existing API.
Args:
inputs: List of input tensors (must be created via `tf.keras.Input()` or
originated from `tf.keras.Input()`).
outputs: List of output tensors.
name: String, optional. Name of the model.
trainable: Boolean, optional. If the model's variables should be
trainable.
"""
# See tf.Module for the usage of this property.
# The key of _layer_call_argspecs is a layer. tf.Module._flatten will fail
# to flatten the key since it is trying to convert Trackable/Layer to a
# string.
_TF_MODULE_IGNORED_PROPERTIES = frozenset(
itertools.chain(
(
"_layer_call_argspecs",
"_compiled_trainable_state",
"_output_mask_cache",
"_output_tensor_cache",
"_output_shape_cache",
),
training_lib.Model._TF_MODULE_IGNORED_PROPERTIES,
)
)
@tf.__internal__.tracking.no_automatic_dependency_tracking
def __init__(self, inputs, outputs, name=None, trainable=True, **kwargs):
# This is used by the Model class, since we have some logic to swap the
# class in the __new__ method, which will lead to __init__ get invoked
# twice. Using the skip_init to skip one of the invocation of __init__
# to avoid any side effects
skip_init = kwargs.pop("skip_init", False)
if skip_init:
return
generic_utils.validate_kwargs(kwargs, {})
super().__init__(name=name, trainable=trainable)
# Check if the inputs contain any intermediate `KerasTensor` (not
# created by tf.keras.Input()). In this case we need to clone the `Node`
# and `KerasTensor` objects to mimic rebuilding a new model from new
# inputs. This feature is only enabled in TF2 not in v1 graph mode.
if tf.compat.v1.executing_eagerly_outside_functions():
if not all(
[
functional_utils.is_input_keras_tensor(t)
for t in tf.nest.flatten(inputs)
]
):
inputs, outputs = functional_utils.clone_graph_nodes(
inputs, outputs
)
self._init_graph_network(inputs, outputs)
@tf.__internal__.tracking.no_automatic_dependency_tracking
def _init_graph_network(self, inputs, outputs):
# This method is needed for Sequential to reinitialize graph network
# when layer is added or removed.
base_layer.keras_api_gauge.get_cell("Functional").set(True)
self._is_graph_network = True
# Normalize and set self.inputs, self.outputs.
if isinstance(inputs, list) and len(tf.nest.flatten(inputs)) == 1:
inputs = inputs[0]
if isinstance(outputs, list) and len(tf.nest.flatten(outputs)) == 1:
outputs = outputs[0]
self._nested_inputs = inputs
self._nested_outputs = outputs
self.inputs = tf.nest.flatten(inputs)
self.outputs = tf.nest.flatten(outputs)
# Models constructed with a single Tensor or list of Tensors can
# be called with a dict, where the keys of the dict are the names
# of the `Input` objects. Extra keys are ignored with warning.
if not tf.nest.is_nested(self._nested_inputs):
self._enable_dict_to_input_mapping = True
elif isinstance(self._nested_inputs, (list, tuple)) and not any(
tf.nest.is_nested(t) for t in self._nested_inputs
):
self._enable_dict_to_input_mapping = True
elif isinstance(self._nested_inputs, dict) and not any(
tf.nest.is_nested(t) for t in self._nested_inputs.values()
):
self._enable_dict_to_input_mapping = True
else:
self._enable_dict_to_input_mapping = False
if not tf.compat.v1.executing_eagerly_outside_functions():
if any(
not hasattr(tensor, "_keras_history") for tensor in self.outputs
):
base_layer_utils.create_keras_history(self._nested_outputs)
self._validate_graph_inputs_and_outputs()
# A Network does not create weights of its own, thus it is already
# built.
self.built = True
self._build_input_shape = tf.nest.map_structure(
lambda x: x.shape, inputs
)
self._compute_output_and_mask_jointly = True
# `_expects_training_arg` is True since the `training` argument is
# always present in the signature of the `call` method of a graph
# network.
self._call_spec.expects_training_arg = True
self._call_spec.expects_mask_arg = True
# A graph network does not autocast inputs, as its layers will cast them
# instead.
self._autocast = False
self._input_layers = []
self._output_layers = []
self._input_coordinates = []
self._output_coordinates = []
# This is for performance optimization when calling the Network on new
# inputs. Every time the Network is called on a set on input tensors, we
# compute the output tensors, output masks and output shapes in one
# pass, then cache them here. When any of these outputs is queried
# later, we retrieve it from there instead of recomputing it.
self._output_mask_cache = {}
self._output_tensor_cache = {}
self._output_shape_cache = {}
# Build self._output_layers:
for x in self.outputs:
(
layer,
node_index,
tensor_index,
) = x._keras_history
self._output_layers.append(layer)
self._output_coordinates.append((layer, node_index, tensor_index))
# Build self._input_layers:
for x in self.inputs:
(
layer,
node_index,
tensor_index,
) = x._keras_history
# It's supposed to be an input layer, so only one node
# and one tensor output.
assert node_index == 0
assert tensor_index == 0
self._input_layers.append(layer)
self._input_coordinates.append((layer, node_index, tensor_index))
# Keep track of the network's nodes and layers.
nodes, nodes_by_depth, layers, _ = _map_graph_network(
self.inputs, self.outputs
)
self._network_nodes = nodes
self._nodes_by_depth = nodes_by_depth
self._self_tracked_trackables = layers
self._layer_call_argspecs = {}
for layer in self._self_tracked_trackables:
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(
layer.call
)
# Build self.input_names and self.output_names.
self._set_output_names()
self.input_names = []
self._feed_input_names = []
self._feed_inputs = []
self._feed_input_shapes = []
for layer in self._input_layers:
self.input_names.append(layer.name)
if layer.is_placeholder:
self._feed_input_names.append(layer.name)
# Use batch_input_shape here because non-eager composite tensors
# may not have a shape attribute that's meaningful (sparse, for
# instance, has a tensor that's non-constant and needs to be
# fed). This means that input layers that create placeholders
# will need to have the batch_input_shape attr to allow for
# input shape validation.
self._feed_input_shapes.append(layer._batch_input_shape)
self._feed_inputs.append(layer.input)
self._compute_tensor_usage_count()
self._set_save_spec(self._nested_inputs)
tf_utils.assert_no_legacy_layers(self.layers)
# Note that this method is used by both functional and sequential
# models, so we can't just have this method in functional.__init__,
# which will miss the coverage of sequential model.
if self._layout_map is not None:
layout_map_lib._map_functional_model_variable(
self, self._layout_map
)
@property
def input(self):
"""Retrieves the input tensor(s) of a layer.
Only applicable if the layer has exactly one input,
i.e. if it is connected to one incoming layer.
Returns:
Input tensor or list of input tensors.
Raises:
RuntimeError: If called in Eager mode.
AttributeError: If no inbound nodes are found.
"""
return self._nested_inputs
@property
def input_shape(self):
"""Retrieves the input shape(s) of a layer.
Only applicable if the layer has exactly one input,
i.e. if it is connected to one incoming layer, or if all inputs
have the same shape.
Returns:
Input shape, as an integer shape tuple
(or list of shape tuples, one tuple per input tensor).
Raises:
AttributeError: if the layer has no defined input_shape.
RuntimeError: if called in Eager mode.
"""
return tf.nest.map_structure(backend.int_shape, self.input)
@property
def input_spec(self):
if hasattr(self, "_manual_input_spec"):
return self._manual_input_spec
if isinstance(self._nested_inputs, (dict, list, tuple)) and len(
self._nested_inputs
) != len(self.inputs):
# Case where we have a nested structure.
# In such a case we can't safely run any checks.
return None
if isinstance(self._nested_inputs, dict):
# Case where `_nested_inputs` is a plain dict of Inputs.
names = sorted(self._nested_inputs.keys())
return [
input_spec.InputSpec(
shape=shape_with_no_batch_size(self._nested_inputs[name]),
allow_last_axis_squeeze=True,
name=name,
)
for name in names
]
else:
# Single input, or list / tuple of inputs.
# The data may be passed as a dict keyed by input name.
return [
input_spec.InputSpec(
shape=shape_with_no_batch_size(x),
allow_last_axis_squeeze=True,
name=x._keras_history.layer.name,
)
for x in self.inputs
]
@input_spec.setter
def input_spec(self, value):
self._manual_input_spec = value
@property
def output(self):
"""Retrieves the output tensor(s) of a layer.
Only applicable if the layer has exactly one output,
i.e. if it is connected to one incoming layer.
Returns:
Output tensor or list of output tensors.
Raises:
AttributeError: if the layer is connected to more than one incoming
layers.
RuntimeError: if called in Eager mode.
"""
return self._nested_outputs
@property
def output_shape(self):
"""Retrieves the output shape(s) of a layer.
Only applicable if the layer has one output,
or if all outputs have the same shape.
Returns:
Output shape, as an integer shape tuple
(or list of shape tuples, one tuple per output tensor).
Raises:
AttributeError: if the layer has no defined output shape.
RuntimeError: if called in Eager mode.
"""
return tf.nest.map_structure(backend.int_shape, self.output)
def _set_output_names(self):
"""Assigns unique names to the Network's outputs.
Output layers with multiple output tensors would otherwise lead to
duplicate names in self.output_names.
"""
uniquified = []
output_names = set()
prefix_count = {}
for layer in self._output_layers:
proposal = layer.name
while proposal in output_names:
existing_count = prefix_count.get(layer.name, 1)
proposal = f"{layer.name}_{existing_count}"
prefix_count[layer.name] = existing_count + 1
output_names.add(proposal)
uniquified.append(proposal)
self.output_names = uniquified
@property
def _layer_checkpoint_dependencies(self):
"""Dictionary of layer dependencies to be included in the checkpoint."""
weight_layer_index = 0
dependencies = collections.OrderedDict()
for layer_index, layer in enumerate(self.layers):
try:
if layer.weights:
# Keep a separate index for layers which have weights. This
# allows users to insert Layers without weights anywhere in
# the network without breaking checkpoints.
dependencies[
"layer_with_weights-%d" % weight_layer_index
] = layer
weight_layer_index += 1
except ValueError:
# The layer might have weights, but may not be built yet. We
# just treat it as layer without weight.
pass
# Even if it doesn't have weights, we should still track everything
# in case it has/will have Trackable dependencies.
dependencies["layer-%d" % layer_index] = layer
return dependencies
def _trackable_children(self, save_type="checkpoint", **kwargs):
dependencies = self._layer_checkpoint_dependencies
dependencies.update(super()._trackable_children(save_type, **kwargs))
return dependencies
def _lookup_dependency(self, name):
layer_dependencies = self._layer_checkpoint_dependencies
if name in layer_dependencies:
return layer_dependencies[name]
return super()._lookup_dependency(name)
def _handle_deferred_layer_dependencies(self, layers):
"""Handles layer checkpoint dependencies that are added after init."""
layer_checkpoint_dependencies = self._layer_checkpoint_dependencies
layer_to_name = {v: k for k, v in layer_checkpoint_dependencies.items()}
for layer in layers:
if layer in layer_to_name:
self._handle_deferred_dependencies(
name=layer_to_name[layer], trackable=layer
)
@property
def _should_compute_mask(self):
return True
def compute_mask(self, inputs, mask):
# TODO(omalleyt): b/123540974 This function is not really safe to call
# by itself because it will duplicate any updates and losses in graph
# mode by `call`ing the Layers again.
output_tensors = self._run_internal_graph(inputs, mask=mask)
return tf.nest.map_structure(
lambda t: getattr(t, "_keras_mask", None), output_tensors
)
@doc_controls.do_not_doc_inheritable
def call(self, inputs, training=None, mask=None):
"""Calls the model on new inputs.
In this case `call` just reapplies
all ops in the graph to the new inputs
(e.g. build a new computational graph from the provided inputs).
Args:
inputs: A tensor or list of tensors.
training: Boolean or boolean scalar tensor, indicating whether to
run the `Network` in training mode or inference mode.
mask: A mask or list of masks. A mask can be
either a tensor or None (no mask).
Returns:
A tensor if there is a single output, or
a list of tensors if there are more than one outputs.
"""
return self._run_internal_graph(inputs, training=training, mask=mask)
def compute_output_shape(self, input_shape):
# Convert any shapes in tuple format to TensorShapes.
input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
if len(tf.nest.flatten(input_shape)) != len(
tf.nest.flatten(self._input_layers)
):
raise ValueError(
f"Invalid `input_shape` argument {input_shape}: "
f"the model expects {len(self._input_layers)} "
"input tensors."
)
# Use the tuple of TensorShape as the cache key, since tuple is hashable
# and can be used as hash key.
try:
cache_key = tuple(
tf_utils.convert_shapes(input_shape, to_tuples=True)
)
if cache_key in self._output_shape_cache:
# Cache hit. Return shapes as TensorShapes.
return self._output_shape_cache[cache_key]
except ValueError:
# In case there are unknown TensorShape, eg for sparse tensor input,
# We skip the caching since the shape is unknown.
pass
layers_to_output_shapes = {}
for layer, shape in zip(
self._input_layers, tf.nest.flatten(input_shape)
):
# It's an input layer: then `compute_output_shape` is identity,
# and there is only one node and one tensor..
shape_key = layer.name + "_0_0"
layers_to_output_shapes[shape_key] = shape
depth_keys = list(self._nodes_by_depth.keys())
depth_keys.sort(reverse=True)
# Iterate over nodes, by depth level.
if len(depth_keys) > 1:
for depth in depth_keys:
nodes = self._nodes_by_depth[depth]
for node in nodes:
layer = node.layer
if layer in self._input_layers:
# We've already covered the input layers
# a few lines above.
continue
# Get the input shapes for the first argument of the node
layer_input_shapes = []
layer_inputs = node.call_args[0]
for layer_input in tf.nest.flatten(layer_inputs):
kh = layer_input._keras_history
input_layer_key = kh.layer.name + "_%s_%s" % (
kh.node_index,
kh.tensor_index,
)
layer_input_shapes.append(
layers_to_output_shapes[input_layer_key]
)
layer_input_shapes = tf.nest.pack_sequence_as(
layer_inputs, layer_input_shapes
)
# Layers expect shapes to be tuples for
# `compute_output_shape`.
layer_input_shapes = tf_utils.convert_shapes(
layer_input_shapes, to_tuples=True
)
layer_output_shapes = layer.compute_output_shape(
layer_input_shapes
)
# Convert back to TensorShapes.
layer_output_shapes = tf_utils.convert_shapes(
layer_output_shapes, to_tuples=False
)
node_index = layer._inbound_nodes.index(node)
for j, shape in enumerate(
tf.nest.flatten(layer_output_shapes)
):
shape_key = layer.name + f"_{node_index}_{j}"
layers_to_output_shapes[shape_key] = shape
# Read final output shapes from layers_to_output_shapes.
output_shapes = []
for i in range(len(self._output_layers)):
layer, node_index, tensor_index = self._output_coordinates[i]
shape_key = layer.name + f"_{node_index}_{tensor_index}"
output_shapes.append(layers_to_output_shapes[shape_key])
output_shapes = tf.nest.pack_sequence_as(
self._nested_outputs, output_shapes
)
# Store in cache.
self._output_shape_cache[cache_key] = output_shapes
# Return shapes as TensorShapes.
return output_shapes
def _init_set_name(self, name, zero_based=True):
if not name:
cls_name = self.__class__.__name__
if self.__class__ == Functional:
# Hide the functional class name from user, since its not a
# public visible class. Use "Model" instead,
cls_name = "Model"
self._name = backend.unique_object_name(
generic_utils.to_snake_case(cls_name), zero_based=zero_based
)
else:
self._name = name
def _run_internal_graph(self, inputs, training=None, mask=None):
"""Computes output tensors for new inputs.
# Note:
- Can be run on non-Keras tensors.
Args:
inputs: Tensor or nested structure of Tensors.
training: Boolean learning phase.
mask: (Optional) Tensor or nested structure of Tensors.
Returns:
output_tensors
"""
inputs = self._flatten_to_reference_inputs(inputs)
if mask is None:
masks = [None] * len(inputs)
else:
masks = self._flatten_to_reference_inputs(mask)
for input_t, mask in zip(inputs, masks):
input_t._keras_mask = mask
# Dictionary mapping reference tensors to computed tensors.
tensor_dict = {}
tensor_usage_count = self._tensor_usage_count
for x, y in zip(self.inputs, inputs):
y = self._conform_to_reference_input(y, ref_input=x)
x_id = str(id(x))
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
nodes_by_depth = self._nodes_by_depth
depth_keys = list(nodes_by_depth.keys())
depth_keys.sort(reverse=True)
for depth in depth_keys:
nodes = nodes_by_depth[depth]
for node in nodes:
if node.is_input:
continue # Input tensors already exist.
if any(t_id not in tensor_dict for t_id in node.flat_input_ids):
continue # Node is not computable, try skipping.
args, kwargs = node.map_arguments(tensor_dict)
outputs = node.layer(*args, **kwargs)
# Update tensor_dict.
for x_id, y in zip(
node.flat_output_ids, tf.nest.flatten(outputs)
):
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
output_tensors = []
for x in self.outputs:
x_id = str(id(x))
assert x_id in tensor_dict, "Could not compute output " + str(x)
output_tensors.append(tensor_dict[x_id].pop())
return tf.nest.pack_sequence_as(self._nested_outputs, output_tensors)
def _flatten_to_reference_inputs(self, tensors):
"""Maps `tensors` to their respective `keras.Input`."""
if self._enable_dict_to_input_mapping and isinstance(tensors, dict):
ref_inputs = self._nested_inputs
if not tf.nest.is_nested(ref_inputs):
ref_inputs = [self._nested_inputs]
if isinstance(ref_inputs, dict):
# In the case that the graph is constructed with dict input
# tensors, We will use the original dict key to map with the
# keys in the input data. Note that the model.inputs is using
# nest.flatten to process the input tensors, which means the
# dict input tensors are ordered by their keys.
ref_input_names = sorted(ref_inputs.keys())
else:
ref_input_names = [
inp._keras_history.layer.name for inp in ref_inputs
]
# Raise an warning if there are more input data comparing to input
# tensor
if len(tensors) > len(ref_input_names):
warnings.warn(
"Input dict contained keys {} which did not match any "
"model input. They will be ignored by the model.".format(
[n for n in tensors.keys() if n not in ref_input_names]
),
stacklevel=2,
)
try:
# Flatten in the order `Input`s were passed during Model
# construction.
return [tensors[n] for n in ref_input_names]
except KeyError:
# TODO(b/151582614)
return tf.nest.flatten(tensors)
# Otherwise both self.inputs and tensors will already be in same order.
return tf.nest.flatten(tensors)
def _conform_to_reference_input(self, tensor, ref_input):
"""Set shape and dtype based on `keras.Input`s."""
if isinstance(tensor, tf.Tensor):
# Allow (None,) and (None, 1) Tensors to be passed interchangeably.
# Use the shape specified by the `keras.Input`.
t_shape = tensor.shape
t_rank = t_shape.rank
ref_shape = ref_input.shape
ref_rank = ref_shape.rank
keras_history = getattr(tensor, "_keras_history", None)
if t_rank is not None and ref_rank is not None:
# Should squeeze last dimension. True if tensor is (BATCH, ...,
# 1) and reference is (BATCH, ...).
if t_rank == ref_rank + 1 and t_shape[-1] == 1:
tensor = tf.squeeze(tensor, axis=-1)
# Should expand last_dimension. True if tensor is (BATCH, ...)
# and reference is (BATCH, ..., 1).
elif t_rank == ref_rank - 1 and ref_shape[-1] == 1:
tensor = tf.expand_dims(tensor, axis=-1)
if keras_history is not None: # Restore keras history.
tensor._keras_history = keras_history
# Dtype casting.
tensor = tf.cast(tensor, dtype=ref_input.dtype)
elif tf_utils.is_extension_type(tensor):
# Dtype casting (If the extension type has a non-variant dtype and
# supports being cast). Only cast if necessary (since some
# extension types may not implement tf.cast).
tensor_dtype = getattr(tensor, "dtype", None)
ref_input_dtype = getattr(ref_input, "dtype", None)
if (
ref_input_dtype is not None
and tensor_dtype is not None
and tensor_dtype != ref_input_dtype
and ref_input_dtype != tf.variant
):
tensor = tf.cast(tensor, dtype=ref_input_dtype)
return tensor
@generic_utils.default
def get_config(self):
# Prepare base arguments
config = {
"name": self.name,
"trainable": self.trainable,
}
if saved_model_utils.in_tf_saved_model_scope():
# SavedModel special case: need to preserve legacy (potentially
# incorrect) behavior.
return copy.deepcopy(get_network_config(self, config=config))
# Check whether the class has a constructor compatible with a Functional
# model or if it has a custom constructor.
if has_functional_like_constructor(self.__class__):
# Only return a Functional config if the constructor is the same
# as that of a Functional model. This excludes subclassed Functional
# models with a custom __init__.
config = copy.deepcopy(get_network_config(self, config=config))
else:
# Try to autogenerate config
xtra_args = set(config.keys())
if getattr(self, "_auto_get_config", False):
config.update(self._auto_config.config)
# Remove args non explicitly supported
argspec = tf_inspect.getfullargspec(self.__init__)
if argspec.varkw != "kwargs":
for key in xtra_args - xtra_args.intersection(argspec.args[1:]):
config.pop(key, None)
return config
def get_weight_paths(self):
result = {}
for layer in self.layers:
(
descendants,
object_paths_dict,
) = tf.__internal__.tracking.ObjectGraphView(
layer
).breadth_first_traversal()
for descendant in descendants:
if isinstance(descendant, tf.Variable):
trackable_references = object_paths_dict[descendant]
object_path = ".".join(
[t.name for t in trackable_references]
)
result[layer.name + "." + object_path] = descendant
return result
def _validate_graph_inputs_and_outputs(self):
"""Validates the inputs and outputs of a Graph Network."""
# Check for redundancy in inputs.
if len({id(i) for i in self.inputs}) != len(self.inputs):
raise ValueError(
"The list of inputs passed to the model "
"contains the same input multiple times. "
"All inputs should only appear once."
f"Received inputs={self.inputs}"
)
for x in self.inputs:
# Check that x has appropriate `_keras_history` metadata.
if not hasattr(x, "_keras_history"):
cls_name = self.__class__.__name__
raise ValueError(
f"Input tensors to a {cls_name} model "
"must come from `tf.keras.Input`. "
f"Received inputs={x} (missing previous layer metadata)."
)
# Check that x is an input tensor.
layer = x._keras_history.layer
if len(layer._inbound_nodes) > 1 or (
layer._inbound_nodes and not layer._inbound_nodes[0].is_input
):
cls_name = self.__class__.__name__
logging.warning(
f"{cls_name} model inputs must come from "
"`tf.keras.Input` (thus holding past layer metadata). "
"They cannot be the output of "
"a previous non-Input layer. "
"Here, a tensor specified as "
f'input to "{self.name}" was not an Input tensor, '
f'it was generated by layer "{layer.name}".\n'
"Note that input tensors are "
"instantiated via `tensor = tf.keras.Input(shape)`.\n"
f"The tensor that caused the issue was: {x}"
)
# Check compatibility of batch sizes of Input Layers.
input_batch_sizes = set(
[
training_utils.get_static_batch_size(x._keras_history.layer)
for x in self.inputs
]
)
input_batch_sizes.discard(None)
if len(input_batch_sizes) > 1:
logging.warning(
"Found incompatible static batch sizes among the "
f"inputs. Batch sizes: {sorted(input_batch_sizes)}"
)
for x in self.outputs:
if not hasattr(x, "_keras_history"):
cls_name = self.__class__.__name__
raise ValueError(
f"Output tensors of a {cls_name} model must be "
"the output of a TensorFlow `Layer` "
f"(thus holding past layer metadata). Found: {x}"
)
def _insert_layers(self, layers, relevant_nodes=None):
"""Inserts Layers into the Network after Network creation.
This is only valid for Keras Graph Networks. Layers added via this
function will be included in the `call` computation and `get_config` of
this Network. They will not be added to the Network's outputs.
Args:
layers: Arbitrary nested structure of Layers. Layers must be reachable
from one or more of the `keras.Input` Tensors that correspond to
this Network's inputs.
relevant_nodes: Nodes from the Layers that should be considered part
of this Network. If `None`, all Nodes will be considered part of
this Network.
Raises:
ValueError: If the layers depend on `Input`s not found in this Model.
"""
layers = tf.nest.flatten(layers)
tf_utils.assert_no_legacy_layers(layers)
node_to_depth = {}
for depth, nodes in self._nodes_by_depth.items():
node_to_depth.update({node: depth for node in nodes})
# The nodes of these Layers that are relevant to this Network. If not
# provided, assume all Nodes are relevant
if not relevant_nodes:
relevant_nodes = tf.nest.flatten(
[layer._inbound_nodes for layer in layers]
)
network_nodes = set(relevant_nodes + list(node_to_depth.keys()))
def _get_min_depth(node):
"""Gets the minimum depth at which node can be computed."""
min_depth = 0
for layer, node_id, _, _ in node.iterate_inbound():
inbound_node = layer._inbound_nodes[node_id]
if inbound_node in node_to_depth:
min_depth = min(min_depth, node_to_depth[inbound_node])
elif inbound_node not in network_nodes:
continue
else:
# Previous relevant nodes haven't been processed yet.
return None
# New node is one shallower than its shallowest input.
return min_depth - 1
# Insert nodes into `_nodes_by_depth` and other node attrs.
unprocessed_nodes = copy.copy(relevant_nodes)
i = 0
while unprocessed_nodes:
i += 1
# Do a sanity check. This can occur if `Input`s from outside this
# Model are being relied on.
if i > 10000:
raise ValueError(
"Layers could not be added due to missing dependencies."
)
node = unprocessed_nodes.pop(0)
depth = _get_min_depth(node)
if depth is None: # Defer until inbound nodes are processed.
unprocessed_nodes.append(node)
continue
node_key = _make_node_key(
node.layer.name, node.layer._inbound_nodes.index(node)
)
if node_key not in self._network_nodes:
node_to_depth[node] = depth
self._network_nodes.add(node_key)
self._nodes_by_depth[depth].append(node)
# Insert layers and update other layer attrs.
layer_set = set(self._self_tracked_trackables)
deferred_layers = []
for layer in layers:
if layer not in layer_set:
self._self_tracked_trackables.append(layer)
deferred_layers.append(layer)
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(
layer.call
)
layer_set.add(layer)
self._handle_deferred_layer_dependencies(deferred_layers)
self._compute_tensor_usage_count()
def _compute_tensor_usage_count(self):
"""Compute the #. of tensor usages for all the output tensors of layers.
The computed tensor usage count is saved as `self._tensor_usage_count`.
This is later used for saving memory in eager computation by releasing
no-longer-needed tensors as early as possible.
"""
tensor_usage_count = collections.Counter()
available_tensors = set(str(id(tensor)) for tensor in self.inputs)
depth_keys = list(self._nodes_by_depth.keys())
depth_keys.sort(reverse=True)
depth_keys = depth_keys[1:]
for depth in depth_keys:
for node in self._nodes_by_depth[depth]:
input_tensors = {
str(id(tensor))
for tensor in tf.nest.flatten(node.keras_inputs)
}
if input_tensors.issubset(available_tensors):
for tensor in tf.nest.flatten(node.keras_inputs):
tensor_usage_count[str(id(tensor))] += 1
for output_tensor in tf.nest.flatten(node.outputs):
available_tensors.add(str(id(output_tensor)))
for tensor in self.outputs:
tensor_usage_count[str(id(tensor))] += 1
self._tensor_usage_count = tensor_usage_count
def _assert_weights_created(self):
# Override the implementation in Model.
# The Functional model should always have weight created already.
return
def _graph_network_add_loss(self, symbolic_loss):
new_nodes, new_layers = _map_subgraph_network(
self.inputs, [symbolic_loss]
)
# Losses must be keyed on inputs no matter what in order to be supported
# in DistributionStrategy.
add_loss_layer = base_layer.AddLoss(
unconditional=False, dtype=symbolic_loss.dtype
)
add_loss_layer(symbolic_loss)
new_nodes.extend(add_loss_layer.inbound_nodes)
new_layers.append(add_loss_layer)
self._insert_layers(new_layers, new_nodes)
def _graph_network_add_metric(self, value, aggregation, name):
new_nodes, new_layers = _map_subgraph_network(self.inputs, [value])
add_metric_layer = base_layer.AddMetric(
aggregation, name, dtype=value.dtype
)
add_metric_layer(value)
new_nodes.extend(add_metric_layer.inbound_nodes)
new_layers.append(add_metric_layer)
self._insert_layers(new_layers, new_nodes)
@property
def _trackable_saved_model_saver(self):
return network_serialization.NetworkSavedModelSaver(self)
def _get_save_spec(self, dynamic_batch=True, inputs_only=True):
if getattr(self, "_has_explicit_input_shape", True):
# Functional models and Sequential models that have an explicit
# input shape should use the batch size set by the input layer.
dynamic_batch = False
return super()._get_save_spec(dynamic_batch, inputs_only)
def _make_node_key(layer_name, node_index):
return layer_name + "_ib-" + str(node_index)
def _map_graph_network(inputs, outputs):
"""Validates a network's topology and gather its layers and nodes.
Args:
inputs: List of input tensors.
outputs: List of outputs tensors.
Returns:
A tuple `(nodes, nodes_by_depth, layers, layers_by_depth)`.
- nodes: list of Node instances.
- nodes_by_depth: dict mapping ints (depth) to lists of node instances.
- layers: list of Layer instances.
- layers_by_depth: dict mapping ints (depth) to lists of layer instances.
Raises:
ValueError: In case the network is not valid (e.g. disconnected graph).
"""
# "depth" is number of layers between output Node and the Node.
# Nodes are ordered from inputs -> outputs.
nodes_in_decreasing_depth, layer_indices = _build_map(outputs)
network_nodes = {
_make_node_key(node.layer.name, node.layer._inbound_nodes.index(node))
for node in nodes_in_decreasing_depth
}
nodes_depths = {} # dict {node: depth value}
layers_depths = {} # dict {layer: depth value}
for node in reversed(nodes_in_decreasing_depth):
# If the depth is not set, the node has no outbound nodes (depth 0).
depth = nodes_depths.setdefault(node, 0)
# Update the depth of the corresponding layer
previous_depth = layers_depths.get(node.layer, 0)
# If we've seen this layer before at a higher depth,
# we should use that depth instead of the node depth.
# This is necessary for shared layers that have inputs at different
# depth levels in the graph.
depth = max(depth, previous_depth)
layers_depths[node.layer] = depth
nodes_depths[node] = depth
# Update the depth of inbound nodes.
# The "depth" of a node is the max of the depths
# of all nodes it is connected to + 1.
for node_dep in node.parent_nodes:
previous_depth = nodes_depths.get(node_dep, 0)
nodes_depths[node_dep] = max(depth + 1, previous_depth)
# Handle inputs that are not connected to outputs.
# We do not error out here because the inputs may be used to compute losses
# and metrics.
for input_t in inputs:
input_layer = input_t._keras_history[0]
if input_layer not in layers_depths:
layers_depths[input_layer] = 0
layer_indices[input_layer] = -1
nodes_depths[input_layer._inbound_nodes[0]] = 0
network_nodes.add(_make_node_key(input_layer.name, 0))
# Build a dict {depth: list of nodes with this depth}
nodes_by_depth = collections.defaultdict(list)
for node, depth in nodes_depths.items():
nodes_by_depth[depth].append(node)
# Build a dict {depth: list of layers with this depth}
layers_by_depth = collections.defaultdict(list)
for layer, depth in layers_depths.items():
layers_by_depth[depth].append(layer)
# Get sorted list of layer depths.
depth_keys = list(layers_by_depth.keys())
depth_keys.sort(reverse=True)
# Set self.layers ordered by depth.
layers = []
for depth in depth_keys:
layers_for_depth = layers_by_depth[depth]
# Network.layers needs to have a deterministic order:
# here we order them by traversal order.
layers_for_depth.sort(key=lambda x: layer_indices[x])
layers.extend(layers_for_depth)
# Get sorted list of node depths.
depth_keys = list(nodes_by_depth.keys())
depth_keys.sort(reverse=True)
# Check that all tensors required are computable.
# computable_tensors: all tensors in the graph
# that can be computed from the inputs provided.
computable_tensors = set()
for x in inputs:
computable_tensors.add(id(x))
layers_with_complete_input = [] # To provide a better error msg.
for depth in depth_keys:
for node in nodes_by_depth[depth]:
layer = node.layer
if layer and not node.is_input:
for x in tf.nest.flatten(node.keras_inputs):
if id(x) not in computable_tensors:
raise ValueError(
"Graph disconnected: cannot obtain value for "
f'tensor {x} at layer "{layer.name}". '
"The following previous layers were accessed "
f"without issue: {layers_with_complete_input}"
)
for x in tf.nest.flatten(node.outputs):
computable_tensors.add(id(x))
layers_with_complete_input.append(layer.name)
# Ensure name unicity, which will be crucial for serialization
# (since serialized nodes refer to layers by their name).
all_names = [layer.name for layer in layers]
for name in all_names:
if all_names.count(name) != 1:
raise ValueError(
f'The name "{name}" is used {all_names.count(name)} '
"times in the model. All layer names should be unique."
)
return network_nodes, nodes_by_depth, layers, layers_by_depth
def _build_map(outputs):
"""This method topologically sorts nodes in order from inputs to outputs.
It uses a depth-first search to topologically sort nodes that appear in the
_keras_history connectivity metadata of `outputs`.
Args:
outputs: the output tensors whose _keras_history metadata should be
walked. This may be an arbitrary nested structure.
Returns:
A tuple like (ordered_nodes, layer_to_first_traversal_index)
ordered_nodes: list of nodes appearing in the keras history, topologically
sorted from original inputs to the `outputs`.
(If outputs have different sets of ancestors, the inputs to one output
may appear after a different output).
layer_to_first_traversal_index:
A dict mapping layer to the traversal index in the DFS where it is
seen. Note: if a layer is shared by several nodes, the dict will only
store the index corresponding to the *first* time the layer seen.
"""
finished_nodes = set()
nodes_in_progress = set()
nodes_in_decreasing_depth = [] # nodes from inputs -> outputs.
layer_indices = {} # layer -> in traversal order.
for output in tf.nest.flatten(outputs):
_build_map_helper(
output,
finished_nodes,
nodes_in_progress,
nodes_in_decreasing_depth,
layer_indices,
)
return nodes_in_decreasing_depth, layer_indices
def _build_map_helper(
tensor,
finished_nodes,
nodes_in_progress,
nodes_in_decreasing_depth,
layer_indices,
):
"""Recursive helper for `_build_map`."""
(
layer,
node_index,
_,
) = tensor._keras_history
node = layer._inbound_nodes[node_index]
# Don't repeat work for shared subgraphs
if node in finished_nodes:
return
# Prevent cycles.
if node in nodes_in_progress:
raise ValueError(
f'Tensor {tensor} from layer "{layer.name}" is part of a cycle.'
)
# Store the traversal order for layer sorting.
if layer not in layer_indices:
layer_indices[layer] = len(layer_indices)
# Propagate to all previous tensors connected to this node.
nodes_in_progress.add(node)
if not node.is_input:
for tensor in node.keras_inputs:
_build_map_helper(
tensor,
finished_nodes,
nodes_in_progress,
nodes_in_decreasing_depth,
layer_indices,
)
finished_nodes.add(node)
nodes_in_progress.remove(node)
nodes_in_decreasing_depth.append(node)
def _map_subgraph_network(inputs, outputs):
"""Returns the nodes and layers in the topology from `inputs` to `outputs`.
Args:
inputs: List of input tensors.
outputs: List of output tensors.
Returns:
A tuple of List{Node] and List[Layer].
"""
if not tf.compat.v1.executing_eagerly_outside_functions():
base_layer_utils.create_keras_history(outputs)
# Keep only nodes and layers in the topology between inputs and outputs.
_, nodes_by_depth, layers, _ = _map_graph_network(inputs, outputs)
return tf.nest.flatten([nodes for nodes in nodes_by_depth.values()]), layers
def _should_skip_first_node(layer):
"""Returns True if the first layer node should not be saved or loaded."""
# Networks that are constructed with an Input layer/shape start with a
# pre-existing node linking their input to output. This node is excluded
# from the network config.
if layer._self_tracked_trackables:
return (
isinstance(layer, Functional)
# Filter out Sequential models without an input shape.
and isinstance(
layer._self_tracked_trackables[0], input_layer_module.InputLayer
)
)
else:
return isinstance(layer, Functional)
def connect_ancillary_layers(model, created_layers):
"""Adds layers that are not connected to the outputs to the model."""
# Layers not connected to outputs, such as those added in `add_loss`.
ancillary_layers = [
layer for layer in created_layers.values() if layer not in model.layers
]
if ancillary_layers:
relevant_nodes = tf.nest.flatten(
[
layer.inbound_nodes[1:]
if _should_skip_first_node(layer)
else layer.inbound_nodes
for layer in created_layers.values()
]
)
model._insert_layers(ancillary_layers, relevant_nodes)
return model
def reconstruct_from_config(config, custom_objects=None, created_layers=None):
"""Reconstructs graph from config object.
Args:
config: Dictionary returned from Network.get_config()
custom_objects: Optional dictionary mapping names (strings) to custom
classes or functions to be considered during deserialization.
created_layers: Optional dictionary mapping names to Layer objects. Any
layer not in this dictionary will be created and added to the dict.
This function will add new nodes to all layers (excluding InputLayers),
instead of re-using pre-existing nodes in the layers.
Returns:
Tuple of (input tensors, output tensors, dictionary of created layers)
"""
# Layer instances created during the graph reconstruction process.
created_layers = created_layers or collections.OrderedDict()
# Maps input data (tuple of inbound layer name, node index) from the config
# to node indices in the newly generated model. The node indices may be
# different if the layers have already been called previously.
node_index_map = {}
node_count_by_layer = {}
# Dictionary mapping layer instances to
# node data that specifies a layer call.
# It acts as a queue that maintains any unprocessed
# layer call until it becomes possible to process it
# (i.e. until the input tensors to the call all exist).
unprocessed_nodes = collections.defaultdict(list)
def get_node_index(layer, config_node_index):
"""Returns node index in layer (might differ from config_node_index)."""
if isinstance(layer, input_layer_module.InputLayer):
return 0
return node_index_map.get((layer.name, config_node_index), None)
def _deserialize_keras_tensors(kwargs, layer_map):
"""Deserializes Keras Tensors passed to `call`.."""
def _deserialize_keras_tensor(t):
"""Deserializes a single Keras Tensor passed to `call`."""
if isinstance(t, tf_utils.ListWrapper):
t = t.as_list()
layer_name = t[0]
node_index = t[1]
tensor_index = t[2]
layer = layer_map[layer_name]
new_node_index = get_node_index(layer, node_index)
if new_node_index is None:
# The inbound node may not have been processed yet,
# (This can happen e.g. if it depends on a different set
# of inputs than those that have been processed already).
# raise an IndexError so that the current node puts itself
# back on the unprocessed queue.
# Caution: This may lead to infinite loops for malformed
# network configurations! (or when there is a bug in
# the network config loading code).
raise IndexError
node = layer._inbound_nodes[new_node_index]
return tf.nest.flatten(node.outputs)[tensor_index]
return t
kwargs = tf_utils.convert_inner_node_data(kwargs, wrap=True)
return tf.nest.map_structure(_deserialize_keras_tensor, kwargs)
def process_node(layer, node_data):
"""Deserialize a node.
Args:
layer: layer instance.
node_data: Nested structure of `ListWrapper`.
Returns:
Whether the node was processed (i.e. the layer was called on the
inputs specified by the node data)
Raises:
ValueError: In case of improperly formatted `node_data`.
"""
input_tensors = []
for input_data in tf.nest.flatten(node_data):
input_data = input_data.as_list()
if len(input_data) == 3:
kwargs = {}
elif len(input_data) == 4:
kwargs = input_data[3]
try:
kwargs = _deserialize_keras_tensors(kwargs, created_layers)
except IndexError:
# Happens if keras tensors in kwargs are still unprocessed
return False
else:
raise ValueError("Improperly formatted model config.")
if input_data[0] != node_module._CONSTANT_VALUE:
inbound_layer_name = input_data[0]
inbound_node_index = input_data[1]
inbound_tensor_index = input_data[2]
inbound_layer = created_layers[inbound_layer_name]
inbound_node_index = get_node_index(
inbound_layer, inbound_node_index
)
if inbound_node_index is None:
return False
inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
input_tensors.append(
tf.nest.flatten(inbound_node.outputs)[inbound_tensor_index]
)
else:
# We received a constant w/ no Keras history attached,
# which means it is a constant tensor input.
# Input is a constant value.
# Format = [_CONSTANT_VALUE, -1, const_val, kwargs]
assert input_data[1] == -1
assert len(input_data) >= 3
const_val = input_data[2]
if (
isinstance(const_val, tuple)
and len(const_val) == 2
and const_val[0] == node_module._COMPOSITE_TYPE
):
# It is a composite tensor.
input_tensors.append(json_utils.decode(const_val[1]))
else:
input_tensors.append(const_val)
input_tensors = tf.nest.pack_sequence_as(node_data, input_tensors)
# Call layer on its inputs, thus creating the node
# and building the layer if needed.
if input_tensors is not None:
if not layer._preserve_input_structure_in_config:
input_tensors = base_layer_utils.unnest_if_single_tensor(
input_tensors
)
output_tensors = layer(input_tensors, **kwargs)
# Update node index map.
output_index = tf.nest.flatten(output_tensors)[
0
]._keras_history.node_index
node_index_map[
(layer.name, node_count_by_layer[layer])
] = output_index
node_count_by_layer[layer] += 1
return True
def process_layer(layer_data):
"""Deserializes a layer, then call it on appropriate inputs.
Args:
layer_data: layer config dict.
Raises:
ValueError: In case of improperly formatted `layer_data` dict.
"""
layer_name = layer_data["name"]
if layer_name in created_layers:
layer = created_layers[layer_name]
else:
# Instantiate layer.
from keras.layers import deserialize as deserialize_layer
layer = deserialize_layer(layer_data, custom_objects=custom_objects)
created_layers[layer_name] = layer
node_count_by_layer[layer] = int(_should_skip_first_node(layer))
# Gather layer inputs and convert to `ListWrapper` objects.
inbound_nodes_data = layer_data["inbound_nodes"]
inbound_nodes_data = tf_utils.convert_inner_node_data(
inbound_nodes_data, wrap=True
)
for node_data in inbound_nodes_data:
# We don't process nodes (i.e. make layer calls)
# on the fly because the inbound node may not yet exist,
# in case of layer shared at different topological depths
# (e.g. a model such as A(B(A(B(x)))))
unprocessed_nodes[layer].append(node_data)
# First, we create all layers and enqueue nodes to be processed
for layer_data in config["layers"]:
process_layer(layer_data)
# Then we process nodes in order of layer depth.
# Nodes that cannot yet be processed (if the inbound node
# does not yet exist) are re-enqueued, and the process
# is repeated until all nodes are processed.
while unprocessed_nodes:
for layer_data in config["layers"]:
layer = created_layers[layer_data["name"]]
if layer in unprocessed_nodes:
layer_nodes = unprocessed_nodes.pop(layer)
while layer_nodes:
node_data = layer_nodes[0]
if process_node(layer, node_data):
layer_nodes.pop(0)
else:
# If a node can't be processed, stop processing the
# nodes of the current layer to maintain node ordering.
unprocessed_nodes[layer] = layer_nodes
break
input_tensors = []
output_tensors = []
input_layers = tf_utils.convert_inner_node_data(
config["input_layers"], wrap=True
)
for layer_data in tf.nest.flatten(input_layers):
layer_name, node_index, tensor_index = layer_data.as_list()
assert layer_name in created_layers
layer = created_layers[layer_name]
node_index = get_node_index(layer, node_index)
layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
input_tensors.append(
tf.nest.flatten(layer_output_tensors)[tensor_index]
)
output_layers = tf_utils.convert_inner_node_data(
config["output_layers"], wrap=True
)
for layer_data in tf.nest.flatten(output_layers):
layer_name, node_index, tensor_index = layer_data.as_list()
assert layer_name in created_layers
layer = created_layers[layer_name]
node_index = get_node_index(layer, node_index)
layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
output_tensors.append(
tf.nest.flatten(layer_output_tensors)[tensor_index]
)
input_tensors = tf.nest.pack_sequence_as(input_layers, input_tensors)
output_tensors = tf.nest.pack_sequence_as(output_layers, output_tensors)
return input_tensors, output_tensors, created_layers
def get_network_config(network, serialize_layer_fn=None, config=None):
"""Build the config, which consists of the node graph and serialized layers.
Args:
network: A Network object.
serialize_layer_fn: Function used to serialize layers.
config: A dict to append more config entries into. If None, start with a
new dict for the config.
Returns:
Config dictionary.
"""
serialize_layer_fn = (
serialize_layer_fn or serialization.serialize_keras_object
)
config = config or {}
config["name"] = network.name
node_conversion_map = {}
for layer in network.layers:
kept_nodes = 1 if _should_skip_first_node(layer) else 0
for original_node_index, node in enumerate(layer._inbound_nodes):
node_key = _make_node_key(layer.name, original_node_index)
if node_key in network._network_nodes:
node_conversion_map[node_key] = kept_nodes
kept_nodes += 1
layer_configs = []
with serialization.SharedObjectSavingScope():
for layer in network.layers: # From the earliest layers on.
filtered_inbound_nodes = []
for original_node_index, node in enumerate(layer._inbound_nodes):
node_key = _make_node_key(layer.name, original_node_index)
if node_key in network._network_nodes and not node.is_input:
# The node is relevant to the model:
# add to filtered_inbound_nodes.
node_data = node.serialize(
_make_node_key, node_conversion_map
)
filtered_inbound_nodes.append(node_data)
layer_config = serialize_layer_fn(layer)
layer_config["name"] = layer.name
layer_config["inbound_nodes"] = filtered_inbound_nodes
layer_configs.append(layer_config)
config["layers"] = layer_configs
# Gather info about inputs and outputs.
model_inputs = []
for i in range(len(network._input_layers)):
layer, node_index, tensor_index = network._input_coordinates[i]
node_key = _make_node_key(layer.name, node_index)
if node_key not in network._network_nodes:
continue
new_node_index = node_conversion_map[node_key]
model_inputs.append(
tf_utils.ListWrapper([layer.name, new_node_index, tensor_index])
)
model_inputs = tf.nest.pack_sequence_as(
network._nested_inputs, model_inputs
)
# Preserve external Keras compat for Models with single input.
if not tf.nest.is_nested(model_inputs):
model_inputs = [model_inputs]
model_inputs = tf_utils.convert_inner_node_data(model_inputs)
config["input_layers"] = model_inputs
model_outputs = []
for i in range(len(network._output_layers)):
layer, node_index, tensor_index = network._output_coordinates[i]
node_key = _make_node_key(layer.name, node_index)
if node_key not in network._network_nodes:
continue
new_node_index = node_conversion_map[node_key]
model_outputs.append(
tf_utils.ListWrapper([layer.name, new_node_index, tensor_index])
)
model_outputs = tf.nest.pack_sequence_as(
network._nested_outputs, model_outputs
)
# Preserve external Keras compat for Models with single output.
if not tf.nest.is_nested(model_outputs):
model_outputs = [model_outputs]
model_outputs = tf_utils.convert_inner_node_data(model_outputs)
config["output_layers"] = model_outputs
return config
def shape_with_no_batch_size(x):
if x.shape.rank is None:
return None
shape = x.shape.as_list()
if shape:
shape[0] = None
return shape
class ModuleWrapper(base_layer.Layer):
"""Wrapper for `tf.Module`s to support the Functional and Sequential API."""
def __init__(self, module, method_name=None, **kwargs):
"""Initializes the wrapper Layer for this module.
Args:
module: The `tf.Module` instance to be wrapped.
method_name: (Optional) str. The name of the method to use as the
forward pass of the module. If not set, defaults to '__call__' if
defined, or 'call'.
**kwargs: Additional keywrod arguments. See `tf.keras.layers.Layer`.
Raises:
ValueError: If `method` is not defined on `module`.
"""
super().__init__(**kwargs)
if method_name is None:
if hasattr(module, "__call__"):
method_name = "__call__"
elif hasattr(module, "call"):
method_name = "call"
if method_name is None or not hasattr(module, method_name):
raise ValueError(f"{method_name} is not defined on object {module}")
self._module = module
self._method_name = method_name
# Check if module.__call__ has a `training` arg or accepts `**kwargs`.
method = getattr(module, method_name)
method_arg_spec = tf_inspect.getfullargspec(method)
self._call_spec.expects_training_arg = (
"training" in method_arg_spec.args
or method_arg_spec.varkw is not None
)
self._call_spec.expects_mask_arg = (
"mask" in method_arg_spec.args or method_arg_spec.varkw is not None
)
def call(self, *args, **kwargs):
if "training" in kwargs and not self._expects_training_arg:
kwargs.pop("training")
if "mask" in kwargs and not self._expects_mask_arg:
kwargs.pop("mask")
return getattr(self._module, self._method_name)(*args, **kwargs)
def has_functional_like_constructor(cls):
init_args = tf_inspect.getfullargspec(cls.__init__).args[1:]
functional_init_args = tf_inspect.getfullargspec(Functional.__init__).args[
1:
]
if init_args == functional_init_args:
return True
return False