184 lines
6.5 KiB
Python
184 lines
6.5 KiB
Python
![]() |
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# =============================================================================
|
||
|
"""Utility to convert a Graph to a FunctionDef."""
|
||
|
|
||
|
import re
|
||
|
|
||
|
from tensorflow.core.framework import function_pb2
|
||
|
from tensorflow.core.framework import op_def_pb2
|
||
|
from tensorflow.python.framework import op_def_registry
|
||
|
|
||
|
|
||
|
def _make_argname_from_tensor_name(name):
|
||
|
return re.sub(":0$", "", name).replace(":", "_o")
|
||
|
|
||
|
|
||
|
def _tensor_to_argdef(t, name=None, used_names=None):
|
||
|
"""Convert tensor t to an argdef, with a specified name or a unique name."""
|
||
|
arg = op_def_pb2.OpDef.ArgDef()
|
||
|
if name is None:
|
||
|
arg.name = _make_argname_from_tensor_name(t.name)
|
||
|
if used_names is not None:
|
||
|
if arg.name in used_names:
|
||
|
i = 0
|
||
|
while True:
|
||
|
new_name = "%s_U%d" % (arg.name, i)
|
||
|
if new_name not in used_names:
|
||
|
arg.name = new_name
|
||
|
break
|
||
|
i += 1
|
||
|
used_names.add(arg.name)
|
||
|
else:
|
||
|
arg.name = name
|
||
|
arg.type = t.dtype.as_datatype_enum
|
||
|
return arg
|
||
|
|
||
|
|
||
|
def _is_in_placeholders(op, func_arg_placeholders):
|
||
|
"""Checks whether any output of this op is in func_arg_placeholders."""
|
||
|
return op.values() and any(x.name in func_arg_placeholders
|
||
|
for x in op.values())
|
||
|
|
||
|
|
||
|
def _get_node_def(op):
|
||
|
return op.node_def # pylint: disable=protected-access
|
||
|
|
||
|
|
||
|
def _get_op_def(op):
|
||
|
return op.op_def or op_def_registry.get(op.type)
|
||
|
|
||
|
|
||
|
def _create_input_dict(function_graph,
|
||
|
func_arg_placeholders,
|
||
|
initial_value=None):
|
||
|
"""Create a mapping from graph tensor names to function tensor names."""
|
||
|
if initial_value is None:
|
||
|
input_dict = {}
|
||
|
else:
|
||
|
input_dict = dict(initial_value)
|
||
|
for op in function_graph.get_operations():
|
||
|
if _is_in_placeholders(op, func_arg_placeholders):
|
||
|
input_dict[op.name] = op.name
|
||
|
else:
|
||
|
op_def = _get_op_def(op)
|
||
|
attrs = _get_node_def(op).attr
|
||
|
o = 0
|
||
|
for arg_def in op_def.output_arg:
|
||
|
if arg_def.number_attr:
|
||
|
num = attrs[arg_def.number_attr].i
|
||
|
elif arg_def.type_list_attr:
|
||
|
num = len(attrs[arg_def.type_list_attr].list.type)
|
||
|
else:
|
||
|
num = 1
|
||
|
for i in range(num):
|
||
|
result = "%s:%s:%d" % (op.name, arg_def.name, i)
|
||
|
input_dict[op.values()[o].name] = result
|
||
|
if o == 0:
|
||
|
input_dict[op.name] = result
|
||
|
o += 1
|
||
|
return input_dict
|
||
|
|
||
|
|
||
|
def _add_op_node(op, func, input_dict):
|
||
|
"""Converts an op to a function def node and add it to `func`."""
|
||
|
# Add an entry in func.node_def
|
||
|
|
||
|
# Note that extend() makes a copy in this case, see:
|
||
|
# https://developers.google.com/protocol-buffers/docs/reference/python-generated#repeated-message-fields
|
||
|
func.node_def.extend([_get_node_def(op)])
|
||
|
node_def = func.node_def[-1]
|
||
|
for i in range(len(node_def.input)):
|
||
|
if not node_def.input[i].startswith("^"):
|
||
|
assert node_def.input[i] in input_dict, ("%s missing from %s" %
|
||
|
(node_def.input[i],
|
||
|
input_dict.items()))
|
||
|
node_def.input[i] = input_dict[node_def.input[i]]
|
||
|
# The function is stateful if any of its operations are stateful.
|
||
|
# NOTE(mrry): The "Const" node typically does not have an `OpDef` associated
|
||
|
# with it, so we assume any nodes without an `OpDef` are stateless.
|
||
|
# TODO(skyewm): Remove the `is not None` test after we transition to the C
|
||
|
# API.
|
||
|
if op.op_def is not None and op.op_def.is_stateful:
|
||
|
func.signature.is_stateful = True
|
||
|
|
||
|
|
||
|
def graph_to_function_def(graph, operations, inputs, outputs, out_names=None):
|
||
|
"""Returns `graph` as a `FunctionDef` protocol buffer.
|
||
|
|
||
|
This method creates a [`FunctionDef`](
|
||
|
https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
|
||
|
protocol buffer that contains all the ops in `operations`. The
|
||
|
operations become the body of the function.
|
||
|
|
||
|
The arguments `inputs` and `outputs` will be listed as the inputs
|
||
|
and outputs tensors of the function. They must be lists of
|
||
|
tensors present in the graph. The lists can optionally be empty.
|
||
|
|
||
|
Args:
|
||
|
graph: Graph.
|
||
|
operations: the operations to put in the function. Must be a subset of
|
||
|
the operations in the graph.
|
||
|
inputs: List of tensors. Inputs to the function.
|
||
|
outputs: List of tensors. Outputs of the function.
|
||
|
out_names: Optional list of string names for the outputs.
|
||
|
|
||
|
Returns:
|
||
|
A FunctionDef protocol buffer.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if out_names is specified and the wrong length.
|
||
|
"""
|
||
|
func = function_pb2.FunctionDef()
|
||
|
func.signature.name = "_"
|
||
|
used_names = set()
|
||
|
func.signature.input_arg.extend(
|
||
|
[_tensor_to_argdef(i, used_names=used_names) for i in inputs])
|
||
|
# Initializes the input map with all placeholder input tensors.
|
||
|
initial_dict = {}
|
||
|
for o, m in zip(inputs, func.signature.input_arg):
|
||
|
initial_dict[o.name] = m.name
|
||
|
if out_names is None:
|
||
|
used_names = set()
|
||
|
func.signature.output_arg.extend(
|
||
|
[_tensor_to_argdef(o, used_names=used_names) for o in outputs])
|
||
|
elif len(outputs) != len(out_names):
|
||
|
raise ValueError(
|
||
|
f"out_names must be either empty or equal in size to outputs. "
|
||
|
f"len(out_names) = {len(out_names)} len(outputs) = {len(outputs)}")
|
||
|
elif len(out_names) != len(set(out_names)):
|
||
|
raise ValueError(
|
||
|
f"Must not have duplicates in out_names. Received: {out_names}")
|
||
|
else:
|
||
|
func.signature.output_arg.extend(
|
||
|
[_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)])
|
||
|
func_arg_placeholders = set(i.name for i in inputs)
|
||
|
input_dict = _create_input_dict(graph, func_arg_placeholders,
|
||
|
initial_value=initial_dict)
|
||
|
|
||
|
for op in operations:
|
||
|
if _is_in_placeholders(op, func_arg_placeholders):
|
||
|
continue
|
||
|
_add_op_node(op, func, input_dict)
|
||
|
|
||
|
if out_names is None:
|
||
|
for index, o in enumerate(outputs):
|
||
|
k = func.signature.output_arg[index].name
|
||
|
func.ret[k] = input_dict[o.name]
|
||
|
else:
|
||
|
for o, n in zip(outputs, out_names):
|
||
|
func.ret[n] = input_dict[o.name]
|
||
|
|
||
|
return func
|