327 lines
13 KiB
Python
327 lines
13 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# =============================================================================
|
|
"""Utility to convert FunctionDef to GraphDef and Graph."""
|
|
|
|
import itertools
|
|
|
|
|
|
from tensorflow.core.framework import function_pb2
|
|
from tensorflow.core.framework import graph_pb2
|
|
from tensorflow.core.framework import tensor_shape_pb2
|
|
from tensorflow.core.framework import types_pb2
|
|
from tensorflow.core.framework import versions_pb2
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import cpp_shape_inference_pb2
|
|
from tensorflow.python.framework import importer
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import versions
|
|
from tensorflow.python.framework.func_graph import FuncGraph
|
|
from tensorflow.python.ops import resource_variable_ops
|
|
|
|
|
|
def function_def_to_graph(fdef,
|
|
structured_input_signature=None,
|
|
structured_outputs=None,
|
|
input_shapes=None,
|
|
propagate_device_spec=False):
|
|
"""Converts a FunctionDef to a FuncGraph (sub-class Graph).
|
|
|
|
The returned FuncGraph's `name`, `inputs` and `outputs` fields will be set.
|
|
The input tensors are represented as placeholders.
|
|
|
|
Note: `FuncGraph.inputs` and `FuncGraph.captures` are not set and may be set
|
|
by the caller.
|
|
|
|
Args:
|
|
fdef: FunctionDef.
|
|
structured_input_signature: Optional. The structured input signature to
|
|
use for initializing the FuncGraph. See the docstring for FuncGraph for
|
|
more information.
|
|
structured_outputs: Optional. The structured outputs to use for
|
|
initializing the FuncGraph. See the docstring for FuncGraph for more
|
|
information.
|
|
input_shapes: Optional. A list of TensorShape objects of the shapes of
|
|
function inputs. Defaults to the function's "_input_shapes" attribute. If
|
|
specified, its length must match length of `fdef.signature.input_arg`. If
|
|
a shape is None, the corresponding input placeholder will have unknown
|
|
shape.
|
|
propagate_device_spec: Optional. Whether to propagate assigned device
|
|
information when constructing a new Graph from a FunctionDef.
|
|
|
|
Returns:
|
|
A FuncGraph.
|
|
"""
|
|
func_graph = FuncGraph(fdef.signature.name,
|
|
structured_input_signature=structured_input_signature,
|
|
structured_outputs=structured_outputs)
|
|
if input_shapes is None:
|
|
input_shapes_attr = fdef.attr.get("_input_shapes", None)
|
|
if input_shapes_attr is not None:
|
|
raw_input_shapes = input_shapes_attr.list.shape
|
|
|
|
# Replace resource handle shapes in the inputs to disable shape inference.
|
|
# Setting the shape to either the variable handle shape (which is always
|
|
# `[]`) or the variable shape can cause shape inference issues.
|
|
input_shapes = []
|
|
for input_shape, arg_def in zip(raw_input_shapes,
|
|
fdef.signature.input_arg):
|
|
if arg_def.type == types_pb2.DT_RESOURCE and arg_def.handle_data:
|
|
input_shapes.append(None)
|
|
else:
|
|
input_shapes.append(input_shape)
|
|
|
|
graph_def, nested_to_flat_tensor_name = function_def_to_graph_def(
|
|
fdef, input_shapes)
|
|
|
|
with func_graph.as_default():
|
|
# Add all function nodes to the graph.
|
|
importer.import_graph_def_for_function(
|
|
graph_def, name="", propagate_device_spec=propagate_device_spec)
|
|
|
|
# Initialize fields specific to FuncGraph.
|
|
|
|
# inputs
|
|
input_tensor_names = [
|
|
nested_to_flat_tensor_name[arg.name] for arg in fdef.signature.input_arg
|
|
]
|
|
func_graph.inputs = [
|
|
func_graph.get_tensor_by_name(name) for name in input_tensor_names
|
|
]
|
|
|
|
# outputs
|
|
output_tensor_names = [
|
|
nested_to_flat_tensor_name[fdef.ret[arg.name]]
|
|
for arg in fdef.signature.output_arg
|
|
]
|
|
func_graph.outputs = [
|
|
func_graph.get_tensor_by_name(name) for name in output_tensor_names
|
|
]
|
|
func_graph.control_outputs = [
|
|
func_graph.get_operation_by_name(fdef.control_ret[ret_name])
|
|
for ret_name in fdef.signature.control_output
|
|
]
|
|
|
|
_set_handle_data(func_graph, fdef)
|
|
|
|
for node in graph_def.node:
|
|
output_shapes = node.attr.get("_output_shapes", None)
|
|
if output_shapes is not None:
|
|
op = func_graph.get_operation_by_name(node.name)
|
|
# _output_shapes for functions can sometimes be too long because the
|
|
# output-intermediates-for-gradients version of the function was
|
|
# substituted before saving. We'll accept that here. (See b/133666530).
|
|
for output_index, shape in enumerate(
|
|
output_shapes.list.shape[:len(op.outputs)]):
|
|
op.outputs[output_index].set_shape(shape)
|
|
output_names = {}
|
|
for ret_arg_def, tensor_name in zip(
|
|
fdef.signature.output_arg, output_tensor_names):
|
|
output_names[ops.tensor_id(
|
|
func_graph.get_tensor_by_name(tensor_name))] = (
|
|
ret_arg_def.name)
|
|
func_graph._output_names = output_names # pylint: disable=protected-access
|
|
return func_graph
|
|
|
|
|
|
def is_function(fname):
|
|
"""Checks for a function definition with `fname` in the current context."""
|
|
if context.executing_eagerly():
|
|
return context.context().has_function(fname)
|
|
else:
|
|
graph = ops.get_default_graph()
|
|
while graph is not None:
|
|
if graph._is_function(fname): # pylint: disable=protected-access
|
|
return True
|
|
if hasattr(graph, "outer_graph"):
|
|
graph = graph.outer_graph
|
|
else:
|
|
return False
|
|
|
|
|
|
def function_def_to_graph_def(fdef, input_shapes=None):
|
|
"""Convert a FunctionDef to a GraphDef.
|
|
|
|
Steps:
|
|
1. Creates placeholder nodes corresponding to inputs in
|
|
`FunctionDef.signature.input_arg`.
|
|
2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`.
|
|
3. Renames inputs of all nodes to use the convention of GraphDef instead of
|
|
FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming
|
|
in FunctionDefs is different from GraphDefs.
|
|
|
|
Args:
|
|
fdef: FunctionDef.
|
|
input_shapes: Optional. A list of TensorShape objects of the shapes of
|
|
function inputs. If specified, its length must match length of
|
|
`fdef.signature.input_arg`. If a shape is None, the corresponding input
|
|
placeholder will have unknown shape.
|
|
|
|
Returns:
|
|
A tuple of (GraphDef, dict<string, string>). The dict contains a mapping
|
|
from nested tensor names (in FunctionDef) to flattened names (in GraphDef).
|
|
|
|
Raises:
|
|
ValueError: If the length of input_shapes does not match the number of
|
|
input_args or if the FunctionDef is invalid.
|
|
"""
|
|
graph_def = graph_pb2.GraphDef()
|
|
graph_def.versions.CopyFrom(
|
|
versions_pb2.VersionDef(
|
|
producer=versions.GRAPH_DEF_VERSION,
|
|
min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER))
|
|
|
|
default_graph = ops.get_default_graph()
|
|
|
|
copied_functions = set()
|
|
|
|
if input_shapes and len(input_shapes) != len(fdef.signature.input_arg):
|
|
raise ValueError("Length of `input_shapes` must match the number "
|
|
f"of `input_arg`s in `fdef`. Got "
|
|
f"{len(input_shapes)} `input_shapes` and "
|
|
f"{len(fdef.signature.input_arg)} `input_arg`s.")
|
|
|
|
# 1. Create placeholders for input nodes.
|
|
for i, arg_def in enumerate(fdef.signature.input_arg):
|
|
node_def = graph_def.node.add()
|
|
node_def.name = arg_def.name
|
|
node_def.op = "Placeholder"
|
|
node_def.attr["dtype"].type = arg_def.type
|
|
if input_shapes and input_shapes[i] is not None:
|
|
input_shape = input_shapes[i]
|
|
if not isinstance(input_shape, tensor_shape_pb2.TensorShapeProto):
|
|
input_shape = input_shape.as_proto()
|
|
node_def.attr["shape"].shape.CopyFrom(input_shape)
|
|
arg_attrs = fdef.arg_attr[i].attr
|
|
for k in arg_attrs:
|
|
# Only copy internal attributes. Normal attributes for nodes cannot be
|
|
# applied to these Placeholder nodes.
|
|
if k == "_output_shapes":
|
|
if arg_attrs[k].WhichOneof("value") == "list":
|
|
node_def.attr["shape"].shape.CopyFrom(arg_attrs[k].list.shape[0])
|
|
elif arg_attrs[k].WhichOneof("value") == "shape":
|
|
node_def.attr["shape"].shape.CopyFrom(arg_attrs[k].shape)
|
|
elif k.startswith("_"):
|
|
node_def.attr[k].CopyFrom(arg_attrs[k])
|
|
|
|
# 2. Copy all body NodeDefs to the GraphDef.
|
|
graph_def.node.extend(fdef.node_def)
|
|
|
|
# 3. Perform the renaming.
|
|
|
|
# Build the tensor name mapping then flatten the tensor names.
|
|
# See comment on `FunctionDef.node_def` on how the tensor naming in
|
|
# FunctionDefs is different from GraphDefs.
|
|
nested_to_flat_tensor_name = {}
|
|
|
|
for arg_def in fdef.signature.input_arg:
|
|
nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name)
|
|
control_name = "^" + arg_def.name
|
|
nested_to_flat_tensor_name[control_name] = control_name
|
|
|
|
for node_def in fdef.node_def:
|
|
graph = default_graph
|
|
while True:
|
|
f = graph._functions.get(node_def.op, None) # pylint: disable=protected-access
|
|
if f is not None or not hasattr(graph, "outer_graph"):
|
|
break
|
|
graph = graph.outer_graph
|
|
|
|
if f is not None:
|
|
fdef = f.definition
|
|
op_def = fdef.signature
|
|
if node_def.op not in copied_functions:
|
|
# Since this function is referenced as an op type, we have no choice but
|
|
# to copy it into the GraphDef if we want downstream tools to process
|
|
# it.
|
|
graph_def.library.function.add().CopyFrom(fdef)
|
|
copied_functions.add(node_def.op)
|
|
if f.grad_func_name:
|
|
grad_def = function_pb2.GradientDef()
|
|
grad_def.function_name = f.name
|
|
grad_def.gradient_func = f.grad_func_name
|
|
graph_def.library.gradient.extend([grad_def])
|
|
else:
|
|
op_def = default_graph._get_op_def(node_def.op) # pylint: disable=protected-access
|
|
|
|
for attr in op_def.attr:
|
|
if attr.type == "func":
|
|
fname = node_def.attr[attr.name].func.name
|
|
if not is_function(fname):
|
|
raise ValueError(f"Function {fname} was not found. Please make sure "
|
|
"the FunctionDef `fdef` is correct.")
|
|
elif attr.type == "list(func)":
|
|
for fn in node_def.attr[attr.name].list.func:
|
|
fname = fn.name
|
|
if not is_function(fname):
|
|
raise ValueError(f"Function {fname} was not found. Please make "
|
|
"sure the FunctionDef `fdef` is correct.")
|
|
|
|
# Iterate over output_args in op_def to build the map.
|
|
# Index of the output tensor in the flattened list of *all* output
|
|
# tensors of the op.
|
|
flattened_index = 0
|
|
for arg_def in op_def.output_arg:
|
|
num_args = _get_num_args(arg_def, node_def)
|
|
for i in range(num_args):
|
|
# Map tensor names from "node_name:output_arg_name:index" to
|
|
# "node_name:flattened_index".
|
|
nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i)
|
|
flat_name = "{}:{}".format(node_def.name, flattened_index)
|
|
nested_to_flat_tensor_name[nested_name] = flat_name
|
|
flattened_index += 1
|
|
control_name = "^" + node_def.name
|
|
nested_to_flat_tensor_name[control_name] = control_name
|
|
|
|
# Update inputs of all nodes in graph.
|
|
for node_def in graph_def.node:
|
|
for i in range(len(node_def.input)):
|
|
node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]]
|
|
|
|
return graph_def, nested_to_flat_tensor_name
|
|
|
|
|
|
# Based on implementation in core/framework/node_def_util.cc::ComputeArgRange.
|
|
def _get_num_args(arg_def, node_def):
|
|
if arg_def.number_attr:
|
|
return node_def.attr[arg_def.number_attr].i
|
|
elif arg_def.type_list_attr:
|
|
return len(node_def.attr[arg_def.type_list_attr].list.type)
|
|
elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID:
|
|
return 1
|
|
else:
|
|
raise ValueError(f"Invalid arg_def:\n\n{arg_def}. Please make sure the "
|
|
"FunctionDef `fdef` is correct.")
|
|
|
|
|
|
def _set_handle_data(func_graph, fdef):
|
|
"""Adds handle data for resource type inputs and outputs."""
|
|
# The shape of the handle itself is [], while the variable shape is
|
|
# saved in `handle_data`. Previously, the shape of the resource handle
|
|
# was set to `None`. Correct both shapes here.
|
|
for tensor, arg_def in itertools.chain(
|
|
zip(func_graph.inputs, fdef.signature.input_arg),
|
|
zip(func_graph.outputs, fdef.signature.output_arg)):
|
|
if arg_def.handle_data:
|
|
tensor.set_shape([])
|
|
|
|
shape_and_dtype = arg_def.handle_data[0]
|
|
handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData()
|
|
handle_data.is_set = True
|
|
handle_data.shape_and_type.append(
|
|
cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType(
|
|
shape=shape_and_dtype.shape, dtype=shape_and_dtype.dtype))
|
|
resource_variable_ops._set_handle_shapes_and_types( # pylint: disable=protected-access
|
|
tensor, handle_data, True)
|