561 lines
23 KiB
Python
561 lines
23 KiB
Python
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
|
||
|
"""Home of the `Sequential` model."""
|
||
|
|
||
|
import copy
|
||
|
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
|
||
|
from keras import layers as layer_module
|
||
|
from keras.engine import base_layer
|
||
|
from keras.engine import functional
|
||
|
from keras.engine import input_layer
|
||
|
from keras.engine import training
|
||
|
from keras.engine import training_utils
|
||
|
from keras.saving.legacy import serialization
|
||
|
from keras.saving.legacy.saved_model import model_serialization
|
||
|
from keras.utils import generic_utils
|
||
|
from keras.utils import layer_utils
|
||
|
from keras.utils import tf_inspect
|
||
|
from keras.utils import tf_utils
|
||
|
from keras.utils import traceback_utils
|
||
|
|
||
|
# isort: off
|
||
|
from tensorflow.python.platform import tf_logging as logging
|
||
|
from tensorflow.python.util.tf_export import keras_export
|
||
|
|
||
|
SINGLE_LAYER_OUTPUT_ERROR_MSG = (
|
||
|
"All layers in a Sequential model should have "
|
||
|
"a single output tensor. For multi-output "
|
||
|
"layers, use the functional API."
|
||
|
)
|
||
|
|
||
|
|
||
|
@keras_export("keras.Sequential", "keras.models.Sequential")
|
||
|
class Sequential(functional.Functional):
|
||
|
"""`Sequential` groups a linear stack of layers into a `tf.keras.Model`.
|
||
|
|
||
|
`Sequential` provides training and inference features on this model.
|
||
|
|
||
|
Examples:
|
||
|
|
||
|
```python
|
||
|
# Optionally, the first layer can receive an `input_shape` argument:
|
||
|
model = tf.keras.Sequential()
|
||
|
model.add(tf.keras.layers.Dense(8, input_shape=(16,)))
|
||
|
# Afterwards, we do automatic shape inference:
|
||
|
model.add(tf.keras.layers.Dense(4))
|
||
|
|
||
|
# This is identical to the following:
|
||
|
model = tf.keras.Sequential()
|
||
|
model.add(tf.keras.Input(shape=(16,)))
|
||
|
model.add(tf.keras.layers.Dense(8))
|
||
|
|
||
|
# Note that you can also omit the `input_shape` argument.
|
||
|
# In that case the model doesn't have any weights until the first call
|
||
|
# to a training/evaluation method (since it isn't yet built):
|
||
|
model = tf.keras.Sequential()
|
||
|
model.add(tf.keras.layers.Dense(8))
|
||
|
model.add(tf.keras.layers.Dense(4))
|
||
|
# model.weights not created yet
|
||
|
|
||
|
# Whereas if you specify the input shape, the model gets built
|
||
|
# continuously as you are adding layers:
|
||
|
model = tf.keras.Sequential()
|
||
|
model.add(tf.keras.layers.Dense(8, input_shape=(16,)))
|
||
|
model.add(tf.keras.layers.Dense(4))
|
||
|
len(model.weights)
|
||
|
# Returns "4"
|
||
|
|
||
|
# When using the delayed-build pattern (no input shape specified), you can
|
||
|
# choose to manually build your model by calling
|
||
|
# `build(batch_input_shape)`:
|
||
|
model = tf.keras.Sequential()
|
||
|
model.add(tf.keras.layers.Dense(8))
|
||
|
model.add(tf.keras.layers.Dense(4))
|
||
|
model.build((None, 16))
|
||
|
len(model.weights)
|
||
|
# Returns "4"
|
||
|
|
||
|
# Note that when using the delayed-build pattern (no input shape specified),
|
||
|
# the model gets built the first time you call `fit`, `eval`, or `predict`,
|
||
|
# or the first time you call the model on some input data.
|
||
|
model = tf.keras.Sequential()
|
||
|
model.add(tf.keras.layers.Dense(8))
|
||
|
model.add(tf.keras.layers.Dense(1))
|
||
|
model.compile(optimizer='sgd', loss='mse')
|
||
|
# This builds the model for the first time:
|
||
|
model.fit(x, y, batch_size=32, epochs=10)
|
||
|
```
|
||
|
"""
|
||
|
|
||
|
@tf.__internal__.tracking.no_automatic_dependency_tracking
|
||
|
@traceback_utils.filter_traceback
|
||
|
def __init__(self, layers=None, name=None):
|
||
|
"""Creates a `Sequential` model instance.
|
||
|
|
||
|
Args:
|
||
|
layers: Optional list of layers to add to the model.
|
||
|
name: Optional name for the model.
|
||
|
"""
|
||
|
# Skip the init in FunctionalModel since model doesn't have input/output
|
||
|
# yet
|
||
|
super(functional.Functional, self).__init__(name=name, autocast=False)
|
||
|
base_layer.keras_api_gauge.get_cell("Sequential").set(True)
|
||
|
self.supports_masking = True
|
||
|
self._compute_output_and_mask_jointly = True
|
||
|
self._auto_track_sub_layers = False
|
||
|
self._inferred_input_shape = None
|
||
|
self._has_explicit_input_shape = False
|
||
|
self._input_dtype = None
|
||
|
self._layer_call_argspecs = {}
|
||
|
self._created_nodes = set()
|
||
|
# Flag that indicate whether the sequential network topology has been
|
||
|
# created. It is false when there isn't any layer, or the layers don't
|
||
|
# have an input shape.
|
||
|
self._graph_initialized = False
|
||
|
|
||
|
# Unfortunately some Sequential models using custom layers or
|
||
|
# FeatureColumn layers have multiple inputs. This is fundamentally
|
||
|
# incompatible with most of the Sequential API, and we have to disable a
|
||
|
# number of features for such models.
|
||
|
self._use_legacy_deferred_behavior = False
|
||
|
|
||
|
# Add to the model any layers passed to the constructor.
|
||
|
if layers:
|
||
|
if not isinstance(layers, (list, tuple)):
|
||
|
layers = [layers]
|
||
|
for layer in layers:
|
||
|
self.add(layer)
|
||
|
|
||
|
@property
|
||
|
def layers(self):
|
||
|
# Historically, `sequential.layers` only returns layers that were added
|
||
|
# via `add`, and omits the auto-generated `InputLayer` that comes at the
|
||
|
# bottom of the stack.
|
||
|
# `Trackable` manages the `_layers` attributes and does filtering
|
||
|
# over it.
|
||
|
layers = super().layers
|
||
|
if layers and isinstance(layers[0], input_layer.InputLayer):
|
||
|
return layers[1:]
|
||
|
return layers[:]
|
||
|
|
||
|
@tf.__internal__.tracking.no_automatic_dependency_tracking
|
||
|
@traceback_utils.filter_traceback
|
||
|
def add(self, layer):
|
||
|
"""Adds a layer instance on top of the layer stack.
|
||
|
|
||
|
Args:
|
||
|
layer: layer instance.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: If `layer` is not a layer instance.
|
||
|
ValueError: In case the `layer` argument does not
|
||
|
know its input shape.
|
||
|
ValueError: In case the `layer` argument has
|
||
|
multiple output tensors, or is already connected
|
||
|
somewhere else (forbidden in `Sequential` models).
|
||
|
"""
|
||
|
# If we are passed a Keras tensor created by keras.Input(), we can
|
||
|
# extract the input layer from its keras history and use that without
|
||
|
# any loss of
|
||
|
# generality.
|
||
|
if hasattr(layer, "_keras_history"):
|
||
|
origin_layer = layer._keras_history[0]
|
||
|
if isinstance(origin_layer, input_layer.InputLayer):
|
||
|
layer = origin_layer
|
||
|
|
||
|
if isinstance(layer, tf.Module):
|
||
|
if not isinstance(layer, base_layer.Layer):
|
||
|
layer = functional.ModuleWrapper(layer)
|
||
|
else:
|
||
|
raise TypeError(
|
||
|
"The added layer must be an instance of class Layer. "
|
||
|
f"Received: layer={layer} of type {type(layer)}."
|
||
|
)
|
||
|
|
||
|
tf_utils.assert_no_legacy_layers([layer])
|
||
|
if not self._is_layer_name_unique(layer):
|
||
|
raise ValueError(
|
||
|
"All layers added to a Sequential model "
|
||
|
f'should have unique names. Name "{layer.name}" is already '
|
||
|
"the name of a layer in this model. Update the `name` argument "
|
||
|
"to pass a unique name."
|
||
|
)
|
||
|
|
||
|
self.built = False
|
||
|
set_inputs = False
|
||
|
self._maybe_create_attribute("_self_tracked_trackables", [])
|
||
|
if not self._self_tracked_trackables:
|
||
|
if isinstance(layer, input_layer.InputLayer):
|
||
|
# Case where the user passes an Input or InputLayer layer via
|
||
|
# `add`.
|
||
|
set_inputs = True
|
||
|
else:
|
||
|
batch_shape, dtype = training_utils.get_input_shape_and_dtype(
|
||
|
layer
|
||
|
)
|
||
|
if batch_shape:
|
||
|
# Instantiate an input layer.
|
||
|
x = input_layer.Input(
|
||
|
batch_shape=batch_shape,
|
||
|
dtype=dtype,
|
||
|
name=layer.name + "_input",
|
||
|
)
|
||
|
# This will build the current layer
|
||
|
# and create the node connecting the current layer
|
||
|
# to the input layer we just created.
|
||
|
layer(x)
|
||
|
set_inputs = True
|
||
|
|
||
|
if set_inputs:
|
||
|
outputs = tf.nest.flatten(layer._inbound_nodes[-1].outputs)
|
||
|
if len(outputs) != 1:
|
||
|
raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG)
|
||
|
self.outputs = outputs
|
||
|
self.inputs = layer_utils.get_source_inputs(self.outputs[0])
|
||
|
self.built = True
|
||
|
self._has_explicit_input_shape = True
|
||
|
|
||
|
elif self.outputs:
|
||
|
# If the model is being built continuously on top of an input layer:
|
||
|
# refresh its output.
|
||
|
output_tensor = layer(self.outputs[0])
|
||
|
if len(tf.nest.flatten(output_tensor)) != 1:
|
||
|
raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG)
|
||
|
self.outputs = [output_tensor]
|
||
|
self.built = True
|
||
|
|
||
|
if set_inputs or self._graph_initialized:
|
||
|
self._init_graph_network(self.inputs, self.outputs)
|
||
|
self._graph_initialized = True
|
||
|
else:
|
||
|
self._self_tracked_trackables.append(layer)
|
||
|
self._handle_deferred_layer_dependencies([layer])
|
||
|
|
||
|
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
|
||
|
|
||
|
@tf.__internal__.tracking.no_automatic_dependency_tracking
|
||
|
@traceback_utils.filter_traceback
|
||
|
def pop(self):
|
||
|
"""Removes the last layer in the model.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: if there are no layers in the model.
|
||
|
"""
|
||
|
if not self.layers:
|
||
|
raise TypeError("There are no layers in the model.")
|
||
|
|
||
|
layer = self._self_tracked_trackables.pop()
|
||
|
self._layer_call_argspecs.pop(layer)
|
||
|
if not self.layers:
|
||
|
self.outputs = None
|
||
|
self.inputs = None
|
||
|
self.built = False
|
||
|
self._inferred_input_shape = None
|
||
|
self._has_explicit_input_shape = False
|
||
|
self._graph_initialized = False
|
||
|
elif self._graph_initialized:
|
||
|
self.layers[-1]._outbound_nodes = []
|
||
|
self.outputs = [self.layers[-1].output]
|
||
|
self._init_graph_network(self.inputs, self.outputs)
|
||
|
self.built = True
|
||
|
|
||
|
@tf.__internal__.tracking.no_automatic_dependency_tracking
|
||
|
def _build_graph_network_for_inferred_shape(
|
||
|
self, input_shape, input_dtype=None
|
||
|
):
|
||
|
if input_shape is None or not self.layers:
|
||
|
return
|
||
|
if (
|
||
|
not tf.__internal__.tf2.enabled()
|
||
|
or not tf.compat.v1.executing_eagerly_outside_functions()
|
||
|
):
|
||
|
# This behavior is disabled in V1 or when eager execution is
|
||
|
# disabled.
|
||
|
return
|
||
|
if (
|
||
|
not self._has_explicit_input_shape
|
||
|
and not self._use_legacy_deferred_behavior
|
||
|
):
|
||
|
# Determine whether the input shape is novel, i.e. whether the model
|
||
|
# should be rebuilt.
|
||
|
input_shape = tuple(input_shape)
|
||
|
if self._inferred_input_shape is None:
|
||
|
new_shape = input_shape
|
||
|
else:
|
||
|
new_shape = relax_input_shape(
|
||
|
self._inferred_input_shape, input_shape
|
||
|
)
|
||
|
if (
|
||
|
new_shape is not None
|
||
|
and new_shape != self._inferred_input_shape
|
||
|
):
|
||
|
# A novel shape has been received: we need to rebuild the model.
|
||
|
# In case we are inside a graph function, we step out of it.
|
||
|
with tf.init_scope():
|
||
|
inputs = input_layer.Input(
|
||
|
batch_shape=new_shape,
|
||
|
dtype=input_dtype,
|
||
|
name=self.layers[0].name + "_input",
|
||
|
)
|
||
|
layer_input = inputs
|
||
|
created_nodes = set()
|
||
|
for layer in self.layers:
|
||
|
# Clear nodes previously created via this method. This
|
||
|
# prevents node accumulation and ensures that e.g.
|
||
|
# `layer.output` is always connected to `model.inputs`
|
||
|
# (this is important e.g. for the feature extraction use
|
||
|
# case). We don't just do `layer._inbound_nodes = []`
|
||
|
# in order not to break shared layers added to
|
||
|
# Sequential models (which is technically illegal as per
|
||
|
# the `add()` docstring, but wasn't previously
|
||
|
# disabled).
|
||
|
clear_previously_created_nodes(
|
||
|
layer, self._created_nodes
|
||
|
)
|
||
|
try:
|
||
|
# Create Functional API connection by calling the
|
||
|
# current layer
|
||
|
layer_output = layer(layer_input)
|
||
|
except: # noqa: E722
|
||
|
# Functional API calls may fail for a number of
|
||
|
# reasons: 1) The layer may be buggy. In this case
|
||
|
# it will be easier for the user to debug if we fail
|
||
|
# on the first call on concrete data, instead of our
|
||
|
# own call on a symbolic input. 2) The layer is
|
||
|
# dynamic (graph-incompatible) and hasn't overridden
|
||
|
# `compute_output_shape`. In this case, it is
|
||
|
# impossible to build a graph network. 3) The layer
|
||
|
# is otherwise incompatible with the Functional API
|
||
|
# (e.g. this is the case for some probabilistic
|
||
|
# layers that rely on hacks and that do not return
|
||
|
# tensors). In all these cases, we should avoid
|
||
|
# creating a graph network (or we simply can't).
|
||
|
self._use_legacy_deferred_behavior = True
|
||
|
return
|
||
|
if len(tf.nest.flatten(layer_output)) != 1:
|
||
|
raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG)
|
||
|
# Keep track of nodes just created above
|
||
|
track_nodes_created_by_last_call(layer, created_nodes)
|
||
|
layer_input = layer_output
|
||
|
outputs = layer_output
|
||
|
self._created_nodes = created_nodes
|
||
|
try:
|
||
|
# Initialize a graph Network. This call will never fail
|
||
|
# for a stack of valid Keras layers. However some users
|
||
|
# have layers that are fundamentally incompatible with
|
||
|
# the Functional API, which do not return tensors. In
|
||
|
# this case, we fall back to the legacy deferred
|
||
|
# behavior.
|
||
|
# TODO(fchollet): consider raising here, as we should
|
||
|
# not be supporting such layers.
|
||
|
self._init_graph_network(inputs, outputs)
|
||
|
self._graph_initialized = True
|
||
|
except: # noqa: E722
|
||
|
self._use_legacy_deferred_behavior = True
|
||
|
self._inferred_input_shape = new_shape
|
||
|
|
||
|
@generic_utils.default
|
||
|
def build(self, input_shape=None):
|
||
|
if self._graph_initialized:
|
||
|
self._init_graph_network(self.inputs, self.outputs)
|
||
|
else:
|
||
|
if input_shape is None:
|
||
|
raise ValueError("You must provide an `input_shape` argument.")
|
||
|
self._build_graph_network_for_inferred_shape(input_shape)
|
||
|
if not self.built:
|
||
|
input_shape = tuple(input_shape)
|
||
|
self._build_input_shape = input_shape
|
||
|
super().build(input_shape)
|
||
|
self.built = True
|
||
|
|
||
|
def call(self, inputs, training=None, mask=None):
|
||
|
# If applicable, update the static input shape of the model.
|
||
|
if not self._has_explicit_input_shape:
|
||
|
if not tf.is_tensor(inputs) and not isinstance(inputs, tf.Tensor):
|
||
|
# This is a Sequential with multiple inputs. This is technically
|
||
|
# an invalid use case of Sequential, but we tolerate it for
|
||
|
# backwards compatibility.
|
||
|
self._use_legacy_deferred_behavior = True
|
||
|
self._build_input_shape = tf.nest.map_structure(
|
||
|
_get_shape_tuple, inputs
|
||
|
)
|
||
|
if tf.__internal__.tf2.enabled():
|
||
|
logging.warning(
|
||
|
"Layers in a Sequential model should only have a "
|
||
|
f"single input tensor. Received: inputs={inputs}. "
|
||
|
"Consider rewriting this model with the Functional "
|
||
|
"API."
|
||
|
)
|
||
|
else:
|
||
|
self._build_graph_network_for_inferred_shape(
|
||
|
inputs.shape, inputs.dtype
|
||
|
)
|
||
|
|
||
|
if self._graph_initialized:
|
||
|
if not self.built:
|
||
|
self._init_graph_network(self.inputs, self.outputs)
|
||
|
return super().call(inputs, training=training, mask=mask)
|
||
|
|
||
|
outputs = inputs # handle the corner case where self.layers is empty
|
||
|
for layer in self.layers:
|
||
|
# During each iteration, `inputs` are the inputs to `layer`, and
|
||
|
# `outputs` are the outputs of `layer` applied to `inputs`. At the
|
||
|
# end of each iteration `inputs` is set to `outputs` to prepare for
|
||
|
# the next layer.
|
||
|
kwargs = {}
|
||
|
argspec = self._layer_call_argspecs[layer].args
|
||
|
if "mask" in argspec:
|
||
|
kwargs["mask"] = mask
|
||
|
if "training" in argspec:
|
||
|
kwargs["training"] = training
|
||
|
|
||
|
outputs = layer(inputs, **kwargs)
|
||
|
|
||
|
inputs = outputs
|
||
|
|
||
|
def _get_mask_from_keras_tensor(kt):
|
||
|
return getattr(kt, "_keras_mask", None)
|
||
|
|
||
|
mask = tf.nest.map_structure(_get_mask_from_keras_tensor, outputs)
|
||
|
return outputs
|
||
|
|
||
|
def compute_output_shape(self, input_shape):
|
||
|
shape = input_shape
|
||
|
for layer in self.layers:
|
||
|
shape = layer.compute_output_shape(shape)
|
||
|
return shape
|
||
|
|
||
|
def compute_mask(self, inputs, mask):
|
||
|
# TODO(omalleyt): b/123540974 This function is not really safe to call
|
||
|
# by itself because it will duplicate any updates and losses in graph
|
||
|
# mode by `call`ing the Layers again.
|
||
|
outputs = self.call(inputs, mask=mask)
|
||
|
return getattr(outputs, "_keras_mask", None)
|
||
|
|
||
|
def get_config(self):
|
||
|
layer_configs = []
|
||
|
for layer in super().layers:
|
||
|
# `super().layers` include the InputLayer if available (it is
|
||
|
# filtered out of `self.layers`). Note that
|
||
|
# `self._self_tracked_trackables` is managed by the tracking
|
||
|
# infrastructure and should not be used.
|
||
|
layer_configs.append(serialization.serialize_keras_object(layer))
|
||
|
config = training.Model.get_config(self)
|
||
|
config["name"] = self.name
|
||
|
config["layers"] = copy.deepcopy(layer_configs)
|
||
|
if not self._is_graph_network and self._build_input_shape is not None:
|
||
|
config["build_input_shape"] = self._build_input_shape
|
||
|
return config
|
||
|
|
||
|
@classmethod
|
||
|
def from_config(cls, config, custom_objects=None):
|
||
|
if "name" in config:
|
||
|
name = config["name"]
|
||
|
build_input_shape = config.get("build_input_shape")
|
||
|
layer_configs = config["layers"]
|
||
|
else:
|
||
|
name = None
|
||
|
layer_configs = config
|
||
|
model = cls(name=name)
|
||
|
for layer_config in layer_configs:
|
||
|
layer = layer_module.deserialize(
|
||
|
layer_config, custom_objects=custom_objects
|
||
|
)
|
||
|
model.add(layer)
|
||
|
|
||
|
if (
|
||
|
not model.inputs
|
||
|
and build_input_shape
|
||
|
and isinstance(build_input_shape, (tuple, list))
|
||
|
):
|
||
|
model.build(build_input_shape)
|
||
|
|
||
|
return model
|
||
|
|
||
|
@property
|
||
|
def input_spec(self):
|
||
|
if hasattr(self, "_manual_input_spec"):
|
||
|
return self._manual_input_spec
|
||
|
if self._has_explicit_input_shape:
|
||
|
return super().input_spec
|
||
|
return None
|
||
|
|
||
|
@input_spec.setter
|
||
|
def input_spec(self, value):
|
||
|
self._manual_input_spec = value
|
||
|
|
||
|
@property
|
||
|
def _trackable_saved_model_saver(self):
|
||
|
return model_serialization.SequentialSavedModelSaver(self)
|
||
|
|
||
|
def _is_layer_name_unique(self, layer):
|
||
|
for ref_layer in self.layers:
|
||
|
if layer.name == ref_layer.name and ref_layer is not layer:
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
def _assert_weights_created(self):
|
||
|
if self._graph_initialized:
|
||
|
return
|
||
|
# When the graph has not been initialized, use the Model's
|
||
|
# implementation to to check if the weights has been created.
|
||
|
super(functional.Functional, self)._assert_weights_created()
|
||
|
|
||
|
|
||
|
def _get_shape_tuple(t):
|
||
|
if hasattr(t, "shape"):
|
||
|
shape = t.shape
|
||
|
if isinstance(shape, tuple):
|
||
|
return shape
|
||
|
if shape.rank is not None:
|
||
|
return tuple(shape.as_list())
|
||
|
return None
|
||
|
return None
|
||
|
|
||
|
|
||
|
def relax_input_shape(shape_1, shape_2):
|
||
|
if shape_1 is None or shape_2 is None:
|
||
|
return None
|
||
|
if len(shape_1) != len(shape_2):
|
||
|
return None
|
||
|
return tuple(None if d1 != d2 else d1 for d1, d2 in zip(shape_1, shape_2))
|
||
|
|
||
|
|
||
|
def clear_previously_created_nodes(layer, created_nodes):
|
||
|
"""Remove nodes from `created_nodes` from the layer's inbound_nodes."""
|
||
|
for node in layer._inbound_nodes:
|
||
|
prev_layers = node.inbound_layers
|
||
|
for prev_layer in tf.nest.flatten(prev_layers):
|
||
|
prev_layer._outbound_nodes = [
|
||
|
n for n in prev_layer._outbound_nodes if n not in created_nodes
|
||
|
]
|
||
|
layer._inbound_nodes = [
|
||
|
n for n in layer._inbound_nodes if n not in created_nodes
|
||
|
]
|
||
|
|
||
|
|
||
|
def track_nodes_created_by_last_call(layer, created_nodes):
|
||
|
"""Adds to `created_nodes` the nodes created by the last call to `layer`."""
|
||
|
if not layer._inbound_nodes:
|
||
|
return
|
||
|
created_nodes.add(layer._inbound_nodes[-1])
|
||
|
prev_layers = layer._inbound_nodes[-1].inbound_layers
|
||
|
for prev_layer in tf.nest.flatten(prev_layers):
|
||
|
if prev_layer._outbound_nodes:
|
||
|
created_nodes.add(prev_layer._outbound_nodes[-1])
|