52 lines
2.0 KiB
Python
52 lines
2.0 KiB
Python
|
from typing import Any, List, Optional, Union
|
||
|
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
|
||
|
|
||
|
def _replace_relu(module: nn.Module) -> None:
|
||
|
reassign = {}
|
||
|
for name, mod in module.named_children():
|
||
|
_replace_relu(mod)
|
||
|
# Checking for explicit type instead of instance
|
||
|
# as we only want to replace modules of the exact type
|
||
|
# not inherited classes
|
||
|
if type(mod) is nn.ReLU or type(mod) is nn.ReLU6:
|
||
|
reassign[name] = nn.ReLU(inplace=False)
|
||
|
|
||
|
for key, value in reassign.items():
|
||
|
module._modules[key] = value
|
||
|
|
||
|
|
||
|
def quantize_model(model: nn.Module, backend: str) -> None:
|
||
|
_dummy_input_data = torch.rand(1, 3, 299, 299)
|
||
|
if backend not in torch.backends.quantized.supported_engines:
|
||
|
raise RuntimeError("Quantized backend not supported ")
|
||
|
torch.backends.quantized.engine = backend
|
||
|
model.eval()
|
||
|
# Make sure that weight qconfig matches that of the serialized models
|
||
|
if backend == "fbgemm":
|
||
|
model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment]
|
||
|
activation=torch.ao.quantization.default_observer,
|
||
|
weight=torch.ao.quantization.default_per_channel_weight_observer,
|
||
|
)
|
||
|
elif backend == "qnnpack":
|
||
|
model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment]
|
||
|
activation=torch.ao.quantization.default_observer, weight=torch.ao.quantization.default_weight_observer
|
||
|
)
|
||
|
|
||
|
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
|
||
|
model.fuse_model() # type: ignore[operator]
|
||
|
torch.ao.quantization.prepare(model, inplace=True)
|
||
|
model(_dummy_input_data)
|
||
|
torch.ao.quantization.convert(model, inplace=True)
|
||
|
|
||
|
|
||
|
def _fuse_modules(
|
||
|
model: nn.Module, modules_to_fuse: Union[List[str], List[List[str]]], is_qat: Optional[bool], **kwargs: Any
|
||
|
):
|
||
|
if is_qat is None:
|
||
|
is_qat = model.training
|
||
|
method = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
|
||
|
return method(model, modules_to_fuse, **kwargs)
|