100 lines
3.2 KiB
Python
100 lines
3.2 KiB
Python
"""
|
|
helper class that supports empty tensors on some nn functions.
|
|
|
|
Ideally, add support directly in PyTorch to empty tensors in
|
|
those functions.
|
|
|
|
This can be removed once https://github.com/pytorch/pytorch/issues/12013
|
|
is implemented
|
|
"""
|
|
|
|
import warnings
|
|
import torch
|
|
from torch import Tensor
|
|
from typing import List, Optional
|
|
|
|
|
|
class Conv2d(torch.nn.Conv2d):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
warnings.warn(
|
|
"torchvision.ops.misc.Conv2d is deprecated and will be "
|
|
"removed in future versions, use torch.nn.Conv2d instead.", FutureWarning)
|
|
|
|
|
|
class ConvTranspose2d(torch.nn.ConvTranspose2d):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
warnings.warn(
|
|
"torchvision.ops.misc.ConvTranspose2d is deprecated and will be "
|
|
"removed in future versions, use torch.nn.ConvTranspose2d instead.", FutureWarning)
|
|
|
|
|
|
class BatchNorm2d(torch.nn.BatchNorm2d):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
warnings.warn(
|
|
"torchvision.ops.misc.BatchNorm2d is deprecated and will be "
|
|
"removed in future versions, use torch.nn.BatchNorm2d instead.", FutureWarning)
|
|
|
|
|
|
interpolate = torch.nn.functional.interpolate
|
|
|
|
|
|
# This is not in nn
|
|
class FrozenBatchNorm2d(torch.nn.Module):
|
|
"""
|
|
BatchNorm2d where the batch statistics and the affine parameters
|
|
are fixed
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_features: int,
|
|
eps: float = 1e-5,
|
|
n: Optional[int] = None,
|
|
):
|
|
# n=None for backward-compatibility
|
|
if n is not None:
|
|
warnings.warn("`n` argument is deprecated and has been renamed `num_features`",
|
|
DeprecationWarning)
|
|
num_features = n
|
|
super(FrozenBatchNorm2d, self).__init__()
|
|
self.eps = eps
|
|
self.register_buffer("weight", torch.ones(num_features))
|
|
self.register_buffer("bias", torch.zeros(num_features))
|
|
self.register_buffer("running_mean", torch.zeros(num_features))
|
|
self.register_buffer("running_var", torch.ones(num_features))
|
|
|
|
def _load_from_state_dict(
|
|
self,
|
|
state_dict: dict,
|
|
prefix: str,
|
|
local_metadata: dict,
|
|
strict: bool,
|
|
missing_keys: List[str],
|
|
unexpected_keys: List[str],
|
|
error_msgs: List[str],
|
|
):
|
|
num_batches_tracked_key = prefix + 'num_batches_tracked'
|
|
if num_batches_tracked_key in state_dict:
|
|
del state_dict[num_batches_tracked_key]
|
|
|
|
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
|
state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
# move reshapes to the beginning
|
|
# to make it fuser-friendly
|
|
w = self.weight.reshape(1, -1, 1, 1)
|
|
b = self.bias.reshape(1, -1, 1, 1)
|
|
rv = self.running_var.reshape(1, -1, 1, 1)
|
|
rm = self.running_mean.reshape(1, -1, 1, 1)
|
|
scale = w * (rv + self.eps).rsqrt()
|
|
bias = b - rm * scale
|
|
return x * scale + bias
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})"
|