Traktor/myenv/Lib/site-packages/torch/onnx/symbolic_opset11.py
2024-05-26 05:12:46 +02:00

1651 lines
58 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""This file exports ONNX ops for opset 11."""
from __future__ import annotations
import functools
import sys
import warnings
from typing import Optional, Sequence
import torch
from torch import _C
from torch._C import _onnx as _C_onnx
from torch.onnx import (
_type_utils,
errors,
symbolic_helper,
symbolic_opset10 as opset10,
symbolic_opset9 as opset9,
utils,
)
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import _beartype, jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
__all__ = [
"add",
"append",
"arange",
"argsort",
"atleast_1d",
"atleast_2d",
"atleast_3d",
"cat",
"chunk",
"clamp_max",
"clamp_min",
"clamp",
"constant_pad_nd",
"cumsum",
"Delete",
"embedding_bag",
"embedding_renorm",
"flatten",
"gather",
"hardtanh",
"hstack",
"im2col",
"index_fill",
"index",
"index_copy",
"index_put",
"insert",
"linalg_det",
"linalg_vector_norm",
"logdet",
"masked_scatter",
"masked_select",
"mm",
"narrow",
"normal",
"pad",
"pixel_shuffle",
"pop",
"prim_constant_chunk",
"reflection_pad",
"relu6",
"remainder",
"replication_pad",
"round",
"scatter",
"select",
"size",
"sort",
"split_with_sizes",
"split",
"squeeze",
"stack",
"topk",
"unbind",
"unique_dim",
"unsqueeze",
"vstack",
]
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=11)
def _apply_params(*args, **kwargs):
"""Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
def _apply(fn):
return fn(*args, **kwargs)
return _apply
@_onnx_symbolic("aten::hardtanh")
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "f", "f")
@_beartype.beartype
def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float):
scalar_type = _type_utils.JitScalarType.from_value(
self, _type_utils.JitScalarType.FLOAT
)
min_val = g.op(
"Constant",
value_t=torch.tensor(min_val, dtype=scalar_type.dtype()),
)
max_val = g.op(
"Constant",
value_t=torch.tensor(max_val, dtype=scalar_type.dtype()),
)
return opset9._op_with_optional_float_cast(
g, "Clip", self, min_val, max_val, opset_before=12
)
@_onnx_symbolic("aten::clamp")
@_beartype.beartype
def clamp(g: jit_utils.GraphContext, self, min, max):
@_beartype.beartype
def _cast_if_not_none(tensor, dtype):
if tensor is not None and not symbolic_helper._is_none(tensor):
return g.op(
"Cast",
tensor,
to_i=dtype.onnx_type(),
)
else:
return tensor
scalar_type = _type_utils.JitScalarType.from_value(
self, _type_utils.JitScalarType.UNDEFINED
)
if scalar_type != _type_utils.JitScalarType.UNDEFINED:
min = _cast_if_not_none(min, scalar_type)
max = _cast_if_not_none(max, scalar_type)
if symbolic_helper._is_none(min):
return clamp_max(g, self, max)
elif symbolic_helper._is_none(max):
return clamp_min(g, self, min)
else:
if (
symbolic_helper._get_tensor_rank(min) == 0
and symbolic_helper._get_tensor_rank(max) == 0
):
return opset9._op_with_optional_float_cast(
g, "Clip", self, min, max, opset_before=12
)
else:
return clamp_max(g, clamp_min(g, self, min), max)
@_onnx_symbolic("aten::clamp_min")
@symbolic_helper.parse_args("v", "v")
@_beartype.beartype
def clamp_min(g: jit_utils.GraphContext, self, min):
min = g.op("Cast", min, to_i=_type_utils.JitScalarType.from_value(self).onnx_type())
if symbolic_helper._get_tensor_rank(min) == 0:
max = opset9.unused(g)
return opset9._op_with_optional_float_cast(
g, "Clip", self, min, max, opset_before=12
)
else:
return opset9._op_with_optional_float_cast(g, "Max", self, min, opset_before=12)
@_onnx_symbolic("aten::clamp_max")
@symbolic_helper.parse_args("v", "v")
@_beartype.beartype
def clamp_max(g: jit_utils.GraphContext, self, max):
max = g.op("Cast", max, to_i=_type_utils.JitScalarType.from_value(self).onnx_type())
if symbolic_helper._get_tensor_rank(max) == 0:
min = opset9.unused(g)
return opset9._op_with_optional_float_cast(
g, "Clip", self, min, max, opset_before=12
)
else:
return opset9._op_with_optional_float_cast(g, "Min", self, max, opset_before=12)
@_onnx_symbolic("aten::relu6")
@_beartype.beartype
def relu6(g: jit_utils.GraphContext, input):
scalar_type = _type_utils.JitScalarType.from_value(
input, _type_utils.JitScalarType.FLOAT
)
min_val = g.op(
"Constant",
value_t=torch.tensor(0, dtype=scalar_type.dtype()),
)
max_val = g.op(
"Constant",
value_t=torch.tensor(6, dtype=scalar_type.dtype()),
)
return clamp(g, input, min_val, max_val)
@_onnx_symbolic("aten::select")
# Opset 11 gather accepts negative indices
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "i", "v")
@_beartype.beartype
def select(g: jit_utils.GraphContext, self, dim, index):
return g.op("Gather", self, index, axis_i=dim)
@_onnx_symbolic("aten::index_put")
@_beartype.beartype
def index_put(
g: jit_utils.GraphContext, self, indices_list_value, values, accumulate=False
):
if symbolic_helper._is_packed_list(indices_list_value):
indices_list = symbolic_helper._unpack_list(indices_list_value)
else:
indices_list = [indices_list_value]
if symbolic_helper.is_caffe2_aten_fallback():
args = [self] + indices_list + [values, accumulate]
return g.at("index_put", *args)
accumulate = symbolic_helper._parse_arg(accumulate, "b")
if len(indices_list) == 0:
return values
if len(indices_list) > 1:
for idx_ in range(len(indices_list)):
if symbolic_helper._is_bool(indices_list[idx_]):
indices_list[idx_] = g.op("NonZero", indices_list[idx_])
index = indices_list[0]
for ind in indices_list[1:]:
index = opset9.add(g, index, ind)
broadcast_index_shape = g.op("Shape", index)
indices_list = [
symbolic_helper._unsqueeze_helper(
g, opset9.expand(g, ind, broadcast_index_shape, None), [-1]
)
for ind in indices_list
]
index = g.op("Concat", *indices_list, axis_i=-1)
else:
# Replace index_put node with masked_scatter or masked_fill
# when inputs to the index_put node contains a single boolean input.
#
# index_put -> masked_fill
# * input index contains single tensor of Bool type (e.g.: %24 <- %23).
# * input value contains single element (e.g.: %18).
#
# Torch IR
# %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
# %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
# aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
# %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
# %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22)
# %24 : Tensor?[] = prim::ListConstruct(%23)
# %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
# aten::index_put(%mask, %24, %18, %30)
# return (%25)
#
#
# index_put -> masked_scatter
# * input index contains single tensor of Bool type (e.g.: %32 <- %31).
# * input value contains multiple elements (e.g.: %28).
#
# Torch IR
# %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
# %28 : Float(8, strides=[1], requires_grad=0, device=cpu)
# = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]()
# %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
# = aten::ne(%mask, %some_const)
# %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
# = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
# %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
# %30 : int[] = prim::Constant[value=[-1]]()
# %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30)
# %32 : Tensor?[] = prim::ListConstruct(%31)
# %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
# = aten::index_put(%mask, %32, %28, %38)
# return (%33)
index = indices_list[0]
bool_inp = index
if symbolic_helper._is_bool(bool_inp):
rank = symbolic_helper._get_tensor_rank(values)
if rank is not None and rank == 0:
return opset9.masked_fill(g, self, bool_inp, values)
mask_rank = symbolic_helper._get_tensor_rank(bool_inp)
self_rank = symbolic_helper._get_tensor_rank(self)
if (
mask_rank is not None
and self_rank is not None
and self_rank > mask_rank
):
# Unsqueeze 'bool_inp' to be broadcastable to shape of 'self'.
bool_inp = symbolic_helper._unsqueeze_helper(
g, bool_inp, list(range(mask_rank, self_rank))
)
return masked_scatter(g, self, bool_inp, values)
broadcast_index_shape = g.op("Shape", index)
index = symbolic_helper._unsqueeze_helper(g, index, [-1])
sub_data_shape = symbolic_helper._slice_helper(
g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize]
)
values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0)
# Check if values is a singular value and expand accordingly
rank = symbolic_helper._get_tensor_rank(values)
if rank is not None and rank == 0:
values = opset9.expand(g, values, values_shape, None)
values = symbolic_helper._reshape_helper(g, values, values_shape)
self_scalar_type = _type_utils.JitScalarType.from_value(
self, _type_utils.JitScalarType.UNDEFINED
)
if self_scalar_type != _type_utils.JitScalarType.UNDEFINED:
values_scalar_type = _type_utils.JitScalarType.from_value(
values, _type_utils.JitScalarType.UNDEFINED
)
if self_scalar_type != values_scalar_type:
values = g.op("Cast", values, to_i=self_scalar_type.onnx_type())
elif accumulate:
raise errors.SymbolicValueError("self does not have a valid scalar type.", self)
if accumulate:
zeros = g.op(
"ConstantOfShape",
g.op("Shape", self),
value_t=torch.tensor([0], dtype=self_scalar_type.dtype()),
)
result = g.op("ScatterND", zeros, index, values)
result = add(g, self, result)
else:
result = g.op("ScatterND", self, index, values)
return result
@_onnx_symbolic("aten::pixel_shuffle")
@symbolic_helper.parse_args("v", "i")
@_beartype.beartype
def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor):
rank = symbolic_helper._get_tensor_rank(self)
if rank is not None and rank != 4:
return symbolic_helper._unimplemented("pixel_shuffle", "only support 4d input")
return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD")
@_onnx_symbolic(
"aten::upsample_nearest1d",
decorate=[_apply_params("upsample_nearest1d", 3, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_nearest2d",
decorate=[_apply_params("upsample_nearest2d", 4, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_nearest3d",
decorate=[_apply_params("upsample_nearest3d", 5, "nearest")],
)
@_onnx_symbolic(
"aten::upsample_linear1d",
decorate=[_apply_params("upsample_linear1d", 3, "linear")],
)
@_onnx_symbolic(
"aten::upsample_bilinear2d",
decorate=[_apply_params("upsample_bilinear2d", 4, "linear")],
)
@_onnx_symbolic(
"aten::upsample_trilinear3d",
decorate=[_apply_params("upsample_trilinear3d", 5, "linear")],
)
@_onnx_symbolic(
"aten::upsample_bicubic2d",
decorate=[_apply_params("upsample_bicubic2d", 4, "cubic")],
)
@_beartype.beartype
def _interpolate(name: str, dim: int, interpolate_mode: str):
return symbolic_helper._interpolate_helper(name, dim, interpolate_mode)
@_onnx_symbolic("aten::__interpolate")
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
@_beartype.beartype
def __interpolate(
g: jit_utils.GraphContext,
input,
size,
scale_factor,
mode,
align_corners,
recompute_scale_factor,
antialias,
):
return symbolic_helper.__interpolate_helper(
g, input, size, scale_factor, mode, align_corners, recompute_scale_factor
)
@_onnx_symbolic("aten::gather")
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False):
if symbolic_helper._maybe_get_const(sparse_grad, "i"):
return symbolic_helper._unimplemented("gather", "sparse_grad == True")
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("gather", self, dim, index, sparse_grad)
return g.op("GatherElements", self, index, axis_i=dim)
@_onnx_symbolic("aten::scatter")
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def scatter(g: jit_utils.GraphContext, self, dim, index, src):
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("scatter", self, dim, index, src, overload_name="src")
src_type = _type_utils.JitScalarType.from_value(src)
src = symbolic_helper._maybe_get_scalar(src)
if symbolic_helper._is_value(src):
return g.op("ScatterElements", self, index, src, axis_i=dim)
else:
# Check if scalar "src" has same type as self (PyTorch allows different
# type for scalar src (but not when src is tensor)). If not, insert Cast node.
if _type_utils.JitScalarType.from_value(self) != src_type:
src = g.op(
"Cast",
src,
to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
)
return g.op(
"ScatterElements", self, index, opset9.expand_as(g, src, index), axis_i=dim
)
@_onnx_symbolic("aten::cumsum")
@symbolic_helper.parse_args("v", "i", "none")
@_beartype.beartype
def cumsum(g: jit_utils.GraphContext, self, dim, dtype=None):
dim_tensor = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int))
if dtype and dtype.node().kind() != "prim::Constant":
parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
cast = g.op(
"Cast", self, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()
)
else:
cast = self
csum = g.op("CumSum", cast, dim_tensor)
return csum
@_onnx_symbolic("aten::masked_select")
@_beartype.beartype
def masked_select(g: jit_utils.GraphContext, self, mask):
index = opset9.nonzero(g, opset9.expand_as(g, mask, self))
return g.op("GatherND", self, index)
@_onnx_symbolic("aten::masked_scatter")
@_beartype.beartype
def masked_scatter(g: jit_utils.GraphContext, self, mask, source):
index = opset9.nonzero(g, opset9.expand_as(g, mask, self))
# NOTE: source can have more elements than needed.
# It could also have arbitrary shape.
# This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor.
source = symbolic_helper._reshape_helper(g, source, torch.LongTensor([-1]))
source = symbolic_helper._slice_helper(
g,
source,
axes=torch.LongTensor([0]),
starts=torch.LongTensor([0]),
ends=opset9.size(g, index, torch.LongTensor([0])),
)
return g.op("ScatterND", self, index, source)
@_onnx_symbolic("aten::len")
@_beartype.beartype
def _len(g: jit_utils.GraphContext, self):
if (
symbolic_helper._is_tensor_list(self)
or self.node().kind() == "onnx::SplitToSequence"
):
return g.op("SequenceLength", self)
sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0])))
return symbolic_helper._squeeze_helper(g, sz_0, [0])
@_onnx_symbolic("aten::__getitem_")
@_beartype.beartype
def __getitem_(g: jit_utils.GraphContext, self, i):
if symbolic_helper._is_tensor_list(self):
# SequenceAt requires that the input be a List of Tensors
return g.op("SequenceAt", self, i)
else:
from torch.onnx.symbolic_opset9 import __getitem_ as getitem
return getitem(g, self, i)
@_onnx_symbolic("aten::_set_item")
@_beartype.beartype
def _set_item(g: jit_utils.GraphContext, tensor_list, i, v):
tensor_list = g.op("SequenceErase", tensor_list, i)
return g.op("SequenceInsert", tensor_list, v, i)
@_onnx_symbolic("aten::append")
@_beartype.beartype
def append(g: jit_utils.GraphContext, self, tensor):
return g.op("SequenceInsert", self, tensor)
@_onnx_symbolic("aten::add")
@_beartype.beartype
def add(g: jit_utils.GraphContext, self, other, alpha=None):
if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self):
tensor_list_node = other.node()
if tensor_list_node.kind() != "prim::ListConstruct":
return symbolic_helper._unimplemented(
"add", "does not support adding dynamic tensor list to another"
)
tensors = symbolic_helper._unpack_list(other)
l = self
for t in tensors:
l = g.op("SequenceInsert", l, t)
return l
return opset9.add(g, self, other, alpha)
@_onnx_symbolic("aten::insert")
@_beartype.beartype
def insert(g: jit_utils.GraphContext, self, pos, tensor):
return g.op("SequenceInsert", self, tensor, pos)
@_onnx_symbolic("aten::pop")
@_beartype.beartype
def pop(g: jit_utils.GraphContext, tensor_list, dim):
return g.op("SequenceErase", tensor_list, dim)
@_onnx_symbolic("aten::Delete")
@_beartype.beartype
def Delete(g: jit_utils.GraphContext, tensor_list, dim):
return g.op("SequenceErase", tensor_list, dim)
@_onnx_symbolic("aten::cat")
@symbolic_helper.quantized_args(True)
@_beartype.beartype
def cat(g: jit_utils.GraphContext, tensor_list, dim):
if symbolic_helper._is_packed_list(tensor_list):
return opset9.cat(g, tensor_list, dim)
else:
dim = symbolic_helper._get_const(dim, "i", "dim")
return g.op("ConcatFromSequence", tensor_list, axis_i=dim)
@_onnx_symbolic("aten::stack")
@_beartype.beartype
def stack(g: jit_utils.GraphContext, tensor_list, dim):
if symbolic_helper._is_packed_list(tensor_list):
return opset9.stack(g, tensor_list, dim)
else:
dim = symbolic_helper._get_const(dim, "i", "dim")
return g.op("ConcatFromSequence", tensor_list, axis_i=dim, new_axis_i=1)
@_onnx_symbolic("aten::_unique2")
@symbolic_helper.parse_args("v", "i", "i", "i")
@_beartype.beartype
def _unique2(g: jit_utils.GraphContext, self, sorted, return_inverse, return_counts):
u, indices, inverse_indices, counts = g.op(
"Unique", self, sorted_i=sorted, outputs=4
)
return u, inverse_indices, counts
@_onnx_symbolic("aten::unique_dim")
@symbolic_helper.parse_args("v", "i", "i", "i", "i")
@_beartype.beartype
def unique_dim(
g: jit_utils.GraphContext, self, dim, sorted, return_inverse, return_counts
):
u, indices, inverse_indices, counts = g.op(
"Unique", self, axis_i=dim, sorted_i=sorted, outputs=4
)
return u, inverse_indices, counts
@_onnx_symbolic("aten::topk")
@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none")
@_beartype.beartype
def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
return symbolic_helper._topk_helper(
g, self, k, dim, largest=largest, sorted=sorted, out=out
)
@_onnx_symbolic("aten::sort")
@symbolic_helper.parse_args("v", "i", "i", "none")
@_beartype.beartype
def sort(g: jit_utils.GraphContext, self, dim, decending, out=None):
return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out)
@_onnx_symbolic("aten::argsort")
@symbolic_helper.parse_args("v", "i", "i", "none")
@_beartype.beartype
def argsort(g: jit_utils.GraphContext, self, dim, decending, out=None):
_, indices = symbolic_helper._sort_helper(
g, self, dim, decending=decending, out=out
)
return indices
@_onnx_symbolic("aten::round")
@symbolic_helper.parse_args("v", "i")
@_beartype.beartype
def round(g: jit_utils.GraphContext, self, decimals=0):
if not symbolic_helper._is_fp(self):
return self
if decimals == 0:
return g.op("Round", self)
mul = g.op("Mul", self, g.op("Constant", value_t=torch.tensor(pow(10, decimals))))
round = g.op("Round", mul)
return g.op(
"Mul", round, g.op("Constant", value_t=torch.tensor(pow(10, -1 * decimals)))
)
@_onnx_symbolic("aten::remainder")
@_beartype.beartype
def remainder(g: jit_utils.GraphContext, input, other):
if symbolic_helper._is_fp(input) or symbolic_helper._is_fp(other):
return opset9.remainder(g, input, other)
return g.op("Mod", input, other, fmod_i=0)
@_onnx_symbolic("aten::split")
@symbolic_helper.parse_args("v", "v", "i", "i")
@_beartype.beartype
def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None):
if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs):
split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim)
if _outputs is None:
return split_out
# Convert to multiple slice nodes iff number of splits and number of outputs are statically known.
if (
symbolic_helper._is_packed_list(split_size_or_sizes)
and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs
):
split_sizes = [
symbolic_helper._unsqueeze_helper(g, v, [0])
for v in symbolic_helper._unpack_list(split_size_or_sizes)
]
start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
res = []
for i in range(_outputs):
end = g.op(
"Add", start, split_sizes[i]
) # split_sizes is a list of same length as _outputs
res.append(g.op("Slice", self, start, end, axis))
start = end
return res
return [
g.op(
"SequenceAt",
split_out,
g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)),
)
for i in range(_outputs)
]
else:
return opset9.split(g, self, split_size_or_sizes, dim, _outputs)
@_onnx_symbolic("aten::split_with_sizes")
@symbolic_helper.parse_args("v", "v", "i", "i")
@_beartype.beartype
def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None):
return split(g, self, split_sizes, dim, _outputs)
@_onnx_symbolic("aten::unbind")
@symbolic_helper.parse_args("v", "i", "i")
@_beartype.beartype
def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None):
if _outputs is None:
return g.op(
"SplitToSequence",
self,
g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)),
axis_i=dim,
keepdims_i=0,
)
else:
return opset9.unbind(g, self, dim, _outputs)
@_beartype.beartype
def _prepare_onnx_paddings(g: jit_utils.GraphContext, input, pad):
"""Generate paddings in ONNX order based on pad in pytorch.
Args:
input: the input tensor.
pad: the paddings in pytorch.
The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end,
where m is in range [0, n].
"""
if (
not symbolic_helper._is_packed_list(pad)
and symbolic_helper._is_list(pad)
and symbolic_helper._is_scalar_list(pad)
):
pad = g.op("ConcatFromSequence", pad, axis_i=0, new_axis_i=1)
# The desired order of paddings is
# dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
# n is the dimension of input.
# Assume zero-dimensions in the beginning, pad the "pad" sequence with zeros in the beginning
pad_len = opset9.size(g, pad, g.op("Constant", value_t=torch.tensor([0])))
# Set extension = [0] * (dim * 2 - len(pad))
rank = symbolic_helper._get_tensor_rank(input)
if rank is None:
rank = g.op("Size", g.op("Shape", input))
else:
rank = g.op("Constant", value_t=torch.tensor(rank, dtype=torch.int64))
extension = g.op(
"Sub",
g.op("Mul", rank, g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))),
pad_len,
)
# Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ]
# Currently ONNX only supports int64 type for Pad
pad = g.op("Cast", pad, to_i=_C_onnx.TensorProtoDataType.INT64)
paddings = g.op(
"Concat",
pad,
g.op(
"ConstantOfShape", extension, value_t=torch.tensor([0], dtype=torch.int64)
),
axis_i=0,
)
# Reshape and reverse order and collate first beginnings and then ends
# paddings = [[..., 0, dim_n-1_begin, dim_n_begin],
# [..., 0, dim_n-1_end, dim_n_end]]
# Reshape back to 1-D paddings = [..., 0, dim_n - 1_begin, dim_n_begin, ..., 0, dim_n - 1_end, dim_n_end]
paddings = symbolic_helper._reshape_helper(
g, paddings, g.op("Constant", value_t=torch.tensor([-1, 2]))
)
paddings = g.op("Transpose", opset10.flip(g, paddings, [0]), perm_i=[1, 0])
paddings = symbolic_helper._reshape_helper(
g, paddings, g.op("Constant", value_t=torch.tensor([-1]))
)
padding_c = g.op("Cast", paddings, to_i=_C_onnx.TensorProtoDataType.INT64)
return padding_c
@_onnx_symbolic("aten::constant_pad_nd")
@_beartype.beartype
def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value=None):
mode = "constant"
value = symbolic_helper._maybe_get_scalar(value)
value = symbolic_helper._if_scalar_type_as(value, input)
pad = _prepare_onnx_paddings(g, input, padding)
return g.op("Pad", input, pad, value, mode_s=mode)
@_onnx_symbolic("aten::reflection_pad1d")
@_onnx_symbolic("aten::reflection_pad2d")
@_onnx_symbolic("aten::reflection_pad3d")
@_beartype.beartype
def reflection_pad(g: jit_utils.GraphContext, input, padding):
mode = "reflect"
paddings = _prepare_onnx_paddings(g, input, padding)
return g.op("Pad", input, paddings, mode_s=mode)
@_onnx_symbolic("aten::replication_pad1d")
@_onnx_symbolic("aten::replication_pad2d")
@_onnx_symbolic("aten::replication_pad3d")
@_beartype.beartype
def replication_pad(g: jit_utils.GraphContext, input, padding):
mode = "edge"
paddings = _prepare_onnx_paddings(g, input, padding)
return g.op("Pad", input, paddings, mode_s=mode)
@_onnx_symbolic("aten::pad")
@_beartype.beartype
def pad(
g: jit_utils.GraphContext,
input: _C.Value,
pad: _C.Value,
mode: _C.Value,
value: _C.Value,
):
mode = symbolic_helper._parse_arg(mode, "s")
if mode == "replicate":
return replication_pad(g, input, pad)
elif mode == "reflect":
return reflection_pad(g, input, pad)
elif mode == "constant":
return constant_pad_nd(g, input, pad, value)
elif mode == "circular":
return opset9._pad_circular(g, input, pad)
else:
raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input)
@_onnx_symbolic("aten::linalg_det")
@_beartype.beartype
def linalg_det(g: jit_utils.GraphContext, self):
return g.op("Det", self)
@_onnx_symbolic("aten::logdet")
@_beartype.beartype
def logdet(g: jit_utils.GraphContext, input):
return opset9.log(g, linalg_det(g, input))
@_onnx_symbolic("aten::arange")
@_beartype.beartype
def arange(g: jit_utils.GraphContext, *args):
def _get_arange_dtype(dtype):
dtype = symbolic_helper._maybe_get_const(dtype, "i")
return dtype
if len(args) == 2 and all(isinstance(val, int) for val in args):
# aten::arange(Scalar start, Scalar end)
dtype = torch.int64
# Start index.
start = g.op(
"Constant",
value_t=torch.tensor(args[0], dtype=dtype),
)
# End (exclusive) index.
end = g.op(
"Constant",
value_t=torch.tensor(args[1], dtype=dtype),
)
# Step size from start to end indexes.
delta_default = g.op(
"Constant",
value_t=torch.tensor(1, dtype=dtype),
)
return g.op("Range", start, end, delta_default)
elif len(args) == 2 or len(args) == 5:
if len(args) == 2:
# aten::arange(Scalar end, Tensor out)
dtype = None
else:
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
dtype = _get_arange_dtype(args[1])
type_, end, start, step = symbolic_helper._arange_cast_helper(
g, end=args[0], dtype=dtype
)
start_default = g.op(
"Constant",
value_t=torch.tensor(0, dtype=type_.dtype()),
)
delta_default = g.op(
"Constant",
value_t=torch.tensor(1, dtype=type_.dtype()),
)
return g.op("Range", start_default, end, delta_default)
elif len(args) == 4 or len(args) == 7:
if len(args) == 4:
# aten::arange(Scalar start, Scalar end, Scalar step, Tensor out)
dtype = None
else:
# aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory)
dtype = _get_arange_dtype(args[3])
_, end, start, step = symbolic_helper._arange_cast_helper(
g, start=args[0], end=args[1], step=args[2], dtype=dtype
)
return g.op("Range", start, end, step)
elif len(args) == 6:
# aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
dtype = _get_arange_dtype(args[2])
type_, end, start, step = symbolic_helper._arange_cast_helper(
g, start=args[0], end=args[1], dtype=dtype
)
delta_default = g.op(
"Constant",
value_t=torch.tensor(1, dtype=type_.dtype()),
)
return g.op("Range", start, end, delta_default)
else:
return symbolic_helper._unimplemented(
"aten::arange", f"with {len(args)} arguments"
)
@_onnx_symbolic("aten::_dim_arange")
@symbolic_helper.parse_args("v", "i")
@_beartype.beartype
def _dim_arange(g: jit_utils.GraphContext, like, dim):
like_shape = g.op("Shape", like)
stop = g.op(
"Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0
)
if symbolic_helper.is_caffe2_aten_fallback():
return g.op("_caffe2::Range", stop)
return arange(g, stop, 4, None, None, None)
@_onnx_symbolic("aten::size")
@symbolic_helper.quantized_args(True, quantize_output=False)
@_beartype.beartype
def size(g: jit_utils.GraphContext, self, dim=None):
if dim is None:
return g.op("Shape", self)
return symbolic_helper._size_helper(g, self, dim)
@_onnx_symbolic("aten::squeeze")
@_beartype.beartype
def squeeze(g: jit_utils.GraphContext, self, dim=None):
if dim is None:
return g.op("Squeeze", self)
# dim as a tensor
if not symbolic_helper._is_constant(dim):
return symbolic_helper._squeeze_helper(g, self, [dim])
dim = symbolic_helper._get_const(dim, "i", "dim")
input_rank = symbolic_helper._get_tensor_rank(self)
adjusted_dim = dim
if input_rank is not None and dim < 0:
adjusted_dim += input_rank
dim_size = symbolic_helper._get_tensor_dim_size(self, adjusted_dim)
if (dim < 0 and input_rank is None) or dim_size is None:
# If onnx shape inference is not on, export always as dynamic.
# Because we cannot tell if observed static shape is also static at runtime.
# create "cond" node (condition is shape[i]==1)
dim_constant = g.op("Constant", value_t=torch.tensor([dim]))
size = symbolic_helper._size_helper(g, self, dim_constant)
const_one = g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))
cond = g.op("Equal", size, const_one)
# create the "If" node and add the "then" and "else" blocks to it.
if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
g, "If", cond, n_blocks=2
)
squeeze_ = symbolic_helper._squeeze_helper(if_context, self, [dim])
utils._add_output_to_block(if_context.block, squeeze_)
identity_ = else_context.op("Identity", self)
utils._add_output_to_block(else_context.block, identity_)
return if_op
# For static input shape
dim = adjusted_dim
if dim_size > 1:
warnings.warn(
"This model contains a squeeze operation on dimension "
+ str(dim)
+ ". The size of "
+ "this dimension in the given input is "
+ str(dim_size)
+ ". The model will "
+ "be exported without the squeeze node. If the model is intended to be used with dynamic "
+ "input shapes, please export with dynamic_axes argument."
)
return self
return symbolic_helper._squeeze_helper(g, self, [dim])
@_onnx_symbolic("aten::unsqueeze")
@_beartype.beartype
def unsqueeze(g: jit_utils.GraphContext, self, dim):
if symbolic_helper._is_constant(dim):
dim = symbolic_helper._get_const(dim, "i", "dim")
return symbolic_helper._unsqueeze_helper(g, self, [dim])
@_onnx_symbolic("aten::mm")
@_beartype.beartype
def mm(g: jit_utils.GraphContext, self, other):
return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0)
@_onnx_symbolic("aten::index")
@_beartype.beartype
def index(g: jit_utils.GraphContext, self, index):
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("index", self, index, overload_name="Tensor")
if symbolic_helper._is_packed_list(index):
indices = symbolic_helper._unpack_list(index)
else:
indices = [index]
# Handle single mask index.
if len(indices) == 1:
index = indices[0]
if not symbolic_helper._is_none(index) and (
symbolic_helper._is_bool(index)
or _type_utils.JitScalarType.from_value(index)
== _type_utils.JitScalarType.UINT8
):
index = opset9.nonzero(g, index)
return g.op("GatherND", self, index)
return opset9.index(g, self, index)
@_onnx_symbolic("aten::index_fill")
@_beartype.beartype
def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
dim_value = symbolic_helper._parse_arg(dim, "i")
if symbolic_helper.is_caffe2_aten_fallback():
return g.at(
"index_fill",
self,
index,
value,
overload_name="int_Scalar",
dim_i=dim_value,
)
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
g, self, dim, index
)
value = symbolic_helper._maybe_get_scalar(value)
value = symbolic_helper._if_scalar_type_as(value, self)
expanded_value = opset9.expand(g, value, expanded_index_shape, None)
return scatter(g, self, dim, expanded_index, expanded_value)
@_onnx_symbolic("aten::index_copy")
@_beartype.beartype
def index_copy(g: jit_utils.GraphContext, self, dim, index, source):
dim_value = symbolic_helper._parse_arg(dim, "i")
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("index_copy", self, index, source, dim_i=dim_value)
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
g, self, dim, index
)
return scatter(g, self, dim, expanded_index, source)
@_onnx_symbolic("aten::__rshift_")
@_beartype.beartype
def __rshift_(g: jit_utils.GraphContext, self, other):
# make sure to cast other to self's type
# (when self is long, make sure that other is not float)
if _type_utils.JitScalarType.from_value(
other, _type_utils.JitScalarType.UNDEFINED
) != _type_utils.JitScalarType.from_value(self):
other = g.op(
"Cast",
other,
to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
)
if (
_type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED)
== _type_utils.JitScalarType.UINT8
):
return g.op("BitShift", self, other, direction_s="RIGHT")
two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32))
# exponent (same type as self) has to be float or double in onnx::Pow
if not symbolic_helper._is_fp(self):
other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT)
two_pow = g.op("Pow", two, other)
two_pow = g.op(
"Cast",
two_pow,
to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
)
rshift = g.op("Div", self, two_pow)
return rshift
@_onnx_symbolic("aten::__lshift_")
@_beartype.beartype
def __lshift_(g: jit_utils.GraphContext, self, other):
# make sure to cast other to self's type
# (when self is long, make sure that other is not float)
if _type_utils.JitScalarType.from_value(
other, _type_utils.JitScalarType.UNDEFINED
) != _type_utils.JitScalarType.from_value(self):
other = g.op(
"Cast",
other,
to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
)
if (
_type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED)
== _type_utils.JitScalarType.UINT8
):
return g.op("BitShift", self, other, direction_s="LEFT")
two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32))
# exponent (same type as self) has to be float or double in onnx::Pow
if not symbolic_helper._is_fp(self):
other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT)
two_pow = g.op("Pow", two, other)
two_pow = g.op(
"Cast",
two_pow,
to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
)
lshift = g.op("Mul", self, two_pow)
return lshift
@_beartype.beartype
def _get_im2col_indices_along_dim(
g: jit_utils.GraphContext, input_d, kernel_size_d, dilation_d, padding_d, stride_d
):
# Input is always 4-D (N, C, H, W)
# Calculate indices of sliding blocks along spatial dimension
# Slide kernel over input each dim d:
# each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1)
# with steps = stride
blocks_d = g.op(
"Add", input_d, g.op("Constant", value_t=torch.tensor(padding_d * 2))
)
blocks_d = g.op(
"Sub",
blocks_d,
g.op("Constant", value_t=torch.tensor(dilation_d * (kernel_size_d - 1))),
)
# Stride kernel over input and find starting indices along dim d
blocks_d_indices = g.op(
"Range",
g.op("Constant", value_t=torch.tensor(0)),
blocks_d,
g.op("Constant", value_t=torch.tensor(stride_d)),
)
# Apply dilation on kernel and find its indices along dim d
kernel_grid = torch.arange(0, kernel_size_d * dilation_d, dilation_d)
kernel_grid = g.op("Constant", value_t=kernel_grid.unsqueeze(0))
# Broadcast and add kernel staring positions (indices) with
# kernel_grid along dim d, to get block indices along dim d
blocks_d_indices = symbolic_helper._unsqueeze_helper(
g, blocks_d_indices, [0]
) # Reshape to [1, -1]
kernel_mask = symbolic_helper._reshape_helper(
g, kernel_grid, g.op("Constant", value_t=torch.tensor([-1, 1]))
)
block_mask = g.op("Add", blocks_d_indices, kernel_mask)
return block_mask
@_beartype.beartype
def _get_im2col_padded_input(g: jit_utils.GraphContext, input, padding_h, padding_w):
# Input is always 4-D tensor (N, C, H, W)
# Padding tensor has the following format: (padding_h, padding_w)
# Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...)
pad = g.op("Constant", value_t=torch.LongTensor([0, 0, padding_h, padding_w] * 2))
return g.op("Pad", input, pad)
@_beartype.beartype
def _get_im2col_output_shape(g: jit_utils.GraphContext, input, kernel_h, kernel_w):
batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0)))
channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1)))
channel_unfolded = g.op(
"Mul", channel_dim, g.op("Constant", value_t=torch.tensor(kernel_h * kernel_w))
)
return g.op(
"Concat",
symbolic_helper._unsqueeze_helper(g, batch_dim, [0]),
symbolic_helper._unsqueeze_helper(g, channel_unfolded, [0]),
g.op("Constant", value_t=torch.tensor([-1])),
axis_i=0,
)
@_onnx_symbolic("aten::im2col")
@symbolic_helper.parse_args("v", "is", "is", "is", "is")
@_beartype.beartype
def im2col(g: jit_utils.GraphContext, input, kernel_size, dilation, padding, stride):
# Input is always 4-D tensor (N, C, H, W)
# All other args are int[2]
input_h = size(g, input, g.op("Constant", value_t=torch.tensor(2)))
input_w = size(g, input, g.op("Constant", value_t=torch.tensor(3)))
stride_h, stride_w = stride[0], stride[1]
padding_h, padding_w = padding[0], padding[1]
dilation_h, dilation_w = dilation[0], dilation[1]
kernel_h, kernel_w = kernel_size[0], kernel_size[1]
blocks_row_indices = _get_im2col_indices_along_dim(
g, input_h, kernel_h, dilation_h, padding_h, stride_h
)
blocks_col_indices = _get_im2col_indices_along_dim(
g, input_w, kernel_w, dilation_w, padding_w, stride_w
)
output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w)
padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w)
# For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1
# [[[[1., 2., 3.,],
# [4., 5., 6.,],
# [7., 8., 9.,]]]]
# First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get:
# [[[[[1., 2., 3.],
# [4., 5., 6.]],
# [[4., 5., 6.],
# [7., 8., 9.]]]]]
# And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get:
# [[[[[[1., 2.],
# [4., 5.]],
# [[2., 3.],
# [5., 6]]],
# [[[4., 5.],
# [7., 8.]],
# [[5., 6.],
# [8., 9.]]]]]]
# Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get:
# [[[1., 2., 4., 5.],
# [2., 3., 5., 6.],
# [4., 5., 7., 8.],
# [5., 6., 8., 9.]]]
output = g.op("Gather", padded_input, blocks_row_indices, axis_i=2)
output = g.op("Gather", output, blocks_col_indices, axis_i=4)
output = g.op("Transpose", output, perm_i=[0, 1, 2, 4, 3, 5])
return symbolic_helper._reshape_helper(g, output, output_shape)
@_onnx_symbolic("aten::narrow")
@_beartype.beartype
def narrow(g: jit_utils.GraphContext, input, dim, start, length):
end = g.op("Add", start, length)
return symbolic_helper._slice_helper(g, input, axes=dim, starts=start, ends=end)
@_onnx_symbolic("aten::flatten")
@symbolic_helper.quantized_args(True, False, False)
@symbolic_helper.parse_args("v", "i", "i")
@_beartype.beartype
def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim):
dim = symbolic_helper._get_tensor_rank(input)
if dim == 1:
return input
# use ONNX's Flatten operator for cases where the output shape is 2D
if start_dim == 1:
if end_dim == -1 or (dim is not None and end_dim == dim - 1):
return g.op("Flatten", input, axis_i=start_dim)
elif start_dim == 0:
if end_dim == -2 or (dim is not None and end_dim == dim - 2):
return g.op("Flatten", input, axis_i=end_dim + 1)
if dim is None:
return symbolic_helper._unimplemented(
"dim",
"ONNX and PyTorch use different strategies to split the input. "
"Input rank must be known at export time.",
)
# if end_dim is negative add dim
if end_dim < 0:
end_dim = dim + end_dim
return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim)
@_onnx_symbolic("aten::linalg_vector_norm")
@symbolic_helper.parse_args("v", "f", "is", "b", "v")
@_beartype.beartype
def linalg_vector_norm(
g: jit_utils.GraphContext,
self,
ord,
dim: Optional[Sequence[int]],
keepdim: bool,
dtype,
):
if ord == 0:
if dim is None:
self = symbolic_helper._reshape_helper(
g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
)
keepdim = False
cond_op = g.op(
"Not", g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0])))
)
cond_op = g.op(
"Cast",
cond_op,
to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
)
return symbolic_helper._reducesum_helper(
g, cond_op, axes_i=dim, keepdims_i=keepdim
)
else:
return opset9.linalg_vector_norm(g, self, ord, dim, keepdim, dtype)
@_onnx_symbolic("aten::embedding_bag")
@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
@_beartype.beartype
def embedding_bag(
g: jit_utils.GraphContext,
embedding_matrix,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
padding_idx,
):
if scale_grad_by_freq and GLOBALS.export_training:
return symbolic_helper._onnx_unsupported(
"embedding_bag with scale_grad_by_freq for training mode"
)
if padding_idx is not None and padding_idx >= 0:
raise RuntimeError("embedding_bag with padding_idx")
loop_condition = g.op("Constant", value_t=torch.tensor(1))
loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL)
zero = g.op("Constant", value_t=torch.tensor([0]))
indices_len = symbolic_helper._unsqueeze_helper(
g,
symbolic_helper._size_helper(
g, indices, g.op("Constant", value_t=torch.tensor(0))
),
[0],
)
if not include_last_offset:
offsets = [offsets, indices_len]
offsets = g.op("Concat", *offsets, axis_i=0)
# Offsets holds the starting index position of each bag. So we create a list of the indices slices (determined by
# offsets) and gather those indices in indices_row. Then we use this subset of indices to gather from embeddings.
# The embeddings output is a loop scan output, so we can avoid creating a sequence and inserting elements in.
offsets_starts = symbolic_helper._slice_helper(
g, offsets, axes=[0], starts=[0], ends=[sys.maxsize], steps=[1]
)
offsets_ends = symbolic_helper._slice_helper(
g, offsets, axes=[0], starts=[1], ends=[sys.maxsize], steps=[1]
)
loop_len = symbolic_helper._size_helper(
g, offsets_ends, g.op("Constant", value_t=torch.tensor(0))
)
loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
g, "Loop", loop_len, loop_condition, n_blocks=1
)
loop_block = loop_context.block
# FIXME(justinchuby): We need to handle what happens when we call b.op on a node return
block_input_iter = utils._add_input_to_block(loop_block)
cond = utils._add_input_to_block(loop_block)
indices_start = loop_context.op(
"Gather", offsets_starts, block_input_iter, axis_i=0
)
indices_end = loop_context.op("Gather", offsets_ends, block_input_iter, axis_i=0)
indices_start = symbolic_helper._unsqueeze_helper(loop_context, indices_start, [0])
indices_end = symbolic_helper._unsqueeze_helper(loop_context, indices_end, [0])
indices_row = loop_context.op("Slice", indices, indices_start, indices_end, zero)
embeddings = loop_context.op("Gather", embedding_matrix, indices_row, axis_i=0)
if not symbolic_helper._is_none(per_sample_weights):
per_sample_weights_row = loop_context.op(
"Slice", per_sample_weights, indices_start, indices_end, zero
)
per_sample_weights_row = symbolic_helper._unsqueeze_helper(
loop_context, per_sample_weights_row, [1]
)
embeddings = loop_context.op("Mul", embeddings, per_sample_weights_row)
if mode == 0:
embeddings = symbolic_helper._reducesum_helper(
loop_context, embeddings, axes_i=[0], keepdims_i=0
)
elif mode == 1:
embeddings = loop_context.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
else:
embeddings = loop_context.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0)
cond_out = loop_context.op(
"Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL
)
utils._add_output_to_block(loop_block, cond_out)
utils._add_output_to_block(loop_block, embeddings)
# aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
# But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
return loop.node().output(), None, None, None
@_onnx_symbolic("aten::embedding_renorm")
@symbolic_helper.parse_args("v", "v", "f", "f")
@_beartype.beartype
def embedding_renorm(g: jit_utils.GraphContext, weight, indices, max_norm, norm_type):
unique_indices = g.op("Unique", indices)
partial_weight = g.op("Gather", weight, unique_indices)
norm_i = int(norm_type)
if norm_i == 1:
norm_type = "ReduceL1"
elif norm_i == 2:
norm_type = "ReduceL2"
else:
raise errors.SymbolicValueError(
f"Unsupported: ONNX export of embedding_renorm with norm: {norm_i}. "
"Only 1. and 2. are supported.",
weight,
)
partial_weight_norm = g.op(norm_type, partial_weight, axes_i=[1], keepdims_i=1)
# https://github.com/pytorch/pytorch/blob/0a07488ed2c47765e337e290bd138c0e6e459cbd/aten/src/ATen/native/Embedding.cpp#L177
# Add 1e-7 to prevent division by zero.
partial_weight_norm_ = g.op(
"Add", partial_weight_norm, g.op("Constant", value_t=torch.tensor(1e-7))
)
max_norm = torch.tensor(max_norm)
scales = g.op("Div", max_norm, partial_weight_norm_)
partial_weight_renorm = g.op("Mul", partial_weight, scales)
partial_weight_renorm = g.op(
"Where",
g.op("Greater", partial_weight_norm, max_norm),
partial_weight_renorm,
partial_weight,
)
return g.op(
"ScatterND",
weight,
symbolic_helper._unsqueeze_helper(g, unique_indices, [1]),
partial_weight_renorm,
)
@_onnx_symbolic("aten::chunk")
@_beartype.beartype
def chunk(g: jit_utils.GraphContext, self, chunks, dim):
# Calculate chunk size for dynamic chunk
dim_size = g.op("Gather", g.op("Shape", self), dim, axis_i=0)
chunk_size_s = g.op(
"Sub", chunks, g.op("Constant", value_t=torch.tensor([1], dtype=torch.long))
)
chunk_size = g.op("Div", g.op("Add", dim_size, chunk_size_s), chunks)
# Create splits vector
chunk_vec = [
opset9.expand(g, chunk_size, chunk_size_s, None),
g.op("Sub", dim_size, g.op("Mul", chunk_size, chunk_size_s)),
]
chunk_vec = g.op("Concat", *chunk_vec, axis_i=0)
return split(g, self, chunk_vec, dim)
@_onnx_symbolic("aten::normal")
@_beartype.beartype
def normal(
g: jit_utils.GraphContext,
mean,
std,
sizes=None,
generator=None,
dtype=None,
layout=None,
device=None,
pin_memory=None,
):
# If you can sample from a given distribution with mean 0 and variance 1, then you can easily sample from a
# scale-location transformation of that distribution, which has mean μ and variance σ's square. If x is a sample
# from a mean 0 and variance 1 distribution then
# σx+μ
# is a sample with mean μ and variance σ's square.
if sizes is not None and not symbolic_helper._is_none(sizes):
mean = opset9.expand(g, mean, sizes, None)
result = opset9.mul(g, std, g.op("RandomNormalLike", mean))
return add(g, result, mean)
@_onnx_symbolic("aten::atleast_1d")
@_beartype.beartype
def atleast_1d(g: jit_utils.GraphContext, self: torch._C.Value):
# NOTE: If it's 0D, reshape to 1D
# NOTE: self could be a packed list or a tensor
if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self):
tensor_list = symbolic_helper._unpack_list(self)
new_tensor_list = []
for tensor in tensor_list:
new_tensor = tensor
tensor_rank = symbolic_helper._get_tensor_rank(tensor)
if tensor_rank == 0:
new_tensor = symbolic_helper._reshape_helper(
g, new_tensor, g.op("Constant", value_t=torch.tensor([1]))
)
new_tensor_list.append(new_tensor)
return g.op("SequenceConstruct", *new_tensor_list)
tensor_rank = symbolic_helper._get_tensor_rank(self)
if tensor_rank == 0:
self = symbolic_helper._reshape_helper(
g, self, g.op("Constant", value_t=torch.tensor([1]))
)
return self
@_onnx_symbolic("aten::atleast_2d")
@_beartype.beartype
def atleast_2d(g: jit_utils.GraphContext, self: torch._C.Value):
# NOTE: If it's 0D, reshape to 2D
# If it's 1D, unsqueeze to 2D
# NOTE: self could be a packed list or a tensor
if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self):
tensor_list = symbolic_helper._unpack_list(self)
new_tensor_list = []
for tensor in tensor_list:
new_tensor = tensor
tensor_rank = symbolic_helper._get_tensor_rank(tensor)
if tensor_rank == 0:
new_tensor = symbolic_helper._reshape_helper(
g, new_tensor, g.op("Constant", value_t=torch.tensor([1, 1]))
)
elif tensor_rank == 1:
new_tensor = symbolic_helper._unsqueeze_helper(
g, new_tensor, axes_i=[0]
)
new_tensor_list.append(new_tensor)
return g.op("SequenceConstruct", *new_tensor_list)
tensor_rank = symbolic_helper._get_tensor_rank(self)
if tensor_rank == 0:
self = symbolic_helper._reshape_helper(
g, self, g.op("Constant", value_t=torch.tensor([1, 1]))
)
elif tensor_rank == 1:
self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[0])
return self
@_onnx_symbolic("aten::atleast_3d")
@_beartype.beartype
def atleast_3d(g: jit_utils.GraphContext, self: torch._C.Value):
# NOTE: If it's 0D, reshape to 3D
# If it's 1D, unsqueeze to 3D
# If it's 2D, unsqueeze to 3D
# NOTE: self could be a packed list or a tensor
if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self):
tensor_list = symbolic_helper._unpack_list(self)
new_tensor_list = []
for tensor in tensor_list:
new_tensor = tensor
tensor_rank = symbolic_helper._get_tensor_rank(tensor)
if tensor_rank == 0:
new_tensor = symbolic_helper._reshape_helper(
g, new_tensor, g.op("Constant", value_t=torch.tensor([1, 1, 1]))
)
elif tensor_rank == 1:
new_tensor = symbolic_helper._unsqueeze_helper(
g, new_tensor, axes_i=[0]
)
new_tensor = symbolic_helper._unsqueeze_helper(
g, new_tensor, axes_i=[-1]
)
elif tensor_rank == 2:
new_tensor = symbolic_helper._unsqueeze_helper(
g, new_tensor, axes_i=[-1]
)
new_tensor_list.append(new_tensor)
return g.op("SequenceConstruct", *new_tensor_list)
tensor_rank = symbolic_helper._get_tensor_rank(self)
if tensor_rank == 0:
self = symbolic_helper._reshape_helper(
g, self, g.op("Constant", value_t=torch.tensor([1, 1, 1]))
)
elif tensor_rank == 1:
self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[0])
self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[-1])
elif tensor_rank == 2:
self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[-1])
return self
@_onnx_symbolic("prim::ConstantChunk")
@_beartype.beartype
def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim):
input_shape = g.op("Shape", self)
axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
input_shape_dim = g.op("Gather", input_shape, axis, axis_i=0)
start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long))
chunk_size_minus_1 = g.op(
"Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long)
)
input_shape_dim_shift = g.op("Add", input_shape_dim, chunk_size_minus_1)
chunk_dim = g.op("Div", input_shape_dim_shift, chunk_size)
res = []
for i in range(chunks):
index = g.op("Constant", value_t=torch.tensor([i + 1], dtype=torch.long))
end = g.op("Mul", chunk_dim, index)
res.append(g.op("Slice", self, start, end, axis))
start = end
return res
@_onnx_symbolic("aten::hstack")
@_beartype.beartype
def hstack(g: jit_utils.GraphContext, tensor_list: _C.Value):
tensor_list = atleast_1d(g, tensor_list)
first_tensor = g.op(
"SequenceAt",
tensor_list,
g.op("Constant", value_t=torch.tensor(0, dtype=torch.long)),
)
first_tensor_shape = g.op("Shape", first_tensor)
first_tensor_dim = g.op("Size", first_tensor_shape)
const_one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long))
equal_to_one = g.op("Equal", first_tensor_dim, const_one)
(
if_op_greater,
(if_context_equal, else_context_equal),
_,
) = jit_utils.add_op_with_blocks(g, "If", equal_to_one, n_blocks=2, outputs=1)
result_if = if_context_equal.op(
"ConcatFromSequence", tensor_list, axis_i=0, new_axis_i=0
)
utils._add_output_to_block(if_context_equal.block, result_if)
result_else = else_context_equal.op(
"ConcatFromSequence", tensor_list, axis_i=1, new_axis_i=0
)
utils._add_output_to_block(else_context_equal.block, result_else)
result = if_op_greater.node().output()
return result
@_onnx_symbolic("aten::vstack")
@_beartype.beartype
def vstack(g: jit_utils.GraphContext, tensor_list: _C.Value):
tensor_list = atleast_2d(g, tensor_list)
return g.op("ConcatFromSequence", tensor_list, axis_i=0, new_axis_i=0)