3RNN/Lib/site-packages/tensorflow/lite/python/convert_phase.py
2024-05-26 19:49:15 +02:00

220 lines
7.8 KiB
Python

# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for collecting TFLite metrics."""
import collections
import enum
import functools
from typing import Text
from tensorflow.lite.python.metrics import converter_error_data_pb2
from tensorflow.lite.python.metrics import metrics
class Component(enum.Enum):
"""Enum class defining name of the converter components."""
# Validate the given input and prepare and optimize TensorFlow Model.
PREPARE_TF_MODEL = "PREPARE_TF_MODEL"
# Convert to TFLite model format.
CONVERT_TF_TO_TFLITE_MODEL = "CONVERT_TF_TO_TFLITE_MODEL"
# RUN quantization and sparsification.
OPTIMIZE_TFLITE_MODEL = "OPTIMIZE_TFLITE_MODEL"
SubComponentItem = collections.namedtuple("SubComponentItem",
["name", "component"])
class SubComponent(SubComponentItem, enum.Enum):
"""Enum class defining name of the converter subcomponents.
This enum only defines the subcomponents in Python, there might be more
subcomponents defined in C++.
"""
def __str__(self):
return self.value.name
@property
def name(self):
return self.value.name
@property
def component(self):
return self.value.component
# The subcomponent name is unspecified.
UNSPECIFIED = SubComponentItem("UNSPECIFIED", None)
# Valid the given input and parameters.
VALIDATE_INPUTS = SubComponentItem("VALIDATE_INPUTS",
Component.PREPARE_TF_MODEL)
# Load GraphDef from SavedModel.
LOAD_SAVED_MODEL = SubComponentItem("LOAD_SAVED_MODEL",
Component.PREPARE_TF_MODEL)
# Convert a SavedModel to frozen graph.
FREEZE_SAVED_MODEL = SubComponentItem("FREEZE_SAVED_MODEL",
Component.PREPARE_TF_MODEL)
# Save a Keras model to SavedModel.
CONVERT_KERAS_TO_SAVED_MODEL = SubComponentItem(
"CONVERT_KERAS_TO_SAVED_MODEL", Component.PREPARE_TF_MODEL)
# Save Concrete functions to SavedModel.
CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL = SubComponentItem(
"CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL", Component.PREPARE_TF_MODEL)
# Convert a Keras model to a frozen graph.
FREEZE_KERAS_MODEL = SubComponentItem("FREEZE_KERAS_MODEL",
Component.PREPARE_TF_MODEL)
# Replace all the variables with constants in a ConcreteFunction.
FREEZE_CONCRETE_FUNCTION = SubComponentItem("FREEZE_CONCRETE_FUNCTION",
Component.PREPARE_TF_MODEL)
# Run grappler optimization.
OPTIMIZE_TF_MODEL = SubComponentItem("OPTIMIZE_TF_MODEL",
Component.PREPARE_TF_MODEL)
# Convert using the old TOCO converter.
CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER = SubComponentItem(
"CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER",
Component.CONVERT_TF_TO_TFLITE_MODEL)
# Convert a GraphDef to TFLite model.
CONVERT_GRAPHDEF = SubComponentItem("CONVERT_GRAPHDEF",
Component.CONVERT_TF_TO_TFLITE_MODEL)
# Convert a SavedModel to TFLite model.
CONVERT_SAVED_MODEL = SubComponentItem("CONVERT_SAVED_MODEL",
Component.CONVERT_TF_TO_TFLITE_MODEL)
# Convert a Jax HLO to TFLite model.
CONVERT_JAX_HLO = SubComponentItem("CONVERT_JAX_HLO",
Component.CONVERT_TF_TO_TFLITE_MODEL)
# Do quantization by the deprecated quantizer.
QUANTIZE_USING_DEPRECATED_QUANTIZER = SubComponentItem(
"QUANTIZE_USING_DEPRECATED_QUANTIZER", Component.OPTIMIZE_TFLITE_MODEL)
# Do calibration.
CALIBRATE = SubComponentItem("CALIBRATE", Component.OPTIMIZE_TFLITE_MODEL)
# Do quantization by MLIR.
QUANTIZE = SubComponentItem("QUANTIZE", Component.OPTIMIZE_TFLITE_MODEL)
# Do sparsification by MLIR.
SPARSIFY = SubComponentItem("SPARSIFY", Component.OPTIMIZE_TFLITE_MODEL)
class ConverterError(Exception):
"""Raised when an error occurs during model conversion."""
def __init__(self, message):
super(ConverterError, self).__init__(message)
self.errors = []
self._parse_error_message(message)
def append_error(self,
error_data: converter_error_data_pb2.ConverterErrorData):
self.errors.append(error_data)
def _parse_error_message(self, message):
"""If the message matches a pattern, assigns the associated error code.
It is difficult to assign an error code to some errrors in MLIR side, Ex:
errors thrown by other components than TFLite or not using mlir::emitError.
This function try to detect them by the error message and assign the
corresponding error code.
Args:
message: The error message of this exception.
"""
error_code_mapping = {
"Failed to functionalize Control Flow V1 ops. Consider using Control "
"Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/"
"tf/compat/v1/enable_control_flow_v2.":
converter_error_data_pb2.ConverterErrorData
.ERROR_UNSUPPORTED_CONTROL_FLOW_V1,
}
for pattern, error_code in error_code_mapping.items():
if pattern in message:
error_data = converter_error_data_pb2.ConverterErrorData()
error_data.error_message = message
error_data.error_code = error_code
self.append_error(error_data)
return
def convert_phase(component, subcomponent=SubComponent.UNSPECIFIED):
"""The decorator to identify converter component and subcomponent.
Args:
component: Converter component name.
subcomponent: Converter subcomponent name.
Returns:
Forward the result from the wrapped function.
Raises:
ValueError: if component and subcomponent name is not valid.
"""
if component not in Component:
raise ValueError("Given component name not found")
if subcomponent not in SubComponent:
raise ValueError("Given subcomponent name not found")
if (subcomponent != SubComponent.UNSPECIFIED and
subcomponent.component != component):
raise ValueError("component and subcomponent name don't match")
def report_error(error_data: converter_error_data_pb2.ConverterErrorData):
# Always overwrites the component information, but only overwrites the
# subcomponent if it is not available.
error_data.component = component.value
if not error_data.subcomponent:
error_data.subcomponent = subcomponent.name
tflite_metrics = metrics.TFLiteConverterMetrics()
tflite_metrics.set_converter_error(error_data)
def report_error_message(error_message: Text):
error_data = converter_error_data_pb2.ConverterErrorData()
error_data.error_message = error_message
report_error(error_data)
def actual_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except ConverterError as converter_error:
if converter_error.errors:
for error_data in converter_error.errors:
report_error(error_data)
else:
report_error_message(str(converter_error))
raise converter_error from None # Re-throws the exception.
except Exception as error:
report_error_message(str(error))
raise error from None # Re-throws the exception.
return wrapper
return actual_decorator