471 lines
15 KiB
Python
471 lines
15 KiB
Python
"""
|
|
Note [ONNX operators that are added/updated from opset 8 to opset 9]
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
New operators:
|
|
Compress
|
|
ConstantOfShape
|
|
EyeLike
|
|
MaxUnpool
|
|
OneHot
|
|
Sinh
|
|
Cosh
|
|
Asinh
|
|
Acosh
|
|
Atanh
|
|
Shrink
|
|
IsNaN
|
|
Sign
|
|
Erf
|
|
Scatter
|
|
Where
|
|
NonZero
|
|
TfIdfVectorizer
|
|
MeanVarianceNormalization
|
|
|
|
Updated operators:
|
|
BatchNormalization: removed spatial attribute.
|
|
Greater, Less, Constant, MatMul, PRelu, Gemm, Flatten: more data types{integers} supported.
|
|
Cast: more data types{string} supported.
|
|
Upsample: moved scales from attribute to input.
|
|
Scan
|
|
"""
|
|
|
|
import functools
|
|
import warnings
|
|
|
|
import torch
|
|
from torch._C import _onnx as _C_onnx
|
|
from torch.onnx import _type_utils, errors, symbolic_helper, symbolic_opset9 as opset9
|
|
from torch.onnx._internal import jit_utils, registration
|
|
|
|
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=8)
|
|
|
|
block_listed_operators = (
|
|
"nonzero",
|
|
"where",
|
|
"scatter",
|
|
"scatter_add",
|
|
"erf",
|
|
"sign",
|
|
"isnan",
|
|
"gather",
|
|
"arange",
|
|
"masked_fill",
|
|
"index_fill",
|
|
"index_copy",
|
|
"repeat_interleave",
|
|
"any",
|
|
"all",
|
|
)
|
|
|
|
for block_listed_op in block_listed_operators:
|
|
_onnx_symbolic(f"aten::{block_listed_op}")(
|
|
symbolic_helper._block_list_in_opset(block_listed_op)
|
|
)
|
|
|
|
|
|
def _apply_params(*args, **kwargs):
|
|
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
|
|
|
|
def _apply(fn):
|
|
return fn(*args, **kwargs)
|
|
|
|
return _apply
|
|
|
|
|
|
@_onnx_symbolic(
|
|
"aten::upsample_nearest1d",
|
|
decorate=[_apply_params("upsample_nearest1d", 3, "nearest")],
|
|
)
|
|
@_onnx_symbolic(
|
|
"aten::upsample_nearest2d",
|
|
decorate=[_apply_params("upsample_nearest2d", 4, "nearest")],
|
|
)
|
|
@_onnx_symbolic(
|
|
"aten::upsample_nearest3d",
|
|
decorate=[_apply_params("upsample_nearest3d", 5, "nearest")],
|
|
)
|
|
@_onnx_symbolic(
|
|
"aten::upsample_linear1d",
|
|
decorate=[_apply_params("upsample_linear1d", 3, "linear")],
|
|
)
|
|
@_onnx_symbolic(
|
|
"aten::upsample_bilinear2d",
|
|
decorate=[_apply_params("upsample_bilinear2d", 4, "linear")],
|
|
)
|
|
@_onnx_symbolic(
|
|
"aten::upsample_trilinear3d",
|
|
decorate=[_apply_params("upsample_trilinear3d", 5, "linear")],
|
|
)
|
|
def _interpolate(name, dim, interpolate_mode):
|
|
def symbolic_fn(g, input, output_size, *args):
|
|
scales, align_corners = symbolic_helper._get_interpolate_attributes(
|
|
g, interpolate_mode, args
|
|
)
|
|
symbolic_helper._interpolate_warning(interpolate_mode)
|
|
align_corners = symbolic_helper._maybe_get_scalar(align_corners)
|
|
if align_corners:
|
|
return symbolic_helper._unimplemented(name, "align_corners == True", input)
|
|
output_size = symbolic_helper._maybe_get_const(output_size, "is")
|
|
if symbolic_helper._is_value(output_size):
|
|
return symbolic_helper._unimplemented(
|
|
name, "torch._C.Value (output_size) indexing"
|
|
)
|
|
if scales is None:
|
|
scales = [
|
|
1.0
|
|
if i < 2
|
|
else float(output_size[-(dim - i)])
|
|
/ float(input.type().sizes()[-(dim - i)])
|
|
for i in range(0, dim)
|
|
]
|
|
return g.op("Upsample", input, mode_s=interpolate_mode, scales_f=scales)
|
|
|
|
return symbolic_fn
|
|
|
|
|
|
@_onnx_symbolic("aten::__interpolate")
|
|
def __interpolate(
|
|
g: jit_utils.GraphContext,
|
|
input,
|
|
size,
|
|
scale_factor,
|
|
mode,
|
|
align_corners,
|
|
recompute_scale_factor,
|
|
antialias,
|
|
):
|
|
align_corners = symbolic_helper._maybe_get_const(align_corners, "b")
|
|
if not symbolic_helper._is_none(align_corners) and align_corners:
|
|
return symbolic_helper._unimplemented("interpolate", "align_corners == True")
|
|
|
|
if not symbolic_helper._is_none(scale_factor) and symbolic_helper._is_value(
|
|
scale_factor
|
|
):
|
|
return symbolic_helper._unimplemented(
|
|
"interpolate", "dynamic scales in opset 8"
|
|
)
|
|
|
|
if not symbolic_helper._is_none(size) and symbolic_helper._is_value(size):
|
|
return symbolic_helper._unimplemented("interpolate", "dynamic size in opset 8")
|
|
|
|
scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
|
|
g, input, size, scale_factor, mode, align_corners
|
|
)
|
|
return g.op("Upsample", input, mode_s=mode, scales_f=scales)
|
|
|
|
|
|
# NOTE: We should create a wrapper for this kind of operation, after resolving the shape/type propagation
|
|
# issue for "cast" operators. Some symbolic functions depend on shape information of input tensor, which
|
|
# is lost after casting.
|
|
def _try_cast_integer_to_float(g: jit_utils.GraphContext, *args):
|
|
floating_scalar_types = {
|
|
_type_utils.JitScalarType.HALF,
|
|
_type_utils.JitScalarType.FLOAT,
|
|
_type_utils.JitScalarType.DOUBLE,
|
|
}
|
|
old_type = None
|
|
# Cast the input tensor to Float if its scalarType is known and is not floating number.
|
|
# If casting is performed, return the old scalarType, otherwise return None.
|
|
arg0_type = _type_utils.JitScalarType.from_value(
|
|
args[0], _type_utils.JitScalarType.UNDEFINED
|
|
)
|
|
if arg0_type != _type_utils.JitScalarType.UNDEFINED:
|
|
old_type = arg0_type
|
|
if old_type not in floating_scalar_types:
|
|
old_type = old_type.scalar_name()
|
|
args = tuple(
|
|
g.op("Cast", arg, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
|
for arg in args
|
|
)
|
|
else:
|
|
return (None,) + args
|
|
else:
|
|
warnings.warn(
|
|
"Only floating datatype is supported for these operators: "
|
|
"{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause "
|
|
"the onnx model to be incorrect, if inputs have integer datatypes."
|
|
)
|
|
return (old_type,) + args
|
|
|
|
|
|
def _cast_to_type(g: jit_utils.GraphContext, input, to_type):
|
|
if to_type is None:
|
|
return input
|
|
return getattr(opset9, f"_cast_{to_type}")(g, input, False)
|
|
|
|
|
|
def _comparison_operator(g: jit_utils.GraphContext, input, other, op_name):
|
|
other = symbolic_helper._maybe_get_scalar(other)
|
|
other = symbolic_helper._if_scalar_type_as(other, input)
|
|
_, input, other = _try_cast_integer_to_float(g, input, other)
|
|
return g.op(op_name, input, other)
|
|
|
|
|
|
# NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten},
|
|
# integer input type not supported in opset8. Cast to float if possible.
|
|
@_onnx_symbolic("aten::gt")
|
|
def gt(g: jit_utils.GraphContext, input, other):
|
|
return _comparison_operator(g, input, other, "Greater")
|
|
|
|
|
|
@_onnx_symbolic("aten::lt")
|
|
def lt(g: jit_utils.GraphContext, input, other):
|
|
return _comparison_operator(g, input, other, "Less")
|
|
|
|
|
|
@_onnx_symbolic("aten::bmm")
|
|
def bmm(g: jit_utils.GraphContext, self, other):
|
|
if symbolic_helper._try_get_scalar_type(self):
|
|
old_type, self, other = _try_cast_integer_to_float(g, self, other)
|
|
return _cast_to_type(g, g.op("MatMul", self, other), old_type)
|
|
else:
|
|
return g.op("MatMul", self, other)
|
|
|
|
|
|
@_onnx_symbolic("aten::matmul")
|
|
def matmul(g: jit_utils.GraphContext, self, other):
|
|
return bmm(g, self, other)
|
|
|
|
|
|
@_onnx_symbolic("aten::prelu")
|
|
def prelu(g: jit_utils.GraphContext, self, weight):
|
|
self_rank = symbolic_helper._get_tensor_rank(self)
|
|
weight_sizes = symbolic_helper._get_tensor_sizes(weight)
|
|
if self_rank is not None and self_rank > 2:
|
|
weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1)))
|
|
elif self_rank == 0 and weight_sizes == [1]:
|
|
# self and weight are both scalar but weight has rank == 1, squeeze weight.
|
|
weight = symbolic_helper._squeeze_helper(g, weight, [0])
|
|
if symbolic_helper._try_get_scalar_type(self):
|
|
old_type, self, weight = _try_cast_integer_to_float(g, self, weight)
|
|
return _cast_to_type(g, g.op("PRelu", self, weight), old_type)
|
|
else:
|
|
return g.op("PRelu", self, weight)
|
|
|
|
|
|
@_onnx_symbolic("aten::mm")
|
|
def mm(g: jit_utils.GraphContext, self, other):
|
|
# Create a dummy C tensor. Only needed for API purposes, the value is
|
|
# since beta = 0
|
|
scalar_type = symbolic_helper._try_get_scalar_type(self, other)
|
|
if scalar_type is None:
|
|
raise errors.SymbolicValueError(
|
|
"mm can only operate on tensors with known types", self
|
|
)
|
|
zero_constant = g.op(
|
|
"Constant",
|
|
value_t=torch.tensor([0], dtype=scalar_type.dtype()),
|
|
)
|
|
|
|
if symbolic_helper._try_get_scalar_type(self):
|
|
old_type, self, other, zero_constant = _try_cast_integer_to_float(
|
|
g, self, other, zero_constant
|
|
)
|
|
return _cast_to_type(
|
|
g,
|
|
g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0),
|
|
old_type,
|
|
)
|
|
return g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0)
|
|
|
|
|
|
@_onnx_symbolic("aten::addmm")
|
|
@symbolic_helper.parse_args("v", "v", "v", "t", "t")
|
|
def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha):
|
|
if symbolic_helper._try_get_scalar_type(self):
|
|
old_type, self, mat1, mat2 = _try_cast_integer_to_float(g, self, mat1, mat2)
|
|
return _cast_to_type(
|
|
g,
|
|
g.op(
|
|
"Gemm",
|
|
mat1,
|
|
mat2,
|
|
self,
|
|
beta_f=symbolic_helper._scalar(beta),
|
|
alpha_f=symbolic_helper._scalar(alpha),
|
|
),
|
|
old_type,
|
|
)
|
|
else:
|
|
return g.op(
|
|
"Gemm",
|
|
mat1,
|
|
mat2,
|
|
self,
|
|
beta_f=symbolic_helper._scalar(beta),
|
|
alpha_f=symbolic_helper._scalar(alpha),
|
|
)
|
|
|
|
|
|
@_onnx_symbolic("aten::flatten")
|
|
def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim):
|
|
start_dim_i = symbolic_helper._get_const(start_dim, "i", "start_dim")
|
|
end_dim_i = symbolic_helper._get_const(end_dim, "i", "end_dim")
|
|
|
|
dim = input.type().dim()
|
|
if end_dim_i < 0:
|
|
end_dim_i = dim + end_dim_i
|
|
# use ONNX's Flatten operator for cases where the output shape is 2D
|
|
if start_dim_i == 1 and end_dim_i == dim - 1:
|
|
if symbolic_helper._try_get_scalar_type(input):
|
|
old_type, input = _try_cast_integer_to_float(g, input)
|
|
return _cast_to_type(
|
|
g, g.op("Flatten", input, axis_i=start_dim_i), old_type
|
|
)
|
|
else:
|
|
return g.op("Flatten", input, axis_i=start_dim_i)
|
|
if start_dim_i == 0 and end_dim_i == dim - 2:
|
|
if symbolic_helper._try_get_scalar_type(input):
|
|
old_type, input = _try_cast_integer_to_float(g, input)
|
|
return _cast_to_type(
|
|
g, g.op("Flatten", input, axis_i=end_dim_i + 1), old_type
|
|
)
|
|
else:
|
|
return g.op("Flatten", input, axis_i=end_dim_i + 1)
|
|
|
|
return opset9.flatten(g, input, start_dim, end_dim)
|
|
|
|
|
|
def _constant_fill(g: jit_utils.GraphContext, sizes, dtype: int, const_value):
|
|
if dtype is None:
|
|
scalar_type = _type_utils.JitScalarType.FLOAT
|
|
else:
|
|
scalar_type = _type_utils.JitScalarType(dtype)
|
|
if not scalar_type.dtype().is_floating_point:
|
|
result = g.op(
|
|
"ConstantFill",
|
|
sizes,
|
|
dtype_i=_type_utils.JitScalarType.FLOAT.onnx_type(),
|
|
input_as_shape_i=1,
|
|
value_f=const_value,
|
|
)
|
|
return g.op("Cast", result, to_i=scalar_type.onnx_type())
|
|
else:
|
|
return g.op(
|
|
"ConstantFill",
|
|
sizes,
|
|
dtype_i=scalar_type.onnx_type(),
|
|
input_as_shape_i=1,
|
|
value_f=const_value,
|
|
)
|
|
|
|
|
|
@_onnx_symbolic("aten::empty")
|
|
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
|
|
def empty(
|
|
g: jit_utils.GraphContext,
|
|
sizes,
|
|
dtype,
|
|
layout,
|
|
device,
|
|
pin_memory=False,
|
|
memory_format=None,
|
|
):
|
|
return zeros(g, sizes, dtype, layout, device, pin_memory)
|
|
|
|
|
|
@_onnx_symbolic("aten::empty_like")
|
|
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
|
|
def empty_like(
|
|
g: jit_utils.GraphContext,
|
|
input,
|
|
dtype,
|
|
layout,
|
|
device,
|
|
pin_memory=False,
|
|
memory_format=None,
|
|
):
|
|
return zeros_like(g, input, dtype, layout, device, pin_memory)
|
|
|
|
|
|
@_onnx_symbolic("aten::zeros")
|
|
@symbolic_helper.parse_args("v", "i", "v", "v", "v")
|
|
def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
|
|
# NOTE: no way to set device and layout in ONNX, so we ignore it
|
|
return _constant_fill(g, sizes, dtype, 0)
|
|
|
|
|
|
@_onnx_symbolic("aten::zeros_like")
|
|
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
|
|
def zeros_like(
|
|
g: jit_utils.GraphContext,
|
|
input,
|
|
dtype,
|
|
layout,
|
|
device,
|
|
pin_memory=False,
|
|
memory_format=None,
|
|
):
|
|
shape = g.op("Shape", input)
|
|
return _constant_fill(g, shape, dtype, 0)
|
|
|
|
|
|
@_onnx_symbolic("aten::ones")
|
|
@symbolic_helper.parse_args("v", "i", "v", "v", "v")
|
|
def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
|
|
return _constant_fill(g, sizes, dtype, 1)
|
|
|
|
|
|
@_onnx_symbolic("aten::ones_like")
|
|
@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
|
|
def ones_like(
|
|
g: jit_utils.GraphContext,
|
|
input,
|
|
dtype,
|
|
layout,
|
|
device,
|
|
pin_memory=False,
|
|
memory_format=None,
|
|
):
|
|
shape = g.op("Shape", input)
|
|
return _constant_fill(g, shape, dtype, 1)
|
|
|
|
|
|
@_onnx_symbolic("aten::full")
|
|
def full(
|
|
g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False
|
|
):
|
|
const_value = symbolic_helper._maybe_get_const(value, "t")
|
|
if symbolic_helper._is_value(const_value):
|
|
tmp = zeros(g, sizes, dtype, layout, device)
|
|
return opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
|
|
else:
|
|
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
|
return _constant_fill(g, sizes, dtype, const_value)
|
|
|
|
|
|
@_onnx_symbolic("aten::full_like")
|
|
@symbolic_helper.parse_args("v", "f", "i", "v", "v", "v", "v")
|
|
def full_like(
|
|
g: jit_utils.GraphContext,
|
|
input,
|
|
fill_value,
|
|
dtype,
|
|
layout,
|
|
device,
|
|
pin_memory=False,
|
|
memory_format=None,
|
|
):
|
|
shape = g.op("Shape", input)
|
|
return _constant_fill(g, shape, dtype, fill_value)
|
|
|
|
|
|
@_onnx_symbolic("aten::repeat")
|
|
def repeat(g: jit_utils.GraphContext, self, repeats):
|
|
if not symbolic_helper._is_value(repeats):
|
|
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
|
|
if symbolic_helper._is_packed_list(repeats):
|
|
repeat_size_len = len(symbolic_helper._unpack_list(repeats))
|
|
else:
|
|
const_repeats = symbolic_helper._maybe_get_const(repeats, "is")
|
|
repeat_size_len = len(const_repeats)
|
|
if self.isCompleteTensor():
|
|
sizes = self.type().sizes()
|
|
diff_dims = repeat_size_len - len(sizes)
|
|
if diff_dims > 0:
|
|
self = opset9.view(
|
|
g, self, g.op("Constant", value_t=torch.tensor([1] * diff_dims + sizes))
|
|
)
|
|
return g.op("Tile", self, repeats)
|