Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/framework/graph_to_function_def.py
2023-06-19 00:49:18 +02:00

184 lines
6.5 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Utility to convert a Graph to a FunctionDef."""
import re
from tensorflow.core.framework import function_pb2
from tensorflow.core.framework import op_def_pb2
from tensorflow.python.framework import op_def_registry
def _make_argname_from_tensor_name(name):
return re.sub(":0$", "", name).replace(":", "_o")
def _tensor_to_argdef(t, name=None, used_names=None):
"""Convert tensor t to an argdef, with a specified name or a unique name."""
arg = op_def_pb2.OpDef.ArgDef()
if name is None:
arg.name = _make_argname_from_tensor_name(t.name)
if used_names is not None:
if arg.name in used_names:
i = 0
while True:
new_name = "%s_U%d" % (arg.name, i)
if new_name not in used_names:
arg.name = new_name
break
i += 1
used_names.add(arg.name)
else:
arg.name = name
arg.type = t.dtype.as_datatype_enum
return arg
def _is_in_placeholders(op, func_arg_placeholders):
"""Checks whether any output of this op is in func_arg_placeholders."""
return op.values() and any(x.name in func_arg_placeholders
for x in op.values())
def _get_node_def(op):
return op.node_def # pylint: disable=protected-access
def _get_op_def(op):
return op.op_def or op_def_registry.get(op.type)
def _create_input_dict(function_graph,
func_arg_placeholders,
initial_value=None):
"""Create a mapping from graph tensor names to function tensor names."""
if initial_value is None:
input_dict = {}
else:
input_dict = dict(initial_value)
for op in function_graph.get_operations():
if _is_in_placeholders(op, func_arg_placeholders):
input_dict[op.name] = op.name
else:
op_def = _get_op_def(op)
attrs = _get_node_def(op).attr
o = 0
for arg_def in op_def.output_arg:
if arg_def.number_attr:
num = attrs[arg_def.number_attr].i
elif arg_def.type_list_attr:
num = len(attrs[arg_def.type_list_attr].list.type)
else:
num = 1
for i in range(num):
result = "%s:%s:%d" % (op.name, arg_def.name, i)
input_dict[op.values()[o].name] = result
if o == 0:
input_dict[op.name] = result
o += 1
return input_dict
def _add_op_node(op, func, input_dict):
"""Converts an op to a function def node and add it to `func`."""
# Add an entry in func.node_def
# Note that extend() makes a copy in this case, see:
# https://developers.google.com/protocol-buffers/docs/reference/python-generated#repeated-message-fields
func.node_def.extend([_get_node_def(op)])
node_def = func.node_def[-1]
for i in range(len(node_def.input)):
if not node_def.input[i].startswith("^"):
assert node_def.input[i] in input_dict, ("%s missing from %s" %
(node_def.input[i],
input_dict.items()))
node_def.input[i] = input_dict[node_def.input[i]]
# The function is stateful if any of its operations are stateful.
# NOTE(mrry): The "Const" node typically does not have an `OpDef` associated
# with it, so we assume any nodes without an `OpDef` are stateless.
# TODO(skyewm): Remove the `is not None` test after we transition to the C
# API.
if op.op_def is not None and op.op_def.is_stateful:
func.signature.is_stateful = True
def graph_to_function_def(graph, operations, inputs, outputs, out_names=None):
"""Returns `graph` as a `FunctionDef` protocol buffer.
This method creates a [`FunctionDef`](
https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
protocol buffer that contains all the ops in `operations`. The
operations become the body of the function.
The arguments `inputs` and `outputs` will be listed as the inputs
and outputs tensors of the function. They must be lists of
tensors present in the graph. The lists can optionally be empty.
Args:
graph: Graph.
operations: the operations to put in the function. Must be a subset of
the operations in the graph.
inputs: List of tensors. Inputs to the function.
outputs: List of tensors. Outputs of the function.
out_names: Optional list of string names for the outputs.
Returns:
A FunctionDef protocol buffer.
Raises:
ValueError: if out_names is specified and the wrong length.
"""
func = function_pb2.FunctionDef()
func.signature.name = "_"
used_names = set()
func.signature.input_arg.extend(
[_tensor_to_argdef(i, used_names=used_names) for i in inputs])
# Initializes the input map with all placeholder input tensors.
initial_dict = {}
for o, m in zip(inputs, func.signature.input_arg):
initial_dict[o.name] = m.name
if out_names is None:
used_names = set()
func.signature.output_arg.extend(
[_tensor_to_argdef(o, used_names=used_names) for o in outputs])
elif len(outputs) != len(out_names):
raise ValueError(
f"out_names must be either empty or equal in size to outputs. "
f"len(out_names) = {len(out_names)} len(outputs) = {len(outputs)}")
elif len(out_names) != len(set(out_names)):
raise ValueError(
f"Must not have duplicates in out_names. Received: {out_names}")
else:
func.signature.output_arg.extend(
[_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)])
func_arg_placeholders = set(i.name for i in inputs)
input_dict = _create_input_dict(graph, func_arg_placeholders,
initial_value=initial_dict)
for op in operations:
if _is_in_placeholders(op, func_arg_placeholders):
continue
_add_op_node(op, func, input_dict)
if out_names is None:
for index, o in enumerate(outputs):
k = func.signature.output_arg[index].name
func.ret[k] = input_dict[o.name]
else:
for o, n in zip(outputs, out_names):
func.ret[n] = input_dict[o.name]
return func