1095 lines
43 KiB
Python
1095 lines
43 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""MetaGraph and related functions."""
|
|
import copy
|
|
from packaging import version as packaging_version # pylint: disable=g-bad-import-order
|
|
import os.path
|
|
import re
|
|
import sys
|
|
|
|
from google.protobuf.any_pb2 import Any
|
|
from google.protobuf import text_format
|
|
|
|
from tensorflow.core.framework import attr_value_pb2
|
|
from tensorflow.core.framework import graph_pb2
|
|
from tensorflow.core.framework import op_def_pb2
|
|
from tensorflow.core.protobuf import meta_graph_pb2
|
|
from tensorflow.core.protobuf import saver_pb2
|
|
from tensorflow.python.client import pywrap_tf_session as c_api
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import byte_swap_tensor as bst
|
|
from tensorflow.python.framework import error_interpolation
|
|
from tensorflow.python.framework import graph_io
|
|
from tensorflow.python.framework import importer
|
|
from tensorflow.python.framework import op_def_registry
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor
|
|
from tensorflow.python.framework import versions
|
|
from tensorflow.python.lib.io import file_io
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
from tensorflow.python.util import compat
|
|
|
|
|
|
# Prefix to be added to unbound input names so they are easily identifiable.
|
|
_UNBOUND_INPUT_PREFIX = "$unbound_inputs_"
|
|
|
|
# List of collections that didn't register proto functions, as a result in
|
|
# a previously exported meta_graph the items are of a different data type.
|
|
_COMPAT_COLLECTION_LIST = [ops.GraphKeys.LOCAL_VARIABLES,
|
|
ops.GraphKeys.MODEL_VARIABLES,
|
|
ops.GraphKeys.METRIC_VARIABLES]
|
|
|
|
|
|
def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False):
|
|
"""Create a `NodeDef` proto with export_scope stripped.
|
|
|
|
Args:
|
|
from_node_def: A `node_def_pb2.NodeDef` protocol buffer.
|
|
export_scope: A `string` representing the name scope to remove.
|
|
unbound_inputs: An array of unbound input names if they exist.
|
|
clear_devices: Boolean which controls whether to clear device information
|
|
from node_def. Default false.
|
|
|
|
Returns:
|
|
A `node_def_pb2.NodeDef` protocol buffer.
|
|
"""
|
|
node_def = copy.deepcopy(from_node_def)
|
|
for i, v in enumerate(node_def.input):
|
|
if (export_scope and
|
|
not node_def.input[i].lstrip("^").startswith(export_scope)):
|
|
# Adds "$unbound_inputs_" prefix to the unbound name so they are easily
|
|
# identifiable.
|
|
node_def.input[i] = re.sub(r"([\^]|^)(.*)",
|
|
r"\1" + _UNBOUND_INPUT_PREFIX + r"\2",
|
|
compat.as_str(v))
|
|
unbound_inputs.append(node_def.input[i])
|
|
else:
|
|
node_def.input[i] = ops.strip_name_scope(v, export_scope)
|
|
node_def.name = compat.as_bytes(
|
|
ops.strip_name_scope(from_node_def.name, export_scope))
|
|
for k, v in from_node_def.attr.items():
|
|
if k == "_class":
|
|
new_s = [compat.as_bytes(
|
|
ops.strip_name_scope(s, export_scope)) for s in v.list.s
|
|
if not export_scope or
|
|
compat.as_str(s).split("@")[1].startswith(export_scope)]
|
|
node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(
|
|
list=attr_value_pb2.AttrValue.ListValue(s=new_s)))
|
|
elif node_def.op in ("Enter", "RefEnter") and k == "frame_name":
|
|
if not export_scope or compat.as_str(v.s).startswith(export_scope):
|
|
new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope))
|
|
node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s))
|
|
else:
|
|
node_def.attr[k].CopyFrom(v)
|
|
|
|
if clear_devices:
|
|
node_def.device = ""
|
|
|
|
return node_def
|
|
|
|
|
|
def _read_file(filename):
|
|
"""Reads a file containing `GraphDef` and returns the protocol buffer.
|
|
|
|
Args:
|
|
filename: `graph_def` filename including the path.
|
|
|
|
Returns:
|
|
A `GraphDef` protocol buffer.
|
|
|
|
Raises:
|
|
IOError: If the file doesn't exist, or cannot be successfully parsed.
|
|
"""
|
|
graph_def = graph_pb2.GraphDef()
|
|
if not file_io.file_exists(filename):
|
|
raise IOError(f"File {filename} does not exist.")
|
|
# First try to read it as a binary file.
|
|
with file_io.FileIO(filename, "rb") as f:
|
|
file_content = f.read()
|
|
try:
|
|
graph_def.ParseFromString(file_content)
|
|
return graph_def
|
|
except Exception: # pylint: disable=broad-except
|
|
pass
|
|
|
|
# Next try to read it as a text file.
|
|
try:
|
|
text_format.Merge(file_content, graph_def)
|
|
except text_format.ParseError as e:
|
|
raise IOError(f"Cannot parse file {filename}: {str(e)}.")
|
|
|
|
return graph_def
|
|
|
|
|
|
def ops_used_by_graph_def(graph_def):
|
|
"""Collect the list of ops used by a graph.
|
|
|
|
Does not validate that the ops are all registered.
|
|
|
|
Args:
|
|
graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.
|
|
|
|
Returns:
|
|
A list of strings, each naming an op used by the graph.
|
|
"""
|
|
# Map function names to definitions
|
|
name_to_function = {}
|
|
for fun in graph_def.library.function:
|
|
name_to_function[fun.signature.name] = fun
|
|
|
|
# Collect the list of op names. Since functions can reference functions, we
|
|
# need a recursive traversal.
|
|
used_ops = set() # Includes both primitive ops and functions
|
|
functions_to_process = [] # A subset of used_ops
|
|
|
|
def mark_op_as_used(op):
|
|
if op not in used_ops and op in name_to_function:
|
|
functions_to_process.append(name_to_function[op])
|
|
used_ops.add(op)
|
|
|
|
def process_node(node):
|
|
mark_op_as_used(node.op)
|
|
if node.op in ["PartitionedCall", "StatefulPartitionedCall"]:
|
|
mark_op_as_used(node.attr["f"].func.name)
|
|
|
|
for node in graph_def.node:
|
|
process_node(node)
|
|
while functions_to_process:
|
|
fun = functions_to_process.pop()
|
|
for node in fun.node_def:
|
|
process_node(node)
|
|
|
|
return [op for op in used_ops if op not in name_to_function]
|
|
|
|
|
|
def stripped_op_list_for_graph(graph_def):
|
|
"""Collect the stripped OpDefs for ops used by a graph.
|
|
|
|
This function computes the `stripped_op_list` field of `MetaGraphDef` and
|
|
similar protos. The result can be communicated from the producer to the
|
|
consumer, which can then use the C++ function
|
|
`RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.
|
|
|
|
Args:
|
|
graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.
|
|
|
|
Returns:
|
|
An `OpList` of ops used by the graph.
|
|
"""
|
|
# This is similar to StrippedOpListForGraph in C++, but unlike its
|
|
# C++ counterpart, this version does not require all ops to be registered.
|
|
# This is done to support Prelu fusion in tfjs.
|
|
used_ops = ops_used_by_graph_def(graph_def)
|
|
op_defs = []
|
|
for op in sorted(used_ops):
|
|
op_def = op_def_registry.get(op)
|
|
if op_def is not None:
|
|
op_defs.append(op_def)
|
|
return op_def_pb2.OpList(op=op_defs)
|
|
|
|
|
|
def _get_kind_name(item):
|
|
"""Returns the kind name in CollectionDef.
|
|
|
|
Args:
|
|
item: A data item.
|
|
|
|
Returns:
|
|
The string representation of the kind in CollectionDef.
|
|
"""
|
|
if isinstance(item, (str, bytes)):
|
|
kind = "bytes_list"
|
|
elif isinstance(item, int):
|
|
kind = "int64_list"
|
|
elif isinstance(item, float):
|
|
kind = "float_list"
|
|
elif isinstance(item, Any):
|
|
kind = "any_list"
|
|
else:
|
|
kind = "node_list"
|
|
return kind
|
|
|
|
|
|
SAVE_AND_RESTORE_OPS = ["SaveV2",
|
|
"Save", "SaveSlice",
|
|
"LegacySave", "LegacySaveSlice",
|
|
"RestoreV2",
|
|
"Restore", "RestoreSlice",
|
|
"LegacyRestore", "LegacyRestoreSlice"]
|
|
|
|
|
|
def _get_scope(node_name):
|
|
"""Extract the scope name from a node name.
|
|
|
|
The scope name is everything before the final slash,
|
|
not including any ^ prefix denoting a control dependency.
|
|
|
|
Args:
|
|
node_name: the full name of an Op or a Tensor in the graph.
|
|
Returns:
|
|
The deepest named scope containing the node.
|
|
Raises:
|
|
ValueError: if tensor_name is None or empty
|
|
"""
|
|
if not node_name:
|
|
raise ValueError(
|
|
f"Node name cannot be empty or None. Received: {node_name}.")
|
|
|
|
# Control dependency inputs start with ^.
|
|
if node_name.startswith("^"):
|
|
node_name = node_name[1:]
|
|
if "/" in node_name:
|
|
scope, _ = node_name.rsplit("/", 1)
|
|
return scope
|
|
|
|
return ""
|
|
|
|
|
|
def _find_extraneous_saver_nodes(graph_def, saver_def):
|
|
"""Identifies any nodes in the graph_def related to unused Savers.
|
|
|
|
This approach assumes that each Saver is cleanly isolated in its own name
|
|
scope, so we need only identify the scopes associated with extraneous Savers
|
|
and return all the nodes in those scopes.
|
|
|
|
Args:
|
|
graph_def: a GraphDef proto to evaluate.
|
|
saver_def: a SaverDef proto referencing Save/Restore ops to be retained.
|
|
Returns:
|
|
An iterable of node names that may be safely omitted.
|
|
"""
|
|
# TODO(soergel): confirm that the assumption of scope isolation is valid.
|
|
# If not, we need to walk up the graph from any restore_all nodes, and walk
|
|
# down the graph from any Save/Restore nodes. I drafted that approach too,
|
|
# but it seems unnecessarily complex given the name scope solution.
|
|
|
|
# load the graph DAG in minimal form, without initializing a full Graph object
|
|
nodes = {
|
|
node_def.name: (
|
|
set(tensor.get_op_name(x) for x in node_def.input), node_def.op)
|
|
for node_def in graph_def.node
|
|
}
|
|
|
|
retain_scope_save = None
|
|
retain_scope_restore = None
|
|
# It's possible to have no saver if the graph has no Variables
|
|
if saver_def is not None:
|
|
save_op_name = tensor.get_op_name(saver_def.save_tensor_name)
|
|
restore_op_name = tensor.get_op_name(saver_def.restore_op_name)
|
|
|
|
# The save and restore scopes should always be the same, but if they differ
|
|
# for some reason, we retain them both to be safe.
|
|
retain_scope_restore = _get_scope(restore_op_name) + "/"
|
|
retain_scope_save = _get_scope(save_op_name) + "/"
|
|
|
|
all_saver_node_names = set(
|
|
name for name, (_, op) in nodes.items() if op in SAVE_AND_RESTORE_OPS)
|
|
|
|
all_saver_scopes = (
|
|
set(_get_scope(x) for x in all_saver_node_names) - all_saver_node_names)
|
|
all_saver_scopes = set(x + "/" for x in all_saver_scopes)
|
|
|
|
extraneous_scopes = all_saver_scopes - set([retain_scope_save,
|
|
retain_scope_restore])
|
|
|
|
extraneous_node_names = set()
|
|
for name, _ in nodes.items():
|
|
for extraneous_scope in extraneous_scopes:
|
|
if name.startswith(extraneous_scope):
|
|
extraneous_node_names.add(name)
|
|
break
|
|
|
|
return extraneous_node_names
|
|
|
|
|
|
def _should_include_node(node_or_node_name, export_scope, exclude_nodes):
|
|
"""Returns `True` if a node should be included.
|
|
|
|
Args:
|
|
node_or_node_name: A node or `string` node name.
|
|
export_scope: `string`. Name scope under which to extract the subgraph. The
|
|
scope name will be stripped from the node definitions for easy import
|
|
later into new name scopes.
|
|
exclude_nodes: An iterable of nodes or `string` node names to omit from the
|
|
export, or None. Note no sanity-checking is done, so this list must be
|
|
carefully constructed to avoid producing an invalid graph.
|
|
|
|
Returns:
|
|
`True` if the node should be included.
|
|
"""
|
|
if not isinstance(node_or_node_name, str):
|
|
try:
|
|
node_name = node_or_node_name.name
|
|
except AttributeError:
|
|
# Keep the object that we don't know how to process.
|
|
return True
|
|
else:
|
|
node_name = node_or_node_name
|
|
|
|
if exclude_nodes and (node_or_node_name in exclude_nodes
|
|
or node_name in exclude_nodes):
|
|
return False
|
|
|
|
return (node_name.startswith(_UNBOUND_INPUT_PREFIX) or
|
|
(not export_scope or node_name.startswith(export_scope)))
|
|
|
|
|
|
def add_collection_def(meta_graph_def, key, graph=None,
|
|
export_scope=None, exclude_nodes=None,
|
|
override_contents=None):
|
|
"""Adds a collection to MetaGraphDef protocol buffer.
|
|
|
|
Args:
|
|
meta_graph_def: MetaGraphDef protocol buffer.
|
|
key: One of the GraphKeys or user-defined string.
|
|
graph: The `Graph` from which to get collections.
|
|
export_scope: Optional `string`. Name scope to remove.
|
|
exclude_nodes: An iterable of nodes or `string` node names to omit from the
|
|
collection, or None.
|
|
override_contents: An iterable of values to place in the collection,
|
|
ignoring the current values (if set).
|
|
"""
|
|
if graph and not isinstance(graph, ops.Graph):
|
|
raise TypeError(
|
|
f"graph must be of type Graph. Received type: {type(graph)}.")
|
|
|
|
if not isinstance(key, str) and not isinstance(key, bytes):
|
|
logging.warning("Only collections with string type keys will be "
|
|
"serialized. This key has %s", type(key))
|
|
return
|
|
|
|
# Sets graph to default graph if it's not passed in.
|
|
graph = graph or ops.get_default_graph()
|
|
|
|
if override_contents:
|
|
collection_list = override_contents
|
|
else:
|
|
collection_list = graph.get_collection(key)
|
|
|
|
# Remove nodes that should not be exported from the collection list.
|
|
collection_list = [x for x in collection_list if
|
|
_should_include_node(x, export_scope, exclude_nodes)]
|
|
if not collection_list:
|
|
return
|
|
|
|
try:
|
|
col_def = meta_graph_def.collection_def[key]
|
|
to_proto = ops.get_to_proto_function(key)
|
|
proto_type = ops.get_collection_proto_type(key)
|
|
if to_proto:
|
|
kind = "bytes_list"
|
|
for x in collection_list:
|
|
# Additional type check to make sure the returned proto is indeed
|
|
# what we expect.
|
|
proto = to_proto(x, export_scope=export_scope)
|
|
if proto:
|
|
assert isinstance(proto, proto_type)
|
|
getattr(col_def, kind).value.append(proto.SerializeToString())
|
|
else:
|
|
kind = _get_kind_name(collection_list[0])
|
|
if kind == "node_list":
|
|
for x in collection_list:
|
|
if not export_scope or x.name.startswith(export_scope):
|
|
getattr(col_def, kind).value.append(
|
|
ops.strip_name_scope(x.name, export_scope))
|
|
elif kind == "bytes_list":
|
|
# NOTE(opensource): This force conversion is to work around the fact
|
|
# that Python3 distinguishes between bytes and strings.
|
|
getattr(col_def, kind).value.extend(
|
|
[compat.as_bytes(x) for x in collection_list])
|
|
else:
|
|
getattr(col_def, kind).value.extend([x for x in collection_list])
|
|
except Exception as e: # pylint: disable=broad-except
|
|
logging.warning("Issue encountered when serializing %s.\n"
|
|
"Type is unsupported, or the types of the items don't "
|
|
"match field type in CollectionDef. Note this is a warning "
|
|
"and probably safe to ignore.\n%s", key, str(e))
|
|
if key in meta_graph_def.collection_def:
|
|
del meta_graph_def.collection_def[key]
|
|
return
|
|
|
|
|
|
def _is_default_attr_value(op_def, attr_name, attr_value):
|
|
"""Checks if given attribute matches the default value in the op def."""
|
|
for attr_def in op_def.attr:
|
|
if attr_def.name == attr_name:
|
|
if not attr_def.HasField("default_value"):
|
|
return False
|
|
# c_api.EqualAttrValueWrapper returns an empty string
|
|
# if both arguments represent an equivalent AttrValue instance.
|
|
return not c_api.EqualAttrValueWrapper(
|
|
attr_value.SerializeToString(),
|
|
attr_def.default_value.SerializeToString())
|
|
return False
|
|
|
|
|
|
def strip_graph_default_valued_attrs(meta_graph_def):
|
|
"""Strips default valued attributes for node defs in given MetaGraphDef.
|
|
|
|
This method also sets `meta_info_def.stripped_default_attrs` in the given
|
|
`MetaGraphDef` proto to True.
|
|
|
|
Args:
|
|
meta_graph_def: `MetaGraphDef` protocol buffer
|
|
|
|
Returns:
|
|
None.
|
|
"""
|
|
# Map function op names to their function definitions.
|
|
op_name_to_function = {}
|
|
for function_def in meta_graph_def.graph_def.library.function:
|
|
op_name_to_function[function_def.signature.name] = function_def
|
|
|
|
def _strip_node_default_valued_attrs(node_def):
|
|
"""Removes default valued attributes from a single node def."""
|
|
if node_def.op in op_name_to_function:
|
|
return
|
|
|
|
op_def = op_def_registry.get(node_def.op)
|
|
if op_def is None:
|
|
return
|
|
|
|
attrs_to_strip = set()
|
|
for attr_name, attr_value in node_def.attr.items():
|
|
if _is_default_attr_value(op_def, attr_name, attr_value):
|
|
attrs_to_strip.add(attr_name)
|
|
|
|
for attr in attrs_to_strip:
|
|
del node_def.attr[attr]
|
|
|
|
# Process all NodeDef instances in graph_def.
|
|
for node_def in meta_graph_def.graph_def.node:
|
|
_strip_node_default_valued_attrs(node_def)
|
|
|
|
# Process all NodeDef instances in graph_def.library.function.
|
|
for function_def in meta_graph_def.graph_def.library.function:
|
|
for function_node_def in function_def.node_def:
|
|
_strip_node_default_valued_attrs(function_node_def)
|
|
|
|
# Tell consumers of this graph that default valued attrs have been stripped.
|
|
meta_graph_def.meta_info_def.stripped_default_attrs = True
|
|
|
|
|
|
def create_meta_graph_def(meta_info_def=None,
|
|
graph_def=None,
|
|
saver_def=None,
|
|
collection_list=None,
|
|
graph=None,
|
|
export_scope=None,
|
|
exclude_nodes=None,
|
|
clear_extraneous_savers=False,
|
|
strip_default_attrs=False):
|
|
# pylint: disable=line-too-long
|
|
"""Construct and returns a `MetaGraphDef` protocol buffer.
|
|
|
|
Args:
|
|
meta_info_def: `MetaInfoDef` protocol buffer.
|
|
graph_def: `GraphDef` protocol buffer.
|
|
saver_def: `SaverDef` protocol buffer.
|
|
collection_list: List of string keys to collect.
|
|
graph: The `Graph` to create `MetaGraphDef` out of.
|
|
export_scope: Optional `string`. Name scope to remove.
|
|
exclude_nodes: An iterable of nodes or `string` node names to omit from all
|
|
collection, or None.
|
|
clear_extraneous_savers: Remove any preexisting SaverDefs from the SAVERS
|
|
collection. Note this method does not alter the graph, so any
|
|
extraneous Save/Restore ops should have been removed already, as needed.
|
|
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
|
|
removed from the NodeDefs. For a detailed guide, see
|
|
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
|
|
|
|
Returns:
|
|
MetaGraphDef protocol buffer.
|
|
|
|
Raises:
|
|
TypeError: If the arguments are not of the correct proto buffer type.
|
|
"""
|
|
# pylint: enable=line-too-long
|
|
# Type check.
|
|
if graph and not isinstance(graph, ops.Graph):
|
|
raise TypeError(
|
|
f"graph must be of type Graph. Received type: {type(graph)}.")
|
|
if meta_info_def and not isinstance(meta_info_def,
|
|
meta_graph_pb2.MetaGraphDef.MetaInfoDef):
|
|
raise TypeError(
|
|
"meta_info_def must be of type MetaInfoDef. "
|
|
f"Received type: {type(meta_info_def)}.")
|
|
if graph_def and not isinstance(graph_def, graph_pb2.GraphDef):
|
|
raise TypeError(
|
|
"graph_def must be of type GraphDef. "
|
|
f"Received type: {type(graph_def)}.")
|
|
if saver_def and not isinstance(saver_def, saver_pb2.SaverDef):
|
|
raise TypeError(
|
|
f"saver_def must be of type SaverDef. "
|
|
f"Received type: {type(saver_def)}.")
|
|
|
|
# Sets graph to default graph if it's not passed in.
|
|
graph = graph or ops.get_default_graph()
|
|
|
|
# Creates a MetaGraphDef proto.
|
|
meta_graph_def = meta_graph_pb2.MetaGraphDef()
|
|
# Adds meta_info_def.
|
|
if not meta_info_def:
|
|
meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
|
|
|
|
# Set the tf version strings to the current tf build.
|
|
meta_info_def.tensorflow_version = versions.__version__
|
|
meta_info_def.tensorflow_git_version = versions.__git_version__
|
|
meta_graph_def.meta_info_def.MergeFrom(meta_info_def)
|
|
|
|
# Adds graph_def or the default.
|
|
if not graph_def:
|
|
meta_graph_def.graph_def.MergeFrom(graph.as_graph_def(add_shapes=True))
|
|
else:
|
|
meta_graph_def.graph_def.MergeFrom(graph_def)
|
|
|
|
# Fills in meta_info_def.stripped_op_list using the ops from graph_def.
|
|
# pylint: disable=g-explicit-length-test
|
|
if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0:
|
|
meta_graph_def.meta_info_def.stripped_op_list.MergeFrom(
|
|
stripped_op_list_for_graph(meta_graph_def.graph_def))
|
|
# pylint: enable=g-explicit-length-test
|
|
|
|
# Strip default valued attributes in graph_def.
|
|
if strip_default_attrs:
|
|
strip_graph_default_valued_attrs(meta_graph_def)
|
|
|
|
# Adds saver_def.
|
|
if saver_def:
|
|
meta_graph_def.saver_def.MergeFrom(saver_def)
|
|
|
|
# Adds collection_list.
|
|
if collection_list is not None:
|
|
clist = collection_list
|
|
else:
|
|
clist = graph.get_all_collection_keys()
|
|
|
|
for ctype in clist:
|
|
if clear_extraneous_savers and ctype == ops.GraphKeys.SAVERS:
|
|
# Avoid importing Saver here
|
|
from_proto = ops.get_from_proto_function(ctype)
|
|
add_collection_def(meta_graph_def, ctype,
|
|
graph=graph,
|
|
export_scope=export_scope,
|
|
exclude_nodes=exclude_nodes,
|
|
override_contents=[from_proto(saver_def)])
|
|
else:
|
|
add_collection_def(meta_graph_def, ctype,
|
|
graph=graph,
|
|
export_scope=export_scope,
|
|
exclude_nodes=exclude_nodes)
|
|
return meta_graph_def
|
|
|
|
|
|
def read_meta_graph_file(filename):
|
|
"""Reads a file containing `MetaGraphDef` and returns the protocol buffer.
|
|
|
|
Args:
|
|
filename: `meta_graph_def` filename including the path.
|
|
|
|
Returns:
|
|
A `MetaGraphDef` protocol buffer.
|
|
|
|
Raises:
|
|
IOError: If the file doesn't exist, or cannot be successfully parsed.
|
|
"""
|
|
meta_graph_def = meta_graph_pb2.MetaGraphDef()
|
|
if not file_io.file_exists(filename):
|
|
raise IOError(f"File does not exist. Received: {filename}.")
|
|
# First try to read it as a binary file.
|
|
with file_io.FileIO(filename, "rb") as f:
|
|
file_content = f.read()
|
|
try:
|
|
meta_graph_def.ParseFromString(file_content)
|
|
if sys.byteorder == "big":
|
|
bst.swap_tensor_content_in_graph_function(meta_graph_def, "little", "big")
|
|
return meta_graph_def
|
|
except Exception: # pylint: disable=broad-except
|
|
pass
|
|
|
|
# Next try to read it as a text file.
|
|
try:
|
|
text_format.Merge(file_content.decode("utf-8"), meta_graph_def)
|
|
if sys.byteorder == "big":
|
|
bst.swap_tensor_content_in_graph_function(meta_graph_def, "little", "big")
|
|
except text_format.ParseError as e:
|
|
raise IOError(f"Cannot parse file {filename}: {str(e)}.")
|
|
|
|
return meta_graph_def
|
|
|
|
|
|
def import_scoped_meta_graph(meta_graph_or_file,
|
|
clear_devices=False,
|
|
graph=None,
|
|
import_scope=None,
|
|
input_map=None,
|
|
unbound_inputs_col_name="unbound_inputs",
|
|
restore_collections_predicate=(lambda key: True)):
|
|
"""Recreates a `Graph` saved in a `MetaGraphDef` proto.
|
|
|
|
This function takes a `MetaGraphDef` protocol buffer as input. If
|
|
the argument is a file containing a `MetaGraphDef` protocol buffer ,
|
|
it constructs a protocol buffer from the file content. The function
|
|
then adds all the nodes from the `graph_def` field to the
|
|
current graph, recreates the desired collections, and returns a dictionary of
|
|
all the Variables imported into the name scope.
|
|
|
|
In combination with `export_scoped_meta_graph()`, this function can be used to
|
|
|
|
* Serialize a graph along with other Python objects such as `QueueRunner`,
|
|
`Variable` into a `MetaGraphDef`.
|
|
|
|
* Restart training from a saved graph and checkpoints.
|
|
|
|
* Run inference from a saved graph and checkpoints.
|
|
|
|
Args:
|
|
meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
|
|
the path) containing a `MetaGraphDef`.
|
|
clear_devices: Boolean which controls whether to clear device information
|
|
from graph_def. Default false.
|
|
graph: The `Graph` to import into. If `None`, use the default graph.
|
|
import_scope: Optional `string`. Name scope into which to import the
|
|
subgraph. If `None`, the graph is imported to the root name scope.
|
|
input_map: A dictionary mapping input names (as strings) in `graph_def` to
|
|
`Tensor` objects. The values of the named input tensors in the imported
|
|
graph will be re-mapped to the respective `Tensor` values.
|
|
unbound_inputs_col_name: Collection name for looking up unbound inputs.
|
|
restore_collections_predicate: a predicate on collection names. A collection
|
|
named c (i.e whose key is c) will be restored iff
|
|
1) `restore_collections_predicate(c)` is True, and
|
|
2) `c != unbound_inputs_col_name`.
|
|
|
|
Returns:
|
|
A dictionary of all the `Variables` imported into the name scope.
|
|
|
|
Raises:
|
|
ValueError: If the graph_def contains unbound inputs.
|
|
"""
|
|
return import_scoped_meta_graph_with_return_elements(
|
|
meta_graph_or_file, clear_devices, graph, import_scope, input_map,
|
|
unbound_inputs_col_name, restore_collections_predicate)[0]
|
|
|
|
|
|
def import_scoped_meta_graph_with_return_elements(
|
|
meta_graph_or_file,
|
|
clear_devices=False,
|
|
graph=None,
|
|
import_scope=None,
|
|
input_map=None,
|
|
unbound_inputs_col_name="unbound_inputs",
|
|
restore_collections_predicate=(lambda key: True),
|
|
return_elements=None):
|
|
"""Imports graph from `MetaGraphDef` and returns vars and return elements.
|
|
|
|
This function takes a `MetaGraphDef` protocol buffer as input. If
|
|
the argument is a file containing a `MetaGraphDef` protocol buffer ,
|
|
it constructs a protocol buffer from the file content. The function
|
|
then adds all the nodes from the `graph_def` field to the
|
|
current graph, recreates the desired collections, and returns a dictionary of
|
|
all the Variables imported into the name scope.
|
|
|
|
In combination with `export_scoped_meta_graph()`, this function can be used to
|
|
|
|
* Serialize a graph along with other Python objects such as `QueueRunner`,
|
|
`Variable` into a `MetaGraphDef`.
|
|
|
|
* Restart training from a saved graph and checkpoints.
|
|
|
|
* Run inference from a saved graph and checkpoints.
|
|
|
|
Args:
|
|
meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
|
|
the path) containing a `MetaGraphDef`.
|
|
clear_devices: Boolean which controls whether to clear device information
|
|
from graph_def. Default false.
|
|
graph: The `Graph` to import into. If `None`, use the default graph.
|
|
import_scope: Optional `string`. Name scope into which to import the
|
|
subgraph. If `None`, the graph is imported to the root name scope.
|
|
input_map: A dictionary mapping input names (as strings) in `graph_def` to
|
|
`Tensor` objects. The values of the named input tensors in the imported
|
|
graph will be re-mapped to the respective `Tensor` values.
|
|
unbound_inputs_col_name: Collection name for looking up unbound inputs.
|
|
restore_collections_predicate: a predicate on collection names. A collection
|
|
named c (i.e whose key is c) will be restored iff
|
|
1) `restore_collections_predicate(c)` is True, and
|
|
2) `c != unbound_inputs_col_name`.
|
|
return_elements: A list of strings containing operation names in the
|
|
`MetaGraphDef` that will be returned as `Operation` objects; and/or
|
|
tensor names in `MetaGraphDef` that will be returned as `Tensor` objects.
|
|
|
|
Returns:
|
|
A tuple of (
|
|
dictionary of all the `Variables` imported into the name scope,
|
|
list of `Operation` or `Tensor` objects from the `return_elements` list).
|
|
|
|
Raises:
|
|
ValueError: If the graph_def contains unbound inputs.
|
|
|
|
"""
|
|
if context.executing_eagerly():
|
|
raise ValueError("Exporting/importing meta graphs is not supported when "
|
|
"eager execution is enabled.")
|
|
if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
|
|
meta_graph_def = meta_graph_or_file
|
|
else:
|
|
meta_graph_def = read_meta_graph_file(meta_graph_or_file)
|
|
|
|
if unbound_inputs_col_name:
|
|
for key, col_def in meta_graph_def.collection_def.items():
|
|
if key == unbound_inputs_col_name:
|
|
kind = col_def.WhichOneof("kind")
|
|
field = getattr(col_def, kind)
|
|
if field.value and (
|
|
not input_map or
|
|
sorted([compat.as_str(v) for v in field.value]) !=
|
|
sorted(input_map)):
|
|
raise ValueError("Graph contains unbound inputs: %s. Must "
|
|
"provide these inputs through input_map." % ",".join(
|
|
compat.as_str(v)
|
|
for v in field.value
|
|
if not input_map or v not in input_map))
|
|
break
|
|
|
|
# Sets graph to default graph if it's not passed in.
|
|
graph = graph or ops.get_default_graph()
|
|
|
|
# Gathers the list of nodes we are interested in.
|
|
with graph.as_default():
|
|
producer_op_list = None
|
|
if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
|
|
producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
|
|
input_graph_def = meta_graph_def.graph_def
|
|
# Remove all the explicit device specifications for this node. This helps to
|
|
# make the graph more portable.
|
|
if clear_devices:
|
|
for node in input_graph_def.node:
|
|
node.device = ""
|
|
|
|
scope_to_prepend_to_names = graph.unique_name(
|
|
import_scope or "", mark_as_used=False)
|
|
|
|
imported_return_elements = importer.import_graph_def(
|
|
input_graph_def,
|
|
name=(import_scope or scope_to_prepend_to_names),
|
|
input_map=input_map,
|
|
producer_op_list=producer_op_list,
|
|
return_elements=return_elements)
|
|
|
|
# TensorFlow versions before 1.9 (not inclusive) exported SavedModels
|
|
# without a VariableDef.trainable field set.
|
|
tf_version = meta_graph_def.meta_info_def.tensorflow_version
|
|
if not tf_version:
|
|
variables_have_trainable = True
|
|
else:
|
|
variables_have_trainable = (
|
|
packaging_version.parse(tf_version) >= packaging_version.parse("1.9"))
|
|
|
|
# Sort collections so we see TRAINABLE_VARIABLES first and can default these
|
|
# variables to trainable if the value is not set in their VariableDef.
|
|
sorted_collections = []
|
|
if ops.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def:
|
|
sorted_collections.append(
|
|
(ops.GraphKeys.TRAINABLE_VARIABLES,
|
|
meta_graph_def.collection_def[ops.GraphKeys.TRAINABLE_VARIABLES]))
|
|
for key, value in sorted(meta_graph_def.collection_def.items()):
|
|
if key != ops.GraphKeys.TRAINABLE_VARIABLES:
|
|
sorted_collections.append((key, value))
|
|
|
|
# Restores all the other collections.
|
|
variable_objects = {}
|
|
for key, col_def in sorted_collections:
|
|
# Don't add unbound_inputs to the new graph.
|
|
if key == unbound_inputs_col_name:
|
|
continue
|
|
if not restore_collections_predicate(key):
|
|
continue
|
|
|
|
kind = col_def.WhichOneof("kind")
|
|
if kind is None:
|
|
logging.error("Cannot identify data type for collection %s. Skipping.",
|
|
key)
|
|
continue
|
|
from_proto = ops.get_from_proto_function(key)
|
|
|
|
# Temporary change to allow the TFMA evaluator to read metric variables
|
|
# saved as a bytes list.
|
|
# TODO(kathywu): Remove this hack once cl/248406059 has been submitted.
|
|
if key == ops.GraphKeys.METRIC_VARIABLES:
|
|
# Metric variables will use the same proto functions as GLOBAL_VARIABLES
|
|
from_proto = ops.get_from_proto_function(ops.GraphKeys.GLOBAL_VARIABLES)
|
|
if from_proto and kind == "bytes_list":
|
|
proto_type = ops.get_collection_proto_type(key)
|
|
if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access
|
|
for value in col_def.bytes_list.value:
|
|
variable = variable_objects.get(value, None)
|
|
if variable is None:
|
|
proto = proto_type()
|
|
proto.ParseFromString(value)
|
|
if not variables_have_trainable:
|
|
# If the VariableDef proto does not contain a "trainable"
|
|
# property because it was exported before that property was
|
|
# added, we default it to whether the variable is in the
|
|
# TRAINABLE_VARIABLES collection. We've sorted
|
|
# TRAINABLE_VARIABLES to be first, so trainable variables will
|
|
# be created from that collection.
|
|
proto.trainable = (key == ops.GraphKeys.TRAINABLE_VARIABLES)
|
|
variable = from_proto(
|
|
proto, import_scope=scope_to_prepend_to_names)
|
|
variable_objects[value] = variable
|
|
graph.add_to_collection(key, variable)
|
|
else:
|
|
for value in col_def.bytes_list.value:
|
|
proto = proto_type()
|
|
proto.ParseFromString(value)
|
|
graph.add_to_collection(
|
|
key, from_proto(
|
|
proto, import_scope=scope_to_prepend_to_names))
|
|
else:
|
|
field = getattr(col_def, kind)
|
|
if key in _COMPAT_COLLECTION_LIST:
|
|
logging.warning(
|
|
"The saved meta_graph is possibly from an older release:\n"
|
|
"'%s' collection should be of type 'byte_list', but instead "
|
|
"is of type '%s'.", key, kind)
|
|
if kind == "node_list":
|
|
for value in field.value:
|
|
col_op = graph.as_graph_element(
|
|
ops.prepend_name_scope(value, scope_to_prepend_to_names))
|
|
graph.add_to_collection(key, col_op)
|
|
elif kind == "int64_list":
|
|
# NOTE(opensource): This force conversion is to work around the fact
|
|
# that Python2 distinguishes between int and long, while Python3 has
|
|
# only int.
|
|
for value in field.value:
|
|
graph.add_to_collection(key, int(value))
|
|
else:
|
|
for value in field.value:
|
|
graph.add_to_collection(
|
|
key, ops.prepend_name_scope(value, scope_to_prepend_to_names))
|
|
|
|
var_list = {}
|
|
variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
|
|
scope=scope_to_prepend_to_names)
|
|
for v in variables:
|
|
var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v
|
|
|
|
return var_list, imported_return_elements
|
|
|
|
|
|
def export_scoped_meta_graph(filename=None,
|
|
graph_def=None,
|
|
graph=None,
|
|
export_scope=None,
|
|
as_text=False,
|
|
unbound_inputs_col_name="unbound_inputs",
|
|
clear_devices=False,
|
|
saver_def=None,
|
|
clear_extraneous_savers=False,
|
|
strip_default_attrs=False,
|
|
save_debug_info=False,
|
|
**kwargs):
|
|
"""Returns `MetaGraphDef` proto. Optionally writes it to filename.
|
|
|
|
This function exports the graph, saver, and collection objects into
|
|
`MetaGraphDef` protocol buffer with the intention of it being imported
|
|
at a later time or location to restart training, run inference, or be
|
|
a subgraph.
|
|
|
|
Args:
|
|
filename: Optional filename including the path for writing the
|
|
generated `MetaGraphDef` protocol buffer.
|
|
graph_def: `GraphDef` protocol buffer.
|
|
graph: The `Graph` to export. If `None`, use the default graph.
|
|
export_scope: Optional `string`. Name scope under which to extract
|
|
the subgraph. The scope name will be stripped from the node definitions
|
|
for easy import later into new name scopes. If `None`, the whole graph
|
|
is exported.
|
|
as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
|
|
unbound_inputs_col_name: Optional `string`. If provided, a string collection
|
|
with the given name will be added to the returned `MetaGraphDef`,
|
|
containing the names of tensors that must be remapped when importing the
|
|
`MetaGraphDef`.
|
|
clear_devices: Boolean which controls whether to clear device information
|
|
before exporting the graph.
|
|
saver_def: `SaverDef` protocol buffer.
|
|
clear_extraneous_savers: Remove any Saver-related information from the
|
|
graph (both Save/Restore ops and SaverDefs) that are not associated
|
|
with the provided SaverDef.
|
|
strip_default_attrs: Set to true if default valued attributes must be
|
|
removed while exporting the GraphDef.
|
|
save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
|
|
which in the same directory of filename and with `_debug` added before the
|
|
file extension.
|
|
**kwargs: Optional keyed arguments, including meta_info_def and
|
|
collection_list.
|
|
|
|
Returns:
|
|
A `MetaGraphDef` proto and dictionary of `Variables` in the exported
|
|
name scope.
|
|
|
|
Raises:
|
|
ValueError: When the `GraphDef` is larger than 2GB.
|
|
ValueError: When executing in Eager mode and either `graph_def` or `graph`
|
|
is undefined.
|
|
"""
|
|
if context.executing_eagerly() and not (graph_def is not None and
|
|
graph is not None):
|
|
raise ValueError("Exporting/importing meta graphs is not supported when "
|
|
"Eager Execution is enabled.")
|
|
graph = graph or ops.get_default_graph()
|
|
|
|
exclude_nodes = None
|
|
unbound_inputs = []
|
|
if export_scope or clear_extraneous_savers or clear_devices:
|
|
if graph_def:
|
|
new_graph_def = graph_pb2.GraphDef()
|
|
new_graph_def.versions.CopyFrom(graph_def.versions)
|
|
new_graph_def.library.CopyFrom(graph_def.library)
|
|
|
|
if clear_extraneous_savers:
|
|
exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def)
|
|
|
|
for node_def in graph_def.node:
|
|
if _should_include_node(node_def.name, export_scope, exclude_nodes):
|
|
new_node_def = _node_def(node_def, export_scope, unbound_inputs,
|
|
clear_devices=clear_devices)
|
|
new_graph_def.node.extend([new_node_def])
|
|
graph_def = new_graph_def
|
|
else:
|
|
# Only do this complicated work if we want to remove a name scope.
|
|
graph_def = graph_pb2.GraphDef()
|
|
# pylint: disable=protected-access
|
|
graph_def.versions.CopyFrom(graph.graph_def_versions)
|
|
bytesize = 0
|
|
|
|
if clear_extraneous_savers:
|
|
exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(),
|
|
saver_def)
|
|
|
|
for key in sorted(graph._nodes_by_id):
|
|
if _should_include_node(graph._nodes_by_id[key].name,
|
|
export_scope,
|
|
exclude_nodes):
|
|
value = graph._nodes_by_id[key]
|
|
# pylint: enable=protected-access
|
|
node_def = _node_def(value.node_def, export_scope, unbound_inputs,
|
|
clear_devices=clear_devices)
|
|
graph_def.node.extend([node_def])
|
|
if value.outputs:
|
|
assert "_output_shapes" not in graph_def.node[-1].attr
|
|
graph_def.node[-1].attr["_output_shapes"].list.shape.extend([
|
|
output.get_shape().as_proto() for output in value.outputs])
|
|
bytesize += value.node_def.ByteSize()
|
|
if bytesize >= (1 << 31) or bytesize < 0:
|
|
raise ValueError(
|
|
"GraphDef cannot be larger than 2GB. "
|
|
f"Received size: {bytesize}.")
|
|
|
|
graph._copy_functions_to_graph_def(graph_def, bytesize) # pylint: disable=protected-access
|
|
|
|
# It's possible that not all the inputs are in the export_scope.
|
|
# If we would like such information included in the exported meta_graph,
|
|
# add them to a special unbound_inputs collection.
|
|
if unbound_inputs_col_name:
|
|
# Clears the unbound_inputs collections.
|
|
graph.clear_collection(unbound_inputs_col_name)
|
|
for k in unbound_inputs:
|
|
graph.add_to_collection(unbound_inputs_col_name, k)
|
|
|
|
var_list = {}
|
|
variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
|
|
scope=export_scope)
|
|
for v in variables:
|
|
if _should_include_node(v, export_scope, exclude_nodes):
|
|
var_list[ops.strip_name_scope(v.name, export_scope)] = v
|
|
|
|
scoped_meta_graph_def = create_meta_graph_def(
|
|
graph_def=graph_def,
|
|
graph=graph,
|
|
export_scope=export_scope,
|
|
exclude_nodes=exclude_nodes,
|
|
clear_extraneous_savers=clear_extraneous_savers,
|
|
saver_def=saver_def,
|
|
strip_default_attrs=strip_default_attrs,
|
|
**kwargs)
|
|
|
|
if filename:
|
|
graph_io.write_graph(
|
|
scoped_meta_graph_def,
|
|
os.path.dirname(filename),
|
|
os.path.basename(filename),
|
|
as_text=as_text)
|
|
if save_debug_info:
|
|
name, _ = os.path.splitext(filename)
|
|
debug_filename = "{name}{ext}".format(name=name, ext=".debug")
|
|
|
|
# Gets the operation from the graph by the name. Excludes variable nodes,
|
|
# so only the nodes in the frozen models are included.
|
|
# TODO(liufengdb): fix this for functions.
|
|
ops_to_export = []
|
|
for node in scoped_meta_graph_def.graph_def.node:
|
|
scoped_op_name = ops.prepend_name_scope(node.name, export_scope)
|
|
ops_to_export.append(("", graph.get_operation_by_name(scoped_op_name)))
|
|
|
|
graph_debug_info = error_interpolation.create_graph_debug_info_def(
|
|
ops_to_export)
|
|
|
|
graph_io.write_graph(
|
|
graph_debug_info,
|
|
os.path.dirname(debug_filename),
|
|
os.path.basename(debug_filename),
|
|
as_text=as_text)
|
|
|
|
return scoped_meta_graph_def, var_list
|
|
|
|
|
|
def copy_scoped_meta_graph(from_scope, to_scope,
|
|
from_graph=None, to_graph=None):
|
|
"""Copies a sub-meta_graph from one scope to another.
|
|
|
|
Args:
|
|
from_scope: `String` name scope containing the subgraph to be copied.
|
|
to_scope: `String` name scope under which the copied subgraph will reside.
|
|
from_graph: Optional `Graph` from which to copy the subgraph. If `None`, the
|
|
default graph is use.
|
|
to_graph: Optional `Graph` to which to copy the subgraph. If `None`, the
|
|
default graph is used.
|
|
|
|
Returns:
|
|
A dictionary of `Variables` that has been copied into `to_scope`.
|
|
|
|
Raises:
|
|
ValueError: If `from_scope` and `to_scope` are the same while
|
|
`from_graph` and `to_graph` are also the same.
|
|
"""
|
|
from_graph = from_graph or ops.get_default_graph()
|
|
to_graph = to_graph or ops.get_default_graph()
|
|
|
|
if from_graph == to_graph and from_scope == to_scope:
|
|
raise ValueError("'from_scope' and 'to_scope' need to be different "
|
|
"when performing copy in the same graph. "
|
|
f"Received: 'from_graph': {from_graph}, "
|
|
f"'to_graph': {to_graph}, "
|
|
f"'from_scope': {from_scope}, 'to_scope': {to_scope}.")
|
|
|
|
orig_meta_graph, var_list = export_scoped_meta_graph(
|
|
export_scope=from_scope, graph=from_graph)
|
|
var_list = import_scoped_meta_graph(orig_meta_graph,
|
|
graph=to_graph,
|
|
import_scope=to_scope)
|
|
return var_list
|