Projekt_AI-Automatyczny_saper/venv/Lib/site-packages/torch/quantization/utils.py
2021-06-01 17:38:31 +02:00

87 lines
3.2 KiB
Python

"""
Utils shared by different modes of quantization (eager/graph)
"""
import torch
from .quant_type import QuantType, quant_type_to_str
def get_combined_dict(default_dict, additional_dict):
d = default_dict.copy()
d.update(additional_dict)
return d
def is_per_tensor(qscheme):
return qscheme == torch.per_tensor_affine or \
qscheme == torch.per_tensor_symmetric
def is_per_channel(qscheme):
return qscheme in [torch.per_channel_affine,
torch.per_channel_affine_float_qparams,
torch.per_channel_symmetric]
def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig):
""" Get the observed/quantized custom module class that we need
to swap `custom_module` to
Input:
custom_module: input, can be an instance of either a float or observed custom module
custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping
qconfig: qconfig configured for the custom module
Output:
corresponding observed/quantized custom module class for input custom module instance
"""
quant_type = get_quant_type(qconfig)
quant_type_str = quant_type_to_str(quant_type)
class_mapping = custom_module_class_mapping.get(quant_type_str, {})
assert type(custom_module) in class_mapping, "did not find corresponding observed " \
"module class for {} in mapping: {}".format(type(custom_module), class_mapping)
return class_mapping[type(custom_module)]
def activation_is_statically_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
statically quantized or not
"""
assert qconfig is not None
activation = qconfig.activation()
return activation.dtype in [torch.quint8, torch.qint8]
def weight_dtype(qconfig):
assert qconfig is not None
weight = qconfig.weight()
return weight.dtype
def weight_is_statically_quantized(qconfig):
""" Given a qconfig, decide if the weight needs to be
quantized or not
"""
return weight_dtype(qconfig) in [torch.quint8, torch.qint8]
def get_qconfig_dtypes(qconfig):
r""" returns the qconfig tuple for qconfig:
(activation_dtype, weight_dtype, activation_compute_dtype)
"""
assert qconfig is not None
activation = qconfig.activation()
weight = qconfig.weight()
compute_dtype = activation.compute_dtype if hasattr(activation, 'compute_dtype') else None
return (activation.dtype, weight.dtype, compute_dtype)
def get_quant_type(qconfig):
assert qconfig is not None
activation = qconfig.activation()
weight = qconfig.weight()
static_dtypes = [torch.quint8, torch.qint8]
if weight.dtype in static_dtypes:
if activation.dtype in static_dtypes:
return QuantType.STATIC
elif hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes:
return QuantType.DYNAMIC
else:
return QuantType.WEIGHT_ONLY
if weight.dtype == torch.float16:
if activation.dtype == torch.float:
return QuantType.DYNAMIC
raise Exception("Unrecognized dtype combination in get_quant_type: activation({}),"
"weight({})".format(activation.dtype, weight.dtype))