188 lines
6.4 KiB
Python
188 lines
6.4 KiB
Python
|
"""This file exports ONNX ops for opset 16.
|
||
|
|
||
|
Note [ONNX Operators that are added/updated in opset 16]
|
||
|
|
||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set
|
||
|
New operators:
|
||
|
GridSample https://github.com/onnx/onnx/pull/3557
|
||
|
|
||
|
Updated operators:
|
||
|
Identity
|
||
|
If
|
||
|
LeakyRelu
|
||
|
Loop
|
||
|
PRelu
|
||
|
RoiAlign
|
||
|
Scan
|
||
|
ScatterElements
|
||
|
ScatterND
|
||
|
Where
|
||
|
GreaterOrEqual
|
||
|
LessOrEqual
|
||
|
"""
|
||
|
|
||
|
# EDITING THIS FILE? READ THIS FIRST!
|
||
|
# see Note [Edit Symbolic Files] in README.md
|
||
|
|
||
|
import functools
|
||
|
|
||
|
import torch
|
||
|
from torch.nn.functional import (
|
||
|
GRID_SAMPLE_INTERPOLATION_MODES,
|
||
|
GRID_SAMPLE_PADDING_MODES,
|
||
|
)
|
||
|
from torch.onnx import _type_utils, errors, symbolic_helper, utils
|
||
|
from torch.onnx._internal import _beartype, jit_utils, registration
|
||
|
|
||
|
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16)
|
||
|
|
||
|
|
||
|
# note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
|
||
|
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
|
||
|
@_onnx_symbolic("aten::grid_sampler")
|
||
|
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
|
||
|
@_beartype.beartype
|
||
|
def grid_sampler(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input,
|
||
|
grid,
|
||
|
mode_enum,
|
||
|
padding_mode_enum,
|
||
|
align_corners,
|
||
|
):
|
||
|
# Check the input and grid tensor rank beforehand.
|
||
|
if symbolic_helper._get_tensor_rank(input) == 5:
|
||
|
return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input")
|
||
|
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
|
||
|
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg]
|
||
|
return g.op(
|
||
|
"GridSample",
|
||
|
input,
|
||
|
grid,
|
||
|
align_corners_i=int(align_corners),
|
||
|
mode_s=mode_s,
|
||
|
padding_mode_s=padding_mode_s,
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::scatter_add")
|
||
|
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||
|
@_beartype.beartype
|
||
|
def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
|
||
|
if symbolic_helper.is_caffe2_aten_fallback():
|
||
|
return g.at("scatter", self, dim, index, src, overload_name="src")
|
||
|
|
||
|
src_type = _type_utils.JitScalarType.from_value(
|
||
|
src, _type_utils.JitScalarType.UNDEFINED
|
||
|
)
|
||
|
src_sizes = symbolic_helper._get_tensor_sizes(src)
|
||
|
index_sizes = symbolic_helper._get_tensor_sizes(index)
|
||
|
|
||
|
if len(src_sizes) != len(index_sizes):
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"scatter_add",
|
||
|
f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})",
|
||
|
)
|
||
|
|
||
|
# PyTorch only allows index shape <= src shape, so we can only consider
|
||
|
# taking index as subset size to src, like PyTorch does. When sizes for src
|
||
|
# and index are not matched or there are dynamic axes, we take index shape to
|
||
|
# slice src to accommodate.
|
||
|
if src_sizes != index_sizes or None in index_sizes:
|
||
|
adjusted_shape = g.op("Shape", index)
|
||
|
starts = g.op("Constant", value_t=torch.tensor([0] * len(index_sizes)))
|
||
|
src = g.op("Slice", src, starts, adjusted_shape)
|
||
|
|
||
|
src = symbolic_helper._maybe_get_scalar(src)
|
||
|
if symbolic_helper._is_value(src):
|
||
|
return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add")
|
||
|
else:
|
||
|
# Check if scalar "src" has same type as self (PyTorch allows different
|
||
|
# type for scalar src (but not when src is tensor)). If not, insert Cast node.
|
||
|
if _type_utils.JitScalarType.from_value(self) != src_type:
|
||
|
src = g.op(
|
||
|
"Cast",
|
||
|
src,
|
||
|
to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
|
||
|
)
|
||
|
|
||
|
return g.op(
|
||
|
"ScatterElements",
|
||
|
self,
|
||
|
index,
|
||
|
src,
|
||
|
axis_i=dim,
|
||
|
reduction_s="add",
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::scatter_reduce")
|
||
|
@symbolic_helper.parse_args("v", "i", "v", "v", "s", "b")
|
||
|
@_beartype.beartype
|
||
|
def scatter_reduce(
|
||
|
g: jit_utils.GraphContext,
|
||
|
self: torch._C.Value,
|
||
|
dim: int,
|
||
|
index: torch._C.Value,
|
||
|
src: torch._C.Value,
|
||
|
reduce: str,
|
||
|
include_self: bool,
|
||
|
):
|
||
|
if reduce == "mean":
|
||
|
raise errors.OnnxExporterError(
|
||
|
"ONNX does not support mean reduction for scatter_reduce"
|
||
|
)
|
||
|
if not include_self:
|
||
|
raise errors.OnnxExporterError(
|
||
|
"ONNX does not support include_self=False for scatter_reduce"
|
||
|
)
|
||
|
|
||
|
reduce_mode = { # convert torch string name to onnx string name
|
||
|
"mean": "none", # 'mean' doesn't support in ONNX 1.14 definition
|
||
|
"sum": "add",
|
||
|
"prod": "mul",
|
||
|
"amin": "min",
|
||
|
"amax": "max",
|
||
|
}
|
||
|
onnx_reduce = reduce_mode[reduce]
|
||
|
|
||
|
self_rank = g.op("Size", g.op("Shape", self))
|
||
|
|
||
|
# if self_rank == 0: # assert (index_rank == 0 and rank_src == 0)
|
||
|
self_rank_is_zero = g.op(
|
||
|
"Equal", self_rank, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
|
||
|
)
|
||
|
if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
|
||
|
g, "If", self_rank_is_zero, n_blocks=2, outputs=3
|
||
|
)
|
||
|
neg_1 = if_context.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
|
||
|
|
||
|
self_reshape = if_context.op("Reshape", self, neg_1)
|
||
|
utils._add_output_to_block(if_context.block, self_reshape)
|
||
|
index_reshape = if_context.op("Reshape", index, neg_1)
|
||
|
utils._add_output_to_block(if_context.block, index_reshape)
|
||
|
src_reshape = if_context.op("Reshape", src, neg_1)
|
||
|
utils._add_output_to_block(if_context.block, src_reshape)
|
||
|
|
||
|
self_identity = else_context.op("Identity", self)
|
||
|
utils._add_output_to_block(else_context.block, self_identity)
|
||
|
index_identitye = else_context.op("Identity", index)
|
||
|
utils._add_output_to_block(else_context.block, index_identitye)
|
||
|
src_identity = else_context.op("Identity", src)
|
||
|
utils._add_output_to_block(else_context.block, src_identity)
|
||
|
|
||
|
result = g.op("ScatterElements", *if_op, axis_i=dim, reduction_s=onnx_reduce)
|
||
|
|
||
|
# if self_rank == 0:
|
||
|
if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
|
||
|
g, "If", self_rank_is_zero, n_blocks=2, outputs=1
|
||
|
)
|
||
|
result_squeezed = if_context.op("Squeeze", result)
|
||
|
utils._add_output_to_block(if_context.block, result_squeezed)
|
||
|
result_identity = else_context.op("Identity", result)
|
||
|
utils._add_output_to_block(else_context.block, result_identity)
|
||
|
result_final = if_op.node().output()
|
||
|
|
||
|
return result_final
|