Traktor/myenv/Lib/site-packages/torchvision/models/segmentation/deeplabv3.py
2024-05-23 01:57:24 +02:00

391 lines
15 KiB
Python

from functools import partial
from typing import Any, Optional, Sequence
import torch
from torch import nn
from torch.nn import functional as F
from ...transforms._presets import SemanticSegmentation
from .._api import register_model, Weights, WeightsEnum
from .._meta import _VOC_CATEGORIES
from .._utils import _ovewrite_value_param, handle_legacy_interface, IntermediateLayerGetter
from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights, MobileNetV3
from ..resnet import ResNet, resnet101, ResNet101_Weights, resnet50, ResNet50_Weights
from ._utils import _SimpleSegmentationModel
from .fcn import FCNHead
__all__ = [
"DeepLabV3",
"DeepLabV3_ResNet50_Weights",
"DeepLabV3_ResNet101_Weights",
"DeepLabV3_MobileNet_V3_Large_Weights",
"deeplabv3_mobilenet_v3_large",
"deeplabv3_resnet50",
"deeplabv3_resnet101",
]
class DeepLabV3(_SimpleSegmentationModel):
"""
Implements DeepLabV3 model from
`"Rethinking Atrous Convolution for Semantic Image Segmentation"
<https://arxiv.org/abs/1706.05587>`_.
Args:
backbone (nn.Module): the network used to compute the features for the model.
The backbone should return an OrderedDict[Tensor], with the key being
"out" for the last feature map used, and "aux" if an auxiliary classifier
is used.
classifier (nn.Module): module that takes the "out" element returned from
the backbone and returns a dense prediction.
aux_classifier (nn.Module, optional): auxiliary classifier used during training
"""
pass
class DeepLabHead(nn.Sequential):
def __init__(self, in_channels: int, num_classes: int, atrous_rates: Sequence[int] = (12, 24, 36)) -> None:
super().__init__(
ASPP(in_channels, atrous_rates),
nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, num_classes, 1),
)
class ASPPConv(nn.Sequential):
def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None:
modules = [
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
]
super().__init__(*modules)
class ASPPPooling(nn.Sequential):
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
size = x.shape[-2:]
for mod in self:
x = mod(x)
return F.interpolate(x, size=size, mode="bilinear", align_corners=False)
class ASPP(nn.Module):
def __init__(self, in_channels: int, atrous_rates: Sequence[int], out_channels: int = 256) -> None:
super().__init__()
modules = []
modules.append(
nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU())
)
rates = tuple(atrous_rates)
for rate in rates:
modules.append(ASPPConv(in_channels, out_channels, rate))
modules.append(ASPPPooling(in_channels, out_channels))
self.convs = nn.ModuleList(modules)
self.project = nn.Sequential(
nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Dropout(0.5),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
_res = []
for conv in self.convs:
_res.append(conv(x))
res = torch.cat(_res, dim=1)
return self.project(res)
def _deeplabv3_resnet(
backbone: ResNet,
num_classes: int,
aux: Optional[bool],
) -> DeepLabV3:
return_layers = {"layer4": "out"}
if aux:
return_layers["layer3"] = "aux"
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
aux_classifier = FCNHead(1024, num_classes) if aux else None
classifier = DeepLabHead(2048, num_classes)
return DeepLabV3(backbone, classifier, aux_classifier)
_COMMON_META = {
"categories": _VOC_CATEGORIES,
"min_size": (1, 1),
"_docs": """
These weights were trained on a subset of COCO, using only the 20 categories that are present in the Pascal VOC
dataset.
""",
}
class DeepLabV3_ResNet50_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
transforms=partial(SemanticSegmentation, resize_size=520),
meta={
**_COMMON_META,
"num_params": 42004074,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50",
"_metrics": {
"COCO-val2017-VOC-labels": {
"miou": 66.4,
"pixel_acc": 92.4,
}
},
"_ops": 178.722,
"_file_size": 160.515,
},
)
DEFAULT = COCO_WITH_VOC_LABELS_V1
class DeepLabV3_ResNet101_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
transforms=partial(SemanticSegmentation, resize_size=520),
meta={
**_COMMON_META,
"num_params": 60996202,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101",
"_metrics": {
"COCO-val2017-VOC-labels": {
"miou": 67.4,
"pixel_acc": 92.4,
}
},
"_ops": 258.743,
"_file_size": 233.217,
},
)
DEFAULT = COCO_WITH_VOC_LABELS_V1
class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
transforms=partial(SemanticSegmentation, resize_size=520),
meta={
**_COMMON_META,
"num_params": 11029328,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large",
"_metrics": {
"COCO-val2017-VOC-labels": {
"miou": 60.3,
"pixel_acc": 91.2,
}
},
"_ops": 10.452,
"_file_size": 42.301,
},
)
DEFAULT = COCO_WITH_VOC_LABELS_V1
def _deeplabv3_mobilenetv3(
backbone: MobileNetV3,
num_classes: int,
aux: Optional[bool],
) -> DeepLabV3:
backbone = backbone.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
out_pos = stage_indices[-1] # use C5 which has output_stride = 16
out_inplanes = backbone[out_pos].out_channels
aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8
aux_inplanes = backbone[aux_pos].out_channels
return_layers = {str(out_pos): "out"}
if aux:
return_layers[str(aux_pos)] = "aux"
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None
classifier = DeepLabHead(out_inplanes, num_classes)
return DeepLabV3(backbone, classifier, aux_classifier)
@register_model()
@handle_legacy_interface(
weights=("pretrained", DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def deeplabv3_resnet50(
*,
weights: Optional[DeepLabV3_ResNet50_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
**kwargs: Any,
) -> DeepLabV3:
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
.. betastatus:: segmentation module
Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation <https://arxiv.org/abs/1706.05587>`__.
Args:
weights (:class:`~torchvision.models.segmentation.DeepLabV3_ResNet50_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.segmentation.DeepLabV3_ResNet50_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
num_classes (int, optional): number of output classes of the model (including the background)
aux_loss (bool, optional): If True, it uses an auxiliary loss
weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for the
backbone
**kwargs: unused
.. autoclass:: torchvision.models.segmentation.DeepLabV3_ResNet50_Weights
:members:
"""
weights = DeepLabV3_ResNet50_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
aux_loss = _ovewrite_value_param("aux_loss", aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
@register_model()
@handle_legacy_interface(
weights=("pretrained", DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1),
)
def deeplabv3_resnet101(
*,
weights: Optional[DeepLabV3_ResNet101_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet101_Weights] = ResNet101_Weights.IMAGENET1K_V1,
**kwargs: Any,
) -> DeepLabV3:
"""Constructs a DeepLabV3 model with a ResNet-101 backbone.
.. betastatus:: segmentation module
Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation <https://arxiv.org/abs/1706.05587>`__.
Args:
weights (:class:`~torchvision.models.segmentation.DeepLabV3_ResNet101_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.segmentation.DeepLabV3_ResNet101_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
num_classes (int, optional): number of output classes of the model (including the background)
aux_loss (bool, optional): If True, it uses an auxiliary loss
weights_backbone (:class:`~torchvision.models.ResNet101_Weights`, optional): The pretrained weights for the
backbone
**kwargs: unused
.. autoclass:: torchvision.models.segmentation.DeepLabV3_ResNet101_Weights
:members:
"""
weights = DeepLabV3_ResNet101_Weights.verify(weights)
weights_backbone = ResNet101_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
aux_loss = _ovewrite_value_param("aux_loss", aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
@register_model()
@handle_legacy_interface(
weights=("pretrained", DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
def deeplabv3_mobilenet_v3_large(
*,
weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None,
weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
**kwargs: Any,
) -> DeepLabV3:
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation <https://arxiv.org/abs/1706.05587>`__.
Args:
weights (:class:`~torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
num_classes (int, optional): number of output classes of the model (including the background)
aux_loss (bool, optional): If True, it uses an auxiliary loss
weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The pretrained weights
for the backbone
**kwargs: unused
.. autoclass:: torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights
:members:
"""
weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
aux_loss = _ovewrite_value_param("aux_loss", aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True)
model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model