from typing import Any, List, Optional, Union import torch from torch import nn def _replace_relu(module: nn.Module) -> None: reassign = {} for name, mod in module.named_children(): _replace_relu(mod) # Checking for explicit type instead of instance # as we only want to replace modules of the exact type # not inherited classes if type(mod) is nn.ReLU or type(mod) is nn.ReLU6: reassign[name] = nn.ReLU(inplace=False) for key, value in reassign.items(): module._modules[key] = value def quantize_model(model: nn.Module, backend: str) -> None: _dummy_input_data = torch.rand(1, 3, 299, 299) if backend not in torch.backends.quantized.supported_engines: raise RuntimeError("Quantized backend not supported ") torch.backends.quantized.engine = backend model.eval() # Make sure that weight qconfig matches that of the serialized models if backend == "fbgemm": model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment] activation=torch.ao.quantization.default_observer, weight=torch.ao.quantization.default_per_channel_weight_observer, ) elif backend == "qnnpack": model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment] activation=torch.ao.quantization.default_observer, weight=torch.ao.quantization.default_weight_observer ) # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 model.fuse_model() # type: ignore[operator] torch.ao.quantization.prepare(model, inplace=True) model(_dummy_input_data) torch.ao.quantization.convert(model, inplace=True) def _fuse_modules( model: nn.Module, modules_to_fuse: Union[List[str], List[List[str]]], is_qat: Optional[bool], **kwargs: Any ): if is_qat is None: is_qat = model.training method = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules return method(model, modules_to_fuse, **kwargs)