87 lines
3.2 KiB
Python
87 lines
3.2 KiB
Python
|
"""
|
||
|
Utils shared by different modes of quantization (eager/graph)
|
||
|
"""
|
||
|
import torch
|
||
|
from .quant_type import QuantType, quant_type_to_str
|
||
|
|
||
|
def get_combined_dict(default_dict, additional_dict):
|
||
|
d = default_dict.copy()
|
||
|
d.update(additional_dict)
|
||
|
return d
|
||
|
|
||
|
def is_per_tensor(qscheme):
|
||
|
return qscheme == torch.per_tensor_affine or \
|
||
|
qscheme == torch.per_tensor_symmetric
|
||
|
|
||
|
def is_per_channel(qscheme):
|
||
|
return qscheme in [torch.per_channel_affine,
|
||
|
torch.per_channel_affine_float_qparams,
|
||
|
torch.per_channel_symmetric]
|
||
|
|
||
|
def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig):
|
||
|
""" Get the observed/quantized custom module class that we need
|
||
|
to swap `custom_module` to
|
||
|
Input:
|
||
|
custom_module: input, can be an instance of either a float or observed custom module
|
||
|
custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping
|
||
|
qconfig: qconfig configured for the custom module
|
||
|
|
||
|
Output:
|
||
|
corresponding observed/quantized custom module class for input custom module instance
|
||
|
"""
|
||
|
quant_type = get_quant_type(qconfig)
|
||
|
quant_type_str = quant_type_to_str(quant_type)
|
||
|
class_mapping = custom_module_class_mapping.get(quant_type_str, {})
|
||
|
assert type(custom_module) in class_mapping, "did not find corresponding observed " \
|
||
|
"module class for {} in mapping: {}".format(type(custom_module), class_mapping)
|
||
|
return class_mapping[type(custom_module)]
|
||
|
|
||
|
def activation_is_statically_quantized(qconfig):
|
||
|
""" Given a qconfig, decide if the activation needs to be
|
||
|
statically quantized or not
|
||
|
"""
|
||
|
assert qconfig is not None
|
||
|
activation = qconfig.activation()
|
||
|
return activation.dtype in [torch.quint8, torch.qint8]
|
||
|
|
||
|
def weight_dtype(qconfig):
|
||
|
assert qconfig is not None
|
||
|
weight = qconfig.weight()
|
||
|
return weight.dtype
|
||
|
|
||
|
def weight_is_statically_quantized(qconfig):
|
||
|
""" Given a qconfig, decide if the weight needs to be
|
||
|
quantized or not
|
||
|
"""
|
||
|
return weight_dtype(qconfig) in [torch.quint8, torch.qint8]
|
||
|
|
||
|
def get_qconfig_dtypes(qconfig):
|
||
|
r""" returns the qconfig tuple for qconfig:
|
||
|
(activation_dtype, weight_dtype, activation_compute_dtype)
|
||
|
"""
|
||
|
assert qconfig is not None
|
||
|
activation = qconfig.activation()
|
||
|
weight = qconfig.weight()
|
||
|
compute_dtype = activation.compute_dtype if hasattr(activation, 'compute_dtype') else None
|
||
|
return (activation.dtype, weight.dtype, compute_dtype)
|
||
|
|
||
|
def get_quant_type(qconfig):
|
||
|
assert qconfig is not None
|
||
|
activation = qconfig.activation()
|
||
|
weight = qconfig.weight()
|
||
|
static_dtypes = [torch.quint8, torch.qint8]
|
||
|
if weight.dtype in static_dtypes:
|
||
|
if activation.dtype in static_dtypes:
|
||
|
return QuantType.STATIC
|
||
|
elif hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes:
|
||
|
return QuantType.DYNAMIC
|
||
|
else:
|
||
|
return QuantType.WEIGHT_ONLY
|
||
|
|
||
|
if weight.dtype == torch.float16:
|
||
|
if activation.dtype == torch.float:
|
||
|
return QuantType.DYNAMIC
|
||
|
|
||
|
raise Exception("Unrecognized dtype combination in get_quant_type: activation({}),"
|
||
|
"weight({})".format(activation.dtype, weight.dtype))
|