import math from collections import OrderedDict from functools import partial from typing import Any, Callable, List, Optional, Sequence, Tuple import numpy as np import torch import torch.nn.functional as F from torch import nn, Tensor from torchvision.models._api import register_model, Weights, WeightsEnum from torchvision.models._meta import _IMAGENET_CATEGORIES from torchvision.models._utils import _ovewrite_named_param, handle_legacy_interface from torchvision.ops.misc import Conv2dNormActivation, SqueezeExcitation from torchvision.ops.stochastic_depth import StochasticDepth from torchvision.transforms._presets import ImageClassification, InterpolationMode from torchvision.utils import _log_api_usage_once __all__ = [ "MaxVit", "MaxVit_T_Weights", "maxvit_t", ] def _get_conv_output_shape(input_size: Tuple[int, int], kernel_size: int, stride: int, padding: int) -> Tuple[int, int]: return ( (input_size[0] - kernel_size + 2 * padding) // stride + 1, (input_size[1] - kernel_size + 2 * padding) // stride + 1, ) def _make_block_input_shapes(input_size: Tuple[int, int], n_blocks: int) -> List[Tuple[int, int]]: """Util function to check that the input size is correct for a MaxVit configuration.""" shapes = [] block_input_shape = _get_conv_output_shape(input_size, 3, 2, 1) for _ in range(n_blocks): block_input_shape = _get_conv_output_shape(block_input_shape, 3, 2, 1) shapes.append(block_input_shape) return shapes def _get_relative_position_index(height: int, width: int) -> torch.Tensor: coords = torch.stack(torch.meshgrid([torch.arange(height), torch.arange(width)])) coords_flat = torch.flatten(coords, 1) relative_coords = coords_flat[:, :, None] - coords_flat[:, None, :] relative_coords = relative_coords.permute(1, 2, 0).contiguous() relative_coords[:, :, 0] += height - 1 relative_coords[:, :, 1] += width - 1 relative_coords[:, :, 0] *= 2 * width - 1 return relative_coords.sum(-1) class MBConv(nn.Module): """MBConv: Mobile Inverted Residual Bottleneck. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. expansion_ratio (float): Expansion ratio in the bottleneck. squeeze_ratio (float): Squeeze ratio in the SE Layer. stride (int): Stride of the depthwise convolution. activation_layer (Callable[..., nn.Module]): Activation function. norm_layer (Callable[..., nn.Module]): Normalization function. p_stochastic_dropout (float): Probability of stochastic depth. """ def __init__( self, in_channels: int, out_channels: int, expansion_ratio: float, squeeze_ratio: float, stride: int, activation_layer: Callable[..., nn.Module], norm_layer: Callable[..., nn.Module], p_stochastic_dropout: float = 0.0, ) -> None: super().__init__() proj: Sequence[nn.Module] self.proj: nn.Module should_proj = stride != 1 or in_channels != out_channels if should_proj: proj = [nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=True)] if stride == 2: proj = [nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)] + proj # type: ignore self.proj = nn.Sequential(*proj) else: self.proj = nn.Identity() # type: ignore mid_channels = int(out_channels * expansion_ratio) sqz_channels = int(out_channels * squeeze_ratio) if p_stochastic_dropout: self.stochastic_depth = StochasticDepth(p_stochastic_dropout, mode="row") # type: ignore else: self.stochastic_depth = nn.Identity() # type: ignore _layers = OrderedDict() _layers["pre_norm"] = norm_layer(in_channels) _layers["conv_a"] = Conv2dNormActivation( in_channels, mid_channels, kernel_size=1, stride=1, padding=0, activation_layer=activation_layer, norm_layer=norm_layer, inplace=None, ) _layers["conv_b"] = Conv2dNormActivation( mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, activation_layer=activation_layer, norm_layer=norm_layer, groups=mid_channels, inplace=None, ) _layers["squeeze_excitation"] = SqueezeExcitation(mid_channels, sqz_channels, activation=nn.SiLU) _layers["conv_c"] = nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1, bias=True) self.layers = nn.Sequential(_layers) def forward(self, x: Tensor) -> Tensor: """ Args: x (Tensor): Input tensor with expected layout of [B, C, H, W]. Returns: Tensor: Output tensor with expected layout of [B, C, H / stride, W / stride]. """ res = self.proj(x) x = self.stochastic_depth(self.layers(x)) return res + x class RelativePositionalMultiHeadAttention(nn.Module): """Relative Positional Multi-Head Attention. Args: feat_dim (int): Number of input features. head_dim (int): Number of features per head. max_seq_len (int): Maximum sequence length. """ def __init__( self, feat_dim: int, head_dim: int, max_seq_len: int, ) -> None: super().__init__() if feat_dim % head_dim != 0: raise ValueError(f"feat_dim: {feat_dim} must be divisible by head_dim: {head_dim}") self.n_heads = feat_dim // head_dim self.head_dim = head_dim self.size = int(math.sqrt(max_seq_len)) self.max_seq_len = max_seq_len self.to_qkv = nn.Linear(feat_dim, self.n_heads * self.head_dim * 3) self.scale_factor = feat_dim**-0.5 self.merge = nn.Linear(self.head_dim * self.n_heads, feat_dim) self.relative_position_bias_table = nn.parameter.Parameter( torch.empty(((2 * self.size - 1) * (2 * self.size - 1), self.n_heads), dtype=torch.float32), ) self.register_buffer("relative_position_index", _get_relative_position_index(self.size, self.size)) # initialize with truncated normal the bias torch.nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) def get_relative_positional_bias(self) -> torch.Tensor: bias_index = self.relative_position_index.view(-1) # type: ignore relative_bias = self.relative_position_bias_table[bias_index].view(self.max_seq_len, self.max_seq_len, -1) # type: ignore relative_bias = relative_bias.permute(2, 0, 1).contiguous() return relative_bias.unsqueeze(0) def forward(self, x: Tensor) -> Tensor: """ Args: x (Tensor): Input tensor with expected layout of [B, G, P, D]. Returns: Tensor: Output tensor with expected layout of [B, G, P, D]. """ B, G, P, D = x.shape H, DH = self.n_heads, self.head_dim qkv = self.to_qkv(x) q, k, v = torch.chunk(qkv, 3, dim=-1) q = q.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4) k = k.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4) v = v.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4) k = k * self.scale_factor dot_prod = torch.einsum("B G H I D, B G H J D -> B G H I J", q, k) pos_bias = self.get_relative_positional_bias() dot_prod = F.softmax(dot_prod + pos_bias, dim=-1) out = torch.einsum("B G H I J, B G H J D -> B G H I D", dot_prod, v) out = out.permute(0, 1, 3, 2, 4).reshape(B, G, P, D) out = self.merge(out) return out class SwapAxes(nn.Module): """Permute the axes of a tensor.""" def __init__(self, a: int, b: int) -> None: super().__init__() self.a = a self.b = b def forward(self, x: torch.Tensor) -> torch.Tensor: res = torch.swapaxes(x, self.a, self.b) return res class WindowPartition(nn.Module): """ Partition the input tensor into non-overlapping windows. """ def __init__(self) -> None: super().__init__() def forward(self, x: Tensor, p: int) -> Tensor: """ Args: x (Tensor): Input tensor with expected layout of [B, C, H, W]. p (int): Number of partitions. Returns: Tensor: Output tensor with expected layout of [B, H/P, W/P, P*P, C]. """ B, C, H, W = x.shape P = p # chunk up H and W dimensions x = x.reshape(B, C, H // P, P, W // P, P) x = x.permute(0, 2, 4, 3, 5, 1) # colapse P * P dimension x = x.reshape(B, (H // P) * (W // P), P * P, C) return x class WindowDepartition(nn.Module): """ Departition the input tensor of non-overlapping windows into a feature volume of layout [B, C, H, W]. """ def __init__(self) -> None: super().__init__() def forward(self, x: Tensor, p: int, h_partitions: int, w_partitions: int) -> Tensor: """ Args: x (Tensor): Input tensor with expected layout of [B, (H/P * W/P), P*P, C]. p (int): Number of partitions. h_partitions (int): Number of vertical partitions. w_partitions (int): Number of horizontal partitions. Returns: Tensor: Output tensor with expected layout of [B, C, H, W]. """ B, G, PP, C = x.shape P = p HP, WP = h_partitions, w_partitions # split P * P dimension into 2 P tile dimensionsa x = x.reshape(B, HP, WP, P, P, C) # permute into B, C, HP, P, WP, P x = x.permute(0, 5, 1, 3, 2, 4) # reshape into B, C, H, W x = x.reshape(B, C, HP * P, WP * P) return x class PartitionAttentionLayer(nn.Module): """ Layer for partitioning the input tensor into non-overlapping windows and applying attention to each window. Args: in_channels (int): Number of input channels. head_dim (int): Dimension of each attention head. partition_size (int): Size of the partitions. partition_type (str): Type of partitioning to use. Can be either "grid" or "window". grid_size (Tuple[int, int]): Size of the grid to partition the input tensor into. mlp_ratio (int): Ratio of the feature size expansion in the MLP layer. activation_layer (Callable[..., nn.Module]): Activation function to use. norm_layer (Callable[..., nn.Module]): Normalization function to use. attention_dropout (float): Dropout probability for the attention layer. mlp_dropout (float): Dropout probability for the MLP layer. p_stochastic_dropout (float): Probability of dropping out a partition. """ def __init__( self, in_channels: int, head_dim: int, # partitioning parameters partition_size: int, partition_type: str, # grid size needs to be known at initialization time # because we need to know hamy relative offsets there are in the grid grid_size: Tuple[int, int], mlp_ratio: int, activation_layer: Callable[..., nn.Module], norm_layer: Callable[..., nn.Module], attention_dropout: float, mlp_dropout: float, p_stochastic_dropout: float, ) -> None: super().__init__() self.n_heads = in_channels // head_dim self.head_dim = head_dim self.n_partitions = grid_size[0] // partition_size self.partition_type = partition_type self.grid_size = grid_size if partition_type not in ["grid", "window"]: raise ValueError("partition_type must be either 'grid' or 'window'") if partition_type == "window": self.p, self.g = partition_size, self.n_partitions else: self.p, self.g = self.n_partitions, partition_size self.partition_op = WindowPartition() self.departition_op = WindowDepartition() self.partition_swap = SwapAxes(-2, -3) if partition_type == "grid" else nn.Identity() self.departition_swap = SwapAxes(-2, -3) if partition_type == "grid" else nn.Identity() self.attn_layer = nn.Sequential( norm_layer(in_channels), # it's always going to be partition_size ** 2 because # of the axis swap in the case of grid partitioning RelativePositionalMultiHeadAttention(in_channels, head_dim, partition_size**2), nn.Dropout(attention_dropout), ) # pre-normalization similar to transformer layers self.mlp_layer = nn.Sequential( nn.LayerNorm(in_channels), nn.Linear(in_channels, in_channels * mlp_ratio), activation_layer(), nn.Linear(in_channels * mlp_ratio, in_channels), nn.Dropout(mlp_dropout), ) # layer scale factors self.stochastic_dropout = StochasticDepth(p_stochastic_dropout, mode="row") def forward(self, x: Tensor) -> Tensor: """ Args: x (Tensor): Input tensor with expected layout of [B, C, H, W]. Returns: Tensor: Output tensor with expected layout of [B, C, H, W]. """ # Undefined behavior if H or W are not divisible by p # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L766 gh, gw = self.grid_size[0] // self.p, self.grid_size[1] // self.p torch._assert( self.grid_size[0] % self.p == 0 and self.grid_size[1] % self.p == 0, "Grid size must be divisible by partition size. Got grid size of {} and partition size of {}".format( self.grid_size, self.p ), ) x = self.partition_op(x, self.p) x = self.partition_swap(x) x = x + self.stochastic_dropout(self.attn_layer(x)) x = x + self.stochastic_dropout(self.mlp_layer(x)) x = self.departition_swap(x) x = self.departition_op(x, self.p, gh, gw) return x class MaxVitLayer(nn.Module): """ MaxVit layer consisting of a MBConv layer followed by a PartitionAttentionLayer with `window` and a PartitionAttentionLayer with `grid`. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. expansion_ratio (float): Expansion ratio in the bottleneck. squeeze_ratio (float): Squeeze ratio in the SE Layer. stride (int): Stride of the depthwise convolution. activation_layer (Callable[..., nn.Module]): Activation function. norm_layer (Callable[..., nn.Module]): Normalization function. head_dim (int): Dimension of the attention heads. mlp_ratio (int): Ratio of the MLP layer. mlp_dropout (float): Dropout probability for the MLP layer. attention_dropout (float): Dropout probability for the attention layer. p_stochastic_dropout (float): Probability of stochastic depth. partition_size (int): Size of the partitions. grid_size (Tuple[int, int]): Size of the input feature grid. """ def __init__( self, # conv parameters in_channels: int, out_channels: int, squeeze_ratio: float, expansion_ratio: float, stride: int, # conv + transformer parameters norm_layer: Callable[..., nn.Module], activation_layer: Callable[..., nn.Module], # transformer parameters head_dim: int, mlp_ratio: int, mlp_dropout: float, attention_dropout: float, p_stochastic_dropout: float, # partitioning parameters partition_size: int, grid_size: Tuple[int, int], ) -> None: super().__init__() layers: OrderedDict = OrderedDict() # convolutional layer layers["MBconv"] = MBConv( in_channels=in_channels, out_channels=out_channels, expansion_ratio=expansion_ratio, squeeze_ratio=squeeze_ratio, stride=stride, activation_layer=activation_layer, norm_layer=norm_layer, p_stochastic_dropout=p_stochastic_dropout, ) # attention layers, block -> grid layers["window_attention"] = PartitionAttentionLayer( in_channels=out_channels, head_dim=head_dim, partition_size=partition_size, partition_type="window", grid_size=grid_size, mlp_ratio=mlp_ratio, activation_layer=activation_layer, norm_layer=nn.LayerNorm, attention_dropout=attention_dropout, mlp_dropout=mlp_dropout, p_stochastic_dropout=p_stochastic_dropout, ) layers["grid_attention"] = PartitionAttentionLayer( in_channels=out_channels, head_dim=head_dim, partition_size=partition_size, partition_type="grid", grid_size=grid_size, mlp_ratio=mlp_ratio, activation_layer=activation_layer, norm_layer=nn.LayerNorm, attention_dropout=attention_dropout, mlp_dropout=mlp_dropout, p_stochastic_dropout=p_stochastic_dropout, ) self.layers = nn.Sequential(layers) def forward(self, x: Tensor) -> Tensor: """ Args: x (Tensor): Input tensor of shape (B, C, H, W). Returns: Tensor: Output tensor of shape (B, C, H, W). """ x = self.layers(x) return x class MaxVitBlock(nn.Module): """ A MaxVit block consisting of `n_layers` MaxVit layers. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. expansion_ratio (float): Expansion ratio in the bottleneck. squeeze_ratio (float): Squeeze ratio in the SE Layer. activation_layer (Callable[..., nn.Module]): Activation function. norm_layer (Callable[..., nn.Module]): Normalization function. head_dim (int): Dimension of the attention heads. mlp_ratio (int): Ratio of the MLP layer. mlp_dropout (float): Dropout probability for the MLP layer. attention_dropout (float): Dropout probability for the attention layer. p_stochastic_dropout (float): Probability of stochastic depth. partition_size (int): Size of the partitions. input_grid_size (Tuple[int, int]): Size of the input feature grid. n_layers (int): Number of layers in the block. p_stochastic (List[float]): List of probabilities for stochastic depth for each layer. """ def __init__( self, # conv parameters in_channels: int, out_channels: int, squeeze_ratio: float, expansion_ratio: float, # conv + transformer parameters norm_layer: Callable[..., nn.Module], activation_layer: Callable[..., nn.Module], # transformer parameters head_dim: int, mlp_ratio: int, mlp_dropout: float, attention_dropout: float, # partitioning parameters partition_size: int, input_grid_size: Tuple[int, int], # number of layers n_layers: int, p_stochastic: List[float], ) -> None: super().__init__() if not len(p_stochastic) == n_layers: raise ValueError(f"p_stochastic must have length n_layers={n_layers}, got p_stochastic={p_stochastic}.") self.layers = nn.ModuleList() # account for the first stride of the first layer self.grid_size = _get_conv_output_shape(input_grid_size, kernel_size=3, stride=2, padding=1) for idx, p in enumerate(p_stochastic): stride = 2 if idx == 0 else 1 self.layers += [ MaxVitLayer( in_channels=in_channels if idx == 0 else out_channels, out_channels=out_channels, squeeze_ratio=squeeze_ratio, expansion_ratio=expansion_ratio, stride=stride, norm_layer=norm_layer, activation_layer=activation_layer, head_dim=head_dim, mlp_ratio=mlp_ratio, mlp_dropout=mlp_dropout, attention_dropout=attention_dropout, partition_size=partition_size, grid_size=self.grid_size, p_stochastic_dropout=p, ), ] def forward(self, x: Tensor) -> Tensor: """ Args: x (Tensor): Input tensor of shape (B, C, H, W). Returns: Tensor: Output tensor of shape (B, C, H, W). """ for layer in self.layers: x = layer(x) return x class MaxVit(nn.Module): """ Implements MaxVit Transformer from the `MaxViT: Multi-Axis Vision Transformer `_ paper. Args: input_size (Tuple[int, int]): Size of the input image. stem_channels (int): Number of channels in the stem. partition_size (int): Size of the partitions. block_channels (List[int]): Number of channels in each block. block_layers (List[int]): Number of layers in each block. stochastic_depth_prob (float): Probability of stochastic depth. Expands to a list of probabilities for each layer that scales linearly to the specified value. squeeze_ratio (float): Squeeze ratio in the SE Layer. Default: 0.25. expansion_ratio (float): Expansion ratio in the MBConv bottleneck. Default: 4. norm_layer (Callable[..., nn.Module]): Normalization function. Default: None (setting to None will produce a `BatchNorm2d(eps=1e-3, momentum=0.01)`). activation_layer (Callable[..., nn.Module]): Activation function Default: nn.GELU. head_dim (int): Dimension of the attention heads. mlp_ratio (int): Expansion ratio of the MLP layer. Default: 4. mlp_dropout (float): Dropout probability for the MLP layer. Default: 0.0. attention_dropout (float): Dropout probability for the attention layer. Default: 0.0. num_classes (int): Number of classes. Default: 1000. """ def __init__( self, # input size parameters input_size: Tuple[int, int], # stem and task parameters stem_channels: int, # partitioning parameters partition_size: int, # block parameters block_channels: List[int], block_layers: List[int], # attention head dimensions head_dim: int, stochastic_depth_prob: float, # conv + transformer parameters # norm_layer is applied only to the conv layers # activation_layer is applied both to conv and transformer layers norm_layer: Optional[Callable[..., nn.Module]] = None, activation_layer: Callable[..., nn.Module] = nn.GELU, # conv parameters squeeze_ratio: float = 0.25, expansion_ratio: float = 4, # transformer parameters mlp_ratio: int = 4, mlp_dropout: float = 0.0, attention_dropout: float = 0.0, # task parameters num_classes: int = 1000, ) -> None: super().__init__() _log_api_usage_once(self) input_channels = 3 # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L1029-L1030 # for the exact parameters used in batchnorm if norm_layer is None: norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.01) # Make sure input size will be divisible by the partition size in all blocks # Undefined behavior if H or W are not divisible by p # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L766 block_input_sizes = _make_block_input_shapes(input_size, len(block_channels)) for idx, block_input_size in enumerate(block_input_sizes): if block_input_size[0] % partition_size != 0 or block_input_size[1] % partition_size != 0: raise ValueError( f"Input size {block_input_size} of block {idx} is not divisible by partition size {partition_size}. " f"Consider changing the partition size or the input size.\n" f"Current configuration yields the following block input sizes: {block_input_sizes}." ) # stem self.stem = nn.Sequential( Conv2dNormActivation( input_channels, stem_channels, 3, stride=2, norm_layer=norm_layer, activation_layer=activation_layer, bias=False, inplace=None, ), Conv2dNormActivation( stem_channels, stem_channels, 3, stride=1, norm_layer=None, activation_layer=None, bias=True ), ) # account for stem stride input_size = _get_conv_output_shape(input_size, kernel_size=3, stride=2, padding=1) self.partition_size = partition_size # blocks self.blocks = nn.ModuleList() in_channels = [stem_channels] + block_channels[:-1] out_channels = block_channels # precompute the stochastich depth probabilities from 0 to stochastic_depth_prob # since we have N blocks with L layers, we will have N * L probabilities uniformly distributed # over the range [0, stochastic_depth_prob] p_stochastic = np.linspace(0, stochastic_depth_prob, sum(block_layers)).tolist() p_idx = 0 for in_channel, out_channel, num_layers in zip(in_channels, out_channels, block_layers): self.blocks.append( MaxVitBlock( in_channels=in_channel, out_channels=out_channel, squeeze_ratio=squeeze_ratio, expansion_ratio=expansion_ratio, norm_layer=norm_layer, activation_layer=activation_layer, head_dim=head_dim, mlp_ratio=mlp_ratio, mlp_dropout=mlp_dropout, attention_dropout=attention_dropout, partition_size=partition_size, input_grid_size=input_size, n_layers=num_layers, p_stochastic=p_stochastic[p_idx : p_idx + num_layers], ), ) input_size = self.blocks[-1].grid_size p_idx += num_layers # see https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L1137-L1158 # for why there is Linear -> Tanh -> Linear self.classifier = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.LayerNorm(block_channels[-1]), nn.Linear(block_channels[-1], block_channels[-1]), nn.Tanh(), nn.Linear(block_channels[-1], num_classes, bias=False), ) self._init_weights() def forward(self, x: Tensor) -> Tensor: x = self.stem(x) for block in self.blocks: x = block(x) x = self.classifier(x) return x def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) def _maxvit( # stem parameters stem_channels: int, # block parameters block_channels: List[int], block_layers: List[int], stochastic_depth_prob: float, # partitioning parameters partition_size: int, # transformer parameters head_dim: int, # Weights API weights: Optional[WeightsEnum] = None, progress: bool = False, # kwargs, **kwargs: Any, ) -> MaxVit: if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) assert weights.meta["min_size"][0] == weights.meta["min_size"][1] _ovewrite_named_param(kwargs, "input_size", weights.meta["min_size"]) input_size = kwargs.pop("input_size", (224, 224)) model = MaxVit( stem_channels=stem_channels, block_channels=block_channels, block_layers=block_layers, stochastic_depth_prob=stochastic_depth_prob, head_dim=head_dim, partition_size=partition_size, input_size=input_size, **kwargs, ) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) return model class MaxVit_T_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( # URL empty until official release url="https://download.pytorch.org/models/maxvit_t-bc5ab103.pth", transforms=partial( ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC ), meta={ "categories": _IMAGENET_CATEGORIES, "num_params": 30919624, "min_size": (224, 224), "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#maxvit", "_metrics": { "ImageNet-1K": { "acc@1": 83.700, "acc@5": 96.722, } }, "_ops": 5.558, "_file_size": 118.769, "_docs": """These weights reproduce closely the results of the paper using a similar training recipe. They were trained with a BatchNorm2D momentum of 0.99 instead of the more correct 0.01.""", }, ) DEFAULT = IMAGENET1K_V1 @register_model() @handle_legacy_interface(weights=("pretrained", MaxVit_T_Weights.IMAGENET1K_V1)) def maxvit_t(*, weights: Optional[MaxVit_T_Weights] = None, progress: bool = True, **kwargs: Any) -> MaxVit: """ Constructs a maxvit_t architecture from `MaxViT: Multi-Axis Vision Transformer `_. Args: weights (:class:`~torchvision.models.MaxVit_T_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.MaxVit_T_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.maxvit.MaxVit`` base class. Please refer to the `source code `_ for more details about this class. .. autoclass:: torchvision.models.MaxVit_T_Weights :members: """ weights = MaxVit_T_Weights.verify(weights) return _maxvit( stem_channels=64, block_channels=[64, 128, 256, 512], block_layers=[2, 2, 5, 2], head_dim=32, stochastic_depth_prob=0.2, partition_size=7, weights=weights, progress=progress, **kwargs, )