616 lines
28 KiB
Python
616 lines
28 KiB
Python
import math
|
|
from enum import Enum
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
from . import functional as F, InterpolationMode
|
|
|
|
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide", "AugMix"]
|
|
|
|
|
|
def _apply_op(
|
|
img: Tensor, op_name: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]]
|
|
):
|
|
if op_name == "ShearX":
|
|
# magnitude should be arctan(magnitude)
|
|
# official autoaug: (1, level, 0, 0, 1, 0)
|
|
# https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
|
|
# compared to
|
|
# torchvision: (1, tan(level), 0, 0, 1, 0)
|
|
# https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
|
|
img = F.affine(
|
|
img,
|
|
angle=0.0,
|
|
translate=[0, 0],
|
|
scale=1.0,
|
|
shear=[math.degrees(math.atan(magnitude)), 0.0],
|
|
interpolation=interpolation,
|
|
fill=fill,
|
|
center=[0, 0],
|
|
)
|
|
elif op_name == "ShearY":
|
|
# magnitude should be arctan(magnitude)
|
|
# See above
|
|
img = F.affine(
|
|
img,
|
|
angle=0.0,
|
|
translate=[0, 0],
|
|
scale=1.0,
|
|
shear=[0.0, math.degrees(math.atan(magnitude))],
|
|
interpolation=interpolation,
|
|
fill=fill,
|
|
center=[0, 0],
|
|
)
|
|
elif op_name == "TranslateX":
|
|
img = F.affine(
|
|
img,
|
|
angle=0.0,
|
|
translate=[int(magnitude), 0],
|
|
scale=1.0,
|
|
interpolation=interpolation,
|
|
shear=[0.0, 0.0],
|
|
fill=fill,
|
|
)
|
|
elif op_name == "TranslateY":
|
|
img = F.affine(
|
|
img,
|
|
angle=0.0,
|
|
translate=[0, int(magnitude)],
|
|
scale=1.0,
|
|
interpolation=interpolation,
|
|
shear=[0.0, 0.0],
|
|
fill=fill,
|
|
)
|
|
elif op_name == "Rotate":
|
|
img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
|
|
elif op_name == "Brightness":
|
|
img = F.adjust_brightness(img, 1.0 + magnitude)
|
|
elif op_name == "Color":
|
|
img = F.adjust_saturation(img, 1.0 + magnitude)
|
|
elif op_name == "Contrast":
|
|
img = F.adjust_contrast(img, 1.0 + magnitude)
|
|
elif op_name == "Sharpness":
|
|
img = F.adjust_sharpness(img, 1.0 + magnitude)
|
|
elif op_name == "Posterize":
|
|
img = F.posterize(img, int(magnitude))
|
|
elif op_name == "Solarize":
|
|
img = F.solarize(img, magnitude)
|
|
elif op_name == "AutoContrast":
|
|
img = F.autocontrast(img)
|
|
elif op_name == "Equalize":
|
|
img = F.equalize(img)
|
|
elif op_name == "Invert":
|
|
img = F.invert(img)
|
|
elif op_name == "Identity":
|
|
pass
|
|
else:
|
|
raise ValueError(f"The provided operator {op_name} is not recognized.")
|
|
return img
|
|
|
|
|
|
class AutoAugmentPolicy(Enum):
|
|
"""AutoAugment policies learned on different datasets.
|
|
Available policies are IMAGENET, CIFAR10 and SVHN.
|
|
"""
|
|
|
|
IMAGENET = "imagenet"
|
|
CIFAR10 = "cifar10"
|
|
SVHN = "svhn"
|
|
|
|
|
|
# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
|
|
class AutoAugment(torch.nn.Module):
|
|
r"""AutoAugment data augmentation method based on
|
|
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
|
|
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
|
|
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
|
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
|
|
|
Args:
|
|
policy (AutoAugmentPolicy): Desired policy enum defined by
|
|
:class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
|
|
interpolation (InterpolationMode): Desired interpolation enum defined by
|
|
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
|
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
|
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
|
image. If given a number, the value is used for all bands respectively.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
|
|
interpolation: InterpolationMode = InterpolationMode.NEAREST,
|
|
fill: Optional[List[float]] = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.policy = policy
|
|
self.interpolation = interpolation
|
|
self.fill = fill
|
|
self.policies = self._get_policies(policy)
|
|
|
|
def _get_policies(
|
|
self, policy: AutoAugmentPolicy
|
|
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
|
|
if policy == AutoAugmentPolicy.IMAGENET:
|
|
return [
|
|
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
|
|
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
|
|
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
|
|
(("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
|
|
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
|
|
(("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
|
|
(("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
|
|
(("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
|
|
(("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
|
|
(("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
|
|
(("Rotate", 0.8, 8), ("Color", 0.4, 0)),
|
|
(("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
|
|
(("Equalize", 0.0, None), ("Equalize", 0.8, None)),
|
|
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
|
|
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
|
|
(("Rotate", 0.8, 8), ("Color", 1.0, 2)),
|
|
(("Color", 0.8, 8), ("Solarize", 0.8, 7)),
|
|
(("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
|
|
(("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
|
|
(("Color", 0.4, 0), ("Equalize", 0.6, None)),
|
|
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
|
|
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
|
|
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
|
|
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
|
|
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
|
|
]
|
|
elif policy == AutoAugmentPolicy.CIFAR10:
|
|
return [
|
|
(("Invert", 0.1, None), ("Contrast", 0.2, 6)),
|
|
(("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
|
|
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
|
|
(("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
|
|
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
|
|
(("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
|
|
(("Color", 0.4, 3), ("Brightness", 0.6, 7)),
|
|
(("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
|
|
(("Equalize", 0.6, None), ("Equalize", 0.5, None)),
|
|
(("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
|
|
(("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
|
|
(("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
|
|
(("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
|
|
(("Brightness", 0.9, 6), ("Color", 0.2, 8)),
|
|
(("Solarize", 0.5, 2), ("Invert", 0.0, None)),
|
|
(("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
|
|
(("Equalize", 0.2, None), ("Equalize", 0.6, None)),
|
|
(("Color", 0.9, 9), ("Equalize", 0.6, None)),
|
|
(("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
|
|
(("Brightness", 0.1, 3), ("Color", 0.7, 0)),
|
|
(("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
|
|
(("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
|
|
(("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
|
|
(("Equalize", 0.8, None), ("Invert", 0.1, None)),
|
|
(("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
|
|
]
|
|
elif policy == AutoAugmentPolicy.SVHN:
|
|
return [
|
|
(("ShearX", 0.9, 4), ("Invert", 0.2, None)),
|
|
(("ShearY", 0.9, 8), ("Invert", 0.7, None)),
|
|
(("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
|
|
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
|
|
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
|
|
(("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
|
|
(("ShearY", 0.9, 8), ("Invert", 0.4, None)),
|
|
(("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
|
|
(("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
|
|
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
|
|
(("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
|
|
(("ShearY", 0.8, 8), ("Invert", 0.7, None)),
|
|
(("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
|
|
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
|
|
(("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
|
|
(("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
|
|
(("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
|
|
(("Invert", 0.6, None), ("Rotate", 0.8, 4)),
|
|
(("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
|
|
(("ShearX", 0.1, 6), ("Invert", 0.6, None)),
|
|
(("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
|
|
(("ShearY", 0.8, 4), ("Invert", 0.8, None)),
|
|
(("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
|
|
(("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
|
|
(("ShearX", 0.7, 2), ("Invert", 0.1, None)),
|
|
]
|
|
else:
|
|
raise ValueError(f"The provided policy {policy} is not recognized.")
|
|
|
|
def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
|
|
return {
|
|
# op_name: (magnitudes, signed)
|
|
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
|
|
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
|
|
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
|
|
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
|
|
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
|
|
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
|
|
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
|
|
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
|
|
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
|
|
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
|
|
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
|
|
"AutoContrast": (torch.tensor(0.0), False),
|
|
"Equalize": (torch.tensor(0.0), False),
|
|
"Invert": (torch.tensor(0.0), False),
|
|
}
|
|
|
|
@staticmethod
|
|
def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
|
|
"""Get parameters for autoaugment transformation
|
|
|
|
Returns:
|
|
params required by the autoaugment transformation
|
|
"""
|
|
policy_id = int(torch.randint(transform_num, (1,)).item())
|
|
probs = torch.rand((2,))
|
|
signs = torch.randint(2, (2,))
|
|
|
|
return policy_id, probs, signs
|
|
|
|
def forward(self, img: Tensor) -> Tensor:
|
|
"""
|
|
img (PIL Image or Tensor): Image to be transformed.
|
|
|
|
Returns:
|
|
PIL Image or Tensor: AutoAugmented image.
|
|
"""
|
|
fill = self.fill
|
|
channels, height, width = F.get_dimensions(img)
|
|
if isinstance(img, Tensor):
|
|
if isinstance(fill, (int, float)):
|
|
fill = [float(fill)] * channels
|
|
elif fill is not None:
|
|
fill = [float(f) for f in fill]
|
|
|
|
transform_id, probs, signs = self.get_params(len(self.policies))
|
|
|
|
op_meta = self._augmentation_space(10, (height, width))
|
|
for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]):
|
|
if probs[i] <= p:
|
|
magnitudes, signed = op_meta[op_name]
|
|
magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
|
|
if signed and signs[i] == 0:
|
|
magnitude *= -1.0
|
|
img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
|
|
|
|
return img
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}(policy={self.policy}, fill={self.fill})"
|
|
|
|
|
|
class RandAugment(torch.nn.Module):
|
|
r"""RandAugment data augmentation method based on
|
|
`"RandAugment: Practical automated data augmentation with a reduced search space"
|
|
<https://arxiv.org/abs/1909.13719>`_.
|
|
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
|
|
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
|
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
|
|
|
Args:
|
|
num_ops (int): Number of augmentation transformations to apply sequentially.
|
|
magnitude (int): Magnitude for all the transformations.
|
|
num_magnitude_bins (int): The number of different magnitude values.
|
|
interpolation (InterpolationMode): Desired interpolation enum defined by
|
|
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
|
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
|
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
|
image. If given a number, the value is used for all bands respectively.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_ops: int = 2,
|
|
magnitude: int = 9,
|
|
num_magnitude_bins: int = 31,
|
|
interpolation: InterpolationMode = InterpolationMode.NEAREST,
|
|
fill: Optional[List[float]] = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.num_ops = num_ops
|
|
self.magnitude = magnitude
|
|
self.num_magnitude_bins = num_magnitude_bins
|
|
self.interpolation = interpolation
|
|
self.fill = fill
|
|
|
|
def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
|
|
return {
|
|
# op_name: (magnitudes, signed)
|
|
"Identity": (torch.tensor(0.0), False),
|
|
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
|
|
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
|
|
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
|
|
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
|
|
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
|
|
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
|
|
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
|
|
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
|
|
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
|
|
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
|
|
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
|
|
"AutoContrast": (torch.tensor(0.0), False),
|
|
"Equalize": (torch.tensor(0.0), False),
|
|
}
|
|
|
|
def forward(self, img: Tensor) -> Tensor:
|
|
"""
|
|
img (PIL Image or Tensor): Image to be transformed.
|
|
|
|
Returns:
|
|
PIL Image or Tensor: Transformed image.
|
|
"""
|
|
fill = self.fill
|
|
channels, height, width = F.get_dimensions(img)
|
|
if isinstance(img, Tensor):
|
|
if isinstance(fill, (int, float)):
|
|
fill = [float(fill)] * channels
|
|
elif fill is not None:
|
|
fill = [float(f) for f in fill]
|
|
|
|
op_meta = self._augmentation_space(self.num_magnitude_bins, (height, width))
|
|
for _ in range(self.num_ops):
|
|
op_index = int(torch.randint(len(op_meta), (1,)).item())
|
|
op_name = list(op_meta.keys())[op_index]
|
|
magnitudes, signed = op_meta[op_name]
|
|
magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0
|
|
if signed and torch.randint(2, (1,)):
|
|
magnitude *= -1.0
|
|
img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
|
|
|
|
return img
|
|
|
|
def __repr__(self) -> str:
|
|
s = (
|
|
f"{self.__class__.__name__}("
|
|
f"num_ops={self.num_ops}"
|
|
f", magnitude={self.magnitude}"
|
|
f", num_magnitude_bins={self.num_magnitude_bins}"
|
|
f", interpolation={self.interpolation}"
|
|
f", fill={self.fill}"
|
|
f")"
|
|
)
|
|
return s
|
|
|
|
|
|
class TrivialAugmentWide(torch.nn.Module):
|
|
r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
|
|
`"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_.
|
|
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
|
|
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
|
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
|
|
|
Args:
|
|
num_magnitude_bins (int): The number of different magnitude values.
|
|
interpolation (InterpolationMode): Desired interpolation enum defined by
|
|
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
|
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
|
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
|
image. If given a number, the value is used for all bands respectively.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_magnitude_bins: int = 31,
|
|
interpolation: InterpolationMode = InterpolationMode.NEAREST,
|
|
fill: Optional[List[float]] = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.num_magnitude_bins = num_magnitude_bins
|
|
self.interpolation = interpolation
|
|
self.fill = fill
|
|
|
|
def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
|
|
return {
|
|
# op_name: (magnitudes, signed)
|
|
"Identity": (torch.tensor(0.0), False),
|
|
"ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
|
|
"ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
|
|
"TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
|
|
"TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
|
|
"Rotate": (torch.linspace(0.0, 135.0, num_bins), True),
|
|
"Brightness": (torch.linspace(0.0, 0.99, num_bins), True),
|
|
"Color": (torch.linspace(0.0, 0.99, num_bins), True),
|
|
"Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
|
|
"Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
|
|
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
|
|
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
|
|
"AutoContrast": (torch.tensor(0.0), False),
|
|
"Equalize": (torch.tensor(0.0), False),
|
|
}
|
|
|
|
def forward(self, img: Tensor) -> Tensor:
|
|
"""
|
|
img (PIL Image or Tensor): Image to be transformed.
|
|
|
|
Returns:
|
|
PIL Image or Tensor: Transformed image.
|
|
"""
|
|
fill = self.fill
|
|
channels, height, width = F.get_dimensions(img)
|
|
if isinstance(img, Tensor):
|
|
if isinstance(fill, (int, float)):
|
|
fill = [float(fill)] * channels
|
|
elif fill is not None:
|
|
fill = [float(f) for f in fill]
|
|
|
|
op_meta = self._augmentation_space(self.num_magnitude_bins)
|
|
op_index = int(torch.randint(len(op_meta), (1,)).item())
|
|
op_name = list(op_meta.keys())[op_index]
|
|
magnitudes, signed = op_meta[op_name]
|
|
magnitude = (
|
|
float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item())
|
|
if magnitudes.ndim > 0
|
|
else 0.0
|
|
)
|
|
if signed and torch.randint(2, (1,)):
|
|
magnitude *= -1.0
|
|
|
|
return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
|
|
|
|
def __repr__(self) -> str:
|
|
s = (
|
|
f"{self.__class__.__name__}("
|
|
f"num_magnitude_bins={self.num_magnitude_bins}"
|
|
f", interpolation={self.interpolation}"
|
|
f", fill={self.fill}"
|
|
f")"
|
|
)
|
|
return s
|
|
|
|
|
|
class AugMix(torch.nn.Module):
|
|
r"""AugMix data augmentation method based on
|
|
`"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.
|
|
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
|
|
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
|
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
|
|
|
Args:
|
|
severity (int): The severity of base augmentation operators. Default is ``3``.
|
|
mixture_width (int): The number of augmentation chains. Default is ``3``.
|
|
chain_depth (int): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3].
|
|
Default is ``-1``.
|
|
alpha (float): The hyperparameter for the probability distributions. Default is ``1.0``.
|
|
all_ops (bool): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
|
|
interpolation (InterpolationMode): Desired interpolation enum defined by
|
|
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
|
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
|
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
|
image. If given a number, the value is used for all bands respectively.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
severity: int = 3,
|
|
mixture_width: int = 3,
|
|
chain_depth: int = -1,
|
|
alpha: float = 1.0,
|
|
all_ops: bool = True,
|
|
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
|
|
fill: Optional[List[float]] = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self._PARAMETER_MAX = 10
|
|
if not (1 <= severity <= self._PARAMETER_MAX):
|
|
raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
|
|
self.severity = severity
|
|
self.mixture_width = mixture_width
|
|
self.chain_depth = chain_depth
|
|
self.alpha = alpha
|
|
self.all_ops = all_ops
|
|
self.interpolation = interpolation
|
|
self.fill = fill
|
|
|
|
def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
|
|
s = {
|
|
# op_name: (magnitudes, signed)
|
|
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
|
|
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
|
|
"TranslateX": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True),
|
|
"TranslateY": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True),
|
|
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
|
|
"Posterize": (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
|
|
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
|
|
"AutoContrast": (torch.tensor(0.0), False),
|
|
"Equalize": (torch.tensor(0.0), False),
|
|
}
|
|
if self.all_ops:
|
|
s.update(
|
|
{
|
|
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
|
|
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
|
|
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
|
|
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
|
|
}
|
|
)
|
|
return s
|
|
|
|
@torch.jit.unused
|
|
def _pil_to_tensor(self, img) -> Tensor:
|
|
return F.pil_to_tensor(img)
|
|
|
|
@torch.jit.unused
|
|
def _tensor_to_pil(self, img: Tensor):
|
|
return F.to_pil_image(img)
|
|
|
|
def _sample_dirichlet(self, params: Tensor) -> Tensor:
|
|
# Must be on a separate method so that we can overwrite it in tests.
|
|
return torch._sample_dirichlet(params)
|
|
|
|
def forward(self, orig_img: Tensor) -> Tensor:
|
|
"""
|
|
img (PIL Image or Tensor): Image to be transformed.
|
|
|
|
Returns:
|
|
PIL Image or Tensor: Transformed image.
|
|
"""
|
|
fill = self.fill
|
|
channels, height, width = F.get_dimensions(orig_img)
|
|
if isinstance(orig_img, Tensor):
|
|
img = orig_img
|
|
if isinstance(fill, (int, float)):
|
|
fill = [float(fill)] * channels
|
|
elif fill is not None:
|
|
fill = [float(f) for f in fill]
|
|
else:
|
|
img = self._pil_to_tensor(orig_img)
|
|
|
|
op_meta = self._augmentation_space(self._PARAMETER_MAX, (height, width))
|
|
|
|
orig_dims = list(img.shape)
|
|
batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims)
|
|
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
|
|
|
|
# Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet
|
|
# with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image.
|
|
m = self._sample_dirichlet(
|
|
torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
|
|
)
|
|
|
|
# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images.
|
|
combined_weights = self._sample_dirichlet(
|
|
torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
|
|
) * m[:, 1].view([batch_dims[0], -1])
|
|
|
|
mix = m[:, 0].view(batch_dims) * batch
|
|
for i in range(self.mixture_width):
|
|
aug = batch
|
|
depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
|
|
for _ in range(depth):
|
|
op_index = int(torch.randint(len(op_meta), (1,)).item())
|
|
op_name = list(op_meta.keys())[op_index]
|
|
magnitudes, signed = op_meta[op_name]
|
|
magnitude = (
|
|
float(magnitudes[torch.randint(self.severity, (1,), dtype=torch.long)].item())
|
|
if magnitudes.ndim > 0
|
|
else 0.0
|
|
)
|
|
if signed and torch.randint(2, (1,)):
|
|
magnitude *= -1.0
|
|
aug = _apply_op(aug, op_name, magnitude, interpolation=self.interpolation, fill=fill)
|
|
mix.add_(combined_weights[:, i].view(batch_dims) * aug)
|
|
mix = mix.view(orig_dims).to(dtype=img.dtype)
|
|
|
|
if not isinstance(orig_img, Tensor):
|
|
return self._tensor_to_pil(mix)
|
|
return mix
|
|
|
|
def __repr__(self) -> str:
|
|
s = (
|
|
f"{self.__class__.__name__}("
|
|
f"severity={self.severity}"
|
|
f", mixture_width={self.mixture_width}"
|
|
f", chain_depth={self.chain_depth}"
|
|
f", alpha={self.alpha}"
|
|
f", all_ops={self.all_ops}"
|
|
f", interpolation={self.interpolation}"
|
|
f", fill={self.fill}"
|
|
f")"
|
|
)
|
|
return s
|