7208 lines
236 KiB
Python
7208 lines
236 KiB
Python
|
"""This file exports ONNX ops for opset 9.
|
||
|
|
||
|
Opset 9 is supported by ONNX release 1.4.1
|
||
|
release on 01/23/19
|
||
|
"""
|
||
|
from __future__ import annotations
|
||
|
|
||
|
import builtins
|
||
|
import functools
|
||
|
import math
|
||
|
import sys
|
||
|
import warnings
|
||
|
from typing import Callable, List, Optional, Sequence, Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
import torch._C._onnx as _C_onnx
|
||
|
import torch.nn.modules.utils
|
||
|
import torch.onnx
|
||
|
from torch import _C
|
||
|
|
||
|
# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
|
||
|
from torch.onnx import _constants, _deprecation, _type_utils, errors, symbolic_helper
|
||
|
from torch.onnx._globals import GLOBALS
|
||
|
from torch.onnx._internal import _beartype, jit_utils, registration
|
||
|
from torch.types import Number
|
||
|
|
||
|
# EDITING THIS FILE? READ THIS FIRST!
|
||
|
# see Note [Edit Symbolic Files] in README.md
|
||
|
|
||
|
__all__ = [
|
||
|
"abs",
|
||
|
"acos",
|
||
|
"add",
|
||
|
"addcmul",
|
||
|
"addmm",
|
||
|
"alias",
|
||
|
"amax",
|
||
|
"amin",
|
||
|
"aminmax",
|
||
|
"arange",
|
||
|
"argmax",
|
||
|
"argmin",
|
||
|
"as_strided",
|
||
|
"as_tensor",
|
||
|
"asin",
|
||
|
"atan",
|
||
|
"atan2",
|
||
|
"baddbmm",
|
||
|
"batch_norm",
|
||
|
"bernoulli",
|
||
|
"bitwise_not",
|
||
|
"bitwise_or",
|
||
|
"bmm",
|
||
|
"broadcast_tensors",
|
||
|
"broadcast_to",
|
||
|
"bucketize",
|
||
|
"cat",
|
||
|
"cdist",
|
||
|
"ceil",
|
||
|
"clamp_max",
|
||
|
"clamp_min",
|
||
|
"clamp",
|
||
|
"clone",
|
||
|
"constant_pad_nd",
|
||
|
"contiguous",
|
||
|
"conv_tbc",
|
||
|
"conv_transpose1d",
|
||
|
"conv_transpose2d",
|
||
|
"conv_transpose3d",
|
||
|
"conv1d",
|
||
|
"conv2d",
|
||
|
"conv3d",
|
||
|
"convert_element_type",
|
||
|
"convolution",
|
||
|
"cos",
|
||
|
"cosine_similarity",
|
||
|
"cross",
|
||
|
"cumsum",
|
||
|
"detach",
|
||
|
"dim",
|
||
|
"div",
|
||
|
"dot",
|
||
|
"dropout",
|
||
|
"elu",
|
||
|
"embedding_bag",
|
||
|
"embedding",
|
||
|
"empty_like",
|
||
|
"empty",
|
||
|
"eq",
|
||
|
"erf",
|
||
|
"exp",
|
||
|
"expand_as",
|
||
|
"expand",
|
||
|
"eye",
|
||
|
"fill",
|
||
|
"flatten",
|
||
|
"floor_divide",
|
||
|
"floor",
|
||
|
"floordiv",
|
||
|
"frobenius_norm",
|
||
|
"full_like",
|
||
|
"full",
|
||
|
"gather",
|
||
|
"ge",
|
||
|
"gelu",
|
||
|
"get_pool_ceil_padding",
|
||
|
"glu",
|
||
|
"group_norm",
|
||
|
"gt",
|
||
|
"hann_window",
|
||
|
"hardshrink",
|
||
|
"hardsigmoid",
|
||
|
"hardswish",
|
||
|
"hardtanh",
|
||
|
"index_add",
|
||
|
"index_copy",
|
||
|
"index_fill",
|
||
|
"index_put",
|
||
|
"index_select",
|
||
|
"index",
|
||
|
"instance_norm",
|
||
|
"is_floating_point",
|
||
|
"is_pinned",
|
||
|
"isnan",
|
||
|
"item",
|
||
|
"kl_div",
|
||
|
"layer_norm",
|
||
|
"le",
|
||
|
"leaky_relu",
|
||
|
"lerp",
|
||
|
"lift",
|
||
|
"linalg_cross",
|
||
|
"linalg_matrix_norm",
|
||
|
"linalg_norm",
|
||
|
"linalg_vector_norm",
|
||
|
"linear",
|
||
|
"linspace",
|
||
|
"log_sigmoid",
|
||
|
"log_softmax",
|
||
|
"log",
|
||
|
"log10",
|
||
|
"log1p",
|
||
|
"log2",
|
||
|
"logical_and",
|
||
|
"logical_not",
|
||
|
"logical_or",
|
||
|
"logical_xor",
|
||
|
"logit",
|
||
|
"logsumexp",
|
||
|
"lstm_cell",
|
||
|
"lstm",
|
||
|
"lt",
|
||
|
"masked_fill",
|
||
|
"masked_fill_",
|
||
|
"matmul",
|
||
|
"max_pool1d_with_indices",
|
||
|
"max_pool2d_with_indices",
|
||
|
"max_pool3d_with_indices",
|
||
|
"max",
|
||
|
"maximum",
|
||
|
"meshgrid",
|
||
|
"min",
|
||
|
"minimum",
|
||
|
"mish",
|
||
|
"mm",
|
||
|
"movedim",
|
||
|
"mse_loss",
|
||
|
"mul",
|
||
|
"multinomial",
|
||
|
"mv",
|
||
|
"narrow",
|
||
|
"native_layer_norm",
|
||
|
"ne",
|
||
|
"neg",
|
||
|
"new_empty",
|
||
|
"new_full",
|
||
|
"new_ones",
|
||
|
"new_zeros",
|
||
|
"nonzero_numpy",
|
||
|
"nonzero",
|
||
|
"norm",
|
||
|
"numel",
|
||
|
"numpy_T",
|
||
|
"one_hot",
|
||
|
"ones_like",
|
||
|
"ones",
|
||
|
"onnx_placeholder",
|
||
|
"overload_by_arg_count",
|
||
|
"pad",
|
||
|
"pairwise_distance",
|
||
|
"permute",
|
||
|
"pixel_shuffle",
|
||
|
"pixel_unshuffle",
|
||
|
"pow",
|
||
|
"prelu",
|
||
|
"prim_constant_chunk",
|
||
|
"prim_constant_split",
|
||
|
"prim_constant",
|
||
|
"prim_data",
|
||
|
"prim_device",
|
||
|
"prim_dtype",
|
||
|
"prim_if",
|
||
|
"prim_layout",
|
||
|
"prim_list_construct",
|
||
|
"prim_list_unpack",
|
||
|
"prim_loop",
|
||
|
"prim_max",
|
||
|
"prim_min",
|
||
|
"prim_shape",
|
||
|
"prim_tolist",
|
||
|
"prim_tuple_construct",
|
||
|
"prim_type",
|
||
|
"prim_unchecked_cast",
|
||
|
"prim_uninitialized",
|
||
|
"rand_like",
|
||
|
"rand",
|
||
|
"randint_like",
|
||
|
"randint",
|
||
|
"randn_like",
|
||
|
"randn",
|
||
|
"reciprocal",
|
||
|
"reflection_pad",
|
||
|
"relu",
|
||
|
"relu6",
|
||
|
"remainder",
|
||
|
"repeat_interleave",
|
||
|
"repeat",
|
||
|
"replication_pad",
|
||
|
"reshape_as",
|
||
|
"reshape",
|
||
|
"roll",
|
||
|
"rrelu",
|
||
|
"rsqrt",
|
||
|
"rsub",
|
||
|
"scalar_tensor",
|
||
|
"scatter_add",
|
||
|
"scatter",
|
||
|
"select",
|
||
|
"selu",
|
||
|
"sigmoid",
|
||
|
"sign",
|
||
|
"silu",
|
||
|
"sin",
|
||
|
"size",
|
||
|
"slice",
|
||
|
"softmax",
|
||
|
"softplus",
|
||
|
"softshrink",
|
||
|
"sort",
|
||
|
"split_with_sizes",
|
||
|
"split",
|
||
|
"sqrt",
|
||
|
"square",
|
||
|
"squeeze",
|
||
|
"stack",
|
||
|
"std_mean",
|
||
|
"std",
|
||
|
"sub",
|
||
|
"t",
|
||
|
"take",
|
||
|
"tan",
|
||
|
"tanh",
|
||
|
"tanhshrink",
|
||
|
"tensor",
|
||
|
"threshold",
|
||
|
"to",
|
||
|
"topk",
|
||
|
"transpose",
|
||
|
"true_divide",
|
||
|
"type_as",
|
||
|
"unbind",
|
||
|
"unfold",
|
||
|
"unsafe_chunk",
|
||
|
"unsafe_split_with_sizes",
|
||
|
"unsafe_split",
|
||
|
"unsqueeze",
|
||
|
"unsupported_complex_operators",
|
||
|
"noop_complex_operators",
|
||
|
"unused",
|
||
|
"var_mean",
|
||
|
"var",
|
||
|
"view_as",
|
||
|
"view",
|
||
|
"where",
|
||
|
"wrap_logical_op_with_cast_to",
|
||
|
"wrap_logical_op_with_negation",
|
||
|
"zeros_like",
|
||
|
"zeros",
|
||
|
"zero",
|
||
|
]
|
||
|
|
||
|
|
||
|
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9)
|
||
|
|
||
|
|
||
|
def _apply_params(*args, **kwargs):
|
||
|
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
|
||
|
|
||
|
def _apply(fn):
|
||
|
return fn(*args, **kwargs)
|
||
|
|
||
|
return _apply
|
||
|
|
||
|
|
||
|
def _export(name: str):
|
||
|
"""Exports the function in the current global namespace."""
|
||
|
|
||
|
def wrapper(func):
|
||
|
globals()[name] = func
|
||
|
__all__.append(name)
|
||
|
return func
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def unused(g):
|
||
|
"""Represents "missing" optional inputs."""
|
||
|
n = g.op("prim::Constant")
|
||
|
n.setType(_C.OptionalType.ofTensor())
|
||
|
return n
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_shape_as_tensor")
|
||
|
@_beartype.beartype
|
||
|
def _shape_as_tensor(g: jit_utils.GraphContext, input):
|
||
|
return g.op("Shape", input)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_reshape_from_tensor")
|
||
|
@_beartype.beartype
|
||
|
def _reshape_from_tensor(g: jit_utils.GraphContext, input, shape):
|
||
|
if isinstance(shape, list):
|
||
|
shape = g.op("Concat", *shape, axis_i=0)
|
||
|
return reshape(g, input, shape)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::reshape")
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@_beartype.beartype
|
||
|
def reshape(g: jit_utils.GraphContext, self, shape):
|
||
|
return symbolic_helper._reshape_helper(g, self, shape)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::reshape_as")
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@_beartype.beartype
|
||
|
def reshape_as(g: jit_utils.GraphContext, self, other):
|
||
|
shape = g.op("Shape", other)
|
||
|
return reshape(g, self, shape)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::add")
|
||
|
@_beartype.beartype
|
||
|
def add(g: jit_utils.GraphContext, self, other, alpha=None):
|
||
|
if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self):
|
||
|
return symbolic_helper._onnx_opset_unsupported_detailed(
|
||
|
"Add", 9, 11, "Add between list of tensors not supported", self
|
||
|
)
|
||
|
if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1:
|
||
|
other = g.op("Mul", other, alpha)
|
||
|
return g.op("Add", self, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::sub")
|
||
|
@_beartype.beartype
|
||
|
def sub(g: jit_utils.GraphContext, self, other, alpha=None):
|
||
|
if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1:
|
||
|
other = g.op("Mul", other, alpha)
|
||
|
return g.op("Sub", self, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::rsub")
|
||
|
@_beartype.beartype
|
||
|
def rsub(g: jit_utils.GraphContext, self, other, alpha=None):
|
||
|
return sub(g, other, self, alpha=alpha)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::mul")
|
||
|
@_beartype.beartype
|
||
|
def mul(g: jit_utils.GraphContext, self, other):
|
||
|
if symbolic_helper._is_bool(self) and symbolic_helper._is_bool(other):
|
||
|
# ONNX Mul doesn't support Boolean, so use And as an equivalent operator.
|
||
|
return g.op("And", self, other)
|
||
|
else:
|
||
|
return g.op("Mul", self, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::div")
|
||
|
@_beartype.beartype
|
||
|
def div(g: jit_utils.GraphContext, self, other, *args):
|
||
|
if len(args) == 0:
|
||
|
return true_divide(g, self, other)
|
||
|
else:
|
||
|
return _div_rounding_mode(g, self, other, *args)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::addcmul")
|
||
|
@symbolic_helper.parse_args("v", "v", "v", "f")
|
||
|
@_beartype.beartype
|
||
|
def addcmul(g: jit_utils.GraphContext, self, tensor1, tensor2, value=1.0):
|
||
|
value_tens = g.op("Constant", value_t=torch.tensor([value]))
|
||
|
return add(g, self, mul(g, mul(g, tensor1, tensor2), value_tens))
|
||
|
|
||
|
|
||
|
@symbolic_helper.parse_args("v", "v", "s")
|
||
|
@_beartype.beartype
|
||
|
def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode):
|
||
|
if rounding_mode is None:
|
||
|
return true_divide(g, self, other)
|
||
|
elif rounding_mode == "floor":
|
||
|
return _floor_divide(g, self, other)
|
||
|
elif rounding_mode == "trunc":
|
||
|
return _trunc_divide(g, self, other)
|
||
|
else:
|
||
|
raise errors.SymbolicValueError(
|
||
|
f'Unsupported rounding mode: "{rounding_mode}". Expected None, "floor" or "trunc"',
|
||
|
self,
|
||
|
)
|
||
|
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def _trunc_divide(g: jit_utils.GraphContext, self, other):
|
||
|
out = g.op("Div", self, other)
|
||
|
# the correct operation is truncate, which is not supported in ONNX,
|
||
|
# we cannot call floor since it will behave differently for negative numbers
|
||
|
# (eg. -0.1 should become -0 )
|
||
|
# - if scalar_type information are not available, assume that
|
||
|
# we need to call floor (treat as float)
|
||
|
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.INT64)
|
||
|
|
||
|
# Matching PyTorch's behavior:
|
||
|
# - if self is fp the output's type is self's type
|
||
|
# - if self is not fp and other is fp, the output is of type JitScalarType.FLOAT
|
||
|
# - self is not fp and other is not fp, the output's type is self's output type
|
||
|
# - the output type defaults to Float
|
||
|
scalar_type = _type_utils.JitScalarType.from_value(
|
||
|
self, _type_utils.JitScalarType.UNDEFINED
|
||
|
)
|
||
|
if scalar_type != _type_utils.JitScalarType.UNDEFINED:
|
||
|
if not symbolic_helper._is_fp(self) and symbolic_helper._is_fp(other):
|
||
|
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
||
|
else:
|
||
|
out = g.op(
|
||
|
"Cast",
|
||
|
out,
|
||
|
to_i=scalar_type.onnx_type(),
|
||
|
)
|
||
|
else:
|
||
|
out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
||
|
return out
|
||
|
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def _floor_divide(g: jit_utils.GraphContext, self, other):
|
||
|
if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other):
|
||
|
out = true_divide(g, self, other)
|
||
|
return g.op("Floor", out)
|
||
|
else:
|
||
|
# Integer division does trunction rounding
|
||
|
div = g.op("Div", self, other)
|
||
|
# Division is negative if: self < 0 != other < 0
|
||
|
zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
|
||
|
negative = g.op(
|
||
|
"Xor",
|
||
|
symbolic_helper._lt_helper(g, self, zero),
|
||
|
symbolic_helper._lt_helper(g, other, zero),
|
||
|
)
|
||
|
|
||
|
# For negative numbers with self % other != 0, subtract 1 to round down instead of up
|
||
|
mod = g.op("Sub", self, g.op("Mul", div, other))
|
||
|
fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero)))
|
||
|
|
||
|
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
|
||
|
fixup = g.op("Mul", fixup_mask, one)
|
||
|
return g.op("Sub", div, fixup)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::floor_divide")
|
||
|
@_beartype.beartype
|
||
|
def floor_divide(g: jit_utils.GraphContext, self, other):
|
||
|
# Deprecated behavior, floor_divide actually truncates
|
||
|
return _trunc_divide(g, self, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::floordiv")
|
||
|
@_beartype.beartype
|
||
|
def floordiv(g: jit_utils.GraphContext, self, other):
|
||
|
return floor_divide(g, self, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::true_divide")
|
||
|
@_beartype.beartype
|
||
|
def true_divide(g: jit_utils.GraphContext, self, other):
|
||
|
"""Division where both inputs are cast to floating types
|
||
|
|
||
|
If both inputs are floating, performs div as usual
|
||
|
If only one input is a floating type, the other input is cast to its type
|
||
|
If neither input is a floating type, both inputs are cast to the default scalar type
|
||
|
"""
|
||
|
|
||
|
# Case 1: either values are floating
|
||
|
# Performs div as usual.
|
||
|
# Implicit casting will be handled in scalar type analysis pass.
|
||
|
if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other):
|
||
|
return g.op("Div", self, other)
|
||
|
|
||
|
# Case 2: neither is floating
|
||
|
# Casts both inputs to the default scalar type
|
||
|
scalar_type = torch.get_default_dtype()
|
||
|
onnx_scalar_type = _C_onnx.TensorProtoDataType.FLOAT
|
||
|
assert scalar_type is torch.float or scalar_type is torch.double
|
||
|
if torch.get_default_dtype() is torch.double:
|
||
|
onnx_scalar_type = _C_onnx.TensorProtoDataType.DOUBLE
|
||
|
|
||
|
self = g.op("Cast", self, to_i=onnx_scalar_type)
|
||
|
other = g.op("Cast", other, to_i=onnx_scalar_type)
|
||
|
return g.op("Div", self, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::reciprocal")
|
||
|
@_beartype.beartype
|
||
|
def reciprocal(g: jit_utils.GraphContext, self):
|
||
|
# torch.reciprocal implicitly casts to float, so we do the same.
|
||
|
if not symbolic_helper._is_fp(self):
|
||
|
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
||
|
return g.op("Reciprocal", self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::cat")
|
||
|
@symbolic_helper.parse_args("v", "i")
|
||
|
@_beartype.beartype
|
||
|
def cat(g: jit_utils.GraphContext, tensor_list, dim):
|
||
|
tensors = symbolic_helper._unpack_list(tensor_list)
|
||
|
# torch.cat ignores empty tensors such as `torch.Tensor([])`
|
||
|
# These needs to be removed as input from ONNX's concat too, otherwise shape inference
|
||
|
# will likely fail due to inputs with different ranks (0 for empty tensor, > 0 for anything else)
|
||
|
nonempty_tensors = []
|
||
|
for t in tensors:
|
||
|
if symbolic_helper._is_constant(t) and not symbolic_helper._get_tensor_dim_size(
|
||
|
t, 0
|
||
|
):
|
||
|
continue
|
||
|
nonempty_tensors.append(t)
|
||
|
assert len(nonempty_tensors) > 0
|
||
|
assert all(
|
||
|
symbolic_helper._get_tensor_rank(nonempty_tensors[0]) is None
|
||
|
or symbolic_helper._get_tensor_rank(t) is None
|
||
|
or symbolic_helper._get_tensor_rank(t)
|
||
|
== symbolic_helper._get_tensor_rank(nonempty_tensors[0])
|
||
|
for t in nonempty_tensors
|
||
|
)
|
||
|
tensor_list.node().removeAllInputs()
|
||
|
for t in nonempty_tensors:
|
||
|
tensor_list.node().addInput(t)
|
||
|
|
||
|
tensors = symbolic_helper._unpack_list(tensor_list)
|
||
|
return g.op("Concat", *tensors, axis_i=dim)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::stack")
|
||
|
@symbolic_helper.parse_args("v", "i")
|
||
|
@_beartype.beartype
|
||
|
def stack(g: jit_utils.GraphContext, tensor_list, dim):
|
||
|
unsqueezed = [
|
||
|
symbolic_helper._unsqueeze_helper(g, t, [dim])
|
||
|
for t in symbolic_helper._unpack_list(tensor_list)
|
||
|
]
|
||
|
return g.op("Concat", *unsqueezed, axis_i=dim)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::list")
|
||
|
@_beartype.beartype
|
||
|
def _list(g: jit_utils.GraphContext, self):
|
||
|
return self
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::mm")
|
||
|
@_beartype.beartype
|
||
|
def mm(g: jit_utils.GraphContext, self, other):
|
||
|
# Create a dummy C tensor. Only needed for API purposes, the value is
|
||
|
# since beta = 0
|
||
|
C = g.op("Constant", value_t=torch.tensor([1]))
|
||
|
return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::bmm")
|
||
|
@_beartype.beartype
|
||
|
def bmm(g: jit_utils.GraphContext, self, other):
|
||
|
return g.op("MatMul", self, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::matmul")
|
||
|
@_beartype.beartype
|
||
|
def matmul(g: jit_utils.GraphContext, self, other):
|
||
|
return g.op("MatMul", self, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::addmm")
|
||
|
@symbolic_helper.parse_args("v", "v", "v", "t", "t")
|
||
|
@_beartype.beartype
|
||
|
def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha):
|
||
|
scalar_type = None
|
||
|
self_scalar_type = symbolic_helper._try_get_scalar_type(self)
|
||
|
mat1_scalar_type = symbolic_helper._try_get_scalar_type(mat1)
|
||
|
mat2_scalar_type = symbolic_helper._try_get_scalar_type(mat2)
|
||
|
if self_scalar_type is not None:
|
||
|
scalar_type = self_scalar_type
|
||
|
elif mat1_scalar_type is not None:
|
||
|
scalar_type = mat1_scalar_type
|
||
|
elif mat2_scalar_type is not None:
|
||
|
scalar_type = mat2_scalar_type
|
||
|
|
||
|
mat1_rank = symbolic_helper._get_tensor_rank(mat1)
|
||
|
mat2_rank = symbolic_helper._get_tensor_rank(mat2)
|
||
|
|
||
|
def is_not_none_nor(v, u):
|
||
|
return v is not None and v != u
|
||
|
|
||
|
if scalar_type is not None and (
|
||
|
is_not_none_nor(mat1_rank, 2) or is_not_none_nor(mat2_rank, 2)
|
||
|
):
|
||
|
res1 = g.op("MatMul", mat1, mat2)
|
||
|
res2 = self
|
||
|
|
||
|
alpha = symbolic_helper._scalar(alpha)
|
||
|
beta = symbolic_helper._scalar(beta)
|
||
|
|
||
|
if alpha != 1:
|
||
|
alpha = g.op(
|
||
|
"Constant", value_t=torch.tensor(alpha, dtype=scalar_type.dtype())
|
||
|
)
|
||
|
res1 = g.op("Mul", res1, alpha)
|
||
|
if beta != 1:
|
||
|
beta = g.op(
|
||
|
"Constant",
|
||
|
value_t=torch.tensor(
|
||
|
symbolic_helper._scalar(beta), dtype=scalar_type.dtype()
|
||
|
),
|
||
|
)
|
||
|
res2 = g.op("Mul", res2, beta)
|
||
|
|
||
|
return g.op("Add", res1, res2)
|
||
|
|
||
|
return g.op(
|
||
|
"Gemm",
|
||
|
mat1,
|
||
|
mat2,
|
||
|
self,
|
||
|
beta_f=symbolic_helper._scalar(beta),
|
||
|
alpha_f=symbolic_helper._scalar(alpha),
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::neg")
|
||
|
@_beartype.beartype
|
||
|
def neg(g: jit_utils.GraphContext, self):
|
||
|
return g.op("Neg", self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::sqrt")
|
||
|
@_beartype.beartype
|
||
|
def sqrt(g: jit_utils.GraphContext, self):
|
||
|
if _type_utils.JitScalarType.from_value(
|
||
|
self, _type_utils.JitScalarType.UNDEFINED
|
||
|
) in {
|
||
|
_type_utils.JitScalarType.UINT8,
|
||
|
_type_utils.JitScalarType.INT8,
|
||
|
_type_utils.JitScalarType.INT16,
|
||
|
_type_utils.JitScalarType.INT,
|
||
|
_type_utils.JitScalarType.INT64,
|
||
|
}:
|
||
|
# torch converts all int inputs to sqrt to float
|
||
|
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
||
|
|
||
|
return g.op("Sqrt", self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::rsqrt")
|
||
|
@_beartype.beartype
|
||
|
def rsqrt(g: jit_utils.GraphContext, self):
|
||
|
return g.op(
|
||
|
"Div", symbolic_helper._if_scalar_type_as(torch.ones(1), self), sqrt(g, self)
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::tanh")
|
||
|
# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qtanh.cpp
|
||
|
@symbolic_helper.quantized_args(True, scale=2.0 / 256.0, zero_point=128)
|
||
|
@_beartype.beartype
|
||
|
def tanh(g: jit_utils.GraphContext, self):
|
||
|
return g.op("Tanh", self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::sin")
|
||
|
@_beartype.beartype
|
||
|
def sin(g: jit_utils.GraphContext, self):
|
||
|
return g.op("Sin", self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::cos")
|
||
|
@_beartype.beartype
|
||
|
def cos(g: jit_utils.GraphContext, self):
|
||
|
return g.op("Cos", self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::tan")
|
||
|
@_beartype.beartype
|
||
|
def tan(g: jit_utils.GraphContext, self):
|
||
|
return g.op("Tan", self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::asin")
|
||
|
@_beartype.beartype
|
||
|
def asin(g: jit_utils.GraphContext, self):
|
||
|
return g.op("Asin", self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::acos")
|
||
|
@_beartype.beartype
|
||
|
def acos(g: jit_utils.GraphContext, self):
|
||
|
return g.op("Acos", self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::atan")
|
||
|
@_beartype.beartype
|
||
|
def atan(g: jit_utils.GraphContext, self):
|
||
|
return g.op("Atan", self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::atan2")
|
||
|
@_beartype.beartype
|
||
|
def atan2(g: jit_utils.GraphContext, self, other):
|
||
|
# self is y, and other is x on coordinate
|
||
|
slope = g.op("Div", self, other)
|
||
|
atan = g.op("Atan", slope)
|
||
|
const_zero = g.op("Constant", value_t=torch.tensor(0))
|
||
|
const_pi = g.op("Constant", value_t=torch.tensor(math.pi))
|
||
|
|
||
|
condition_second_or_third_quadrant = g.op("Greater", self, const_zero)
|
||
|
second_third_quadrant = g.op(
|
||
|
"Where",
|
||
|
condition_second_or_third_quadrant,
|
||
|
g.op("Add", atan, const_pi),
|
||
|
g.op("Sub", atan, const_pi),
|
||
|
)
|
||
|
|
||
|
condition_14_or_23_quadrant = g.op("Less", other, const_zero)
|
||
|
result = g.op("Where", condition_14_or_23_quadrant, second_third_quadrant, atan)
|
||
|
|
||
|
return result
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::sigmoid")
|
||
|
# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qsigmoid.cpp
|
||
|
@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0)
|
||
|
@_beartype.beartype
|
||
|
def sigmoid(g: jit_utils.GraphContext, self):
|
||
|
return g.op("Sigmoid", self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::sign")
|
||
|
@_beartype.beartype
|
||
|
def sign(g: jit_utils.GraphContext, self):
|
||
|
return g.op("Sign", self)
|
||
|
|
||
|
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@_beartype.beartype
|
||
|
def _slice(g: jit_utils.GraphContext, input, axes, starts, ends):
|
||
|
assert len(starts) == len(ends)
|
||
|
if len(starts) == 1 and starts[0] == 0 and ends[0] == _constants.INT64_MAX:
|
||
|
return input
|
||
|
return g.op("Slice", input, axes_i=axes, starts_i=starts, ends_i=ends)
|
||
|
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def _maybe_cast_reduce_op_input(g: jit_utils.GraphContext, self):
|
||
|
scalar_type = _type_utils.JitScalarType.from_value(
|
||
|
self, _type_utils.JitScalarType.UNDEFINED
|
||
|
)
|
||
|
if scalar_type != _type_utils.JitScalarType.UNDEFINED:
|
||
|
# This check only covers traced modules where dtype is present
|
||
|
# pytorch reduce-ops cast all other integral types to int64
|
||
|
if (
|
||
|
not symbolic_helper._is_fp(self)
|
||
|
and scalar_type != _type_utils.JitScalarType.INT64
|
||
|
):
|
||
|
self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.INT64)
|
||
|
return self
|
||
|
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def _reduce_op_symbolic(onnx_op_name, allow_multi_dim_support=True):
|
||
|
@_beartype.beartype
|
||
|
def symbolic(g, self, dim=None, keepdim=None):
|
||
|
self = _maybe_cast_reduce_op_input(g, self)
|
||
|
if dim is None or dim == tuple():
|
||
|
# Dim can be 0, which will cause (not dim) == True. So we don't want to do
|
||
|
# (not dim)
|
||
|
# all-reduce path
|
||
|
return symbolic_helper._handle_reduce_dim_none(g, self, onnx_op_name)
|
||
|
else:
|
||
|
# dim-reduce path
|
||
|
desc = "is" if allow_multi_dim_support else "i"
|
||
|
dim, keepdim = symbolic_helper._get_const(
|
||
|
dim, desc, "dim"
|
||
|
), symbolic_helper._get_const(keepdim, "i", "keepdim")
|
||
|
dim_list = dim if allow_multi_dim_support else [dim]
|
||
|
return g.op(onnx_op_name, self, axes_i=dim_list, keepdims_i=keepdim)
|
||
|
|
||
|
return symbolic
|
||
|
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def overload_by_arg_count(fn):
|
||
|
@functools.wraps(fn)
|
||
|
@_beartype.beartype
|
||
|
def wrapper(g, *args):
|
||
|
overloads = fn(g, *args)
|
||
|
for overload in overloads:
|
||
|
arg_descriptors = overload._arg_descriptors
|
||
|
if len(arg_descriptors) == len(args):
|
||
|
return overload(g, *args)
|
||
|
return symbolic_helper._unimplemented(
|
||
|
f"aten::{fn.__name__}", f"with {len(args)} arguments"
|
||
|
)
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::sum", decorate=[_apply_params("ReduceSum", "sum")])
|
||
|
@_onnx_symbolic("aten::mean", decorate=[_apply_params("ReduceMean", "mean")])
|
||
|
# torch.prod does not support multidimensional "dim"
|
||
|
@_onnx_symbolic(
|
||
|
"aten::prod",
|
||
|
decorate=[_apply_params("ReduceProd", "prod", allow_multi_dim_support=False)],
|
||
|
)
|
||
|
@_beartype.beartype
|
||
|
def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True):
|
||
|
symbolic = _reduce_op_symbolic(
|
||
|
onnx_op, allow_multi_dim_support=allow_multi_dim_support
|
||
|
)
|
||
|
|
||
|
@overload_by_arg_count
|
||
|
def reduce(g, *args, **kwargs):
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@symbolic_helper.parse_args("v", "none")
|
||
|
def reduce_nodim(g, self, dtype):
|
||
|
dtype_onnx = None
|
||
|
if dtype.node().kind() == "onnx::Constant":
|
||
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||
|
dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type()
|
||
|
self = g.op("Cast", self, to_i=dtype_onnx)
|
||
|
elif dtype.node().kind() != "prim::Constant":
|
||
|
return symbolic_helper._unimplemented(name, "dtype", dtype)
|
||
|
result = symbolic(g, self)
|
||
|
if dtype_onnx is not None:
|
||
|
result_dtype_onnx = _type_utils.JitScalarType.from_value(
|
||
|
result
|
||
|
).onnx_type()
|
||
|
if result_dtype_onnx != dtype_onnx:
|
||
|
result = g.op("Cast", result, to_i=dtype_onnx)
|
||
|
return result
|
||
|
|
||
|
dim_desc = "is" if allow_multi_dim_support else "i"
|
||
|
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@symbolic_helper.parse_args("v", dim_desc, "i", "none") # type: ignore[arg-type]
|
||
|
def reduce_dim(g, self, dim, keepdim, dtype):
|
||
|
dtype_onnx = None
|
||
|
if dtype.node().kind() == "onnx::Constant":
|
||
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||
|
dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type()
|
||
|
self = g.op("Cast", self, to_i=dtype_onnx)
|
||
|
elif dtype.node().kind() != "prim::Constant":
|
||
|
return symbolic_helper._unimplemented(name, "dtype", dtype)
|
||
|
result = symbolic(g, self, dim, keepdim)
|
||
|
if dtype_onnx is not None:
|
||
|
result_dtype_onnx = _type_utils.JitScalarType.from_value(
|
||
|
result
|
||
|
).onnx_type()
|
||
|
if result_dtype_onnx != dtype_onnx:
|
||
|
result = g.op("Cast", result, to_i=dtype_onnx)
|
||
|
return result
|
||
|
|
||
|
return reduce_nodim, reduce_dim
|
||
|
|
||
|
return reduce
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::cumsum")
|
||
|
@symbolic_helper.parse_args("v", "i", "none")
|
||
|
@_beartype.beartype
|
||
|
def cumsum(g: jit_utils.GraphContext, input, dim, dtype):
|
||
|
if symbolic_helper.is_caffe2_aten_fallback():
|
||
|
if dtype.node().kind() != "prim::Constant":
|
||
|
return symbolic_helper._unimplemented("cumsum", "dtype", dtype)
|
||
|
return g.at("cumsum", input, dim_i=dim)
|
||
|
|
||
|
symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11, input)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_sample_dirichlet")
|
||
|
@_beartype.beartype
|
||
|
def _sample_dirichlet(g: jit_utils.GraphContext, self, generator):
|
||
|
if symbolic_helper.is_caffe2_aten_fallback():
|
||
|
if not symbolic_helper._is_none(generator):
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"_sample_dirichlet", "We are not able to export generator", self
|
||
|
)
|
||
|
return g.at("_sample_dirichlet", self)
|
||
|
return symbolic_helper._onnx_unsupported("_sample_dirichlet", self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_standard_gamma")
|
||
|
@_beartype.beartype
|
||
|
def _standard_gamma(g: jit_utils.GraphContext, self, generator):
|
||
|
if symbolic_helper.is_caffe2_aten_fallback():
|
||
|
if not symbolic_helper._is_none(generator):
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"_standard_gamma", "not able to export generator", self
|
||
|
)
|
||
|
return g.at("_standard_gamma", self)
|
||
|
|
||
|
return symbolic_helper._onnx_unsupported("_standard_gamma", self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::t")
|
||
|
@_beartype.beartype
|
||
|
def t(g: jit_utils.GraphContext, self):
|
||
|
rank = symbolic_helper._get_tensor_rank(self)
|
||
|
if rank is None or rank < 2:
|
||
|
# The transpose of a 1d or 0d tensor is itself. ONNX does not define the behavior
|
||
|
# clearly and onnxruntime fails on these cases. So we add an Identity node to
|
||
|
# mirror the behavior of eager mode.
|
||
|
return g.op("Identity", self)
|
||
|
return g.op("Transpose", self, perm_i=(1, 0))
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::numpy_T")
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@_beartype.beartype
|
||
|
def numpy_T(g: jit_utils.GraphContext, input):
|
||
|
ndim = symbolic_helper._get_tensor_rank(input)
|
||
|
assert ndim is not None
|
||
|
perm = list(reversed(range(0, ndim)))
|
||
|
return g.op("Transpose", input, perm_i=perm)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::expand")
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@_beartype.beartype
|
||
|
def expand(g: jit_utils.GraphContext, self, size, implicit):
|
||
|
size = symbolic_helper._maybe_get_const(size, "is")
|
||
|
if not symbolic_helper._is_value(size):
|
||
|
size = g.op("Constant", value_t=torch.LongTensor(size))
|
||
|
elif symbolic_helper._is_packed_list(size):
|
||
|
# Expand with -1 dim value means dim is unchanged.
|
||
|
# Since onnx::expand supports two-way broadcasting,
|
||
|
# -1 dim value can be exported to onnx as 1
|
||
|
size = symbolic_helper._reshape_helper(
|
||
|
g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1]))
|
||
|
)
|
||
|
dtype = _type_utils.JitScalarType.INT64
|
||
|
ones = ones_like(g, size, dtype)
|
||
|
neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1)))
|
||
|
size = where(g, g.op("Equal", size, neg_ones), ones, size)
|
||
|
return g.op("Expand", self, size)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::broadcast_to")
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@_beartype.beartype
|
||
|
def broadcast_to(g: jit_utils.GraphContext, self, size):
|
||
|
size = symbolic_helper._maybe_get_const(size, "is")
|
||
|
if not symbolic_helper._is_value(size):
|
||
|
size = g.op("Constant", value_t=torch.LongTensor(size))
|
||
|
elif symbolic_helper._is_packed_list(size):
|
||
|
# Expand with -1 dim value means dim is unchanged.
|
||
|
# Since onnx::expand supports two-way broadcasting,
|
||
|
# -1 dim value can be exported to onnx as 1
|
||
|
size = symbolic_helper._reshape_helper(
|
||
|
g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1]))
|
||
|
)
|
||
|
dtype = _type_utils.JitScalarType.INT64
|
||
|
ones = ones_like(g, size, dtype)
|
||
|
neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1)))
|
||
|
size = where(g, g.op("Equal", size, neg_ones), ones, size)
|
||
|
return g.op("Expand", self, size)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::expand_as")
|
||
|
@symbolic_helper.quantized_args(True, True)
|
||
|
@_beartype.beartype
|
||
|
def expand_as(g: jit_utils.GraphContext, self, other):
|
||
|
self_t = symbolic_helper._maybe_get_const(self, "t")
|
||
|
if isinstance(self_t, torch.Tensor):
|
||
|
orig_type = self_t.dtype
|
||
|
self_t = self_t.to(torch.double)
|
||
|
dims = []
|
||
|
for d in range(self_t.dim()):
|
||
|
if torch.equal(self_t.mean(d).unsqueeze(d).expand_as(self_t), self_t):
|
||
|
dims.append(d)
|
||
|
self = g.op(
|
||
|
"Constant", value_t=self_t.mean(dims, keepdim=True).to(orig_type)
|
||
|
)
|
||
|
|
||
|
shape = g.op("Shape", other)
|
||
|
return g.op("Expand", self, shape)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::embedding")
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@symbolic_helper.parse_args("v", "v", "i", "b", "v")
|
||
|
@_beartype.beartype
|
||
|
def embedding(
|
||
|
g: jit_utils.GraphContext,
|
||
|
weight,
|
||
|
indices,
|
||
|
padding_idx,
|
||
|
scale_grad_by_freq,
|
||
|
sparse,
|
||
|
):
|
||
|
if scale_grad_by_freq and GLOBALS.export_training:
|
||
|
raise errors.SymbolicValueError(
|
||
|
"Unsupported: ONNX export of embedding with scale_grad_by_freq=True "
|
||
|
"for training mode. ONNX does not support scaling the gradients.",
|
||
|
weight,
|
||
|
)
|
||
|
if padding_idx >= 0 and GLOBALS.export_training:
|
||
|
warnings.warn(
|
||
|
"Warning: ONNX export of embedding with padding_idx >= 0 "
|
||
|
"for training mode. "
|
||
|
"ONNX does not support not updating the embedding vector at padding_idx during training."
|
||
|
)
|
||
|
|
||
|
return g.op("Gather", weight, indices)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::embedding_bag")
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
|
||
|
@_beartype.beartype
|
||
|
def embedding_bag(
|
||
|
g: jit_utils.GraphContext,
|
||
|
embedding_matrix,
|
||
|
indices,
|
||
|
offsets,
|
||
|
scale_grad_by_freq,
|
||
|
mode,
|
||
|
sparse,
|
||
|
per_sample_weights,
|
||
|
include_last_offset,
|
||
|
padding_idx,
|
||
|
):
|
||
|
if not symbolic_helper._is_none(per_sample_weights):
|
||
|
return symbolic_helper._onnx_unsupported(
|
||
|
"embedding_bag with per_sample_weights"
|
||
|
)
|
||
|
if symbolic_helper.is_caffe2_aten_fallback():
|
||
|
return g.at(
|
||
|
"embedding_bag",
|
||
|
embedding_matrix,
|
||
|
indices,
|
||
|
offsets,
|
||
|
outputs=4,
|
||
|
scale_grad_by_freq_i=scale_grad_by_freq,
|
||
|
mode_i=mode,
|
||
|
sparse_i=sparse,
|
||
|
include_last_offset_i=include_last_offset,
|
||
|
padding_idx_i=padding_idx,
|
||
|
)
|
||
|
|
||
|
return symbolic_helper._onnx_unsupported("embedding_bag", embedding_matrix)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::size")
|
||
|
@symbolic_helper.quantized_args(True, quantize_output=False)
|
||
|
@_beartype.beartype
|
||
|
def size(g: jit_utils.GraphContext, self, dim=None):
|
||
|
if dim is None:
|
||
|
return g.op("Shape", self)
|
||
|
if symbolic_helper._maybe_get_const(dim, "i") < 0:
|
||
|
rank = symbolic_helper._get_tensor_rank(self)
|
||
|
if rank is not None:
|
||
|
dim = symbolic_helper._maybe_get_const(dim, "i") + rank
|
||
|
dim = g.op("Constant", value_t=torch.tensor(dim))
|
||
|
return symbolic_helper._size_helper(g, self, dim)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::transpose")
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@symbolic_helper.parse_args("v", "i", "i")
|
||
|
@_beartype.beartype
|
||
|
def transpose(g: jit_utils.GraphContext, self, dim0, dim1):
|
||
|
if dim0 == dim1: # micro-optimization
|
||
|
return self
|
||
|
|
||
|
# NB: Transpose in ONNX is actually a Permute
|
||
|
rank = symbolic_helper._get_tensor_rank(self)
|
||
|
if rank is not None:
|
||
|
axes = list(range(rank))
|
||
|
axes[dim0], axes[dim1] = axes[dim1], axes[dim0]
|
||
|
return g.op("Transpose", self, perm_i=axes)
|
||
|
elif symbolic_helper.is_caffe2_aten_fallback():
|
||
|
# if we don't have dim information we cannot
|
||
|
# output a permute so use ATen instead
|
||
|
return g.at("transpose", self, overload_name="int", dim0_i=dim0, dim1_i=dim1)
|
||
|
else:
|
||
|
raise errors.SymbolicValueError(
|
||
|
"Unsupported: ONNX export of transpose for tensor of unknown rank.",
|
||
|
self,
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::permute")
|
||
|
@symbolic_helper.parse_args("v", "is")
|
||
|
@_beartype.beartype
|
||
|
def permute(g: jit_utils.GraphContext, self, dims):
|
||
|
if dims == list(range(0, len(dims))):
|
||
|
return self
|
||
|
return g.op("Transpose", self, perm_i=dims)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::view")
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@_beartype.beartype
|
||
|
def view(g: jit_utils.GraphContext, self, size):
|
||
|
return reshape(g, self, size)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::view_as")
|
||
|
@_beartype.beartype
|
||
|
def view_as(g: jit_utils.GraphContext, self, other):
|
||
|
shape = g.op("Shape", other)
|
||
|
return reshape(g, self, shape)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::unsafe_chunk")
|
||
|
@symbolic_helper.parse_args("v", "i", "i", "i")
|
||
|
@_beartype.beartype
|
||
|
def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None):
|
||
|
if _outputs is None:
|
||
|
return symbolic_helper._onnx_opset_unsupported_detailed(
|
||
|
"unsafe_chunk", 9, 11, "Dynamic number of outputs not supported", self
|
||
|
)
|
||
|
size = symbolic_helper._get_tensor_dim_size(self, dim)
|
||
|
if size is None:
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"unsafe_chunk", "unknown dimension size", self
|
||
|
)
|
||
|
split_size = (size + chunks - 1) // chunks
|
||
|
splits = [split_size] * (size // split_size)
|
||
|
leftover = size % split_size
|
||
|
if leftover:
|
||
|
splits.append(leftover)
|
||
|
return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::split")
|
||
|
@symbolic_helper.parse_args("v", "v", "i", "i")
|
||
|
@_beartype.beartype
|
||
|
def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None):
|
||
|
if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs):
|
||
|
return symbolic_helper._onnx_opset_unsupported_detailed(
|
||
|
"split", 9, 11, "Dynamic number of outputs not supported", self
|
||
|
)
|
||
|
split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value")
|
||
|
if split_val.dim() > 0:
|
||
|
return split_with_sizes(g, self, split_size_or_sizes, dim, _outputs)
|
||
|
split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size")
|
||
|
|
||
|
size = symbolic_helper._get_tensor_dim_size(self, dim)
|
||
|
if size is None:
|
||
|
if _outputs is not None:
|
||
|
size = split_size * _outputs
|
||
|
else:
|
||
|
return symbolic_helper._onnx_opset_unsupported_detailed(
|
||
|
"split", 9, 11, "Unknown dimension size not supported", self
|
||
|
)
|
||
|
splits = [split_size] * (size // split_size)
|
||
|
leftover = size % split_size
|
||
|
if leftover:
|
||
|
splits.append(leftover)
|
||
|
return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::unsafe_split")
|
||
|
@_beartype.beartype
|
||
|
def unsafe_split(
|
||
|
g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None
|
||
|
):
|
||
|
return split(g, self, split_size_or_sizes, dim, _outputs)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::split_with_sizes")
|
||
|
@symbolic_helper.parse_args("v", "is", "i", "i")
|
||
|
@_beartype.beartype
|
||
|
def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None):
|
||
|
if not symbolic_helper._is_split_static(split_sizes, _outputs):
|
||
|
return symbolic_helper._onnx_opset_unsupported_detailed(
|
||
|
"split_with_sizes", 9, 11, "Dynamic number of outputs not supported", self
|
||
|
)
|
||
|
return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=_outputs)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::unsafe_split_with_sizes")
|
||
|
@_beartype.beartype
|
||
|
def unsafe_split_with_sizes(
|
||
|
g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None
|
||
|
):
|
||
|
return split_with_sizes(g, self, split_sizes, dim, _outputs)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::unbind")
|
||
|
@symbolic_helper.parse_args("v", "i", "i")
|
||
|
@_beartype.beartype
|
||
|
def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None):
|
||
|
if _outputs is None:
|
||
|
return symbolic_helper._onnx_opset_unsupported_detailed(
|
||
|
"unbind", 9, 11, "Dynamic number of outputs not supported", self
|
||
|
)
|
||
|
|
||
|
outputs = g.op("Split", self, split_i=[1] * _outputs, axis_i=dim, outputs=_outputs)
|
||
|
outputs = [outputs] if _outputs == 1 else outputs
|
||
|
squeezed_outputs = [
|
||
|
symbolic_helper._squeeze_helper(g, out, [dim]) for out in outputs
|
||
|
]
|
||
|
return squeezed_outputs
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::select")
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@symbolic_helper.parse_args("v", "i", "v")
|
||
|
@_beartype.beartype
|
||
|
def select(g: jit_utils.GraphContext, self, dim, index):
|
||
|
index = symbolic_helper._maybe_get_scalar(index)
|
||
|
if (not symbolic_helper._is_value(index)) and (index < 0):
|
||
|
if index == -1:
|
||
|
end_index = _constants.INT64_MAX
|
||
|
else:
|
||
|
end_index = index + 1
|
||
|
slice_node = symbolic_helper._slice_helper(
|
||
|
g, self, axes=[dim], starts=[index], ends=[end_index]
|
||
|
)
|
||
|
return symbolic_helper._squeeze_helper(g, slice_node, [dim])
|
||
|
else:
|
||
|
# FIXME(justinchuby): can index be an int and not a value?
|
||
|
return g.op("Gather", self, index, axis_i=dim)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::square")
|
||
|
@_beartype.beartype
|
||
|
def square(g: jit_utils.GraphContext, self):
|
||
|
return g.op("Mul", self, self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::squeeze")
|
||
|
@_beartype.beartype
|
||
|
def squeeze(g: jit_utils.GraphContext, self, dim=None):
|
||
|
if dim is None:
|
||
|
return g.op("Squeeze", self)
|
||
|
|
||
|
squeeze_dim = symbolic_helper._get_const(dim, "i", "dim")
|
||
|
# Handle negative dims
|
||
|
if squeeze_dim < 0:
|
||
|
rank = symbolic_helper._get_tensor_rank(self)
|
||
|
if rank is not None:
|
||
|
warnings.warn(
|
||
|
"ONNX export squeeze with negative axis "
|
||
|
+ str(squeeze_dim)
|
||
|
+ " might cause the onnx model to be incorrect. "
|
||
|
+ "Negative axis is not supported in ONNX. "
|
||
|
+ "Axis is converted to "
|
||
|
+ str(squeeze_dim + rank)
|
||
|
+ " based on input shape at export time. "
|
||
|
+ "Passing an tensor of different rank in execution will be incorrect."
|
||
|
)
|
||
|
squeeze_dim += rank
|
||
|
else:
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"squeeze", "negative axis with unknown input rank", self
|
||
|
)
|
||
|
|
||
|
dim_size = symbolic_helper._get_tensor_dim_size(self, squeeze_dim)
|
||
|
if dim_size is None:
|
||
|
warnings.warn(
|
||
|
"This model contains a squeeze operation on dimension "
|
||
|
+ str(squeeze_dim)
|
||
|
+ " on an input "
|
||
|
+ "with unknown shape. Note that if the size of dimension "
|
||
|
+ str(squeeze_dim)
|
||
|
+ " of the input "
|
||
|
+ "is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on "
|
||
|
+ "non-singleton dimensions, it is recommended to export this model using opset "
|
||
|
+ "version 11 or higher."
|
||
|
)
|
||
|
return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim])
|
||
|
if dim_size > 1:
|
||
|
warnings.warn(
|
||
|
"This model contains a squeeze operation on dimension "
|
||
|
+ str(squeeze_dim)
|
||
|
+ ". The size of "
|
||
|
+ "this dimension in the given input is "
|
||
|
+ str(dim_size)
|
||
|
+ ". The model will "
|
||
|
+ "be exported without the squeeze node. If the model is intended to be used with dynamic "
|
||
|
+ "input shapes, please use opset version 11 to "
|
||
|
+ "export the model."
|
||
|
)
|
||
|
return self
|
||
|
|
||
|
warnings.warn(
|
||
|
"This model contains a squeeze operation on dimension "
|
||
|
+ str(squeeze_dim)
|
||
|
+ ". If the model is "
|
||
|
+ "intended to be used with dynamic input shapes, please use opset version 11 to export the model."
|
||
|
)
|
||
|
return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim])
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::prelu")
|
||
|
@_beartype.beartype
|
||
|
def prelu(g: jit_utils.GraphContext, self, weight):
|
||
|
self_rank = symbolic_helper._get_tensor_rank(self)
|
||
|
weight_sizes = symbolic_helper._get_tensor_sizes(weight)
|
||
|
weight_rank = len(weight_sizes)
|
||
|
if self_rank is not None:
|
||
|
if self_rank > 2:
|
||
|
# make weight unidirectional broadcastable
|
||
|
weight = symbolic_helper._unsqueeze_helper(
|
||
|
g, weight, list(range(1, self_rank - 1))
|
||
|
)
|
||
|
elif self_rank == 0 and weight_sizes == [1]:
|
||
|
# self and weight are both scalar but weight has rank == 1, squeeze weight.
|
||
|
weight = symbolic_helper._squeeze_helper(g, weight, [0])
|
||
|
weight_rank = 0
|
||
|
|
||
|
if self_rank is not None and weight_rank is not None:
|
||
|
assert (
|
||
|
self_rank >= weight_rank
|
||
|
), f"rank(x) should be >= rank(slope) but got {self_rank} < {weight_rank}"
|
||
|
return g.op("PRelu", self, weight)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::silu")
|
||
|
@_beartype.beartype
|
||
|
def silu(g: jit_utils.GraphContext, input):
|
||
|
return g.op("Mul", input, g.op("Sigmoid", input))
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::mish")
|
||
|
@_beartype.beartype
|
||
|
def mish(g: jit_utils.GraphContext, input):
|
||
|
return g.op("Mul", input, g.op("Tanh", g.op("Softplus", input)))
|
||
|
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kwargs):
|
||
|
"""Some PyTorch operators (e.g., Clip/Min/ReLU/Pad) are super set of ONNX in terms of data types.
|
||
|
This function maximizes the exportability of PyTorch-ONNX by allowing ONNX-unsupported PyTorch
|
||
|
operator data type. For example, `Cast<int>(Clip<float>(Cast<float>(INPUT)))` can be used to mimic
|
||
|
`Clip<int>(INPUT)` (opset version < 12).
|
||
|
|
||
|
Args:
|
||
|
g (torch._C.Graph): graph to write the ONNX representation into.
|
||
|
op_name (str): operator name in ONNX.
|
||
|
*args (tuple): operands to the operator.
|
||
|
**kwargs (dict): attributes to the operator along with "opset_before" (optional, None by default)
|
||
|
indicating the smallest opset version to trigger such casting behavior and "target_float_t"
|
||
|
(optional, torch.onnx.JitScalarType.FLOAT by default) indicating the data type of internal operator.
|
||
|
|
||
|
Returns:
|
||
|
Optional[torch._C.Value, Tuple[torch._C.Value, ...]]: output(s) of the operator.
|
||
|
"""
|
||
|
opset_before = kwargs.pop("opset_before", None)
|
||
|
target_float_t = kwargs.pop("target_float_t", _type_utils.JitScalarType.FLOAT)
|
||
|
|
||
|
inputs = list(args)
|
||
|
dtype_0 = _type_utils.JitScalarType.from_value(inputs[0])
|
||
|
|
||
|
require_cast = not symbolic_helper._is_fp(inputs[0]) and (
|
||
|
opset_before is None or GLOBALS.export_onnx_opset_version < opset_before
|
||
|
)
|
||
|
|
||
|
if require_cast:
|
||
|
for input in inputs:
|
||
|
if input.isCompleteTensor():
|
||
|
input_scalar_type = _type_utils.JitScalarType.from_value(input)
|
||
|
if input_scalar_type != dtype_0:
|
||
|
raise errors.SymbolicValueError(
|
||
|
f"Inputs of {op_name} must have same dtype."
|
||
|
f"Got {dtype_0.scalar_name()} and {input_scalar_type.scalar_name()}",
|
||
|
input,
|
||
|
)
|
||
|
for i, input in enumerate(inputs):
|
||
|
if input.isCompleteTensor() and not symbolic_helper._is_fp(input):
|
||
|
inputs[i] = g.op(
|
||
|
"Cast",
|
||
|
input,
|
||
|
to_i=target_float_t.onnx_type(),
|
||
|
)
|
||
|
|
||
|
self = g.op(op_name, *inputs, **kwargs)
|
||
|
|
||
|
if require_cast:
|
||
|
self = g.op("Cast", self, to_i=dtype_0.onnx_type())
|
||
|
|
||
|
return self
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::relu")
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@_beartype.beartype
|
||
|
def relu(g: jit_utils.GraphContext, input):
|
||
|
return _op_with_optional_float_cast(g, "Relu", input, opset_before=14)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::relu6")
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@_beartype.beartype
|
||
|
def relu6(g: jit_utils.GraphContext, input):
|
||
|
return clamp(g, input, 0, 6)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::ceil")
|
||
|
@_beartype.beartype
|
||
|
def ceil(g: jit_utils.GraphContext, input):
|
||
|
return g.op("Ceil", input)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::floor")
|
||
|
@_beartype.beartype
|
||
|
def floor(g: jit_utils.GraphContext, input):
|
||
|
return g.op("Floor", input)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::len")
|
||
|
@_beartype.beartype
|
||
|
def _len(g: jit_utils.GraphContext, self):
|
||
|
sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0])))
|
||
|
return symbolic_helper._squeeze_helper(g, sz_0, [0])
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::threshold")
|
||
|
@symbolic_helper.parse_args("v", "t", "t")
|
||
|
@_beartype.beartype
|
||
|
def threshold(g: jit_utils.GraphContext, self, threshold, value):
|
||
|
# See Note [Export inplace]
|
||
|
if symbolic_helper._scalar(threshold) != 0:
|
||
|
return symbolic_helper._unimplemented("threshold", "non-zero threshold", self)
|
||
|
if symbolic_helper._scalar(value) != 0:
|
||
|
return symbolic_helper._unimplemented("threshold", "non-zero value", self)
|
||
|
return g.op("Relu", self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::leaky_relu")
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@symbolic_helper.parse_args("v", "f", "b")
|
||
|
@_beartype.beartype
|
||
|
def leaky_relu(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input: _C.Value,
|
||
|
negative_slope: float,
|
||
|
inplace: bool = False,
|
||
|
):
|
||
|
# See Note [Export inplace]
|
||
|
return g.op("LeakyRelu", input, alpha_f=negative_slope)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::glu")
|
||
|
@symbolic_helper.parse_args("v", "i")
|
||
|
@_beartype.beartype
|
||
|
def glu(g: jit_utils.GraphContext, input, dim):
|
||
|
dim_size = symbolic_helper._get_tensor_dim_size(input, dim)
|
||
|
if dim_size is not None:
|
||
|
assert dim_size % 2 == 0
|
||
|
|
||
|
first, second = g.op("Split", input, axis_i=dim, outputs=2)
|
||
|
return g.op("Mul", first, g.op("Sigmoid", second))
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::softmax")
|
||
|
@symbolic_helper.parse_args("v", "i", "none")
|
||
|
@_beartype.beartype
|
||
|
def softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
|
||
|
# Softmax does normalization at vector level.
|
||
|
# PyTorch and ONNX use different strategies to split the input tensor into vectors.
|
||
|
# Thus dim and axis have different meanings.
|
||
|
# PyTorch slices the input tensor into vectors along the `dim`-th dimension.
|
||
|
# ONNX reshapes the input into a 2-D tensor, and `axis` indicates where the input is coerced.
|
||
|
# If input is a 2 x 3 tensor:
|
||
|
# input = [[1.0, 1.0, 1.0],
|
||
|
# [1.0, 1,0, 1,0]]
|
||
|
# with dim = 0, the result is:
|
||
|
# result = [[0.5, 0.5, 0.5],
|
||
|
# [0.5, 0.5, 0.5]]
|
||
|
# with axis = 0, the result is:
|
||
|
# result = [[0.167, 0.167, 0.167],
|
||
|
# [0.167, 0.167, 0.167]]
|
||
|
# So only when dim and axis both equal to ndim - 1 (the last dimension),
|
||
|
# their semantics are equivalent.
|
||
|
# So use softmax when dim and axis both equal to ndim - 1,
|
||
|
# otherwise transpose the input to put the vectors to be normalized to the last dimension.
|
||
|
# When input rank is not known at export time we compute softmax using a subgraph
|
||
|
# with other operators
|
||
|
input_dim = symbolic_helper._get_tensor_rank(input)
|
||
|
if input_dim is not None:
|
||
|
# TODO: remove this as onnx opset 11 spec allows negative axes
|
||
|
if dim < 0:
|
||
|
dim = input_dim + dim
|
||
|
|
||
|
is_transpose_required = input_dim != dim + 1
|
||
|
|
||
|
if is_transpose_required:
|
||
|
axes = list(range(input_dim))
|
||
|
axes[dim], axes[-1] = axes[-1], axes[dim]
|
||
|
input = g.op("Transpose", input, perm_i=axes)
|
||
|
dim = input_dim - 1
|
||
|
|
||
|
softmax = g.op("Softmax", input, axis_i=dim)
|
||
|
if dtype and dtype.node().kind() != "prim::Constant":
|
||
|
parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||
|
softmax = g.op(
|
||
|
"Cast",
|
||
|
softmax,
|
||
|
to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type(),
|
||
|
)
|
||
|
|
||
|
if is_transpose_required:
|
||
|
softmax = g.op("Transpose", softmax, perm_i=axes) # type: ignore[possibly-undefined]
|
||
|
return softmax
|
||
|
|
||
|
# Apply max normalization.
|
||
|
input = g.op("Sub", input, g.op("ReduceMax", input, axes_i=[dim], keepdims_i=1))
|
||
|
|
||
|
exp = g.op("Exp", input)
|
||
|
sum = symbolic_helper._reducesum_helper(g, exp, axes_i=[dim])
|
||
|
softmax = g.op("Div", exp, sum)
|
||
|
if dtype and dtype.node().kind() != "prim::Constant":
|
||
|
parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||
|
softmax = g.op(
|
||
|
"Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()
|
||
|
)
|
||
|
return softmax
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::softplus")
|
||
|
@_beartype.beartype
|
||
|
def softplus(g: jit_utils.GraphContext, self, beta, threshold):
|
||
|
beta_const = symbolic_helper._maybe_get_const(beta, "f")
|
||
|
if beta_const != 1:
|
||
|
return g.op("Div", g.op("Softplus", g.op("Mul", self, beta)), beta)
|
||
|
return g.op("Softplus", self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::get_pool_ceil_padding")
|
||
|
@_beartype.beartype
|
||
|
def get_pool_ceil_padding(input, kernel_size, stride, padding):
|
||
|
# TODO(justinchuby): Looks like this op is deprecated in torch
|
||
|
sizes = symbolic_helper._get_tensor_sizes(input)
|
||
|
dim = sizes[-len(padding) :] if sizes is not None else None
|
||
|
if dim is None or any(i is None for i in dim):
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"get_pool_ceil_padding", "input size not accessible", input
|
||
|
)
|
||
|
ceiled_output_dim = [
|
||
|
int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i])))
|
||
|
+ 1
|
||
|
for i in range(0, len(padding))
|
||
|
]
|
||
|
# ensure last pooling starts inside
|
||
|
ceiled_output_dim = [
|
||
|
ceiled_output_dim[i] - 1
|
||
|
if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i]))
|
||
|
else ceiled_output_dim[i]
|
||
|
for i in range(0, len(ceiled_output_dim))
|
||
|
]
|
||
|
padding_ceil = [
|
||
|
0
|
||
|
if (stride[i] == 1)
|
||
|
else (
|
||
|
kernel_size[i]
|
||
|
- (dim[i] + 2 * padding[i] - ((ceiled_output_dim[i] - 1) * stride[i] + 1))
|
||
|
)
|
||
|
for i in range(0, len(padding))
|
||
|
]
|
||
|
# ensure padding is not > kernel_size
|
||
|
padding_ceil = [
|
||
|
(
|
||
|
int(padding_ceil[i])
|
||
|
if padding_ceil[i] < kernel_size[i] - 1
|
||
|
else int(kernel_size[i] - 1)
|
||
|
)
|
||
|
if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i]))
|
||
|
else int(padding_ceil[i])
|
||
|
for i in range(0, len(padding_ceil))
|
||
|
]
|
||
|
return padding_ceil
|
||
|
|
||
|
|
||
|
@_onnx_symbolic(
|
||
|
"aten::max_pool1d",
|
||
|
decorate=[
|
||
|
_apply_params(
|
||
|
"max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False
|
||
|
),
|
||
|
_export("max_pool1d"),
|
||
|
],
|
||
|
)
|
||
|
@_onnx_symbolic(
|
||
|
"aten::max_pool2d",
|
||
|
decorate=[
|
||
|
_apply_params(
|
||
|
"max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False
|
||
|
),
|
||
|
_export("max_pool2d"),
|
||
|
],
|
||
|
)
|
||
|
@_onnx_symbolic(
|
||
|
"aten::max_pool3d",
|
||
|
decorate=[
|
||
|
_apply_params(
|
||
|
"max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False
|
||
|
),
|
||
|
_export("max_pool3d"),
|
||
|
],
|
||
|
)
|
||
|
@_beartype.beartype
|
||
|
def _max_pool(name, tuple_fn, ndims, return_indices):
|
||
|
@symbolic_helper.quantized_args(True, False, False, False, False, False)
|
||
|
@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
|
||
|
@_beartype.beartype
|
||
|
def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
|
||
|
if set(tuple_fn(dilation)) != {1}:
|
||
|
return symbolic_helper._unimplemented(name, "dilation", input)
|
||
|
if not stride:
|
||
|
stride = kernel_size
|
||
|
padding = tuple(tuple_fn(padding))
|
||
|
if ceil_mode:
|
||
|
padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
|
||
|
padding = padding + tuple(a + b for (a, b) in zip(padding_ceil, padding))
|
||
|
else:
|
||
|
padding = padding * 2
|
||
|
kwargs = {
|
||
|
"kernel_shape_i": tuple_fn(kernel_size),
|
||
|
"pads_i": padding,
|
||
|
"strides_i": tuple_fn(stride),
|
||
|
}
|
||
|
# easy but hacky way to get flattened indices values
|
||
|
# to be used to convert the indices values to non-flattened.
|
||
|
# In ONNX the indices are computed as a flatten 1-D tensor,
|
||
|
# so the values in indices are in [0, N x C x D1 x ... x Dn).
|
||
|
# To convert the indices to the same format used by Pytorch,
|
||
|
# we first execute a maxpool with a kernel and stride of 1 on the same input.
|
||
|
# This will result in a tensor of indices in which each index will have it's own value.
|
||
|
# Using this tensor as a reference, we extract the first index of each axis and subtract
|
||
|
# it from each index of this axis in the indices to convert.
|
||
|
# This step will result in a tensor were each dimension has values of indices within
|
||
|
# the dimension it is in.
|
||
|
# For more information :
|
||
|
# https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
|
||
|
if return_indices:
|
||
|
r, indices = g.op("MaxPool", input, outputs=2, **kwargs)
|
||
|
_, flattened_indices = g.op(
|
||
|
"MaxPool",
|
||
|
input,
|
||
|
outputs=2,
|
||
|
kernel_shape_i=[1 for _ in range(ndims)],
|
||
|
strides_i=[1 for _ in range(ndims)],
|
||
|
)
|
||
|
# convert indices to have non-flattened indices values
|
||
|
s = symbolic_helper._slice_helper(
|
||
|
g,
|
||
|
flattened_indices,
|
||
|
axes=[2 + i for i in range(ndims)],
|
||
|
starts=list(tuple_fn(0)),
|
||
|
ends=list(tuple_fn(1)),
|
||
|
)
|
||
|
indices = sub(g, indices, s)
|
||
|
return r, indices
|
||
|
else:
|
||
|
r = g.op("MaxPool", input, outputs=1, **kwargs)
|
||
|
return r
|
||
|
|
||
|
return symbolic_fn
|
||
|
|
||
|
|
||
|
max_pool1d_with_indices = _onnx_symbolic("aten::max_pool1d_with_indices")(
|
||
|
_max_pool(
|
||
|
"max_pool1d_with_indices",
|
||
|
torch.nn.modules.utils._single,
|
||
|
1,
|
||
|
return_indices=True,
|
||
|
)
|
||
|
)
|
||
|
max_pool2d_with_indices = _onnx_symbolic("aten::max_pool2d_with_indices")(
|
||
|
_max_pool(
|
||
|
"max_pool2d_with_indices",
|
||
|
torch.nn.modules.utils._pair,
|
||
|
2,
|
||
|
return_indices=True,
|
||
|
)
|
||
|
)
|
||
|
max_pool3d_with_indices = _onnx_symbolic("aten::max_pool3d_with_indices")(
|
||
|
_max_pool(
|
||
|
"max_pool3d_with_indices",
|
||
|
torch.nn.modules.utils._triple,
|
||
|
3,
|
||
|
return_indices=True,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic(
|
||
|
"aten::avg_pool1d",
|
||
|
decorate=[
|
||
|
_apply_params("avg_pool1d", torch.nn.modules.utils._single),
|
||
|
_export("avg_pool1d"),
|
||
|
],
|
||
|
)
|
||
|
@_onnx_symbolic(
|
||
|
"aten::avg_pool2d",
|
||
|
decorate=[
|
||
|
_apply_params("avg_pool2d", torch.nn.modules.utils._pair),
|
||
|
_export("avg_pool2d"),
|
||
|
],
|
||
|
)
|
||
|
@_onnx_symbolic(
|
||
|
"aten::avg_pool3d",
|
||
|
decorate=[
|
||
|
_apply_params("avg_pool3d", torch.nn.modules.utils._triple),
|
||
|
_export("avg_pool3d"),
|
||
|
],
|
||
|
)
|
||
|
@_beartype.beartype
|
||
|
def _avg_pool(name, tuple_fn):
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
|
||
|
@_beartype.beartype
|
||
|
def symbolic_fn(
|
||
|
g,
|
||
|
input: _C.Value,
|
||
|
kernel_size: Sequence[int],
|
||
|
stride: Sequence[int],
|
||
|
padding: Union[int, Sequence[int]],
|
||
|
ceil_mode: int,
|
||
|
count_include_pad: int,
|
||
|
divisor_override=None,
|
||
|
):
|
||
|
if not stride:
|
||
|
stride = kernel_size
|
||
|
padding = symbolic_helper._avgpool_helper(
|
||
|
tuple_fn, padding, kernel_size, stride, divisor_override, name
|
||
|
)
|
||
|
assert isinstance(padding, tuple)
|
||
|
adjusted_padding = padding
|
||
|
# Although onnx::AvgPool provides count_include_pad,
|
||
|
# The corner case of Average Pooling with ceil_mode on
|
||
|
# PyTorch allows sliding window go off bound, which leads to
|
||
|
# this accommodation.
|
||
|
# More detail on https://github.com/pytorch/pytorch/issues/57178
|
||
|
if count_include_pad:
|
||
|
input = _op_with_optional_float_cast(
|
||
|
g,
|
||
|
"Pad",
|
||
|
input,
|
||
|
pads_i=((0,) * 2 + padding) * 2,
|
||
|
mode_s="constant",
|
||
|
value_f=0.0,
|
||
|
opset_before=11,
|
||
|
)
|
||
|
adjusted_padding = (0,) * len(padding)
|
||
|
if ceil_mode:
|
||
|
padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
|
||
|
adjusted_padding = adjusted_padding + tuple(
|
||
|
a + b for (a, b) in zip(padding_ceil, adjusted_padding)
|
||
|
)
|
||
|
else:
|
||
|
adjusted_padding = adjusted_padding * 2
|
||
|
output = g.op(
|
||
|
"AveragePool",
|
||
|
input,
|
||
|
kernel_shape_i=tuple_fn(kernel_size),
|
||
|
strides_i=tuple_fn(stride),
|
||
|
pads_i=adjusted_padding,
|
||
|
)
|
||
|
return output
|
||
|
|
||
|
return symbolic_fn
|
||
|
|
||
|
|
||
|
@_onnx_symbolic(
|
||
|
"aten::adaptive_avg_pool1d",
|
||
|
decorate=[
|
||
|
_apply_params(
|
||
|
"adaptive_avg_pool1d", "AveragePool", torch.nn.modules.utils._single
|
||
|
),
|
||
|
_export("adaptive_avg_pool1d"),
|
||
|
],
|
||
|
)
|
||
|
@_onnx_symbolic(
|
||
|
"aten::adaptive_avg_pool2d",
|
||
|
decorate=[
|
||
|
_apply_params(
|
||
|
"adaptive_avg_pool2d", "AveragePool", torch.nn.modules.utils._pair
|
||
|
),
|
||
|
_export("adaptive_avg_pool2d"),
|
||
|
],
|
||
|
)
|
||
|
@_onnx_symbolic(
|
||
|
"aten::adaptive_avg_pool3d",
|
||
|
decorate=[
|
||
|
_apply_params(
|
||
|
"adaptive_avg_pool3d", "AveragePool", torch.nn.modules.utils._triple
|
||
|
),
|
||
|
_export("adaptive_avg_pool3d"),
|
||
|
],
|
||
|
)
|
||
|
@_onnx_symbolic(
|
||
|
"aten::adaptive_max_pool1d",
|
||
|
decorate=[
|
||
|
_apply_params(
|
||
|
"adaptive_max_pool1d",
|
||
|
"MaxPool",
|
||
|
torch.nn.modules.utils._single,
|
||
|
max_pool1d_with_indices,
|
||
|
),
|
||
|
_export("adaptive_max_pool1d"),
|
||
|
],
|
||
|
)
|
||
|
@_onnx_symbolic(
|
||
|
"aten::adaptive_max_pool2d",
|
||
|
decorate=[
|
||
|
_apply_params(
|
||
|
"adaptive_max_pool2d",
|
||
|
"MaxPool",
|
||
|
torch.nn.modules.utils._pair,
|
||
|
max_pool2d_with_indices,
|
||
|
),
|
||
|
_export("adaptive_max_pool2d"),
|
||
|
],
|
||
|
)
|
||
|
@_onnx_symbolic(
|
||
|
"aten::adaptive_max_pool3d",
|
||
|
decorate=[
|
||
|
_apply_params(
|
||
|
"adaptive_max_pool3d",
|
||
|
"MaxPool",
|
||
|
torch.nn.modules.utils._triple,
|
||
|
max_pool3d_with_indices,
|
||
|
),
|
||
|
_export("adaptive_max_pool3d"),
|
||
|
],
|
||
|
)
|
||
|
@_beartype.beartype
|
||
|
def _adaptive_pool(name, type, tuple_fn, fn=None):
|
||
|
@symbolic_helper.quantized_args(True, False)
|
||
|
@_beartype.beartype
|
||
|
def symbolic_fn(g, input, output_size):
|
||
|
# _adaptive_pool is supported for cases where output_size is 1 for all dimensions,
|
||
|
# by executing a GlobalPool.
|
||
|
# It is also supported for cases where the output size is a factor of the input size.
|
||
|
# For these cases the stride and kernel size are uniform along all the indices of
|
||
|
# the same dimension, which makes it possible to export it to ONNX.
|
||
|
# for MaxPool, GlobalMaxPool does not return indices,
|
||
|
# so we try using max_poolxd_with_indices, and if it is not possible
|
||
|
# (input is not a complete tensor or output size not factor of input size)
|
||
|
# then we call GlobalAveragePool and return None for the indices
|
||
|
output_size_value = output_size
|
||
|
try:
|
||
|
output_size = symbolic_helper._parse_arg(output_size, "is")
|
||
|
except Exception:
|
||
|
# FIXME(justinchuby): Avoid catching Exception.
|
||
|
# Catch a more specific exception instead.
|
||
|
return symbolic_helper._onnx_unsupported(
|
||
|
"adaptive pooling, since output_size is not constant.", input
|
||
|
)
|
||
|
if output_size == [1] * len(output_size) and type == "AveragePool":
|
||
|
return g.op("GlobalAveragePool", input)
|
||
|
sizes = symbolic_helper._get_tensor_sizes(input)
|
||
|
try:
|
||
|
dim = sizes[2:]
|
||
|
except Exception:
|
||
|
# FIXME(justinchuby): Avoid catching Exception.
|
||
|
# Catch a more specific exception instead.
|
||
|
dim = None
|
||
|
if dim is None or any(i is None for i in dim):
|
||
|
if output_size == [1] * len(output_size):
|
||
|
return g.op("GlobalMaxPool", input), None
|
||
|
return symbolic_helper._unimplemented(
|
||
|
name, "input size not accessible", input
|
||
|
)
|
||
|
# verify if output size % input size = 0 for all dim
|
||
|
mod = [dim[i] % output_size[i] for i in range(0, len(dim))]
|
||
|
if mod != [0] * len(mod):
|
||
|
if output_size == [1] * len(output_size):
|
||
|
return g.op("GlobalMaxPool", input), None
|
||
|
return symbolic_helper._unimplemented(
|
||
|
name, "output size that are not factor of input size", output_size_value
|
||
|
)
|
||
|
k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))]
|
||
|
# call max_poolxd_with_indices to get indices in the output
|
||
|
if type == "MaxPool":
|
||
|
return fn(g, input, k, k, (0,) * len(dim), (1,) * len(dim), False)
|
||
|
output = g.op(type, input, kernel_shape_i=tuple_fn(k), strides_i=tuple_fn(k))
|
||
|
return output
|
||
|
|
||
|
return symbolic_fn
|
||
|
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def _prepare_onnx_paddings(dim: int, pad):
|
||
|
"""Generate paddings in ONNX order based on pad in pytorch.
|
||
|
Args:
|
||
|
dim: the dimension of the tensor.
|
||
|
pad: the paddings in pytorch.
|
||
|
The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ...
|
||
|
"""
|
||
|
# The desired order of paddings is
|
||
|
# dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
|
||
|
# n is the dimension of input.
|
||
|
# assume zero-dimensions in the beginning
|
||
|
paddings = list(pad[:]) + [0] * (dim * 2 - len(pad))
|
||
|
# reverse order and collate first beginnings and then ends
|
||
|
paddings = paddings[-2::-2] + paddings[-1::-2]
|
||
|
return paddings
|
||
|
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def _convert_padding_node(input):
|
||
|
padding = symbolic_helper._maybe_get_const(input, "is")
|
||
|
if symbolic_helper._is_value(padding) and symbolic_helper._is_packed_list(padding):
|
||
|
input_list = symbolic_helper._unpack_list(padding)
|
||
|
try:
|
||
|
padding = [
|
||
|
symbolic_helper._get_const(v, "i", "padding") for v in input_list
|
||
|
]
|
||
|
except Exception:
|
||
|
# FIXME(justinchuby): Avoid catching Exception.
|
||
|
# Catch a more specific exception instead.
|
||
|
return symbolic_helper._onnx_opset_unsupported_detailed(
|
||
|
"Pad", 9, 11, "The sizes of the padding must be constant", input
|
||
|
)
|
||
|
return padding
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::constant_pad_nd")
|
||
|
@_beartype.beartype
|
||
|
def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value):
|
||
|
mode = "constant"
|
||
|
try:
|
||
|
value = symbolic_helper._get_const(value, "f", "value")
|
||
|
except Exception:
|
||
|
# FIXME(justinchuby): Avoid catching Exception.
|
||
|
# Catch a more specific exception instead.
|
||
|
return symbolic_helper._onnx_opset_unsupported_detailed(
|
||
|
"Pad", 9, 11, "The value for the padding must be constant", value
|
||
|
)
|
||
|
|
||
|
padding = _convert_padding_node(padding)
|
||
|
paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding)
|
||
|
return _op_with_optional_float_cast(
|
||
|
g, "Pad", input, pads_i=paddings, mode_s=mode, value_f=value, opset_before=11
|
||
|
)
|
||
|
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def _pad_circular(g: jit_utils.GraphContext, input: _C.Value, pad: _C.Value):
|
||
|
padding = _convert_padding_node(pad)
|
||
|
assert len(padding) % 2 == 0
|
||
|
ndim = len(padding) // 2
|
||
|
|
||
|
cur = input
|
||
|
for idx in range(ndim):
|
||
|
pad_r = padding[-(2 * idx + 1)]
|
||
|
pad_l = padding[-(2 * idx + 2)]
|
||
|
tensors = []
|
||
|
if pad_l > 0:
|
||
|
left = symbolic_helper._slice_helper(
|
||
|
g, cur, axes=[2 + idx], starts=[-(pad_l)], ends=[_constants.INT64_MAX]
|
||
|
)
|
||
|
tensors.append(left)
|
||
|
|
||
|
if pad_l < 0 or pad_r < 0:
|
||
|
start = builtins.max(0, -pad_l)
|
||
|
end = -(builtins.max(0, -pad_r))
|
||
|
middle = symbolic_helper._slice_helper(
|
||
|
g,
|
||
|
cur,
|
||
|
axes=[2 + idx],
|
||
|
starts=[start],
|
||
|
ends=[end],
|
||
|
)
|
||
|
tensors.append(middle)
|
||
|
else:
|
||
|
tensors.append(cur)
|
||
|
|
||
|
if pad_r > 0:
|
||
|
right = symbolic_helper._slice_helper(
|
||
|
g, cur, axes=[2 + idx], starts=[0], ends=[pad_r]
|
||
|
)
|
||
|
tensors.append(right)
|
||
|
|
||
|
cur = g.op("Concat", *tensors, axis_i=(2 + idx))
|
||
|
|
||
|
return cur
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::reflection_pad1d")
|
||
|
@_onnx_symbolic("aten::reflection_pad2d")
|
||
|
@_onnx_symbolic("aten::reflection_pad3d")
|
||
|
@_beartype.beartype
|
||
|
def reflection_pad(g: jit_utils.GraphContext, input, padding):
|
||
|
mode = "reflect"
|
||
|
padding = _convert_padding_node(padding)
|
||
|
paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding)
|
||
|
return _op_with_optional_float_cast(
|
||
|
g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::replication_pad1d")
|
||
|
@_onnx_symbolic("aten::replication_pad2d")
|
||
|
@_onnx_symbolic("aten::replication_pad3d")
|
||
|
@_beartype.beartype
|
||
|
def replication_pad(g: jit_utils.GraphContext, input, padding):
|
||
|
mode = "edge"
|
||
|
padding = _convert_padding_node(padding)
|
||
|
paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding)
|
||
|
return _op_with_optional_float_cast(
|
||
|
g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::pad")
|
||
|
@_beartype.beartype
|
||
|
def pad(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input: _C.Value,
|
||
|
pad: _C.Value,
|
||
|
mode: _C.Value,
|
||
|
value: _C.Value,
|
||
|
):
|
||
|
mode = symbolic_helper._parse_arg(mode, "s")
|
||
|
if mode == "replicate":
|
||
|
return replication_pad(g, input, pad)
|
||
|
elif mode == "reflect":
|
||
|
return reflection_pad(g, input, pad)
|
||
|
elif mode == "constant":
|
||
|
return constant_pad_nd(g, input, pad, value)
|
||
|
elif mode == "circular":
|
||
|
return _pad_circular(g, input, pad)
|
||
|
else:
|
||
|
raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic(
|
||
|
"aten::upsample_nearest1d",
|
||
|
decorate=[
|
||
|
_apply_params("upsample_nearest1d", 3, "nearest"),
|
||
|
_export("upsample_nearest1d"),
|
||
|
],
|
||
|
)
|
||
|
@_onnx_symbolic(
|
||
|
"aten::upsample_nearest2d",
|
||
|
decorate=[
|
||
|
_apply_params("upsample_nearest2d", 4, "nearest"),
|
||
|
_export("upsample_nearest2d"),
|
||
|
],
|
||
|
)
|
||
|
@_onnx_symbolic(
|
||
|
"aten::upsample_nearest3d",
|
||
|
decorate=[
|
||
|
_apply_params("upsample_nearest3d", 5, "nearest"),
|
||
|
_export("upsample_nearest3d"),
|
||
|
],
|
||
|
)
|
||
|
@_onnx_symbolic(
|
||
|
"aten::upsample_linear1d",
|
||
|
decorate=[
|
||
|
_apply_params("upsample_linear1d", 3, "linear"),
|
||
|
_export("upsample_linear1d"),
|
||
|
],
|
||
|
)
|
||
|
@_onnx_symbolic(
|
||
|
"aten::upsample_bilinear2d",
|
||
|
decorate=[
|
||
|
_apply_params("upsample_bilinear2d", 4, "linear"),
|
||
|
_export("upsample_bilinear2d"),
|
||
|
],
|
||
|
)
|
||
|
@_onnx_symbolic(
|
||
|
"aten::upsample_trilinear3d",
|
||
|
decorate=[
|
||
|
_apply_params("upsample_trilinear3d", 5, "linear"),
|
||
|
_export("upsample_trilinear3d"),
|
||
|
],
|
||
|
)
|
||
|
@_beartype.beartype
|
||
|
def _interpolate(name: str, dim: int, interpolate_mode: str):
|
||
|
def symbolic_fn(g, input, output_size, *args):
|
||
|
scales, align_corners = symbolic_helper._get_interpolate_attributes(
|
||
|
g, interpolate_mode, args
|
||
|
)
|
||
|
symbolic_helper._interpolate_warning(interpolate_mode)
|
||
|
align_corners = symbolic_helper._maybe_get_scalar(align_corners)
|
||
|
if align_corners:
|
||
|
return symbolic_helper._unimplemented(name, "align_corners == True", input)
|
||
|
if scales is None:
|
||
|
scales = symbolic_helper._interpolate_size_to_scales(
|
||
|
g, input, output_size, dim
|
||
|
)
|
||
|
return g.op("Upsample", input, scales, mode_s=interpolate_mode)
|
||
|
|
||
|
return symbolic_fn
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::__interpolate")
|
||
|
@_beartype.beartype
|
||
|
def __interpolate(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input,
|
||
|
size,
|
||
|
scale_factor,
|
||
|
mode,
|
||
|
align_corners,
|
||
|
recompute_scale_factor,
|
||
|
antialias,
|
||
|
):
|
||
|
scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
|
||
|
g, input, size, scale_factor, mode, align_corners
|
||
|
)
|
||
|
return g.op("Upsample", input, scales, mode_s=mode)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::bitwise_not")
|
||
|
@_beartype.beartype
|
||
|
def bitwise_not(g: jit_utils.GraphContext, input):
|
||
|
if not symbolic_helper._is_bool(input):
|
||
|
raise errors.SymbolicValueError(
|
||
|
"ONNX export does NOT support exporting bitwise Not "
|
||
|
"for non-boolean input values",
|
||
|
input,
|
||
|
)
|
||
|
return g.op("Not", input)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::bitwise_or")
|
||
|
@_beartype.beartype
|
||
|
def bitwise_or(g, self, other):
|
||
|
if not symbolic_helper._is_bool(self):
|
||
|
raise errors.SymbolicValueError(
|
||
|
"ONNX export does NOT support exporting bitwise OR "
|
||
|
"for non-boolean input values. self: ",
|
||
|
self,
|
||
|
)
|
||
|
if not symbolic_helper._is_bool(other):
|
||
|
raise errors.SymbolicValueError(
|
||
|
"ONNX export does NOT support exporting bitwise OR "
|
||
|
"for non-boolean input values. other: ",
|
||
|
other,
|
||
|
)
|
||
|
return g.op("Or", self, other)
|
||
|
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def wrap_logical_op_with_cast_to(to_type):
|
||
|
def decorator(fn):
|
||
|
@functools.wraps(fn)
|
||
|
def wrap_with_cast(g, input, other):
|
||
|
to_cast_func = globals()[f"_cast_{to_type}"]
|
||
|
return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False))
|
||
|
|
||
|
return wrap_with_cast
|
||
|
|
||
|
return decorator
|
||
|
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def wrap_logical_op_with_negation(func: Callable) -> Callable:
|
||
|
@functools.wraps(func)
|
||
|
def wrap_with_not(g, input, other):
|
||
|
return g.op("Not", func(g, input, other))
|
||
|
|
||
|
return wrap_with_not
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::__not_")
|
||
|
@_beartype.beartype
|
||
|
def __not_(g: jit_utils.GraphContext, self):
|
||
|
if not symbolic_helper._is_bool(self):
|
||
|
raise errors.SymbolicValueError(
|
||
|
"ONNX export does NOT support exporting bitwise Not "
|
||
|
"for non-boolean input values",
|
||
|
self,
|
||
|
)
|
||
|
return g.op("Not", self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::eq")
|
||
|
@symbolic_helper.quantized_args(True, True)
|
||
|
@_beartype.beartype
|
||
|
def eq(g: jit_utils.GraphContext, self, other):
|
||
|
if isinstance(self.type(), _C.DeviceObjType) and isinstance(
|
||
|
other.type(), _C.DeviceObjType
|
||
|
):
|
||
|
# ONNX doesn't have devices, so consider them all to be equal.
|
||
|
# The no-op check for equality will get constant-folded.
|
||
|
return g.op("Constant", value_t=torch.tensor(True, dtype=torch.bool))
|
||
|
self_node = self.node()
|
||
|
other_node = other.node()
|
||
|
if self_node.kind() == other_node.kind() == "onnx::Constant":
|
||
|
if self_node.kindOf("value") == other_node.kindOf("value") == "s":
|
||
|
# Exporting strings to ONNX is not supported.
|
||
|
# If both strings are constant, we can compare them directly.
|
||
|
# The no-op check for equality will get constant-folded.
|
||
|
return g.op(
|
||
|
"Constant",
|
||
|
value_t=torch.tensor(
|
||
|
self_node.s("value") == other_node.s("value"),
|
||
|
dtype=torch.bool,
|
||
|
),
|
||
|
)
|
||
|
|
||
|
return g.op("Equal", self, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::ne")
|
||
|
@symbolic_helper.quantized_args(True, True)
|
||
|
@wrap_logical_op_with_negation
|
||
|
@_beartype.beartype
|
||
|
def ne(g: jit_utils.GraphContext, self, other):
|
||
|
return eq(g, self, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::gt")
|
||
|
@symbolic_helper.quantized_args(True, True)
|
||
|
@_beartype.beartype
|
||
|
def gt(g: jit_utils.GraphContext, input, other):
|
||
|
return _gt_impl(g, input, other)
|
||
|
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def _gt_impl(g: jit_utils.GraphContext, input, other):
|
||
|
if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other):
|
||
|
input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32)
|
||
|
other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32)
|
||
|
return g.op("Greater", input, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::lt")
|
||
|
@symbolic_helper.quantized_args(True, True)
|
||
|
@_beartype.beartype
|
||
|
def lt(g: jit_utils.GraphContext, input, other):
|
||
|
return _lt_impl(g, input, other)
|
||
|
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def _lt_impl(g: jit_utils.GraphContext, input, other):
|
||
|
if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other):
|
||
|
input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32)
|
||
|
other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32)
|
||
|
return g.op("Less", input, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::ge")
|
||
|
@symbolic_helper.quantized_args(True, True)
|
||
|
@wrap_logical_op_with_negation
|
||
|
@_beartype.beartype
|
||
|
def ge(g: jit_utils.GraphContext, input, other):
|
||
|
return _lt_impl(g, input, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::le")
|
||
|
@symbolic_helper.quantized_args(True, True)
|
||
|
@wrap_logical_op_with_negation
|
||
|
@_beartype.beartype
|
||
|
def le(g: jit_utils.GraphContext, input, other):
|
||
|
return _gt_impl(g, input, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::__and_")
|
||
|
@_beartype.beartype
|
||
|
def __and_(g: jit_utils.GraphContext, input, other):
|
||
|
if not symbolic_helper._is_bool(input):
|
||
|
raise errors.SymbolicValueError(
|
||
|
"ONNX export does NOT support exporting bitwise AND "
|
||
|
"for non-boolean input values",
|
||
|
input,
|
||
|
)
|
||
|
if not symbolic_helper._is_bool(other):
|
||
|
raise errors.SymbolicValueError(
|
||
|
"ONNX export does NOT support exporting bitwise AND "
|
||
|
"for non-boolean input values",
|
||
|
other,
|
||
|
)
|
||
|
return g.op("And", input, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::__or_")
|
||
|
@_beartype.beartype
|
||
|
def __or_(g: jit_utils.GraphContext, input, other):
|
||
|
if not symbolic_helper._is_bool(input):
|
||
|
raise errors.SymbolicValueError(
|
||
|
"ONNX export does NOT support exporting bitwise OR "
|
||
|
"for non-boolean input values",
|
||
|
input,
|
||
|
)
|
||
|
if not symbolic_helper._is_bool(other):
|
||
|
raise errors.SymbolicValueError(
|
||
|
"ONNX export does NOT support exporting bitwise OR "
|
||
|
"for non-boolean input values",
|
||
|
other,
|
||
|
)
|
||
|
return g.op("Or", input, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::__xor_")
|
||
|
@_beartype.beartype
|
||
|
def __xor_(g: jit_utils.GraphContext, input, other):
|
||
|
if not symbolic_helper._is_bool(input):
|
||
|
raise errors.SymbolicValueError(
|
||
|
"ONNX export does NOT support exporting bitwise XOR "
|
||
|
"for non-boolean input values",
|
||
|
input,
|
||
|
)
|
||
|
if not symbolic_helper._is_bool(other):
|
||
|
raise errors.SymbolicValueError(
|
||
|
"ONNX export does NOT support exporting bitwise XOR "
|
||
|
"for non-boolean input values",
|
||
|
other,
|
||
|
)
|
||
|
return g.op("Xor", input, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::logical_and")
|
||
|
@wrap_logical_op_with_cast_to("Bool")
|
||
|
@_beartype.beartype
|
||
|
def logical_and(g: jit_utils.GraphContext, input, other):
|
||
|
return g.op("And", input, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::logical_or")
|
||
|
@wrap_logical_op_with_cast_to("Bool")
|
||
|
@_beartype.beartype
|
||
|
def logical_or(g: jit_utils.GraphContext, input, other):
|
||
|
return g.op("Or", input, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::logical_xor")
|
||
|
@wrap_logical_op_with_cast_to("Bool")
|
||
|
@_beartype.beartype
|
||
|
def logical_xor(g: jit_utils.GraphContext, input, other):
|
||
|
return g.op("Xor", input, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::logical_not")
|
||
|
@_beartype.beartype
|
||
|
def logical_not(g: jit_utils.GraphContext, input):
|
||
|
return g.op("Not", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL))
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::__rshift_")
|
||
|
@_beartype.beartype
|
||
|
def __rshift_(g: jit_utils.GraphContext, self, other):
|
||
|
# make sure to cast other to self's type
|
||
|
# (when self is long, make sure that other is not float)
|
||
|
self_scalar_type = _type_utils.JitScalarType.from_value(self)
|
||
|
if (
|
||
|
_type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED)
|
||
|
!= self_scalar_type
|
||
|
):
|
||
|
other = g.op(
|
||
|
"Cast",
|
||
|
other,
|
||
|
to_i=self_scalar_type.onnx_type(),
|
||
|
)
|
||
|
|
||
|
two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32))
|
||
|
# exponent (same type as self) has to be float or double in onnx::Pow
|
||
|
if not symbolic_helper._is_fp(self):
|
||
|
other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
||
|
two_pow = g.op("Pow", two, other)
|
||
|
two_pow = g.op(
|
||
|
"Cast",
|
||
|
two_pow,
|
||
|
to_i=self_scalar_type.onnx_type(),
|
||
|
)
|
||
|
rshift = g.op("Div", self, two_pow)
|
||
|
return rshift
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::__lshift_")
|
||
|
@_beartype.beartype
|
||
|
def __lshift_(g: jit_utils.GraphContext, self, other):
|
||
|
# make sure to cast other to self's type
|
||
|
# (when self is long, make sure that other is not float)
|
||
|
self_scalar_type = _type_utils.JitScalarType.from_value(self)
|
||
|
if (
|
||
|
_type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED)
|
||
|
!= self_scalar_type
|
||
|
):
|
||
|
other = g.op(
|
||
|
"Cast",
|
||
|
other,
|
||
|
to_i=self_scalar_type.onnx_type(),
|
||
|
)
|
||
|
|
||
|
two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32))
|
||
|
# exponent (same type as self) has to be float or double in onnx::Pow
|
||
|
if not symbolic_helper._is_fp(self):
|
||
|
other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
||
|
two_pow = g.op("Pow", two, other)
|
||
|
two_pow = g.op(
|
||
|
"Cast",
|
||
|
two_pow,
|
||
|
to_i=self_scalar_type.onnx_type(),
|
||
|
)
|
||
|
lshift = g.op("Mul", self, two_pow)
|
||
|
return lshift
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::where")
|
||
|
@symbolic_helper.parse_args("v", "v", "v", "i")
|
||
|
@_beartype.beartype
|
||
|
def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None):
|
||
|
# Assumes that torch.where's first argument takes only Bool and Byte tensors.
|
||
|
if not symbolic_helper._is_bool(condition):
|
||
|
condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL)
|
||
|
if self is None:
|
||
|
condition = nonzero(g, condition)
|
||
|
return symbolic_helper._unbind_helper(
|
||
|
g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs
|
||
|
)
|
||
|
return g.op("Where", condition, self, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::log_softmax")
|
||
|
@symbolic_helper.parse_args("v", "i", "none")
|
||
|
@_beartype.beartype
|
||
|
def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
|
||
|
# PyTorch dim and ONNX axis have different meanings.
|
||
|
# See Softmax comment for details.
|
||
|
# TODO: remove this as onnx opset 11 spec allows negative axes
|
||
|
input_dim = symbolic_helper._get_tensor_rank(input)
|
||
|
if input_dim is None:
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"dim",
|
||
|
"ONNX and PyTorch use different strategies to split the input. "
|
||
|
"Input rank must be known at export time.",
|
||
|
)
|
||
|
if dim < 0:
|
||
|
dim = input_dim + dim
|
||
|
is_transpose_required = input_dim != dim + 1
|
||
|
# ONNX only supports log_softmax with dim = -1. Transpose must be added before and after log_softmax to support other cases.
|
||
|
if is_transpose_required:
|
||
|
axes = list(range(input_dim))
|
||
|
axes[dim], axes[-1] = axes[-1], axes[dim]
|
||
|
input = g.op("Transpose", input, perm_i=axes)
|
||
|
dim = input_dim - 1
|
||
|
return_op = g.op("LogSoftmax", input, axis_i=dim)
|
||
|
if dtype and dtype.node().kind() != "prim::Constant":
|
||
|
parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||
|
return_op = g.op(
|
||
|
"Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()
|
||
|
)
|
||
|
if is_transpose_required:
|
||
|
return_op = g.op("Transpose", return_op, perm_i=axes) # type: ignore[possibly-undefined]
|
||
|
return return_op
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_log_softmax")
|
||
|
@symbolic_helper.parse_args("v", "i", "i")
|
||
|
@_beartype.beartype
|
||
|
def _log_softmax(g: jit_utils.GraphContext, input, dim, half_to_float):
|
||
|
if (
|
||
|
half_to_float
|
||
|
and _type_utils.JitScalarType.from_value(
|
||
|
input, _type_utils.JitScalarType.UNDEFINED
|
||
|
)
|
||
|
== _type_utils.JitScalarType.HALF
|
||
|
):
|
||
|
input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
||
|
return log_softmax(g, input, dim)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_convolution")
|
||
|
@symbolic_helper.parse_args(
|
||
|
"v", "v", "v", "is", "is", "is", "i", "is", "i", "i", "i", "i", "i"
|
||
|
)
|
||
|
@_beartype.beartype
|
||
|
def _convolution(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input,
|
||
|
weight,
|
||
|
bias,
|
||
|
stride,
|
||
|
padding,
|
||
|
dilation,
|
||
|
transposed,
|
||
|
output_padding,
|
||
|
groups,
|
||
|
benchmark,
|
||
|
deterministic,
|
||
|
cudnn_enabled,
|
||
|
allow_tf32=None,
|
||
|
):
|
||
|
weight_size = symbolic_helper._get_tensor_sizes(weight)
|
||
|
try:
|
||
|
kernel_shape = weight_size[2:]
|
||
|
except Exception:
|
||
|
# FIXME(justinchuby): Avoid catching Exception.
|
||
|
# Catch a more specific exception instead.
|
||
|
kernel_shape = None
|
||
|
|
||
|
if kernel_shape is None or any(i is None for i in kernel_shape):
|
||
|
raise errors.SymbolicValueError(
|
||
|
"Unsupported: ONNX export of convolution for kernel of unknown shape.",
|
||
|
input,
|
||
|
)
|
||
|
|
||
|
args = [input, weight]
|
||
|
# ONNX only supports 1D bias
|
||
|
if (
|
||
|
not symbolic_helper._is_none(bias)
|
||
|
and symbolic_helper._get_tensor_rank(bias) == 1
|
||
|
):
|
||
|
args.append(bias)
|
||
|
|
||
|
kwargs = {
|
||
|
"kernel_shape_i": weight_size[2:],
|
||
|
"strides_i": stride,
|
||
|
# NB: ONNX supports asymmetric padding, whereas PyTorch supports only
|
||
|
# symmetric padding
|
||
|
"pads_i": padding + padding,
|
||
|
"dilations_i": dilation,
|
||
|
"group_i": groups,
|
||
|
}
|
||
|
|
||
|
if any(o != 0 for o in output_padding):
|
||
|
# ONNX supports both output_shape and output_padding. they are equivalent expressive.
|
||
|
# output_padding is more straightforward, so we use it here.
|
||
|
# output_shape = stride * (input_shape - 1) + output_padding + kernel_shape - padding * 2
|
||
|
assert transposed
|
||
|
assert len(stride) == len(output_padding)
|
||
|
kwargs["output_padding_i"] = output_padding
|
||
|
|
||
|
n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs)
|
||
|
|
||
|
if (
|
||
|
not symbolic_helper._is_none(bias)
|
||
|
and symbolic_helper._get_tensor_rank(bias) != 1
|
||
|
):
|
||
|
return g.op("Add", n, bias)
|
||
|
else:
|
||
|
return n
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_convolution_mode")
|
||
|
@symbolic_helper.parse_args(
|
||
|
"v",
|
||
|
"v",
|
||
|
"v",
|
||
|
"is",
|
||
|
"s",
|
||
|
"is",
|
||
|
"i",
|
||
|
)
|
||
|
@_beartype.beartype
|
||
|
def _convolution_mode(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input,
|
||
|
weight,
|
||
|
bias,
|
||
|
stride,
|
||
|
padding,
|
||
|
dilation,
|
||
|
groups,
|
||
|
):
|
||
|
weight_size = symbolic_helper._get_tensor_sizes(weight)
|
||
|
try:
|
||
|
kernel_shape = weight_size[2:]
|
||
|
except Exception:
|
||
|
# FIXME(justinchuby): Avoid catching Exception.
|
||
|
# Catch a more specific exception instead.
|
||
|
kernel_shape = None
|
||
|
|
||
|
if kernel_shape is None or any(i is None for i in kernel_shape):
|
||
|
raise errors.SymbolicValueError(
|
||
|
"Unsupported: ONNX export of convolution for kernel of unknown shape.",
|
||
|
input,
|
||
|
)
|
||
|
|
||
|
args = [input, weight]
|
||
|
# ONNX only supports 1D bias
|
||
|
if (
|
||
|
not symbolic_helper._is_none(bias)
|
||
|
and symbolic_helper._get_tensor_rank(bias) == 1
|
||
|
):
|
||
|
args.append(bias)
|
||
|
|
||
|
if padding == "valid":
|
||
|
padding = "VALID"
|
||
|
elif padding == "same":
|
||
|
padding = "SAME_UPPER"
|
||
|
kwargs = {
|
||
|
"kernel_shape_i": weight_size[2:],
|
||
|
"strides_i": stride,
|
||
|
"auto_pad_s": padding,
|
||
|
"dilations_i": dilation,
|
||
|
"group_i": groups,
|
||
|
}
|
||
|
|
||
|
n = g.op("Conv", *args, **kwargs)
|
||
|
|
||
|
if (
|
||
|
not symbolic_helper._is_none(bias)
|
||
|
and symbolic_helper._get_tensor_rank(bias) != 1
|
||
|
):
|
||
|
return g.op("Add", n, bias)
|
||
|
else:
|
||
|
return n
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::convolution")
|
||
|
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is", "i")
|
||
|
@_beartype.beartype
|
||
|
def convolution(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input,
|
||
|
weight,
|
||
|
bias,
|
||
|
stride,
|
||
|
padding,
|
||
|
dilation,
|
||
|
transposed,
|
||
|
output_padding,
|
||
|
groups,
|
||
|
):
|
||
|
return _convolution(
|
||
|
g,
|
||
|
input,
|
||
|
weight,
|
||
|
bias,
|
||
|
stride,
|
||
|
padding,
|
||
|
dilation,
|
||
|
transposed,
|
||
|
output_padding,
|
||
|
groups,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::conv1d")
|
||
|
@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i")
|
||
|
@_beartype.beartype
|
||
|
def conv1d(
|
||
|
g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups
|
||
|
):
|
||
|
str_padding = symbolic_helper._parse_arg(padding, "s")
|
||
|
if str_padding in ["valid", "same"]:
|
||
|
return _convolution_mode(
|
||
|
g,
|
||
|
input,
|
||
|
weight,
|
||
|
bias,
|
||
|
stride,
|
||
|
str_padding,
|
||
|
dilation,
|
||
|
groups,
|
||
|
)
|
||
|
else:
|
||
|
padding = symbolic_helper._parse_arg(padding, "is")
|
||
|
return _convolution(
|
||
|
g,
|
||
|
input,
|
||
|
weight,
|
||
|
bias,
|
||
|
stride,
|
||
|
padding,
|
||
|
dilation,
|
||
|
False,
|
||
|
(),
|
||
|
groups,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::conv2d")
|
||
|
@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i")
|
||
|
@_beartype.beartype
|
||
|
def conv2d(
|
||
|
g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups
|
||
|
):
|
||
|
str_padding = symbolic_helper._parse_arg(padding, "s")
|
||
|
if str_padding in ["valid", "same"]:
|
||
|
return _convolution_mode(
|
||
|
g,
|
||
|
input,
|
||
|
weight,
|
||
|
bias,
|
||
|
stride,
|
||
|
str_padding,
|
||
|
dilation,
|
||
|
groups,
|
||
|
)
|
||
|
else:
|
||
|
padding = symbolic_helper._parse_arg(padding, "is")
|
||
|
return _convolution(
|
||
|
g,
|
||
|
input,
|
||
|
weight,
|
||
|
bias,
|
||
|
stride,
|
||
|
padding,
|
||
|
dilation,
|
||
|
False,
|
||
|
(),
|
||
|
groups,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::conv3d")
|
||
|
@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i")
|
||
|
@_beartype.beartype
|
||
|
def conv3d(
|
||
|
g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups
|
||
|
):
|
||
|
str_padding = symbolic_helper._parse_arg(padding, "s")
|
||
|
if str_padding in ["valid", "same"]:
|
||
|
return _convolution_mode(
|
||
|
g,
|
||
|
input,
|
||
|
weight,
|
||
|
bias,
|
||
|
stride,
|
||
|
str_padding,
|
||
|
dilation,
|
||
|
groups,
|
||
|
)
|
||
|
else:
|
||
|
padding = symbolic_helper._parse_arg(padding, "is")
|
||
|
return _convolution(
|
||
|
g,
|
||
|
input,
|
||
|
weight,
|
||
|
bias,
|
||
|
stride,
|
||
|
padding,
|
||
|
dilation,
|
||
|
False,
|
||
|
(),
|
||
|
groups,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::conv_transpose1d")
|
||
|
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is")
|
||
|
@_beartype.beartype
|
||
|
def conv_transpose1d(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input,
|
||
|
weight,
|
||
|
bias,
|
||
|
stride,
|
||
|
padding,
|
||
|
output_padding,
|
||
|
groups,
|
||
|
dilation,
|
||
|
):
|
||
|
return _convolution(
|
||
|
g,
|
||
|
input,
|
||
|
weight,
|
||
|
bias,
|
||
|
stride,
|
||
|
padding,
|
||
|
dilation,
|
||
|
True,
|
||
|
output_padding,
|
||
|
groups,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::conv_transpose2d")
|
||
|
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is")
|
||
|
@_beartype.beartype
|
||
|
def conv_transpose2d(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input,
|
||
|
weight,
|
||
|
bias,
|
||
|
stride,
|
||
|
padding,
|
||
|
output_padding,
|
||
|
groups,
|
||
|
dilation,
|
||
|
):
|
||
|
return _convolution(
|
||
|
g,
|
||
|
input,
|
||
|
weight,
|
||
|
bias,
|
||
|
stride,
|
||
|
padding,
|
||
|
dilation,
|
||
|
True,
|
||
|
output_padding,
|
||
|
groups,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::conv_transpose3d")
|
||
|
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is")
|
||
|
@_beartype.beartype
|
||
|
def conv_transpose3d(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input,
|
||
|
weight,
|
||
|
bias,
|
||
|
stride,
|
||
|
padding,
|
||
|
output_padding,
|
||
|
groups,
|
||
|
dilation,
|
||
|
):
|
||
|
return _convolution(
|
||
|
g,
|
||
|
input,
|
||
|
weight,
|
||
|
bias,
|
||
|
stride,
|
||
|
padding,
|
||
|
dilation,
|
||
|
True,
|
||
|
output_padding,
|
||
|
groups,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::batch_norm")
|
||
|
@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i")
|
||
|
@_beartype.beartype
|
||
|
def batch_norm(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input,
|
||
|
weight,
|
||
|
bias,
|
||
|
running_mean,
|
||
|
running_var,
|
||
|
training,
|
||
|
momentum,
|
||
|
eps,
|
||
|
cudnn_enabled,
|
||
|
):
|
||
|
symbolic_helper.check_training_mode(training, "batch_norm")
|
||
|
|
||
|
if (
|
||
|
torch.is_autocast_enabled()
|
||
|
and not symbolic_helper.args_have_same_dtype(
|
||
|
[input, weight, bias, running_mean, running_var]
|
||
|
)
|
||
|
and GLOBALS.export_onnx_opset_version < 15
|
||
|
):
|
||
|
return symbolic_helper._onnx_opset_unsupported_detailed(
|
||
|
"BatchNormalization",
|
||
|
9,
|
||
|
15,
|
||
|
"All input tensors must have the same `dtype`."
|
||
|
" Turn off Autocast or export using opset version 15.",
|
||
|
input,
|
||
|
)
|
||
|
|
||
|
weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper(
|
||
|
g, input, weight, bias, running_mean, running_var
|
||
|
)
|
||
|
out = g.op(
|
||
|
"BatchNormalization",
|
||
|
input,
|
||
|
weight,
|
||
|
bias,
|
||
|
running_mean,
|
||
|
running_var,
|
||
|
epsilon_f=eps,
|
||
|
momentum_f=1 - momentum,
|
||
|
outputs=1 if not training else 5,
|
||
|
)
|
||
|
if not training:
|
||
|
return out
|
||
|
else:
|
||
|
res, new_running_mean, new_running_var, saved_mean, saved_var = out
|
||
|
new_running_mean.setType(running_mean.type())
|
||
|
new_running_var.setType(running_var.type())
|
||
|
saved_mean.setDebugName("batch_norm_dead_output-" + saved_mean.debugName())
|
||
|
saved_var.setDebugName("batch_norm_dead_output-" + saved_var.debugName())
|
||
|
return res
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::native_layer_norm")
|
||
|
@symbolic_helper.quantized_args(True, False, False, False)
|
||
|
@symbolic_helper.parse_args("v", "is", "v", "v", "f")
|
||
|
@_beartype.beartype
|
||
|
def native_layer_norm(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input: _C.Value,
|
||
|
normalized_shape: Sequence[int],
|
||
|
weight: _C.Value,
|
||
|
bias: _C.Value,
|
||
|
eps: float,
|
||
|
) -> Tuple[_C.Value, _C.Value, _C.Value]:
|
||
|
axes = [-i for i in range(len(normalized_shape), 0, -1)]
|
||
|
|
||
|
two_cst = symbolic_helper._generate_wrapped_number(g, 2.0)
|
||
|
eps_cst = symbolic_helper._generate_wrapped_number(g, eps)
|
||
|
|
||
|
mean = g.op("ReduceMean", input, axes_i=axes)
|
||
|
numerator = sub(g, input, mean)
|
||
|
|
||
|
# Cast it to eps dtype to avoid precision loss
|
||
|
is_type_half = (
|
||
|
_type_utils.JitScalarType.from_value(numerator)
|
||
|
== _type_utils.JitScalarType.HALF
|
||
|
)
|
||
|
if is_type_half:
|
||
|
eps_dtype = _type_utils.JitScalarType.from_value(eps_cst)
|
||
|
numerator = g.op(
|
||
|
"Cast", numerator, to_i=_type_utils.JitScalarType(eps_dtype).onnx_type()
|
||
|
)
|
||
|
|
||
|
# variance = e((x - e(x))^2), and (x - e(x)) is the numerator in the layer_norm formula
|
||
|
variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes)
|
||
|
denominator = sqrt(g, g.op("Add", variance, eps_cst))
|
||
|
normalized = g.op("Div", numerator, denominator)
|
||
|
|
||
|
# Cast back to input type as eps related ops are all done
|
||
|
if is_type_half:
|
||
|
input_dtype = _type_utils.JitScalarType.from_value(input)
|
||
|
normalized = g.op(
|
||
|
"Cast", normalized, to_i=_type_utils.JitScalarType(input_dtype).onnx_type()
|
||
|
)
|
||
|
|
||
|
if not (weight is None or symbolic_helper._is_none(weight)):
|
||
|
normalized = mul(g, normalized, weight)
|
||
|
if not (bias is None or symbolic_helper._is_none(bias)):
|
||
|
normalized = add(g, normalized, bias)
|
||
|
|
||
|
# rdenominator := 1 / sqrt(variance + eps)
|
||
|
# According to aten::native_layer_norm, rdenominator should have the same dtype as input,
|
||
|
# mean and normalized, so we need to Cast it back
|
||
|
if is_type_half:
|
||
|
denominator = g.op(
|
||
|
"Cast", denominator, to_i=_type_utils.JitScalarType(input_dtype).onnx_type() # type: ignore[possibly-undefined]
|
||
|
)
|
||
|
rdenominator = g.op("Reciprocal", denominator)
|
||
|
else:
|
||
|
rdenominator = reciprocal(g, denominator)
|
||
|
|
||
|
return normalized, mean, rdenominator
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::layer_norm")
|
||
|
@symbolic_helper.quantized_args(True, False, False, False)
|
||
|
@symbolic_helper.parse_args("v", "is", "v", "v", "f", "b")
|
||
|
@_beartype.beartype
|
||
|
def layer_norm(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input: _C.Value,
|
||
|
normalized_shape: Sequence[int],
|
||
|
weight: _C.Value,
|
||
|
bias: _C.Value,
|
||
|
eps: float,
|
||
|
cudnn_enable: bool,
|
||
|
) -> _C.Value:
|
||
|
if symbolic_helper.is_caffe2_aten_fallback():
|
||
|
return g.at(
|
||
|
"layer_norm",
|
||
|
input,
|
||
|
weight,
|
||
|
bias,
|
||
|
normalized_shape_i=normalized_shape,
|
||
|
eps_f=eps,
|
||
|
cudnn_enable_i=cudnn_enable,
|
||
|
)
|
||
|
normalized, _, _ = native_layer_norm(g, input, normalized_shape, weight, bias, eps)
|
||
|
return normalized
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::instance_norm")
|
||
|
@symbolic_helper.parse_args("v", "v", "v", "v", "v", "b", "f", "f", "b")
|
||
|
@_beartype.beartype
|
||
|
def instance_norm(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input,
|
||
|
weight,
|
||
|
bias,
|
||
|
running_mean,
|
||
|
running_var,
|
||
|
use_input_stats: bool,
|
||
|
momentum: Number,
|
||
|
eps: Number,
|
||
|
cudnn_enabled: bool,
|
||
|
):
|
||
|
symbolic_helper.check_training_mode(use_input_stats, "instance_norm")
|
||
|
channel_size = symbolic_helper._get_tensor_dim_size(input, 1)
|
||
|
if weight is None or symbolic_helper._is_none(weight):
|
||
|
if channel_size is None:
|
||
|
raise errors.SymbolicValueError(
|
||
|
"Unsupported: ONNX export of instance_norm for unknown channel size.",
|
||
|
input,
|
||
|
)
|
||
|
weight_value = torch.tensor(
|
||
|
[1.0] * channel_size,
|
||
|
dtype=_type_utils.JitScalarType.from_value(input).dtype(),
|
||
|
)
|
||
|
weight = g.op("Constant", value_t=weight_value)
|
||
|
if bias is None or symbolic_helper._is_none(bias):
|
||
|
if channel_size is None:
|
||
|
raise errors.SymbolicValueError(
|
||
|
"Unsupported: ONNX export of instance_norm for unknown channel size.",
|
||
|
input,
|
||
|
)
|
||
|
bias_value = torch.tensor(
|
||
|
[0.0] * channel_size,
|
||
|
dtype=_type_utils.JitScalarType.from_value(input).dtype(),
|
||
|
)
|
||
|
bias = g.op("Constant", value_t=bias_value)
|
||
|
if (
|
||
|
running_mean is None
|
||
|
or symbolic_helper._is_none(running_mean)
|
||
|
or running_var is None
|
||
|
or symbolic_helper._is_none(running_var)
|
||
|
):
|
||
|
return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps)
|
||
|
else:
|
||
|
input_size = symbolic_helper._get_tensor_sizes(input)
|
||
|
# If input shape is [N, C, H, W], reshape to [1, N * C, H, W] and call batch_norm.
|
||
|
# For more information instance_norm():
|
||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L542
|
||
|
input_size_reshape = input_size.copy()
|
||
|
n = input_size[0]
|
||
|
if n is None:
|
||
|
raise errors.SymbolicValueError(
|
||
|
"Unsupported: ONNX export of instance_norm training for unknown "
|
||
|
"batch size.",
|
||
|
input,
|
||
|
)
|
||
|
c = input_size[1]
|
||
|
input_size_reshape[0] = 1
|
||
|
input_size_reshape[1] = n * c
|
||
|
weight_ = repeat(
|
||
|
g, weight, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64))
|
||
|
)
|
||
|
bias_ = repeat(
|
||
|
g, bias, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64))
|
||
|
)
|
||
|
running_mean_ = repeat(
|
||
|
g,
|
||
|
running_mean,
|
||
|
g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)),
|
||
|
)
|
||
|
running_var_ = repeat(
|
||
|
g,
|
||
|
running_var,
|
||
|
g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)),
|
||
|
)
|
||
|
input_reshaped = g.op(
|
||
|
"Reshape",
|
||
|
input,
|
||
|
g.op("Constant", value_t=torch.LongTensor(input_size_reshape)),
|
||
|
)
|
||
|
out = batch_norm(
|
||
|
g,
|
||
|
input_reshaped,
|
||
|
weight_,
|
||
|
bias_,
|
||
|
running_mean_,
|
||
|
running_var_,
|
||
|
use_input_stats,
|
||
|
momentum,
|
||
|
eps,
|
||
|
cudnn_enabled,
|
||
|
)
|
||
|
return view(g, out, g.op("Constant", value_t=torch.tensor(input_size)))
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::unfold")
|
||
|
@symbolic_helper.parse_args("v", "i", "i", "i")
|
||
|
@_beartype.beartype
|
||
|
def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
|
||
|
if symbolic_helper.is_caffe2_aten_fallback():
|
||
|
return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step)
|
||
|
sizes = symbolic_helper._get_tensor_sizes(input)
|
||
|
# FIXME(justinchuby): Get rid of the try catch here to improve readability
|
||
|
try:
|
||
|
sizedim = sizes[dimension]
|
||
|
except Exception:
|
||
|
# FIXME(justinchuby): Avoid catching Exception.
|
||
|
# Catch a more specific exception instead.
|
||
|
sizedim = None
|
||
|
if sizedim is not None:
|
||
|
low_indices = range(0, sizedim, step)
|
||
|
hi_indices = range(size, sizedim + 1, step)
|
||
|
stack = [
|
||
|
symbolic_helper._slice_helper(
|
||
|
g, input, axes=[dimension], starts=[low], ends=[hi]
|
||
|
)
|
||
|
for low, hi in zip(low_indices, hi_indices)
|
||
|
]
|
||
|
ndim = len(sizes)
|
||
|
perm = list(range(0, ndim))
|
||
|
perm.append(perm.pop(dimension))
|
||
|
unsqueeze = [
|
||
|
symbolic_helper._unsqueeze_helper(
|
||
|
g, g.op("Transpose", t, perm_i=perm), [dimension]
|
||
|
)
|
||
|
for t in stack
|
||
|
]
|
||
|
return g.op("Concat", *unsqueeze, axis_i=dimension)
|
||
|
else:
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"Unfold", "input size not accessible", input
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::elu")
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@symbolic_helper.parse_args("v", "t", "t", "t")
|
||
|
@_beartype.beartype
|
||
|
def elu(g: jit_utils.GraphContext, input, alpha, scale, input_scale):
|
||
|
if scale and scale != 1.0:
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"scale", "does not support scale in Elu", scale
|
||
|
)
|
||
|
if input_scale and input_scale != 1.0:
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"input_scale", "does not support input_scale in Elu", input_scale
|
||
|
)
|
||
|
# See Note [Export inplace]
|
||
|
return g.op("Elu", input, alpha_f=symbolic_helper._scalar(alpha))
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::selu")
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@_beartype.beartype
|
||
|
def selu(g: jit_utils.GraphContext, input):
|
||
|
return g.op("Selu", input)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::index_select")
|
||
|
@symbolic_helper.parse_args("v", "i", "v")
|
||
|
@_beartype.beartype
|
||
|
def index_select(g: jit_utils.GraphContext, self, dim, index):
|
||
|
# In case of a scalar index, index_select returns a tensor with the same rank as the input.
|
||
|
# To match this behavior in ONNX, we make index a 1D tensor so that the following gather
|
||
|
# also produces a tensor with the same rank as the input.
|
||
|
return symbolic_helper._select_helper(g, self, dim, index)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::index_put")
|
||
|
@_beartype.beartype
|
||
|
def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accumulate):
|
||
|
if symbolic_helper._is_packed_list(indices_list_value):
|
||
|
indices_list = symbolic_helper._unpack_list(indices_list_value)
|
||
|
else:
|
||
|
indices_list = [indices_list_value]
|
||
|
if symbolic_helper.is_caffe2_aten_fallback():
|
||
|
args = [self] + indices_list + [values, accumulate]
|
||
|
return g.at("index_put", *args)
|
||
|
|
||
|
accumulate = symbolic_helper._parse_arg(accumulate, "b")
|
||
|
|
||
|
if len(indices_list) == 0:
|
||
|
if accumulate:
|
||
|
return add(g, self, values)
|
||
|
return values
|
||
|
symbolic_helper._onnx_opset_unsupported("index_put", 9, 11, self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::index_fill")
|
||
|
@_beartype.beartype
|
||
|
def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
|
||
|
dim_value = symbolic_helper._parse_arg(dim, "i")
|
||
|
if symbolic_helper.is_caffe2_aten_fallback():
|
||
|
return g.at(
|
||
|
"index_fill",
|
||
|
self,
|
||
|
index,
|
||
|
value,
|
||
|
overload_name="int_Scalar",
|
||
|
dim_i=dim_value,
|
||
|
)
|
||
|
|
||
|
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
|
||
|
g, self, dim, index
|
||
|
)
|
||
|
value = symbolic_helper._maybe_get_scalar(value)
|
||
|
value = symbolic_helper._if_scalar_type_as(value, self)
|
||
|
expanded_value = expand(g, value, expanded_index_shape, None)
|
||
|
|
||
|
return scatter(g, self, dim, expanded_index, expanded_value)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::index_copy")
|
||
|
@_beartype.beartype
|
||
|
def index_copy(g: jit_utils.GraphContext, self, dim, index, source):
|
||
|
dim_value = symbolic_helper._parse_arg(dim, "i")
|
||
|
if symbolic_helper.is_caffe2_aten_fallback():
|
||
|
return g.at("index_copy", self, index, source, dim_i=dim_value)
|
||
|
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
|
||
|
g, self, dim, index
|
||
|
)
|
||
|
return scatter(g, self, dim, expanded_index, source)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::bucketize")
|
||
|
@symbolic_helper.parse_args("v", "v", "b", "b")
|
||
|
@_beartype.beartype
|
||
|
def bucketize(
|
||
|
g: jit_utils.GraphContext, self, boundaries, out_int32=False, right=False
|
||
|
):
|
||
|
out_type = _C_onnx.TensorProtoDataType.INT64
|
||
|
if out_int32:
|
||
|
out_type = _C_onnx.TensorProtoDataType.INT32
|
||
|
# A tensor expanded_boundaries is created such that it
|
||
|
# contains a copy of boundaries for each element of self.
|
||
|
new_shape = g.op("Concat", g.op("Shape", boundaries), g.op("Shape", self), axis_i=0)
|
||
|
# Unsqueeze step is performed to respect ONNX's numpy style broadcasting for comparison ops
|
||
|
# https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md
|
||
|
tensor_rank = symbolic_helper._get_tensor_rank(self)
|
||
|
assert tensor_rank is not None
|
||
|
unsqueeze_axes = list(range(1, tensor_rank + 1))
|
||
|
expanded_boundaries = expand(
|
||
|
g,
|
||
|
symbolic_helper._unsqueeze_helper(g, boundaries, unsqueeze_axes),
|
||
|
new_shape,
|
||
|
None,
|
||
|
)
|
||
|
# Compare each element of self to boundaries to get a tensor
|
||
|
# with leading 1s and trailing 0s.
|
||
|
# e.g., 4 > [1, 3, 4] = [1, 1, 0]
|
||
|
# The index of the last 1 is the bucket where the element should go.
|
||
|
if right:
|
||
|
cond = ge(g, self, expanded_boundaries)
|
||
|
else:
|
||
|
cond = gt(g, self, expanded_boundaries)
|
||
|
cond_out = g.op("Cast", cond, to_i=out_type)
|
||
|
# Sum to get the number of 1s corresponding to each element,
|
||
|
# which is the same as the bucket index.
|
||
|
# e.g., sum(4 > [1, 3, 4]) = sum([1, 1, 0]) = 2
|
||
|
return symbolic_helper._reducesum_helper(g, cond_out, axes_i=[0], keepdims_i=0)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::type_as")
|
||
|
@_beartype.beartype
|
||
|
def type_as(g: jit_utils.GraphContext, self, other):
|
||
|
self_dtype = symbolic_helper._try_get_scalar_type(self)
|
||
|
other_dtype = symbolic_helper._try_get_scalar_type(other)
|
||
|
if self_dtype == other_dtype and self_dtype is not None:
|
||
|
return self
|
||
|
if other_dtype is not None:
|
||
|
return g.op(
|
||
|
"Cast",
|
||
|
self,
|
||
|
to_i=other_dtype.onnx_type(),
|
||
|
)
|
||
|
|
||
|
if symbolic_helper.is_caffe2_aten_fallback():
|
||
|
# We don't know the type of other, bail by emitting ATen
|
||
|
return g.at("type_as", self, other)
|
||
|
|
||
|
raise errors.SymbolicValueError(
|
||
|
"Unsupported: ONNX export of type_as for tensor "
|
||
|
"of unknown dtype. Please check if the dtype of the "
|
||
|
"parameter passed to the type_as function is correct.",
|
||
|
other,
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::cosine_similarity")
|
||
|
@symbolic_helper.parse_args("v", "v", "i", "f")
|
||
|
@_beartype.beartype
|
||
|
def cosine_similarity(g: jit_utils.GraphContext, x1, x2, dim, eps):
|
||
|
if symbolic_helper.is_caffe2_aten_fallback():
|
||
|
return g.at("cosine_similarity", x1, x2, dim_i=dim, eps_f=eps)
|
||
|
cross = symbolic_helper._reducesum_helper(
|
||
|
g, mul(g, x1, x2), axes_i=[dim], keepdims_i=0
|
||
|
)
|
||
|
x1_l2 = symbolic_helper._reducesum_helper(
|
||
|
g, mul(g, x1, x1), axes_i=[dim], keepdims_i=0
|
||
|
)
|
||
|
x2_l2 = symbolic_helper._reducesum_helper(
|
||
|
g, mul(g, x2, x2), axes_i=[dim], keepdims_i=0
|
||
|
)
|
||
|
div_tens = max(
|
||
|
g, sqrt(g, mul(g, x1_l2, x2_l2)), g.op("Constant", value_t=torch.tensor([eps]))
|
||
|
)
|
||
|
return div(g, cross, div_tens)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::pairwise_distance")
|
||
|
@_beartype.beartype
|
||
|
def pairwise_distance(g: jit_utils.GraphContext, input1, input2, p, eps, keepdim):
|
||
|
if not symbolic_helper._is_value(eps):
|
||
|
eps = g.op("Constant", value_t=torch.tensor([eps]))
|
||
|
inv_p = div(
|
||
|
g,
|
||
|
g.op("Constant", value_t=torch.tensor([1], dtype=torch.float)),
|
||
|
add(g, p, eps),
|
||
|
)
|
||
|
summation = symbolic_helper._reducesum_helper(
|
||
|
g,
|
||
|
pow(g, sub(g, input1, input2), p),
|
||
|
axes_i=[-1],
|
||
|
keepdims_i=symbolic_helper._parse_arg(keepdim, "i"),
|
||
|
)
|
||
|
return pow(g, summation, inv_p)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::clone")
|
||
|
# ignore clone operators that are inserted by PyTorch autograd
|
||
|
@_beartype.beartype
|
||
|
def clone(g: jit_utils.GraphContext, input, unused_memory_format):
|
||
|
return input
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::abs")
|
||
|
@_beartype.beartype
|
||
|
def abs(g: jit_utils.GraphContext, self):
|
||
|
return g.op("Abs", self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::log")
|
||
|
@_beartype.beartype
|
||
|
def log(g: jit_utils.GraphContext, self):
|
||
|
return g.op("Log", self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::log1p")
|
||
|
@_beartype.beartype
|
||
|
def log1p(g: jit_utils.GraphContext, self):
|
||
|
return log(g, add(g, symbolic_helper._if_scalar_type_as(torch.ones(1), self), self))
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::log10")
|
||
|
@_beartype.beartype
|
||
|
def log10(g: jit_utils.GraphContext, self):
|
||
|
_ln10 = 2.30258509299404568401
|
||
|
return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor([_ln10])))
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::pow")
|
||
|
@_beartype.beartype
|
||
|
def pow(g: jit_utils.GraphContext, self, exponent):
|
||
|
f_dtype = _type_utils.JitScalarType.from_value(self)
|
||
|
if not symbolic_helper._is_fp(self):
|
||
|
f_dtype = _type_utils.JitScalarType.FLOAT
|
||
|
self = g.op("Cast", self, to_i=f_dtype.onnx_type())
|
||
|
if not symbolic_helper._is_fp(exponent):
|
||
|
exponent = g.op(
|
||
|
"Cast",
|
||
|
exponent,
|
||
|
to_i=f_dtype.onnx_type(),
|
||
|
)
|
||
|
pow = g.op("Pow", self, exponent)
|
||
|
return pow
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::clamp")
|
||
|
@_beartype.beartype
|
||
|
def clamp(g: jit_utils.GraphContext, self, min, max):
|
||
|
# min or max may be None that we need to dispatch to
|
||
|
# Clip separately, as ONNX does not have None syntax
|
||
|
if symbolic_helper._is_none(min):
|
||
|
return clamp_max(g, self, max)
|
||
|
elif symbolic_helper._is_none(max):
|
||
|
return clamp_min(g, self, min)
|
||
|
else:
|
||
|
if symbolic_helper._is_constant(min) and symbolic_helper._is_constant(max):
|
||
|
return _op_with_optional_float_cast(
|
||
|
g,
|
||
|
"Clip",
|
||
|
self,
|
||
|
min_f=symbolic_helper._parse_arg(min, "f"),
|
||
|
max_f=symbolic_helper._parse_arg(max, "f"),
|
||
|
opset_before=12,
|
||
|
)
|
||
|
else:
|
||
|
return clamp_max(g, clamp_min(g, self, min), max)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::clamp_min")
|
||
|
@symbolic_helper.parse_args("v", "v")
|
||
|
@_beartype.beartype
|
||
|
def clamp_min(g: jit_utils.GraphContext, self, min):
|
||
|
if symbolic_helper._is_constant(min):
|
||
|
return _op_with_optional_float_cast(
|
||
|
g, "Clip", self, min_f=symbolic_helper._parse_arg(min, "f"), opset_before=12
|
||
|
)
|
||
|
else:
|
||
|
dtype = _type_utils.JitScalarType.from_value(self)
|
||
|
min = g.op("Cast", min, to_i=dtype.onnx_type())
|
||
|
return _op_with_optional_float_cast(g, "Max", self, min, opset_before=12)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::clamp_max")
|
||
|
@symbolic_helper.parse_args("v", "v")
|
||
|
@_beartype.beartype
|
||
|
def clamp_max(g: jit_utils.GraphContext, self, max):
|
||
|
if symbolic_helper._is_constant(max):
|
||
|
return _op_with_optional_float_cast(
|
||
|
g, "Clip", self, max_f=symbolic_helper._parse_arg(max, "f"), opset_before=12
|
||
|
)
|
||
|
else:
|
||
|
dtype = _type_utils.JitScalarType.from_value(self)
|
||
|
max = g.op("Cast", max, to_i=dtype.onnx_type())
|
||
|
return _op_with_optional_float_cast(g, "Min", self, max, opset_before=12)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::max")
|
||
|
# torch.max (same for torch.min) actually has two interfaces smashed together:
|
||
|
# torch.max(x, dim, keepdim) and torch.max(x, y)
|
||
|
# TODO(justinchuby): Support multiple quantized args in output
|
||
|
@_beartype.beartype
|
||
|
def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||
|
# torch.max(input)
|
||
|
if dim_or_y is None and keepdim is None:
|
||
|
return g.op("ReduceMax", self, keepdims_i=0)
|
||
|
# torch.max(input, other)
|
||
|
if keepdim is None:
|
||
|
return _op_with_optional_float_cast(g, "Max", self, dim_or_y, opset_before=12)
|
||
|
# torch.max(input, dim, keepdim)
|
||
|
else:
|
||
|
dim = symbolic_helper._get_const(dim_or_y, "i", "dim")
|
||
|
keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim")
|
||
|
max = g.op("ReduceMax", self, axes_i=[dim], keepdims_i=keepdim)
|
||
|
indices = g.op("ArgMax", self, axis_i=dim, keepdims_i=keepdim)
|
||
|
return max, indices
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::maximum")
|
||
|
@symbolic_helper.quantized_args(True, True)
|
||
|
@_beartype.beartype
|
||
|
def maximum(g: jit_utils.GraphContext, input, other):
|
||
|
return max(g, input, dim_or_y=other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::min")
|
||
|
# TODO(justinchuby): Support multiple quantized args in output
|
||
|
@_beartype.beartype
|
||
|
def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
|
||
|
# torch.min(input)
|
||
|
if dim_or_y is None and keepdim is None:
|
||
|
return g.op("ReduceMin", self, keepdims_i=0)
|
||
|
# torch.min(input, other)
|
||
|
if keepdim is None:
|
||
|
return _op_with_optional_float_cast(g, "Min", self, dim_or_y, opset_before=12)
|
||
|
# torch.min(input, dim, keepdim)
|
||
|
else:
|
||
|
dim = symbolic_helper._get_const(dim_or_y, "i", "dim")
|
||
|
keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim")
|
||
|
min = g.op("ReduceMin", self, axes_i=[dim], keepdims_i=keepdim)
|
||
|
indices = g.op("ArgMin", self, axis_i=dim, keepdims_i=keepdim)
|
||
|
return min, indices
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::minimum")
|
||
|
@symbolic_helper.quantized_args(True, True)
|
||
|
@_beartype.beartype
|
||
|
def minimum(g: jit_utils.GraphContext, input, other):
|
||
|
return min(g, input, dim_or_y=other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::amax")
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@symbolic_helper.parse_args("v", "is", "i")
|
||
|
@_beartype.beartype
|
||
|
def amax(g: jit_utils.GraphContext, self, dim, keepdim):
|
||
|
return g.op("ReduceMax", self, axes_i=dim, keepdims_i=keepdim)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::amin")
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@symbolic_helper.parse_args("v", "is", "i")
|
||
|
@_beartype.beartype
|
||
|
def amin(g: jit_utils.GraphContext, self, dim, keepdim):
|
||
|
return g.op("ReduceMin", self, axes_i=dim, keepdims_i=keepdim)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::aminmax")
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@symbolic_helper.parse_args("v", "v", "i")
|
||
|
@_beartype.beartype
|
||
|
def aminmax(g: jit_utils.GraphContext, self, dim, keepdim):
|
||
|
reduce_kwargs = {"keepdims_i": keepdim}
|
||
|
if not symbolic_helper._is_none(dim):
|
||
|
dim = symbolic_helper._get_const(dim, "i", "dim")
|
||
|
reduce_kwargs["axes_i"] = [dim]
|
||
|
|
||
|
return g.op("ReduceMin", self, **reduce_kwargs), g.op(
|
||
|
"ReduceMax", self, **reduce_kwargs
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::exp")
|
||
|
@_beartype.beartype
|
||
|
def exp(g: jit_utils.GraphContext, self):
|
||
|
return g.op("Exp", self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::dropout_")
|
||
|
@_onnx_symbolic("aten::dropout")
|
||
|
@symbolic_helper.parse_args("v", "f", "i")
|
||
|
@_beartype.beartype
|
||
|
def dropout(g: jit_utils.GraphContext, input, p, train):
|
||
|
symbolic_helper.check_training_mode(train, "dropout")
|
||
|
# if train is False, dropout is no-op
|
||
|
if not train:
|
||
|
return input
|
||
|
r, _ = g.op("Dropout", input, ratio_f=p, outputs=2)
|
||
|
return r
|
||
|
|
||
|
|
||
|
@_onnx_symbolic(
|
||
|
"aten::alpha_dropout_", decorate=[_apply_params("aten::alpha_dropout_")]
|
||
|
) # See Note [Export inplace]
|
||
|
@_onnx_symbolic(
|
||
|
"aten::feature_alpha_dropout_",
|
||
|
decorate=[_apply_params("aten::feature_alpha_dropout_")],
|
||
|
)
|
||
|
@_onnx_symbolic(
|
||
|
"aten::feature_dropout_", decorate=[_apply_params("aten::feature_dropout_")]
|
||
|
)
|
||
|
@_onnx_symbolic(
|
||
|
"aten::feature_alpha_dropout",
|
||
|
decorate=[_apply_params("aten::feature_alpha_dropout")],
|
||
|
)
|
||
|
@_onnx_symbolic("aten::alpha_dropout", decorate=[_apply_params("aten::alpha_dropout")])
|
||
|
@_onnx_symbolic(
|
||
|
"aten::feature_dropout", decorate=[_apply_params("aten::feature_dropout")]
|
||
|
)
|
||
|
@_beartype.beartype
|
||
|
def _unsupported_dropout(name: str):
|
||
|
@symbolic_helper.parse_args("v", "none", "b")
|
||
|
@_beartype.beartype
|
||
|
def feature_dropout(g, input, p, train):
|
||
|
# NB: In inference mode, FeatureDropout is exported as an identity op.
|
||
|
if train:
|
||
|
return symbolic_helper._unimplemented(name, "training mode", input)
|
||
|
return input
|
||
|
|
||
|
return feature_dropout
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::norm")
|
||
|
@symbolic_helper.parse_args("v", "t", "is", "i", "v")
|
||
|
@_beartype.beartype
|
||
|
def norm(g: jit_utils.GraphContext, self, p, dim, keepdim, dtype=None):
|
||
|
if p == 1:
|
||
|
f = _reduce_op_symbolic("ReduceL1")
|
||
|
elif p == 2:
|
||
|
f = _reduce_op_symbolic("ReduceL2")
|
||
|
else:
|
||
|
raise errors.SymbolicValueError(
|
||
|
"ONNX export only p-norms with p of 1 or 2", self
|
||
|
)
|
||
|
result = f(g, self, dim=dim, keepdim=keepdim)
|
||
|
if dtype is not None:
|
||
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||
|
result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type())
|
||
|
return result
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::conv_tbc")
|
||
|
@symbolic_helper.parse_args("v", "v", "v", "i")
|
||
|
@_beartype.beartype
|
||
|
def conv_tbc(g: jit_utils.GraphContext, input, weight, bias, pad):
|
||
|
if symbolic_helper.is_caffe2_aten_fallback():
|
||
|
return g.at("conv_tbc", input, weight, bias, pad_i=pad)
|
||
|
else:
|
||
|
# input must have 3 dimensions, see:
|
||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10
|
||
|
# input = (time, batch, in_channels)
|
||
|
# weight = (kernel_width, in_channels, out_channels)
|
||
|
# bias = (out_channels,)
|
||
|
input = g.op("Transpose", input, perm_i=[1, 2, 0])
|
||
|
weight = g.op("Transpose", weight, perm_i=[2, 1, 0])
|
||
|
conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1)
|
||
|
return g.op("Transpose", conv, perm_i=[2, 0, 1])
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_unique")
|
||
|
@symbolic_helper.parse_args("v", "i", "i")
|
||
|
@_beartype.beartype
|
||
|
def _unique(g: jit_utils.GraphContext, input, sorted, return_inverse):
|
||
|
if symbolic_helper.is_caffe2_aten_fallback():
|
||
|
return g.at(
|
||
|
"_unique",
|
||
|
input,
|
||
|
sorted_i=sorted,
|
||
|
return_inverse_i=return_inverse,
|
||
|
outputs=2,
|
||
|
)
|
||
|
else:
|
||
|
return symbolic_helper._onnx_unsupported("_unique", input)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_unique2")
|
||
|
@symbolic_helper.parse_args("v", "i", "i", "i")
|
||
|
@_beartype.beartype
|
||
|
def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts):
|
||
|
if symbolic_helper.is_caffe2_aten_fallback():
|
||
|
return g.at(
|
||
|
"_unique2",
|
||
|
input,
|
||
|
sorted_i=sorted,
|
||
|
return_inverse_i=return_inverse,
|
||
|
return_counts_i=return_counts,
|
||
|
outputs=3,
|
||
|
)
|
||
|
|
||
|
symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_cast_Byte")
|
||
|
@_deprecation.deprecated(
|
||
|
"2.0",
|
||
|
"the future",
|
||
|
"Avoid using this function and create a Cast node instead",
|
||
|
)
|
||
|
@_beartype.beartype
|
||
|
def _cast_Byte(g: jit_utils.GraphContext, input, non_blocking):
|
||
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.UINT8)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_cast_Char")
|
||
|
@_deprecation.deprecated(
|
||
|
"2.0",
|
||
|
"the future",
|
||
|
"Avoid using this function and create a Cast node instead",
|
||
|
)
|
||
|
@_beartype.beartype
|
||
|
def _cast_Char(g: jit_utils.GraphContext, input, non_blocking):
|
||
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT8)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_cast_Short")
|
||
|
@_deprecation.deprecated(
|
||
|
"2.0",
|
||
|
"the future",
|
||
|
"Avoid using this function and create a Cast node instead",
|
||
|
)
|
||
|
@_beartype.beartype
|
||
|
def _cast_Short(g: jit_utils.GraphContext, input, non_blocking):
|
||
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT16)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_cast_Int")
|
||
|
@_deprecation.deprecated(
|
||
|
"2.0",
|
||
|
"the future",
|
||
|
"Avoid using this function and create a Cast node instead",
|
||
|
)
|
||
|
@_beartype.beartype
|
||
|
def _cast_Int(g: jit_utils.GraphContext, input, non_blocking):
|
||
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_cast_Long")
|
||
|
@_deprecation.deprecated(
|
||
|
"2.0",
|
||
|
"the future",
|
||
|
"Avoid using this function and create a Cast node instead",
|
||
|
)
|
||
|
@_beartype.beartype
|
||
|
def _cast_Long(g: jit_utils.GraphContext, input, non_blocking):
|
||
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_cast_Half")
|
||
|
@_deprecation.deprecated(
|
||
|
"2.0",
|
||
|
"the future",
|
||
|
"Avoid using this function and create a Cast node instead",
|
||
|
)
|
||
|
@_beartype.beartype
|
||
|
def _cast_Half(g: jit_utils.GraphContext, input, non_blocking):
|
||
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT16)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_cast_Float")
|
||
|
@_deprecation.deprecated(
|
||
|
"2.0",
|
||
|
"the future",
|
||
|
"Avoid using this function and create a Cast node instead",
|
||
|
)
|
||
|
@_beartype.beartype
|
||
|
def _cast_Float(g: jit_utils.GraphContext, input, non_blocking):
|
||
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_cast_Double")
|
||
|
@_deprecation.deprecated(
|
||
|
"2.0",
|
||
|
"the future",
|
||
|
"Avoid using this function and create a Cast node instead",
|
||
|
)
|
||
|
@_beartype.beartype
|
||
|
def _cast_Double(g: jit_utils.GraphContext, input, non_blocking):
|
||
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_cast_Bool")
|
||
|
@_deprecation.deprecated(
|
||
|
"2.0",
|
||
|
"the future",
|
||
|
"Avoid using this function and create a Cast node instead",
|
||
|
)
|
||
|
@_beartype.beartype
|
||
|
def _cast_Bool(g: jit_utils.GraphContext, input, non_blocking):
|
||
|
return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::empty")
|
||
|
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
|
||
|
@_beartype.beartype
|
||
|
def empty(
|
||
|
g: jit_utils.GraphContext,
|
||
|
sizes,
|
||
|
dtype,
|
||
|
layout,
|
||
|
device,
|
||
|
pin_memory=False,
|
||
|
memory_format=None,
|
||
|
):
|
||
|
return zeros(g, sizes, dtype, layout, device, pin_memory)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::empty_like")
|
||
|
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
|
||
|
@_beartype.beartype
|
||
|
def empty_like(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input,
|
||
|
dtype=None,
|
||
|
layout=None,
|
||
|
device=None,
|
||
|
pin_memory=False,
|
||
|
memory_format=None,
|
||
|
):
|
||
|
return zeros_like(g, input, dtype, layout, device, pin_memory)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::new_empty")
|
||
|
@_beartype.beartype
|
||
|
def new_empty(
|
||
|
g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False
|
||
|
):
|
||
|
self_dtype = symbolic_helper._try_get_scalar_type(self)
|
||
|
if symbolic_helper._is_none(dtype) and self_dtype is not None:
|
||
|
dtype = self_dtype
|
||
|
return empty(g, sizes, dtype, layout, device, pin_memory)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::scalar_tensor")
|
||
|
@_beartype.beartype
|
||
|
def scalar_tensor(g: jit_utils.GraphContext, scalar, dtype, *options):
|
||
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||
|
if dtype is None:
|
||
|
dtype = _type_utils.JitScalarType.FLOAT
|
||
|
scalar = g.op("Cast", scalar, to_i=_type_utils.JitScalarType(dtype).onnx_type())
|
||
|
return scalar
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::tensor")
|
||
|
@_beartype.beartype
|
||
|
def tensor(
|
||
|
g: jit_utils.GraphContext, data, dtype=None, device=None, requires_grad=False
|
||
|
):
|
||
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||
|
if symbolic_helper._is_packed_list(data):
|
||
|
if dtype is None:
|
||
|
dtype = _type_utils.JitScalarType.from_value(
|
||
|
symbolic_helper._unpack_list(data)[0]
|
||
|
)
|
||
|
input_list = list()
|
||
|
for t in symbolic_helper._unpack_list(data):
|
||
|
shape_reference = g.op("Constant", value_t=torch.LongTensor([1]))
|
||
|
t = symbolic_helper._reshape_helper(g, t, shape_reference)
|
||
|
t = g.op("Cast", t, to_i=_type_utils.JitScalarType(dtype).onnx_type())
|
||
|
input_list.append(t)
|
||
|
return g.op("Concat", *input_list, axis_i=0)
|
||
|
else:
|
||
|
if dtype is None:
|
||
|
dtype = _type_utils.JitScalarType.from_value(data)
|
||
|
if symbolic_helper._is_list(data) and (
|
||
|
symbolic_helper._is_tensor_list(data)
|
||
|
or symbolic_helper._is_scalar_list(data)
|
||
|
):
|
||
|
data = g.op("ConcatFromSequence", data, axis_i=0, new_axis_i=1)
|
||
|
return g.op("Cast", data, to_i=_type_utils.JitScalarType(dtype).onnx_type())
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::as_tensor")
|
||
|
@_beartype.beartype
|
||
|
def as_tensor(g: jit_utils.GraphContext, data, dtype=None, device=None):
|
||
|
return tensor(g, data, dtype, device)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::zeros")
|
||
|
@symbolic_helper.parse_args("v", "i", "v", "v", "v")
|
||
|
@_beartype.beartype
|
||
|
def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
|
||
|
# NOTE: no way to set device, layout and pin_memory in ONNX, so we ignore it
|
||
|
if dtype is None:
|
||
|
scalar_type = _type_utils.JitScalarType.FLOAT
|
||
|
else:
|
||
|
scalar_type = _type_utils.JitScalarType(dtype)
|
||
|
sizes_ = symbolic_helper._maybe_get_const(sizes, "is")
|
||
|
if isinstance(sizes_, list) and len(sizes_) == 0:
|
||
|
sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64))
|
||
|
return g.op(
|
||
|
"ConstantOfShape",
|
||
|
sizes,
|
||
|
value_t=torch.tensor([0], dtype=scalar_type.dtype()),
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::zeros_like")
|
||
|
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
|
||
|
@_beartype.beartype
|
||
|
def zeros_like(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input,
|
||
|
dtype=None,
|
||
|
layout=None,
|
||
|
device=None,
|
||
|
pin_memory=False,
|
||
|
memory_format=None,
|
||
|
):
|
||
|
shape = g.op("Shape", input)
|
||
|
if symbolic_helper._is_none(dtype):
|
||
|
scalar_type = _type_utils.JitScalarType.from_value(
|
||
|
input, _type_utils.JitScalarType.FLOAT
|
||
|
)
|
||
|
else:
|
||
|
scalar_type = _type_utils.JitScalarType(dtype)
|
||
|
return g.op(
|
||
|
"ConstantOfShape",
|
||
|
shape,
|
||
|
value_t=torch.tensor([0], dtype=scalar_type.dtype()),
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::new_zeros")
|
||
|
@_beartype.beartype
|
||
|
def new_zeros(
|
||
|
g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False
|
||
|
):
|
||
|
self_dtype = symbolic_helper._try_get_scalar_type(self)
|
||
|
|
||
|
if symbolic_helper._is_none(dtype) and self_dtype is not None:
|
||
|
dtype = self_dtype
|
||
|
return zeros(g, sizes, dtype, layout, device, pin_memory)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::zero")
|
||
|
@_beartype.beartype
|
||
|
def zero(g: jit_utils.GraphContext, self):
|
||
|
self_dtype = symbolic_helper._try_get_scalar_type(self)
|
||
|
return zeros_like(g, self, self_dtype)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::ones")
|
||
|
@symbolic_helper.parse_args("v", "i", "v", "v", "v")
|
||
|
@_beartype.beartype
|
||
|
def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
|
||
|
if dtype is None:
|
||
|
scalar_type = _type_utils.JitScalarType.FLOAT
|
||
|
else:
|
||
|
scalar_type = _type_utils.JitScalarType(dtype)
|
||
|
sizes_ = symbolic_helper._maybe_get_const(sizes, "is")
|
||
|
if isinstance(sizes_, list) and len(sizes_) == 0:
|
||
|
sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64))
|
||
|
return g.op(
|
||
|
"ConstantOfShape",
|
||
|
sizes,
|
||
|
value_t=torch.tensor([1], dtype=scalar_type.dtype()),
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::ones_like")
|
||
|
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
|
||
|
@_beartype.beartype
|
||
|
def ones_like(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input,
|
||
|
dtype=None,
|
||
|
layout=None,
|
||
|
device=None,
|
||
|
pin_memory=False,
|
||
|
memory_format=None,
|
||
|
):
|
||
|
shape = g.op("Shape", input)
|
||
|
if symbolic_helper._is_none(dtype):
|
||
|
scalar_type = _type_utils.JitScalarType.from_value(
|
||
|
input, _type_utils.JitScalarType.FLOAT
|
||
|
)
|
||
|
else:
|
||
|
scalar_type = _type_utils.JitScalarType(dtype)
|
||
|
return g.op(
|
||
|
"ConstantOfShape",
|
||
|
shape,
|
||
|
value_t=torch.tensor([1], dtype=scalar_type.dtype()),
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::new_ones")
|
||
|
@_beartype.beartype
|
||
|
def new_ones(
|
||
|
g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False
|
||
|
):
|
||
|
self_dtype = symbolic_helper._try_get_scalar_type(self)
|
||
|
if symbolic_helper._is_none(dtype) and self_dtype is not None:
|
||
|
dtype = self_dtype
|
||
|
return ones(g, sizes, dtype, layout, device, pin_memory)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::full")
|
||
|
@_beartype.beartype
|
||
|
def full(
|
||
|
g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False
|
||
|
):
|
||
|
const_value = symbolic_helper._maybe_get_const(value, "t")
|
||
|
if symbolic_helper._is_value(const_value):
|
||
|
dtype = _type_utils.JitScalarType.FLOAT if dtype is None else dtype
|
||
|
tmp = zeros(g, sizes, dtype, layout, device)
|
||
|
return add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
|
||
|
else:
|
||
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||
|
if dtype is None:
|
||
|
scalar_type = _type_utils.JitScalarType.FLOAT
|
||
|
else:
|
||
|
scalar_type = _type_utils.JitScalarType(dtype)
|
||
|
sizes_ = symbolic_helper._maybe_get_const(sizes, "is")
|
||
|
if isinstance(sizes_, list) and len(sizes_) == 0:
|
||
|
sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64))
|
||
|
return g.op(
|
||
|
"ConstantOfShape",
|
||
|
sizes,
|
||
|
value_t=const_value.view(1).to(scalar_type.dtype()),
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::full_like")
|
||
|
@_beartype.beartype
|
||
|
def full_like(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input,
|
||
|
fill_value,
|
||
|
dtype=None,
|
||
|
layout=None,
|
||
|
device=None,
|
||
|
pin_memory=False,
|
||
|
memory_format=None,
|
||
|
):
|
||
|
fill_value = symbolic_helper._maybe_get_const(fill_value, "f")
|
||
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||
|
if dtype is None:
|
||
|
scalar_type = _type_utils.JitScalarType.from_value(
|
||
|
input, _type_utils.JitScalarType.FLOAT
|
||
|
)
|
||
|
else:
|
||
|
scalar_type = _type_utils.JitScalarType(dtype)
|
||
|
if symbolic_helper._is_value(fill_value):
|
||
|
tmp = zeros_like(g, input, dtype, layout, device)
|
||
|
fill_value = g.op("Cast", fill_value, to_i=scalar_type.onnx_type())
|
||
|
return add(g, tmp, fill_value, g.op("Constant", value_t=torch.tensor(1)))
|
||
|
else:
|
||
|
shape = g.op("Shape", input)
|
||
|
return g.op(
|
||
|
"ConstantOfShape",
|
||
|
shape,
|
||
|
value_t=torch.tensor([fill_value], dtype=scalar_type.dtype()),
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::new_full")
|
||
|
@_beartype.beartype
|
||
|
def new_full(
|
||
|
g: jit_utils.GraphContext,
|
||
|
self,
|
||
|
size,
|
||
|
fill_value,
|
||
|
dtype,
|
||
|
layout,
|
||
|
device,
|
||
|
pin_memory=False,
|
||
|
):
|
||
|
self_dtype = symbolic_helper._try_get_scalar_type(self)
|
||
|
if symbolic_helper._is_none(dtype) and self_dtype is not None:
|
||
|
dtype = self_dtype
|
||
|
return full(g, size, fill_value, dtype, layout, device, pin_memory)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::eye")
|
||
|
@_beartype.beartype
|
||
|
def eye(g: jit_utils.GraphContext, *args):
|
||
|
if len(args) == 5:
|
||
|
# aten::eye(n, dtype, layout, device, pin_memory)
|
||
|
n, dtype, layout, device, pin_memory = args
|
||
|
dim_size = symbolic_helper._unsqueeze_helper(g, n, [0])
|
||
|
shape = g.op("Concat", dim_size, dim_size, axis_i=0)
|
||
|
tensor = zeros(g, shape, dtype, layout, device)
|
||
|
return g.op("EyeLike", tensor)
|
||
|
if len(args) == 6:
|
||
|
# aten::eye(n, m, dtype, layout, device, pin_memory)
|
||
|
n, m, dtype, layout, device, pin_memory = args
|
||
|
shape = g.op(
|
||
|
"Concat",
|
||
|
symbolic_helper._unsqueeze_helper(g, n, [0]),
|
||
|
symbolic_helper._unsqueeze_helper(g, m, [0]),
|
||
|
axis_i=0,
|
||
|
)
|
||
|
tensor = zeros(g, shape, dtype, layout, device)
|
||
|
return g.op("EyeLike", tensor)
|
||
|
|
||
|
return symbolic_helper._unimplemented("aten::eye", f"with {len(args)} arguments")
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::slice")
|
||
|
@_beartype.beartype
|
||
|
def slice(g: jit_utils.GraphContext, self, *args):
|
||
|
if len(args) == 4:
|
||
|
# aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor
|
||
|
dim, start, end, step = args
|
||
|
step = symbolic_helper._parse_arg(step, "i")
|
||
|
if step != 1:
|
||
|
raise errors.SymbolicValueError("step!=1 is currently not supported", self)
|
||
|
is_start_none = start.node().kind() == "prim::Constant" and isinstance(
|
||
|
start.type(), _C.NoneType
|
||
|
)
|
||
|
is_end_none = end.node().kind() == "prim::Constant" and isinstance(
|
||
|
end.type(), _C.NoneType
|
||
|
)
|
||
|
is_start_onnx_const = start.node().kind() == "onnx::Constant"
|
||
|
is_end_onnx_const = end.node().kind() == "onnx::Constant"
|
||
|
if (
|
||
|
((not is_start_none) and (not is_start_onnx_const))
|
||
|
or ((not is_end_none) and (not is_end_onnx_const))
|
||
|
or dim.node().kind() != "onnx::Constant"
|
||
|
):
|
||
|
if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX:
|
||
|
raise errors.SymbolicValueError(
|
||
|
"Unsupported: ONNX export of Slice with dynamic inputs. DynamicSlice "
|
||
|
"is a deprecated experimental op. Please use statically allocated "
|
||
|
"variables or export to a higher opset version.",
|
||
|
self,
|
||
|
)
|
||
|
else:
|
||
|
start_unsqueezed = symbolic_helper._unsqueeze_helper(g, start, [0])
|
||
|
end_unsqueezed = symbolic_helper._unsqueeze_helper(g, end, [0])
|
||
|
dim_unsqueezed = symbolic_helper._unsqueeze_helper(g, dim, [0])
|
||
|
return g.op(
|
||
|
"DynamicSlice",
|
||
|
self,
|
||
|
start_unsqueezed,
|
||
|
end_unsqueezed,
|
||
|
dim_unsqueezed,
|
||
|
)
|
||
|
else:
|
||
|
start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i")
|
||
|
end = (
|
||
|
_constants.INT64_MAX
|
||
|
if is_end_none
|
||
|
else symbolic_helper._parse_arg(end, "i")
|
||
|
)
|
||
|
dim = symbolic_helper._parse_arg(dim, "i")
|
||
|
return symbolic_helper._slice_helper(
|
||
|
g, self, axes=[dim], starts=[start], ends=[end]
|
||
|
)
|
||
|
elif len(args) == 3:
|
||
|
# aten::slice(t[] l, int start, int end, int step) -> t[]
|
||
|
start, end, step = args
|
||
|
dim = 0
|
||
|
is_start_none = start.node().kind() == "prim::Constant" and isinstance(
|
||
|
start.type(), _C.NoneType
|
||
|
)
|
||
|
is_end_none = end.node().kind() == "prim::Constant" and isinstance(
|
||
|
end.type(), _C.NoneType
|
||
|
)
|
||
|
start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i")
|
||
|
end = (
|
||
|
_constants.INT64_MAX
|
||
|
if is_end_none
|
||
|
else symbolic_helper._parse_arg(end, "i")
|
||
|
)
|
||
|
return symbolic_helper._slice_helper(
|
||
|
g, self, axes=[dim], starts=[start], ends=[end]
|
||
|
)
|
||
|
|
||
|
return symbolic_helper._unimplemented("aten::slice", f"with {len(args)} arguments")
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::hardtanh")
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@symbolic_helper.parse_args("v", "f", "f")
|
||
|
@_beartype.beartype
|
||
|
def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float):
|
||
|
return _op_with_optional_float_cast(
|
||
|
g, "Clip", self, min_f=min_val, max_f=max_val, opset_before=12
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::hardswish")
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@symbolic_helper.parse_args("v")
|
||
|
@_beartype.beartype
|
||
|
def hardswish(g: jit_utils.GraphContext, self):
|
||
|
hs = hardsigmoid(g, self)
|
||
|
return g.op("Mul", self, hs)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::hardsigmoid")
|
||
|
# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp
|
||
|
@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0)
|
||
|
@symbolic_helper.parse_args("v")
|
||
|
@_beartype.beartype
|
||
|
def hardsigmoid(g: jit_utils.GraphContext, self):
|
||
|
# Set alpha_f to 1 / 6 to make op equivalent to PyTorch's definition of Hardsigmoid.
|
||
|
# See https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html
|
||
|
return g.op("HardSigmoid", self, alpha_f=1 / 6)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::tanhshrink")
|
||
|
@symbolic_helper.parse_args("v")
|
||
|
@_beartype.beartype
|
||
|
def tanhshrink(g: jit_utils.GraphContext, self):
|
||
|
return g.op("Sub", self, tanh(g, self))
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::hardshrink")
|
||
|
@symbolic_helper.parse_args("v", "f")
|
||
|
@_beartype.beartype
|
||
|
def hardshrink(g: jit_utils.GraphContext, self, lambd):
|
||
|
scalar_type = _type_utils.JitScalarType.from_value(
|
||
|
self, _type_utils.JitScalarType.FLOAT
|
||
|
)
|
||
|
lambd_op = g.op(
|
||
|
"Constant",
|
||
|
value_t=torch.tensor(lambd, dtype=scalar_type.dtype()),
|
||
|
)
|
||
|
cond = logical_or(g, gt(g, self, lambd_op), lt(g, self, neg(g, lambd_op)))
|
||
|
return g.op(
|
||
|
"Where",
|
||
|
cond,
|
||
|
self,
|
||
|
g.op(
|
||
|
"Constant",
|
||
|
value_t=torch.tensor(0, dtype=scalar_type.dtype()),
|
||
|
),
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::softshrink")
|
||
|
@symbolic_helper.parse_args("v", "f")
|
||
|
@_beartype.beartype
|
||
|
def softshrink(g: jit_utils.GraphContext, self, lambd):
|
||
|
scalar_type = _type_utils.JitScalarType.from_value(
|
||
|
self, _type_utils.JitScalarType.FLOAT
|
||
|
)
|
||
|
lambd_op = g.op(
|
||
|
"Constant",
|
||
|
value_t=torch.tensor(lambd, dtype=scalar_type.dtype()),
|
||
|
)
|
||
|
gt_cond = gt(g, self, lambd_op)
|
||
|
gt_out = g.op(
|
||
|
"Where",
|
||
|
gt_cond,
|
||
|
sub(g, self, lambd_op),
|
||
|
g.op(
|
||
|
"Constant",
|
||
|
value_t=torch.tensor(0, dtype=scalar_type.dtype()),
|
||
|
),
|
||
|
)
|
||
|
lt_cond = lt(g, self, neg(g, lambd_op))
|
||
|
lt_out = g.op(
|
||
|
"Where",
|
||
|
lt_cond,
|
||
|
add(g, self, lambd_op),
|
||
|
g.op(
|
||
|
"Constant",
|
||
|
value_t=torch.tensor(0, dtype=scalar_type.dtype()),
|
||
|
),
|
||
|
)
|
||
|
return add(g, gt_out, lt_out)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::alias")
|
||
|
@_beartype.beartype
|
||
|
def alias(g: jit_utils.GraphContext, self):
|
||
|
return self
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::unsqueeze")
|
||
|
@symbolic_helper.parse_args("v", "i")
|
||
|
@_beartype.beartype
|
||
|
def unsqueeze(g: jit_utils.GraphContext, self, dim):
|
||
|
# Handle negative dim
|
||
|
if dim < 0:
|
||
|
rank = symbolic_helper._get_tensor_rank(self)
|
||
|
if rank is not None:
|
||
|
warnings.warn(
|
||
|
"ONNX export unsqueeze with negative axis "
|
||
|
+ str(dim)
|
||
|
+ " might cause the onnx model to be incorrect. "
|
||
|
+ "Negative axis is not supported in ONNX. "
|
||
|
+ "Axis is converted to "
|
||
|
+ str(dim + rank + 1)
|
||
|
+ " based on input shape at export time. "
|
||
|
+ "Passing an tensor of different rank in execution will be incorrect."
|
||
|
)
|
||
|
dim = dim + rank + 1
|
||
|
else:
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"unsqueeze", "negative axis with unknown input rank", self
|
||
|
)
|
||
|
|
||
|
return symbolic_helper._unsqueeze_helper(g, self, axes_i=[dim])
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::sort")
|
||
|
# TODO(justinchuby): Support multiple quantized args in output
|
||
|
@symbolic_helper.parse_args("v", "i", "i", "none")
|
||
|
@_beartype.beartype
|
||
|
def sort(g: jit_utils.GraphContext, self, dim, decending, out=None):
|
||
|
if out is not None:
|
||
|
symbolic_helper._unimplemented(
|
||
|
"Sort", "Out parameter is not supported for sort", self
|
||
|
)
|
||
|
self_sizes = symbolic_helper._get_tensor_sizes(self)
|
||
|
try:
|
||
|
dim_size = self_sizes[dim]
|
||
|
except Exception:
|
||
|
# FIXME(justinchuby): Avoid catching Exception.
|
||
|
# Catch a more specific exception instead.
|
||
|
dim_size = None
|
||
|
|
||
|
if dim_size is None:
|
||
|
return symbolic_helper._unimplemented("Sort", "input size not accessible", self)
|
||
|
|
||
|
return g.op("TopK", self, k_i=dim_size, axis_i=dim, outputs=2)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::numel")
|
||
|
@_beartype.beartype
|
||
|
def numel(g: jit_utils.GraphContext, self):
|
||
|
shape = g.op("Shape", self)
|
||
|
return g.op("ReduceProd", shape, keepdims_i=0)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::topk")
|
||
|
# TODO(justinchuby): Support multiple quantized args in output
|
||
|
@symbolic_helper.parse_args("v", "i", "i", "i", "i", "none")
|
||
|
@_beartype.beartype
|
||
|
def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
|
||
|
if out is not None:
|
||
|
symbolic_helper._unimplemented(
|
||
|
"TopK", "Out parameter is not supported for topk", self
|
||
|
)
|
||
|
if not largest:
|
||
|
symbolic_helper._unimplemented("TopK", "Ascending TopK is not supported", self)
|
||
|
|
||
|
return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("prim::convert_element_type")
|
||
|
@_beartype.beartype
|
||
|
def convert_element_type(g: jit_utils.GraphContext, self, *args):
|
||
|
dtype = symbolic_helper._get_const(args[0], "i", "dtype")
|
||
|
return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type())
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::to")
|
||
|
@_beartype.beartype
|
||
|
def to(g: jit_utils.GraphContext, self, *args):
|
||
|
@_beartype.beartype
|
||
|
def is_aten_to_device_only(args):
|
||
|
if len(args) == 4:
|
||
|
# aten::to(Tensor, Device, bool, bool, memory_format)
|
||
|
return (
|
||
|
args[0].node().kind() == "prim::device"
|
||
|
or args[0].type().isSubtypeOf(_C.ListType.ofInts())
|
||
|
or isinstance(args[0].type(), _C.DeviceObjType)
|
||
|
)
|
||
|
elif len(args) == 5:
|
||
|
# aten::to(Tensor, Device, ScalarType, bool, bool, memory_format)
|
||
|
# When dtype is None, this is a aten::to(device) call
|
||
|
dtype = symbolic_helper._get_const(args[1], "i", "dtype")
|
||
|
return dtype is None
|
||
|
elif len(args) in (6, 7):
|
||
|
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor
|
||
|
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor
|
||
|
# When dtype is None, this is a aten::to(device) call
|
||
|
dtype = symbolic_helper._get_const(args[0], "i", "dtype")
|
||
|
return dtype is None
|
||
|
return False
|
||
|
|
||
|
# ONNX doesn't have a concept of a device, so we ignore device-only casts
|
||
|
if is_aten_to_device_only(args):
|
||
|
return self
|
||
|
|
||
|
if len(args) == 4:
|
||
|
# TestONNXRuntime::test_ones_bool shows args[0] of aten::to() can be onnx::Constant[value=<Tensor>]()
|
||
|
# In this case, the constant value is a tensor not int,
|
||
|
# so symbolic_helper._maybe_get_const(args[0], 'i') would not work.
|
||
|
dtype = args[0]
|
||
|
if (
|
||
|
symbolic_helper._is_value(args[0])
|
||
|
and args[0].node().kind() == "onnx::Constant"
|
||
|
):
|
||
|
tval = symbolic_helper._node_get(args[0].node(), "value")
|
||
|
if isinstance(tval, torch.Tensor):
|
||
|
if len(tval.shape) == 0:
|
||
|
tval = tval.item()
|
||
|
dtype = int(tval)
|
||
|
else:
|
||
|
dtype = tval
|
||
|
|
||
|
if symbolic_helper._is_value(dtype) or isinstance(dtype, torch.Tensor):
|
||
|
# aten::to(Tensor, Tensor, bool, bool, memory_format)
|
||
|
dtype = _type_utils.JitScalarType.from_value(args[0])
|
||
|
return g.op(
|
||
|
"Cast",
|
||
|
self,
|
||
|
to_i=dtype.onnx_type(),
|
||
|
)
|
||
|
else:
|
||
|
# aten::to(Tensor, ScalarType, bool, bool, memory_format)
|
||
|
# memory_format is ignored
|
||
|
return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type())
|
||
|
elif len(args) == 5:
|
||
|
# aten::to(Tensor, Device, ScalarType, bool, bool, memory_format)
|
||
|
dtype = symbolic_helper._get_const(args[1], "i", "dtype")
|
||
|
# memory_format is ignored
|
||
|
return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type())
|
||
|
elif len(args) == 6:
|
||
|
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor
|
||
|
dtype = symbolic_helper._get_const(args[0], "i", "dtype")
|
||
|
# Layout, device and memory_format are ignored
|
||
|
return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type())
|
||
|
elif len(args) == 7:
|
||
|
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor
|
||
|
dtype = symbolic_helper._get_const(args[0], "i", "dtype")
|
||
|
# Layout, device and memory_format are ignored
|
||
|
return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type())
|
||
|
|
||
|
return symbolic_helper._onnx_unsupported("Unknown aten::to signature", self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::repeat")
|
||
|
@_beartype.beartype
|
||
|
def repeat(g: jit_utils.GraphContext, self, repeats):
|
||
|
dtype = _type_utils.JitScalarType.INT64
|
||
|
shape_ = ones_like(g, repeats, dtype)
|
||
|
self = g.op("Expand", self, shape_)
|
||
|
return g.op("Tile", self, repeats)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::repeat_interleave")
|
||
|
@_beartype.beartype
|
||
|
def repeat_interleave(
|
||
|
g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None
|
||
|
):
|
||
|
repeats_dim = symbolic_helper._get_tensor_rank(repeats)
|
||
|
repeats_sizes = symbolic_helper._get_tensor_sizes(repeats)
|
||
|
input_sizes = symbolic_helper._get_tensor_sizes(self)
|
||
|
if repeats_dim is None:
|
||
|
raise errors.SymbolicValueError(
|
||
|
"Unsupported: ONNX export of repeat_interleave for unknown repeats rank.",
|
||
|
self,
|
||
|
)
|
||
|
if repeats_sizes is None:
|
||
|
raise errors.SymbolicValueError(
|
||
|
"Unsupported: ONNX export of repeat_interleave for unknown repeats size.",
|
||
|
self,
|
||
|
)
|
||
|
if input_sizes is None:
|
||
|
raise errors.SymbolicValueError(
|
||
|
"Unsupported: ONNX export of repeat_interleave for unknown input size.",
|
||
|
self,
|
||
|
)
|
||
|
|
||
|
# if dim is None flatten
|
||
|
# By default, use the flattened input array, and return a flat output array
|
||
|
if symbolic_helper._is_none(dim):
|
||
|
self = symbolic_helper._reshape_helper(
|
||
|
g, self, g.op("Constant", value_t=torch.tensor([-1]))
|
||
|
)
|
||
|
dim = torch.tensor(0, dtype=torch.int64)
|
||
|
else:
|
||
|
dim = symbolic_helper._maybe_get_scalar(dim)
|
||
|
|
||
|
# Handle cases where dim is negative
|
||
|
if dim < 0:
|
||
|
dim += len(input_sizes)
|
||
|
|
||
|
input_sizes_temp = input_sizes.copy()
|
||
|
for idx, input_size in enumerate(input_sizes):
|
||
|
if input_size is None:
|
||
|
input_sizes[idx], input_sizes_temp[idx] = 0, -1
|
||
|
|
||
|
# Cases where repeats is an int or single value tensor
|
||
|
if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1):
|
||
|
if input_sizes[dim] == 0:
|
||
|
return symbolic_helper._onnx_opset_unsupported_detailed(
|
||
|
"repeat_interleave",
|
||
|
9,
|
||
|
13,
|
||
|
"Unsupported along dimension with unknown input size",
|
||
|
self,
|
||
|
)
|
||
|
return symbolic_helper._repeat_interleave_single_value_repeat_helper(
|
||
|
g, self, repeats, dim
|
||
|
)
|
||
|
|
||
|
# Cases where repeats is a 1 dim Tensor
|
||
|
elif repeats_dim == 1:
|
||
|
if input_sizes[dim] == 0:
|
||
|
return symbolic_helper._onnx_opset_unsupported_detailed(
|
||
|
"repeat_interleave",
|
||
|
9,
|
||
|
13,
|
||
|
"Unsupported along dimension with unknown input size",
|
||
|
self,
|
||
|
)
|
||
|
if repeats_sizes[0] is None:
|
||
|
return symbolic_helper._onnx_opset_unsupported_detailed(
|
||
|
"repeat_interleave",
|
||
|
9,
|
||
|
13,
|
||
|
"Unsupported for cases with dynamic repeats",
|
||
|
self,
|
||
|
)
|
||
|
assert (
|
||
|
repeats_sizes[0] == input_sizes[dim]
|
||
|
), "repeats must have the same size as input along dim"
|
||
|
reps = repeats_sizes[0]
|
||
|
else:
|
||
|
raise errors.SymbolicValueError("repeats must be 0-dim or 1-dim tensor", self)
|
||
|
|
||
|
final_splits = list()
|
||
|
r_splits = symbolic_helper._repeat_interleave_split_helper(g, repeats, reps, 0)
|
||
|
i_splits = symbolic_helper._repeat_interleave_split_helper(g, self, reps, dim)
|
||
|
input_sizes[dim], input_sizes_temp[dim] = -1, 1
|
||
|
for idx, r_split in enumerate(r_splits):
|
||
|
i_split = unsqueeze(g, i_splits[idx], dim + 1)
|
||
|
r_concat = [
|
||
|
g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[: dim + 1])),
|
||
|
r_split,
|
||
|
g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1 :])),
|
||
|
]
|
||
|
r_concat = g.op("Concat", *r_concat, axis_i=0)
|
||
|
i_split = expand(g, i_split, r_concat, None)
|
||
|
i_split = symbolic_helper._reshape_helper(
|
||
|
g,
|
||
|
i_split,
|
||
|
g.op("Constant", value_t=torch.LongTensor(input_sizes)),
|
||
|
allowzero=0,
|
||
|
)
|
||
|
final_splits.append(i_split)
|
||
|
return g.op("Concat", *final_splits, axis_i=dim)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::pixel_shuffle")
|
||
|
@symbolic_helper.parse_args("v", "i")
|
||
|
@_beartype.beartype
|
||
|
def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor):
|
||
|
dims = symbolic_helper._get_tensor_sizes(self)
|
||
|
if len(dims) != 4:
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"pixel_shuffle", "only support 4d input", self
|
||
|
)
|
||
|
if any(i is None for i in dims[1:]):
|
||
|
after_view = symbolic_helper._reshape_helper(
|
||
|
g,
|
||
|
symbolic_helper._unsqueeze_helper(g, self, [2, 3]),
|
||
|
g.op(
|
||
|
"Constant",
|
||
|
value_t=torch.tensor([0, -1, upscale_factor, upscale_factor, 0, 0]),
|
||
|
),
|
||
|
allowzero=0,
|
||
|
)
|
||
|
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
|
||
|
# For dynamic input shapes, two reshapes are performed
|
||
|
reshape_h = symbolic_helper._reshape_helper(
|
||
|
g,
|
||
|
after_transpose,
|
||
|
g.op("Constant", value_t=torch.tensor([0, 0, -1, 1, 0, 0])),
|
||
|
allowzero=0,
|
||
|
)
|
||
|
reshape_w = symbolic_helper._reshape_helper(
|
||
|
g,
|
||
|
reshape_h,
|
||
|
g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, 1])),
|
||
|
allowzero=0,
|
||
|
)
|
||
|
return symbolic_helper._squeeze_helper(g, reshape_w, [3, 5])
|
||
|
else:
|
||
|
output_channel = dims[1] // upscale_factor // upscale_factor
|
||
|
after_view = symbolic_helper._reshape_helper(
|
||
|
g,
|
||
|
self,
|
||
|
g.op(
|
||
|
"Constant",
|
||
|
value_t=torch.tensor(
|
||
|
[
|
||
|
-1,
|
||
|
output_channel,
|
||
|
upscale_factor,
|
||
|
upscale_factor,
|
||
|
dims[2],
|
||
|
dims[3],
|
||
|
]
|
||
|
),
|
||
|
),
|
||
|
allowzero=0,
|
||
|
)
|
||
|
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
|
||
|
return symbolic_helper._reshape_helper(
|
||
|
g,
|
||
|
after_transpose,
|
||
|
g.op(
|
||
|
"Constant",
|
||
|
value_t=torch.tensor(
|
||
|
[
|
||
|
-1,
|
||
|
output_channel,
|
||
|
dims[2] * upscale_factor,
|
||
|
dims[3] * upscale_factor,
|
||
|
]
|
||
|
),
|
||
|
),
|
||
|
allowzero=0,
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::pixel_unshuffle")
|
||
|
@symbolic_helper.parse_args("v", "i")
|
||
|
@_beartype.beartype
|
||
|
def pixel_unshuffle(g: jit_utils.GraphContext, self, downscale_factor):
|
||
|
dims = symbolic_helper._get_tensor_sizes(self)
|
||
|
if len(dims) != 4:
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"pixel_shuffle", "only support 4d input", self
|
||
|
)
|
||
|
if any(i is None for i in dims[1:]):
|
||
|
# For dynamic input shapes, two reshapes are performed
|
||
|
reshape_h = symbolic_helper._reshape_helper(
|
||
|
g,
|
||
|
symbolic_helper._unsqueeze_helper(g, self, [3]),
|
||
|
g.op("Constant", value_t=torch.tensor([0, 0, -1, downscale_factor, 0])),
|
||
|
allowzero=0,
|
||
|
)
|
||
|
reshape_w = symbolic_helper._reshape_helper(
|
||
|
g,
|
||
|
reshape_h,
|
||
|
g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, downscale_factor])),
|
||
|
allowzero=0,
|
||
|
)
|
||
|
after_transpose = g.op("Transpose", reshape_w, perm_i=[0, 1, 3, 5, 2, 4])
|
||
|
final_reshape = symbolic_helper._reshape_helper(
|
||
|
g,
|
||
|
after_transpose,
|
||
|
g.op("Constant", value_t=torch.tensor([0, -1, 1, 1, 0, 0])),
|
||
|
allowzero=0,
|
||
|
)
|
||
|
return symbolic_helper._squeeze_helper(g, final_reshape, [2, 3])
|
||
|
else:
|
||
|
output_channel = dims[1] * downscale_factor * downscale_factor
|
||
|
after_view = symbolic_helper._reshape_helper(
|
||
|
g,
|
||
|
self,
|
||
|
g.op(
|
||
|
"Constant",
|
||
|
value_t=torch.tensor(
|
||
|
[
|
||
|
-1,
|
||
|
dims[1],
|
||
|
dims[2] // downscale_factor,
|
||
|
downscale_factor,
|
||
|
dims[3] // downscale_factor,
|
||
|
downscale_factor,
|
||
|
]
|
||
|
),
|
||
|
),
|
||
|
allowzero=0,
|
||
|
)
|
||
|
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 3, 5, 2, 4])
|
||
|
return symbolic_helper._reshape_helper(
|
||
|
g,
|
||
|
after_transpose,
|
||
|
g.op(
|
||
|
"Constant",
|
||
|
value_t=torch.tensor(
|
||
|
[
|
||
|
-1,
|
||
|
output_channel,
|
||
|
dims[2] // downscale_factor,
|
||
|
dims[3] // downscale_factor,
|
||
|
]
|
||
|
),
|
||
|
),
|
||
|
allowzero=0,
|
||
|
)
|
||
|
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def _generic_rnn(
|
||
|
g: jit_utils.GraphContext,
|
||
|
variant,
|
||
|
input,
|
||
|
initial_states,
|
||
|
all_weights,
|
||
|
has_biases,
|
||
|
num_layers,
|
||
|
dropout,
|
||
|
train,
|
||
|
bidirectional,
|
||
|
batch_first=None,
|
||
|
batch_sizes=None,
|
||
|
):
|
||
|
warnings.warn(
|
||
|
"Exporting a model to ONNX with a batch_size other than 1, "
|
||
|
+ "with a variable length with "
|
||
|
+ variant
|
||
|
+ " can cause an error "
|
||
|
+ "when running the ONNX model with a different batch size. "
|
||
|
+ "Make sure to save the model with a batch size of 1, "
|
||
|
+ "or define the initial states (h0/c0) as inputs of the model. "
|
||
|
)
|
||
|
|
||
|
onnxActivations = [
|
||
|
"Relu",
|
||
|
"Tanh",
|
||
|
"Sigmoid",
|
||
|
"Affine",
|
||
|
"LeakyRelu",
|
||
|
"ThresholdedRelu",
|
||
|
"ScaledTanh",
|
||
|
"HardSigmoid",
|
||
|
"Elu",
|
||
|
"Softsign",
|
||
|
"Softplus",
|
||
|
]
|
||
|
variantToOnnxActivationMap = dict(
|
||
|
zip([act_fun.lower() for act_fun in onnxActivations], onnxActivations)
|
||
|
)
|
||
|
weights_per_layer = 4 if has_biases else 2
|
||
|
# this means that projections are used inside LSTM, so need to tell user that it's not supported
|
||
|
if variant == "LSTM" and len(all_weights) != num_layers * weights_per_layer * (
|
||
|
1 + bidirectional
|
||
|
):
|
||
|
return symbolic_helper._unimplemented("LSTM", "LSTMs with projections", input)
|
||
|
assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional)
|
||
|
layer_weights = [
|
||
|
all_weights[i : i + weights_per_layer]
|
||
|
for i in range(0, len(all_weights), weights_per_layer)
|
||
|
]
|
||
|
if batch_first:
|
||
|
# batch, seq, feat -> seq, batch, feat
|
||
|
input = g.op("Transpose", input, perm_i=[1, 0, 2])
|
||
|
if dropout and train:
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"RNN/GRU/LSTM", "dropout in training mode", input
|
||
|
)
|
||
|
|
||
|
if variant.startswith("RNN"):
|
||
|
nonlinearity = variantToOnnxActivationMap[variant[4:].lower()]
|
||
|
variant = "RNN"
|
||
|
|
||
|
w_hh = all_weights[1]
|
||
|
hidden_size = symbolic_helper._get_tensor_dim_size(w_hh, 1)
|
||
|
if hidden_size is None:
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"RNN/GRU/LSTM", "unknown hidden size", input
|
||
|
)
|
||
|
|
||
|
unidirectional = not bidirectional
|
||
|
|
||
|
prev_output = input
|
||
|
|
||
|
h_outs = []
|
||
|
if variant == "RNN" or variant == "GRU":
|
||
|
h0 = initial_states
|
||
|
elif variant == "LSTM":
|
||
|
h0, c0 = initial_states
|
||
|
c_outs = []
|
||
|
|
||
|
sequence_lens = unused(g) if batch_sizes is None else batch_sizes
|
||
|
|
||
|
if variant == "GRU":
|
||
|
# pytorch is reset, input, hidden
|
||
|
# onnx is input, reset, hidden
|
||
|
reform_permutation = [(1, 2), (0, 1), (2, 3)]
|
||
|
elif variant == "LSTM":
|
||
|
# pytorch is input, forget, cell, output.
|
||
|
# onnx is input, output, forget, cell.
|
||
|
reform_permutation = [(0, 1), (3, 4), (1, 3)]
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def reform_weights(g, w, n, intervals):
|
||
|
slices = [
|
||
|
symbolic_helper._slice_helper(g, w, axes=[0], starts=[x * n], ends=[y * n])
|
||
|
for x, y in intervals
|
||
|
]
|
||
|
return g.op("Concat", *slices, axis_i=0)
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def transform_weights_no_bias(layer_index):
|
||
|
weights = layer_weights[layer_index]
|
||
|
if variant == "RNN":
|
||
|
weight_ih, weight_hh = weights
|
||
|
elif variant == "GRU" or variant == "LSTM":
|
||
|
weight_ih, weight_hh = (
|
||
|
reform_weights(g, w, hidden_size, reform_permutation) for w in weights
|
||
|
)
|
||
|
return tuple(
|
||
|
symbolic_helper._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh) # type: ignore[possibly-undefined]
|
||
|
)
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def transform_weights(layer_index):
|
||
|
weights = layer_weights[layer_index]
|
||
|
if variant == "RNN":
|
||
|
weight_ih, weight_hh, bias_ih, bias_hh = weights
|
||
|
elif variant == "GRU" or variant == "LSTM":
|
||
|
weight_ih, weight_hh, bias_ih, bias_hh = (
|
||
|
reform_weights(g, w, hidden_size, reform_permutation) for w in weights
|
||
|
)
|
||
|
bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0) # type: ignore[possibly-undefined]
|
||
|
return tuple(
|
||
|
symbolic_helper._unsqueeze_helper(g, x, [0])
|
||
|
for x in (weight_ih, weight_hh, bias_concat) # type: ignore[possibly-undefined]
|
||
|
)
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def retrieve_state(x, start, end):
|
||
|
return (
|
||
|
x
|
||
|
if num_layers == 1
|
||
|
else symbolic_helper._slice_helper(
|
||
|
g, x, axes=[0], starts=[start], ends=[end]
|
||
|
)
|
||
|
)
|
||
|
|
||
|
for i in range(num_layers):
|
||
|
if unidirectional:
|
||
|
if weights_per_layer == 4:
|
||
|
weight_ih, weight_hh, bias_concat = transform_weights(i)
|
||
|
else:
|
||
|
weight_ih, weight_hh = transform_weights_no_bias(i)
|
||
|
bias_concat = unused(g)
|
||
|
|
||
|
state_indices = i, i + 1
|
||
|
else:
|
||
|
if weights_per_layer == 4:
|
||
|
weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i)
|
||
|
weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1)
|
||
|
bias_concat = g.op("Concat", bias_f, bias_b, axis_i=0)
|
||
|
else:
|
||
|
weight_ih_f, weight_hh_f = transform_weights_no_bias(2 * i)
|
||
|
weight_ih_b, weight_hh_b = transform_weights_no_bias(2 * i + 1)
|
||
|
bias_concat = unused(g)
|
||
|
|
||
|
weight_ih = g.op("Concat", weight_ih_f, weight_ih_b, axis_i=0)
|
||
|
weight_hh = g.op("Concat", weight_hh_f, weight_hh_b, axis_i=0)
|
||
|
|
||
|
state_indices = 2 * i, 2 * i + 2
|
||
|
|
||
|
inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens]
|
||
|
|
||
|
inputs.append(retrieve_state(h0, *state_indices)) # type: ignore[possibly-undefined]
|
||
|
if variant == "LSTM":
|
||
|
inputs.append(retrieve_state(c0, *state_indices)) # type: ignore[possibly-undefined]
|
||
|
|
||
|
extra_kwargs = {} if unidirectional else {"direction_s": "bidirectional"}
|
||
|
if variant == "RNN":
|
||
|
if bidirectional:
|
||
|
activation = [nonlinearity, nonlinearity] # type: ignore[possibly-undefined]
|
||
|
else:
|
||
|
activation = [nonlinearity] # type: ignore[possibly-undefined]
|
||
|
|
||
|
prev_output, h_out = g.op(
|
||
|
"RNN",
|
||
|
*inputs,
|
||
|
outputs=2,
|
||
|
hidden_size_i=hidden_size,
|
||
|
activations_s=activation,
|
||
|
**extra_kwargs,
|
||
|
)
|
||
|
elif variant == "GRU":
|
||
|
prev_output, h_out = g.op(
|
||
|
"GRU",
|
||
|
*inputs,
|
||
|
outputs=2,
|
||
|
hidden_size_i=hidden_size,
|
||
|
linear_before_reset_i=1,
|
||
|
**extra_kwargs,
|
||
|
)
|
||
|
elif variant == "LSTM":
|
||
|
prev_output, h_out, c_out = g.op(
|
||
|
"LSTM", *inputs, outputs=3, hidden_size_i=hidden_size, **extra_kwargs
|
||
|
)
|
||
|
|
||
|
if bidirectional:
|
||
|
# The ONNX RNN/GRU/LSTM produce an output of dimensions
|
||
|
# seq_len, num_directions, batch, hidden_size
|
||
|
# We have to convert to match pytorch's expected
|
||
|
# seq_len, batch, num_directions * hidden_size
|
||
|
# by first moving num_directions before hidden_size with
|
||
|
# Transpose, and then combining it with hidden_size
|
||
|
# with Reshape.
|
||
|
prev_output = g.op("Transpose", prev_output, perm_i=[0, 2, 1, 3])
|
||
|
prev_output = symbolic_helper._reshape_helper(
|
||
|
g,
|
||
|
prev_output,
|
||
|
g.op("Constant", value_t=torch.LongTensor([0, 0, -1])),
|
||
|
allowzero=0,
|
||
|
)
|
||
|
else:
|
||
|
prev_output = symbolic_helper._squeeze_helper(g, prev_output, [1])
|
||
|
|
||
|
h_outs.append(h_out) # type: ignore[possibly-undefined]
|
||
|
if variant == "LSTM":
|
||
|
c_outs.append(c_out) # type: ignore[possibly-undefined]
|
||
|
if batch_first:
|
||
|
# seq, batch, num_directions * hidden_size -> batch, seq, num_directions * hidden_size
|
||
|
prev_output = g.op("Transpose", prev_output, perm_i=[1, 0, 2])
|
||
|
h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0) # type: ignore[possibly-undefined]
|
||
|
if variant == "RNN" or variant == "GRU":
|
||
|
return prev_output, h_outs
|
||
|
elif variant == "LSTM":
|
||
|
c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0) # type: ignore[possibly-undefined]
|
||
|
return prev_output, h_outs, c_outs
|
||
|
|
||
|
|
||
|
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i")
|
||
|
@_beartype.beartype
|
||
|
def _lstm_full(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input,
|
||
|
hidden_v,
|
||
|
weight_v,
|
||
|
has_biases,
|
||
|
num_layers,
|
||
|
dropout,
|
||
|
train,
|
||
|
bidirectional,
|
||
|
batch_first,
|
||
|
):
|
||
|
hidden, weight = symbolic_helper._unpack_list(
|
||
|
hidden_v
|
||
|
), symbolic_helper._unpack_list(weight_v)
|
||
|
return _generic_rnn(
|
||
|
g,
|
||
|
"LSTM",
|
||
|
input,
|
||
|
hidden,
|
||
|
weight,
|
||
|
has_biases,
|
||
|
num_layers,
|
||
|
dropout,
|
||
|
train,
|
||
|
bidirectional,
|
||
|
batch_first,
|
||
|
)
|
||
|
|
||
|
|
||
|
@symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i")
|
||
|
@_beartype.beartype
|
||
|
def _lstm_packed(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input,
|
||
|
batch_sizes,
|
||
|
hidden_v,
|
||
|
weight_v,
|
||
|
has_biases,
|
||
|
num_layers,
|
||
|
dropout,
|
||
|
train,
|
||
|
bidirectional,
|
||
|
):
|
||
|
hidden, weight = symbolic_helper._unpack_list(
|
||
|
hidden_v
|
||
|
), symbolic_helper._unpack_list(weight_v)
|
||
|
return _generic_rnn(
|
||
|
g,
|
||
|
"LSTM",
|
||
|
input,
|
||
|
hidden,
|
||
|
weight,
|
||
|
has_biases,
|
||
|
num_layers,
|
||
|
dropout,
|
||
|
train,
|
||
|
bidirectional,
|
||
|
batch_sizes=batch_sizes,
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::lstm")
|
||
|
@_beartype.beartype
|
||
|
def lstm(g: jit_utils.GraphContext, *args):
|
||
|
if symbolic_helper._is_tensor_list(args[3]):
|
||
|
return _lstm_packed(g, *args)
|
||
|
else:
|
||
|
return _lstm_full(g, *args)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::lstm_cell")
|
||
|
@_beartype.beartype
|
||
|
def lstm_cell(g: jit_utils.GraphContext, self, hidden, w_ih, w_hh, b_ih, b_hh):
|
||
|
input = symbolic_helper._unsqueeze_helper(g, self, [0])
|
||
|
hidden = symbolic_helper._unpack_list(hidden)
|
||
|
hidden = [symbolic_helper._unsqueeze_helper(g, x, [0]) for x in hidden]
|
||
|
weight = (
|
||
|
(w_ih, w_hh, b_ih, b_hh) if symbolic_helper._is_tensor(b_ih) else (w_ih, w_hh)
|
||
|
)
|
||
|
has_biases = True if symbolic_helper._is_tensor(b_ih) else False
|
||
|
_, h_outs, c_outs = _generic_rnn(
|
||
|
g,
|
||
|
"LSTM",
|
||
|
input,
|
||
|
hidden,
|
||
|
weight,
|
||
|
has_biases,
|
||
|
num_layers=1,
|
||
|
dropout=0,
|
||
|
train=0,
|
||
|
bidirectional=False,
|
||
|
batch_first=False,
|
||
|
)
|
||
|
return symbolic_helper._squeeze_helper(
|
||
|
g, h_outs, [0]
|
||
|
), symbolic_helper._squeeze_helper(g, c_outs, [0])
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::gru", decorate=[_apply_params("GRU"), _export("gru")])
|
||
|
@_onnx_symbolic(
|
||
|
"aten::rnn_tanh", decorate=[_apply_params("RNN_TANH"), _export("rnn_tanh")]
|
||
|
)
|
||
|
@_onnx_symbolic(
|
||
|
"aten::rnn_relu", decorate=[_apply_params("RNN_RELU"), _export("rnn_relu")]
|
||
|
)
|
||
|
def _one_hidden_rnn(kind: str):
|
||
|
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i")
|
||
|
@_beartype.beartype
|
||
|
def _rnn_full(
|
||
|
g,
|
||
|
input,
|
||
|
hidden,
|
||
|
weight_v,
|
||
|
has_biases,
|
||
|
num_layers,
|
||
|
dropout,
|
||
|
train,
|
||
|
bidirectional,
|
||
|
batch_first,
|
||
|
):
|
||
|
weight = symbolic_helper._unpack_list(weight_v)
|
||
|
return _generic_rnn(
|
||
|
g,
|
||
|
kind,
|
||
|
input,
|
||
|
hidden,
|
||
|
weight,
|
||
|
has_biases,
|
||
|
num_layers,
|
||
|
dropout,
|
||
|
train,
|
||
|
bidirectional,
|
||
|
batch_first,
|
||
|
)
|
||
|
|
||
|
@symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i")
|
||
|
def _rnn_packed(
|
||
|
g,
|
||
|
input,
|
||
|
batch_sizes,
|
||
|
hidden,
|
||
|
weight_v,
|
||
|
has_biases,
|
||
|
num_layers,
|
||
|
dropout,
|
||
|
train,
|
||
|
bidirectional,
|
||
|
):
|
||
|
weight = symbolic_helper._unpack_list(weight_v)
|
||
|
return _generic_rnn(
|
||
|
g,
|
||
|
kind,
|
||
|
input,
|
||
|
hidden,
|
||
|
weight,
|
||
|
has_biases,
|
||
|
num_layers,
|
||
|
dropout,
|
||
|
train,
|
||
|
bidirectional,
|
||
|
batch_sizes=batch_sizes,
|
||
|
)
|
||
|
|
||
|
def symbolic(g, *args):
|
||
|
if symbolic_helper._is_tensor_list(args[3]):
|
||
|
return _rnn_packed(g, *args)
|
||
|
else:
|
||
|
return _rnn_full(g, *args)
|
||
|
|
||
|
return symbolic
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_dim_arange")
|
||
|
@symbolic_helper.parse_args("v", "i")
|
||
|
@_beartype.beartype
|
||
|
def _dim_arange(g: jit_utils.GraphContext, like, dim):
|
||
|
like_shape = g.op("Shape", like)
|
||
|
stop = g.op(
|
||
|
"Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0
|
||
|
)
|
||
|
if symbolic_helper.is_caffe2_aten_fallback():
|
||
|
return g.op("_caffe2::Range", stop)
|
||
|
else:
|
||
|
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
|
||
|
return arange(g, stop, 4, None, None, None)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::detach")
|
||
|
@_beartype.beartype
|
||
|
def detach(g: jit_utils.GraphContext, input):
|
||
|
# Erase aten::detach nodes because ONNX is inference only
|
||
|
return input
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::contiguous")
|
||
|
@symbolic_helper.parse_args("v", "i")
|
||
|
@_beartype.beartype
|
||
|
def contiguous(g: jit_utils.GraphContext, input, memory_format):
|
||
|
if memory_format > 2: # allower values are any, preserve and contiguous_format
|
||
|
raise errors.SymbolicValueError(
|
||
|
"onnx memory_format support is not implemented", input
|
||
|
)
|
||
|
return input
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_pack_padded_sequence")
|
||
|
@symbolic_helper.parse_args("v", "v", "i")
|
||
|
@_beartype.beartype
|
||
|
def _pack_padded_sequence(g: jit_utils.GraphContext, input, lengths, batch_first):
|
||
|
# Currently there is no PackPadded operator in ONNX. We rely on an
|
||
|
# optimization pass to remove this later. It is an error if all
|
||
|
# PackPadded operators cannot be optimized out.
|
||
|
if batch_first:
|
||
|
input = g.op("Transpose", input, perm_i=[1, 0, 2])
|
||
|
if not lengths.type().isSubtypeOf(torch._C.TensorType.get()):
|
||
|
raise errors.SymbolicValueError(
|
||
|
"'lengths' must be a Tensor for ONNX export", input
|
||
|
)
|
||
|
# We know it's a TensorType so this check is now safe.
|
||
|
# It's really only necessary because those operators expand to something that
|
||
|
# only works with int32 types in Caffe2...
|
||
|
if (
|
||
|
_type_utils.JitScalarType.from_value(
|
||
|
lengths, _type_utils.JitScalarType.UNDEFINED
|
||
|
)
|
||
|
!= _type_utils.JitScalarType.INT
|
||
|
):
|
||
|
lengths = g.op("Cast", lengths, to_i=_C_onnx.TensorProtoDataType.INT32)
|
||
|
return g.op("prim::PackPadded", input, lengths, outputs=2)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_pad_packed_sequence")
|
||
|
@symbolic_helper.parse_args("v", "v", "i", "t", "v")
|
||
|
@_beartype.beartype
|
||
|
def _pad_packed_sequence(
|
||
|
g: jit_utils.GraphContext,
|
||
|
data,
|
||
|
batch_sizes,
|
||
|
batch_first,
|
||
|
padding_value,
|
||
|
total_length,
|
||
|
):
|
||
|
# Ignore total_length as it is not supported in _symbolic_pad_packed_sequence
|
||
|
# It is only useful/used when training using data_parallel model, so
|
||
|
# It shouldn't be relevant for ONNX anyway
|
||
|
data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2)
|
||
|
if batch_first:
|
||
|
data = g.op("Transpose", data, perm_i=[1, 0, 2])
|
||
|
return data, lengths
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::randint")
|
||
|
@_beartype.beartype
|
||
|
def randint(g: jit_utils.GraphContext, low, high, shapes, dtype, *options):
|
||
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||
|
low_i = symbolic_helper._get_const(low, "i", "low")
|
||
|
high_i = symbolic_helper._get_const(high, "i", "high")
|
||
|
if dtype is None:
|
||
|
scalar_type = _type_utils.JitScalarType.INT64
|
||
|
else:
|
||
|
scalar_type = _type_utils.JitScalarType(dtype)
|
||
|
if low_i is None:
|
||
|
raise symbolic_helper._onnx_unsupported("randint", low)
|
||
|
if high_i is None:
|
||
|
raise symbolic_helper._onnx_unsupported("randint", high)
|
||
|
|
||
|
shape = symbolic_helper._maybe_get_const(shapes, "is")
|
||
|
if symbolic_helper._is_value(shape):
|
||
|
shape_const = g.op(
|
||
|
"ConstantOfShape",
|
||
|
shapes,
|
||
|
value_t=torch.tensor([0], dtype=torch.float),
|
||
|
)
|
||
|
randn = g.op(
|
||
|
"RandomUniformLike",
|
||
|
shape_const,
|
||
|
low_f=low_i,
|
||
|
high_f=high_i,
|
||
|
)
|
||
|
else:
|
||
|
randn = g.op(
|
||
|
"RandomUniform",
|
||
|
shape_i=shape,
|
||
|
low_f=low_i,
|
||
|
high_f=high_i,
|
||
|
)
|
||
|
|
||
|
# cast to integer type
|
||
|
int_dtype = _type_utils.JitScalarType.INT64
|
||
|
randint = g.op("Cast", randn, to_i=int_dtype.onnx_type())
|
||
|
if int_dtype != scalar_type:
|
||
|
randint = g.op("Cast", randint, to_i=scalar_type.onnx_type())
|
||
|
return randint
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::randint_like")
|
||
|
@_beartype.beartype
|
||
|
def randint_like(g: jit_utils.GraphContext, self, low, high, dtype, *options):
|
||
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||
|
low_i = symbolic_helper._get_const(low, "i", "low")
|
||
|
high_i = symbolic_helper._get_const(high, "i", "high")
|
||
|
if dtype is None:
|
||
|
scalar_type = _type_utils.JitScalarType.INT64
|
||
|
else:
|
||
|
scalar_type = _type_utils.JitScalarType(dtype)
|
||
|
if low_i is None:
|
||
|
raise symbolic_helper._onnx_unsupported("randint", low)
|
||
|
if high_i is None:
|
||
|
raise symbolic_helper._onnx_unsupported("randint", high)
|
||
|
|
||
|
randn = g.op(
|
||
|
"RandomUniformLike",
|
||
|
self,
|
||
|
low_f=low_i,
|
||
|
high_f=high_i,
|
||
|
)
|
||
|
|
||
|
# cast to integer type
|
||
|
int_dtype = _type_utils.JitScalarType.INT64
|
||
|
randint = g.op("Cast", randn, to_i=int_dtype.onnx_type())
|
||
|
if int_dtype != scalar_type:
|
||
|
randint = g.op("Cast", randint, to_i=scalar_type.onnx_type())
|
||
|
return randint
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::randn")
|
||
|
@_beartype.beartype
|
||
|
def randn(g: jit_utils.GraphContext, shapes, dtype, *options):
|
||
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||
|
if dtype is None:
|
||
|
scalar_type = _type_utils.JitScalarType.FLOAT
|
||
|
else:
|
||
|
scalar_type = _type_utils.JitScalarType(dtype)
|
||
|
shape = symbolic_helper._maybe_get_const(shapes, "is")
|
||
|
if symbolic_helper._is_value(shape):
|
||
|
shape_const = g.op(
|
||
|
"ConstantOfShape",
|
||
|
shapes,
|
||
|
value_t=torch.tensor([0], dtype=torch.float),
|
||
|
)
|
||
|
return g.op(
|
||
|
"RandomNormalLike",
|
||
|
shape_const,
|
||
|
dtype_i=scalar_type.onnx_type(),
|
||
|
)
|
||
|
return g.op(
|
||
|
"RandomNormal",
|
||
|
shape_i=shape,
|
||
|
dtype_i=scalar_type.onnx_type(),
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::rand")
|
||
|
@_beartype.beartype
|
||
|
def rand(g: jit_utils.GraphContext, shapes, dtype, *options):
|
||
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||
|
if dtype is None:
|
||
|
scalar_type = _type_utils.JitScalarType.FLOAT
|
||
|
else:
|
||
|
scalar_type = _type_utils.JitScalarType(dtype)
|
||
|
shape = symbolic_helper._maybe_get_const(shapes, "is")
|
||
|
if symbolic_helper._is_value(shape):
|
||
|
shape_const = g.op(
|
||
|
"ConstantOfShape",
|
||
|
shapes,
|
||
|
value_t=torch.tensor([0], dtype=torch.float),
|
||
|
)
|
||
|
return g.op(
|
||
|
"RandomUniformLike",
|
||
|
shape_const,
|
||
|
dtype_i=scalar_type.onnx_type(),
|
||
|
)
|
||
|
return g.op(
|
||
|
"RandomUniform",
|
||
|
shape_i=shape,
|
||
|
dtype_i=scalar_type.onnx_type(),
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::randn_like")
|
||
|
@_beartype.beartype
|
||
|
def randn_like(
|
||
|
g: jit_utils.GraphContext,
|
||
|
self,
|
||
|
dtype,
|
||
|
layout=None,
|
||
|
device=None,
|
||
|
pin_memory=False,
|
||
|
memory_format=None,
|
||
|
):
|
||
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||
|
if dtype is None:
|
||
|
scalar_type = _type_utils.JitScalarType.from_value(
|
||
|
self, _type_utils.JitScalarType.FLOAT
|
||
|
)
|
||
|
else:
|
||
|
scalar_type = _type_utils.JitScalarType(dtype)
|
||
|
return g.op("RandomNormalLike", self, dtype_i=scalar_type.onnx_type())
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::rand_like")
|
||
|
@_beartype.beartype
|
||
|
def rand_like(
|
||
|
g: jit_utils.GraphContext,
|
||
|
self,
|
||
|
dtype,
|
||
|
layout=None,
|
||
|
device=None,
|
||
|
pin_memory=False,
|
||
|
memory_format=None,
|
||
|
):
|
||
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||
|
if dtype is None:
|
||
|
dtype = _type_utils.JitScalarType.from_value(
|
||
|
self, _type_utils.JitScalarType.FLOAT
|
||
|
)
|
||
|
return g.op(
|
||
|
"RandomUniformLike", self, dtype_i=_type_utils.JitScalarType(dtype).onnx_type()
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::rrelu")
|
||
|
@symbolic_helper.parse_args("v", "f", "f", "i", "none")
|
||
|
@_beartype.beartype
|
||
|
def rrelu(g: jit_utils.GraphContext, input, lower, upper, training, generator):
|
||
|
if not training:
|
||
|
slope = (upper + lower) / 2.0
|
||
|
return g.op("LeakyRelu", input, alpha_f=slope)
|
||
|
p = g.op("RandomUniformLike", input, high_f=upper, low_f=lower)
|
||
|
return g.op("PRelu", input, p)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::bernoulli")
|
||
|
@_beartype.beartype
|
||
|
def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None):
|
||
|
if out is not None and not symbolic_helper._is_none(out):
|
||
|
symbolic_helper._unimplemented(
|
||
|
"Bernoulli", "out parameter is not supported for bernoulli", input
|
||
|
)
|
||
|
if generator is not None and not symbolic_helper._is_none(generator):
|
||
|
symbolic_helper._unimplemented(
|
||
|
"Bernoulli", "generator is not supported for bernoulli", input
|
||
|
)
|
||
|
|
||
|
dtype = _type_utils.JitScalarType.from_value(
|
||
|
input, _type_utils.JitScalarType.UNDEFINED
|
||
|
)
|
||
|
if dtype == _type_utils.JitScalarType.UNDEFINED:
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"Bernoulli", "input dtype not accessible", input
|
||
|
)
|
||
|
|
||
|
rands = g.op(
|
||
|
"RandomUniformLike",
|
||
|
input,
|
||
|
high_f=1.0,
|
||
|
low_f=0.0,
|
||
|
dtype_i=dtype.onnx_type(),
|
||
|
)
|
||
|
prob = p if p is not None and not symbolic_helper._is_none(p) else input
|
||
|
output = g.op("Less", rands, prob)
|
||
|
return g.op("Cast", output, to_i=dtype.onnx_type())
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::log_sigmoid")
|
||
|
@symbolic_helper.parse_args("v")
|
||
|
@_beartype.beartype
|
||
|
def log_sigmoid(g: jit_utils.GraphContext, input):
|
||
|
p = g.op("Sigmoid", input)
|
||
|
return g.op("Log", p)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::erf")
|
||
|
@symbolic_helper.parse_args("v")
|
||
|
@_beartype.beartype
|
||
|
def erf(g: jit_utils.GraphContext, input):
|
||
|
return g.op("Erf", input)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::flatten")
|
||
|
@symbolic_helper.quantized_args(True, False, False)
|
||
|
@symbolic_helper.parse_args("v", "i", "i")
|
||
|
@_beartype.beartype
|
||
|
def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim):
|
||
|
dim = symbolic_helper._get_tensor_rank(input)
|
||
|
if dim is None:
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"dim",
|
||
|
"ONNX and PyTorch use different strategies to split the input. "
|
||
|
"Input rank must be known at export time.",
|
||
|
input,
|
||
|
)
|
||
|
|
||
|
if dim == 0:
|
||
|
return symbolic_helper._reshape_helper(g, input, [1])
|
||
|
if dim == 1:
|
||
|
return g.op("Identity", input)
|
||
|
# TODO: remove this as onnx opset 11 spec allows negative axes
|
||
|
if end_dim < 0:
|
||
|
end_dim = dim + end_dim
|
||
|
# use ONNX's Flatten operator for cases where the output shape is 2D
|
||
|
if start_dim == 1 and end_dim == dim - 1:
|
||
|
return g.op("Flatten", input, axis_i=start_dim)
|
||
|
if start_dim == 0 and end_dim == dim - 2:
|
||
|
return g.op("Flatten", input, axis_i=end_dim + 1)
|
||
|
|
||
|
return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::nonzero")
|
||
|
@symbolic_helper.parse_args("v")
|
||
|
@_beartype.beartype
|
||
|
def nonzero(g: jit_utils.GraphContext, input):
|
||
|
"""Emitted from `torch.nonzero(x, as_tuple=False)`"""
|
||
|
return t(g, g.op("NonZero", input))
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::nonzero_numpy")
|
||
|
# Emitted from `torch.nonzero(x, as_tuple=True)`
|
||
|
@_beartype.beartype
|
||
|
def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None):
|
||
|
return unbind(g, nonzero(g, input), 1, _outputs=_outputs)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::isnan")
|
||
|
@symbolic_helper.parse_args("v")
|
||
|
@_beartype.beartype
|
||
|
def isnan(g: jit_utils.GraphContext, input):
|
||
|
output = g.op("IsNaN", input)
|
||
|
return output
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::any")
|
||
|
@_beartype.beartype
|
||
|
def _any(g: jit_utils.GraphContext, *args):
|
||
|
# aten::any(Tensor self)
|
||
|
if len(args) == 1:
|
||
|
input = args[0]
|
||
|
dim, keepdim = None, 0
|
||
|
# aten::any(Tensor self, int[]? dim, bool keepdim)
|
||
|
else:
|
||
|
input, dim, keepdim = args
|
||
|
# Can be int list or single int
|
||
|
dim = symbolic_helper._parse_arg(dim, "t")
|
||
|
dim = [int(d) for d in dim.view(-1)]
|
||
|
keepdim = symbolic_helper._parse_arg(keepdim, "i")
|
||
|
input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64)
|
||
|
input_sum = symbolic_helper._reducesum_helper(
|
||
|
g, input, axes_i=dim, keepdims_i=keepdim
|
||
|
)
|
||
|
return gt(g, input_sum, g.op("Constant", value_t=torch.tensor(0, dtype=torch.long)))
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::all")
|
||
|
@_beartype.beartype
|
||
|
def _all(g: jit_utils.GraphContext, *args):
|
||
|
input = g.op("Not", args[0])
|
||
|
# aten::all(Tensor self)
|
||
|
if len(args) == 1:
|
||
|
return g.op("Not", _any(g, input))
|
||
|
# aten::all(Tensor self, int[]? dim, bool keepdim)
|
||
|
else:
|
||
|
return g.op("Not", _any(g, input, args[1], args[2]))
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::narrow")
|
||
|
@symbolic_helper.parse_args("v", "i", "i", "i")
|
||
|
@_beartype.beartype
|
||
|
def narrow(g: jit_utils.GraphContext, input, dim, start, length):
|
||
|
return symbolic_helper._slice_helper(
|
||
|
g, input, axes=[dim], starts=[start], ends=[start + length]
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::argmax")
|
||
|
@symbolic_helper.parse_args("v", "v", "b")
|
||
|
@_beartype.beartype
|
||
|
def argmax(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input: torch._C.Value,
|
||
|
dim: torch._C.Value,
|
||
|
keepdim: bool,
|
||
|
):
|
||
|
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax")
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::argmin")
|
||
|
@symbolic_helper.parse_args("v", "v", "b")
|
||
|
@_beartype.beartype
|
||
|
def argmin(
|
||
|
g: jit_utils.GraphContext,
|
||
|
input: torch._C.Value,
|
||
|
dim: torch._C.Value,
|
||
|
keepdim: bool,
|
||
|
):
|
||
|
return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin")
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::scatter")
|
||
|
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||
|
@_beartype.beartype
|
||
|
def scatter(g: jit_utils.GraphContext, self, dim, index, src):
|
||
|
src_type = _type_utils.JitScalarType.from_value(
|
||
|
src, _type_utils.JitScalarType.UNDEFINED
|
||
|
)
|
||
|
src = symbolic_helper._maybe_get_scalar(src)
|
||
|
if symbolic_helper._is_value(src):
|
||
|
return g.op("Scatter", self, index, src, axis_i=dim)
|
||
|
else:
|
||
|
# Check if scalar "src" has same type as self (PyTorch allows different
|
||
|
# type for scalar src (but not when src is tensor)). If not, insert Cast node.
|
||
|
self_scalar_type = _type_utils.JitScalarType.from_value(self)
|
||
|
if self_scalar_type != src_type:
|
||
|
src = g.op("Cast", src, to_i=self_scalar_type.onnx_type())
|
||
|
return g.op("Scatter", self, index, expand_as(g, src, index), axis_i=dim)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::scatter_add")
|
||
|
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||
|
@_beartype.beartype
|
||
|
def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
|
||
|
scalar_type = symbolic_helper._try_get_scalar_type(self)
|
||
|
if scalar_type is None:
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"scatter_add", "input dtype not accessible", self
|
||
|
)
|
||
|
sizes = symbolic_helper._get_tensor_sizes(self, allow_nonstatic=False)
|
||
|
if sizes:
|
||
|
to_add = g.op("Constant", value_t=torch.zeros(sizes, dtype=scalar_type.dtype()))
|
||
|
else:
|
||
|
to_add = zeros_like(g, self, scalar_type)
|
||
|
to_add = symbolic_helper._scatter_helper(g, to_add, dim, index, src)
|
||
|
return add(g, self, to_add)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::log2")
|
||
|
@_beartype.beartype
|
||
|
def log2(g: jit_utils.GraphContext, self):
|
||
|
_ln2 = 0.693147180559945309
|
||
|
return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor(_ln2)))
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::is_floating_point")
|
||
|
@_beartype.beartype
|
||
|
def is_floating_point(g: jit_utils.GraphContext, self):
|
||
|
if symbolic_helper._is_fp(self):
|
||
|
return g.op("Constant", value_t=torch.BoolTensor([1]))
|
||
|
return g.op("Constant", value_t=torch.BoolTensor([0]))
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::__is_")
|
||
|
@_beartype.beartype
|
||
|
def __is_(g: jit_utils.GraphContext, self, other):
|
||
|
if symbolic_helper._is_none(other):
|
||
|
if symbolic_helper._is_none(self):
|
||
|
return g.op("Constant", value_t=torch.BoolTensor([1]))
|
||
|
return g.op("Constant", value_t=torch.BoolTensor([0]))
|
||
|
return eq(g, self, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::__isnot_")
|
||
|
@wrap_logical_op_with_negation
|
||
|
@_beartype.beartype
|
||
|
def __isnot_(g: jit_utils.GraphContext, self, other):
|
||
|
return __is_(g, self, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::one_hot")
|
||
|
@_beartype.beartype
|
||
|
def one_hot(g: jit_utils.GraphContext, self, num_classes):
|
||
|
values = g.op("Constant", value_t=torch.LongTensor([0, 1]))
|
||
|
# onnxruntime supports limited type combinations for OneHot.
|
||
|
if _type_utils.JitScalarType.from_value(
|
||
|
num_classes, _type_utils.JitScalarType.UNDEFINED
|
||
|
) in {
|
||
|
_type_utils.JitScalarType.UINT8,
|
||
|
_type_utils.JitScalarType.INT8,
|
||
|
_type_utils.JitScalarType.INT,
|
||
|
_type_utils.JitScalarType.INT16,
|
||
|
}:
|
||
|
num_classes = g.op("Cast", num_classes, to_i=_C_onnx.TensorProtoDataType.INT64)
|
||
|
return g.op("OneHot", self, num_classes, values, axis_i=-1)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::gather")
|
||
|
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||
|
@_beartype.beartype
|
||
|
def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False):
|
||
|
if symbolic_helper._maybe_get_const(sparse_grad, "i"):
|
||
|
return symbolic_helper._unimplemented("gather", "sparse_grad == True", self)
|
||
|
# NOTE: This workaround is needed since GatherElement is only supported
|
||
|
# since opset 11, and Gather in ONNX is not the same as torch.gather.
|
||
|
scalar_type = _type_utils.JitScalarType.from_value(self)
|
||
|
values = g.op("Constant", value_t=torch.LongTensor([0, 1]))
|
||
|
depth = size(g, self, g.op("Constant", value_t=torch.LongTensor([dim])))
|
||
|
index = g.op(
|
||
|
"Cast",
|
||
|
g.op("OneHot", index, depth, values, axis_i=dim),
|
||
|
to_i=scalar_type.onnx_type(),
|
||
|
)
|
||
|
mul = g.op("Mul", symbolic_helper._unsqueeze_helper(g, self, [dim + 1]), index)
|
||
|
return symbolic_helper._reducesum_helper(g, mul, axes_i=[dim], keepdims_i=0)
|
||
|
|
||
|
|
||
|
@symbolic_helper.parse_args("v", "is", "i", "i")
|
||
|
@_beartype.beartype
|
||
|
def _var_mean(g: jit_utils.GraphContext, input, dim, correction, keepdim):
|
||
|
if dim is None:
|
||
|
mean = g.op("ReduceMean", input, keepdims_i=0)
|
||
|
t_mean = mean
|
||
|
num_elements = numel(g, input)
|
||
|
else:
|
||
|
mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=keepdim)
|
||
|
t_mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=1)
|
||
|
redudced_dims = g.op("Shape", input)
|
||
|
# dim could contain one or multiple dimensions
|
||
|
redudced_dims = g.op(
|
||
|
"Gather",
|
||
|
redudced_dims,
|
||
|
g.op("Constant", value_t=torch.tensor(dim)),
|
||
|
axis_i=0,
|
||
|
)
|
||
|
num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0)
|
||
|
sub_v = g.op("Sub", input, t_mean)
|
||
|
sqr_sub = g.op("Mul", sub_v, sub_v)
|
||
|
keepdim_mean = 0 if dim is None else keepdim
|
||
|
var = g.op("ReduceMean", sqr_sub, axes_i=dim, keepdims_i=keepdim_mean)
|
||
|
# Correct bias in calculating variance, by dividing it over (N - correction) instead on N
|
||
|
if correction is None:
|
||
|
correction = 1
|
||
|
if correction != 0:
|
||
|
num_elements = g.op(
|
||
|
"Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT
|
||
|
)
|
||
|
one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float))
|
||
|
mul = g.op("Mul", var, num_elements)
|
||
|
var = g.op("Div", mul, g.op("Sub", num_elements, one))
|
||
|
return var, mean
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::std")
|
||
|
@_beartype.beartype
|
||
|
def std(g: jit_utils.GraphContext, input, *args):
|
||
|
var, _ = var_mean(g, input, *args)
|
||
|
return g.op("Sqrt", var)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::var")
|
||
|
@_beartype.beartype
|
||
|
def var(g: jit_utils.GraphContext, input, *args):
|
||
|
var, _ = var_mean(g, input, *args)
|
||
|
return var
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::var_mean")
|
||
|
@_beartype.beartype
|
||
|
def var_mean(g: jit_utils.GraphContext, input, *args):
|
||
|
# var_mean (and all variance-related functions) has multiple signatures, so need to manually figure
|
||
|
# out the correct arguments:
|
||
|
# aten::var_mean(Tensor self, bool unbiased)
|
||
|
# aten::var_mean(Tensor self, int[1] dim, bool unbiased, bool keepdim=False)
|
||
|
# aten::var_mean(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False)
|
||
|
if len(args) == 1:
|
||
|
return _var_mean(g, input, None, args[0], None)
|
||
|
else:
|
||
|
return _var_mean(g, input, *args)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::std_mean")
|
||
|
@_beartype.beartype
|
||
|
def std_mean(g: jit_utils.GraphContext, input, *args):
|
||
|
var, mean = var_mean(g, input, *args)
|
||
|
return g.op("Sqrt", var), mean
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::logsumexp")
|
||
|
@symbolic_helper.parse_args("v", "is", "i")
|
||
|
@_beartype.beartype
|
||
|
def logsumexp(g: jit_utils.GraphContext, input, dim, keepdim):
|
||
|
return g.op("ReduceLogSumExp", input, axes_i=dim, keepdims_i=keepdim)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::arange")
|
||
|
@_beartype.beartype
|
||
|
def arange(g: jit_utils.GraphContext, *args):
|
||
|
if symbolic_helper.is_caffe2_aten_fallback():
|
||
|
return g.at("arange", *args)
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def _get_arange_dtype(dtype):
|
||
|
dtype = symbolic_helper._maybe_get_const(dtype, "i")
|
||
|
return dtype
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def _float_step_convert(range_tensor):
|
||
|
if symbolic_helper._is_fp(range_tensor):
|
||
|
range_tensor = g.op(
|
||
|
"Cast",
|
||
|
g.op("Ceil", range_tensor),
|
||
|
to_i=_type_utils.JitScalarType.INT64.onnx_type(),
|
||
|
)
|
||
|
return range_tensor
|
||
|
|
||
|
if len(args) == 2 or len(args) == 5:
|
||
|
if len(args) == 2:
|
||
|
# aten::arange(Scalar end, Tensor out)
|
||
|
dtype = None
|
||
|
else:
|
||
|
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
|
||
|
dtype = _get_arange_dtype(args[1])
|
||
|
dtype, end, start, step = symbolic_helper._arange_cast_helper(
|
||
|
g, end=args[0], dtype=dtype
|
||
|
)
|
||
|
end = symbolic_helper._unsqueeze_helper(g, end, [0])
|
||
|
range_tensor = _float_step_convert(end)
|
||
|
arange_tensor = symbolic_helper._squeeze_helper(
|
||
|
g, nonzero(g, ones(g, range_tensor, dtype, None, None)), [1]
|
||
|
)
|
||
|
return g.op(
|
||
|
"Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type()
|
||
|
)
|
||
|
elif len(args) == 4 or len(args) == 7:
|
||
|
if len(args) == 4:
|
||
|
# aten::arange(Scalar start, Scalar end, Scalar step, Tensor out)
|
||
|
dtype = None
|
||
|
else:
|
||
|
# aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory)
|
||
|
dtype = _get_arange_dtype(args[3])
|
||
|
dtype, end, start, step = symbolic_helper._arange_cast_helper(
|
||
|
g, start=args[0], end=args[1], step=args[2], dtype=dtype
|
||
|
)
|
||
|
step = symbolic_helper._unsqueeze_helper(g, step, [0])
|
||
|
end = symbolic_helper._unsqueeze_helper(g, end, [0])
|
||
|
start = symbolic_helper._unsqueeze_helper(g, start, [0])
|
||
|
range_tensor = _float_step_convert(g.op("Div", g.op("Sub", end, start), step))
|
||
|
arange_tensor = symbolic_helper._squeeze_helper(
|
||
|
g, nonzero(g, ones(g, range_tensor, None, None, None)), [1]
|
||
|
)
|
||
|
arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start)
|
||
|
return g.op(
|
||
|
"Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type()
|
||
|
)
|
||
|
elif len(args) == 6:
|
||
|
# aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
|
||
|
dtype = _get_arange_dtype(args[2])
|
||
|
dtype, end, start, step = symbolic_helper._arange_cast_helper(
|
||
|
g, start=args[0], end=args[1], dtype=dtype
|
||
|
)
|
||
|
end = symbolic_helper._unsqueeze_helper(g, end, [0])
|
||
|
start = symbolic_helper._unsqueeze_helper(g, start, [0])
|
||
|
range_tensor = _float_step_convert(g.op("Sub", end, start))
|
||
|
arange_tensor = g.op(
|
||
|
"Add",
|
||
|
symbolic_helper._squeeze_helper(
|
||
|
g, nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), [1]
|
||
|
),
|
||
|
start,
|
||
|
)
|
||
|
return g.op(
|
||
|
"Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type()
|
||
|
)
|
||
|
|
||
|
return symbolic_helper._unimplemented("aten::arange", f"with {len(args)} arguments")
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::linspace")
|
||
|
@_beartype.beartype
|
||
|
def linspace(
|
||
|
g: jit_utils.GraphContext, start, end, steps, dtype, layout, device, pin_memory
|
||
|
):
|
||
|
range_tensor = symbolic_helper._arange_helper(g, steps, None)
|
||
|
step = div(
|
||
|
g,
|
||
|
sub(g, end, start),
|
||
|
sub(g, steps, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))),
|
||
|
)
|
||
|
return add(g, mul(g, range_tensor, step), start)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::lift")
|
||
|
@_beartype.beartype
|
||
|
def lift(g: jit_utils.GraphContext, self):
|
||
|
# at::lift() is a no-op from the perspective of tracing for onnx
|
||
|
return self
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::masked_fill")
|
||
|
@_beartype.beartype
|
||
|
def masked_fill(g: jit_utils.GraphContext, self, mask, value):
|
||
|
mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL)
|
||
|
value = symbolic_helper._maybe_get_scalar(value)
|
||
|
return g.op("Where", mask, symbolic_helper._if_scalar_type_as(value, self), self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::masked_fill_")
|
||
|
@_beartype.beartype
|
||
|
def masked_fill_(g: jit_utils.GraphContext, self, mask, value):
|
||
|
return masked_fill(g, self, mask, value)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::index")
|
||
|
@_beartype.beartype
|
||
|
def index(g: jit_utils.GraphContext, self, index):
|
||
|
if symbolic_helper.is_caffe2_aten_fallback():
|
||
|
return g.at("index", self, index, overload_name="Tensor")
|
||
|
|
||
|
if symbolic_helper._is_packed_list(index):
|
||
|
indices = symbolic_helper._unpack_list(index)
|
||
|
else:
|
||
|
indices = [index]
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def try_mask_to_index(index):
|
||
|
if not symbolic_helper._is_none(index) and (
|
||
|
_type_utils.JitScalarType.from_value(
|
||
|
index, _type_utils.JitScalarType.UNDEFINED
|
||
|
)
|
||
|
== _type_utils.JitScalarType.UINT8
|
||
|
or symbolic_helper._is_bool(index)
|
||
|
):
|
||
|
if g.opset < 9:
|
||
|
raise errors.SymbolicValueError(
|
||
|
"Exporting masked indices are only supported after ONNX opset 9.",
|
||
|
self,
|
||
|
)
|
||
|
warnings.warn(
|
||
|
"Exporting aten::index operator with indices of type Byte. "
|
||
|
"Only 1-D indices are supported. In any other case, "
|
||
|
"this will produce an incorrect ONNX graph."
|
||
|
)
|
||
|
index = symbolic_helper._squeeze_helper(g, nonzero(g, index), [1])
|
||
|
return index
|
||
|
|
||
|
indices = [try_mask_to_index(idx) for idx in indices]
|
||
|
if len(indices) == 1:
|
||
|
return symbolic_helper._select_helper(
|
||
|
g, self, 0, indices[0], apply_reshape=False
|
||
|
)
|
||
|
else:
|
||
|
# Multiple tensors as indices. Each tensor could either be
|
||
|
# 1. prim::Constant()
|
||
|
# representing ":" in python indexing. E.g. tensor[:, :]
|
||
|
# 2. prim::Constant[value=...] or tensor output
|
||
|
# representing advanced indexing. E.g. tensor[[0, 1], [2, 0]].
|
||
|
# For more info on advanced indexing,
|
||
|
# check https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
|
||
|
|
||
|
# Consider a general case of
|
||
|
# t: [x_1, y_1, y_2, ..., x_m, ..., y_n]
|
||
|
# where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes for ":".
|
||
|
# Same results can be achieved through transposing t into
|
||
|
# t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n]
|
||
|
# and use gatherND. However ONNX does not have gatherND, to use 1d gather we'll need to flatten t
|
||
|
# and process the tensor indices.
|
||
|
# t: [x_1 * x_2 * ... * x_m, y_1 * y_2 * ... * y_n]
|
||
|
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j))
|
||
|
# After gather, reshape and transpose back.
|
||
|
adv_idx_indices = [
|
||
|
i for i, idx in enumerate(indices) if not symbolic_helper._is_none(idx)
|
||
|
]
|
||
|
|
||
|
if len(adv_idx_indices) == 0:
|
||
|
return self
|
||
|
elif len(adv_idx_indices) == 1:
|
||
|
return index_select(
|
||
|
g, self, adv_idx_indices[0], indices[adv_idx_indices[0]]
|
||
|
)
|
||
|
else:
|
||
|
rank = symbolic_helper._get_tensor_rank(self)
|
||
|
if rank is None:
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"aten::index",
|
||
|
"operator of advanced indexing on tensor of unknown rank. "
|
||
|
"Try turning on shape inference during export: "
|
||
|
"torch.onnx._export(..., onnx_shape_inference=True).",
|
||
|
self,
|
||
|
)
|
||
|
# TODO: If indexing is supported natively in ONNX in future opsets,
|
||
|
# update the warning to recommend exporting with higher opset version.
|
||
|
warnings.warn(
|
||
|
"Exporting aten::index operator of advanced indexing in opset "
|
||
|
f"{GLOBALS.export_onnx_opset_version}"
|
||
|
" is achieved by combination of multiple ONNX operators, "
|
||
|
"including Reshape, Transpose, Concat, and Gather. "
|
||
|
"If indices include negative values, the exported graph will produce incorrect results."
|
||
|
)
|
||
|
adv_idx_count = len(adv_idx_indices)
|
||
|
shape_tensor = _shape_as_tensor(g, self)
|
||
|
dim_tensor_list = [
|
||
|
g.op(
|
||
|
"Gather",
|
||
|
shape_tensor,
|
||
|
g.op("Constant", value_t=torch.LongTensor([dim])),
|
||
|
axis_i=0,
|
||
|
)
|
||
|
for dim in range(rank)
|
||
|
]
|
||
|
|
||
|
self = g.op(
|
||
|
"Transpose",
|
||
|
self,
|
||
|
perm_i=adv_idx_indices
|
||
|
+ [i for i in range(rank) if i not in adv_idx_indices],
|
||
|
)
|
||
|
self = g.op("Flatten", self, axis_i=adv_idx_count)
|
||
|
|
||
|
# Note that tensor indices will be broadcasted while accumulating. Thus we get the final subarray shape as well.
|
||
|
cum_adv_index = indices[adv_idx_indices[-1]]
|
||
|
multiplier = dim_tensor_list[adv_idx_indices[-1]]
|
||
|
for i in range(adv_idx_count - 2, -1, -1):
|
||
|
adv_index = g.op("Mul", indices[adv_idx_indices[i]], multiplier)
|
||
|
cum_adv_index = g.op("Add", cum_adv_index, adv_index)
|
||
|
multiplier = g.op(
|
||
|
"Mul", multiplier, dim_tensor_list[adv_idx_indices[i]]
|
||
|
)
|
||
|
|
||
|
# perform gather
|
||
|
self = index_select(g, self, 0, cum_adv_index)
|
||
|
|
||
|
cum_adv_index_shape_tensor = _shape_as_tensor(g, cum_adv_index)
|
||
|
# check if all advanced indices are consecutive.
|
||
|
# Refer to https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#combining-advanced-and-basic-indexing
|
||
|
# to understand how the subarray position is decided.
|
||
|
if adv_idx_indices == list(
|
||
|
range(adv_idx_indices[0], adv_idx_indices[-1] + 1)
|
||
|
):
|
||
|
# unfold regular index axes
|
||
|
folded_adv_idx_shape_list = [
|
||
|
g.op("Constant", value_t=torch.LongTensor([-1]))
|
||
|
] + [
|
||
|
dim_tensor_list[i] for i in range(rank) if i not in adv_idx_indices
|
||
|
]
|
||
|
folded_adv_idx_shape = g.op(
|
||
|
"Concat", *folded_adv_idx_shape_list, axis_i=0
|
||
|
)
|
||
|
self = symbolic_helper._reshape_helper(g, self, folded_adv_idx_shape)
|
||
|
|
||
|
# Transpose folded advanced indexed axis to its original location.
|
||
|
adv_idx_permute = (
|
||
|
list(range(1, adv_idx_indices[0] + 1))
|
||
|
+ [0]
|
||
|
+ list(range(adv_idx_indices[0] + 1, rank - adv_idx_count + 1))
|
||
|
)
|
||
|
self = g.op("Transpose", self, perm_i=adv_idx_permute)
|
||
|
|
||
|
# unfold advanced index axes
|
||
|
final_shape_list = (
|
||
|
[dim_tensor_list[i] for i in range(adv_idx_indices[0])]
|
||
|
+ [cum_adv_index_shape_tensor]
|
||
|
+ [
|
||
|
dim_tensor_list[i]
|
||
|
for i in range(adv_idx_indices[0], rank)
|
||
|
if i not in adv_idx_indices
|
||
|
]
|
||
|
)
|
||
|
final_shape = g.op("Concat", *final_shape_list, axis_i=0)
|
||
|
else:
|
||
|
final_shape = g.op(
|
||
|
"Concat",
|
||
|
cum_adv_index_shape_tensor,
|
||
|
*[
|
||
|
dim_tensor_list[i]
|
||
|
for i in range(rank)
|
||
|
if i not in adv_idx_indices
|
||
|
],
|
||
|
axis_i=0,
|
||
|
)
|
||
|
|
||
|
return symbolic_helper._reshape_helper(g, self, final_shape)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::linalg_norm")
|
||
|
@symbolic_helper.parse_args("v", "v", "is", "b", "v")
|
||
|
@_beartype.beartype
|
||
|
def linalg_norm(
|
||
|
g: jit_utils.GraphContext,
|
||
|
self: torch._C.Value,
|
||
|
ord: torch._C.Value,
|
||
|
dim: Optional[Sequence[int]],
|
||
|
keepdim: bool,
|
||
|
dtype: torch._C.Value,
|
||
|
):
|
||
|
# Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.norm.html
|
||
|
ord_value = None
|
||
|
if dim is None:
|
||
|
if symbolic_helper._is_none(ord):
|
||
|
self = symbolic_helper._reshape_helper(g, self, [-1])
|
||
|
ord = g.op("Constant", value_t=torch.LongTensor([2]))
|
||
|
self_dim = symbolic_helper._get_tensor_rank(self)
|
||
|
if self_dim is None:
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"dim", "Input rank must be known at export time.", self
|
||
|
)
|
||
|
if self_dim == 1:
|
||
|
ord_value = symbolic_helper._parse_arg(ord, "f")
|
||
|
else:
|
||
|
dim = [0, 1]
|
||
|
else:
|
||
|
if len(dim) == 1:
|
||
|
if symbolic_helper._is_none(ord):
|
||
|
ord = g.op("Constant", value_t=torch.LongTensor([2]))
|
||
|
ord_value = symbolic_helper._parse_arg(ord, "f")
|
||
|
if ord_value:
|
||
|
return linalg_vector_norm(g, self, ord_value, dim, keepdim, dtype)
|
||
|
return linalg_matrix_norm(g, self, ord, dim, keepdim, dtype)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::linalg_vector_norm")
|
||
|
@symbolic_helper.parse_args("v", "f", "is", "b", "v")
|
||
|
@_beartype.beartype
|
||
|
def linalg_vector_norm(
|
||
|
g: jit_utils.GraphContext,
|
||
|
self: torch._C.Value,
|
||
|
ord: float,
|
||
|
dim: Optional[Sequence[int]],
|
||
|
keepdim: bool,
|
||
|
dtype: torch._C.Value,
|
||
|
):
|
||
|
# Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html
|
||
|
if symbolic_helper._is_none(dim):
|
||
|
self = symbolic_helper._reshape_helper(g, self, [-1])
|
||
|
keepdim = False
|
||
|
|
||
|
if ord == math.inf:
|
||
|
result = g.op("ReduceMax", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim)
|
||
|
elif ord == -math.inf:
|
||
|
result = g.op("ReduceMin", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim)
|
||
|
elif ord == 0:
|
||
|
return symbolic_helper._onnx_opset_unsupported_detailed(
|
||
|
"linalg_vector_norm", 9, 11, "ord=0 not supported", self
|
||
|
)
|
||
|
elif ord == 1:
|
||
|
result = _reduce_op_symbolic("ReduceL1")(g, self, dim=dim, keepdim=keepdim)
|
||
|
elif ord == 2:
|
||
|
result = _reduce_op_symbolic("ReduceL2")(g, self, dim=dim, keepdim=keepdim)
|
||
|
else:
|
||
|
ord_op = g.op("Constant", value_t=torch.tensor(ord, dtype=torch.float32))
|
||
|
result = symbolic_helper._reducesum_helper(
|
||
|
g, g.op("Pow", g.op("Abs", self), ord_op), axes_i=dim, keepdims_i=keepdim
|
||
|
)
|
||
|
result = g.op(
|
||
|
"Pow",
|
||
|
result,
|
||
|
g.op(
|
||
|
"Div",
|
||
|
g.op("Constant", value_t=torch.tensor(1, dtype=torch.float32)),
|
||
|
ord_op,
|
||
|
),
|
||
|
)
|
||
|
|
||
|
if not symbolic_helper._is_none(dtype):
|
||
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||
|
result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) # type: ignore[arg-type]
|
||
|
return result
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::linalg_matrix_norm")
|
||
|
@symbolic_helper.parse_args("v", "v", "is", "b", "v")
|
||
|
@_beartype.beartype
|
||
|
def linalg_matrix_norm(
|
||
|
g: jit_utils.GraphContext,
|
||
|
self: torch._C.Value,
|
||
|
ord: torch._C.Value,
|
||
|
dim: List[int],
|
||
|
keepdim: bool,
|
||
|
dtype: torch._C.Value,
|
||
|
):
|
||
|
# Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.matrix_norm.html
|
||
|
ord_value = symbolic_helper._parse_arg(ord, "s")
|
||
|
if ord_value == "fro":
|
||
|
return frobenius_norm(g, self, dim, keepdim)
|
||
|
elif ord_value == "nuc":
|
||
|
return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==nuc", self)
|
||
|
else:
|
||
|
ord_value = symbolic_helper._parse_arg(ord, "f")
|
||
|
if ord_value is None:
|
||
|
return frobenius_norm(g, self, dim, keepdim)
|
||
|
if ord_value == 2 or ord_value == -2:
|
||
|
# ord = 2/-2 unimplemented due to lack of operators
|
||
|
# used to calculate singular values
|
||
|
return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==2", self)
|
||
|
# Wrap the dim vector to handle negative dim values
|
||
|
self_dim = symbolic_helper._get_tensor_rank(self)
|
||
|
if self_dim is None:
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"linalg.matrix_norm", "Input rank must be known at export time.", self
|
||
|
)
|
||
|
# Common implementation for cases with
|
||
|
# ord = 1/-1 and ord = inf/-inf
|
||
|
if dim[0] < 0:
|
||
|
dim[0] += self_dim
|
||
|
if dim[1] < 0:
|
||
|
dim[1] += self_dim
|
||
|
|
||
|
if ord_value == math.inf or ord_value == -math.inf:
|
||
|
dim[0], dim[1] = dim[1], dim[0]
|
||
|
if dim[1] > dim[0] and not keepdim:
|
||
|
dim[1] -= 1
|
||
|
sum = symbolic_helper._reducesum_helper(
|
||
|
g, g.op("Abs", self), axes_i=[dim[0]], keepdims_i=keepdim
|
||
|
)
|
||
|
if ord_value > 0:
|
||
|
result, indices = max(
|
||
|
g,
|
||
|
sum,
|
||
|
dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])),
|
||
|
keepdim=keepdim,
|
||
|
)
|
||
|
else:
|
||
|
result, indices = min(
|
||
|
g,
|
||
|
sum,
|
||
|
dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])),
|
||
|
keepdim=keepdim,
|
||
|
)
|
||
|
return result
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::linalg_cross")
|
||
|
@symbolic_helper.parse_args("v", "v", "i")
|
||
|
@_beartype.beartype
|
||
|
def linalg_cross(g: jit_utils.GraphContext, input, other, dim=-1):
|
||
|
return cross(g, input, other, dim)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::frobenius_norm")
|
||
|
@symbolic_helper.parse_args("v", "is", "b")
|
||
|
@_beartype.beartype
|
||
|
def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False):
|
||
|
sqr = g.op("Mul", self, self)
|
||
|
sumsqr = symbolic_helper._reducesum_helper(g, sqr, axes_i=dim, keepdims_i=keepdim)
|
||
|
return g.op("Sqrt", sumsqr)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::multinomial")
|
||
|
@symbolic_helper.parse_args("v", "i", "b", "v")
|
||
|
@_beartype.beartype
|
||
|
def multinomial(
|
||
|
g: jit_utils.GraphContext, input, num_samples, replacement=False, generator=None
|
||
|
):
|
||
|
if generator is not None and not symbolic_helper._is_none(generator):
|
||
|
symbolic_helper._unimplemented(
|
||
|
"Multinomial", "generator is not supported for multinomial", input
|
||
|
)
|
||
|
if not replacement and num_samples > 1:
|
||
|
symbolic_helper._unimplemented(
|
||
|
"Multinomial",
|
||
|
"replacement=False when num_samples > 1 is not supported for multinomial",
|
||
|
input,
|
||
|
)
|
||
|
|
||
|
log_input = log(g, input)
|
||
|
return g.op(
|
||
|
"Multinomial",
|
||
|
log_input,
|
||
|
dtype_i=_C_onnx.TensorProtoDataType.INT64,
|
||
|
sample_size_i=num_samples,
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::baddbmm")
|
||
|
@_beartype.beartype
|
||
|
def baddbmm(g: jit_utils.GraphContext, self, batch1, batch2, beta, alpha):
|
||
|
scalar_type = _type_utils.JitScalarType.from_value(self)
|
||
|
batch_mul = matmul(g, batch1, batch2)
|
||
|
mul_a = mul(
|
||
|
g,
|
||
|
batch_mul,
|
||
|
g.op("Cast", alpha, to_i=scalar_type.onnx_type()),
|
||
|
)
|
||
|
mul_b = mul(
|
||
|
g,
|
||
|
self,
|
||
|
g.op("Cast", beta, to_i=scalar_type.onnx_type()),
|
||
|
)
|
||
|
return add(g, mul_a, mul_b)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::meshgrid")
|
||
|
@symbolic_helper.parse_args("v", "s")
|
||
|
@_beartype.beartype
|
||
|
def meshgrid(g: jit_utils.GraphContext, tensor_list, indexing: Optional[str] = None):
|
||
|
if indexing is None:
|
||
|
indexing = "ij"
|
||
|
elif indexing not in {"ij", "xy"}:
|
||
|
raise errors.SymbolicValueError(
|
||
|
f"Unsupported indexing: {indexing}", tensor_list
|
||
|
)
|
||
|
unpacked_tensor_list = symbolic_helper._unpack_list(tensor_list)
|
||
|
if indexing == "xy":
|
||
|
unpacked_tensor_list[:2] = unpacked_tensor_list[1::-1]
|
||
|
tensors = [
|
||
|
symbolic_helper._reshape_helper(
|
||
|
g, t, g.op("Constant", value_t=torch.LongTensor([-1]))
|
||
|
)
|
||
|
for t in unpacked_tensor_list
|
||
|
]
|
||
|
tensors_shape = [g.op("Shape", t) for t in tensors]
|
||
|
out_shape = g.op("Concat", *tensors_shape, axis_i=0)
|
||
|
out = []
|
||
|
for i, t in enumerate(tensors):
|
||
|
shape_i = [g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))] * len(
|
||
|
tensors
|
||
|
)
|
||
|
shape_i[i] = tensors_shape[i]
|
||
|
t_reshaped = _reshape_from_tensor(g, t, g.op("Concat", *shape_i, axis_i=0))
|
||
|
out.append(g.op("Expand", t_reshaped, out_shape))
|
||
|
if indexing == "xy":
|
||
|
out[0], out[1] = out[1], out[0]
|
||
|
return g.op("prim::ListConstruct", *out)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::remainder")
|
||
|
@_beartype.beartype
|
||
|
def remainder(g: jit_utils.GraphContext, input, other):
|
||
|
div = _floor_divide(g, input, other)
|
||
|
quo = g.op("Mul", div, other)
|
||
|
return g.op("Sub", input, quo)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::gelu")
|
||
|
@symbolic_helper.parse_args("v", "s")
|
||
|
@_beartype.beartype
|
||
|
def gelu(g: jit_utils.GraphContext, self: torch._C.Value, approximate: str = "none"):
|
||
|
if approximate == "tanh":
|
||
|
kBeta = math.sqrt(2 / math.pi)
|
||
|
kKappa = 0.044715
|
||
|
|
||
|
beta = torch.tensor(kBeta, dtype=torch.double)
|
||
|
kappa = torch.tensor(kKappa, dtype=torch.double)
|
||
|
one = torch.tensor(1.0, dtype=torch.double)
|
||
|
half = torch.tensor(0.5, dtype=torch.double)
|
||
|
|
||
|
self_cube = mul(g, self, mul(g, self, self))
|
||
|
inner = mul(g, beta, add(g, self, mul(g, kappa, self_cube)))
|
||
|
return mul(g, half, mul(g, self, add(g, one, g.op("Tanh", inner))))
|
||
|
else:
|
||
|
_sqrt2 = 1.4142135623730951
|
||
|
erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double)))
|
||
|
erf_plusone = add(
|
||
|
g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double))
|
||
|
)
|
||
|
return mul(
|
||
|
g,
|
||
|
mul(g, self, erf_plusone),
|
||
|
g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)),
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::group_norm")
|
||
|
@symbolic_helper.quantized_args(True, False, False, False)
|
||
|
@symbolic_helper.parse_args("v", "i", "v", "v", "f", "i")
|
||
|
@_beartype.beartype
|
||
|
def group_norm(
|
||
|
g: jit_utils.GraphContext, input, num_groups, weight, bias, eps, cudnn_enabled
|
||
|
):
|
||
|
if symbolic_helper.is_caffe2_aten_fallback():
|
||
|
return g.at(
|
||
|
"group_norm",
|
||
|
input,
|
||
|
weight,
|
||
|
bias,
|
||
|
num_groups_i=num_groups,
|
||
|
eps_f=eps,
|
||
|
cudnn_enabled_i=cudnn_enabled,
|
||
|
)
|
||
|
|
||
|
channel_size = symbolic_helper._get_tensor_dim_size(input, 1)
|
||
|
if channel_size is not None:
|
||
|
assert channel_size % num_groups == 0
|
||
|
input_rank = symbolic_helper._get_tensor_rank(input)
|
||
|
if input_rank is None:
|
||
|
return symbolic_helper._unimplemented("group_norm", "unknown input rank", input)
|
||
|
# 0 in the shape list keeps dimension value unchanged.
|
||
|
shape = [0, num_groups, -1]
|
||
|
input_reshaped = symbolic_helper._reshape_helper(
|
||
|
g, input, g.op("Constant", value_t=torch.LongTensor(shape))
|
||
|
)
|
||
|
|
||
|
# C is always divisible by num_groups
|
||
|
# Due to shape difference. we need to apply weight and bias after
|
||
|
# instance norm computation and reshape
|
||
|
weight_ = g.op(
|
||
|
"Constant",
|
||
|
value_t=torch.tensor(
|
||
|
[1.0] * num_groups,
|
||
|
dtype=_type_utils.JitScalarType.from_value(input).dtype(),
|
||
|
),
|
||
|
)
|
||
|
bias_ = g.op(
|
||
|
"Constant",
|
||
|
value_t=torch.tensor(
|
||
|
[0.0] * num_groups,
|
||
|
dtype=_type_utils.JitScalarType.from_value(input).dtype(),
|
||
|
),
|
||
|
)
|
||
|
|
||
|
norm_reshaped = g.op(
|
||
|
"InstanceNormalization", input_reshaped, weight_, bias_, epsilon_f=eps
|
||
|
)
|
||
|
norm = symbolic_helper._reshape_helper(g, norm_reshaped, g.op("Shape", input))
|
||
|
|
||
|
if weight is None or weight.node().mustBeNone():
|
||
|
weight_value = torch.tensor(
|
||
|
[1.0], dtype=_type_utils.JitScalarType.from_value(input).dtype()
|
||
|
)
|
||
|
weight = g.op("Constant", value_t=weight_value)
|
||
|
if bias is None or bias.node().mustBeNone():
|
||
|
bias_value = torch.tensor(
|
||
|
[0.0], dtype=_type_utils.JitScalarType.from_value(input).dtype()
|
||
|
)
|
||
|
bias = g.op("Constant", value_t=bias_value)
|
||
|
|
||
|
# Norm has shape [N, C, *] so we reshape weight and bias to [C, *]
|
||
|
axes = list(range(1, input_rank - 1))
|
||
|
return add(
|
||
|
g,
|
||
|
mul(g, norm, symbolic_helper._unsqueeze_helper(g, weight, axes)),
|
||
|
symbolic_helper._unsqueeze_helper(g, bias, axes),
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_weight_norm")
|
||
|
@symbolic_helper.parse_args("v", "v", "i")
|
||
|
@_beartype.beartype
|
||
|
def _weight_norm(g: jit_utils.GraphContext, weight_v, weight_g, dim):
|
||
|
rank = symbolic_helper._get_tensor_rank(weight_v)
|
||
|
if rank is not None:
|
||
|
# W = g * ((v) / ||v||)
|
||
|
# Compute norm_except_dim for l2 norm. dim = None means over all dims
|
||
|
# torch's weight_norm module sets dim = -1 if it's None.
|
||
|
# This conflicts the logic for negative axes to access dims backwards
|
||
|
# TODO: Might need a fix in torch group_norm module
|
||
|
axes = list(range(rank))
|
||
|
if dim is not None:
|
||
|
if dim < -1:
|
||
|
dim += rank
|
||
|
if dim != -1:
|
||
|
axes.remove(dim)
|
||
|
norm_v = norm(g, weight_v, 2, axes, 1)
|
||
|
div = g.op("Div", weight_v, norm_v)
|
||
|
return g.op("Mul", div, weight_g)
|
||
|
if symbolic_helper.is_caffe2_aten_fallback():
|
||
|
return g.at("_weight_norm", weight_v, weight_g, dim_i=dim)
|
||
|
|
||
|
raise errors.SymbolicValueError(
|
||
|
"Unsupported: ONNX export of _weight_norm for tensor of unknown rank.",
|
||
|
weight_v,
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::dim")
|
||
|
@_beartype.beartype
|
||
|
def dim(g: jit_utils.GraphContext, self):
|
||
|
"""Implement the dim functionality available for a pytorch tensor in ONNX"""
|
||
|
# ONNX does not support dim directly in this opset so we can use 2 ops to get the info
|
||
|
shape = g.op("Shape", self)
|
||
|
return g.op("Size", shape)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::__contains_")
|
||
|
@_beartype.beartype
|
||
|
def __contains_(g: jit_utils.GraphContext, self, element):
|
||
|
unpacked_list = symbolic_helper._unpack_list(self)
|
||
|
if all(
|
||
|
symbolic_helper._is_constant(x) for x in unpacked_list
|
||
|
) and symbolic_helper._is_constant(element):
|
||
|
return g.op(
|
||
|
"Constant",
|
||
|
value_t=torch.tensor(
|
||
|
symbolic_helper._node_get(element.node(), "value")
|
||
|
in (symbolic_helper._node_get(x.node(), "value") for x in unpacked_list)
|
||
|
),
|
||
|
)
|
||
|
|
||
|
raise errors.SymbolicValueError(
|
||
|
"Unsupported: ONNX export of __contains__ for non-constant list or element.",
|
||
|
self,
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::__getitem_")
|
||
|
@_beartype.beartype
|
||
|
def __getitem_(g: jit_utils.GraphContext, self, i):
|
||
|
return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::item")
|
||
|
@_beartype.beartype
|
||
|
def item(g: jit_utils.GraphContext, self):
|
||
|
return self
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::take")
|
||
|
@_beartype.beartype
|
||
|
def take(g: jit_utils.GraphContext, self, index):
|
||
|
self_flattened = symbolic_helper._reshape_helper(
|
||
|
g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
|
||
|
)
|
||
|
out = index_select(g, self_flattened, 0, index)
|
||
|
out = reshape_as(g, out, index)
|
||
|
return out
|
||
|
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def _kl_div_log_target_impl(g: jit_utils.GraphContext, input, target):
|
||
|
diff_ = sub(g, target, input)
|
||
|
exp_ = exp(g, target)
|
||
|
output = mul(g, exp_, diff_)
|
||
|
return output
|
||
|
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def _kl_div_non_log_target_impl(g: jit_utils.GraphContext, input, target):
|
||
|
log_ = log(g, target)
|
||
|
diff_ = sub(g, log_, input)
|
||
|
output_pos = mul(g, target, diff_)
|
||
|
zeros_ = zeros_like(g, output_pos)
|
||
|
mask_ = gt(g, target, g.op("Constant", value_t=torch.tensor(0)))
|
||
|
output = where(g, mask_, output_pos, zeros_)
|
||
|
return output
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::kl_div")
|
||
|
@symbolic_helper.parse_args("v", "v", "i", "b")
|
||
|
@_beartype.beartype
|
||
|
def kl_div(g: jit_utils.GraphContext, input, target, reduction, log_target):
|
||
|
if log_target:
|
||
|
output = _kl_div_log_target_impl(g, input, target)
|
||
|
else:
|
||
|
output = _kl_div_non_log_target_impl(g, input, target)
|
||
|
|
||
|
if reduction == 0:
|
||
|
return output
|
||
|
elif reduction == 1:
|
||
|
return g.op("ReduceMean", output, keepdims_i=0)
|
||
|
elif reduction == 2:
|
||
|
return symbolic_helper._reducesum_helper(g, output, keepdims_i=0)
|
||
|
else:
|
||
|
return symbolic_helper._onnx_unsupported(
|
||
|
"kl_div with reduction other than none, mean, or sum.", input
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::mse_loss")
|
||
|
@symbolic_helper.parse_args("v", "v", "i")
|
||
|
@_beartype.beartype
|
||
|
def mse_loss(g: jit_utils.GraphContext, input, target, reduction):
|
||
|
output = mul(g, sub(g, input, target), sub(g, input, target))
|
||
|
if reduction == 0:
|
||
|
return output
|
||
|
elif reduction == 1:
|
||
|
return g.op("ReduceMean", output, keepdims_i=0)
|
||
|
elif reduction == 2:
|
||
|
return symbolic_helper._reducesum_helper(g, output, keepdims_i=0)
|
||
|
else:
|
||
|
return symbolic_helper._onnx_unsupported(
|
||
|
"mse_loss with reduction other than none, mean, or sum.", input
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::as_strided")
|
||
|
@symbolic_helper.quantized_args(True)
|
||
|
@symbolic_helper.parse_args("v", "v", "is", "i")
|
||
|
@_beartype.beartype
|
||
|
def as_strided(g: jit_utils.GraphContext, self, sizes, strides, offset=None):
|
||
|
sizes = symbolic_helper._maybe_get_const(sizes, "is")
|
||
|
rank = len(strides)
|
||
|
self_1d = symbolic_helper._reshape_helper(
|
||
|
g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
|
||
|
)
|
||
|
ind: Optional[torch.Tensor]
|
||
|
if not symbolic_helper._is_value(sizes):
|
||
|
ind = torch.tensor([0], dtype=torch.long)
|
||
|
for i, (size, stride) in enumerate(zip(sizes, strides)):
|
||
|
r_size = [1] * rank
|
||
|
r_size[i] = -1
|
||
|
ind = ind + torch.arange(size).view(r_size) * stride
|
||
|
if offset:
|
||
|
ind = ind + offset
|
||
|
return g.op("Gather", self_1d, g.op("Constant", value_t=ind))
|
||
|
else:
|
||
|
ind = None
|
||
|
for i, stride in enumerate(strides):
|
||
|
r_size = [1] * rank
|
||
|
r_size[i] = -1
|
||
|
size = select(
|
||
|
g,
|
||
|
sizes,
|
||
|
g.op("Constant", value_t=torch.tensor([0])),
|
||
|
g.op("Constant", value_t=torch.tensor(i)),
|
||
|
)
|
||
|
tmp_ind = symbolic_helper._reshape_helper(
|
||
|
g,
|
||
|
arange(g, size, 4, None, None, None),
|
||
|
g.op("Constant", value_t=torch.tensor(r_size)),
|
||
|
)
|
||
|
tmp_ind = g.op(
|
||
|
"Mul", tmp_ind, g.op("Constant", value_t=torch.tensor([stride]))
|
||
|
)
|
||
|
if ind is None:
|
||
|
ind = tmp_ind
|
||
|
else:
|
||
|
ind = g.op("Add", ind, tmp_ind)
|
||
|
if offset:
|
||
|
ind = g.op("Add", ind, g.op("Constant", torch.tensor([offset])))
|
||
|
return g.op("Gather", self_1d, ind)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::__derive_index")
|
||
|
@_beartype.beartype
|
||
|
def __derive_index(g: jit_utils.GraphContext, index, start, step):
|
||
|
return g.op("Add", start, g.op("Mul", index, step))
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::__range_length")
|
||
|
# Source code for aten op can be found here: pytorch/torch/csrc/jit/runtime/register_prim_ops.cpp
|
||
|
# if (step > 0 && lo < hi) {
|
||
|
# push(stack, 1 + (hi - 1 - lo) / step);
|
||
|
# } else if (step < 0 && lo > hi) {
|
||
|
# push(stack, 1 + (lo - 1 - hi) / (0 - step));
|
||
|
# } else {
|
||
|
# push(stack, 0);
|
||
|
# }
|
||
|
@_beartype.beartype
|
||
|
def __range_length(g: jit_utils.GraphContext, lo, hi, step):
|
||
|
sub = g.op("Sub", hi, lo)
|
||
|
div = g.op("Ceil", true_divide(g, sub, step))
|
||
|
return g.op("Cast", div, to_i=_C_onnx.TensorProtoDataType.INT64)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::linear")
|
||
|
@_beartype.beartype
|
||
|
def linear(g: jit_utils.GraphContext, input, weight, bias):
|
||
|
rank = symbolic_helper._get_tensor_rank(input)
|
||
|
weight = t(g, weight)
|
||
|
if rank == 2 and not bias.node().mustBeNone():
|
||
|
alpha = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
|
||
|
beta = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
|
||
|
output = addmm(g, bias, input, weight, alpha, beta)
|
||
|
else:
|
||
|
output = matmul(g, input, weight)
|
||
|
if not bias.node().mustBeNone():
|
||
|
output = add(g, bias, output)
|
||
|
|
||
|
return output
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::hann_window")
|
||
|
@symbolic_helper.parse_args("v", "b", "i", "v", "v", "v", "v")
|
||
|
@_beartype.beartype
|
||
|
def hann_window(
|
||
|
g: jit_utils.GraphContext,
|
||
|
window_length,
|
||
|
periodic=True,
|
||
|
dtype: Optional[int] = None,
|
||
|
layout=None,
|
||
|
device=None,
|
||
|
pin_memory=None,
|
||
|
requires_grad=False,
|
||
|
):
|
||
|
if dtype is None:
|
||
|
dtype_ = torch.get_default_dtype()
|
||
|
if not dtype_ or not dtype_.is_floating_point:
|
||
|
dtype_ = torch.float
|
||
|
scalar_type = _type_utils.JitScalarType.from_dtype(dtype_)
|
||
|
else:
|
||
|
scalar_type = _type_utils.JitScalarType(dtype)
|
||
|
|
||
|
n_array = arange(g, window_length, 4, None, None, None)
|
||
|
output = g.op("Cast", n_array, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
||
|
output = mul(
|
||
|
g, g.op("Constant", value_t=torch.tensor(math.pi, dtype=torch.float)), output
|
||
|
)
|
||
|
|
||
|
if periodic is False:
|
||
|
window_length = sub(
|
||
|
g, window_length, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int))
|
||
|
)
|
||
|
output = div(g, output, window_length)
|
||
|
output = g.op(
|
||
|
"Cast",
|
||
|
square(g, sin(g, output)),
|
||
|
to_i=scalar_type.onnx_type(),
|
||
|
)
|
||
|
|
||
|
return output
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::mv")
|
||
|
@_beartype.beartype
|
||
|
def mv(g: jit_utils.GraphContext, self, vec):
|
||
|
return matmul(g, self, vec)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::dot")
|
||
|
@_beartype.beartype
|
||
|
def dot(g: jit_utils.GraphContext, self, other):
|
||
|
return matmul(g, self, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::movedim")
|
||
|
@symbolic_helper.parse_args("v", "t", "t")
|
||
|
@_beartype.beartype
|
||
|
def movedim(g: jit_utils.GraphContext, self, source, destination):
|
||
|
# This is a pythonic implementation mostly taken from aten/src/ATen/native/TensorShape.cpp::movedim
|
||
|
source = source.view(-1)
|
||
|
destination = destination.view(-1)
|
||
|
|
||
|
assert source.size() == destination.size()
|
||
|
|
||
|
if (source == destination).all():
|
||
|
return self
|
||
|
|
||
|
self_rank = symbolic_helper._get_tensor_rank(self)
|
||
|
assert self_rank is not None
|
||
|
|
||
|
perm = list(range(self_rank))
|
||
|
|
||
|
src_dims = perm.copy()
|
||
|
dst_dims = perm.copy()
|
||
|
|
||
|
for src, dst in zip(source.tolist(), destination.tolist()):
|
||
|
perm[dst] = src
|
||
|
src_dims[src] = -1
|
||
|
dst_dims[dst] = -1
|
||
|
|
||
|
src_dims = [dim for dim in src_dims if dim != -1]
|
||
|
dst_dims = [dim for dim in dst_dims if dim != -1]
|
||
|
|
||
|
for src, dst in zip(src_dims, dst_dims):
|
||
|
perm[dst] = src
|
||
|
|
||
|
return g.op("Transpose", self, perm_i=perm)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::fill")
|
||
|
@symbolic_helper.parse_args("v", "v")
|
||
|
@_beartype.beartype
|
||
|
def fill(g: jit_utils.GraphContext, self, value):
|
||
|
scalar_type = _type_utils.JitScalarType.from_value(
|
||
|
self, _type_utils.JitScalarType.FLOAT
|
||
|
)
|
||
|
return full_like(g, self, value, scalar_type)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::index_add")
|
||
|
@_beartype.beartype
|
||
|
def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None):
|
||
|
warnings.warn(
|
||
|
"Warning: ONNX export does not support duplicated values in 'index' field, "
|
||
|
+ "this will cause the ONNX model to be incorrect."
|
||
|
)
|
||
|
|
||
|
# ONNX does not support "alpha" argument, unlike aten index_add
|
||
|
# See: https://github.com/pytorch/pytorch/pull/65993#issuecomment-953151102 for more context
|
||
|
if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1:
|
||
|
return symbolic_helper._unimplemented("index_add", "alpha != 1", self)
|
||
|
|
||
|
dim = symbolic_helper._maybe_get_const(dim, "i")
|
||
|
if dim is None:
|
||
|
raise errors.SymbolicValueError(
|
||
|
"ONNX export does NOT support exporting 'index_add_()' function with "
|
||
|
"unknown 'dim' value.",
|
||
|
self,
|
||
|
)
|
||
|
|
||
|
self_dim_rank = symbolic_helper._get_tensor_rank(self)
|
||
|
other_dim_rank = symbolic_helper._get_tensor_rank(other)
|
||
|
|
||
|
if self_dim_rank is None or other_dim_rank is None:
|
||
|
raise errors.SymbolicValueError(
|
||
|
"ONNX export does NOT support exporting 'index_add_()' function while "
|
||
|
"the rank of self tensor or tensor to be added is unknown.",
|
||
|
self,
|
||
|
)
|
||
|
|
||
|
if other_dim_rank != self_dim_rank:
|
||
|
delta = self_dim_rank - other_dim_rank
|
||
|
for i in range(delta):
|
||
|
other = symbolic_helper._unsqueeze_helper(
|
||
|
g, other, [symbolic_helper._get_tensor_rank(other)]
|
||
|
)
|
||
|
|
||
|
other_dim_size = symbolic_helper._get_tensor_dim_size(other, dim)
|
||
|
self_dim_size = symbolic_helper._get_tensor_dim_size(self, dim)
|
||
|
|
||
|
if (other_dim_size is not None) and (self_dim_size is not None):
|
||
|
if other_dim_size > self_dim_size:
|
||
|
raise errors.SymbolicValueError(
|
||
|
"ONNX export does not support exporting 'index_add_()' function with "
|
||
|
"duplicated values in 'index' parameter yet.",
|
||
|
self,
|
||
|
)
|
||
|
|
||
|
# Construct a new shape. It's almost as same as self except the size of the 'dim'
|
||
|
# dimension is 1, so that we can expand other dimensions as expected.
|
||
|
new_shape_axes = list(range(self_dim_rank))
|
||
|
new_shape_starts = [0 for i in range(self_dim_rank)]
|
||
|
new_shape_ends = [sys.maxsize if (i != dim) else 1 for i in range(self_dim_rank)]
|
||
|
|
||
|
new_shape = symbolic_helper._slice_helper(
|
||
|
g, self, axes=new_shape_axes, starts=new_shape_starts, ends=new_shape_ends
|
||
|
)
|
||
|
other = expand_as(g, other, new_shape)
|
||
|
|
||
|
for i in range(dim):
|
||
|
index = symbolic_helper._unsqueeze_helper(g, index, [0])
|
||
|
|
||
|
for i in range(self_dim_rank - dim - 1):
|
||
|
index = symbolic_helper._unsqueeze_helper(
|
||
|
g, index, [symbolic_helper._get_tensor_rank(index)]
|
||
|
)
|
||
|
|
||
|
return scatter_add(g, self, dim, expand_as(g, index, other), other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::roll")
|
||
|
@symbolic_helper.parse_args("v", "is", "is")
|
||
|
@_beartype.beartype
|
||
|
def roll(g: jit_utils.GraphContext, self, shifts, dims):
|
||
|
assert len(shifts) == len(dims)
|
||
|
|
||
|
result = self
|
||
|
for i in range(len(shifts)):
|
||
|
shapes = []
|
||
|
shape = symbolic_helper._slice_helper(
|
||
|
g, result, axes=[dims[i]], starts=[-shifts[i]], ends=[sys.maxsize]
|
||
|
)
|
||
|
shapes.append(shape)
|
||
|
shape = symbolic_helper._slice_helper(
|
||
|
g, result, axes=[dims[i]], starts=[0], ends=[-shifts[i]]
|
||
|
)
|
||
|
shapes.append(shape)
|
||
|
result = g.op("Concat", *shapes, axis_i=dims[i])
|
||
|
|
||
|
return result
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::cross")
|
||
|
@symbolic_helper.parse_args("v", "v", "i")
|
||
|
@_beartype.beartype
|
||
|
def cross(g: jit_utils.GraphContext, input, other, dim=None):
|
||
|
dim = symbolic_helper._get_dim_for_cross(input, dim)
|
||
|
# If we have two tensors such that
|
||
|
# A = [a, b, c], B = [d, e, f], we permute the tensor such that we have
|
||
|
# After first roll,
|
||
|
# A' = [b, c, a], B' = [f, d, e], so that we calculate (b*f, c*d, a*e)
|
||
|
roll_x_1 = roll(g, input, [2], [dim])
|
||
|
roll_y_1 = roll(g, other, [1], [dim])
|
||
|
# After second roll,
|
||
|
# A' = [c, a, b], B' = [e, f, d], so that we calculate (c*e, a*f, b*d)
|
||
|
roll_x_2 = roll(g, input, [1], [dim])
|
||
|
roll_y_2 = roll(g, other, [2], [dim])
|
||
|
# cross product is calculated as
|
||
|
# result = [(b*f - c*e), (c*d - a*f), (a*e - b*d)]
|
||
|
return sub(g, mul(g, roll_x_1, roll_y_1), mul(g, roll_x_2, roll_y_2))
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::cdist")
|
||
|
@_beartype.beartype
|
||
|
def cdist(
|
||
|
g: jit_utils.GraphContext,
|
||
|
x1,
|
||
|
x2,
|
||
|
p=2.0,
|
||
|
compute_mode="use_mm_for_euclid_dist_if_necessary",
|
||
|
):
|
||
|
# X1.shape = (B * P * D), X2.shape = (B * R * D)
|
||
|
# In order to respect numpy style broadcasting as demonstrated in
|
||
|
# https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md
|
||
|
# we unsqueeze both input tensors
|
||
|
# Currently we ignore the 'compute_mode' variable as we use default to
|
||
|
# using matrix multiplication to calculate the euclidean distance
|
||
|
rank = symbolic_helper._get_tensor_rank(x1)
|
||
|
assert rank is not None
|
||
|
broadcasted_x1 = symbolic_helper._unsqueeze_helper(g, x1, [rank - 1])
|
||
|
broadcasted_x2 = symbolic_helper._unsqueeze_helper(g, x2, [rank - 2])
|
||
|
return pairwise_distance(
|
||
|
g, broadcasted_x1, broadcasted_x2, p, eps=1e-06, keepdim=False
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::lerp")
|
||
|
@_beartype.beartype
|
||
|
def lerp(g: jit_utils.GraphContext, self, end, weight):
|
||
|
# Conditional for better numeric. This has been discussed in
|
||
|
# https://github.com/pytorch/pytorch/pull/18871
|
||
|
diff = g.op("Sub", end, self)
|
||
|
return where(
|
||
|
g,
|
||
|
g.op("Less", weight, g.op("Constant", value_t=torch.tensor(0.5))),
|
||
|
g.op("Add", self, g.op("Mul", weight, diff)),
|
||
|
g.op(
|
||
|
"Sub",
|
||
|
end,
|
||
|
g.op(
|
||
|
"Mul",
|
||
|
diff,
|
||
|
g.op("Sub", g.op("Constant", value_t=torch.tensor(1.0)), weight),
|
||
|
),
|
||
|
),
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::broadcast_tensors")
|
||
|
@_beartype.beartype
|
||
|
def broadcast_tensors(g: jit_utils.GraphContext, self):
|
||
|
all_tensors = symbolic_helper._unpack_list(self)
|
||
|
t_with_final_shape = zeros_like(g, all_tensors[0])
|
||
|
|
||
|
# Add operator supports multidirectional broadcasting. So we leverage this function
|
||
|
# to infer the final shape generated by the broadcast.
|
||
|
for t in all_tensors:
|
||
|
t_with_final_shape = add(g, t_with_final_shape, t)
|
||
|
|
||
|
t_list = [expand_as(g, t, t_with_final_shape) for t in all_tensors]
|
||
|
return g.op("prim::ListConstruct", *t_list)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::is_pinned")
|
||
|
def is_pinned(g: jit_utils.GraphContext, self, device=None):
|
||
|
# Unused by ONNX.
|
||
|
return None
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("prim::ConstantSplit")
|
||
|
@_beartype.beartype
|
||
|
def prim_constant_split(g: jit_utils.GraphContext, self, split_size, dim):
|
||
|
size = symbolic_helper._get_tensor_dim_size(self, dim)
|
||
|
if size is None:
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"prim::ConstantSplit", "unknown dimension size", self
|
||
|
)
|
||
|
splits = [split_size] * (size // split_size)
|
||
|
leftover = size % split_size
|
||
|
if leftover:
|
||
|
splits.append(leftover)
|
||
|
return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits))
|
||
|
|
||
|
|
||
|
# TODO: It would be better to export this as a chunk directly, as this is
|
||
|
# less sensitive to changes in input size.
|
||
|
# TODO: Once we have proper scoping, stop reimplementing chunk, delete this
|
||
|
# method, and use the desugared version
|
||
|
@_onnx_symbolic("prim::ConstantChunk")
|
||
|
@_beartype.beartype
|
||
|
def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim):
|
||
|
dim_size = symbolic_helper._get_tensor_dim_size(self, dim)
|
||
|
if dim_size is None:
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"prim::ConstantChunk", "unknown dimension size", self
|
||
|
)
|
||
|
split_size = (dim_size + chunks - 1) // chunks
|
||
|
return prim_constant_split(g, self, split_size, dim)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("prim::shape")
|
||
|
@_beartype.beartype
|
||
|
def prim_shape(g: jit_utils.GraphContext, self):
|
||
|
return g.op("Shape", self)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("prim::max")
|
||
|
@_beartype.beartype
|
||
|
def prim_max(g: jit_utils.GraphContext, self, other):
|
||
|
return _op_with_optional_float_cast(g, "Max", self, other, opset_before=12)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("prim::min")
|
||
|
@_beartype.beartype
|
||
|
def prim_min(g: jit_utils.GraphContext, self, other=None):
|
||
|
if not other:
|
||
|
if symbolic_helper._is_packed_list(self):
|
||
|
self = stack(g, self, g.op("Constant", value_t=torch.tensor([0])))
|
||
|
return min(g, self)
|
||
|
return min(g, self, other)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("prim::data")
|
||
|
@_beartype.beartype
|
||
|
def prim_data(g: jit_utils.GraphContext, self):
|
||
|
return self
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("prim::layout")
|
||
|
def prim_layout(g: jit_utils.GraphContext, self):
|
||
|
# Always return 'torch.strided'. Other layout types are not supported by JIT 'TensorType'.
|
||
|
# Layout class defined in 'c10/core/Layout.h'.
|
||
|
return g.op("Constant", value_t=torch.tensor(0))
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("prim::ListConstruct")
|
||
|
@_beartype.beartype
|
||
|
def prim_list_construct(g: jit_utils.GraphContext, *inputs, **kwargs):
|
||
|
return None
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("prim::ListUnpack")
|
||
|
@_beartype.beartype
|
||
|
def prim_list_unpack(
|
||
|
g: jit_utils.GraphContext, *inputs, **kwargs
|
||
|
) -> Optional[List[_C.Value]]:
|
||
|
if len(inputs) == 1 and inputs[0].node().kind() == "prim::ListConstruct":
|
||
|
# Cancel the previous node if it is ListConstruct by returning its inputs
|
||
|
# TODO(justinchuby): Use a public method in the helper module
|
||
|
return symbolic_helper._unpack_list(inputs[0])
|
||
|
|
||
|
return None
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("prim::TupleConstruct")
|
||
|
@_beartype.beartype
|
||
|
def prim_tuple_construct(g: jit_utils.GraphContext, *inputs, **kwargs):
|
||
|
return None
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("prim::Uninitialized")
|
||
|
@_beartype.beartype
|
||
|
def prim_uninitialized(g: jit_utils.GraphContext, *inputs, **kwargs):
|
||
|
return None
|
||
|
|
||
|
|
||
|
# exists to refine the type of the Value
|
||
|
# if x is an optional Tensor, unchecked_cast will cast
|
||
|
# x to Tensor, so the rest of the graph knows that x is a Tensor
|
||
|
# this doesn't do anything in runtime and is a noop in ONNX
|
||
|
@_onnx_symbolic("prim::unchecked_cast")
|
||
|
@_beartype.beartype
|
||
|
def prim_unchecked_cast(g: jit_utils.GraphContext, self):
|
||
|
return self
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("prim::dtype")
|
||
|
@_beartype.beartype
|
||
|
def prim_dtype(g: jit_utils.GraphContext, self):
|
||
|
scalar_type = symbolic_helper._try_get_scalar_type(self)
|
||
|
if scalar_type is None:
|
||
|
scalar_type = _type_utils.JitScalarType.FLOAT
|
||
|
# This node records a torch dtype as int
|
||
|
return g.op("Constant", value_t=torch.tensor(scalar_type))
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("prim::tolist")
|
||
|
@_beartype.beartype
|
||
|
def prim_tolist(g: jit_utils.GraphContext, input, dim_val, elem_ty_val):
|
||
|
"""tolist is currently supported only for 1D input tensors.
|
||
|
|
||
|
dim_val and elem_ty_val represent dimension and type annotations
|
||
|
that need to match dimension and type of the input tensor.
|
||
|
"""
|
||
|
dim = symbolic_helper._maybe_get_const(dim_val, "i")
|
||
|
if dim > 1:
|
||
|
return symbolic_helper._unimplemented("prim::tolist", "dim_val > 1", input)
|
||
|
return input
|
||
|
|
||
|
|
||
|
# -----------------------------------------------------------------------------
|
||
|
# Symbolic functions that need extra context
|
||
|
# -----------------------------------------------------------------------------
|
||
|
@_onnx_symbolic("prim::device")
|
||
|
@_beartype.beartype
|
||
|
def prim_device(g: jit_utils.GraphContext, *inputs, **kwargs) -> None:
|
||
|
output_type = g.original_node.output().type()
|
||
|
if isinstance(output_type, _C.DeviceObjType):
|
||
|
return None
|
||
|
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"prim::device",
|
||
|
f"output type should be 'DeviceObjType', not '{output_type.kind()}'",
|
||
|
g.original_node.output(),
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("prim::Loop")
|
||
|
@_beartype.beartype
|
||
|
def prim_loop(g: jit_utils.GraphContext, *inputs, **attrs) -> List[_C.Value]:
|
||
|
node = g.original_node
|
||
|
env = g.env
|
||
|
params_dict = g.params_dict
|
||
|
|
||
|
operator_export_type = GLOBALS.operator_export_type
|
||
|
opset_version = GLOBALS.export_onnx_opset_version
|
||
|
|
||
|
old_blocks = tuple(node.blocks())
|
||
|
new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks(
|
||
|
g, "Loop", *inputs, outputs=node.outputsSize(), n_blocks=len(old_blocks)
|
||
|
)
|
||
|
|
||
|
for old_block, new_block_context in zip(old_blocks, new_block_contexts):
|
||
|
# Copy input metadata to subblock
|
||
|
#
|
||
|
# prim::Loop(iter, cond, input_1, ..., input_n)
|
||
|
# block0(iter, input_1, ..., input_n)
|
||
|
#
|
||
|
# For `Loop` node, copy metadata for `iter`, `input_1`, ..., `input_n`.
|
||
|
for i, b_in in enumerate(old_block.inputs()):
|
||
|
if i == 0 and i < len(inputs):
|
||
|
b_in.setType(inputs[i].type())
|
||
|
# For optional block inputs, they may switch between None not-None inside
|
||
|
# the loop body, so if the loop input is not optional, the block input may
|
||
|
# still need to be optional.
|
||
|
if (
|
||
|
i > 0
|
||
|
and (i + 1) < len(inputs)
|
||
|
and not isinstance(b_in.type(), _C.OptionalType)
|
||
|
):
|
||
|
b_in.setType(inputs[i + 1].type())
|
||
|
torch._C._jit_pass_onnx_block(
|
||
|
old_block,
|
||
|
new_block_context.block,
|
||
|
operator_export_type,
|
||
|
env,
|
||
|
False,
|
||
|
)
|
||
|
fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(
|
||
|
new_node, opset_version
|
||
|
)
|
||
|
# Run shape type inference for Loop after subblock is converted.
|
||
|
if GLOBALS.onnx_shape_inference:
|
||
|
torch._C._jit_pass_onnx_node_shape_type_inference(
|
||
|
new_node, params_dict, opset_version
|
||
|
)
|
||
|
return fixed_outputs
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("prim::If")
|
||
|
@_beartype.beartype
|
||
|
def prim_if(g: jit_utils.GraphContext, *inputs, **attrs) -> List[_C.Value]:
|
||
|
n = g.original_node
|
||
|
block = g.block
|
||
|
env = g.env
|
||
|
params_dict = g.params_dict
|
||
|
|
||
|
operator_export_type = GLOBALS.operator_export_type
|
||
|
opset_version = GLOBALS.export_onnx_opset_version
|
||
|
|
||
|
static_if = inputs[0].node().kind() == "onnx::Constant"
|
||
|
if static_if:
|
||
|
# Fold static if
|
||
|
#
|
||
|
# The torch IR
|
||
|
# graph(%embedding_matrix.1 : Float(10, 15, strides=[15, 1], requires_grad=0, device=cpu),
|
||
|
# %input.1 : Long(6, strides=[1], requires_grad=0, device=cpu), ...
|
||
|
# %65 : Bool(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
|
||
|
# %21 : Long(device=cpu) = aten::eq(%20, %64)
|
||
|
# %22 : Long(device=cpu) = prim::If(%21)
|
||
|
# block0():
|
||
|
# %23 : Long(device=cpu) = aten::is_floating_point(%input.1)
|
||
|
# -> (%23)
|
||
|
# block1():
|
||
|
# -> (%65)
|
||
|
# %input.53 : Tensor, %weight : Tensor = prim::If(%22)
|
||
|
# block0():
|
||
|
# -> (%embedding_matrix.1, %input.1)
|
||
|
# block1():
|
||
|
# -> (%input.1, %embedding_matrix.1)
|
||
|
# %26 : int[] = aten::size(%input.53)
|
||
|
#
|
||
|
# The converted ONNX graph
|
||
|
# %10 : Bool(device=cpu) = onnx::Constant[value={0}]()
|
||
|
# %14 : Bool(device=cpu) = onnx::Equal(%13, %8)
|
||
|
# %15 : Bool(requires_grad=0, device=cpu) = onnx::Constant[value={0}]()
|
||
|
# %16 : Long(1, strides=[1], device=cpu) = onnx::Shape(%input.1)
|
||
|
input_flag = symbolic_helper._node_get(inputs[0].node(), "value").tolist()
|
||
|
const_value = (
|
||
|
all(input_flag) if isinstance(input_flag, list) else bool(input_flag)
|
||
|
)
|
||
|
block_idx = 0 if const_value else 1
|
||
|
current_b = list(n.blocks())[block_idx]
|
||
|
env = torch._C._jit_pass_onnx_block(
|
||
|
current_b,
|
||
|
block,
|
||
|
operator_export_type,
|
||
|
env,
|
||
|
True,
|
||
|
)
|
||
|
if_output_list = list(n.outputs())
|
||
|
current_b_list = list(current_b.outputs())
|
||
|
|
||
|
final_b_list = []
|
||
|
for idx in range(len(if_output_list)):
|
||
|
if current_b_list[idx] not in env:
|
||
|
raise errors.SymbolicValueError(
|
||
|
f"The sub block ATen output {current_b_list[idx]} is not in env.",
|
||
|
current_b_list[idx],
|
||
|
) # type:ignore[operator]
|
||
|
onnx_b = env[current_b_list[idx]]
|
||
|
final_b_list.append(onnx_b)
|
||
|
return final_b_list
|
||
|
else:
|
||
|
old_blocks = tuple(n.blocks())
|
||
|
new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks(
|
||
|
g, "If", *inputs, outputs=n.outputsSize(), n_blocks=len(old_blocks)
|
||
|
)
|
||
|
|
||
|
for old_block, new_block_context in zip(old_blocks, new_block_contexts):
|
||
|
torch._C._jit_pass_onnx_block(
|
||
|
old_block,
|
||
|
new_block_context.block,
|
||
|
operator_export_type,
|
||
|
env,
|
||
|
False,
|
||
|
)
|
||
|
fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(
|
||
|
new_node, opset_version
|
||
|
)
|
||
|
# Run shape type inference for If after subblock is converted.
|
||
|
if GLOBALS.onnx_shape_inference:
|
||
|
torch._C._jit_pass_onnx_node_shape_type_inference(
|
||
|
new_node, params_dict, opset_version
|
||
|
)
|
||
|
return fixed_outputs
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("prim::Constant")
|
||
|
@_beartype.beartype
|
||
|
def prim_constant(g: jit_utils.GraphContext, *inputs, **attrs):
|
||
|
node = g.original_node
|
||
|
|
||
|
if node.mustBeNone():
|
||
|
return None
|
||
|
# This must go before checking for string values, because some device constants
|
||
|
# have string values, but we want to keep them as unconverted Device types so
|
||
|
# that eq() can work on them.
|
||
|
if isinstance(node.output().type(), _C.DeviceObjType):
|
||
|
return None
|
||
|
if node.kindOf("value") == "t":
|
||
|
return g.op("Constant", value_t=symbolic_helper._node_get(node, "value"))
|
||
|
if node.kindOf("value") == "s":
|
||
|
return g.op("Constant", value_s=symbolic_helper._node_get(node, "value"))
|
||
|
if node.output().type().isSubtypeOf(
|
||
|
_C.ListType.ofInts()
|
||
|
) or node.output().type().isSubtypeOf(_C.ListType.ofFloats()):
|
||
|
return g.op(
|
||
|
"Constant", value_t=torch.tensor(symbolic_helper._node_get(node, "value"))
|
||
|
)
|
||
|
if node.output().type().isSubtypeOf(_C.ListType.ofStrings()):
|
||
|
str_constants = [
|
||
|
g.op("Constant", value_s=s)
|
||
|
for s in symbolic_helper._node_get(node, "value")
|
||
|
]
|
||
|
return g.op("prim::ListConstruct", *str_constants)
|
||
|
|
||
|
raise errors.SymbolicValueError(
|
||
|
f"Unsupported prim::Constant kind: '{node.kindOf('value')}'. "
|
||
|
f"Please send a bug report at {_constants.PYTORCH_GITHUB_ISSUES_URL}.",
|
||
|
node.output(),
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("prim::type")
|
||
|
@_beartype.beartype
|
||
|
def prim_type(g: jit_utils.GraphContext, device_value: _C.Value, *args, **kwargs):
|
||
|
if device_value.node().kind() == "prim::device":
|
||
|
device = jit_utils.get_device_from_value(device_value.node().input())
|
||
|
if device is not None:
|
||
|
return g.op("Constant", value_s=str(device))
|
||
|
|
||
|
return symbolic_helper._unimplemented(
|
||
|
"prim::type",
|
||
|
"Device type cannot be statically determined.",
|
||
|
device_value,
|
||
|
)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("onnx::Placeholder")
|
||
|
@_beartype.beartype
|
||
|
def onnx_placeholder(g: jit_utils.GraphContext, *inputs, **attrs):
|
||
|
node = g.original_node
|
||
|
block = g.block
|
||
|
env = g.env
|
||
|
|
||
|
return torch._C._jit_onnx_convert_pattern_from_subblock(block, node, env)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::resolve_conj")
|
||
|
@_onnx_symbolic("aten::resolve_neg")
|
||
|
@_beartype.beartype
|
||
|
def noop_complex_operators(g: jit_utils.GraphContext, input: _C.Value):
|
||
|
# ONNX does not have operators to *directly* manipulate real/imaginary components
|
||
|
# However, a few torch APIs (e.g. .tolist()) use complex operations when input is real,
|
||
|
# which results in failures due to missing operators for complex numbers
|
||
|
|
||
|
# `aten::resolve_conj` and `aten::resolve_neg` can safely be implemented as no-op
|
||
|
return input
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::_conj")
|
||
|
@_onnx_symbolic("aten::conj_physical")
|
||
|
@_beartype.beartype
|
||
|
def unsupported_complex_operators(g: jit_utils.GraphContext, input: _C.Value):
|
||
|
# ONNX does not have operators to *directly* manipulate real/imaginary components
|
||
|
# However, a few torch APIs (e.g. .tolist()) use complex operations when input is real,
|
||
|
# which results in failures due to missing operators for complex numbers
|
||
|
|
||
|
# While `aten::_conj` and `aten::conj_physical` raise exception when input is complex
|
||
|
if symbolic_helper.is_complex_value(input):
|
||
|
# FIXME(justinchuby): report correct name for symbolic being executed
|
||
|
return symbolic_helper._onnx_unsupported(
|
||
|
"aten::_conj, aten::conj_physical",
|
||
|
input,
|
||
|
)
|
||
|
|
||
|
# they can safely be implemented as no-op for real numbers only
|
||
|
return noop_complex_operators(g, input)
|
||
|
|
||
|
|
||
|
@_onnx_symbolic("aten::logit")
|
||
|
@_beartype.beartype
|
||
|
def logit(g: jit_utils.GraphContext, self: torch._C.Value, eps: torch._C.Value):
|
||
|
one = g.op("Constant", value_t=torch.tensor(1.0))
|
||
|
|
||
|
if not symbolic_helper._is_none(eps):
|
||
|
eps = g.op(
|
||
|
"Cast", eps, to_i=_type_utils.JitScalarType.from_value(self).onnx_type()
|
||
|
)
|
||
|
one_sub_eps = g.op("Sub", one, eps)
|
||
|
self_less_equal_one_sub_eps = g.op("Greater", one_sub_eps, self)
|
||
|
temporary_self = g.op("Where", self_less_equal_one_sub_eps, self, one_sub_eps)
|
||
|
|
||
|
temporary_self_less_eps = g.op("Less", temporary_self, eps)
|
||
|
z = g.op("Where", temporary_self_less_eps, eps, temporary_self)
|
||
|
else:
|
||
|
z = self
|
||
|
|
||
|
sub = g.op("Sub", one, z)
|
||
|
div = g.op("Div", z, sub)
|
||
|
return g.op("Log", div)
|