import importlib import inspect from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 from torch.onnx._internal import jit_utils, registration def register_quantized_ops(domain: str, version: int): # Register all quantized ops module = importlib.import_module("torch.onnx.symbolic_caffe2") quant_version_ops = inspect.getmembers(module) aten_q_ops = { "relu", "_empty_affine_quantized", "dequantize", "quantize_per_tensor", "upsample_nearest2d", "avg_pool2d", "reshape", "slice", "cat", "max_pool2d", "sigmoid", } for op, func in quant_version_ops: name = f"{domain}::{op}" if inspect.isfunction(func) and not registration.registry.is_registered_op( name, version ): if op in aten_q_ops: # Override the builtin aten ops registration.registry.register( f"aten::{op}", version, func, custom=True ) registration.registry.register(name, version, func) def _permute_helper(g: jit_utils.GraphContext, input, axes): quant_args = { "axes_i": axes, "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), } output = g.op("_caffe2::Int8Transpose", input, **quant_args) symbolic_helper._quantized_ops.add(output) return output def nchw2nhwc(g: jit_utils.GraphContext, input): axes = [0, 2, 3, 1] return _permute_helper(g, input, axes) def nhwc2nchw(g: jit_utils.GraphContext, input): axes = [0, 3, 1, 2] return _permute_helper(g, input, axes) def linear_prepack(g: jit_utils.GraphContext, weight, bias): # Mapping to a dummy caffe2 prepack node. # During the onnx -> c2 conversion we can look up original weight and bias # from this node output = g.op("_caffe2::WeightPrepack", weight, bias) symbolic_helper._quantized_ops.add(output) return output @symbolic_helper.parse_args("v", "v", "v", "f", "i") def linear(g: jit_utils.GraphContext, input, weight, bias, scale, zero_point): kwargs = { "Y_scale_f": scale, "Y_zero_point_i": zero_point, } output = g.op("_caffe2::Int8FC", input, weight, bias, **kwargs) symbolic_helper._quantized_ops.add(output) return output def conv_prepack( g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups ): # Mapping to a dummy caffe2 prepack node. # During the onnx -> c2 conversion we can look up original weight and bias # from this node output = g.op("_caffe2::WeightPrepack", input, weight, bias) symbolic_helper._quantized_ops.add(output) return output @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i") def conv2d( g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups, scale, zero_point, ): kernel_size = weight.node()["shape"][1:3] kwargs = { "strides_i": stride, "pads_i": padding + padding, "dilations_i": dilation, "group_i": groups, "kernels_i": kernel_size, "order_s": "NHWC", "Y_scale_f": scale, "Y_zero_point_i": zero_point, } output = g.op("_caffe2::Int8Conv", input, weight, bias, **kwargs) symbolic_helper._quantized_ops.add(output) return output @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i") def conv2d_relu( g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups, scale, zero_point, ): kernel_size = weight.node()["shape"][1:3] kwargs = { "strides_i": stride, "pads_i": padding + padding, "dilations_i": dilation, "group_i": groups, "kernels_i": kernel_size, "order_s": "NHWC", "Y_scale_f": scale, "Y_zero_point_i": zero_point, } output = g.op("_caffe2::Int8ConvRelu", input, weight, bias, **kwargs) symbolic_helper._quantized_ops.add(output) return output @symbolic_helper.parse_args("v", "v", "f", "i") def add(g: jit_utils.GraphContext, input_a, input_b, scale, zero_point): kwargs = { "Y_scale_f": scale, "Y_zero_point_i": zero_point, } output = g.op("_caffe2::Int8Add", input_a, input_b, **kwargs) symbolic_helper._quantized_ops.add(output) return output @symbolic_helper.parse_args("v") def relu(g: jit_utils.GraphContext, input): if input not in symbolic_helper._quantized_ops: return opset9.relu(g, input) kwargs = { "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), } output = g.op("_caffe2::Int8Relu", input, **kwargs) symbolic_helper._quantized_ops.add(output) return output @symbolic_helper.parse_args("v", "f", "i", "t") def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype): kwargs = { "Y_scale_f": scale, "Y_zero_point_i": zero_point, } output = g.op("_caffe2::Int8Quantize", input, **kwargs) symbolic_helper._quantized_ops.add(output) return output @symbolic_helper.parse_args("v") def dequantize(g: jit_utils.GraphContext, input): return g.op("_caffe2::Int8Dequantize", input) @symbolic_helper.parse_args("v", "t", "t", "t", "t", "t", "t", "t") def _empty_affine_quantized( g: jit_utils.GraphContext, input, shape, scale, zero_point, dtype, pin_memory, memory_format, layout, ): return input def upsample_nearest2d( g: jit_utils.GraphContext, input, output_size, align_corners=None, scales_h=None, scales_w=None, ): if input not in symbolic_helper._quantized_ops: return opset9.upsample_nearest2d(g, input, output_size, align_corners) # type: ignore[attr-defined] output_size = symbolic_helper._parse_arg(output_size, "is") kwargs = { "output_size_i": output_size, "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), } input = nchw2nhwc(g, input) output = g.op("_caffe2::Int8ResizeNearest", input, **kwargs) output = nhwc2nchw(g, output) symbolic_helper._quantized_ops.add(output) return output @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") def max_pool2d( g: jit_utils.GraphContext, input, kernel_size, stride, padding, dilation, ceil_mode, ): if input not in symbolic_helper._quantized_ops: return opset9.max_pool2d( # type: ignore[attr-defined] g, input, kernel_size, stride, padding, dilation, ceil_mode ) kwargs = { "strides_i": stride, "pads_i": padding + padding, "kernel_i": kernel_size[0], "order_s": "NHWC", "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), } input = nchw2nhwc(g, input) output = g.op("_caffe2::Int8MaxPool", input, **kwargs) output = nhwc2nchw(g, output) symbolic_helper._quantized_ops.add(output) return output @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") def avg_pool2d( g: jit_utils.GraphContext, input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override=None, ): if input not in symbolic_helper._quantized_ops: return opset9.avg_pool2d( # type: ignore[attr-defined] g, input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, ) kwargs = { "strides_i": stride, "pads_i": padding + padding, "kernel_i": kernel_size[0], "order_s": "NHWC", "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), } input = nchw2nhwc(g, input) output = g.op("_caffe2::Int8AveragePool", input, **kwargs) output = nhwc2nchw(g, output) symbolic_helper._quantized_ops.add(output) return output def reshape(g: jit_utils.GraphContext, input, shape): if input not in symbolic_helper._quantized_ops: return opset9.reshape(g, input, shape) kwargs = { "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), } output = g.op("_caffe2::Int8Reshape", input, shape, **kwargs) symbolic_helper._quantized_ops.add(output) return output @symbolic_helper.parse_args("v", "v", "v", "v", "i") def slice(g: jit_utils.GraphContext, input, dim, start, end, step): if input not in symbolic_helper._quantized_ops: return opset9.slice(g, input, dim, start, end, step) if step != 1: raise RuntimeError("ONNX quantized slice export only works for step 1.") start = symbolic_helper._parse_arg(start, "i") end = symbolic_helper._parse_arg(end, "i") dim = symbolic_helper._parse_arg(dim, "i") kwargs = { "start_idx_i": start, "end_idx_i": end, "dim_i": dim, "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), } output = g.op("_caffe2::Int8Slice", input, **kwargs) symbolic_helper._quantized_ops.add(output) return output def cat(g: jit_utils.GraphContext, tensor_list, dim, scale=None, zero_point=None): tensors = symbolic_helper._unpack_list(tensor_list) input = tensors[0] if input not in symbolic_helper._quantized_ops: return opset9.cat(g, tensor_list, dim) dim = symbolic_helper._parse_arg(dim, "i") kwargs = { "Y_scale_f": tensors[0].node()["Y_scale"], "Y_zero_point_i": tensors[0].node()["Y_zero_point"], } output = g.op("_caffe2::Int8Concat", *tensors, axis_i=dim, **kwargs) symbolic_helper._quantized_ops.add(output) return output @symbolic_helper.parse_args("v") def sigmoid(g: jit_utils.GraphContext, input): if input not in symbolic_helper._quantized_ops: return opset9.sigmoid(g, input) # Caffe2 expects the output scale to be 1/2^8 # and output zero_point to be 0 (quint8 type) out_scale = 1.0 / 256 zero_point = 0 kwargs = { "Y_scale_f": out_scale, "Y_zero_point_i": zero_point, } output = g.op("_caffe2::Int8Sigmoid", input, **kwargs) symbolic_helper._quantized_ops.add(output) return output