898 lines
32 KiB
Python
898 lines
32 KiB
Python
|
import math
|
||
|
from dataclasses import dataclass
|
||
|
from functools import partial
|
||
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
|
||
|
|
||
|
import torch
|
||
|
import torch.fx
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from ...ops import MLP, StochasticDepth
|
||
|
from ...transforms._presets import VideoClassification
|
||
|
from ...utils import _log_api_usage_once
|
||
|
from .._api import register_model, Weights, WeightsEnum
|
||
|
from .._meta import _KINETICS400_CATEGORIES
|
||
|
from .._utils import _ovewrite_named_param, handle_legacy_interface
|
||
|
|
||
|
|
||
|
__all__ = [
|
||
|
"MViT",
|
||
|
"MViT_V1_B_Weights",
|
||
|
"mvit_v1_b",
|
||
|
"MViT_V2_S_Weights",
|
||
|
"mvit_v2_s",
|
||
|
]
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class MSBlockConfig:
|
||
|
num_heads: int
|
||
|
input_channels: int
|
||
|
output_channels: int
|
||
|
kernel_q: List[int]
|
||
|
kernel_kv: List[int]
|
||
|
stride_q: List[int]
|
||
|
stride_kv: List[int]
|
||
|
|
||
|
|
||
|
def _prod(s: Sequence[int]) -> int:
|
||
|
product = 1
|
||
|
for v in s:
|
||
|
product *= v
|
||
|
return product
|
||
|
|
||
|
|
||
|
def _unsqueeze(x: torch.Tensor, target_dim: int, expand_dim: int) -> Tuple[torch.Tensor, int]:
|
||
|
tensor_dim = x.dim()
|
||
|
if tensor_dim == target_dim - 1:
|
||
|
x = x.unsqueeze(expand_dim)
|
||
|
elif tensor_dim != target_dim:
|
||
|
raise ValueError(f"Unsupported input dimension {x.shape}")
|
||
|
return x, tensor_dim
|
||
|
|
||
|
|
||
|
def _squeeze(x: torch.Tensor, target_dim: int, expand_dim: int, tensor_dim: int) -> torch.Tensor:
|
||
|
if tensor_dim == target_dim - 1:
|
||
|
x = x.squeeze(expand_dim)
|
||
|
return x
|
||
|
|
||
|
|
||
|
torch.fx.wrap("_unsqueeze")
|
||
|
torch.fx.wrap("_squeeze")
|
||
|
|
||
|
|
||
|
class Pool(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
pool: nn.Module,
|
||
|
norm: Optional[nn.Module],
|
||
|
activation: Optional[nn.Module] = None,
|
||
|
norm_before_pool: bool = False,
|
||
|
) -> None:
|
||
|
super().__init__()
|
||
|
self.pool = pool
|
||
|
layers = []
|
||
|
if norm is not None:
|
||
|
layers.append(norm)
|
||
|
if activation is not None:
|
||
|
layers.append(activation)
|
||
|
self.norm_act = nn.Sequential(*layers) if layers else None
|
||
|
self.norm_before_pool = norm_before_pool
|
||
|
|
||
|
def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
|
||
|
x, tensor_dim = _unsqueeze(x, 4, 1)
|
||
|
|
||
|
# Separate the class token and reshape the input
|
||
|
class_token, x = torch.tensor_split(x, indices=(1,), dim=2)
|
||
|
x = x.transpose(2, 3)
|
||
|
B, N, C = x.shape[:3]
|
||
|
x = x.reshape((B * N, C) + thw).contiguous()
|
||
|
|
||
|
# normalizing prior pooling is useful when we use BN which can be absorbed to speed up inference
|
||
|
if self.norm_before_pool and self.norm_act is not None:
|
||
|
x = self.norm_act(x)
|
||
|
|
||
|
# apply the pool on the input and add back the token
|
||
|
x = self.pool(x)
|
||
|
T, H, W = x.shape[2:]
|
||
|
x = x.reshape(B, N, C, -1).transpose(2, 3)
|
||
|
x = torch.cat((class_token, x), dim=2)
|
||
|
|
||
|
if not self.norm_before_pool and self.norm_act is not None:
|
||
|
x = self.norm_act(x)
|
||
|
|
||
|
x = _squeeze(x, 4, 1, tensor_dim)
|
||
|
return x, (T, H, W)
|
||
|
|
||
|
|
||
|
def _interpolate(embedding: torch.Tensor, d: int) -> torch.Tensor:
|
||
|
if embedding.shape[0] == d:
|
||
|
return embedding
|
||
|
|
||
|
return (
|
||
|
nn.functional.interpolate(
|
||
|
embedding.permute(1, 0).unsqueeze(0),
|
||
|
size=d,
|
||
|
mode="linear",
|
||
|
)
|
||
|
.squeeze(0)
|
||
|
.permute(1, 0)
|
||
|
)
|
||
|
|
||
|
|
||
|
def _add_rel_pos(
|
||
|
attn: torch.Tensor,
|
||
|
q: torch.Tensor,
|
||
|
q_thw: Tuple[int, int, int],
|
||
|
k_thw: Tuple[int, int, int],
|
||
|
rel_pos_h: torch.Tensor,
|
||
|
rel_pos_w: torch.Tensor,
|
||
|
rel_pos_t: torch.Tensor,
|
||
|
) -> torch.Tensor:
|
||
|
# Modified code from: https://github.com/facebookresearch/SlowFast/commit/1aebd71a2efad823d52b827a3deaf15a56cf4932
|
||
|
q_t, q_h, q_w = q_thw
|
||
|
k_t, k_h, k_w = k_thw
|
||
|
dh = int(2 * max(q_h, k_h) - 1)
|
||
|
dw = int(2 * max(q_w, k_w) - 1)
|
||
|
dt = int(2 * max(q_t, k_t) - 1)
|
||
|
|
||
|
# Scale up rel pos if shapes for q and k are different.
|
||
|
q_h_ratio = max(k_h / q_h, 1.0)
|
||
|
k_h_ratio = max(q_h / k_h, 1.0)
|
||
|
dist_h = torch.arange(q_h)[:, None] * q_h_ratio - (torch.arange(k_h)[None, :] + (1.0 - k_h)) * k_h_ratio
|
||
|
q_w_ratio = max(k_w / q_w, 1.0)
|
||
|
k_w_ratio = max(q_w / k_w, 1.0)
|
||
|
dist_w = torch.arange(q_w)[:, None] * q_w_ratio - (torch.arange(k_w)[None, :] + (1.0 - k_w)) * k_w_ratio
|
||
|
q_t_ratio = max(k_t / q_t, 1.0)
|
||
|
k_t_ratio = max(q_t / k_t, 1.0)
|
||
|
dist_t = torch.arange(q_t)[:, None] * q_t_ratio - (torch.arange(k_t)[None, :] + (1.0 - k_t)) * k_t_ratio
|
||
|
|
||
|
# Interpolate rel pos if needed.
|
||
|
rel_pos_h = _interpolate(rel_pos_h, dh)
|
||
|
rel_pos_w = _interpolate(rel_pos_w, dw)
|
||
|
rel_pos_t = _interpolate(rel_pos_t, dt)
|
||
|
Rh = rel_pos_h[dist_h.long()]
|
||
|
Rw = rel_pos_w[dist_w.long()]
|
||
|
Rt = rel_pos_t[dist_t.long()]
|
||
|
|
||
|
B, n_head, _, dim = q.shape
|
||
|
|
||
|
r_q = q[:, :, 1:].reshape(B, n_head, q_t, q_h, q_w, dim)
|
||
|
rel_h_q = torch.einsum("bythwc,hkc->bythwk", r_q, Rh) # [B, H, q_t, qh, qw, k_h]
|
||
|
rel_w_q = torch.einsum("bythwc,wkc->bythwk", r_q, Rw) # [B, H, q_t, qh, qw, k_w]
|
||
|
# [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim]
|
||
|
r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape(q_t, B * n_head * q_h * q_w, dim)
|
||
|
# [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t]
|
||
|
rel_q_t = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1)
|
||
|
# [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t]
|
||
|
rel_q_t = rel_q_t.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5)
|
||
|
|
||
|
# Combine rel pos.
|
||
|
rel_pos = (
|
||
|
rel_h_q[:, :, :, :, :, None, :, None]
|
||
|
+ rel_w_q[:, :, :, :, :, None, None, :]
|
||
|
+ rel_q_t[:, :, :, :, :, :, None, None]
|
||
|
).reshape(B, n_head, q_t * q_h * q_w, k_t * k_h * k_w)
|
||
|
|
||
|
# Add it to attention
|
||
|
attn[:, :, 1:, 1:] += rel_pos
|
||
|
|
||
|
return attn
|
||
|
|
||
|
|
||
|
def _add_shortcut(x: torch.Tensor, shortcut: torch.Tensor, residual_with_cls_embed: bool):
|
||
|
if residual_with_cls_embed:
|
||
|
x.add_(shortcut)
|
||
|
else:
|
||
|
x[:, :, 1:, :] += shortcut[:, :, 1:, :]
|
||
|
return x
|
||
|
|
||
|
|
||
|
torch.fx.wrap("_add_rel_pos")
|
||
|
torch.fx.wrap("_add_shortcut")
|
||
|
|
||
|
|
||
|
class MultiscaleAttention(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
input_size: List[int],
|
||
|
embed_dim: int,
|
||
|
output_dim: int,
|
||
|
num_heads: int,
|
||
|
kernel_q: List[int],
|
||
|
kernel_kv: List[int],
|
||
|
stride_q: List[int],
|
||
|
stride_kv: List[int],
|
||
|
residual_pool: bool,
|
||
|
residual_with_cls_embed: bool,
|
||
|
rel_pos_embed: bool,
|
||
|
dropout: float = 0.0,
|
||
|
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
||
|
) -> None:
|
||
|
super().__init__()
|
||
|
self.embed_dim = embed_dim
|
||
|
self.output_dim = output_dim
|
||
|
self.num_heads = num_heads
|
||
|
self.head_dim = output_dim // num_heads
|
||
|
self.scaler = 1.0 / math.sqrt(self.head_dim)
|
||
|
self.residual_pool = residual_pool
|
||
|
self.residual_with_cls_embed = residual_with_cls_embed
|
||
|
|
||
|
self.qkv = nn.Linear(embed_dim, 3 * output_dim)
|
||
|
layers: List[nn.Module] = [nn.Linear(output_dim, output_dim)]
|
||
|
if dropout > 0.0:
|
||
|
layers.append(nn.Dropout(dropout, inplace=True))
|
||
|
self.project = nn.Sequential(*layers)
|
||
|
|
||
|
self.pool_q: Optional[nn.Module] = None
|
||
|
if _prod(kernel_q) > 1 or _prod(stride_q) > 1:
|
||
|
padding_q = [int(q // 2) for q in kernel_q]
|
||
|
self.pool_q = Pool(
|
||
|
nn.Conv3d(
|
||
|
self.head_dim,
|
||
|
self.head_dim,
|
||
|
kernel_q, # type: ignore[arg-type]
|
||
|
stride=stride_q, # type: ignore[arg-type]
|
||
|
padding=padding_q, # type: ignore[arg-type]
|
||
|
groups=self.head_dim,
|
||
|
bias=False,
|
||
|
),
|
||
|
norm_layer(self.head_dim),
|
||
|
)
|
||
|
|
||
|
self.pool_k: Optional[nn.Module] = None
|
||
|
self.pool_v: Optional[nn.Module] = None
|
||
|
if _prod(kernel_kv) > 1 or _prod(stride_kv) > 1:
|
||
|
padding_kv = [int(kv // 2) for kv in kernel_kv]
|
||
|
self.pool_k = Pool(
|
||
|
nn.Conv3d(
|
||
|
self.head_dim,
|
||
|
self.head_dim,
|
||
|
kernel_kv, # type: ignore[arg-type]
|
||
|
stride=stride_kv, # type: ignore[arg-type]
|
||
|
padding=padding_kv, # type: ignore[arg-type]
|
||
|
groups=self.head_dim,
|
||
|
bias=False,
|
||
|
),
|
||
|
norm_layer(self.head_dim),
|
||
|
)
|
||
|
self.pool_v = Pool(
|
||
|
nn.Conv3d(
|
||
|
self.head_dim,
|
||
|
self.head_dim,
|
||
|
kernel_kv, # type: ignore[arg-type]
|
||
|
stride=stride_kv, # type: ignore[arg-type]
|
||
|
padding=padding_kv, # type: ignore[arg-type]
|
||
|
groups=self.head_dim,
|
||
|
bias=False,
|
||
|
),
|
||
|
norm_layer(self.head_dim),
|
||
|
)
|
||
|
|
||
|
self.rel_pos_h: Optional[nn.Parameter] = None
|
||
|
self.rel_pos_w: Optional[nn.Parameter] = None
|
||
|
self.rel_pos_t: Optional[nn.Parameter] = None
|
||
|
if rel_pos_embed:
|
||
|
size = max(input_size[1:])
|
||
|
q_size = size // stride_q[1] if len(stride_q) > 0 else size
|
||
|
kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
|
||
|
spatial_dim = 2 * max(q_size, kv_size) - 1
|
||
|
temporal_dim = 2 * input_size[0] - 1
|
||
|
self.rel_pos_h = nn.Parameter(torch.zeros(spatial_dim, self.head_dim))
|
||
|
self.rel_pos_w = nn.Parameter(torch.zeros(spatial_dim, self.head_dim))
|
||
|
self.rel_pos_t = nn.Parameter(torch.zeros(temporal_dim, self.head_dim))
|
||
|
nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
|
||
|
nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
|
||
|
nn.init.trunc_normal_(self.rel_pos_t, std=0.02)
|
||
|
|
||
|
def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
|
||
|
B, N, C = x.shape
|
||
|
q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(dim=2)
|
||
|
|
||
|
if self.pool_k is not None:
|
||
|
k, k_thw = self.pool_k(k, thw)
|
||
|
else:
|
||
|
k_thw = thw
|
||
|
if self.pool_v is not None:
|
||
|
v = self.pool_v(v, thw)[0]
|
||
|
if self.pool_q is not None:
|
||
|
q, thw = self.pool_q(q, thw)
|
||
|
|
||
|
attn = torch.matmul(self.scaler * q, k.transpose(2, 3))
|
||
|
if self.rel_pos_h is not None and self.rel_pos_w is not None and self.rel_pos_t is not None:
|
||
|
attn = _add_rel_pos(
|
||
|
attn,
|
||
|
q,
|
||
|
thw,
|
||
|
k_thw,
|
||
|
self.rel_pos_h,
|
||
|
self.rel_pos_w,
|
||
|
self.rel_pos_t,
|
||
|
)
|
||
|
attn = attn.softmax(dim=-1)
|
||
|
|
||
|
x = torch.matmul(attn, v)
|
||
|
if self.residual_pool:
|
||
|
_add_shortcut(x, q, self.residual_with_cls_embed)
|
||
|
x = x.transpose(1, 2).reshape(B, -1, self.output_dim)
|
||
|
x = self.project(x)
|
||
|
|
||
|
return x, thw
|
||
|
|
||
|
|
||
|
class MultiscaleBlock(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
input_size: List[int],
|
||
|
cnf: MSBlockConfig,
|
||
|
residual_pool: bool,
|
||
|
residual_with_cls_embed: bool,
|
||
|
rel_pos_embed: bool,
|
||
|
proj_after_attn: bool,
|
||
|
dropout: float = 0.0,
|
||
|
stochastic_depth_prob: float = 0.0,
|
||
|
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
||
|
) -> None:
|
||
|
super().__init__()
|
||
|
self.proj_after_attn = proj_after_attn
|
||
|
|
||
|
self.pool_skip: Optional[nn.Module] = None
|
||
|
if _prod(cnf.stride_q) > 1:
|
||
|
kernel_skip = [s + 1 if s > 1 else s for s in cnf.stride_q]
|
||
|
padding_skip = [int(k // 2) for k in kernel_skip]
|
||
|
self.pool_skip = Pool(
|
||
|
nn.MaxPool3d(kernel_skip, stride=cnf.stride_q, padding=padding_skip), None # type: ignore[arg-type]
|
||
|
)
|
||
|
|
||
|
attn_dim = cnf.output_channels if proj_after_attn else cnf.input_channels
|
||
|
|
||
|
self.norm1 = norm_layer(cnf.input_channels)
|
||
|
self.norm2 = norm_layer(attn_dim)
|
||
|
self.needs_transposal = isinstance(self.norm1, nn.BatchNorm1d)
|
||
|
|
||
|
self.attn = MultiscaleAttention(
|
||
|
input_size,
|
||
|
cnf.input_channels,
|
||
|
attn_dim,
|
||
|
cnf.num_heads,
|
||
|
kernel_q=cnf.kernel_q,
|
||
|
kernel_kv=cnf.kernel_kv,
|
||
|
stride_q=cnf.stride_q,
|
||
|
stride_kv=cnf.stride_kv,
|
||
|
rel_pos_embed=rel_pos_embed,
|
||
|
residual_pool=residual_pool,
|
||
|
residual_with_cls_embed=residual_with_cls_embed,
|
||
|
dropout=dropout,
|
||
|
norm_layer=norm_layer,
|
||
|
)
|
||
|
self.mlp = MLP(
|
||
|
attn_dim,
|
||
|
[4 * attn_dim, cnf.output_channels],
|
||
|
activation_layer=nn.GELU,
|
||
|
dropout=dropout,
|
||
|
inplace=None,
|
||
|
)
|
||
|
|
||
|
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
|
||
|
|
||
|
self.project: Optional[nn.Module] = None
|
||
|
if cnf.input_channels != cnf.output_channels:
|
||
|
self.project = nn.Linear(cnf.input_channels, cnf.output_channels)
|
||
|
|
||
|
def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
|
||
|
x_norm1 = self.norm1(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm1(x)
|
||
|
x_attn, thw_new = self.attn(x_norm1, thw)
|
||
|
x = x if self.project is None or not self.proj_after_attn else self.project(x_norm1)
|
||
|
x_skip = x if self.pool_skip is None else self.pool_skip(x, thw)[0]
|
||
|
x = x_skip + self.stochastic_depth(x_attn)
|
||
|
|
||
|
x_norm2 = self.norm2(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm2(x)
|
||
|
x_proj = x if self.project is None or self.proj_after_attn else self.project(x_norm2)
|
||
|
|
||
|
return x_proj + self.stochastic_depth(self.mlp(x_norm2)), thw_new
|
||
|
|
||
|
|
||
|
class PositionalEncoding(nn.Module):
|
||
|
def __init__(self, embed_size: int, spatial_size: Tuple[int, int], temporal_size: int, rel_pos_embed: bool) -> None:
|
||
|
super().__init__()
|
||
|
self.spatial_size = spatial_size
|
||
|
self.temporal_size = temporal_size
|
||
|
|
||
|
self.class_token = nn.Parameter(torch.zeros(embed_size))
|
||
|
self.spatial_pos: Optional[nn.Parameter] = None
|
||
|
self.temporal_pos: Optional[nn.Parameter] = None
|
||
|
self.class_pos: Optional[nn.Parameter] = None
|
||
|
if not rel_pos_embed:
|
||
|
self.spatial_pos = nn.Parameter(torch.zeros(self.spatial_size[0] * self.spatial_size[1], embed_size))
|
||
|
self.temporal_pos = nn.Parameter(torch.zeros(self.temporal_size, embed_size))
|
||
|
self.class_pos = nn.Parameter(torch.zeros(embed_size))
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
class_token = self.class_token.expand(x.size(0), -1).unsqueeze(1)
|
||
|
x = torch.cat((class_token, x), dim=1)
|
||
|
|
||
|
if self.spatial_pos is not None and self.temporal_pos is not None and self.class_pos is not None:
|
||
|
hw_size, embed_size = self.spatial_pos.shape
|
||
|
pos_embedding = torch.repeat_interleave(self.temporal_pos, hw_size, dim=0)
|
||
|
pos_embedding.add_(self.spatial_pos.unsqueeze(0).expand(self.temporal_size, -1, -1).reshape(-1, embed_size))
|
||
|
pos_embedding = torch.cat((self.class_pos.unsqueeze(0), pos_embedding), dim=0).unsqueeze(0)
|
||
|
x.add_(pos_embedding)
|
||
|
|
||
|
return x
|
||
|
|
||
|
|
||
|
class MViT(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
spatial_size: Tuple[int, int],
|
||
|
temporal_size: int,
|
||
|
block_setting: Sequence[MSBlockConfig],
|
||
|
residual_pool: bool,
|
||
|
residual_with_cls_embed: bool,
|
||
|
rel_pos_embed: bool,
|
||
|
proj_after_attn: bool,
|
||
|
dropout: float = 0.5,
|
||
|
attention_dropout: float = 0.0,
|
||
|
stochastic_depth_prob: float = 0.0,
|
||
|
num_classes: int = 400,
|
||
|
block: Optional[Callable[..., nn.Module]] = None,
|
||
|
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
||
|
patch_embed_kernel: Tuple[int, int, int] = (3, 7, 7),
|
||
|
patch_embed_stride: Tuple[int, int, int] = (2, 4, 4),
|
||
|
patch_embed_padding: Tuple[int, int, int] = (1, 3, 3),
|
||
|
) -> None:
|
||
|
"""
|
||
|
MViT main class.
|
||
|
|
||
|
Args:
|
||
|
spatial_size (tuple of ints): The spacial size of the input as ``(H, W)``.
|
||
|
temporal_size (int): The temporal size ``T`` of the input.
|
||
|
block_setting (sequence of MSBlockConfig): The Network structure.
|
||
|
residual_pool (bool): If True, use MViTv2 pooling residual connection.
|
||
|
residual_with_cls_embed (bool): If True, the addition on the residual connection will include
|
||
|
the class embedding.
|
||
|
rel_pos_embed (bool): If True, use MViTv2's relative positional embeddings.
|
||
|
proj_after_attn (bool): If True, apply the projection after the attention.
|
||
|
dropout (float): Dropout rate. Default: 0.0.
|
||
|
attention_dropout (float): Attention dropout rate. Default: 0.0.
|
||
|
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
|
||
|
num_classes (int): The number of classes.
|
||
|
block (callable, optional): Module specifying the layer which consists of the attention and mlp.
|
||
|
norm_layer (callable, optional): Module specifying the normalization layer to use.
|
||
|
patch_embed_kernel (tuple of ints): The kernel of the convolution that patchifies the input.
|
||
|
patch_embed_stride (tuple of ints): The stride of the convolution that patchifies the input.
|
||
|
patch_embed_padding (tuple of ints): The padding of the convolution that patchifies the input.
|
||
|
"""
|
||
|
super().__init__()
|
||
|
# This implementation employs a different parameterization scheme than the one used at PyTorch Video:
|
||
|
# https://github.com/facebookresearch/pytorchvideo/blob/718d0a4/pytorchvideo/models/vision_transformers.py
|
||
|
# We remove any experimental configuration that didn't make it to the final variants of the models. To represent
|
||
|
# the configuration of the architecture we use the simplified form suggested at Table 1 of the paper.
|
||
|
_log_api_usage_once(self)
|
||
|
total_stage_blocks = len(block_setting)
|
||
|
if total_stage_blocks == 0:
|
||
|
raise ValueError("The configuration parameter can't be empty.")
|
||
|
|
||
|
if block is None:
|
||
|
block = MultiscaleBlock
|
||
|
|
||
|
if norm_layer is None:
|
||
|
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||
|
|
||
|
# Patch Embedding module
|
||
|
self.conv_proj = nn.Conv3d(
|
||
|
in_channels=3,
|
||
|
out_channels=block_setting[0].input_channels,
|
||
|
kernel_size=patch_embed_kernel,
|
||
|
stride=patch_embed_stride,
|
||
|
padding=patch_embed_padding,
|
||
|
)
|
||
|
|
||
|
input_size = [size // stride for size, stride in zip((temporal_size,) + spatial_size, self.conv_proj.stride)]
|
||
|
|
||
|
# Spatio-Temporal Class Positional Encoding
|
||
|
self.pos_encoding = PositionalEncoding(
|
||
|
embed_size=block_setting[0].input_channels,
|
||
|
spatial_size=(input_size[1], input_size[2]),
|
||
|
temporal_size=input_size[0],
|
||
|
rel_pos_embed=rel_pos_embed,
|
||
|
)
|
||
|
|
||
|
# Encoder module
|
||
|
self.blocks = nn.ModuleList()
|
||
|
for stage_block_id, cnf in enumerate(block_setting):
|
||
|
# adjust stochastic depth probability based on the depth of the stage block
|
||
|
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
|
||
|
|
||
|
self.blocks.append(
|
||
|
block(
|
||
|
input_size=input_size,
|
||
|
cnf=cnf,
|
||
|
residual_pool=residual_pool,
|
||
|
residual_with_cls_embed=residual_with_cls_embed,
|
||
|
rel_pos_embed=rel_pos_embed,
|
||
|
proj_after_attn=proj_after_attn,
|
||
|
dropout=attention_dropout,
|
||
|
stochastic_depth_prob=sd_prob,
|
||
|
norm_layer=norm_layer,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
if len(cnf.stride_q) > 0:
|
||
|
input_size = [size // stride for size, stride in zip(input_size, cnf.stride_q)]
|
||
|
self.norm = norm_layer(block_setting[-1].output_channels)
|
||
|
|
||
|
# Classifier module
|
||
|
self.head = nn.Sequential(
|
||
|
nn.Dropout(dropout, inplace=True),
|
||
|
nn.Linear(block_setting[-1].output_channels, num_classes),
|
||
|
)
|
||
|
|
||
|
for m in self.modules():
|
||
|
if isinstance(m, nn.Linear):
|
||
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||
|
nn.init.constant_(m.bias, 0.0)
|
||
|
elif isinstance(m, nn.LayerNorm):
|
||
|
if m.weight is not None:
|
||
|
nn.init.constant_(m.weight, 1.0)
|
||
|
if m.bias is not None:
|
||
|
nn.init.constant_(m.bias, 0.0)
|
||
|
elif isinstance(m, PositionalEncoding):
|
||
|
for weights in m.parameters():
|
||
|
nn.init.trunc_normal_(weights, std=0.02)
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
# Convert if necessary (B, C, H, W) -> (B, C, 1, H, W)
|
||
|
x = _unsqueeze(x, 5, 2)[0]
|
||
|
# patchify and reshape: (B, C, T, H, W) -> (B, embed_channels[0], T', H', W') -> (B, THW', embed_channels[0])
|
||
|
x = self.conv_proj(x)
|
||
|
x = x.flatten(2).transpose(1, 2)
|
||
|
|
||
|
# add positional encoding
|
||
|
x = self.pos_encoding(x)
|
||
|
|
||
|
# pass patches through the encoder
|
||
|
thw = (self.pos_encoding.temporal_size,) + self.pos_encoding.spatial_size
|
||
|
for block in self.blocks:
|
||
|
x, thw = block(x, thw)
|
||
|
x = self.norm(x)
|
||
|
|
||
|
# classifier "token" as used by standard language architectures
|
||
|
x = x[:, 0]
|
||
|
x = self.head(x)
|
||
|
|
||
|
return x
|
||
|
|
||
|
|
||
|
def _mvit(
|
||
|
block_setting: List[MSBlockConfig],
|
||
|
stochastic_depth_prob: float,
|
||
|
weights: Optional[WeightsEnum],
|
||
|
progress: bool,
|
||
|
**kwargs: Any,
|
||
|
) -> MViT:
|
||
|
if weights is not None:
|
||
|
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
|
||
|
assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
|
||
|
_ovewrite_named_param(kwargs, "spatial_size", weights.meta["min_size"])
|
||
|
_ovewrite_named_param(kwargs, "temporal_size", weights.meta["min_temporal_size"])
|
||
|
spatial_size = kwargs.pop("spatial_size", (224, 224))
|
||
|
temporal_size = kwargs.pop("temporal_size", 16)
|
||
|
|
||
|
model = MViT(
|
||
|
spatial_size=spatial_size,
|
||
|
temporal_size=temporal_size,
|
||
|
block_setting=block_setting,
|
||
|
residual_pool=kwargs.pop("residual_pool", False),
|
||
|
residual_with_cls_embed=kwargs.pop("residual_with_cls_embed", True),
|
||
|
rel_pos_embed=kwargs.pop("rel_pos_embed", False),
|
||
|
proj_after_attn=kwargs.pop("proj_after_attn", False),
|
||
|
stochastic_depth_prob=stochastic_depth_prob,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
if weights is not None:
|
||
|
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
||
|
|
||
|
return model
|
||
|
|
||
|
|
||
|
class MViT_V1_B_Weights(WeightsEnum):
|
||
|
KINETICS400_V1 = Weights(
|
||
|
url="https://download.pytorch.org/models/mvit_v1_b-dbeb1030.pth",
|
||
|
transforms=partial(
|
||
|
VideoClassification,
|
||
|
crop_size=(224, 224),
|
||
|
resize_size=(256,),
|
||
|
mean=(0.45, 0.45, 0.45),
|
||
|
std=(0.225, 0.225, 0.225),
|
||
|
),
|
||
|
meta={
|
||
|
"min_size": (224, 224),
|
||
|
"min_temporal_size": 16,
|
||
|
"categories": _KINETICS400_CATEGORIES,
|
||
|
"recipe": "https://github.com/facebookresearch/pytorchvideo/blob/main/docs/source/model_zoo.md",
|
||
|
"_docs": (
|
||
|
"The weights were ported from the paper. The accuracies are estimated on video-level "
|
||
|
"with parameters `frame_rate=7.5`, `clips_per_video=5`, and `clip_len=16`"
|
||
|
),
|
||
|
"num_params": 36610672,
|
||
|
"_metrics": {
|
||
|
"Kinetics-400": {
|
||
|
"acc@1": 78.477,
|
||
|
"acc@5": 93.582,
|
||
|
}
|
||
|
},
|
||
|
"_ops": 70.599,
|
||
|
"_file_size": 139.764,
|
||
|
},
|
||
|
)
|
||
|
DEFAULT = KINETICS400_V1
|
||
|
|
||
|
|
||
|
class MViT_V2_S_Weights(WeightsEnum):
|
||
|
KINETICS400_V1 = Weights(
|
||
|
url="https://download.pytorch.org/models/mvit_v2_s-ae3be167.pth",
|
||
|
transforms=partial(
|
||
|
VideoClassification,
|
||
|
crop_size=(224, 224),
|
||
|
resize_size=(256,),
|
||
|
mean=(0.45, 0.45, 0.45),
|
||
|
std=(0.225, 0.225, 0.225),
|
||
|
),
|
||
|
meta={
|
||
|
"min_size": (224, 224),
|
||
|
"min_temporal_size": 16,
|
||
|
"categories": _KINETICS400_CATEGORIES,
|
||
|
"recipe": "https://github.com/facebookresearch/SlowFast/blob/main/MODEL_ZOO.md",
|
||
|
"_docs": (
|
||
|
"The weights were ported from the paper. The accuracies are estimated on video-level "
|
||
|
"with parameters `frame_rate=7.5`, `clips_per_video=5`, and `clip_len=16`"
|
||
|
),
|
||
|
"num_params": 34537744,
|
||
|
"_metrics": {
|
||
|
"Kinetics-400": {
|
||
|
"acc@1": 80.757,
|
||
|
"acc@5": 94.665,
|
||
|
}
|
||
|
},
|
||
|
"_ops": 64.224,
|
||
|
"_file_size": 131.884,
|
||
|
},
|
||
|
)
|
||
|
DEFAULT = KINETICS400_V1
|
||
|
|
||
|
|
||
|
@register_model()
|
||
|
@handle_legacy_interface(weights=("pretrained", MViT_V1_B_Weights.KINETICS400_V1))
|
||
|
def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT:
|
||
|
"""
|
||
|
Constructs a base MViTV1 architecture from
|
||
|
`Multiscale Vision Transformers <https://arxiv.org/abs/2104.11227>`__.
|
||
|
|
||
|
.. betastatus:: video module
|
||
|
|
||
|
Args:
|
||
|
weights (:class:`~torchvision.models.video.MViT_V1_B_Weights`, optional): The
|
||
|
pretrained weights to use. See
|
||
|
:class:`~torchvision.models.video.MViT_V1_B_Weights` below for
|
||
|
more details, and possible values. By default, no pre-trained
|
||
|
weights are used.
|
||
|
progress (bool, optional): If True, displays a progress bar of the
|
||
|
download to stderr. Default is True.
|
||
|
**kwargs: parameters passed to the ``torchvision.models.video.MViT``
|
||
|
base class. Please refer to the `source code
|
||
|
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvit.py>`_
|
||
|
for more details about this class.
|
||
|
|
||
|
.. autoclass:: torchvision.models.video.MViT_V1_B_Weights
|
||
|
:members:
|
||
|
"""
|
||
|
weights = MViT_V1_B_Weights.verify(weights)
|
||
|
|
||
|
config: Dict[str, List] = {
|
||
|
"num_heads": [1, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 8],
|
||
|
"input_channels": [96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768],
|
||
|
"output_channels": [192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768, 768],
|
||
|
"kernel_q": [[], [3, 3, 3], [], [3, 3, 3], [], [], [], [], [], [], [], [], [], [], [3, 3, 3], []],
|
||
|
"kernel_kv": [
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
],
|
||
|
"stride_q": [[], [1, 2, 2], [], [1, 2, 2], [], [], [], [], [], [], [], [], [], [], [1, 2, 2], []],
|
||
|
"stride_kv": [
|
||
|
[1, 8, 8],
|
||
|
[1, 4, 4],
|
||
|
[1, 4, 4],
|
||
|
[1, 2, 2],
|
||
|
[1, 2, 2],
|
||
|
[1, 2, 2],
|
||
|
[1, 2, 2],
|
||
|
[1, 2, 2],
|
||
|
[1, 2, 2],
|
||
|
[1, 2, 2],
|
||
|
[1, 2, 2],
|
||
|
[1, 2, 2],
|
||
|
[1, 2, 2],
|
||
|
[1, 2, 2],
|
||
|
[1, 1, 1],
|
||
|
[1, 1, 1],
|
||
|
],
|
||
|
}
|
||
|
|
||
|
block_setting = []
|
||
|
for i in range(len(config["num_heads"])):
|
||
|
block_setting.append(
|
||
|
MSBlockConfig(
|
||
|
num_heads=config["num_heads"][i],
|
||
|
input_channels=config["input_channels"][i],
|
||
|
output_channels=config["output_channels"][i],
|
||
|
kernel_q=config["kernel_q"][i],
|
||
|
kernel_kv=config["kernel_kv"][i],
|
||
|
stride_q=config["stride_q"][i],
|
||
|
stride_kv=config["stride_kv"][i],
|
||
|
)
|
||
|
)
|
||
|
|
||
|
return _mvit(
|
||
|
spatial_size=(224, 224),
|
||
|
temporal_size=16,
|
||
|
block_setting=block_setting,
|
||
|
residual_pool=False,
|
||
|
residual_with_cls_embed=False,
|
||
|
stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2),
|
||
|
weights=weights,
|
||
|
progress=progress,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
|
||
|
@register_model()
|
||
|
@handle_legacy_interface(weights=("pretrained", MViT_V2_S_Weights.KINETICS400_V1))
|
||
|
def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT:
|
||
|
"""Constructs a small MViTV2 architecture from
|
||
|
`Multiscale Vision Transformers <https://arxiv.org/abs/2104.11227>`__ and
|
||
|
`MViTv2: Improved Multiscale Vision Transformers for Classification
|
||
|
and Detection <https://arxiv.org/abs/2112.01526>`__.
|
||
|
|
||
|
.. betastatus:: video module
|
||
|
|
||
|
Args:
|
||
|
weights (:class:`~torchvision.models.video.MViT_V2_S_Weights`, optional): The
|
||
|
pretrained weights to use. See
|
||
|
:class:`~torchvision.models.video.MViT_V2_S_Weights` below for
|
||
|
more details, and possible values. By default, no pre-trained
|
||
|
weights are used.
|
||
|
progress (bool, optional): If True, displays a progress bar of the
|
||
|
download to stderr. Default is True.
|
||
|
**kwargs: parameters passed to the ``torchvision.models.video.MViT``
|
||
|
base class. Please refer to the `source code
|
||
|
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvit.py>`_
|
||
|
for more details about this class.
|
||
|
|
||
|
.. autoclass:: torchvision.models.video.MViT_V2_S_Weights
|
||
|
:members:
|
||
|
"""
|
||
|
weights = MViT_V2_S_Weights.verify(weights)
|
||
|
|
||
|
config: Dict[str, List] = {
|
||
|
"num_heads": [1, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 8],
|
||
|
"input_channels": [96, 96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768],
|
||
|
"output_channels": [96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768],
|
||
|
"kernel_q": [
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
],
|
||
|
"kernel_kv": [
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
[3, 3, 3],
|
||
|
],
|
||
|
"stride_q": [
|
||
|
[1, 1, 1],
|
||
|
[1, 2, 2],
|
||
|
[1, 1, 1],
|
||
|
[1, 2, 2],
|
||
|
[1, 1, 1],
|
||
|
[1, 1, 1],
|
||
|
[1, 1, 1],
|
||
|
[1, 1, 1],
|
||
|
[1, 1, 1],
|
||
|
[1, 1, 1],
|
||
|
[1, 1, 1],
|
||
|
[1, 1, 1],
|
||
|
[1, 1, 1],
|
||
|
[1, 1, 1],
|
||
|
[1, 2, 2],
|
||
|
[1, 1, 1],
|
||
|
],
|
||
|
"stride_kv": [
|
||
|
[1, 8, 8],
|
||
|
[1, 4, 4],
|
||
|
[1, 4, 4],
|
||
|
[1, 2, 2],
|
||
|
[1, 2, 2],
|
||
|
[1, 2, 2],
|
||
|
[1, 2, 2],
|
||
|
[1, 2, 2],
|
||
|
[1, 2, 2],
|
||
|
[1, 2, 2],
|
||
|
[1, 2, 2],
|
||
|
[1, 2, 2],
|
||
|
[1, 2, 2],
|
||
|
[1, 2, 2],
|
||
|
[1, 1, 1],
|
||
|
[1, 1, 1],
|
||
|
],
|
||
|
}
|
||
|
|
||
|
block_setting = []
|
||
|
for i in range(len(config["num_heads"])):
|
||
|
block_setting.append(
|
||
|
MSBlockConfig(
|
||
|
num_heads=config["num_heads"][i],
|
||
|
input_channels=config["input_channels"][i],
|
||
|
output_channels=config["output_channels"][i],
|
||
|
kernel_q=config["kernel_q"][i],
|
||
|
kernel_kv=config["kernel_kv"][i],
|
||
|
stride_q=config["stride_q"][i],
|
||
|
stride_kv=config["stride_kv"][i],
|
||
|
)
|
||
|
)
|
||
|
|
||
|
return _mvit(
|
||
|
spatial_size=(224, 224),
|
||
|
temporal_size=16,
|
||
|
block_setting=block_setting,
|
||
|
residual_pool=True,
|
||
|
residual_with_cls_embed=False,
|
||
|
rel_pos_embed=True,
|
||
|
proj_after_attn=True,
|
||
|
stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2),
|
||
|
weights=weights,
|
||
|
progress=progress,
|
||
|
**kwargs,
|
||
|
)
|