704 lines
25 KiB
Python
704 lines
25 KiB
Python
"""
|
|
Utils shared by different modes of quantization (eager/graph)
|
|
"""
|
|
import functools
|
|
import warnings
|
|
from collections import OrderedDict
|
|
from inspect import getfullargspec, signature
|
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from torch.ao.quantization.quant_type import QuantType
|
|
from torch.fx import Node
|
|
from torch.nn.utils.parametrize import is_parametrized
|
|
|
|
NodePattern = Union[Tuple[Node, Node], Tuple[Node, Tuple[Node, Node]], Any]
|
|
NodePattern.__module__ = "torch.ao.quantization.utils"
|
|
|
|
# This is the Quantizer class instance from torch/quantization/fx/quantize.py.
|
|
# Define separately to prevent circular imports.
|
|
# TODO(future PR): improve this.
|
|
# make this public once fixed (can't be public as is because setting the module directly
|
|
# doesn't work)
|
|
QuantizerCls = Any
|
|
|
|
# Type for fusion patterns, it can be more complicated than the following actually,
|
|
# see pattern.md for docs
|
|
# TODO: not sure if typing supports recursive data types
|
|
Pattern = Union[
|
|
Callable, Tuple[Callable, Callable], Tuple[Callable, Tuple[Callable, Callable]], Any
|
|
]
|
|
Pattern.__module__ = "torch.ao.quantization.utils"
|
|
|
|
# TODO: maybe rename this to MatchInputNode
|
|
class MatchAllNode:
|
|
""" A node pattern that matches all nodes, used in defining
|
|
fusion patterns in FX Graph Mode Quantization
|
|
"""
|
|
pass
|
|
|
|
module_type_list = {
|
|
torch.nn.ReLU,
|
|
torch.nn.ReLU6,
|
|
torch.nn.AdaptiveAvgPool1d,
|
|
torch.nn.AdaptiveAvgPool2d,
|
|
torch.nn.AdaptiveAvgPool3d,
|
|
torch.nn.AvgPool1d,
|
|
torch.nn.AvgPool2d,
|
|
torch.nn.AvgPool3d,
|
|
torch.nn.MaxPool1d,
|
|
torch.nn.MaxPool2d,
|
|
torch.nn.MaxPool3d,
|
|
torch.nn.Identity,
|
|
torch.nn.Hardsigmoid,
|
|
torch.nn.Sigmoid,
|
|
torch.nn.Tanh,
|
|
}
|
|
func_list = {
|
|
torch.nn.functional.adaptive_avg_pool1d,
|
|
torch.nn.functional.adaptive_avg_pool2d,
|
|
torch.nn.functional.adaptive_avg_pool3d,
|
|
torch.nn.functional.elu,
|
|
torch.nn.functional.hardswish,
|
|
torch.nn.functional.instance_norm,
|
|
torch.nn.functional.layer_norm,
|
|
torch.nn.functional.leaky_relu,
|
|
torch.nn.functional.silu,
|
|
torch.nn.functional.mish,
|
|
torch.nn.functional.dropout,
|
|
torch.nn.functional.max_pool1d,
|
|
torch.nn.functional.max_pool2d,
|
|
torch.nn.functional.max_pool3d,
|
|
torch.nn.functional.relu,
|
|
torch.nn.functional.hardtanh,
|
|
torch.nn.functional.hardtanh_,
|
|
torch.nn.functional.hardsigmoid,
|
|
torch.nn.functional.sigmoid,
|
|
torch.transpose,
|
|
torch.repeat_interleave,
|
|
torch.sigmoid,
|
|
torch.squeeze,
|
|
torch.stack,
|
|
torch.sum,
|
|
torch.tanh,
|
|
torch.unsqueeze,
|
|
torch.cat,
|
|
}
|
|
method_list = {
|
|
torch.mean,
|
|
'relu',
|
|
'relu_',
|
|
'contiguous',
|
|
'detach',
|
|
'detach_',
|
|
'hardsigmoid',
|
|
'hardsigmoid_',
|
|
'permute',
|
|
'repeat',
|
|
'repeat_interleave',
|
|
'reshape',
|
|
'resize_',
|
|
'shape',
|
|
'sigmoid',
|
|
'sigmoid_',
|
|
'size',
|
|
'squeeze',
|
|
'squeeze_',
|
|
'tanh',
|
|
'tanh_',
|
|
'transpose',
|
|
'unsqueeze',
|
|
'unsqueeze_',
|
|
'view',
|
|
}
|
|
|
|
# TODO: not used now, remove
|
|
def check_node(node, modules):
|
|
# TODO: reuse is_fixed_qparam_node after we move this function to _lower_to_native_backend.py
|
|
is_call_function = node.op == "call_function" and node.target in func_list
|
|
is_call_method = node.op == "call_method" and node.target in method_list
|
|
is_call_module = node.op == "call_module" and type(modules[str(node.target)]) in module_type_list
|
|
return is_call_function, is_call_method, is_call_module
|
|
|
|
def get_combined_dict(default_dict, additional_dict):
|
|
d = default_dict.copy()
|
|
d.update(additional_dict)
|
|
return d
|
|
|
|
def is_per_tensor(qscheme):
|
|
return qscheme == torch.per_tensor_affine or \
|
|
qscheme == torch.per_tensor_symmetric
|
|
|
|
def is_per_channel(qscheme):
|
|
return qscheme in [torch.per_channel_affine,
|
|
torch.per_channel_affine_float_qparams,
|
|
torch.per_channel_symmetric]
|
|
|
|
def getattr_from_fqn(obj: Any, fqn: str) -> Any:
|
|
"""
|
|
Given an obj and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz.
|
|
"""
|
|
return functools.reduce(getattr, fqn.split("."), obj)
|
|
|
|
def to_underlying_dtype(qdtype):
|
|
DTYPE_MAPPING = {
|
|
torch.quint8: torch.uint8,
|
|
torch.qint8: torch.int8,
|
|
torch.qint32: torch.int32,
|
|
torch.quint4x2: torch.uint8,
|
|
torch.quint2x4: torch.uint8,
|
|
torch.uint8: torch.uint8,
|
|
torch.int8: torch.int8,
|
|
torch.int16: torch.int16,
|
|
torch.int32: torch.int32,
|
|
}
|
|
assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + str(qdtype)
|
|
return DTYPE_MAPPING[qdtype]
|
|
|
|
def get_qparam_dict(observer_or_fake_quant):
|
|
from torch.ao.quantization.observer import PlaceholderObserver
|
|
|
|
qscheme = getattr(observer_or_fake_quant, "qscheme", None)
|
|
dtype = observer_or_fake_quant.dtype
|
|
qparams = {"qscheme": qscheme, "dtype": dtype}
|
|
|
|
if not qscheme or isinstance(observer_or_fake_quant, PlaceholderObserver):
|
|
return {"qscheme": None, "dtype": dtype}
|
|
|
|
if is_per_tensor(qscheme):
|
|
qscheme = torch.per_tensor_affine
|
|
elif is_per_channel(qscheme):
|
|
# change symmetric to affine since we do not have symmetric
|
|
# quantized Tensor
|
|
if qscheme == torch.per_channel_symmetric:
|
|
qscheme = torch.per_channel_affine
|
|
qparams["axis"] = observer_or_fake_quant.ch_axis
|
|
else:
|
|
raise RuntimeError(f"Unrecognized qscheme: {qscheme}")
|
|
# update qscheme, since we don't have symmetric quant qscheme
|
|
# in quantized Tensor
|
|
qparams["qscheme"] = qscheme
|
|
|
|
scale, zero_point = observer_or_fake_quant.calculate_qparams()
|
|
qparams["scale"] = scale
|
|
qparams["zero_point"] = zero_point
|
|
|
|
if hasattr(observer_or_fake_quant, "quant_min"):
|
|
qparams["quant_min"] = observer_or_fake_quant.quant_min
|
|
if hasattr(observer_or_fake_quant, "quant_max"):
|
|
qparams["quant_max"] = observer_or_fake_quant.quant_max
|
|
|
|
return qparams
|
|
|
|
|
|
def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig):
|
|
""" Get the observed/quantized custom module class that we need
|
|
to swap `custom_module` to
|
|
Input:
|
|
custom_module: input, can be an instance of either a float or observed custom module
|
|
custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping
|
|
qconfig: qconfig configured for the custom module
|
|
|
|
Output:
|
|
corresponding observed/quantized custom module class for input custom module instance
|
|
"""
|
|
quant_type = get_quant_type(qconfig)
|
|
class_mapping = custom_module_class_mapping.get(quant_type, {})
|
|
assert type(custom_module) in class_mapping, "did not find corresponding observed " \
|
|
f"module class for {type(custom_module)} in mapping: {class_mapping}"
|
|
return class_mapping[type(custom_module)]
|
|
|
|
def activation_dtype(qconfig):
|
|
assert qconfig is not None
|
|
activation = qconfig.activation()
|
|
return activation.dtype
|
|
|
|
def weight_dtype(qconfig):
|
|
assert qconfig is not None
|
|
weight = qconfig.weight()
|
|
return weight.dtype
|
|
|
|
def activation_is_statically_quantized(qconfig):
|
|
""" Given a qconfig, decide if the activation needs to be
|
|
quantized or not, this includes quantizing to quint8, qint8 and qint32 and float16
|
|
"""
|
|
return (
|
|
activation_dtype(qconfig) in [
|
|
torch.quint8,
|
|
torch.qint8,
|
|
torch.qint32,
|
|
torch.float16,
|
|
torch.uint8,
|
|
torch.int8,
|
|
torch.int16,
|
|
torch.int32
|
|
]
|
|
and (not activation_is_dynamically_quantized(qconfig))
|
|
)
|
|
|
|
def activation_is_dynamically_quantized(qconfig):
|
|
""" Given a qconfig, decide if the activation needs to be
|
|
dynamically quantized or not, this includes dynamically quantizing to
|
|
quint8, qint8 and float16
|
|
"""
|
|
activation_dtype, _, activation_is_dynamic = \
|
|
get_qconfig_dtypes(qconfig)
|
|
return activation_is_dynamic
|
|
|
|
def activation_is_int8_quantized(qconfig):
|
|
""" Given a qconfig, decide if the activation needs to be
|
|
quantized to int8 or not, this includes quantizing to quint8, qint8
|
|
"""
|
|
return activation_dtype(qconfig) in [torch.quint8, torch.qint8, torch.uint8, torch.int8]
|
|
|
|
def activation_is_int32_quantized(qconfig):
|
|
""" Given a qconfig, decide if the activation needs to be
|
|
quantized to int32 or not
|
|
"""
|
|
return activation_dtype(qconfig) in [torch.qint32, torch.int32]
|
|
|
|
def weight_is_quantized(qconfig):
|
|
""" Given a qconfig, decide if the weight needs to be
|
|
quantized or not
|
|
"""
|
|
return weight_dtype(qconfig) in [
|
|
torch.quint8,
|
|
torch.qint8,
|
|
torch.float16,
|
|
torch.quint4x2,
|
|
torch.uint8,
|
|
torch.int8,
|
|
torch.int16,
|
|
torch.int32
|
|
]
|
|
|
|
def weight_is_statically_quantized(qconfig):
|
|
""" Given a qconfig, decide if the weight needs to be statically
|
|
quantized or not
|
|
"""
|
|
return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.uint8, torch.int8]
|
|
|
|
def op_is_int8_dynamically_quantized(qconfig) -> bool:
|
|
""" Given a qconfig, returns True if this op is using int8 dynamic
|
|
quantization
|
|
"""
|
|
activation_dtype, weight_dtype, activation_is_dynamic = \
|
|
get_qconfig_dtypes(qconfig)
|
|
return (
|
|
activation_dtype in [torch.quint8, torch.uint8] and
|
|
# for now, the lines below assume fbgemm or qnnpack
|
|
weight_dtype in [torch.qint8, torch.int8] and
|
|
activation_is_dynamic
|
|
)
|
|
|
|
def get_qconfig_dtypes(qconfig):
|
|
r""" returns the qconfig tuple for qconfig:
|
|
(activation_dtype, weight_dtype, activation_is_dynamic)
|
|
"""
|
|
assert qconfig is not None
|
|
activation = qconfig.activation()
|
|
weight = qconfig.weight()
|
|
act_is_dynamic = getattr(activation, "is_dynamic", False)
|
|
return (activation.dtype, weight.dtype, act_is_dynamic)
|
|
|
|
def get_quant_type(qconfig):
|
|
assert qconfig is not None
|
|
activation = qconfig.activation()
|
|
weight = qconfig.weight()
|
|
static_dtypes = [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32, torch.uint8, torch.int8, torch.int16, torch.int32]
|
|
if weight.dtype in static_dtypes:
|
|
if hasattr(activation, 'is_dynamic') and activation.is_dynamic:
|
|
return QuantType.DYNAMIC
|
|
elif activation.dtype in static_dtypes:
|
|
return QuantType.STATIC
|
|
else:
|
|
return QuantType.WEIGHT_ONLY
|
|
|
|
if weight.dtype == torch.float16:
|
|
if hasattr(activation, 'is_dynamic') and activation.is_dynamic:
|
|
return QuantType.DYNAMIC
|
|
elif activation.dtype == torch.float16:
|
|
return QuantType.STATIC
|
|
|
|
raise Exception(f"Unrecognized dtype combination in get_quant_type: activation({activation.dtype}),"
|
|
f"weight({weight.dtype})")
|
|
|
|
def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool:
|
|
""" Checks if the given minimum and maximum values are valid, meaning that
|
|
they exist and the min value is less than the max value.
|
|
"""
|
|
if min_val.numel() == 0 or max_val.numel() == 0:
|
|
warnings.warn(
|
|
"must run observer before calling calculate_qparams. " +
|
|
"Returning default values."
|
|
)
|
|
return False
|
|
|
|
if min_val.dim() == 0 or max_val.dim() == 0:
|
|
if min_val == float("inf") and max_val == float("-inf"):
|
|
warnings.warn(
|
|
"must run observer before calling calculate_qparams. " +
|
|
"Returning default values."
|
|
)
|
|
|
|
return False
|
|
|
|
assert min_val <= max_val, f"min {min_val} should be less than max {max_val}"
|
|
else:
|
|
assert torch.all(
|
|
min_val <= max_val
|
|
), f"min {min_val} should be less than max {max_val}"
|
|
|
|
return True
|
|
|
|
|
|
def calculate_qmin_qmax(quant_min: int, quant_max: int, has_customized_qrange: bool, dtype: torch.dtype,
|
|
reduce_range: bool) -> Tuple[int, int]:
|
|
r"""Calculates actual qmin and qmax based on the quantization range,
|
|
observer datatype and if range is reduced.
|
|
"""
|
|
# TODO(jerryzh): Figure out why custom quant_min/quant_max are still adjusted.
|
|
if has_customized_qrange:
|
|
# This initialization here is to be resolve TorchScript compilation issues and allow
|
|
# using of refinement to decouple initial_qmin and initial_qmax from quantization range.
|
|
# The actual values of initial_qmin and initial_qmax will be reset below.
|
|
if dtype in [torch.qint32, torch.int32]:
|
|
initial_quant_min, initial_quant_max = 0, 2**32 - 1
|
|
else:
|
|
initial_quant_min, initial_quant_max = 0, 255
|
|
# The following assignment of self.qmin and self.qmax to the local variables and the if check refine the
|
|
# attribute from Optional valid integers for use, based on TorchScript's requirements.
|
|
custom_quant_min, custom_quant_max = quant_min, quant_max
|
|
if custom_quant_min is not None and custom_quant_max is not None:
|
|
initial_quant_min, initial_quant_max = (
|
|
custom_quant_min,
|
|
custom_quant_max,
|
|
)
|
|
|
|
qrange_len = initial_quant_max - initial_quant_min + 1
|
|
if dtype in [torch.qint8, torch.int8]:
|
|
assert (
|
|
0 < qrange_len <= 256
|
|
), "quantization range should be positive and not exceed the maximum bit range (=256)."
|
|
elif dtype in [torch.qint32, torch.int32]:
|
|
assert (
|
|
0 < qrange_len <= 2**32
|
|
), "quantization range should be positive and not exceed the maximum bit range (=4294967296)."
|
|
if reduce_range:
|
|
quant_min, quant_max = quant_min // 2, quant_max // 2
|
|
else:
|
|
# Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used.
|
|
if dtype in [torch.qint8, torch.int8]:
|
|
if reduce_range:
|
|
quant_min, quant_max = -64, 63
|
|
else:
|
|
quant_min, quant_max = -128, 127
|
|
elif dtype in [torch.quint8, torch.uint8]:
|
|
if reduce_range:
|
|
quant_min, quant_max = 0, 127
|
|
else:
|
|
quant_min, quant_max = 0, 255
|
|
elif dtype in [torch.qint32, torch.int32]:
|
|
quant_min, quant_max = -1 * (2 ** 31), (2 ** 31) - 1
|
|
else:
|
|
quant_min, quant_max = 0, 15
|
|
return quant_min, quant_max
|
|
|
|
|
|
def _parent_name(target):
|
|
"""
|
|
Turn 'foo.bar' into ['foo', 'bar']
|
|
"""
|
|
r = target.rsplit('.', 1)
|
|
if len(r) == 1:
|
|
return '', r[0]
|
|
else:
|
|
return r[0], r[1]
|
|
|
|
def has_no_children_ignoring_parametrizations(module):
|
|
"""
|
|
Checks if module._modules is empty or
|
|
if module is a parametrization, checks that module._modules only has
|
|
the 'parametrizations' module
|
|
"""
|
|
if len(module._modules) == 0:
|
|
return True
|
|
elif is_parametrized(module):
|
|
return len(module._modules) == 1 and 'parametrizations' in module._modules
|
|
else:
|
|
return False
|
|
|
|
def _get_path_of_module(root: torch.nn.Module, submodule: torch.nn.Module) -> Optional[str]:
|
|
""" Get the path (fully qualified name) of a submodule
|
|
|
|
Example::
|
|
|
|
>> class M(torch.nn.Module):
|
|
def __init__(self):
|
|
self.linear = torch.nn.Linear(5, 5)
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
>> m = M()
|
|
>> l = m.linear
|
|
>> _get_path_of_module(m, l)
|
|
"linear"
|
|
"""
|
|
for n, p in root.named_modules():
|
|
if submodule is p:
|
|
return n
|
|
return None
|
|
|
|
def _get_signature_locals(f: Callable, loc: Dict[str, Any]) -> Dict[str, Any]:
|
|
""" Get local keyword arguments
|
|
|
|
Example::
|
|
|
|
>> def f(self, a, b=9):
|
|
pass
|
|
>> loc = {"a": 6, "c": 7}
|
|
>> _get_signature_locals(f, loc)
|
|
{"a": 6}
|
|
"""
|
|
return {k: v for k, v in loc.items() if k in signature(f).parameters}
|
|
|
|
def _get_default_kwargs(f: Callable) -> "OrderedDict[str, Any]":
|
|
""" Get all default keyword arguments from function signature
|
|
|
|
Example::
|
|
|
|
>> def f(self, a, b=9):
|
|
pass
|
|
>> _get_default_kwargs(f)
|
|
{"b": 9}
|
|
"""
|
|
kwargs = {}
|
|
for name, param in signature(f).parameters.items():
|
|
if param.default is not param.empty:
|
|
kwargs[name] = param.default
|
|
elif param.kind is param.VAR_POSITIONAL:
|
|
kwargs[name] = ()
|
|
elif param.kind is param.VAR_KEYWORD:
|
|
kwargs[name] = {}
|
|
return OrderedDict(kwargs)
|
|
|
|
def _normalize_kwargs(func: Callable, loc: Dict[str, Any]) -> "OrderedDict[str, Any]":
|
|
""" Given a function and local function arguments, normalize the keyword
|
|
arguments by filling in default arguments from function signature
|
|
|
|
Example::
|
|
|
|
>> def f(self, key1=3, key2=3):
|
|
pass
|
|
>> loc = {"key2": 6}
|
|
>> _normalize_kwargs(f, loc)
|
|
{"key1": 3, "key2": 6}
|
|
"""
|
|
default_kwargs = _get_default_kwargs(func)
|
|
local_kwargs = _get_signature_locals(func, loc)
|
|
normalized_kwargs = default_kwargs.copy()
|
|
for attr, val in local_kwargs.items():
|
|
if attr in normalized_kwargs:
|
|
# override the default keyword arguments
|
|
normalized_kwargs[attr] = val
|
|
return normalized_kwargs
|
|
|
|
def validate_qmin_qmax(quant_min: int, quant_max: int) -> None:
|
|
r"""Validates that the user-specified quantization range is properly initialized
|
|
and within the given bound supported by the observer dtype.
|
|
|
|
To accommodate lower-bit quantization with respect to the existing torch.qint8 and
|
|
torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing
|
|
in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax
|
|
values are used to calculate static estimates of the scale and zero point for aggressive lower-bit
|
|
fake quantization. These estimates are compared against parameters learned through backpropagation.
|
|
The related literatures for scale and zero point via backpropagation are as follows:
|
|
|
|
Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS
|
|
Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf
|
|
"""
|
|
# The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
|
|
# based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
|
|
assert (
|
|
quant_min <= 0 <= quant_max
|
|
), "Used-specified quantization range must include 0."
|
|
assert (
|
|
quant_min < quant_max
|
|
), "qmin must be strictly less than qmax for user-specified quantization range."
|
|
|
|
|
|
# Functionally equivalent to '_calculate_qparams' in observer.py. Observers must be torchscriptable however and qscheme
|
|
# as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer
|
|
# to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code seems unlikey to change
|
|
# (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. TODO(jakeszwe, jerryzh168)
|
|
def determine_qparams(
|
|
min_val: torch.Tensor, max_val: torch.Tensor, quant_min: int, quant_max: int,
|
|
dtype: torch.dtype, eps: torch.Tensor, has_customized_qrange: bool,
|
|
qscheme: torch.qscheme = torch.per_tensor_affine) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
r"""Calculates the quantization parameters, given min and max
|
|
value tensors. Works for both per tensor and per channel cases
|
|
|
|
Args:
|
|
min_val: Minimum values per channel
|
|
max_val: Maximum values per channel
|
|
|
|
Returns:
|
|
scales: Scales tensor of shape (#channels,)
|
|
zero_points: Zero points tensor of shape (#channels,)
|
|
"""
|
|
if not check_min_max_valid(min_val, max_val):
|
|
return torch.tensor([1.0], device=min_val.device.type), torch.tensor([0], device=min_val.device.type)
|
|
|
|
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
|
|
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
|
|
|
|
device = min_val_neg.device
|
|
scale = torch.ones(min_val_neg.size(), dtype=torch.double, device=device)
|
|
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
|
|
|
|
if (
|
|
qscheme == torch.per_tensor_symmetric
|
|
or qscheme == torch.per_channel_symmetric
|
|
):
|
|
max_val_pos = torch.max(-min_val_neg, max_val_pos)
|
|
scale = max_val_pos / (float(quant_max - quant_min) / 2)
|
|
scale = torch.max(scale, eps)
|
|
if dtype in [torch.uint8, torch.quint8]:
|
|
if has_customized_qrange:
|
|
# When customized quantization range is used, down-rounded midpoint of the range is chosen.
|
|
zero_point = zero_point.new_full(
|
|
zero_point.size(), (quant_min + quant_max) // 2
|
|
)
|
|
else:
|
|
zero_point = zero_point.new_full(zero_point.size(), 128)
|
|
elif qscheme == torch.per_channel_affine_float_qparams:
|
|
scale = (max_val - min_val) / float(quant_max - quant_min)
|
|
scale = torch.where(scale > eps, scale, torch.ones_like(scale))
|
|
# We use the quantize function
|
|
# xq = Round(Xf * inv_scale + zero_point),
|
|
# setting zero_point to (-1 * min *inv_scale) we get
|
|
# Xq = Round((Xf - min) * inv_scale)
|
|
zero_point = -1 * min_val / scale
|
|
else:
|
|
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
|
|
scale = torch.max(scale, eps)
|
|
zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
|
|
zero_point = torch.clamp(zero_point, quant_min, quant_max)
|
|
|
|
# For scalar values, cast them to Tensors of size 1 to keep the shape
|
|
# consistent with default values in FakeQuantize.
|
|
if len(scale.shape) == 0:
|
|
# TODO: switch to scale.item() after adding JIT support
|
|
scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device)
|
|
if len(zero_point.shape) == 0:
|
|
# TODO: switch to zero_point.item() after adding JIT support
|
|
zero_point = torch.tensor(
|
|
[int(zero_point)], dtype=zero_point.dtype, device=device
|
|
)
|
|
if qscheme == torch.per_channel_affine_float_qparams:
|
|
zero_point = torch.tensor(
|
|
[float(zero_point)], dtype=zero_point.dtype, device=device
|
|
)
|
|
|
|
return scale.to(torch.double), zero_point.to(torch.int64)
|
|
|
|
def _get_num_pos_args(f: Callable) -> int:
|
|
""" Get number of positional args for a function
|
|
|
|
Example::
|
|
|
|
>> def f(self, key1=3, key2=3):
|
|
pass
|
|
>> _get_num_pos_args(f)
|
|
3
|
|
"""
|
|
return len(getfullargspec(f).args)
|
|
|
|
def get_fqn_to_example_inputs(
|
|
model: torch.nn.Module,
|
|
example_inputs: Tuple[Any, ...]
|
|
) -> Dict[str, Tuple[Any, ...]]:
|
|
""" Given a model and its example inputs, return a dictionary from
|
|
fully qualified name of submodules to example_inputs for that submodule,
|
|
e.g. {"linear1": (tensor1,), "linear2": (tensor2,), "sub": (tensor3,),
|
|
"sub.linear1": (tensor4,), ...}
|
|
|
|
Used to make quantizing submodules easier now that FX Graph Mode Quantization requires
|
|
example inputs.
|
|
|
|
Also works for keyword arguments with default values, we would flatten keyword
|
|
arguments as positional arguments and fill in the missing keyword args with default
|
|
values, e.g. if we have a forward function:
|
|
def forward(self, x, key1=3, key2=3):
|
|
...
|
|
|
|
and we call it with self.submodule(x, key2=6)
|
|
we'll get example_inputs: (x, 3, 6)
|
|
|
|
user can also override `key1` with positional arguments as well:
|
|
for self.submodule(x, 5, key2=6)
|
|
we'll get: (x, 5, 6)
|
|
|
|
variable positional arguments and variable positional keyword arguments in forward
|
|
function are not supported currently, so please make sure no submodules is using
|
|
them.
|
|
"""
|
|
root = model
|
|
fqn_to_example_inputs = {}
|
|
|
|
def _patched_module_call(self, *args, **kwargs):
|
|
submodule_example_inputs = list(args).copy()
|
|
normalized_kwargs = _normalize_kwargs(self.forward, kwargs)
|
|
# minus 1 to skipping counting `self`
|
|
num_args = _get_num_pos_args(self.forward) - 1
|
|
num_to_pop = num_args - len(submodule_example_inputs)
|
|
while num_to_pop and normalized_kwargs:
|
|
normalized_kwargs.popitem(last=False)
|
|
num_to_pop -= 1
|
|
submodule_example_inputs.extend(normalized_kwargs.values())
|
|
submodule_example_inputs_tuple = tuple(submodule_example_inputs)
|
|
fqn = _get_path_of_module(root, self)
|
|
if fqn is not None:
|
|
fqn_to_example_inputs[fqn] = submodule_example_inputs_tuple
|
|
return orig_module_call(self, *args, **kwargs)
|
|
|
|
orig_module_call = torch.nn.Module.__call__
|
|
torch.nn.Module.__call__ = _patched_module_call # type: ignore[method-assign]
|
|
try:
|
|
model(*example_inputs)
|
|
finally:
|
|
# restore the module call even if there is an exception
|
|
torch.nn.Module.__call__ = orig_module_call # type: ignore[method-assign]
|
|
return fqn_to_example_inputs
|
|
|
|
__all__ = [
|
|
"NodePattern",
|
|
"Pattern",
|
|
"MatchAllNode",
|
|
"check_node",
|
|
"get_combined_dict",
|
|
"is_per_tensor",
|
|
"is_per_channel",
|
|
"getattr_from_fqn",
|
|
"get_qparam_dict",
|
|
"get_swapped_custom_module_class",
|
|
"activation_dtype",
|
|
"weight_dtype",
|
|
"activation_is_statically_quantized",
|
|
"activation_is_dynamically_quantized",
|
|
"activation_is_int8_quantized",
|
|
"activation_is_int32_quantized",
|
|
"weight_is_quantized",
|
|
"weight_is_statically_quantized",
|
|
"op_is_int8_dynamically_quantized",
|
|
"get_qconfig_dtypes",
|
|
"get_quant_type",
|
|
"check_min_max_valid",
|
|
"calculate_qmin_qmax",
|
|
"has_no_children_ignoring_parametrizations",
|
|
"get_fqn_to_example_inputs",
|
|
"to_underlying_dtype",
|
|
"determine_qparams",
|
|
"validate_qmin_qmax",
|
|
]
|