1071 lines
40 KiB
Python
1071 lines
40 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Functions used by multiple converter files."""
|
|
|
|
import copy
|
|
import datetime
|
|
import sys
|
|
|
|
from absl import logging
|
|
import flatbuffers
|
|
|
|
from tensorflow.core.protobuf import config_pb2 as _config_pb2
|
|
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
|
|
from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb
|
|
from tensorflow.lite.python import schema_py_generated as schema_fb
|
|
from tensorflow.lite.python import schema_util
|
|
from tensorflow.lite.python import tflite_keras_util as _tflite_keras_util
|
|
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
|
|
from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes
|
|
from tensorflow.lite.tools import flatbuffer_utils
|
|
from tensorflow.python.eager import function
|
|
from tensorflow.python.framework import convert_to_constants as _convert_to_constants
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import error_interpolation as _error_interpolation
|
|
from tensorflow.python.grappler import tf_optimizer
|
|
from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph
|
|
|
|
# The field name of conversion metadata in the flatbuffer file.
|
|
CONVERSION_METADATA_FIELD_NAME = "CONVERSION_METADATA"
|
|
|
|
# Keras functions used by TFLite
|
|
model_input_signature = _tflite_keras_util.model_input_signature
|
|
trace_model_call = _tflite_keras_util.trace_model_call
|
|
|
|
# Jax functions used by TFLite
|
|
# pylint: disable=g-import-not-at-top
|
|
# pylint: disable=unused-import
|
|
try:
|
|
from jax import xla_computation as _xla_computation
|
|
except ImportError:
|
|
_xla_computation = None
|
|
# pylint: enable=g-import-not-at-top
|
|
# pylint: enable=unused-import
|
|
|
|
# Defined as per TFLite schema
|
|
_MAP_TFLITE_ENUM_TO_TF_TYPES = {
|
|
0: dtypes.float32,
|
|
1: dtypes.float16,
|
|
2: dtypes.int32,
|
|
3: dtypes.uint8,
|
|
4: dtypes.int64,
|
|
5: dtypes.string,
|
|
6: dtypes.bool,
|
|
7: dtypes.int16,
|
|
8: dtypes.complex64,
|
|
9: dtypes.int8,
|
|
10: dtypes.float64,
|
|
11: dtypes.complex128,
|
|
16: dtypes.uint32,
|
|
}
|
|
|
|
_TFLITE_FILE_IDENTIFIER = b"TFL3"
|
|
|
|
_MAP_QUANT_TO_IO_TYPES = {
|
|
dtypes.int8: {dtypes.int8, dtypes.uint8},
|
|
dtypes.int16: {dtypes.int16},
|
|
}
|
|
|
|
|
|
def _convert_tflite_enum_type_to_tf_type(tflite_enum_type):
|
|
"""Converts tflite enum type (eg: 0) to tf type (eg: tf.float32).
|
|
|
|
Args:
|
|
tflite_enum_type: tflite enum type (eg: 0, that corresponds to float32)
|
|
|
|
Raises:
|
|
ValueError: If an invalid tflite enum type is provided.
|
|
|
|
Returns:
|
|
tf type (eg: tf.float32)
|
|
"""
|
|
tf_type = _MAP_TFLITE_ENUM_TO_TF_TYPES.get(tflite_enum_type)
|
|
if tf_type is None:
|
|
raise ValueError(
|
|
"Unsupported enum {}. The valid map of enum to tf types is : {}"
|
|
.format(tflite_enum_type, _MAP_TFLITE_ENUM_TO_TF_TYPES))
|
|
return tf_type
|
|
|
|
|
|
def get_tf_type_name(tf_type):
|
|
"""Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32")."""
|
|
return "tf." + tf_type.name if tf_type else None
|
|
|
|
|
|
def get_tensor_name(tensor):
|
|
"""Returns name of the input tensor.
|
|
|
|
Args:
|
|
tensor: tf.Tensor
|
|
|
|
Returns:
|
|
str
|
|
"""
|
|
parts = tensor.name.split(":")
|
|
if len(parts) > 2:
|
|
raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format(
|
|
len(parts) - 1))
|
|
|
|
# To be consistent with the tensor naming scheme in tensorflow, we need
|
|
# drop the ':0' suffix for the first tensor.
|
|
if len(parts) > 1 and parts[1] != "0":
|
|
return tensor.name
|
|
return parts[0]
|
|
|
|
|
|
def get_tensors_from_tensor_names(graph, tensor_names):
|
|
"""Gets the Tensors associated with the `tensor_names` in the provided graph.
|
|
|
|
Args:
|
|
graph: TensorFlow Graph.
|
|
tensor_names: List of strings that represent names of tensors in the graph.
|
|
|
|
Returns:
|
|
A list of Tensor objects in the same order the names are provided.
|
|
|
|
Raises:
|
|
ValueError:
|
|
tensor_names contains an invalid tensor name.
|
|
"""
|
|
# Get the list of all of the tensors.
|
|
tensor_name_to_tensor = {}
|
|
for op in graph.get_operations():
|
|
for tensor in op.values():
|
|
tensor_name_to_tensor[get_tensor_name(tensor)] = tensor
|
|
|
|
# Get the tensors associated with tensor_names.
|
|
tensors = []
|
|
invalid_tensors = []
|
|
for name in tensor_names:
|
|
if not isinstance(name, str):
|
|
raise ValueError("Invalid type for a tensor name in the provided graph. "
|
|
"Expected type for a tensor name is 'str', instead got "
|
|
"type '{}' for tensor name '{}'".format(
|
|
type(name), name))
|
|
|
|
tensor = tensor_name_to_tensor.get(name)
|
|
if tensor is None:
|
|
invalid_tensors.append(name)
|
|
else:
|
|
tensors.append(tensor)
|
|
|
|
# Throw ValueError if any user input names are not valid tensors.
|
|
if invalid_tensors:
|
|
raise ValueError("Invalid tensors '{}' were found.".format(
|
|
",".join(invalid_tensors)))
|
|
return tensors
|
|
|
|
|
|
def set_tensor_shapes(tensors, shapes):
|
|
"""Sets Tensor shape for each tensor if the shape is defined.
|
|
|
|
Args:
|
|
tensors: TensorFlow tensor.Tensor.
|
|
shapes: Dict of strings representing input tensor names to list of
|
|
integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
|
|
|
|
Raises:
|
|
ValueError:
|
|
`shapes` contains an invalid tensor.
|
|
`shapes` contains an invalid shape for a valid tensor.
|
|
"""
|
|
if shapes:
|
|
tensor_names_to_tensor = {
|
|
get_tensor_name(tensor): tensor for tensor in tensors
|
|
}
|
|
for name, shape in shapes.items():
|
|
if name not in tensor_names_to_tensor:
|
|
raise ValueError("Invalid tensor \'{}\' found in tensor shapes "
|
|
"map.".format(name))
|
|
if shape is not None:
|
|
tensor = tensor_names_to_tensor[name]
|
|
try:
|
|
tensor.set_shape(shape)
|
|
except ValueError as error:
|
|
message = ("The shape of tensor '{0}' cannot be changed from {1} to "
|
|
"{2}. {3}".format(name, tensor.shape, shape, str(error)))
|
|
raise ValueError(message)
|
|
|
|
|
|
def get_grappler_config(optimizers_list):
|
|
"""Creates a tf.compat.v1.ConfigProto for configuring Grappler.
|
|
|
|
Args:
|
|
optimizers_list: List of strings that represents the list of optimizers.
|
|
|
|
Returns:
|
|
tf.ConfigProto.
|
|
"""
|
|
config = _config_pb2.ConfigProto()
|
|
rewrite_options = config.graph_options.rewrite_options
|
|
for optimizer in optimizers_list:
|
|
rewrite_options.optimizers.append(optimizer)
|
|
return config
|
|
|
|
|
|
def run_graph_optimizations(graph_def,
|
|
input_arrays,
|
|
output_arrays,
|
|
config,
|
|
graph=None):
|
|
"""Apply standard TensorFlow optimizations to the graph_def.
|
|
|
|
Args:
|
|
graph_def: Frozen GraphDef to be optimized.
|
|
input_arrays: List of arrays that are considered inputs of the graph.
|
|
output_arrays: List of arrays that are considered outputs of the graph.
|
|
config: tf.ConfigProto.
|
|
graph: TensorFlow Graph. Required when Eager mode is enabled. (default None)
|
|
|
|
Returns:
|
|
A new, optimized GraphDef.
|
|
"""
|
|
meta_graph = _export_meta_graph(graph_def=graph_def, graph=graph)
|
|
|
|
signature = _meta_graph_pb2.SignatureDef()
|
|
for array in input_arrays:
|
|
signature.inputs[array.name].name = array.name
|
|
signature.inputs[array.name].dtype = array.dtype.as_datatype_enum
|
|
signature.inputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto())
|
|
|
|
for array in output_arrays:
|
|
signature.outputs[array.name].name = array.name
|
|
signature.outputs[array.name].dtype = array.dtype.as_datatype_enum
|
|
signature.outputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto())
|
|
|
|
meta_graph.signature_def["not_used_key"].CopyFrom(signature)
|
|
|
|
# We need to add a collection called 'train_op' so that grappler
|
|
# knows what the outputs are.
|
|
fetch_collection = _meta_graph_pb2.CollectionDef()
|
|
for array in input_arrays + output_arrays:
|
|
fetch_collection.node_list.value.append(array.name)
|
|
meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
|
|
|
|
return tf_optimizer.OptimizeGraph(config, meta_graph)
|
|
|
|
|
|
def _convert_op_hints_if_present(sess, graph_def, output_tensors,
|
|
hinted_outputs_nodes):
|
|
if is_frozen_graph(sess):
|
|
raise ValueError("Try to convert op hints, needs unfrozen graph.")
|
|
output_arrays = [get_tensor_name(tensor) for tensor in output_tensors]
|
|
graph_def = _convert_to_constants.convert_variables_to_constants(
|
|
sess, graph_def, output_arrays + hinted_outputs_nodes)
|
|
graph_def = convert_op_hints_to_stubs(graph_def=graph_def)
|
|
return graph_def
|
|
|
|
|
|
def freeze_graph(sess, input_tensors, output_tensors):
|
|
"""Returns a frozen GraphDef.
|
|
|
|
Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the
|
|
existing GraphDef is returned. The Grappler pass is only run on models that
|
|
are frozen in order to inline the functions in the graph.
|
|
If OpHints is present, it will try to convert the OpHint graph.
|
|
|
|
Args:
|
|
sess: TensorFlow Session.
|
|
input_tensors: List of input tensors.
|
|
output_tensors: List of output tensors (only .name is used from this).
|
|
|
|
Returns:
|
|
Frozen GraphDef.
|
|
"""
|
|
# Runs a Grappler pass in order to inline any functions in the graph.
|
|
# Asides from inlining any simple function, Grappler will also try to lower
|
|
# while loop into switch merge representation which is undesired for Ophints,
|
|
# so we simply remove those attributes to prevent Grappler from doing so.
|
|
graph_def = _convert_to_constants.disable_lower_using_switch_merge(
|
|
sess.graph_def)
|
|
config = get_grappler_config(["function"])
|
|
graph_def = run_graph_optimizations(
|
|
graph_def, input_tensors, output_tensors, config, graph=sess.graph)
|
|
|
|
# If ophints are present, just convert them.
|
|
hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
|
|
if hinted_outputs_nodes:
|
|
return _convert_op_hints_if_present(sess, graph_def, output_tensors,
|
|
hinted_outputs_nodes)
|
|
|
|
if not is_frozen_graph(sess):
|
|
output_node_names = [tensor.name.split(":")[0] for tensor in output_tensors]
|
|
return _convert_to_constants.convert_variables_to_constants(
|
|
sess, graph_def, output_node_names
|
|
)
|
|
else:
|
|
return sess.graph_def
|
|
|
|
|
|
def is_frozen_graph(sess):
|
|
"""Determines if the graph is frozen.
|
|
|
|
Determines if a graph has previously been frozen by checking for any
|
|
operations of type Variable*. If variables are found, the graph is not frozen.
|
|
|
|
Args:
|
|
sess: TensorFlow Session.
|
|
|
|
Returns:
|
|
Bool.
|
|
"""
|
|
for op in sess.graph.get_operations():
|
|
if op.type.startswith("Variable") or op.type.endswith("VariableOp"):
|
|
return False
|
|
return True
|
|
|
|
|
|
def build_debug_info_func(original_graph):
|
|
"""Returns a method to retrieve the `GraphDebugInfo` from the original graph.
|
|
|
|
Args:
|
|
original_graph: The original `Graph` containing all the op stack traces.
|
|
|
|
Returns:
|
|
A function which retrieves the stack traces from the original graph and
|
|
converts them to a `GraphDebugInfo` for a given set of nodes.
|
|
"""
|
|
|
|
def f(original_nodes):
|
|
"""Function to create `GraphDebugInfo` for the given `original_nodes`."""
|
|
if not original_graph:
|
|
return None
|
|
# For the given nodes, gets all the op definitions in the original graph.
|
|
useful_ops = []
|
|
for func, name in original_nodes:
|
|
try:
|
|
if not func:
|
|
useful_ops.append((func, original_graph.get_operation_by_name(name)))
|
|
else:
|
|
sub_func = original_graph._get_function(func) # pylint: disable=protected-access
|
|
if isinstance(sub_func, function.AtomicFunction): # pylint: disable=protected-access
|
|
useful_ops.append(
|
|
(func, sub_func.graph.get_operation_by_name(name)))
|
|
else:
|
|
sys.stderr.write(
|
|
"Use '@tf.function' or '@defun' to decorate the function.\n")
|
|
continue
|
|
except KeyError:
|
|
# New node created by graph optimizer. No stack trace from source code.
|
|
continue
|
|
# Convert all the op definitions to stack traces in terms of GraphDebugInfo.
|
|
return _error_interpolation.create_graph_debug_info_def(useful_ops)
|
|
|
|
return f
|
|
|
|
|
|
def convert_debug_info_func(saved_debug_info):
|
|
"""Returns a method to retrieve the `GraphDebugInfo` from the original graph.
|
|
|
|
Args:
|
|
saved_debug_info: The `GraphDebugInfo` containing all the debug info.
|
|
|
|
Returns:
|
|
A function which retrieves the stack traces from the original graph and
|
|
converts them to a `GraphDebugInfo` for a given set of nodes.
|
|
"""
|
|
|
|
def f(original_nodes):
|
|
"""Function to create `GraphDebugInfo` for the given `original_nodes`."""
|
|
del original_nodes
|
|
return saved_debug_info
|
|
|
|
return f
|
|
|
|
|
|
def get_debug_info(nodes_to_debug_info_func, converted_graph):
|
|
"""Returns the debug info for the original nodes in the `converted_graph`.
|
|
|
|
Args:
|
|
nodes_to_debug_info_func: The method to collect the op debug info for the
|
|
nodes.
|
|
converted_graph: A `GraphDef` after optimization and transformation.
|
|
|
|
Returns:
|
|
`GraphDebugInfo` for all the original nodes in `converted_graph`.
|
|
"""
|
|
if not nodes_to_debug_info_func:
|
|
return None
|
|
|
|
# Collect all the debug info nodes from the converted_graph
|
|
original_nodes = set()
|
|
for node in converted_graph.node:
|
|
debug_nodes = node.experimental_debug_info.original_node_names
|
|
debug_funcs = node.experimental_debug_info.original_func_names
|
|
# If the `original_node_names` are empty, uses the node name directly.
|
|
if not debug_nodes:
|
|
original_nodes.add(("", node.name))
|
|
else:
|
|
for i in range(len(debug_nodes)):
|
|
debug_func = "" if i >= len(debug_funcs) else debug_funcs[i]
|
|
original_nodes.add((debug_func, debug_nodes[i]))
|
|
|
|
# Convert the nodes to the debug info proto object.
|
|
return nodes_to_debug_info_func(original_nodes)
|
|
|
|
|
|
def convert_bytes_to_c_source(data,
|
|
array_name,
|
|
max_line_width=80,
|
|
include_guard=None,
|
|
include_path=None,
|
|
use_tensorflow_license=False):
|
|
"""Returns strings representing a C constant array containing `data`.
|
|
|
|
Args:
|
|
data: Byte array that will be converted into a C constant.
|
|
array_name: String to use as the variable name for the constant array.
|
|
max_line_width: The longest line length, for formatting purposes.
|
|
include_guard: Name to use for the include guard macro definition.
|
|
include_path: Optional path to include in the source file.
|
|
use_tensorflow_license: Whether to include the standard TensorFlow Apache2
|
|
license in the generated files.
|
|
|
|
Returns:
|
|
Text that can be compiled as a C source file to link in the data as a
|
|
literal array of values.
|
|
Text that can be used as a C header file to reference the literal array.
|
|
"""
|
|
|
|
starting_pad = " "
|
|
array_lines = []
|
|
array_line = starting_pad
|
|
for value in bytearray(data):
|
|
if (len(array_line) + 4) > max_line_width:
|
|
array_lines.append(array_line + "\n")
|
|
array_line = starting_pad
|
|
array_line += " 0x%02x," % (value,)
|
|
if len(array_line) > len(starting_pad):
|
|
array_lines.append(array_line + "\n")
|
|
array_values = "".join(array_lines)
|
|
|
|
if include_guard is None:
|
|
include_guard = "TENSORFLOW_LITE_UTIL_" + array_name.upper() + "_DATA_H_"
|
|
|
|
if include_path is not None:
|
|
include_line = "#include \"{include_path}\"\n".format(
|
|
include_path=include_path)
|
|
else:
|
|
include_line = ""
|
|
|
|
if use_tensorflow_license:
|
|
license_text = """
|
|
/* Copyright {year} The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
""".format(year=datetime.date.today().year)
|
|
else:
|
|
license_text = ""
|
|
|
|
source_template = """{license_text}
|
|
// This is a TensorFlow Lite model file that has been converted into a C data
|
|
// array using the tensorflow.lite.util.convert_bytes_to_c_source() function.
|
|
// This form is useful for compiling into a binary for devices that don't have a
|
|
// file system.
|
|
|
|
{include_line}
|
|
// We need to keep the data array aligned on some architectures.
|
|
#ifdef __has_attribute
|
|
#define HAVE_ATTRIBUTE(x) __has_attribute(x)
|
|
#else
|
|
#define HAVE_ATTRIBUTE(x) 0
|
|
#endif
|
|
#if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__))
|
|
#define DATA_ALIGN_ATTRIBUTE __attribute__((aligned(4)))
|
|
#else
|
|
#define DATA_ALIGN_ATTRIBUTE
|
|
#endif
|
|
|
|
const unsigned char {array_name}[] DATA_ALIGN_ATTRIBUTE = {{
|
|
{array_values}}};
|
|
const int {array_name}_len = {array_length};
|
|
"""
|
|
|
|
source_text = source_template.format(
|
|
array_name=array_name,
|
|
array_length=len(data),
|
|
array_values=array_values,
|
|
license_text=license_text,
|
|
include_line=include_line)
|
|
|
|
header_template = """
|
|
{license_text}
|
|
|
|
// This is a TensorFlow Lite model file that has been converted into a C data
|
|
// array using the tensorflow.lite.util.convert_bytes_to_c_source() function.
|
|
// This form is useful for compiling into a binary for devices that don't have a
|
|
// file system.
|
|
|
|
#ifndef {include_guard}
|
|
#define {include_guard}
|
|
|
|
extern const unsigned char {array_name}[];
|
|
extern const int {array_name}_len;
|
|
|
|
#endif // {include_guard}
|
|
"""
|
|
|
|
header_text = header_template.format(
|
|
array_name=array_name,
|
|
include_guard=include_guard,
|
|
license_text=license_text)
|
|
|
|
return source_text, header_text
|
|
|
|
|
|
def _convert_model_from_bytearray_to_object(model_bytearray):
|
|
"""Converts a tflite model from a bytearray into a parsable object."""
|
|
model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0)
|
|
model_object = schema_fb.ModelT.InitFromObj(model_object)
|
|
model_object = copy.deepcopy(model_object)
|
|
return model_object
|
|
|
|
|
|
def _convert_model_from_object_to_bytearray(model_object):
|
|
"""Converts a tflite model from a parsable object into a bytearray."""
|
|
# Initial size of the buffer, which will grow automatically if needed
|
|
builder = flatbuffers.Builder(1024)
|
|
model_offset = model_object.Pack(builder)
|
|
builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
|
|
return bytes(builder.Output())
|
|
|
|
|
|
def get_quantize_opcode_idx(model):
|
|
"""Returns the quantize op idx."""
|
|
quant_opcode_idxs = []
|
|
for idx, opcode in enumerate(model.operatorCodes):
|
|
builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
|
|
if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
|
|
quant_opcode_idxs.append(idx)
|
|
return quant_opcode_idxs
|
|
|
|
|
|
def get_dequantize_opcode_idx(model):
|
|
"""Returns the quantize op idx."""
|
|
quant_opcode_idxs = []
|
|
for idx, opcode in enumerate(model.operatorCodes):
|
|
builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
|
|
if builtin_code == schema_fb.BuiltinOperator.DEQUANTIZE:
|
|
quant_opcode_idxs.append(idx)
|
|
return quant_opcode_idxs
|
|
|
|
|
|
def _update_signature_def_tensors(tensor_maps, map_old_to_new_tensors):
|
|
"""Update the tensors in the SignatureDef's TensorMaps."""
|
|
for i in range(len(tensor_maps)):
|
|
if tensor_maps[i].tensorIndex in map_old_to_new_tensors:
|
|
tensor_maps[i].tensorIndex = (
|
|
map_old_to_new_tensors[tensor_maps[i].tensorIndex])
|
|
|
|
|
|
def _remove_tensors_from_model(model, remove_tensors_idxs):
|
|
"""Remove tensors from model."""
|
|
if not remove_tensors_idxs:
|
|
return
|
|
if len(model.subgraphs) > 1:
|
|
logging.info("Skipping the removal of dangled tensors since the model has "
|
|
"multiple subgraphs and tensors can be used in the different "
|
|
"subgraph(s)")
|
|
return
|
|
subgraph = model.subgraphs[0]
|
|
tensors = subgraph.tensors
|
|
operators = subgraph.operators
|
|
|
|
logging.debug("Removing tensors at indices : %s", remove_tensors_idxs)
|
|
# An optimized check to validate if "remove_tensors_idxs" (eg: [4,5,6]) is an
|
|
# exact subset, with ordering, of "tensors" indices (eg: [0,1,2,3,4,5,6]).
|
|
if min(remove_tensors_idxs) == len(tensors) - len(remove_tensors_idxs):
|
|
logging.debug("Removing tensors only at the end of the tensor list")
|
|
del tensors[min(remove_tensors_idxs):]
|
|
else:
|
|
logging.debug("Removing tensors requires updating the model")
|
|
# Map the old tensor indices to new tensor indices
|
|
d_old_to_new_tensors = {}
|
|
left_shift_by = 0
|
|
for idx in range(len(tensors)):
|
|
if idx in remove_tensors_idxs:
|
|
left_shift_by += 1
|
|
else:
|
|
d_old_to_new_tensors[idx] = idx - left_shift_by
|
|
logging.debug("Old to new tensors map: %s", d_old_to_new_tensors.__str__())
|
|
# Update tensor indices referenced throughout the model
|
|
def update_tensors(tensor_idxs):
|
|
for i, ti in enumerate(tensor_idxs):
|
|
tensor_idxs[i] = d_old_to_new_tensors.get(ti, -1)
|
|
update_tensors(subgraph.inputs)
|
|
update_tensors(subgraph.outputs)
|
|
for op in operators:
|
|
update_tensors(op.inputs)
|
|
update_tensors(op.outputs)
|
|
if model.signatureDefs:
|
|
signature_def = model.signatureDefs[0]
|
|
_update_signature_def_tensors(signature_def.inputs, d_old_to_new_tensors)
|
|
_update_signature_def_tensors(signature_def.outputs, d_old_to_new_tensors)
|
|
# Delete the tensors
|
|
for idx in sorted(remove_tensors_idxs, reverse=True):
|
|
tensors.pop(idx)
|
|
logging.debug("Removed tensors marked for deletion")
|
|
|
|
|
|
def _modify_model_input_type(model, inference_input_type=dtypes.float32):
|
|
"""Modify model input type."""
|
|
if inference_input_type == dtypes.float32:
|
|
return
|
|
|
|
if not model.signatureDefs:
|
|
_modify_model_input_type_per_subgraph(model, 0, -1, inference_input_type)
|
|
return
|
|
|
|
for signature_index, signature_def in enumerate(model.signatureDefs):
|
|
_modify_model_input_type_per_subgraph(model, signature_def.subgraphIndex,
|
|
signature_index, inference_input_type)
|
|
|
|
|
|
def _modify_model_input_type_per_subgraph(model, subgraph_index,
|
|
signature_index,
|
|
inference_input_type):
|
|
"""Modify model input type per subgraph."""
|
|
subgraph = model.subgraphs[subgraph_index]
|
|
tensors = subgraph.tensors
|
|
operators = subgraph.operators
|
|
|
|
# Find all quantize operators
|
|
quant_opcode_idxs = get_quantize_opcode_idx(model)
|
|
if operators and not quant_opcode_idxs:
|
|
for input_idx in subgraph.inputs:
|
|
input_type = _convert_tflite_enum_type_to_tf_type(tensors[input_idx].type)
|
|
if input_type == dtypes.float32:
|
|
raise ValueError("Model input is not dequantized.")
|
|
# None of the inputs have float32, then they must be int16, int8, or bool
|
|
return
|
|
|
|
# Validate that the model input is quantized
|
|
input_quant_ops = []
|
|
for op in operators:
|
|
# Find operators that quantize model input
|
|
if op.opcodeIndex in quant_opcode_idxs and op.inputs[0] in subgraph.inputs:
|
|
float_tensor, quant_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
|
|
# If found, validate that the operator's input type is float
|
|
float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
|
|
if float_type != dtypes.float32:
|
|
if float_type == inference_input_type:
|
|
continue
|
|
else:
|
|
raise ValueError(
|
|
"Initial model input type must be tf.float32. Expected type for "
|
|
"tensor with name '{}' is tf.float32, instead type is {}".format(
|
|
float_tensor.name, get_tf_type_name(float_type)))
|
|
# If found, validate that the operator output is quantized and compatible
|
|
# with the final model input type
|
|
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
|
|
if quant_type not in _MAP_QUANT_TO_IO_TYPES:
|
|
raise ValueError(
|
|
"Initial model input is not quantized. Expected type for "
|
|
"tensor with name '{}' should be in {}, instead type is {}".format(
|
|
quant_tensor.name,
|
|
tuple(get_tf_type_name(t) for t in
|
|
_MAP_QUANT_TO_IO_TYPES.keys()),
|
|
get_tf_type_name(quant_type)))
|
|
else:
|
|
inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
|
|
if inference_input_type not in inference_io_types:
|
|
raise ValueError(
|
|
"Unsupported `inference_input_type` value. Expected to be in "
|
|
"{}, instead got {}.".format(
|
|
tuple(get_tf_type_name(t) for t in inference_io_types),
|
|
get_tf_type_name(inference_input_type)))
|
|
input_quant_ops.append(op)
|
|
|
|
if len(subgraph.inputs) != len(input_quant_ops):
|
|
logging.warning(
|
|
"For model inputs containing unsupported operations which cannot be "
|
|
"quantized, the `inference_input_type` attribute will default to the "
|
|
"original type."
|
|
)
|
|
|
|
# Modify model input type
|
|
if inference_input_type == dtypes.uint8:
|
|
# Change quant op (float to int8) to quant op (uint8 to int8)
|
|
for op in input_quant_ops:
|
|
int8_quantization = tensors[op.outputs[0]].quantization
|
|
uint8_quantization = schema_fb.QuantizationParametersT()
|
|
uint8_quantization.scale = [int8_quantization.scale[0]]
|
|
uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
|
|
tensors[op.inputs[0]].quantization = uint8_quantization
|
|
tensors[op.inputs[0]].type = schema_fb.TensorType.UINT8
|
|
elif inference_input_type in _MAP_QUANT_TO_IO_TYPES:
|
|
# Remove the inputs and the quant operator
|
|
remove_tensors_idxs = set()
|
|
for op in input_quant_ops:
|
|
subgraph.inputs[subgraph.inputs == op.inputs[0]] = op.outputs[0]
|
|
if signature_index >= 0:
|
|
signature_def = model.signatureDefs[signature_index]
|
|
for i in range(len(signature_def.inputs)):
|
|
if signature_def.inputs[i].tensorIndex == op.inputs[0]:
|
|
signature_def.inputs[i].tensorIndex = op.outputs[0]
|
|
remove_tensors_idxs.add(op.inputs[0])
|
|
operators.remove(op)
|
|
# Remove tensors marked for deletion.
|
|
_remove_tensors_from_model(model, remove_tensors_idxs)
|
|
else:
|
|
raise ValueError(
|
|
"Unsupported `inference_input_type` value {}.".format(
|
|
get_tf_type_name(inference_input_type)))
|
|
|
|
|
|
def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
|
"""Modify model output type."""
|
|
if inference_output_type == dtypes.float32:
|
|
return
|
|
|
|
if not model.signatureDefs:
|
|
_modify_model_output_type_per_subgraph(model, 0, -1, inference_output_type)
|
|
return
|
|
|
|
for signature_index, signature_def in enumerate(model.signatureDefs):
|
|
_modify_model_output_type_per_subgraph(model, signature_def.subgraphIndex,
|
|
signature_index,
|
|
inference_output_type)
|
|
|
|
|
|
def _modify_model_output_type_per_subgraph(model, subgraph_index,
|
|
signature_index,
|
|
inference_output_type):
|
|
"""Modify model output type per subgraph."""
|
|
subgraph = model.subgraphs[subgraph_index]
|
|
tensors = subgraph.tensors
|
|
operators = subgraph.operators
|
|
|
|
# Find all dequantize operators
|
|
dequant_opcode_idxs = get_dequantize_opcode_idx(model)
|
|
if operators and not dequant_opcode_idxs:
|
|
for output in subgraph.outputs:
|
|
output_type = _convert_tflite_enum_type_to_tf_type(tensors[output].type)
|
|
if output_type == dtypes.float32:
|
|
raise ValueError("Model output is not dequantized.")
|
|
# None of the outputs have float32, then they must be int16, int8, or bool
|
|
return
|
|
|
|
# Validate that the model output is dequantized
|
|
output_dequant_ops = []
|
|
for op in operators:
|
|
# Find operators that dequantize model output
|
|
if (op.opcodeIndex in dequant_opcode_idxs and
|
|
op.outputs[0] in subgraph.outputs):
|
|
# If found, validate that the operator's output type is float
|
|
quant_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
|
|
float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
|
|
if float_type != dtypes.float32:
|
|
if float_type == inference_output_type:
|
|
continue
|
|
else:
|
|
raise ValueError(
|
|
"Initial model output type must be tf.float32. Expected type for "
|
|
"tensor with name '{}' is tf.float32, instead type is {}".format(
|
|
float_tensor.name, get_tf_type_name(float_type)))
|
|
# If found, validate that the operator input is quantized and compatible
|
|
# with the final model output type
|
|
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
|
|
if quant_type not in _MAP_QUANT_TO_IO_TYPES:
|
|
raise ValueError(
|
|
"Initial model output is not dequantized. Expected type for "
|
|
"tensor with name '{}' should be in {}, instead type is {}".format(
|
|
quant_tensor.name,
|
|
tuple(get_tf_type_name(t) for t in
|
|
_MAP_QUANT_TO_IO_TYPES.keys()),
|
|
get_tf_type_name(quant_type)))
|
|
else:
|
|
inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
|
|
if inference_output_type not in inference_io_types:
|
|
raise ValueError(
|
|
"Unsupported `inference_output_type` value. Expected to be in "
|
|
"{}, instead got {}.".format(
|
|
tuple(get_tf_type_name(t) for t in inference_io_types),
|
|
get_tf_type_name(inference_output_type)))
|
|
output_dequant_ops.append(op)
|
|
|
|
if len(subgraph.outputs) != len(output_dequant_ops):
|
|
logging.warning(
|
|
"For model outputs containing unsupported operations which cannot be "
|
|
"quantized, the `inference_output_type` attribute will default to the "
|
|
"original type."
|
|
)
|
|
|
|
# Modify model output type
|
|
if inference_output_type == dtypes.uint8:
|
|
# Find a quantize operator
|
|
quant_opcode_idx = -1
|
|
for idx, opcode in enumerate(model.operatorCodes):
|
|
builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
|
|
if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
|
|
quant_opcode_idx = idx
|
|
break
|
|
# Create a quantize operator, if none exist
|
|
if quant_opcode_idx == -1:
|
|
quant_op = schema_fb.OperatorCodeT()
|
|
quant_op.builtinCode = schema_fb.BuiltinOperator.QUANTIZE
|
|
quant_op.deprecatedBuiltinCode = schema_fb.BuiltinOperator.QUANTIZE
|
|
model.operatorCodes.append(quant_op)
|
|
quant_opcode_idx = len(model.operatorCodes) - 1
|
|
# Change dequant op (int8 to float) to quant op (int8 to uint8)
|
|
for op in output_dequant_ops:
|
|
op.opcodeIndex = quant_opcode_idx
|
|
int8_quantization = tensors[op.inputs[0]].quantization
|
|
uint8_quantization = schema_fb.QuantizationParametersT()
|
|
uint8_quantization.scale = [int8_quantization.scale[0]]
|
|
uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
|
|
tensors[op.outputs[0]].quantization = uint8_quantization
|
|
tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8
|
|
elif inference_output_type in _MAP_QUANT_TO_IO_TYPES:
|
|
# Remove the outputs and the dequant operator
|
|
remove_tensors_idxs = set()
|
|
for op in output_dequant_ops:
|
|
subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
|
|
if signature_index >= 0:
|
|
signature_def = model.signatureDefs[signature_index]
|
|
for i in range(len(signature_def.outputs)):
|
|
if signature_def.outputs[i].tensorIndex == op.outputs[0]:
|
|
signature_def.outputs[i].tensorIndex = op.inputs[0]
|
|
remove_tensors_idxs.add(op.outputs[0])
|
|
operators.remove(op)
|
|
# Remove tensors marked for deletion.
|
|
_remove_tensors_from_model(model, remove_tensors_idxs)
|
|
else:
|
|
raise ValueError(
|
|
"Unsupported `inference_output_type` value {}.".format(
|
|
get_tf_type_name(inference_output_type)))
|
|
|
|
|
|
def _remove_redundant_quantize_ops(model):
|
|
"""Finds back to back quantize ops and remove the first quantize op."""
|
|
if not model.signatureDefs:
|
|
_remove_redundant_quantize_ops_per_subgraph(model, 0, -1)
|
|
return
|
|
|
|
for signature_index, signature_def in enumerate(model.signatureDefs):
|
|
_remove_redundant_quantize_ops_per_subgraph(model,
|
|
signature_def.subgraphIndex,
|
|
signature_index)
|
|
|
|
|
|
def _remove_redundant_quantize_ops_per_subgraph(model, subgraph_index,
|
|
signature_index):
|
|
"""Remove redundant quantize ops per subgraph."""
|
|
subgraph = model.subgraphs[subgraph_index]
|
|
tensors = subgraph.tensors
|
|
operators = subgraph.operators
|
|
|
|
# Find all quantize operators.
|
|
quant_opcode_idxs = get_quantize_opcode_idx(model)
|
|
dequant_opcode_idxs = get_dequantize_opcode_idx(model)
|
|
|
|
# Find all redundant quant tensors.
|
|
all_quant_ops = []
|
|
redundant_quant_tensors = {}
|
|
output_dequant_tensors = {}
|
|
for op in operators:
|
|
if op.opcodeIndex in quant_opcode_idxs:
|
|
all_quant_ops.append(op)
|
|
input_tensor = tensors[op.inputs[0]]
|
|
output_tensor = tensors[op.outputs[0]]
|
|
input_type = _convert_tflite_enum_type_to_tf_type(input_tensor.type)
|
|
output_type = _convert_tflite_enum_type_to_tf_type(output_tensor.type)
|
|
# This is a requantize op, so write down its input tensor index.
|
|
if input_type != dtypes.float32 and output_type != dtypes.float32:
|
|
redundant_quant_tensors[op.inputs[0]] = op
|
|
if (op.opcodeIndex in dequant_opcode_idxs and
|
|
op.outputs[0] in subgraph.outputs):
|
|
output_dequant_tensors[op.inputs[0]] = op
|
|
|
|
# Remove all the quant ops which produce the redundant quant tensors.
|
|
for op in all_quant_ops:
|
|
output_tensor_idx = op.outputs[0]
|
|
if output_tensor_idx in redundant_quant_tensors:
|
|
requantize_op = redundant_quant_tensors[output_tensor_idx]
|
|
if model.signatureDefs:
|
|
signature_def = model.signatureDefs[0]
|
|
for output in signature_def.outputs:
|
|
if output.tensorIndex == op.outputs[0]:
|
|
output.tensorIndex = op.inputs[0]
|
|
deleted_tensor = requantize_op.inputs[0]
|
|
# Reset the input of the requantize op to the float input
|
|
requantize_op.inputs[0] = op.inputs[0]
|
|
# Migrate other operator users to output tensor of requantize op
|
|
for op_user in operators:
|
|
if deleted_tensor in op_user.inputs and op_user != requantize_op:
|
|
for idx, input_tensor in enumerate(op_user.inputs):
|
|
if input_tensor == deleted_tensor:
|
|
op_user.inputs[idx] = requantize_op.outputs[0]
|
|
operators.remove(op)
|
|
|
|
# Remove all the quant ops which connect to the output dequant op.
|
|
for op in all_quant_ops:
|
|
output_tensor_idx = op.outputs[0]
|
|
if output_tensor_idx in output_dequant_tensors:
|
|
dequant_op = output_dequant_tensors[output_tensor_idx]
|
|
subgraph.outputs[subgraph.outputs == dequant_op.outputs[0]] = op.inputs[0]
|
|
if signature_index >= 0:
|
|
signature_def = model.signatureDefs[signature_index]
|
|
for output in signature_def.outputs:
|
|
if output.tensorIndex == dequant_op.outputs[0]:
|
|
output.tensorIndex = op.inputs[0]
|
|
operators.remove(op)
|
|
operators.remove(dequant_op)
|
|
|
|
|
|
def modify_model_io_type(
|
|
model, inference_input_type=dtypes.float32,
|
|
inference_output_type=dtypes.float32):
|
|
"""Modify the input/output type of a tflite model.
|
|
|
|
Args:
|
|
model: A tflite model.
|
|
inference_input_type: tf.DType representing modified input type.
|
|
(default tf.float32. If model input is int8 quantized, it must be in
|
|
{tf.float32, tf.int8,tf.uint8}, else if model input is int16 quantized,
|
|
it must be in {tf.float32, tf.int16}, else it must be tf.float32)
|
|
inference_output_type: tf.DType representing modified output type.
|
|
(default tf.float32. If model output is int8 dequantized, it must be in
|
|
{tf.float32, tf.int8,tf.uint8}, else if model output is int16 dequantized,
|
|
it must be in {tf.float32, tf.int16}, else it must be tf.float32)
|
|
Returns:
|
|
A tflite model with modified input/output type.
|
|
|
|
Raises:
|
|
ValueError: If `inference_input_type`/`inference_output_type` is unsupported
|
|
or a supported integer type is specified for a model whose input/output is
|
|
not quantized/dequantized.
|
|
RuntimeError: If the modification was unsuccessful.
|
|
|
|
"""
|
|
if (inference_input_type == dtypes.float32 and
|
|
inference_output_type == dtypes.float32):
|
|
return model
|
|
|
|
model_object = _convert_model_from_bytearray_to_object(model)
|
|
|
|
_modify_model_input_type(model_object, inference_input_type)
|
|
|
|
_modify_model_output_type(model_object, inference_output_type)
|
|
|
|
_remove_redundant_quantize_ops(model_object)
|
|
|
|
return _convert_model_from_object_to_bytearray(model_object)
|
|
|
|
|
|
def get_sparsity_modes(model_object):
|
|
"""Get sparsity modes used in a tflite model.
|
|
|
|
The sparsity modes are listed in conversion_metadata.fbs file.
|
|
|
|
Args:
|
|
model_object: A tflite model in object form.
|
|
|
|
Returns:
|
|
The list of sparsity modes used in the model.
|
|
"""
|
|
if not model_object or not model_object.metadata:
|
|
return []
|
|
|
|
result = set()
|
|
for subgraph in model_object.subgraphs:
|
|
for tensor in subgraph.tensors:
|
|
if not tensor.sparsity:
|
|
continue
|
|
|
|
# Block map is the list if indexes where the block size is larger than 1.
|
|
# So empty block map means it is random sparsity.
|
|
if not tensor.sparsity.blockMap:
|
|
result.add(
|
|
conversion_metadata_fb.ModelOptimizationMode.RANDOM_SPARSITY)
|
|
else:
|
|
result.add(
|
|
conversion_metadata_fb.ModelOptimizationMode.BLOCK_SPARSITY)
|
|
|
|
return list(result)
|
|
|
|
|
|
def populate_conversion_metadata(model_object, metadata):
|
|
"""Add or update conversion metadata to a tflite model.
|
|
|
|
Args:
|
|
model_object: A tflite model in object form.
|
|
metadata: The conversion metadata.
|
|
|
|
Returns:
|
|
A tflite model object with embedded conversion metadata.
|
|
"""
|
|
try:
|
|
metadata_builder = flatbuffers.Builder(0)
|
|
metadata_builder.Finish(metadata.Pack(metadata_builder))
|
|
buffer_field = schema_fb.BufferT()
|
|
buffer_field.data = metadata_builder.Output()
|
|
|
|
if not model_object.metadata:
|
|
model_object.metadata = []
|
|
else:
|
|
# Check if metadata has already been populated.
|
|
for meta in model_object.metadata:
|
|
if meta.name.decode("utf-8") == CONVERSION_METADATA_FIELD_NAME:
|
|
model_object.buffers[meta.buffer] = buffer_field
|
|
return model_object
|
|
|
|
if not model_object.buffers:
|
|
model_object.buffers = []
|
|
model_object.buffers.append(buffer_field)
|
|
# Creates a new metadata field.
|
|
metadata_field = schema_fb.MetadataT()
|
|
metadata_field.name = CONVERSION_METADATA_FIELD_NAME
|
|
metadata_field.buffer = len(model_object.buffers) - 1
|
|
model_object.metadata.append(metadata_field)
|
|
|
|
return model_object
|
|
except Exception: # pylint: disable=broad-except
|
|
return model_object
|
|
|
|
|
|
def get_conversion_metadata(model_buffer):
|
|
"""Read conversion metadata from a tflite model.
|
|
|
|
Args:
|
|
model_buffer: A tflite model.
|
|
|
|
Returns:
|
|
The conversion metadata or None if it is not populated.
|
|
"""
|
|
model_object = flatbuffer_utils.convert_bytearray_to_object(model_buffer)
|
|
if not model_object or not model_object.metadata:
|
|
return None
|
|
|
|
for meta in model_object.metadata:
|
|
if meta.name.decode("utf-8") == CONVERSION_METADATA_FIELD_NAME:
|
|
metadata_buf = model_object.buffers[meta.buffer].data.tobytes()
|
|
return conversion_metadata_fb.ConversionMetadataT.InitFromObj(
|
|
conversion_metadata_fb.ConversionMetadata.GetRootAsConversionMetadata(
|
|
metadata_buf, 0))
|
|
|
|
return None
|