import torch from torch.fx import GraphModule # type: ignore from torch.fx.symbolic_trace import Tracer # type: ignore from torch.fx.node import Target, Node, Argument # type: ignore from .fx import Fuser # noqa: F401 from .fx import Quantizer # noqa: F401 from .fx.utils import graph_pretty_str # noqa: F401 from .fx.utils import get_custom_module_class_keys # noqa: F401 from torch.nn.intrinsic import _FusedModule from typing import Dict, Any, List, Callable, Tuple, Optional def _check_is_graph_module(model: torch.nn.Module) -> None: if not isinstance(model, GraphModule): raise ValueError( 'input model must be a GraphModule, ' + 'Got type:' + str(type(model)) + ' Please make ' + 'sure to follow the tutorials.') def _swap_ff_with_fxff(model: torch.nn.Module) -> None: r""" Swap FloatFunctional with FXFloatFunctional """ modules_to_swap = [] for name, module in model.named_children(): if isinstance(module, torch.nn.quantized.FloatFunctional): modules_to_swap.append(name) else: _swap_ff_with_fxff(module) for name in modules_to_swap: del model._modules[name] model._modules[name] = torch.nn.quantized.FXFloatFunctional() def _fuse_fx( graph_module: GraphModule, fuse_custom_config_dict: Dict[str, Any] = None) -> GraphModule: r""" Internal helper function to fuse modules in preparation for quantization Args: graph_module: GraphModule object from symbolic tracing (torch.fx.symbolic_trace) """ _check_is_graph_module(graph_module) fuser = Fuser() return fuser.fuse(graph_module, fuse_custom_config_dict) class Scope(object): """ Scope object that records the module path and the module type of a module. Scope is used to track the information of the module that contains a Node in a Graph of GraphModule. For example: class Sub(torch.nn.Module): def forward(self, x): # This will be a call_method Node in GraphModule, # scope for this would be (module_path="sub", module_type=Sub) return x.transpose(1, 2) class M(torch.nn.Module): def __init__(self): self.sub = Sub() def forward(self, x): # This will be a call_method Node as well, # scope for this would be (module_path="", None) x = x.transpose(1, 2) x = self.sub(x) return x """ def __init__(self, module_path: str, module_type: Any): super().__init__() self.module_path = module_path self.module_type = module_type class ScopeContextManager(object): """ A context manager to track the Scope of Node during symbolic tracing. When entering a forward function of a Module, we'll update the scope information of the current module, and when we exit, we'll restore the previous scope information. """ def __init__( self, scope: Scope, current_module: torch.nn.Module, current_module_path: str): super().__init__() self.prev_module_type = scope.module_type self.prev_module_path = scope.module_path self.scope = scope self.scope.module_path = current_module_path self.scope.module_type = type(current_module) def __enter__(self): return def __exit__(self, *args): self.scope.module_path = self.prev_module_path self.scope.module_type = self.prev_module_type return class QuantizationTracer(Tracer): def __init__( self, skipped_module_names: List[str], skipped_module_classes: List[Callable]): super().__init__() self.skipped_module_names = skipped_module_names self.skipped_module_classes = skipped_module_classes # NB: initialized the module_type of top level module to None # we are assuming people won't configure the model with the type of top level # module here, since people can use "" for global config # We can change this if there is a use case that configures # qconfig using top level module type self.scope = Scope("", None) self.node_name_to_scope : Dict[str, Tuple[str, type]] = {} def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: return (m.__module__.startswith("torch.nn") and not isinstance(m, torch.nn.Sequential)) or \ module_qualified_name in self.skipped_module_names or \ type(m) in self.skipped_module_classes or \ isinstance(m, _FusedModule) def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any: module_qualified_name = self.path_of_module(m) # Creating scope with information of current module # scope will be restored automatically upon exit with ScopeContextManager(self.scope, m, module_qualified_name): return super().call_module(m, forward, args, kwargs) def create_node(self, kind : str, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None, type_expr : Optional[Any] = None) -> Node: node = super().create_node(kind, target, args, kwargs, name, type_expr) self.node_name_to_scope[node.name] = (self.scope.module_path, self.scope.module_type) return node def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any, prepare_custom_config_dict: Dict[str, Any] = None, is_standalone_module: bool = False) -> GraphModule: r""" Internal helper function for prepare_fx Args: `model`, `qconfig_dict`, `prepare_custom_config_dict`: see docs for :func:`~torch.quantization.prepare_fx` `is_standalone_module`: a boolean flag indicates whether we are quantizing a standalone module or not, a standalone module is a submodule of the parent module that is not inlined in the forward graph of the parent module, the way we quantize standalone module is described in: :func:`~torch.quantization._prepare_standalone_module_fx` """ if prepare_custom_config_dict is None: prepare_custom_config_dict = {} skipped_module_names = prepare_custom_config_dict.get("non_traceable_module_name", []) skipped_module_classes = prepare_custom_config_dict.get("non_traceable_module_class", []) # swap FloatFunctional with FXFloatFunctional _swap_ff_with_fxff(model) # symbolically trace the model if not is_standalone_module: # standalone module and custom module config are applied in top level module standalone_module_name_configs = prepare_custom_config_dict.get("standalone_module_name", []) skipped_module_names += [config[0] for config in standalone_module_name_configs] standalone_module_class_configs = prepare_custom_config_dict.get("standalone_module_class", []) skipped_module_classes += [config[0] for config in standalone_module_class_configs] float_custom_module_classes = get_custom_module_class_keys( prepare_custom_config_dict, "float_to_observed_custom_module_class") skipped_module_classes += float_custom_module_classes tracer = QuantizationTracer( skipped_module_names, skipped_module_classes) graph_module = GraphModule(model, tracer.trace(model)) graph_module = _fuse_fx(graph_module, prepare_custom_config_dict) quantizer = Quantizer() prepared = quantizer.prepare( graph_module, qconfig_dict, tracer.node_name_to_scope, prepare_custom_config_dict=prepare_custom_config_dict, is_standalone_module=is_standalone_module) preserved_attributes = prepare_custom_config_dict.get("preserved_attributes", []) for attr_name in preserved_attributes: setattr(prepared, attr_name, getattr(model, attr_name)) return prepared def _prepare_standalone_module_fx( model: torch.nn.Module, qconfig_dict: Any, prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule: r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the parent module. standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. How the standalone module is observed is specified by `input_quantized_idxs` and `output_quantized_idxs` in the prepare_custom_config for the standalone module Returns: model(GraphModule): prepared standalone module attributes: _standalone_module_input_quantized_idxs(List[Int]): a list of indexes for the graph input that is expected to be quantized, same as input_quantized_idxs configuration provided for the standalone module _standalone_module_output_quantized_idxs(List[Int]): a list of indexs for the graph output that is quantized same as input_quantized_idxs configuration provided for the standalone module """ return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module=True) def fuse_fx(model: torch.nn.Module, fuse_custom_config_dict: Dict[str, Any] = None) -> GraphModule: r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode. Fusion rules are defined in torch.quantization.fx.fusion_pattern.py Args: `model`: a torch.nn.Module model `fuse_custom_config_dict`: Dictionary for custom configurations for fuse_fx, e.g. fuse_custom_config_dict = { "additional_fuser_method_mapping": { (Module1, Module2): fuse_module1_module2 } } Example: ```python from torch.quantization import fuse_fx m = Model().eval() m = fuse_fx(m) ``` """ torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx") assert not model.training, 'fuse_fx only works on models in eval mode' graph_module = torch.fx.symbolic_trace(model) # type: ignore return _fuse_fx(graph_module, fuse_custom_config_dict) def prepare_fx( model: torch.nn.Module, qconfig_dict: Any, prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule: r""" Prepare a model for post training static quantization Args: `model`: torch.nn.Module model, must be in eval mode `qconfig_dict`: qconfig_dict is a dictionary with the following configurations: qconfig_dict = { # optional, global config "": qconfig?, # optional, used for module and function types # could also be split into module_types and function_types if we prefer "object_type": [ (torch.nn.Conv2d, qconfig?), (torch.nn.functional.add, qconfig?), ..., ], # optional, used for module names "module_name": [ ("foo.bar", qconfig?) ..., ], # optional, matched in order, first match takes precedence "module_name_regex": [ ("foo.*bar.*conv[0-9]+", qconfig?) ..., ], # priority (in increasing order): global, object_type, module_name_regex, module_name # qconfig == None means fusion and quantization should be skipped for anything # matching the rule } `prepare_custom_config_dict`: customization configuration dictionary for quantization tool: prepare_custom_config_dict = { # optional: specify the path for standalone modules # These modules are symbolically traced and quantized as one unit "standalone_module_name": [ # module_name, qconfig_dict, prepare_custom_config_dict ("submodule.standalone", None, # qconfig_dict for the prepare function called in the submodule, # None means use qconfig from parent qconfig_dict {"input_quantized_idxs": [], "output_quantized_idxs": []}) # prepare_custom_config_dict ], "standalone_module_class": [ # module_class, qconfig_dict, prepare_custom_config_dict (StandaloneModule, None, # qconfig_dict for the prepare function called in the submodule, # None means use qconfig from parent qconfig_dict {"input_quantized_idxs": [0], "output_quantized_idxs": [0]}) # prepare_custom_config_dict ], # user will manually define the corresponding observed # module class which has a from_float class method that converts # float custom module to observed custom module # (only needed for static quantization) "float_to_observed_custom_module_class": { "static": { CustomModule: ObservedCustomModule } }, # the qualified names for the submodule that are not symbolically traceable "non_traceable_module_name": [ "non_traceable_module" ], # the module classes that are not symbolically traceable # we'll also put dynamic/weight_only custom module here "non_traceable_module_class": [ NonTraceableModule ], # Additional fuser_method mapping "additional_fuser_method_mapping": { (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn }, # Additioanl module mapping for qat "additional_qat_module_mapping": { torch.nn.intrinsic.ConvBn2d: torch.nn.qat.ConvBn2d }, # Additional fusion patterns "additional_fusion_pattern": { (torch.nn.BatchNorm2d, torch.nn.Conv2d): ConvReluFusionhandler }, # Additional quantization patterns "additional_quant_pattern": { torch.nn.Conv2d: ConvReluQuantizeHandler, (torch.nn.ReLU, torch.nn.Conv2d): ConvReluQuantizeHandler, } # By default, inputs and outputs of the graph are assumed to be in # fp32. Providing `input_quantized_idxs` will set the inputs with the # corresponding indices to be quantized. Providing # `output_quantized_idxs` will set the outputs with the corresponding # indices to be quantized. "input_quantized_idxs": [0], "output_quantized_idxs": [0], # Attributes that are not used in forward function will # be removed when constructing GraphModule, this is a list of attributes # to preserve as an attribute of the GraphModule even when they are # not used in the code "preserved_attributes": ["preserved_attr"], } Return: A GraphModule with observer (configured by qconfig_dict), ready for calibration Example: ```python import torch from torch.quantization import get_default_qconfig from torch.quantization import prepare_fx float_model.eval() graph_module = torch.fx.symbolic_trace(float_model) qconfig = get_default_qconfig('fbgemm') def calibrate(model, data_loader): model.eval() with torch.no_grad(): for image, target in data_loader: model(image) qconfig_dict = {"": qconfig} prepared_model = prepare_fx(graph_module, qconfig_dict) # Run calibration calibrate(prepared_model, sample_inference_data) ``` """ torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx") assert not model.training, 'prepare_fx only works for models in ' + \ 'eval mode' return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict) def prepare_qat_fx( model: torch.nn.Module, qconfig_dict: Any, prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule: r""" Prepare a model for quantization aware training Args: `model`: torch.nn.Module model, must be in train mode `qconfig_dict`: see :func:`~torch.quantization.prepare_fx` `prepare_custom_config_dict`: see :func:`~torch.quantization.prepare_fx` Return: A GraphModule with fake quant modules (configured by qconfig_dict), ready for quantization aware training Example: ```python import torch from torch.quantization import get_default_qat_qconfig from torch.quantization import prepare_fx qconfig = get_default_qat_qconfig('fbgemm') def train_loop(model, train_data): model.train() for image, target in data_loader: ... float_model.train() qconfig_dict = {"": qconfig} prepared_model = prepare_fx(float_model, qconfig_dict) # Run calibration train_loop(prepared_model, train_loop) ``` """ torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx") assert model.training, 'prepare_qat_fx only works for models in ' + \ 'train mode' return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict) def _convert_fx( graph_module: GraphModule, debug: bool, convert_custom_config_dict: Dict[str, Any] = None, is_standalone_module: bool = False) -> GraphModule: """ `is_standalone_module`: see docs in :func:`~torch.quantization.prepare_standalone_module_fx` """ if convert_custom_config_dict is None: convert_custom_config_dict = {} _check_is_graph_module(graph_module) quantizer = Quantizer() quantized = quantizer.convert(graph_module, debug, convert_custom_config_dict, is_standalone_module) preserved_attributes = convert_custom_config_dict.get("preserved_attributes", []) for attr_name in preserved_attributes: setattr(quantized, attr_name, getattr(graph_module, attr_name)) return quantized def convert_fx( graph_module: GraphModule, debug: bool = False, convert_custom_config_dict: Dict[str, Any] = None) -> GraphModule: r""" Convert a calibrated or trained model to a quantized model Args: `graph_module`: A prepared and calibrated/trained model (GraphModule) `debug`: flag for producing a debug friendly model (preserve weight attribute) `convert_custom_config_dict`: dictionary for custom configurations for convert function: convert_custom_config_dict = { # addtional object (module/operator) mappings that will overwrite the default # module mappingn "additional_object_mapping": { "static": { FloatModule: QuantizedModule, float_op: quantized_op }, "dynamic": { FloatModule: DynamicallyQuantizedModule, float_op: dynamically_quantized_op }, }, # user will manually define the corresponding quantized # module class which has a from_observed class method that converts # observed custom module to quantized custom module "observed_to_quantized_custom_module_class": { "static": { ObservedCustomModule: QuantizedCustomModule }, "dynamic": { ObservedCustomModule: QuantizedCustomModule }, "weight_only": { ObservedCustomModule: QuantizedCustomModule } }, # Attributes that are not used in forward function will # be removed when constructing GraphModule, this is a list of attributes # to preserve as an attribute of the GraphModule even when they are # not used in the code "preserved_attributes": ["preserved_attr"], } Return: A quantized model (GraphModule) Example: ```python # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training quantized_model = convert_fx(prepared_model) ``` """ torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx") return _convert_fx(graph_module, debug, convert_custom_config_dict) def _convert_standalone_module_fx( graph_module: GraphModule, debug: bool = False, convert_custom_config_dict: Dict[str, Any] = None) -> GraphModule: r""" [Internal use only] Convert a model produced by :func:`~torch.quantization.prepare_standalone_module_fx` and convert it to a quantized model Returns a quantized standalone module, whether input/output is quantized is specified by prepare_custom_config_dict, with input_quantized_idxs, output_quantized_idxs, please see docs for prepare_fx for details """ return _convert_fx(graph_module, debug, convert_custom_config_dict, is_standalone_module=True)