561 lines
25 KiB
Python
561 lines
25 KiB
Python
|
from collections import namedtuple
|
||
|
from typing import Optional, Any, Union, Type
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from torch.ao.quantization.fake_quantize import (
|
||
|
FakeQuantize,
|
||
|
FakeQuantizeBase,
|
||
|
default_fake_quant,
|
||
|
default_dynamic_fake_quant,
|
||
|
default_per_channel_weight_fake_quant,
|
||
|
default_weight_fake_quant,
|
||
|
default_fused_act_fake_quant,
|
||
|
default_fused_wt_fake_quant,
|
||
|
FusedMovingAvgObsFakeQuantize,
|
||
|
default_fused_per_channel_wt_fake_quant,
|
||
|
default_embedding_fake_quant,
|
||
|
default_embedding_fake_quant_4bit,
|
||
|
fused_wt_fake_quant_range_neg_127_to_127,
|
||
|
fused_per_channel_wt_fake_quant_range_neg_127_to_127,
|
||
|
)
|
||
|
|
||
|
from .observer import (
|
||
|
_PartialWrapper,
|
||
|
MinMaxObserver,
|
||
|
HistogramObserver,
|
||
|
MovingAverageMinMaxObserver,
|
||
|
NoopObserver,
|
||
|
PlaceholderObserver,
|
||
|
ReuseInputObserver,
|
||
|
default_debug_observer,
|
||
|
default_dynamic_quant_observer,
|
||
|
default_float_qparams_observer,
|
||
|
default_float_qparams_observer_4bit,
|
||
|
default_observer,
|
||
|
default_per_channel_weight_observer,
|
||
|
default_placeholder_observer,
|
||
|
default_weight_observer,
|
||
|
weight_observer_range_neg_127_to_127,
|
||
|
per_channel_weight_observer_range_neg_127_to_127,
|
||
|
default_reuse_input_observer,
|
||
|
ObserverBase,
|
||
|
)
|
||
|
import warnings
|
||
|
import copy
|
||
|
|
||
|
__all__ = [
|
||
|
"QConfig",
|
||
|
# TODO: deprecated, remove
|
||
|
"QConfigDynamic",
|
||
|
"default_qconfig",
|
||
|
"default_debug_qconfig",
|
||
|
"default_per_channel_qconfig",
|
||
|
"default_dynamic_qconfig",
|
||
|
"float16_dynamic_qconfig",
|
||
|
"float16_static_qconfig",
|
||
|
"per_channel_dynamic_qconfig",
|
||
|
"float_qparams_weight_only_qconfig",
|
||
|
"float_qparams_weight_only_qconfig_4bit",
|
||
|
"default_quint8_weight_qconfig",
|
||
|
"default_qat_qconfig",
|
||
|
"default_dynamic_qat_qconfig",
|
||
|
"default_weight_only_qconfig",
|
||
|
"default_activation_only_qconfig",
|
||
|
"default_qat_qconfig_v2",
|
||
|
"default_reuse_input_qconfig",
|
||
|
"default_symmetric_qnnpack_qconfig",
|
||
|
"default_per_channel_symmetric_qnnpack_qconfig",
|
||
|
"default_symmetric_qnnpack_qat_qconfig",
|
||
|
"default_per_channel_symmetric_qnnpack_qat_qconfig",
|
||
|
"default_embedding_qat_qconfig",
|
||
|
"default_embedding_qat_qconfig_4bit",
|
||
|
"get_default_qconfig",
|
||
|
"get_default_qat_qconfig",
|
||
|
"get_default_qconfig_dict",
|
||
|
"get_default_qat_qconfig_dict",
|
||
|
"QConfigAny",
|
||
|
"qconfig_equals",
|
||
|
|
||
|
]
|
||
|
|
||
|
class QConfig(namedtuple('QConfig', ['activation', 'weight'])):
|
||
|
"""
|
||
|
Describes how to quantize a layer or a part of the network by providing
|
||
|
settings (observer classes) for activations and weights respectively.
|
||
|
|
||
|
|
||
|
Note that QConfig needs to contain observer **classes** (like MinMaxObserver) or a callable that returns
|
||
|
instances on invocation, not the concrete observer instances themselves.
|
||
|
Quantization preparation function will instantiate observers multiple times for each of the layers.
|
||
|
|
||
|
|
||
|
Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args`
|
||
|
method (that behaves like functools.partial)::
|
||
|
|
||
|
my_qconfig = QConfig(
|
||
|
activation=MinMaxObserver.with_args(dtype=torch.qint8),
|
||
|
weight=default_observer.with_args(dtype=torch.qint8))
|
||
|
|
||
|
"""
|
||
|
def __new__(cls, activation, weight):
|
||
|
# catch common mistakes
|
||
|
if isinstance(activation, nn.Module) or isinstance(weight, nn.Module):
|
||
|
raise ValueError("QConfig received observer instance, please pass observer class instead. " +
|
||
|
"Use MyObserver.with_args(x=1) to override arguments to constructor if needed")
|
||
|
return super().__new__(cls, activation, weight)
|
||
|
|
||
|
|
||
|
class QConfigDynamic(namedtuple('QConfigDynamic', ['activation', 'weight'])):
|
||
|
"""
|
||
|
Describes how to dynamically quantize a layer or a part of the network by providing
|
||
|
settings (observer classes) for weights.
|
||
|
|
||
|
It's like QConfig, but for dynamic quantization.
|
||
|
|
||
|
Note that QConfigDynamic needs to contain observer **classes** (like MinMaxObserver) or a callable that returns
|
||
|
instances on invocation, not the concrete observer instances themselves.
|
||
|
Quantization function will instantiate observers multiple times for each of the layers.
|
||
|
|
||
|
Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args`
|
||
|
method (that behaves like functools.partial)::
|
||
|
|
||
|
my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8))
|
||
|
"""
|
||
|
def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity):
|
||
|
# catch common mistakes
|
||
|
if isinstance(weight, nn.Module):
|
||
|
raise ValueError("QConfigDynamic received observer instance, please pass observer class instead. " +
|
||
|
"Use MyObserver.with_args(x=1) to override arguments to constructor if needed")
|
||
|
warnings.warn("QConfigDynamic is going to be deprecated in PyTorch 1.12, please use QConfig instead")
|
||
|
return super().__new__(cls, activation, weight)
|
||
|
|
||
|
|
||
|
default_qconfig = QConfig(activation=default_observer,
|
||
|
weight=default_weight_observer)
|
||
|
"""
|
||
|
Default qconfig configuration.
|
||
|
"""
|
||
|
|
||
|
default_debug_qconfig = QConfig(weight=default_weight_observer,
|
||
|
activation=default_debug_observer)
|
||
|
"""
|
||
|
Default qconfig configuration for debugging.
|
||
|
"""
|
||
|
|
||
|
default_per_channel_qconfig = QConfig(activation=default_observer,
|
||
|
weight=default_per_channel_weight_observer)
|
||
|
"""
|
||
|
Default qconfig configuration for per channel weight quantization.
|
||
|
"""
|
||
|
|
||
|
default_dynamic_qconfig = QConfig(activation=default_dynamic_quant_observer,
|
||
|
weight=default_weight_observer)
|
||
|
"""
|
||
|
Default dynamic qconfig.
|
||
|
"""
|
||
|
|
||
|
float16_dynamic_qconfig = QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float16, is_dynamic=True),
|
||
|
weight=PlaceholderObserver.with_args(dtype=torch.float16))
|
||
|
"""
|
||
|
Dynamic qconfig with weights quantized to `torch.float16`.
|
||
|
"""
|
||
|
|
||
|
float16_static_qconfig = QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float16),
|
||
|
weight=PlaceholderObserver.with_args(dtype=torch.float16))
|
||
|
"""
|
||
|
Dynamic qconfig with both activations and weights quantized to `torch.float16`.
|
||
|
"""
|
||
|
|
||
|
per_channel_dynamic_qconfig = QConfig(activation=default_dynamic_quant_observer,
|
||
|
weight=default_per_channel_weight_observer)
|
||
|
"""
|
||
|
Dynamic qconfig with weights quantized per channel.
|
||
|
"""
|
||
|
|
||
|
float_qparams_weight_only_qconfig = QConfig(
|
||
|
activation=default_placeholder_observer,
|
||
|
weight=default_float_qparams_observer)
|
||
|
"""
|
||
|
Dynamic qconfig with weights quantized with a floating point zero_point.
|
||
|
"""
|
||
|
|
||
|
float_qparams_weight_only_qconfig_4bit = QConfig(
|
||
|
activation=default_placeholder_observer,
|
||
|
weight=default_float_qparams_observer_4bit)
|
||
|
|
||
|
default_qat_qconfig = QConfig(activation=default_fake_quant,
|
||
|
weight=default_weight_fake_quant)
|
||
|
"""
|
||
|
Default qconfig for QAT.
|
||
|
"""
|
||
|
|
||
|
default_dynamic_qat_qconfig = QConfig(activation=default_dynamic_fake_quant,
|
||
|
weight=default_weight_fake_quant)
|
||
|
"""
|
||
|
Default qconfig for dynamic QAT.
|
||
|
"""
|
||
|
|
||
|
default_weight_only_qconfig = QConfig(activation=torch.nn.Identity,
|
||
|
weight=default_weight_fake_quant)
|
||
|
"""
|
||
|
Default qconfig for quantizing weights only.
|
||
|
"""
|
||
|
|
||
|
default_activation_only_qconfig = QConfig(activation=default_fake_quant,
|
||
|
weight=torch.nn.Identity)
|
||
|
"""
|
||
|
Default qconfig for quantizing activations only.
|
||
|
"""
|
||
|
|
||
|
# QAT config that uses a fused observer + fake quant modules for optimized training performance.
|
||
|
# to modify the activation/weight observers, the default entries in fake_quantize.py can be modified.
|
||
|
default_qat_qconfig_v2 = QConfig(activation=default_fused_act_fake_quant, weight=default_fused_wt_fake_quant)
|
||
|
"""
|
||
|
Fused version of `default_qat_config`, has performance benefits.
|
||
|
"""
|
||
|
|
||
|
default_reuse_input_qconfig = QConfig(activation=default_reuse_input_observer,
|
||
|
weight=NoopObserver)
|
||
|
"""
|
||
|
Default qconfig for operators that reuse the observers from input Tensor, e.g. reshape
|
||
|
"""
|
||
|
|
||
|
def get_default_qconfig(backend='x86', version=0):
|
||
|
"""
|
||
|
Returns the default PTQ qconfig for the specified backend.
|
||
|
|
||
|
Args:
|
||
|
* `backend` (str): a string representing the target backend. Currently supports
|
||
|
`x86` (default), `fbgemm`, `qnnpack` and `onednn`.
|
||
|
|
||
|
Return:
|
||
|
qconfig
|
||
|
"""
|
||
|
supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"]
|
||
|
if backend not in supported_backends:
|
||
|
raise AssertionError(
|
||
|
"backend: " + str(backend) +
|
||
|
f" not supported. backend must be one of {supported_backends}"
|
||
|
)
|
||
|
|
||
|
if version == 0:
|
||
|
if backend == 'fbgemm':
|
||
|
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True),
|
||
|
weight=default_per_channel_weight_observer)
|
||
|
elif backend == 'qnnpack':
|
||
|
# TODO: make this compatible with xnnpack constraints
|
||
|
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),
|
||
|
weight=default_weight_observer)
|
||
|
elif backend == 'onednn':
|
||
|
if not torch.cpu._is_cpu_support_vnni():
|
||
|
warnings.warn(
|
||
|
"Default qconfig of oneDNN backend with reduce_range of false may have accuracy issues "
|
||
|
"on CPU without Vector Neural Network Instruction support.")
|
||
|
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),
|
||
|
weight=default_per_channel_weight_observer)
|
||
|
elif backend == 'x86':
|
||
|
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True),
|
||
|
weight=default_per_channel_weight_observer)
|
||
|
else:
|
||
|
# won't reach
|
||
|
qconfig = default_qconfig
|
||
|
else:
|
||
|
raise AssertionError("Version number: " + str(version) +
|
||
|
" in get_default_qconfig is not supported. Version number must be 0")
|
||
|
|
||
|
return qconfig
|
||
|
|
||
|
"""
|
||
|
Default, symmetric PTQ qconfig for the specified backend. And a per_channel
|
||
|
variant of the same.
|
||
|
|
||
|
Symmetric here applies to signed weights with zero point = 0, and additional
|
||
|
value restrictions. The activations are also signed 8-bit integers with this
|
||
|
qconfig.
|
||
|
|
||
|
* Once this change is merged [as of 3/17/22], with backend or qengine =
|
||
|
'qnnpack', some quantized operators with this symmetric qconfig may use
|
||
|
operators from xnnpack library.
|
||
|
|
||
|
** Support to use xnnpack ops with `qnnpack` backed for asymmetric
|
||
|
qconfig (returned by get_default_qconfig()) is not available yet.
|
||
|
|
||
|
* This qconfig uses signed activations and weights. Weights have added
|
||
|
restrictions such as zero point is forced to be 0, making the weights
|
||
|
symmetric, hence the name. And the 8-bit quantized values are
|
||
|
restricting to to [-127, +127], excluding -128.
|
||
|
|
||
|
* xnnpack has a requantization scale value restriction, 0x1p-32 <=
|
||
|
requantization_scale < 256.0 where, `requantization_scale = (input_scale
|
||
|
* kernel_scale) / (output_scale)`. Using this eps (w/ assumed max value
|
||
|
of 256) is to prevent requantization_scale to go below xnnpack lower
|
||
|
threshold.
|
||
|
"""
|
||
|
default_symmetric_qnnpack_qconfig = QConfig(activation=HistogramObserver.with_args(dtype=torch.qint8,
|
||
|
reduce_range=False,
|
||
|
eps=2 ** -12),
|
||
|
weight=weight_observer_range_neg_127_to_127)
|
||
|
|
||
|
default_per_channel_symmetric_qnnpack_qconfig = QConfig(activation=HistogramObserver.with_args(dtype=torch.qint8,
|
||
|
reduce_range=False,
|
||
|
eps=2 ** -12),
|
||
|
weight=per_channel_weight_observer_range_neg_127_to_127)
|
||
|
|
||
|
default_embedding_qat_qconfig = QConfig(activation=NoopObserver.with_args(dtype=torch.float32),
|
||
|
weight=default_embedding_fake_quant)
|
||
|
|
||
|
default_embedding_qat_qconfig_4bit = QConfig(activation=NoopObserver.with_args(dtype=torch.float32),
|
||
|
weight=default_embedding_fake_quant_4bit)
|
||
|
|
||
|
default_quint8_weight_qconfig = QConfig(activation=HistogramObserver, weight=MinMaxObserver)
|
||
|
|
||
|
def get_default_qat_qconfig(backend='x86', version=1):
|
||
|
"""
|
||
|
Returns the default QAT qconfig for the specified backend.
|
||
|
|
||
|
Args:
|
||
|
* `backend` (str): a string representing the target backend. Currently supports
|
||
|
`x86` (default), `fbgemm`, `qnnpack` and `onednn`.
|
||
|
* `version`: version, for backwards compatibility. Can be `None` or `1`.
|
||
|
|
||
|
Return:
|
||
|
qconfig
|
||
|
"""
|
||
|
supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"]
|
||
|
if backend not in supported_backends:
|
||
|
raise AssertionError(
|
||
|
"backend: " + str(backend) +
|
||
|
f" not supported. backend must be one of {supported_backends}"
|
||
|
)
|
||
|
|
||
|
# Histogram observer is too slow for quantization aware training
|
||
|
if version == 0:
|
||
|
if backend == 'fbgemm':
|
||
|
qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
|
||
|
quant_min=0,
|
||
|
quant_max=255,
|
||
|
reduce_range=True),
|
||
|
weight=default_per_channel_weight_fake_quant)
|
||
|
elif backend == 'qnnpack':
|
||
|
qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
|
||
|
quant_min=0,
|
||
|
quant_max=255,
|
||
|
reduce_range=False),
|
||
|
weight=default_weight_fake_quant)
|
||
|
elif backend == 'onednn':
|
||
|
qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
|
||
|
quant_min=0,
|
||
|
quant_max=255),
|
||
|
weight=default_per_channel_weight_fake_quant)
|
||
|
elif backend == 'x86':
|
||
|
qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
|
||
|
quant_min=0,
|
||
|
quant_max=255,
|
||
|
reduce_range=True),
|
||
|
weight=default_per_channel_weight_fake_quant)
|
||
|
else:
|
||
|
qconfig = default_qat_qconfig
|
||
|
# Use the fused observe + fake_quant modules for doing QAT.
|
||
|
elif version == 1:
|
||
|
if backend == 'fbgemm':
|
||
|
qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
|
||
|
quant_min=0,
|
||
|
quant_max=255,
|
||
|
reduce_range=True),
|
||
|
weight=default_fused_per_channel_wt_fake_quant)
|
||
|
elif backend == 'qnnpack':
|
||
|
# TODO: make this compatible with xnnpack constraints
|
||
|
qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
|
||
|
quant_min=0,
|
||
|
quant_max=255,
|
||
|
reduce_range=False),
|
||
|
weight=default_fused_wt_fake_quant)
|
||
|
elif backend == 'onednn':
|
||
|
qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
|
||
|
quant_min=0,
|
||
|
quant_max=255),
|
||
|
weight=default_fused_per_channel_wt_fake_quant)
|
||
|
elif backend == 'x86':
|
||
|
qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
|
||
|
quant_min=0,
|
||
|
quant_max=255,
|
||
|
reduce_range=True),
|
||
|
weight=default_fused_per_channel_wt_fake_quant)
|
||
|
else:
|
||
|
qconfig = default_qat_qconfig_v2
|
||
|
else:
|
||
|
raise AssertionError("Version number: " + str(version) +
|
||
|
"in get_default_qat_qconfig is not supported. Version number must be 0 or 1")
|
||
|
|
||
|
return qconfig
|
||
|
|
||
|
"""
|
||
|
Default symmetric QAT qconfig for qnnpack. And its per channel weight variant.
|
||
|
"""
|
||
|
default_symmetric_qnnpack_qat_qconfig = QConfig(
|
||
|
activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
|
||
|
quant_min=-128,
|
||
|
quant_max=127,
|
||
|
dtype=torch.qint8,
|
||
|
reduce_range=False,
|
||
|
eps=2 ** -12),
|
||
|
weight=fused_wt_fake_quant_range_neg_127_to_127)
|
||
|
|
||
|
default_per_channel_symmetric_qnnpack_qat_qconfig = QConfig(
|
||
|
activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
|
||
|
quant_min=-128,
|
||
|
quant_max=127,
|
||
|
dtype=torch.qint8,
|
||
|
reduce_range=False,
|
||
|
eps=2 ** -12),
|
||
|
weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127)
|
||
|
|
||
|
_default_fp32_placeholder_qconfig = QConfig(
|
||
|
activation=PlaceholderObserver.with_args(dtype=torch.float32),
|
||
|
weight=PlaceholderObserver.with_args(dtype=torch.float32)
|
||
|
)
|
||
|
|
||
|
_default_quint8_placeholder_qconfig = QConfig(
|
||
|
activation=PlaceholderObserver.with_args(dtype=torch.quint8),
|
||
|
# operators using this qconfig doesn't have weights
|
||
|
weight=None,
|
||
|
)
|
||
|
|
||
|
def get_default_qconfig_dict(backend='x86', version=0):
|
||
|
warnings.warn(
|
||
|
"torch.ao.quantization.get_default_qconfig_dict is deprecated and will be removed in "
|
||
|
"a future version. Please use torch.ao.quantization.get_default_qconfig_mapping instead.")
|
||
|
return torch.ao.quantization.get_default_qconfig_mapping(backend, version).to_dict()
|
||
|
|
||
|
def get_default_qat_qconfig_dict(backend='x86', version=1):
|
||
|
warnings.warn(
|
||
|
"torch.ao.quantization.get_default_qat_qconfig_dict is deprecated and will be removed in "
|
||
|
"a future version. Please use torch.ao.quantization.get_default_qat_qconfig_mapping instead.")
|
||
|
return torch.ao.quantization.get_default_qat_qconfig_mapping(backend, version).to_dict()
|
||
|
|
||
|
def _assert_valid_qconfig(qconfig: Optional[QConfig],
|
||
|
mod: torch.nn.Module) -> None:
|
||
|
"""
|
||
|
Verifies that this `qconfig` is valid.
|
||
|
"""
|
||
|
if qconfig is None:
|
||
|
return
|
||
|
is_conv_transpose_mod = (
|
||
|
isinstance(mod, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)))
|
||
|
if is_conv_transpose_mod:
|
||
|
if qconfig.weight is None:
|
||
|
# for now, we assume that any qconfig for ConvTranspose without a weight is valid
|
||
|
return
|
||
|
example_observer = qconfig.weight()
|
||
|
is_per_channel = (
|
||
|
isinstance(example_observer, (torch.ao.quantization.PerChannelMinMaxObserver,
|
||
|
torch.ao.quantization.MovingAveragePerChannelMinMaxObserver))
|
||
|
)
|
||
|
assert not is_per_channel, \
|
||
|
'Per channel weight observer is not supported yet for ConvTranspose{n}d.'
|
||
|
|
||
|
QConfigAny = Optional[QConfig]
|
||
|
QConfigAny.__module__ = "torch.ao.quantization.qconfig"
|
||
|
|
||
|
def _add_module_to_qconfig_obs_ctr(
|
||
|
qconfig: QConfigAny,
|
||
|
module: Optional[nn.Module]) -> Any:
|
||
|
r"""This is a helper function for use in quantization prepare that updates a qconfig so that
|
||
|
the constructors stored in the qconfig will create observers on the same device that
|
||
|
'module' is on. This is intended to be used when the qconfigs are propagated to each
|
||
|
module in order to avoid potential device alignment issues.
|
||
|
|
||
|
Args:
|
||
|
qconfig: QConfig with obs constructors stored in activation and weight
|
||
|
module: module which the qconfig is related to
|
||
|
|
||
|
Return:
|
||
|
qconfig: configured so that obs constructors set to construct on the same device as module
|
||
|
"""
|
||
|
|
||
|
if module is None or qconfig is None or qconfig._fields != ('activation', 'weight'):
|
||
|
return qconfig
|
||
|
|
||
|
def get_factory_kwargs_based_on_module_device():
|
||
|
assert isinstance(module, torch.nn.Module)
|
||
|
devices = {p.device for p in module.parameters()} | \
|
||
|
{p.device for p in module.buffers()}
|
||
|
device = next(iter(devices)) if len(devices) > 0 else None
|
||
|
return None if device is None else {'device': device}
|
||
|
|
||
|
def configure_constructor_to_put_obs_on_module_device(original_constructor):
|
||
|
try:
|
||
|
# check if constructor can accept factory_kwargs
|
||
|
check = original_constructor.with_args(factory_kwargs=None)
|
||
|
check()
|
||
|
return original_constructor.with_callable_args(factory_kwargs=get_factory_kwargs_based_on_module_device)
|
||
|
except AttributeError: # qconfig doesn't have activation or weight
|
||
|
return original_constructor
|
||
|
except TypeError: # the class doesn't accept factory_kwargs argument
|
||
|
return original_constructor
|
||
|
|
||
|
activation = configure_constructor_to_put_obs_on_module_device(qconfig.activation)
|
||
|
weight = configure_constructor_to_put_obs_on_module_device(qconfig.weight)
|
||
|
|
||
|
return QConfig(activation, weight)
|
||
|
|
||
|
_ObserverOrFakeQuantizeConstructor = Union[_PartialWrapper, Type[ObserverBase], Type[FakeQuantizeBase]]
|
||
|
|
||
|
def _obs_or_fq_ctr_equals(obs_or_fq1: _ObserverOrFakeQuantizeConstructor, obs_or_fq2: _ObserverOrFakeQuantizeConstructor):
|
||
|
if isinstance(obs_or_fq1, _PartialWrapper) and isinstance(obs_or_fq2, _PartialWrapper):
|
||
|
return _partial_wrapper_equals(obs_or_fq1, obs_or_fq2)
|
||
|
return obs_or_fq1 == obs_or_fq2
|
||
|
|
||
|
def _partial_wrapper_equals(obs_or_fq1: _PartialWrapper, obs_or_fq2: _PartialWrapper):
|
||
|
"""
|
||
|
Return whether the two partial wrappers are equal,
|
||
|
"""
|
||
|
# functools.partial has no __eq__ operator defined so '==' defaults to 'is'
|
||
|
obs_or_fq1_keywords = copy.copy(obs_or_fq1.p.keywords)
|
||
|
obs_or_fq2_keywords = copy.copy(obs_or_fq2.p.keywords)
|
||
|
keywords_equal = True
|
||
|
# compare observer constructor with _obs_or_fq_ctr_equals since direct compare would fail
|
||
|
if "observer" in obs_or_fq1_keywords and "observer" in obs_or_fq2_keywords:
|
||
|
keywords_equal = keywords_equal and _obs_or_fq_ctr_equals(obs_or_fq1_keywords["observer"], obs_or_fq2_keywords["observer"])
|
||
|
obs_or_fq1_keywords.pop("observer")
|
||
|
obs_or_fq2_keywords.pop("observer")
|
||
|
keywords_equal = keywords_equal and obs_or_fq1_keywords == obs_or_fq2_keywords
|
||
|
return obs_or_fq1.p.func == obs_or_fq2.p.func and obs_or_fq1.p.args == obs_or_fq2.p.args and keywords_equal
|
||
|
|
||
|
def qconfig_equals(q1: QConfigAny, q2: QConfigAny):
|
||
|
"""
|
||
|
Returns `True` if `q1` equals `q2`, and `False` otherwise.
|
||
|
"""
|
||
|
if q1 is None or q2 is None:
|
||
|
return q1 == q2
|
||
|
else:
|
||
|
assert q1 is not None and q2 is not None
|
||
|
try:
|
||
|
# Qconfig weight and activation can be either a partial wrapper,
|
||
|
# or an observer class. Special handling is required (above) for
|
||
|
# comparing partial wrappers.
|
||
|
activation_same = _obs_or_fq_ctr_equals(q1.activation, q2.activation)
|
||
|
weight_same = _obs_or_fq_ctr_equals(q1.weight, q2.weight)
|
||
|
return activation_same and weight_same
|
||
|
except AttributeError:
|
||
|
return q1 == q2
|
||
|
|
||
|
def _activation_is_memoryless(qconfig: QConfig):
|
||
|
"""
|
||
|
Return whether the observer for activations defined in the given QConfig is memoryless.
|
||
|
This means a MovingAverage observer with averaging constant equal to 1.
|
||
|
"""
|
||
|
def _is_memoryless(observer):
|
||
|
return hasattr(observer, "averaging_constant") and observer.averaging_constant == 1
|
||
|
act = qconfig.activation()
|
||
|
if isinstance(act, FakeQuantizeBase) and hasattr(act, "activation_post_process"):
|
||
|
return _is_memoryless(act.activation_post_process)
|
||
|
else:
|
||
|
return _is_memoryless(act)
|
||
|
|
||
|
def _is_reuse_input_qconfig(qconfig: Optional[QConfig]):
|
||
|
return qconfig is not None and \
|
||
|
isinstance(qconfig.activation(), ReuseInputObserver) and \
|
||
|
isinstance(qconfig.weight(), NoopObserver)
|