345 lines
14 KiB
Python
345 lines
14 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
|
|
"""Contains the `Node` class."""
|
|
|
|
import collections
|
|
import copy
|
|
import json
|
|
|
|
import numpy as np
|
|
import tensorflow.compat.v2 as tf
|
|
|
|
from keras import backend
|
|
from keras.engine import base_layer_utils
|
|
from keras.saving.legacy.saved_model import json_utils
|
|
from keras.utils import tf_utils
|
|
|
|
_CONSTANT_VALUE = "_CONSTANT_VALUE"
|
|
# Using dict to avoid conflict with constant string tensor.
|
|
_COMPOSITE_TYPE = {"_TYPE": "COMPOSITE"}
|
|
|
|
|
|
class Node:
|
|
"""A `Node` describes a layer `__call__()` event.
|
|
|
|
A Functional model is a DAG with `Node` instances as nodes, and
|
|
`KerasTensor` instances as edges. Nodes aren't `Layer` instances, because a
|
|
single layer could be called multiple times, which would result in graph
|
|
cycles.
|
|
|
|
A `__call__()` event involves input tensors (and other input arguments),
|
|
the layer that was called, and the resulting output tensors.
|
|
A `Node` will include all this information.
|
|
|
|
Since a single `Layer` could be called multiple times, the `Node` instances
|
|
are stored on layers as a list. Each time a layer is called a node is added
|
|
to `layer._inbound_nodes`. Each time the output of a layer is used by
|
|
another layer, a node is added to `layer._outbound_nodes`.
|
|
|
|
Every `KerasTensor` instance has a `KerasHistory` object attached,
|
|
which tracks the `Node` that records the `__call__()` event that created
|
|
the tensor. By recursively walking through `Node` instances
|
|
via the `KerasHistory` metadata of `KerasTensor` instances, once can
|
|
retrieve the entire DAG of a Functional model.
|
|
|
|
Args:
|
|
layer: The layer that was called in the `Layer.__call__()`
|
|
event that this node represents.
|
|
call_args: The positional arguments the layer was called with.
|
|
call_kwargs: The keyword arguments the layer was called with.
|
|
outputs: The output tensors of the `Layer.__call__()`
|
|
"""
|
|
|
|
def __init__(self, layer, call_args=None, call_kwargs=None, outputs=None):
|
|
call_args = [] if call_args is None else call_args
|
|
call_kwargs = {} if call_kwargs is None else call_kwargs
|
|
outputs = [] if outputs is None else outputs
|
|
|
|
self.layer = layer
|
|
self.is_input = not call_args and not call_kwargs
|
|
|
|
# These arguments are user-provided. Copy the structures here so that
|
|
# future user modifications do not affect the node's metadata.
|
|
# We copy using map_structure rather than python's shallow or deep copy,
|
|
# because the args can be data structures (so shallow copy is
|
|
# insufficient), but individual values might not support copy.copy
|
|
# or be too expensive to deep copy.
|
|
call_args = tf.nest.map_structure(lambda t: t, call_args)
|
|
call_kwargs = tf.nest.map_structure(lambda t: t, call_kwargs)
|
|
self.outputs = tf.nest.map_structure(lambda t: t, outputs)
|
|
self.call_args = call_args
|
|
self.call_kwargs = call_kwargs
|
|
|
|
# Cached for performance.
|
|
self._flat_arguments = tf.nest.flatten(
|
|
(self.call_args, self.call_kwargs)
|
|
)
|
|
# Used to avoid expensive `nest` operations in the most common case.
|
|
self._single_positional_tensor_passed = (
|
|
not self.call_kwargs
|
|
and len(self.call_args) == 1
|
|
and tf.is_tensor(self.call_args[0])
|
|
)
|
|
|
|
if not tf.compat.v1.executing_eagerly_outside_functions():
|
|
# Create TensorFlowOpLayers if needed (in TF1)
|
|
for obj in self._flat_arguments:
|
|
if isinstance(
|
|
obj, tf.Tensor
|
|
) and base_layer_utils.needs_keras_history(
|
|
obj, ignore_call_context=True
|
|
):
|
|
base_layer_utils.create_keras_history(obj)
|
|
|
|
self._keras_inputs = []
|
|
self._keras_inputs_ids_and_indices = []
|
|
for i, ele in enumerate(self._flat_arguments):
|
|
if is_keras_tensor(ele):
|
|
self._keras_inputs.append(ele)
|
|
kt_id = str(id(ele))
|
|
kt_index = i
|
|
self._keras_inputs_ids_and_indices.append((kt_id, kt_index))
|
|
|
|
# Wire up Node to Layers.
|
|
self.layer._inbound_nodes.append(self)
|
|
for kt in self.keras_inputs:
|
|
inbound_layer = kt._keras_history.layer
|
|
if inbound_layer is not None: # `None` for `Input` tensors.
|
|
inbound_layer._outbound_nodes.append(self)
|
|
|
|
# Set metadata on outputs.
|
|
node_index = len(self.layer._inbound_nodes) - 1
|
|
for i, tensor in enumerate(tf.nest.flatten(outputs)):
|
|
tensor._keras_history = KerasHistory(
|
|
layer=layer, node_index=node_index, tensor_index=i
|
|
)
|
|
|
|
# Cached for performance.
|
|
self.flat_input_ids = [str(id(t)) for t in self._keras_inputs]
|
|
self.flat_output_ids = [
|
|
str(id(t)) for t in tf.nest.flatten(self.outputs)
|
|
]
|
|
|
|
@property
|
|
def keras_inputs(self):
|
|
"""Tensors input to this node that can be traced back to a
|
|
`keras.Input`."""
|
|
return self._keras_inputs
|
|
|
|
@property
|
|
def parent_nodes(self):
|
|
"""Returns all the `Node`s whose output this node immediately depends
|
|
on."""
|
|
node_deps = []
|
|
for kt in self.keras_inputs:
|
|
layer = kt._keras_history.layer
|
|
node_index = kt._keras_history.node_index
|
|
if layer is not None: # `None` for `Input` tensors.
|
|
node_deps.append(layer._inbound_nodes[node_index])
|
|
return node_deps
|
|
|
|
def iterate_inbound(self):
|
|
"""Yields tuples representing the data inbound from other nodes.
|
|
|
|
Yields:
|
|
tuples like: (inbound_layer, node_index, tensor_index, tensor).
|
|
"""
|
|
for kt in self.keras_inputs:
|
|
keras_history = kt._keras_history
|
|
layer = keras_history.layer
|
|
node_index = keras_history.node_index
|
|
tensor_index = keras_history.tensor_index
|
|
yield layer, node_index, tensor_index, kt
|
|
|
|
def map_arguments(self, tensor_dict):
|
|
"""Maps Keras Tensors to computed Tensors using `tensor_dict`."""
|
|
if self._single_positional_tensor_passed:
|
|
# Performance optimization for most common case.
|
|
kt_id, _ = self._keras_inputs_ids_and_indices[0]
|
|
return (tensor_dict[kt_id].pop(),), {}
|
|
else:
|
|
flat_arguments = copy.copy(self._flat_arguments)
|
|
for kt_id, kt_index in self._keras_inputs_ids_and_indices:
|
|
flat_arguments[kt_index] = tensor_dict[kt_id].pop()
|
|
|
|
args, kwargs = tf.nest.pack_sequence_as(
|
|
(self.call_args, self.call_kwargs), flat_arguments
|
|
)
|
|
return args, kwargs
|
|
|
|
def serialize(self, make_node_key, node_conversion_map):
|
|
"""Serializes `Node` for Functional API's `get_config`."""
|
|
# Serialization still special-cases first argument.
|
|
args, kwargs = self.call_args, self.call_kwargs
|
|
inputs, args, kwargs = self.layer._call_spec.split_out_first_arg(
|
|
args, kwargs
|
|
)
|
|
|
|
# Treat everything other than first argument as a kwarg.
|
|
arguments = dict(zip(self.layer._call_spec.arg_names[1:], args))
|
|
arguments.update(kwargs)
|
|
kwargs = arguments
|
|
|
|
def _serialize_keras_tensor(t):
|
|
"""Serializes a single Tensor passed to `call`."""
|
|
if hasattr(t, "_keras_history"):
|
|
kh = t._keras_history
|
|
node_index = kh.node_index
|
|
node_key = make_node_key(kh.layer.name, node_index)
|
|
new_node_index = node_conversion_map.get(node_key, 0)
|
|
return [kh.layer.name, new_node_index, kh.tensor_index]
|
|
|
|
if isinstance(t, np.ndarray):
|
|
return t.tolist()
|
|
|
|
if isinstance(t, tf.Tensor):
|
|
return backend.get_value(t).tolist()
|
|
|
|
# Not using json_utils to serialize both constant Tensor and
|
|
# constant CompositeTensor for saving format backward compatibility.
|
|
if isinstance(t, tf.__internal__.CompositeTensor):
|
|
return (_COMPOSITE_TYPE, json_utils.Encoder().encode(t))
|
|
|
|
return t
|
|
|
|
kwargs = tf.nest.map_structure(_serialize_keras_tensor, kwargs)
|
|
try:
|
|
json.dumps(kwargs, default=json_utils.get_json_type)
|
|
except TypeError:
|
|
kwarg_types = tf.nest.map_structure(type, kwargs)
|
|
raise TypeError(
|
|
"Layer "
|
|
+ self.layer.name
|
|
+ " was passed non-JSON-serializable arguments. "
|
|
+ "Arguments had types: "
|
|
+ str(kwarg_types)
|
|
+ ". They cannot be serialized out when saving the model."
|
|
)
|
|
|
|
# `kwargs` is added to each Tensor in the first arg. This should be
|
|
# changed in a future version of the serialization format.
|
|
def serialize_first_arg_tensor(t):
|
|
if is_keras_tensor(t):
|
|
kh = t._keras_history
|
|
node_index = kh.node_index
|
|
node_key = make_node_key(kh.layer.name, node_index)
|
|
new_node_index = node_conversion_map.get(node_key, 0)
|
|
data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs]
|
|
else:
|
|
# If an element in the first call argument did not originate as
|
|
# a keras tensor and is a constant value, we save it using the
|
|
# format ['_CONSTANT_VALUE', -1,
|
|
# serialized_tensor_or_python_constant] (potentially including
|
|
# serialized kwargs in an optional 4th argument).
|
|
data = [_CONSTANT_VALUE, -1, _serialize_keras_tensor(t), kwargs]
|
|
return tf_utils.ListWrapper(data)
|
|
|
|
data = tf.nest.map_structure(serialize_first_arg_tensor, inputs)
|
|
if (
|
|
not tf.nest.is_nested(data)
|
|
and not self.layer._preserve_input_structure_in_config
|
|
):
|
|
data = [data]
|
|
data = tf_utils.convert_inner_node_data(data)
|
|
return data
|
|
|
|
#############################################################
|
|
# Properties for Backwards compatibility.
|
|
# These only check the first input argument
|
|
# As nodes are internal, they may be removed in the future.
|
|
#############################################################
|
|
|
|
@property
|
|
def input_tensors(self):
|
|
if self.is_input:
|
|
return [self.outputs] # Used in `Layer.input`.
|
|
return self.call_args[0]
|
|
|
|
@property
|
|
def output_tensors(self):
|
|
if self.is_input:
|
|
return [self.outputs] # Used in `Layer.input`.
|
|
return self.outputs
|
|
|
|
@property
|
|
def input_shapes(self):
|
|
input_shapes = tf.nest.map_structure(
|
|
backend.int_shape, self.input_tensors
|
|
)
|
|
if len(input_shapes) == 1 and not self.is_input:
|
|
return input_shapes[0]
|
|
return input_shapes
|
|
|
|
@property
|
|
def output_shapes(self):
|
|
return tf.nest.map_structure(backend.int_shape, self.output_tensors)
|
|
|
|
@property
|
|
def outbound_layer(self):
|
|
return self.layer
|
|
|
|
@property
|
|
def inbound_layers(self):
|
|
"""Return all layers that feed into the current node."""
|
|
if self.is_input:
|
|
return []
|
|
tensor_call_args = [
|
|
x
|
|
for x in self._flat_arguments
|
|
if tf.is_tensor(x) and hasattr(x, "_keras_history")
|
|
]
|
|
inbound_layers = tf.nest.map_structure(
|
|
lambda t: t._keras_history.layer, tensor_call_args
|
|
)
|
|
if len(inbound_layers) == 1:
|
|
return inbound_layers[0]
|
|
return inbound_layers
|
|
|
|
|
|
class KerasHistory(
|
|
collections.namedtuple(
|
|
"KerasHistory", ["layer", "node_index", "tensor_index"]
|
|
)
|
|
):
|
|
"""Tracks the Layer call that created a Tensor, for Keras Graph Networks.
|
|
|
|
During construction of Keras Graph Networks, this metadata is added to
|
|
each Tensor produced as the output of a Layer, starting with an
|
|
`InputLayer`. This allows Keras to track how each Tensor was produced, and
|
|
this information is later retraced by the `keras.engine.Network` class to
|
|
reconstruct the Keras Graph Network.
|
|
|
|
Attributes:
|
|
layer: The Layer that produced the Tensor.
|
|
node_index: The specific call to the Layer that produced this Tensor.
|
|
Layers can be called multiple times in order to share weights. A new
|
|
node is created every time a Layer is called. The corresponding node
|
|
that represents the call event that produced the Tensor can be found at
|
|
`layer._inbound_nodes[node_index]`.
|
|
tensor_index: The output index for this Tensor. Always zero if the Layer
|
|
that produced this Tensor only has one output. Nested structures of
|
|
Tensors are deterministically assigned an index via `nest.flatten`.
|
|
"""
|
|
|
|
# Added to maintain memory and performance characteristics of `namedtuple`
|
|
# while subclassing.
|
|
__slots__ = ()
|
|
|
|
|
|
def is_keras_tensor(obj):
|
|
return hasattr(obj, "_keras_history")
|