261 lines
11 KiB
Python
261 lines
11 KiB
Python
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Utilities for keras functional model."""
|
|
|
|
import tensorflow.compat.v2 as tf
|
|
|
|
from keras import backend
|
|
from keras.engine import input_layer as input_layer_module
|
|
from keras.engine import keras_tensor
|
|
from keras.engine import node as node_module
|
|
|
|
_KERAS_TENSOR_TYPE_CHECK_ERROR_MSG = (
|
|
"Found unexpected instance while processing input tensors for keras "
|
|
"functional model. Expecting KerasTensor which is from tf.keras.Input() "
|
|
"or output from keras layer call(). Got: {}"
|
|
)
|
|
|
|
|
|
def is_input_keras_tensor(tensor):
|
|
"""Check if tensor is directly generated from `tf.keras.Input`.
|
|
|
|
This check is useful when constructing the functional model, since we will
|
|
need to clone Nodes and KerasTensors if the model is building from non input
|
|
tensor.
|
|
|
|
Args:
|
|
tensor: A `KerasTensor` as inputs to the functional model.
|
|
|
|
Returns:
|
|
bool. Whether the tensor is directly generated from `tf.keras.Input`.
|
|
|
|
Raises:
|
|
ValueError: if the tensor is not a KerasTensor instance.
|
|
"""
|
|
if not node_module.is_keras_tensor(tensor):
|
|
raise ValueError(_KERAS_TENSOR_TYPE_CHECK_ERROR_MSG.format(tensor))
|
|
return tensor.node.is_input
|
|
|
|
|
|
def find_nodes_by_inputs_and_outputs(inputs, outputs):
|
|
"""Fetch all Nodes in the graph defined by "inputs" and "outputs".
|
|
|
|
This method is used to find and then clone Nodes when creating a new
|
|
sub-model from an existing functional model.
|
|
|
|
Args:
|
|
inputs: A nested structure of KerasTensor to use as model inputs.
|
|
outputs: A nested structure of KerasTensor to use as model outputs.
|
|
|
|
Returns:
|
|
A list of Nodes that are connected to the inputs and outputs.
|
|
|
|
Raises:
|
|
ValueError: when inputs and outputs are disconnected or in case of
|
|
unexpected objects in the inputs/outputs.
|
|
"""
|
|
# We walk the graph bottom up, starting from output nodes, and keep tracing
|
|
# the upstream node, until we find all the inputs nodes. We don't use top
|
|
# down search here since we don't know whether a certain node is in the
|
|
# graph between inputs and outputs, e.g. a functional graph could have
|
|
# multiple outputs, and the user could choose a subset of them to build the
|
|
# model. The bottom up approach will ensure all the nodes we visit are
|
|
# actually in use. If we reach the top and didn't find the nodes in the
|
|
# `inputs`, that's an error, since the user didn't specify the correct
|
|
# inputs.
|
|
start_keras_tensors = tf.nest.flatten(outputs)
|
|
end_keras_tensors = tf.nest.flatten(inputs)
|
|
|
|
for t in start_keras_tensors + end_keras_tensors:
|
|
if not node_module.is_keras_tensor(t):
|
|
raise ValueError(_KERAS_TENSOR_TYPE_CHECK_ERROR_MSG.format(t))
|
|
end_ids = set([id(kt) for kt in end_keras_tensors])
|
|
# Track all the end tensors we found so far, if we didn't reach all the
|
|
# user-specified keras inputs after we finish the search, then that's an
|
|
# error since the inputs are disconnected from the outputs.
|
|
end_ids_found = set()
|
|
|
|
nodes_to_visit = []
|
|
nodes_in_graph = []
|
|
node_id_visited = set()
|
|
for t in start_keras_tensors:
|
|
nodes_to_visit.append(t.node)
|
|
|
|
while nodes_to_visit:
|
|
node = nodes_to_visit.pop(0)
|
|
if id(node) in node_id_visited:
|
|
continue
|
|
node_id_visited.add(id(node))
|
|
nodes_in_graph.append(node)
|
|
# Any input keras_tensor that produce the current node.
|
|
for kt in node.keras_inputs:
|
|
if id(kt) in end_ids:
|
|
# We found the inputs of the model, stop tracing upstream nodes
|
|
end_ids_found.add(id(kt))
|
|
continue
|
|
|
|
inbound_node = kt.node
|
|
# In case this is the tf.keras.Input node, we have reached the end
|
|
# of the tracing of upstream nodes. Any further tracing will just be
|
|
# an infinite loop. we should raise an error here since we didn't
|
|
# find the input in the user-specified inputs.
|
|
if inbound_node.is_input:
|
|
raise ValueError(
|
|
"Found input tensor cannot be reached given provided "
|
|
"output tensors. Please make sure the tensor {} is "
|
|
"included in the model inputs when building "
|
|
"functional model.".format(kt)
|
|
)
|
|
nodes_to_visit.append(inbound_node)
|
|
|
|
# Do a final check and make sure we have reached all the user-specified
|
|
# inputs
|
|
if end_ids != end_ids_found:
|
|
unvisited_inputs = [
|
|
kt for kt in end_keras_tensors if id(kt) not in end_ids_found
|
|
]
|
|
raise ValueError(
|
|
"Found unvisited input tensors that are disconnected from "
|
|
"the outputs: {}".format(unvisited_inputs)
|
|
)
|
|
return nodes_in_graph
|
|
|
|
|
|
def clone_graph_nodes(inputs, outputs):
|
|
"""Clone the `Node` between the inputs and output tensors.
|
|
|
|
This function is used to create a new functional model from any intermediate
|
|
keras tensors. The clone of the nodes mimic the behavior of reconstructing
|
|
the functional graph network by re-executing all the __call__ methods. The
|
|
cloned nodes will be appended to the layers.
|
|
|
|
Note that a new tf.keras.Inputs will be created for any items in the
|
|
`inputs`
|
|
|
|
Args:
|
|
inputs: A nested structure of keras_tensors.
|
|
outputs: A nested structure of keras_tensors.
|
|
|
|
Returns:
|
|
A pair of inputs and outputs, with cloned keras_tensors. They can be used
|
|
to create a new functional model.
|
|
"""
|
|
nodes_to_clone = find_nodes_by_inputs_and_outputs(inputs, outputs)
|
|
cloned_inputs = []
|
|
cloned_outputs = []
|
|
# We not only need to create copies of Nodes (mimic the calls), also need to
|
|
# clone keras_tensors to avoid the override of _keras_history attached on
|
|
# the keras_tensor. The following dict is used to track any keras tensor we
|
|
# cloned The key is the string ID of the original keras tensor, and value is
|
|
# the cloned keras_tensor instance.
|
|
kt_id_mapping = {}
|
|
|
|
for kt_input in tf.nest.flatten(inputs):
|
|
if kt_input.node.is_input:
|
|
# For any existing keras_tensor from tf.keras.Input, we leave them
|
|
# as is.
|
|
cloned_inputs.append(kt_input)
|
|
kt_id_mapping[id(kt_input)] = kt_input
|
|
else:
|
|
# We need to create a new tf.keras.Input for any intermediate
|
|
# keras_tensor
|
|
cpy = _clone_keras_tensor(kt_input)
|
|
cloned_input = input_layer_module.Input(tensor=cpy)
|
|
cloned_inputs.append(cloned_input)
|
|
kt_id_mapping[id(kt_input)] = cloned_input
|
|
cloned_inputs = tf.nest.pack_sequence_as(inputs, cloned_inputs)
|
|
|
|
for kt_output in tf.nest.flatten(outputs):
|
|
cpy = _clone_keras_tensor(kt_output)
|
|
# We reuse the _keras_history here, which contains the old information.
|
|
# It is used in the Node constructor to check if the tensor
|
|
# "is_keras_tensor()" The history will be override by the Node
|
|
# constructor anyway for the corresponding layer output anyway.
|
|
cpy._keras_history = kt_output._keras_history
|
|
cloned_outputs.append(cpy)
|
|
kt_id_mapping[id(kt_output)] = cpy
|
|
cloned_outputs = tf.nest.pack_sequence_as(outputs, cloned_outputs)
|
|
|
|
for node in nodes_to_clone:
|
|
# Clone any keras_tensors to avoid override of _keras_history
|
|
# Or reuse an existing keras_tensor if it has already been cloned.
|
|
output_copy = clone_keras_tensors(node.output_tensors, kt_id_mapping)
|
|
call_args_copy = clone_keras_tensors(node.call_args, kt_id_mapping)
|
|
call_kwargs_copy = clone_keras_tensors(node.call_kwargs, kt_id_mapping)
|
|
# Creating new nodes based on the existing node information. Node wires
|
|
# itself to inbound and outbound layers. The Node constructor actually
|
|
# updates this layer's self._inbound_nodes, sets _keras_history on the
|
|
# outputs, and adds itself to the `_outbound_nodes` of the layers that
|
|
# produced the inputs to this layer call.
|
|
node_module.Node(
|
|
node.layer,
|
|
call_args=call_args_copy,
|
|
call_kwargs=call_kwargs_copy,
|
|
outputs=output_copy,
|
|
)
|
|
return cloned_inputs, cloned_outputs
|
|
|
|
|
|
def clone_keras_tensors(args, keras_tensor_mapping):
|
|
"""Clone the keras tensors from the inputs.
|
|
|
|
For any KerasTensor instance in the `args`, a new copy of KerasTensor will
|
|
be created if it has not been cloned yet (by checking the
|
|
`keras_tensor_mapping`). For any other types, the instance will be
|
|
unchanged. This function is useful for cloning the Nodes since KerasTensor
|
|
can't be reused across the models.
|
|
|
|
Args:
|
|
args: A nested structure of objects, which could contain KerasTensor.
|
|
keras_tensor_mapping: A dict contains the ID of original KerasTensor, and
|
|
the cloned KerasTensor instance. The dict will be updated with newly
|
|
copied KerasTensor instances within this method.
|
|
Returns:
|
|
Same structure as inputs, with KerasTensor cloned.
|
|
"""
|
|
result = []
|
|
for obj in tf.nest.flatten(args):
|
|
if node_module.is_keras_tensor(obj):
|
|
if id(obj) in keras_tensor_mapping:
|
|
cpy = keras_tensor_mapping[id(obj)]
|
|
else:
|
|
# Create copy of keras_tensor if we haven't done it before
|
|
cpy = _clone_keras_tensor(obj)
|
|
cpy._keras_history = obj._keras_history
|
|
keras_tensor_mapping[id(obj)] = cpy
|
|
result.append(cpy)
|
|
else:
|
|
result.append(obj)
|
|
return tf.nest.pack_sequence_as(args, result)
|
|
|
|
|
|
def _clone_keras_tensor(kt):
|
|
"""Create an identical keras_tensor based on the input.
|
|
|
|
We use keras_tensor_to_placeholder and keras_tensor_from_tensor to make sure
|
|
inferred shape are not lost during the copy.
|
|
|
|
Args:
|
|
kt: the input KerasTensor.
|
|
|
|
Returns:
|
|
An identical copy of the input KerasTensor.
|
|
"""
|
|
# Create a scratch graph since we don't intend to use the placeholders.
|
|
with backend._scratch_graph() as scratch_graph:
|
|
with scratch_graph.as_default():
|
|
placeholder = keras_tensor.keras_tensor_to_placeholder(kt)
|
|
return keras_tensor.keras_tensor_from_tensor(placeholder)
|