1234 lines
37 KiB
Python
1234 lines
37 KiB
Python
from __future__ import annotations
|
||
|
||
import functools
|
||
import sys
|
||
import warnings
|
||
from typing import List, Optional, Sequence, Tuple, Union
|
||
|
||
import torch
|
||
import torch._C._onnx as _C_onnx
|
||
import torch.onnx
|
||
from torch import _C
|
||
|
||
# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
|
||
from torch.onnx import (
|
||
_constants,
|
||
_type_utils,
|
||
errors,
|
||
symbolic_helper,
|
||
symbolic_opset9 as opset9,
|
||
)
|
||
from torch.onnx._globals import GLOBALS
|
||
from torch.onnx._internal import _beartype, jit_utils, registration
|
||
|
||
# EDITING THIS FILE? READ THIS FIRST!
|
||
# see Note [Edit Symbolic Files] in README.md
|
||
|
||
# This file exports ONNX ops for opset 10
|
||
# Opset 10 is supported by ONNX release 1.5.0
|
||
# release on 04/24/19
|
||
|
||
|
||
__all__ = [
|
||
"dequantize",
|
||
"div",
|
||
"embedding_bag",
|
||
"fake_quantize_per_tensor_affine",
|
||
"flip",
|
||
"fmod",
|
||
"isfinite",
|
||
"isinf",
|
||
"nan_to_num",
|
||
"quantize_per_tensor",
|
||
"quantized_add_relu",
|
||
"quantized_add",
|
||
"quantized_cat",
|
||
"quantized_conv1d_relu",
|
||
"quantized_conv2d_relu",
|
||
"quantized_conv3d_relu",
|
||
"quantized_conv1d",
|
||
"quantized_conv2d",
|
||
"quantized_conv3d",
|
||
"quantized_conv_transpose1d",
|
||
"quantized_conv_transpose2d",
|
||
"quantized_conv_transpose3d",
|
||
"quantized_group_norm",
|
||
"quantized_hardswish",
|
||
"quantized_instance_norm",
|
||
"quantized_layer_norm",
|
||
"quantized_leaky_relu",
|
||
"quantized_linear",
|
||
"quantized_linear_relu",
|
||
"quantized_mul",
|
||
"quantized_sigmoid",
|
||
"slice",
|
||
"sort",
|
||
"topk",
|
||
]
|
||
|
||
|
||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10)
|
||
|
||
|
||
def _apply_params(*args, **kwargs):
|
||
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
|
||
|
||
def _apply(fn):
|
||
return fn(*args, **kwargs)
|
||
|
||
return _apply
|
||
|
||
|
||
@_onnx_symbolic("aten::div")
|
||
@_beartype.beartype
|
||
def div(g: jit_utils.GraphContext, self, other, *args):
|
||
if len(args) == 0:
|
||
return opset9.true_divide(g, self, other)
|
||
else:
|
||
return _div_rounding_mode(g, self, other, *args)
|
||
|
||
|
||
@symbolic_helper.parse_args("v", "v", "s")
|
||
@_beartype.beartype
|
||
def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode):
|
||
if rounding_mode == "floor":
|
||
return _floor_divide(g, self, other)
|
||
else:
|
||
return opset9._div_rounding_mode(g, self, other, rounding_mode)
|
||
|
||
|
||
@_onnx_symbolic("aten::_floor_divide")
|
||
@_beartype.beartype
|
||
def _floor_divide(g: jit_utils.GraphContext, self, other):
|
||
if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other):
|
||
out = opset9.true_divide(g, self, other)
|
||
return g.op("Floor", out)
|
||
else:
|
||
# Integer division does trunction rounding
|
||
div = g.op("Div", self, other)
|
||
# Division is negative if: self < 0 != other < 0
|
||
zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
|
||
negative = g.op("Xor", g.op("Less", self, zero), g.op("Less", other, zero))
|
||
|
||
# For negative numbers with self % other != 0, subtract 1 to round down instead of up
|
||
mod = g.op("Mod", self, other, fmod_i=0)
|
||
fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero)))
|
||
|
||
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
|
||
fixup = g.op("Sub", div, one)
|
||
return g.op("Where", fixup_mask, fixup, div)
|
||
|
||
|
||
@_onnx_symbolic("aten::sort")
|
||
@symbolic_helper.parse_args("v", "i", "i", "none")
|
||
@_beartype.beartype
|
||
def sort(g: jit_utils.GraphContext, self, dim, decending, out=None):
|
||
return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out)
|
||
|
||
|
||
@_onnx_symbolic("aten::topk")
|
||
@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none")
|
||
@_beartype.beartype
|
||
def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
|
||
return symbolic_helper._topk_helper(
|
||
g, self, k, dim, largest=largest, sorted=sorted, out=out
|
||
)
|
||
|
||
|
||
def _aten_max_pool_onnx(
|
||
g: jit_utils.GraphContext,
|
||
self: _C.Value,
|
||
kernel_shape: Sequence[int],
|
||
strides: Sequence[int],
|
||
pads: Sequence[int],
|
||
dilations: Sequence[int],
|
||
ceil_mode: bool,
|
||
unbatched_rank: int,
|
||
) -> _C.Value:
|
||
self_rank = g.op("Size", g.op("Shape", self))
|
||
if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1
|
||
self = g.op(
|
||
"Unsqueeze",
|
||
self,
|
||
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
|
||
)
|
||
|
||
pool_result, _ = g.op(
|
||
"MaxPool",
|
||
self,
|
||
outputs=2,
|
||
ceil_mode_i=ceil_mode,
|
||
dilations_i=dilations,
|
||
kernel_shape_i=kernel_shape,
|
||
pads_i=pads,
|
||
strides_i=strides,
|
||
)
|
||
|
||
if self_rank == unbatched_rank:
|
||
pool_result = g.op(
|
||
"Squeeze",
|
||
pool_result,
|
||
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
|
||
)
|
||
|
||
return pool_result
|
||
|
||
|
||
# For MaxPool
|
||
def _adjust_attributes_of_max_pool(
|
||
expand_size: int,
|
||
kernel_size: Union[Sequence[int], int],
|
||
stride: Union[Sequence[int], int],
|
||
padding: Union[Sequence[int], int],
|
||
dilation: Union[Sequence[int], int],
|
||
) -> Tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]:
|
||
"""Adjust attributes of avg_pool to match ONNX specification."""
|
||
|
||
if isinstance(dilation, int):
|
||
dilation = [dilation] * expand_size
|
||
|
||
if isinstance(kernel_size, int):
|
||
kernel_shape = [kernel_size] * expand_size
|
||
else:
|
||
kernel_shape = kernel_size # type: ignore[assignment]
|
||
|
||
if isinstance(padding, int):
|
||
pads = [padding] * expand_size * 2 # type: ignore[operator, assignment]
|
||
elif len(padding) == 1:
|
||
pads = padding * expand_size * 2 # type: ignore[operator, assignment]
|
||
elif len(padding) == 2:
|
||
# 2D padding
|
||
pads = padding * 2 # type: ignore[operator, assignment]
|
||
elif len(padding) == 3:
|
||
# 3D padding
|
||
pads = padding * 2 # type: ignore[operator, assignment]
|
||
else:
|
||
# When padding is already done for all dimensions,
|
||
# we don't need to double it
|
||
# eg: (1, 1, 1, 1, 1, 1)
|
||
pads = padding # type: ignore[assignment]
|
||
|
||
if isinstance(stride, int):
|
||
strides = [stride] * expand_size
|
||
elif not stride:
|
||
strides = kernel_shape
|
||
else:
|
||
strides = stride # type: ignore[assignment]
|
||
|
||
return (kernel_shape, strides, pads, dilation)
|
||
|
||
|
||
def _aten_max_pool_with_indices_onnx(
|
||
g: jit_utils.GraphContext,
|
||
self: _C.Value,
|
||
kernel_shape: Sequence[int],
|
||
strides: Sequence[int],
|
||
pads: Sequence[int],
|
||
dilations: Sequence[int],
|
||
ceil_mode: bool,
|
||
unbatched_rank: int,
|
||
n_dims_one: Sequence[int],
|
||
n_dims_zero: Sequence[int],
|
||
n_dims_axes: Sequence[int],
|
||
) -> Tuple[_C.Value, Sequence[int]]:
|
||
self_rank = g.op("Size", g.op("Shape", self))
|
||
if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1
|
||
self = g.op(
|
||
"Unsqueeze",
|
||
self,
|
||
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
|
||
)
|
||
|
||
pool_result, indices = g.op(
|
||
"MaxPool",
|
||
self,
|
||
outputs=2,
|
||
ceil_mode_i=ceil_mode,
|
||
dilations_i=dilations,
|
||
kernel_shape_i=kernel_shape,
|
||
pads_i=pads,
|
||
strides_i=strides,
|
||
)
|
||
_, flatten_indices = g.op(
|
||
"MaxPool",
|
||
self,
|
||
outputs=2,
|
||
dilations_i=dilations,
|
||
kernel_shape_i=n_dims_one,
|
||
strides_i=n_dims_one,
|
||
)
|
||
|
||
ends = g.op("Constant", value_t=torch.tensor(n_dims_one))
|
||
starts = g.op("Constant", value_t=torch.tensor(n_dims_zero))
|
||
axes = g.op("Constant", value_t=torch.tensor(n_dims_axes))
|
||
|
||
delta = g.op("Slice", flatten_indices, starts, ends, axes)
|
||
indices = g.op("Sub", indices, delta)
|
||
|
||
if self_rank == unbatched_rank:
|
||
pool_result = g.op(
|
||
"Squeeze", pool_result, value_t=torch.tensor([0], dtype=torch.int64)
|
||
)
|
||
indices = g.op("Squeeze", indices, value_t=torch.tensor([0], dtype=torch.int64))
|
||
|
||
return (pool_result, indices)
|
||
|
||
|
||
@_onnx_symbolic(
|
||
"aten::max_pool1d",
|
||
decorate=[_apply_params("max_pool1d", 1, return_indices=False)],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::max_pool2d",
|
||
decorate=[_apply_params("max_pool2d", 2, return_indices=False)],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::max_pool3d",
|
||
decorate=[_apply_params("max_pool3d", 3, return_indices=False)],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::max_pool1d_with_indices",
|
||
decorate=[
|
||
_apply_params(
|
||
"max_pool1d_with_indices",
|
||
1,
|
||
return_indices=True,
|
||
)
|
||
],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::max_pool2d_with_indices",
|
||
decorate=[
|
||
_apply_params(
|
||
"max_pool2d_with_indices",
|
||
2,
|
||
return_indices=True,
|
||
)
|
||
],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::max_pool3d_with_indices",
|
||
decorate=[
|
||
_apply_params(
|
||
"max_pool3d_with_indices",
|
||
3,
|
||
return_indices=True,
|
||
)
|
||
],
|
||
)
|
||
@_beartype.beartype
|
||
def _max_pool(name: str, expand_size: int, return_indices: bool):
|
||
@symbolic_helper.quantized_args(True, False, False, False, False, False)
|
||
@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
|
||
def symbolic_fn(
|
||
g: jit_utils.GraphContext,
|
||
input: _C.Value,
|
||
kernel_size: Sequence[int],
|
||
stride: Sequence[int],
|
||
padding: Union[int, Sequence[int]],
|
||
dilation: Sequence[int],
|
||
ceil_mode: bool,
|
||
):
|
||
kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool(
|
||
expand_size, kernel_size, stride, padding, dilation
|
||
)
|
||
|
||
if return_indices:
|
||
return _aten_max_pool_with_indices_onnx(
|
||
g,
|
||
input,
|
||
kernel_shape,
|
||
strides,
|
||
pads,
|
||
dilations,
|
||
ceil_mode,
|
||
expand_size + 1,
|
||
([1] * expand_size),
|
||
([0] * expand_size),
|
||
([2 + i for i in range(expand_size)]),
|
||
)
|
||
else:
|
||
return _aten_max_pool_onnx(
|
||
g,
|
||
input,
|
||
kernel_shape,
|
||
strides,
|
||
pads,
|
||
dilations,
|
||
ceil_mode,
|
||
expand_size + 1,
|
||
)
|
||
|
||
return symbolic_fn
|
||
|
||
|
||
# For AvgPool
|
||
def _adjust_attributes_of_avg_pool(
|
||
expand_size: int,
|
||
kernel_size: Union[Sequence[int], int],
|
||
stride: Union[Sequence[int], int],
|
||
padding: Union[Sequence[int], int],
|
||
) -> Tuple[Sequence[int], Sequence[int], Sequence[int]]:
|
||
"""Adjust attributes of avg_pool to match ONNX specification."""
|
||
|
||
if isinstance(kernel_size, int):
|
||
kernel_shape = [kernel_size] * expand_size
|
||
else:
|
||
kernel_shape = kernel_size # type: ignore[assignment]
|
||
|
||
if isinstance(padding, int):
|
||
pads = [padding] * expand_size * 2
|
||
elif len(padding) == 1:
|
||
pads = padding * expand_size * 2 # type: ignore[operator, assignment]
|
||
elif len(padding) == 2:
|
||
pads = padding * expand_size # type: ignore[operator, assignment]
|
||
else:
|
||
pads = padding * 2 # type: ignore[operator, assignment]
|
||
|
||
if isinstance(stride, int):
|
||
strides = [stride] * expand_size
|
||
elif not stride:
|
||
strides = kernel_shape
|
||
else:
|
||
strides = stride # type: ignore[assignment]
|
||
|
||
return (kernel_shape, strides, pads)
|
||
|
||
|
||
@_onnx_symbolic(
|
||
"aten::avg_pool1d",
|
||
decorate=[_apply_params("avg_pool1d", 1)],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::avg_pool2d",
|
||
decorate=[_apply_params("avg_pool2d", 2)],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::avg_pool3d",
|
||
decorate=[_apply_params("avg_pool3d", 3)],
|
||
)
|
||
@_beartype.beartype
|
||
def _avg_pool(name, expand_size):
|
||
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
|
||
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
|
||
@_beartype.beartype
|
||
def symbolic_fn(
|
||
g,
|
||
input: _C.Value,
|
||
kernel_size: Sequence[int],
|
||
stride: Sequence[int],
|
||
padding: Union[int, Sequence[int]],
|
||
ceil_mode: int,
|
||
count_include_pad: int,
|
||
divisor_override=None,
|
||
):
|
||
kernel_shape, strides, pads = _adjust_attributes_of_avg_pool(
|
||
expand_size, kernel_size, stride, padding
|
||
)
|
||
|
||
result = g.op(
|
||
"AveragePool",
|
||
input,
|
||
ceil_mode_i=ceil_mode,
|
||
count_include_pad_i=count_include_pad,
|
||
kernel_shape_i=kernel_shape,
|
||
pads_i=pads,
|
||
strides_i=strides,
|
||
)
|
||
|
||
return result
|
||
|
||
return symbolic_fn
|
||
|
||
|
||
@_onnx_symbolic(
|
||
"aten::upsample_nearest1d",
|
||
decorate=[_apply_params("upsample_nearest1d", 3, "nearest")],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::upsample_nearest2d",
|
||
decorate=[_apply_params("upsample_nearest2d", 4, "nearest")],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::upsample_nearest3d",
|
||
decorate=[_apply_params("upsample_nearest3d", 5, "nearest")],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::upsample_linear1d",
|
||
decorate=[_apply_params("upsample_linear1d", 3, "linear")],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::upsample_bilinear2d",
|
||
decorate=[_apply_params("upsample_bilinear2d", 4, "linear")],
|
||
)
|
||
@_onnx_symbolic(
|
||
"aten::upsample_trilinear3d",
|
||
decorate=[_apply_params("upsample_trilinear3d", 5, "linear")],
|
||
)
|
||
@_beartype.beartype
|
||
def _interpolate(name, dim, interpolate_mode):
|
||
@symbolic_helper.quantized_args(True, False, False)
|
||
@_beartype.beartype
|
||
def symbolic_fn(g, input, output_size, *args):
|
||
scales, align_corners = symbolic_helper._get_interpolate_attributes(
|
||
g, interpolate_mode, args
|
||
)
|
||
symbolic_helper._interpolate_warning(interpolate_mode)
|
||
align_corners = symbolic_helper._maybe_get_scalar(align_corners)
|
||
if align_corners:
|
||
return symbolic_helper._unimplemented(name, "align_corners == True", input)
|
||
if scales is None:
|
||
scales = symbolic_helper._interpolate_size_to_scales(
|
||
g, input, output_size, dim
|
||
)
|
||
return g.op("Resize", input, scales, mode_s=interpolate_mode)
|
||
|
||
return symbolic_fn
|
||
|
||
|
||
@_onnx_symbolic("aten::__interpolate")
|
||
@_beartype.beartype
|
||
def __interpolate(
|
||
g: jit_utils.GraphContext,
|
||
input,
|
||
size,
|
||
scale_factor,
|
||
mode,
|
||
align_corners,
|
||
recompute_scale_factor,
|
||
antialias,
|
||
):
|
||
scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
|
||
g, input, size, scale_factor, mode, align_corners
|
||
)
|
||
return g.op("Resize", input, scales, mode_s=mode)
|
||
|
||
|
||
@_beartype.beartype
|
||
def _slice(
|
||
g: jit_utils.GraphContext,
|
||
input: torch._C.Value,
|
||
axes: Union[List, torch.Tensor, torch._C.Value],
|
||
starts: Union[List, torch.Tensor, torch._C.Value],
|
||
ends: Union[List, torch.Tensor, torch._C.Value],
|
||
steps: Optional[Union[List, torch.Tensor, torch._C.Value]] = None,
|
||
):
|
||
def is_none_value(value):
|
||
if value is None:
|
||
return True
|
||
return (
|
||
isinstance(value, torch._C.Value)
|
||
and value.node().kind() == "prim::Constant"
|
||
and isinstance(value.type(), _C.NoneType)
|
||
)
|
||
|
||
def to_slice_input(list_or_value, default_value=None):
|
||
# Convert input param into a 1D torch.Value.
|
||
if is_none_value(list_or_value) and default_value is not None:
|
||
list_or_value = [default_value]
|
||
|
||
if isinstance(list_or_value, (list, torch.Tensor)):
|
||
return g.op("Constant", value_t=torch.tensor(list_or_value))
|
||
|
||
rank = symbolic_helper._get_tensor_rank(list_or_value)
|
||
if rank == 0:
|
||
return symbolic_helper._unsqueeze_helper(g, list_or_value, [0])
|
||
if rank == 1:
|
||
return list_or_value
|
||
raise errors.SymbolicValueError(
|
||
f"Rank must be 0 or 1, not {rank}", list_or_value
|
||
)
|
||
|
||
def get_const_value(list_or_value):
|
||
if isinstance(list_or_value, (list, torch.Tensor)):
|
||
if len(list_or_value) == 1:
|
||
return list_or_value[0]
|
||
return None
|
||
return symbolic_helper._maybe_get_const(list_or_value, "i")
|
||
|
||
# Check if slice is a no-op
|
||
if (
|
||
get_const_value(starts) == 0
|
||
and get_const_value(ends) == _constants.INT64_MAX
|
||
and (steps is None or get_const_value(steps) == 1)
|
||
):
|
||
return input
|
||
|
||
axes = to_slice_input(axes)
|
||
starts = to_slice_input(starts, default_value=0)
|
||
ends = to_slice_input(ends, default_value=_constants.INT64_MAX)
|
||
if steps is None:
|
||
return g.op("Slice", input, starts, ends, axes)
|
||
steps = to_slice_input(steps, default_value=1)
|
||
return g.op("Slice", input, starts, ends, axes, steps)
|
||
|
||
|
||
@_onnx_symbolic("aten::slice")
|
||
@_beartype.beartype
|
||
def slice(g: jit_utils.GraphContext, self, *args):
|
||
if len(args) == 4:
|
||
# aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor
|
||
dims, start, end, step = args
|
||
elif len(args) == 3:
|
||
# aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[]
|
||
start, end, step = args
|
||
dims = [0]
|
||
else:
|
||
raise errors.SymbolicValueError("Unknown aten::slice signature", self)
|
||
|
||
return symbolic_helper._slice_helper(
|
||
g,
|
||
self,
|
||
axes=dims,
|
||
starts=start,
|
||
ends=end,
|
||
steps=step,
|
||
)
|
||
|
||
|
||
@_onnx_symbolic("aten::flip")
|
||
@symbolic_helper.parse_args("v", "is")
|
||
@_beartype.beartype
|
||
def flip(g: jit_utils.GraphContext, input, dims):
|
||
return symbolic_helper._slice_helper(
|
||
g,
|
||
input,
|
||
axes=dims,
|
||
starts=[-1] * len(dims),
|
||
ends=[-_constants.INT64_MAX] * len(dims),
|
||
steps=[-1] * len(dims),
|
||
)
|
||
|
||
|
||
@_onnx_symbolic("aten::fmod")
|
||
@_beartype.beartype
|
||
def fmod(g: jit_utils.GraphContext, input, other):
|
||
return g.op("Mod", input, other, fmod_i=1)
|
||
|
||
|
||
@_onnx_symbolic("aten::embedding_bag")
|
||
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
|
||
@_beartype.beartype
|
||
def embedding_bag(
|
||
g: jit_utils.GraphContext,
|
||
embedding_matrix,
|
||
indices,
|
||
offsets,
|
||
scale_grad_by_freq,
|
||
mode,
|
||
sparse,
|
||
per_sample_weights,
|
||
include_last_offset,
|
||
padding_idx,
|
||
):
|
||
if scale_grad_by_freq and GLOBALS.export_training:
|
||
return symbolic_helper._onnx_unsupported(
|
||
"embedding_bag with scale_grad_by_freq for training mode"
|
||
)
|
||
if padding_idx is not None and padding_idx >= 0:
|
||
raise RuntimeError("embedding_bag with padding_idx")
|
||
|
||
warnings.warn(
|
||
"Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. "
|
||
"Please use opset 11 or higher to export model for dynamic input shape.'"
|
||
)
|
||
offsets_dim_0 = symbolic_helper._get_tensor_dim_size(offsets, 0)
|
||
if offsets_dim_0 is not None:
|
||
if include_last_offset:
|
||
offset_len = offsets_dim_0 - 1
|
||
offsets_extended = offsets
|
||
else:
|
||
offset_len = offsets_dim_0
|
||
offsets_extended = [
|
||
offsets,
|
||
g.op("Constant", value_t=torch.tensor([sys.maxsize])),
|
||
]
|
||
offsets_extended = g.op("Concat", *offsets_extended, axis_i=0)
|
||
list_ = []
|
||
for i in range(offset_len):
|
||
start_ = symbolic_helper._unsqueeze_helper(
|
||
g,
|
||
opset9.select(g, offsets_extended, torch.tensor(0), torch.tensor(i)),
|
||
[0],
|
||
)
|
||
end_ = symbolic_helper._unsqueeze_helper(
|
||
g,
|
||
opset9.select(
|
||
g, offsets_extended, torch.tensor(0), torch.tensor(i + 1)
|
||
),
|
||
[0],
|
||
)
|
||
axes_ = g.op("Constant", value_t=torch.tensor([0]))
|
||
indices_row = g.op("Slice", indices, start_, end_, axes_)
|
||
|
||
embeddings = g.op("Gather", embedding_matrix, indices_row)
|
||
if not symbolic_helper._is_none(per_sample_weights):
|
||
per_sample_weights_row = g.op(
|
||
"Slice", per_sample_weights, start_, end_, axes_
|
||
)
|
||
per_sample_weights_row = symbolic_helper._unsqueeze_helper(
|
||
g, per_sample_weights_row, [1]
|
||
)
|
||
embeddings = g.op("Mul", embeddings, per_sample_weights_row)
|
||
if mode == 0:
|
||
embeddings = symbolic_helper._reducesum_helper(
|
||
g, embeddings, axes_i=[0], keepdims_i=0
|
||
)
|
||
elif mode == 1:
|
||
embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
|
||
else:
|
||
embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0)
|
||
|
||
embeddings = symbolic_helper._unsqueeze_helper(g, embeddings, [0])
|
||
list_.append(embeddings)
|
||
|
||
output = g.op("Concat", *list_, axis_i=0)
|
||
# aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
|
||
# But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
|
||
return output, None, None, None
|
||
else:
|
||
return symbolic_helper._onnx_unsupported(
|
||
"embedding_bag with unknown shape of offsets for opset 10 is not supported. "
|
||
"please use opset 11 or higher."
|
||
)
|
||
|
||
|
||
@_onnx_symbolic("aten::fake_quantize_per_tensor_affine")
|
||
@symbolic_helper.parse_args("v", "v", "v", "i", "i")
|
||
@_beartype.beartype
|
||
def fake_quantize_per_tensor_affine(
|
||
g: jit_utils.GraphContext,
|
||
inputs,
|
||
scale,
|
||
zero_point,
|
||
quant_min=-128,
|
||
quant_max=127,
|
||
):
|
||
# NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127).
|
||
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
|
||
if (quant_min, quant_max) == (0, 127):
|
||
symbolic_helper._onnx_opset_unsupported_detailed(
|
||
"fake_quantize_per_tensor_affine",
|
||
10,
|
||
13,
|
||
"Quantize range (0, 127) not supported, requires opset 13 Clip",
|
||
inputs,
|
||
)
|
||
if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
|
||
raise errors.SymbolicValueError(
|
||
f"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
|
||
f"Got ({quant_min}, {quant_max})",
|
||
inputs,
|
||
)
|
||
scale = symbolic_helper._maybe_get_scalar(scale)
|
||
if scale is None:
|
||
symbolic_helper._onnx_opset_unsupported_detailed(
|
||
"fake_quantize_per_tensor_affine",
|
||
10,
|
||
13,
|
||
"Non-constant scale not supported",
|
||
inputs,
|
||
)
|
||
scale = scale.float().data # Avoid exporter generating double type
|
||
if quant_min == 0:
|
||
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
|
||
else:
|
||
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
|
||
return g.op(
|
||
"DequantizeLinear",
|
||
g.op("QuantizeLinear", inputs, scale, zero_point),
|
||
scale,
|
||
zero_point,
|
||
)
|
||
|
||
|
||
@_onnx_symbolic("aten::isinf")
|
||
@_beartype.beartype
|
||
def isinf(g: jit_utils.GraphContext, input):
|
||
return g.op("IsInf", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE))
|
||
|
||
|
||
@_onnx_symbolic("aten::isfinite")
|
||
@_beartype.beartype
|
||
def isfinite(g: jit_utils.GraphContext, input):
|
||
inf_node = isinf(g, input)
|
||
nan_node = opset9.isnan(g, input)
|
||
return opset9.__not_(g, opset9.__or_(g, inf_node, nan_node))
|
||
|
||
|
||
@_onnx_symbolic("aten::quantize_per_tensor")
|
||
@_beartype.beartype
|
||
def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype):
|
||
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||
# TODO(justinchuby): Extract all the cast ops into a helper function.
|
||
zero_point = g.op(
|
||
"Cast", zero_point, to_i=_type_utils.JitScalarType(dtype).onnx_type()
|
||
)
|
||
scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
||
return symbolic_helper.quantize_helper(g, input, scale, zero_point)
|
||
|
||
|
||
@_onnx_symbolic("aten::dequantize")
|
||
@_beartype.beartype
|
||
def dequantize(g: jit_utils.GraphContext, input):
|
||
return symbolic_helper.dequantize_helper(g, input)[0]
|
||
|
||
|
||
@_onnx_symbolic("aten::nan_to_num")
|
||
@symbolic_helper.parse_args("v", "f", "f", "f")
|
||
@_beartype.beartype
|
||
def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf):
|
||
# Cannot create a int type tensor with inf/nan values, so we simply
|
||
# return the original tensor
|
||
if not symbolic_helper._is_fp(input):
|
||
return input
|
||
input_dtype = _type_utils.JitScalarType.from_value(input).dtype()
|
||
if nan is None:
|
||
nan = 0.0
|
||
nan_cond = opset9.isnan(g, input)
|
||
nan_result = g.op(
|
||
"Where",
|
||
nan_cond,
|
||
g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)),
|
||
input,
|
||
)
|
||
|
||
# For None values of posinf, neginf we use the greatest/lowest finite
|
||
# value representable by input’s dtype.
|
||
finfo = torch.finfo(input_dtype)
|
||
if posinf is None:
|
||
posinf = finfo.max
|
||
posinf_cond = opset9.logical_and(
|
||
g,
|
||
isinf(g, nan_result),
|
||
opset9.gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0]))),
|
||
)
|
||
nan_posinf_result = g.op(
|
||
"Where",
|
||
posinf_cond,
|
||
g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)),
|
||
nan_result,
|
||
)
|
||
|
||
if neginf is None:
|
||
neginf = finfo.min
|
||
neginf_cond = opset9.logical_and(
|
||
g,
|
||
isinf(g, nan_posinf_result),
|
||
opset9.lt(
|
||
g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0]))
|
||
),
|
||
)
|
||
return g.op(
|
||
"Where",
|
||
neginf_cond,
|
||
g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)),
|
||
nan_posinf_result,
|
||
)
|
||
|
||
|
||
# Quantized symbolics ---------------------------------------------------------
|
||
# https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
|
||
# Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were
|
||
# introduced in opset version 10.
|
||
@_onnx_symbolic("quantized::linear")
|
||
@_beartype.beartype
|
||
def quantized_linear(
|
||
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.linear(g, input, weight, bias)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::linear_relu")
|
||
@_beartype.beartype
|
||
def quantized_linear_relu(
|
||
g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.linear(g, input, weight, bias)
|
||
output = opset9.relu(g, output)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::add")
|
||
@_beartype.beartype
|
||
def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
|
||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
|
||
|
||
output = opset9.add(g, x, y)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::add_relu")
|
||
@_beartype.beartype
|
||
def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
|
||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
|
||
|
||
output = opset9.add(g, x, y)
|
||
output = opset9.relu(g, output)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::mul")
|
||
@_beartype.beartype
|
||
def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
|
||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
|
||
|
||
output = opset9.mul(g, x, y)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::hardswish")
|
||
@_beartype.beartype
|
||
def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
|
||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||
|
||
output = opset9.hardswish(g, x)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::sigmoid")
|
||
@_beartype.beartype
|
||
def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
|
||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||
|
||
output = opset9.sigmoid(g, x)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::leaky_relu")
|
||
@_beartype.beartype
|
||
def quantized_leaky_relu(
|
||
g: jit_utils.GraphContext, x, negative_slope, inplace, op_scale, op_zero_point
|
||
):
|
||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||
|
||
output = opset9.leaky_relu(g, x, negative_slope, inplace)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::layer_norm")
|
||
@_beartype.beartype
|
||
def quantized_layer_norm(
|
||
g: jit_utils.GraphContext,
|
||
x,
|
||
normalized_shape,
|
||
weight,
|
||
bias,
|
||
eps,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||
|
||
output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::group_norm")
|
||
@_beartype.beartype
|
||
def quantized_group_norm(
|
||
g: jit_utils.GraphContext,
|
||
x,
|
||
num_groups,
|
||
weight,
|
||
bias,
|
||
eps,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||
|
||
output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::instance_norm")
|
||
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v")
|
||
@_beartype.beartype
|
||
def quantized_instance_norm(
|
||
g: jit_utils.GraphContext,
|
||
q_input,
|
||
weight,
|
||
bias,
|
||
eps,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
|
||
output = opset9.instance_norm(
|
||
g, input, weight, bias, None, None, False, 0.0, eps, False
|
||
)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::conv1d_relu")
|
||
@_beartype.beartype
|
||
def quantized_conv1d_relu(
|
||
g: jit_utils.GraphContext,
|
||
q_input,
|
||
q_weight,
|
||
bias,
|
||
stride,
|
||
padding,
|
||
dilation,
|
||
groups,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups)
|
||
output = opset9.relu(g, output)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::conv2d_relu")
|
||
@_beartype.beartype
|
||
def quantized_conv2d_relu(
|
||
g: jit_utils.GraphContext,
|
||
q_input,
|
||
q_weight,
|
||
bias,
|
||
stride,
|
||
padding,
|
||
dilation,
|
||
groups,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
|
||
output = opset9.relu(g, output)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::conv3d_relu")
|
||
@_beartype.beartype
|
||
def quantized_conv3d_relu(
|
||
g: jit_utils.GraphContext,
|
||
q_input,
|
||
q_weight,
|
||
bias,
|
||
stride,
|
||
padding,
|
||
dilation,
|
||
groups,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups)
|
||
output = opset9.relu(g, output)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::conv1d")
|
||
@_beartype.beartype
|
||
def quantized_conv1d(
|
||
g: jit_utils.GraphContext,
|
||
q_input,
|
||
q_weight,
|
||
bias,
|
||
stride,
|
||
padding,
|
||
dilation,
|
||
groups,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::conv2d")
|
||
@_beartype.beartype
|
||
def quantized_conv2d(
|
||
g: jit_utils.GraphContext,
|
||
q_input,
|
||
q_weight,
|
||
bias,
|
||
stride,
|
||
padding,
|
||
dilation,
|
||
groups,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::conv3d")
|
||
@_beartype.beartype
|
||
def quantized_conv3d(
|
||
g: jit_utils.GraphContext,
|
||
q_input,
|
||
q_weight,
|
||
bias,
|
||
stride,
|
||
padding,
|
||
dilation,
|
||
groups,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::conv_transpose1d")
|
||
@_beartype.beartype
|
||
def quantized_conv_transpose1d(
|
||
g: jit_utils.GraphContext,
|
||
q_input,
|
||
q_weight,
|
||
bias,
|
||
stride,
|
||
padding,
|
||
output_padding,
|
||
dilation,
|
||
groups,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.conv_transpose2d(
|
||
g, input, weight, bias, stride, padding, output_padding, groups, dilation
|
||
)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::conv_transpose2d")
|
||
@_beartype.beartype
|
||
def quantized_conv_transpose2d(
|
||
g: jit_utils.GraphContext,
|
||
q_input,
|
||
q_weight,
|
||
bias,
|
||
stride,
|
||
padding,
|
||
output_padding,
|
||
dilation,
|
||
groups,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.conv_transpose2d(
|
||
g, input, weight, bias, stride, padding, output_padding, groups, dilation
|
||
)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::conv_transpose3d")
|
||
@_beartype.beartype
|
||
def quantized_conv_transpose3d(
|
||
g: jit_utils.GraphContext,
|
||
q_input,
|
||
q_weight,
|
||
bias,
|
||
stride,
|
||
padding,
|
||
output_padding,
|
||
dilation,
|
||
groups,
|
||
op_scale,
|
||
op_zero_point,
|
||
):
|
||
input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
|
||
weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
|
||
q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
|
||
bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
|
||
|
||
output = opset9.conv_transpose3d(
|
||
g, input, weight, bias, stride, padding, output_padding, groups, dilation
|
||
)
|
||
|
||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||
|
||
|
||
@_onnx_symbolic("quantized::cat")
|
||
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||
@_beartype.beartype
|
||
def quantized_cat(
|
||
g: jit_utils.GraphContext,
|
||
q_inputs: _C.Value,
|
||
dim: int,
|
||
op_scale: _C.Value,
|
||
op_zero_point: _C.Value,
|
||
) -> _C.Value:
|
||
unpacked_inputs = symbolic_helper._unpack_list(q_inputs)
|
||
dequantized = [
|
||
symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs
|
||
]
|
||
concatenated = g.op("Concat", *dequantized, axis_i=dim)
|
||
return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point)
|