433 lines
14 KiB
Python
433 lines
14 KiB
Python
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Utility functions for FlatBuffers.
|
||
|
|
||
|
All functions that are commonly used to work with FlatBuffers.
|
||
|
|
||
|
Refer to the tensorflow lite flatbuffer schema here:
|
||
|
tensorflow/lite/schema/schema.fbs
|
||
|
"""
|
||
|
|
||
|
import copy
|
||
|
import random
|
||
|
import re
|
||
|
import struct
|
||
|
import sys
|
||
|
|
||
|
import flatbuffers
|
||
|
|
||
|
from tensorflow.lite.python import schema_py_generated as schema_fb
|
||
|
from tensorflow.lite.python import schema_util
|
||
|
from tensorflow.python.platform import gfile
|
||
|
|
||
|
_TFLITE_FILE_IDENTIFIER = b'TFL3'
|
||
|
|
||
|
|
||
|
def convert_bytearray_to_object(model_bytearray):
|
||
|
"""Converts a tflite model from a bytearray to an object for parsing."""
|
||
|
model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0)
|
||
|
return schema_fb.ModelT.InitFromObj(model_object)
|
||
|
|
||
|
|
||
|
def read_model(input_tflite_file):
|
||
|
"""Reads a tflite model as a python object.
|
||
|
|
||
|
Args:
|
||
|
input_tflite_file: Full path name to the input tflite file
|
||
|
|
||
|
Raises:
|
||
|
RuntimeError: If input_tflite_file path is invalid.
|
||
|
IOError: If input_tflite_file cannot be opened.
|
||
|
|
||
|
Returns:
|
||
|
A python object corresponding to the input tflite file.
|
||
|
"""
|
||
|
if not gfile.Exists(input_tflite_file):
|
||
|
raise RuntimeError('Input file not found at %r\n' % input_tflite_file)
|
||
|
with gfile.GFile(input_tflite_file, 'rb') as input_file_handle:
|
||
|
model_bytearray = bytearray(input_file_handle.read())
|
||
|
model = convert_bytearray_to_object(model_bytearray)
|
||
|
if sys.byteorder == 'big':
|
||
|
byte_swap_tflite_model_obj(model, 'little', 'big')
|
||
|
return model
|
||
|
|
||
|
|
||
|
def read_model_with_mutable_tensors(input_tflite_file):
|
||
|
"""Reads a tflite model as a python object with mutable tensors.
|
||
|
|
||
|
Similar to read_model() with the addition that the returned object has
|
||
|
mutable tensors (read_model() returns an object with immutable tensors).
|
||
|
|
||
|
NOTE: This API only works for TFLite generated with
|
||
|
_experimental_use_buffer_offset=false
|
||
|
|
||
|
Args:
|
||
|
input_tflite_file: Full path name to the input tflite file
|
||
|
|
||
|
Raises:
|
||
|
RuntimeError: If input_tflite_file path is invalid.
|
||
|
IOError: If input_tflite_file cannot be opened.
|
||
|
|
||
|
Returns:
|
||
|
A mutable python object corresponding to the input tflite file.
|
||
|
"""
|
||
|
return copy.deepcopy(read_model(input_tflite_file))
|
||
|
|
||
|
|
||
|
def convert_object_to_bytearray(model_object, extra_buffer=b''):
|
||
|
"""Converts a tflite model from an object to a immutable bytearray."""
|
||
|
# Initial size of the buffer, which will grow automatically if needed
|
||
|
builder = flatbuffers.Builder(1024)
|
||
|
model_offset = model_object.Pack(builder)
|
||
|
builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
|
||
|
model_bytearray = bytes(builder.Output())
|
||
|
model_bytearray = model_bytearray + extra_buffer
|
||
|
return model_bytearray
|
||
|
|
||
|
|
||
|
def write_model(model_object, output_tflite_file):
|
||
|
"""Writes the tflite model, a python object, into the output file.
|
||
|
|
||
|
NOTE: This API only works for TFLite generated with
|
||
|
_experimental_use_buffer_offset=false
|
||
|
|
||
|
Args:
|
||
|
model_object: A tflite model as a python object
|
||
|
output_tflite_file: Full path name to the output tflite file.
|
||
|
|
||
|
Raises:
|
||
|
IOError: If output_tflite_file path is invalid or cannot be opened.
|
||
|
"""
|
||
|
if sys.byteorder == 'big':
|
||
|
model_object = copy.deepcopy(model_object)
|
||
|
byte_swap_tflite_model_obj(model_object, 'big', 'little')
|
||
|
model_bytearray = convert_object_to_bytearray(model_object)
|
||
|
with gfile.GFile(output_tflite_file, 'wb') as output_file_handle:
|
||
|
output_file_handle.write(model_bytearray)
|
||
|
|
||
|
|
||
|
def strip_strings(model):
|
||
|
"""Strips all nonessential strings from the model to reduce model size.
|
||
|
|
||
|
We remove the following strings:
|
||
|
(find strings by searching ":string" in the tensorflow lite flatbuffer schema)
|
||
|
1. Model description
|
||
|
2. SubGraph name
|
||
|
3. Tensor names
|
||
|
We retain OperatorCode custom_code and Metadata name.
|
||
|
|
||
|
Args:
|
||
|
model: The model from which to remove nonessential strings.
|
||
|
"""
|
||
|
|
||
|
model.description = None
|
||
|
for subgraph in model.subgraphs:
|
||
|
subgraph.name = None
|
||
|
for tensor in subgraph.tensors:
|
||
|
tensor.name = None
|
||
|
# We clear all signature_def structure, since without names it is useless.
|
||
|
model.signatureDefs = None
|
||
|
|
||
|
|
||
|
def type_to_name(tensor_type):
|
||
|
"""Converts a numerical enum to a readable tensor type."""
|
||
|
for name, value in schema_fb.TensorType.__dict__.items():
|
||
|
if value == tensor_type:
|
||
|
return name
|
||
|
return None
|
||
|
|
||
|
|
||
|
def randomize_weights(model, random_seed=0, buffers_to_skip=None):
|
||
|
"""Randomize weights in a model.
|
||
|
|
||
|
Args:
|
||
|
model: The model in which to randomize weights.
|
||
|
random_seed: The input to the random number generator (default value is 0).
|
||
|
buffers_to_skip: The list of buffer indices to skip. The weights in these
|
||
|
buffers are left unmodified.
|
||
|
"""
|
||
|
|
||
|
# The input to the random seed generator. The default value is 0.
|
||
|
random.seed(random_seed)
|
||
|
|
||
|
# Parse model buffers which store the model weights
|
||
|
buffers = model.buffers
|
||
|
buffer_ids = range(1, len(buffers)) # ignore index 0 as it's always None
|
||
|
if buffers_to_skip is not None:
|
||
|
buffer_ids = [idx for idx in buffer_ids if idx not in buffers_to_skip]
|
||
|
|
||
|
buffer_types = {}
|
||
|
for graph in model.subgraphs:
|
||
|
for op in graph.operators:
|
||
|
if op.inputs is None:
|
||
|
break
|
||
|
for input_idx in op.inputs:
|
||
|
tensor = graph.tensors[input_idx]
|
||
|
buffer_types[tensor.buffer] = type_to_name(tensor.type)
|
||
|
|
||
|
for i in buffer_ids:
|
||
|
buffer_i_data = buffers[i].data
|
||
|
buffer_i_size = 0 if buffer_i_data is None else buffer_i_data.size
|
||
|
if buffer_i_size == 0:
|
||
|
continue
|
||
|
|
||
|
# Raw data buffers are of type ubyte (or uint8) whose values lie in the
|
||
|
# range [0, 255]. Those ubytes (or unint8s) are the underlying
|
||
|
# representation of each datatype. For example, a bias tensor of type
|
||
|
# int32 appears as a buffer 4 times it's length of type ubyte (or uint8).
|
||
|
# For floats, we need to generate a valid float and then pack it into
|
||
|
# the raw bytes in place.
|
||
|
buffer_type = buffer_types.get(i, 'INT8')
|
||
|
if buffer_type.startswith('FLOAT'):
|
||
|
format_code = 'e' if buffer_type == 'FLOAT16' else 'f'
|
||
|
for offset in range(0, buffer_i_size, struct.calcsize(format_code)):
|
||
|
value = random.uniform(-0.5, 0.5) # See http://b/152324470#comment2
|
||
|
struct.pack_into(format_code, buffer_i_data, offset, value)
|
||
|
else:
|
||
|
for j in range(buffer_i_size):
|
||
|
buffer_i_data[j] = random.randint(0, 255)
|
||
|
|
||
|
|
||
|
def rename_custom_ops(model, map_custom_op_renames):
|
||
|
"""Rename custom ops so they use the same naming style as builtin ops.
|
||
|
|
||
|
Args:
|
||
|
model: The input tflite model.
|
||
|
map_custom_op_renames: A mapping from old to new custom op names.
|
||
|
"""
|
||
|
for op_code in model.operatorCodes:
|
||
|
if op_code.customCode:
|
||
|
op_code_str = op_code.customCode.decode('ascii')
|
||
|
if op_code_str in map_custom_op_renames:
|
||
|
op_code.customCode = map_custom_op_renames[op_code_str].encode('ascii')
|
||
|
|
||
|
|
||
|
def opcode_to_name(model, op_code):
|
||
|
"""Converts a TFLite op_code to the human readable name.
|
||
|
|
||
|
Args:
|
||
|
model: The input tflite model.
|
||
|
op_code: The op_code to resolve to a readable name.
|
||
|
|
||
|
Returns:
|
||
|
A string containing the human readable op name, or None if not resolvable.
|
||
|
"""
|
||
|
op = model.operatorCodes[op_code]
|
||
|
code = max(op.builtinCode, op.deprecatedBuiltinCode)
|
||
|
for name, value in vars(schema_fb.BuiltinOperator).items():
|
||
|
if value == code:
|
||
|
return name
|
||
|
return None
|
||
|
|
||
|
|
||
|
def xxd_output_to_bytes(input_cc_file):
|
||
|
"""Converts xxd output C++ source file to bytes (immutable).
|
||
|
|
||
|
Args:
|
||
|
input_cc_file: Full path name to th C++ source file dumped by xxd
|
||
|
|
||
|
Raises:
|
||
|
RuntimeError: If input_cc_file path is invalid.
|
||
|
IOError: If input_cc_file cannot be opened.
|
||
|
|
||
|
Returns:
|
||
|
A bytearray corresponding to the input cc file array.
|
||
|
"""
|
||
|
# Match hex values in the string with comma as separator
|
||
|
pattern = re.compile(r'\W*(0x[0-9a-fA-F,x ]+).*')
|
||
|
|
||
|
model_bytearray = bytearray()
|
||
|
|
||
|
with open(input_cc_file) as file_handle:
|
||
|
for line in file_handle:
|
||
|
values_match = pattern.match(line)
|
||
|
|
||
|
if values_match is None:
|
||
|
continue
|
||
|
|
||
|
# Match in the parentheses (hex array only)
|
||
|
list_text = values_match.group(1)
|
||
|
|
||
|
# Extract hex values (text) from the line
|
||
|
# e.g. 0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c,
|
||
|
values_text = filter(None, list_text.split(','))
|
||
|
|
||
|
# Convert to hex
|
||
|
values = [int(x, base=16) for x in values_text]
|
||
|
model_bytearray.extend(values)
|
||
|
|
||
|
return bytes(model_bytearray)
|
||
|
|
||
|
|
||
|
def xxd_output_to_object(input_cc_file):
|
||
|
"""Converts xxd output C++ source file to object.
|
||
|
|
||
|
Args:
|
||
|
input_cc_file: Full path name to th C++ source file dumped by xxd
|
||
|
|
||
|
Raises:
|
||
|
RuntimeError: If input_cc_file path is invalid.
|
||
|
IOError: If input_cc_file cannot be opened.
|
||
|
|
||
|
Returns:
|
||
|
A python object corresponding to the input tflite file.
|
||
|
"""
|
||
|
model_bytes = xxd_output_to_bytes(input_cc_file)
|
||
|
return convert_bytearray_to_object(model_bytes)
|
||
|
|
||
|
|
||
|
def byte_swap_buffer_content(buffer, chunksize, from_endiness, to_endiness):
|
||
|
"""Helper function for byte-swapping the buffers field."""
|
||
|
to_swap = [
|
||
|
buffer.data[i : i + chunksize]
|
||
|
for i in range(0, len(buffer.data), chunksize)
|
||
|
]
|
||
|
buffer.data = b''.join(
|
||
|
[
|
||
|
int.from_bytes(byteswap, from_endiness).to_bytes(
|
||
|
chunksize, to_endiness
|
||
|
)
|
||
|
for byteswap in to_swap
|
||
|
]
|
||
|
)
|
||
|
|
||
|
|
||
|
def byte_swap_string_content(buffer, from_endiness, to_endiness):
|
||
|
"""Helper function for byte-swapping the string buffer.
|
||
|
|
||
|
Args:
|
||
|
buffer: TFLite string buffer of from_endiness format.
|
||
|
from_endiness: The original endianness format of the string buffer.
|
||
|
to_endiness: The destined endianness format of the string buffer.
|
||
|
"""
|
||
|
num_of_strings = int.from_bytes(buffer.data[0:4], from_endiness)
|
||
|
string_content = bytearray(buffer.data[4 * (num_of_strings + 2) :])
|
||
|
prefix_data = b''.join(
|
||
|
[
|
||
|
int.from_bytes(buffer.data[i : i + 4], from_endiness).to_bytes(
|
||
|
4, to_endiness
|
||
|
)
|
||
|
for i in range(0, (num_of_strings + 1) * 4 + 1, 4)
|
||
|
]
|
||
|
)
|
||
|
buffer.data = prefix_data + string_content
|
||
|
|
||
|
|
||
|
def byte_swap_tflite_model_obj(model, from_endiness, to_endiness):
|
||
|
"""Byte swaps the buffers field in a TFLite model.
|
||
|
|
||
|
Args:
|
||
|
model: TFLite model object of from_endiness format.
|
||
|
from_endiness: The original endianness format of the buffers in model.
|
||
|
to_endiness: The destined endianness format of the buffers in model.
|
||
|
"""
|
||
|
if model is None:
|
||
|
return
|
||
|
# Get all the constant buffers, byte swapping them as per their data types
|
||
|
buffer_swapped = []
|
||
|
types_of_16_bits = [
|
||
|
schema_fb.TensorType.FLOAT16,
|
||
|
schema_fb.TensorType.INT16,
|
||
|
schema_fb.TensorType.UINT16,
|
||
|
]
|
||
|
types_of_32_bits = [
|
||
|
schema_fb.TensorType.FLOAT32,
|
||
|
schema_fb.TensorType.INT32,
|
||
|
schema_fb.TensorType.COMPLEX64,
|
||
|
schema_fb.TensorType.UINT32,
|
||
|
]
|
||
|
types_of_64_bits = [
|
||
|
schema_fb.TensorType.INT64,
|
||
|
schema_fb.TensorType.FLOAT64,
|
||
|
schema_fb.TensorType.COMPLEX128,
|
||
|
schema_fb.TensorType.UINT64,
|
||
|
]
|
||
|
for subgraph in model.subgraphs:
|
||
|
for tensor in subgraph.tensors:
|
||
|
if (
|
||
|
tensor.buffer > 0
|
||
|
and tensor.buffer < len(model.buffers)
|
||
|
and tensor.buffer not in buffer_swapped
|
||
|
and model.buffers[tensor.buffer].data is not None
|
||
|
):
|
||
|
if tensor.type == schema_fb.TensorType.STRING:
|
||
|
byte_swap_string_content(
|
||
|
model.buffers[tensor.buffer], from_endiness, to_endiness
|
||
|
)
|
||
|
elif tensor.type in types_of_16_bits:
|
||
|
byte_swap_buffer_content(
|
||
|
model.buffers[tensor.buffer], 2, from_endiness, to_endiness
|
||
|
)
|
||
|
elif tensor.type in types_of_32_bits:
|
||
|
byte_swap_buffer_content(
|
||
|
model.buffers[tensor.buffer], 4, from_endiness, to_endiness
|
||
|
)
|
||
|
elif tensor.type in types_of_64_bits:
|
||
|
byte_swap_buffer_content(
|
||
|
model.buffers[tensor.buffer], 8, from_endiness, to_endiness
|
||
|
)
|
||
|
else:
|
||
|
continue
|
||
|
buffer_swapped.append(tensor.buffer)
|
||
|
|
||
|
|
||
|
def byte_swap_tflite_buffer(tflite_model, from_endiness, to_endiness):
|
||
|
"""Generates a new model byte array after byte swapping its buffers field.
|
||
|
|
||
|
Args:
|
||
|
tflite_model: TFLite flatbuffer in a byte array.
|
||
|
from_endiness: The original endianness format of the buffers in
|
||
|
tflite_model.
|
||
|
to_endiness: The destined endianness format of the buffers in tflite_model.
|
||
|
|
||
|
Returns:
|
||
|
TFLite flatbuffer in a byte array, after being byte swapped to to_endiness
|
||
|
format.
|
||
|
"""
|
||
|
if tflite_model is None:
|
||
|
return None
|
||
|
# Load TFLite Flatbuffer byte array into an object.
|
||
|
model = convert_bytearray_to_object(tflite_model)
|
||
|
|
||
|
# Byte swapping the constant buffers as per their data types
|
||
|
byte_swap_tflite_model_obj(model, from_endiness, to_endiness)
|
||
|
|
||
|
# Return a TFLite flatbuffer as a byte array.
|
||
|
return convert_object_to_bytearray(model)
|
||
|
|
||
|
|
||
|
def count_resource_variables(model):
|
||
|
"""Calculates the number of unique resource variables in a model.
|
||
|
|
||
|
Args:
|
||
|
model: the input tflite model, either as bytearray or object.
|
||
|
|
||
|
Returns:
|
||
|
An integer number representing the number of unique resource variables.
|
||
|
"""
|
||
|
if not isinstance(model, schema_fb.ModelT):
|
||
|
model = convert_bytearray_to_object(model)
|
||
|
unique_shared_names = set()
|
||
|
for subgraph in model.subgraphs:
|
||
|
if subgraph.operators is None:
|
||
|
continue
|
||
|
for op in subgraph.operators:
|
||
|
builtin_code = schema_util.get_builtin_code_from_operator_code(
|
||
|
model.operatorCodes[op.opcodeIndex]
|
||
|
)
|
||
|
if builtin_code == schema_fb.BuiltinOperator.VAR_HANDLE:
|
||
|
unique_shared_names.add(op.builtinOptions.sharedName)
|
||
|
return len(unique_shared_names)
|