226 lines
7.0 KiB
Python
226 lines
7.0 KiB
Python
|
import functools
|
||
|
|
||
|
import torch
|
||
|
import torch._custom_ops
|
||
|
import torch.library
|
||
|
|
||
|
# Ensure that torch.ops.torchvision is visible
|
||
|
import torchvision.extension # noqa: F401
|
||
|
|
||
|
|
||
|
@functools.lru_cache(None)
|
||
|
def get_meta_lib():
|
||
|
return torch.library.Library("torchvision", "IMPL", "Meta")
|
||
|
|
||
|
|
||
|
def register_meta(op_name, overload_name="default"):
|
||
|
def wrapper(fn):
|
||
|
if torchvision.extension._has_ops():
|
||
|
get_meta_lib().impl(getattr(getattr(torch.ops.torchvision, op_name), overload_name), fn)
|
||
|
return fn
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
@register_meta("roi_align")
|
||
|
def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
|
||
|
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
|
||
|
torch._check(
|
||
|
input.dtype == rois.dtype,
|
||
|
lambda: (
|
||
|
"Expected tensor for input to have the same type as tensor for rois; "
|
||
|
f"but type {input.dtype} does not equal {rois.dtype}"
|
||
|
),
|
||
|
)
|
||
|
num_rois = rois.size(0)
|
||
|
channels = input.size(1)
|
||
|
return input.new_empty((num_rois, channels, pooled_height, pooled_width))
|
||
|
|
||
|
|
||
|
@register_meta("_roi_align_backward")
|
||
|
def meta_roi_align_backward(
|
||
|
grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned
|
||
|
):
|
||
|
torch._check(
|
||
|
grad.dtype == rois.dtype,
|
||
|
lambda: (
|
||
|
"Expected tensor for grad to have the same type as tensor for rois; "
|
||
|
f"but type {grad.dtype} does not equal {rois.dtype}"
|
||
|
),
|
||
|
)
|
||
|
return grad.new_empty((batch_size, channels, height, width))
|
||
|
|
||
|
|
||
|
@register_meta("ps_roi_align")
|
||
|
def meta_ps_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio):
|
||
|
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
|
||
|
torch._check(
|
||
|
input.dtype == rois.dtype,
|
||
|
lambda: (
|
||
|
"Expected tensor for input to have the same type as tensor for rois; "
|
||
|
f"but type {input.dtype} does not equal {rois.dtype}"
|
||
|
),
|
||
|
)
|
||
|
channels = input.size(1)
|
||
|
torch._check(
|
||
|
channels % (pooled_height * pooled_width) == 0,
|
||
|
"input channels must be a multiple of pooling height * pooling width",
|
||
|
)
|
||
|
|
||
|
num_rois = rois.size(0)
|
||
|
out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width)
|
||
|
return input.new_empty(out_size), torch.empty(out_size, dtype=torch.int32, device="meta")
|
||
|
|
||
|
|
||
|
@register_meta("_ps_roi_align_backward")
|
||
|
def meta_ps_roi_align_backward(
|
||
|
grad,
|
||
|
rois,
|
||
|
channel_mapping,
|
||
|
spatial_scale,
|
||
|
pooled_height,
|
||
|
pooled_width,
|
||
|
sampling_ratio,
|
||
|
batch_size,
|
||
|
channels,
|
||
|
height,
|
||
|
width,
|
||
|
):
|
||
|
torch._check(
|
||
|
grad.dtype == rois.dtype,
|
||
|
lambda: (
|
||
|
"Expected tensor for grad to have the same type as tensor for rois; "
|
||
|
f"but type {grad.dtype} does not equal {rois.dtype}"
|
||
|
),
|
||
|
)
|
||
|
return grad.new_empty((batch_size, channels, height, width))
|
||
|
|
||
|
|
||
|
@register_meta("roi_pool")
|
||
|
def meta_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width):
|
||
|
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
|
||
|
torch._check(
|
||
|
input.dtype == rois.dtype,
|
||
|
lambda: (
|
||
|
"Expected tensor for input to have the same type as tensor for rois; "
|
||
|
f"but type {input.dtype} does not equal {rois.dtype}"
|
||
|
),
|
||
|
)
|
||
|
num_rois = rois.size(0)
|
||
|
channels = input.size(1)
|
||
|
out_size = (num_rois, channels, pooled_height, pooled_width)
|
||
|
return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32)
|
||
|
|
||
|
|
||
|
@register_meta("_roi_pool_backward")
|
||
|
def meta_roi_pool_backward(
|
||
|
grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width
|
||
|
):
|
||
|
torch._check(
|
||
|
grad.dtype == rois.dtype,
|
||
|
lambda: (
|
||
|
"Expected tensor for grad to have the same type as tensor for rois; "
|
||
|
f"but type {grad.dtype} does not equal {rois.dtype}"
|
||
|
),
|
||
|
)
|
||
|
return grad.new_empty((batch_size, channels, height, width))
|
||
|
|
||
|
|
||
|
@register_meta("ps_roi_pool")
|
||
|
def meta_ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width):
|
||
|
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
|
||
|
torch._check(
|
||
|
input.dtype == rois.dtype,
|
||
|
lambda: (
|
||
|
"Expected tensor for input to have the same type as tensor for rois; "
|
||
|
f"but type {input.dtype} does not equal {rois.dtype}"
|
||
|
),
|
||
|
)
|
||
|
channels = input.size(1)
|
||
|
torch._check(
|
||
|
channels % (pooled_height * pooled_width) == 0,
|
||
|
"input channels must be a multiple of pooling height * pooling width",
|
||
|
)
|
||
|
num_rois = rois.size(0)
|
||
|
out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width)
|
||
|
return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32)
|
||
|
|
||
|
|
||
|
@register_meta("_ps_roi_pool_backward")
|
||
|
def meta_ps_roi_pool_backward(
|
||
|
grad, rois, channel_mapping, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width
|
||
|
):
|
||
|
torch._check(
|
||
|
grad.dtype == rois.dtype,
|
||
|
lambda: (
|
||
|
"Expected tensor for grad to have the same type as tensor for rois; "
|
||
|
f"but type {grad.dtype} does not equal {rois.dtype}"
|
||
|
),
|
||
|
)
|
||
|
return grad.new_empty((batch_size, channels, height, width))
|
||
|
|
||
|
|
||
|
@torch._custom_ops.impl_abstract("torchvision::nms")
|
||
|
def meta_nms(dets, scores, iou_threshold):
|
||
|
torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D")
|
||
|
torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}")
|
||
|
torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}")
|
||
|
torch._check(
|
||
|
dets.size(0) == scores.size(0),
|
||
|
lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}",
|
||
|
)
|
||
|
ctx = torch._custom_ops.get_ctx()
|
||
|
num_to_keep = ctx.create_unbacked_symint()
|
||
|
return dets.new_empty(num_to_keep, dtype=torch.long)
|
||
|
|
||
|
|
||
|
@register_meta("deform_conv2d")
|
||
|
def meta_deform_conv2d(
|
||
|
input,
|
||
|
weight,
|
||
|
offset,
|
||
|
mask,
|
||
|
bias,
|
||
|
stride_h,
|
||
|
stride_w,
|
||
|
pad_h,
|
||
|
pad_w,
|
||
|
dil_h,
|
||
|
dil_w,
|
||
|
n_weight_grps,
|
||
|
n_offset_grps,
|
||
|
use_mask,
|
||
|
):
|
||
|
|
||
|
out_height, out_width = offset.shape[-2:]
|
||
|
out_channels = weight.shape[0]
|
||
|
batch_size = input.shape[0]
|
||
|
return input.new_empty((batch_size, out_channels, out_height, out_width))
|
||
|
|
||
|
|
||
|
@register_meta("_deform_conv2d_backward")
|
||
|
def meta_deform_conv2d_backward(
|
||
|
grad,
|
||
|
input,
|
||
|
weight,
|
||
|
offset,
|
||
|
mask,
|
||
|
bias,
|
||
|
stride_h,
|
||
|
stride_w,
|
||
|
pad_h,
|
||
|
pad_w,
|
||
|
dilation_h,
|
||
|
dilation_w,
|
||
|
groups,
|
||
|
offset_groups,
|
||
|
use_mask,
|
||
|
):
|
||
|
|
||
|
grad_input = input.new_empty(input.shape)
|
||
|
grad_weight = weight.new_empty(weight.shape)
|
||
|
grad_offset = offset.new_empty(offset.shape)
|
||
|
grad_mask = mask.new_empty(mask.shape)
|
||
|
grad_bias = bias.new_empty(bias.shape)
|
||
|
return grad_input, grad_weight, grad_offset, grad_mask, grad_bias
|