3RNN/Lib/site-packages/tensorflow/python/tools/strip_unused.py
2024-05-26 19:49:15 +02:00

105 lines
3.6 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Removes unneeded nodes from a GraphDef file.
This script is designed to help streamline models, by taking the input and
output nodes that will be used by an application and figuring out the smallest
set of operations that are required to run for those arguments. The resulting
minimal graph is then saved out.
The advantages of running this script are:
- You may be able to shrink the file size.
- Operations that are unsupported on your platform but still present can be
safely removed.
The resulting graph may not be as flexible as the original though, since any
input nodes that weren't explicitly mentioned may not be accessible any more.
An example of command-line usage is:
bazel build tensorflow/python/tools:strip_unused && \
bazel-bin/tensorflow/python/tools/strip_unused \
--input_graph=some_graph_def.pb \
--output_graph=/tmp/stripped_graph.pb \
--input_node_names=input0
--output_node_names=softmax
You can also look at strip_unused_test.py for an example of how to use it.
"""
import argparse
import sys
from absl import app
from tensorflow.python.framework import dtypes
from tensorflow.python.tools import strip_unused_lib
FLAGS = None
def main(unused_args):
strip_unused_lib.strip_unused_from_files(FLAGS.input_graph,
FLAGS.input_binary,
FLAGS.output_graph,
FLAGS.output_binary,
FLAGS.input_node_names,
FLAGS.output_node_names,
FLAGS.placeholder_type_enum)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.register('type', 'bool', lambda v: v.lower() == 'true')
parser.add_argument(
'--input_graph',
type=str,
default='',
help='TensorFlow \'GraphDef\' file to load.')
parser.add_argument(
'--input_binary',
nargs='?',
const=True,
type='bool',
default=False,
help='Whether the input files are in binary format.')
parser.add_argument(
'--output_graph',
type=str,
default='',
help='Output \'GraphDef\' file name.')
parser.add_argument(
'--output_binary',
nargs='?',
const=True,
type='bool',
default=True,
help='Whether to write a binary format graph.')
parser.add_argument(
'--input_node_names',
type=str,
default='',
help='The name of the input nodes, comma separated.')
parser.add_argument(
'--output_node_names',
type=str,
default='',
help='The name of the output nodes, comma separated.')
parser.add_argument(
'--placeholder_type_enum',
type=int,
default=dtypes.float32.as_datatype_enum,
help='The AttrValue enum to use for placeholders.')
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)