Intelegentny_Pszczelarz/.venv/Lib/site-packages/keras/engine/functional_utils.py

261 lines
11 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for keras functional model."""
import tensorflow.compat.v2 as tf
from keras import backend
from keras.engine import input_layer as input_layer_module
from keras.engine import keras_tensor
from keras.engine import node as node_module
_KERAS_TENSOR_TYPE_CHECK_ERROR_MSG = (
"Found unexpected instance while processing input tensors for keras "
"functional model. Expecting KerasTensor which is from tf.keras.Input() "
"or output from keras layer call(). Got: {}"
)
def is_input_keras_tensor(tensor):
"""Check if tensor is directly generated from `tf.keras.Input`.
This check is useful when constructing the functional model, since we will
need to clone Nodes and KerasTensors if the model is building from non input
tensor.
Args:
tensor: A `KerasTensor` as inputs to the functional model.
Returns:
bool. Whether the tensor is directly generated from `tf.keras.Input`.
Raises:
ValueError: if the tensor is not a KerasTensor instance.
"""
if not node_module.is_keras_tensor(tensor):
raise ValueError(_KERAS_TENSOR_TYPE_CHECK_ERROR_MSG.format(tensor))
return tensor.node.is_input
def find_nodes_by_inputs_and_outputs(inputs, outputs):
"""Fetch all Nodes in the graph defined by "inputs" and "outputs".
This method is used to find and then clone Nodes when creating a new
sub-model from an existing functional model.
Args:
inputs: A nested structure of KerasTensor to use as model inputs.
outputs: A nested structure of KerasTensor to use as model outputs.
Returns:
A list of Nodes that are connected to the inputs and outputs.
Raises:
ValueError: when inputs and outputs are disconnected or in case of
unexpected objects in the inputs/outputs.
"""
# We walk the graph bottom up, starting from output nodes, and keep tracing
# the upstream node, until we find all the inputs nodes. We don't use top
# down search here since we don't know whether a certain node is in the
# graph between inputs and outputs, e.g. a functional graph could have
# multiple outputs, and the user could choose a subset of them to build the
# model. The bottom up approach will ensure all the nodes we visit are
# actually in use. If we reach the top and didn't find the nodes in the
# `inputs`, that's an error, since the user didn't specify the correct
# inputs.
start_keras_tensors = tf.nest.flatten(outputs)
end_keras_tensors = tf.nest.flatten(inputs)
for t in start_keras_tensors + end_keras_tensors:
if not node_module.is_keras_tensor(t):
raise ValueError(_KERAS_TENSOR_TYPE_CHECK_ERROR_MSG.format(t))
end_ids = set([id(kt) for kt in end_keras_tensors])
# Track all the end tensors we found so far, if we didn't reach all the
# user-specified keras inputs after we finish the search, then that's an
# error since the inputs are disconnected from the outputs.
end_ids_found = set()
nodes_to_visit = []
nodes_in_graph = []
node_id_visited = set()
for t in start_keras_tensors:
nodes_to_visit.append(t.node)
while nodes_to_visit:
node = nodes_to_visit.pop(0)
if id(node) in node_id_visited:
continue
node_id_visited.add(id(node))
nodes_in_graph.append(node)
# Any input keras_tensor that produce the current node.
for kt in node.keras_inputs:
if id(kt) in end_ids:
# We found the inputs of the model, stop tracing upstream nodes
end_ids_found.add(id(kt))
continue
inbound_node = kt.node
# In case this is the tf.keras.Input node, we have reached the end
# of the tracing of upstream nodes. Any further tracing will just be
# an infinite loop. we should raise an error here since we didn't
# find the input in the user-specified inputs.
if inbound_node.is_input:
raise ValueError(
"Found input tensor cannot be reached given provided "
"output tensors. Please make sure the tensor {} is "
"included in the model inputs when building "
"functional model.".format(kt)
)
nodes_to_visit.append(inbound_node)
# Do a final check and make sure we have reached all the user-specified
# inputs
if end_ids != end_ids_found:
unvisited_inputs = [
kt for kt in end_keras_tensors if id(kt) not in end_ids_found
]
raise ValueError(
"Found unvisited input tensors that are disconnected from "
"the outputs: {}".format(unvisited_inputs)
)
return nodes_in_graph
def clone_graph_nodes(inputs, outputs):
"""Clone the `Node` between the inputs and output tensors.
This function is used to create a new functional model from any intermediate
keras tensors. The clone of the nodes mimic the behavior of reconstructing
the functional graph network by re-executing all the __call__ methods. The
cloned nodes will be appended to the layers.
Note that a new tf.keras.Inputs will be created for any items in the
`inputs`
Args:
inputs: A nested structure of keras_tensors.
outputs: A nested structure of keras_tensors.
Returns:
A pair of inputs and outputs, with cloned keras_tensors. They can be used
to create a new functional model.
"""
nodes_to_clone = find_nodes_by_inputs_and_outputs(inputs, outputs)
cloned_inputs = []
cloned_outputs = []
# We not only need to create copies of Nodes (mimic the calls), also need to
# clone keras_tensors to avoid the override of _keras_history attached on
# the keras_tensor. The following dict is used to track any keras tensor we
# cloned The key is the string ID of the original keras tensor, and value is
# the cloned keras_tensor instance.
kt_id_mapping = {}
for kt_input in tf.nest.flatten(inputs):
if kt_input.node.is_input:
# For any existing keras_tensor from tf.keras.Input, we leave them
# as is.
cloned_inputs.append(kt_input)
kt_id_mapping[id(kt_input)] = kt_input
else:
# We need to create a new tf.keras.Input for any intermediate
# keras_tensor
cpy = _clone_keras_tensor(kt_input)
cloned_input = input_layer_module.Input(tensor=cpy)
cloned_inputs.append(cloned_input)
kt_id_mapping[id(kt_input)] = cloned_input
cloned_inputs = tf.nest.pack_sequence_as(inputs, cloned_inputs)
for kt_output in tf.nest.flatten(outputs):
cpy = _clone_keras_tensor(kt_output)
# We reuse the _keras_history here, which contains the old information.
# It is used in the Node constructor to check if the tensor
# "is_keras_tensor()" The history will be override by the Node
# constructor anyway for the corresponding layer output anyway.
cpy._keras_history = kt_output._keras_history
cloned_outputs.append(cpy)
kt_id_mapping[id(kt_output)] = cpy
cloned_outputs = tf.nest.pack_sequence_as(outputs, cloned_outputs)
for node in nodes_to_clone:
# Clone any keras_tensors to avoid override of _keras_history
# Or reuse an existing keras_tensor if it has already been cloned.
output_copy = clone_keras_tensors(node.output_tensors, kt_id_mapping)
call_args_copy = clone_keras_tensors(node.call_args, kt_id_mapping)
call_kwargs_copy = clone_keras_tensors(node.call_kwargs, kt_id_mapping)
# Creating new nodes based on the existing node information. Node wires
# itself to inbound and outbound layers. The Node constructor actually
# updates this layer's self._inbound_nodes, sets _keras_history on the
# outputs, and adds itself to the `_outbound_nodes` of the layers that
# produced the inputs to this layer call.
node_module.Node(
node.layer,
call_args=call_args_copy,
call_kwargs=call_kwargs_copy,
outputs=output_copy,
)
return cloned_inputs, cloned_outputs
def clone_keras_tensors(args, keras_tensor_mapping):
"""Clone the keras tensors from the inputs.
For any KerasTensor instance in the `args`, a new copy of KerasTensor will
be created if it has not been cloned yet (by checking the
`keras_tensor_mapping`). For any other types, the instance will be
unchanged. This function is useful for cloning the Nodes since KerasTensor
can't be reused across the models.
Args:
args: A nested structure of objects, which could contain KerasTensor.
keras_tensor_mapping: A dict contains the ID of original KerasTensor, and
the cloned KerasTensor instance. The dict will be updated with newly
copied KerasTensor instances within this method.
Returns:
Same structure as inputs, with KerasTensor cloned.
"""
result = []
for obj in tf.nest.flatten(args):
if node_module.is_keras_tensor(obj):
if id(obj) in keras_tensor_mapping:
cpy = keras_tensor_mapping[id(obj)]
else:
# Create copy of keras_tensor if we haven't done it before
cpy = _clone_keras_tensor(obj)
cpy._keras_history = obj._keras_history
keras_tensor_mapping[id(obj)] = cpy
result.append(cpy)
else:
result.append(obj)
return tf.nest.pack_sequence_as(args, result)
def _clone_keras_tensor(kt):
"""Create an identical keras_tensor based on the input.
We use keras_tensor_to_placeholder and keras_tensor_from_tensor to make sure
inferred shape are not lost during the copy.
Args:
kt: the input KerasTensor.
Returns:
An identical copy of the input KerasTensor.
"""
# Create a scratch graph since we don't intend to use the placeholders.
with backend._scratch_graph() as scratch_graph:
with scratch_graph.as_default():
placeholder = keras_tensor.keras_tensor_to_placeholder(kt)
return keras_tensor.keras_tensor_from_tensor(placeholder)