652 lines
26 KiB
Python
652 lines
26 KiB
Python
|
import collections
|
||
|
import math
|
||
|
import pathlib
|
||
|
import warnings
|
||
|
from itertools import repeat
|
||
|
from types import FunctionType
|
||
|
from typing import Any, BinaryIO, List, Optional, Tuple, Union
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from PIL import Image, ImageColor, ImageDraw, ImageFont
|
||
|
|
||
|
|
||
|
__all__ = [
|
||
|
"make_grid",
|
||
|
"save_image",
|
||
|
"draw_bounding_boxes",
|
||
|
"draw_segmentation_masks",
|
||
|
"draw_keypoints",
|
||
|
"flow_to_image",
|
||
|
]
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def make_grid(
|
||
|
tensor: Union[torch.Tensor, List[torch.Tensor]],
|
||
|
nrow: int = 8,
|
||
|
padding: int = 2,
|
||
|
normalize: bool = False,
|
||
|
value_range: Optional[Tuple[int, int]] = None,
|
||
|
scale_each: bool = False,
|
||
|
pad_value: float = 0.0,
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Make a grid of images.
|
||
|
|
||
|
Args:
|
||
|
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
|
||
|
or a list of images all of the same size.
|
||
|
nrow (int, optional): Number of images displayed in each row of the grid.
|
||
|
The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
|
||
|
padding (int, optional): amount of padding. Default: ``2``.
|
||
|
normalize (bool, optional): If True, shift the image to the range (0, 1),
|
||
|
by the min and max values specified by ``value_range``. Default: ``False``.
|
||
|
value_range (tuple, optional): tuple (min, max) where min and max are numbers,
|
||
|
then these numbers are used to normalize the image. By default, min and max
|
||
|
are computed from the tensor.
|
||
|
scale_each (bool, optional): If ``True``, scale each image in the batch of
|
||
|
images separately rather than the (min, max) over all images. Default: ``False``.
|
||
|
pad_value (float, optional): Value for the padded pixels. Default: ``0``.
|
||
|
|
||
|
Returns:
|
||
|
grid (Tensor): the tensor containing grid of images.
|
||
|
"""
|
||
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||
|
_log_api_usage_once(make_grid)
|
||
|
if not torch.is_tensor(tensor):
|
||
|
if isinstance(tensor, list):
|
||
|
for t in tensor:
|
||
|
if not torch.is_tensor(t):
|
||
|
raise TypeError(f"tensor or list of tensors expected, got a list containing {type(t)}")
|
||
|
else:
|
||
|
raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
|
||
|
|
||
|
# if list of tensors, convert to a 4D mini-batch Tensor
|
||
|
if isinstance(tensor, list):
|
||
|
tensor = torch.stack(tensor, dim=0)
|
||
|
|
||
|
if tensor.dim() == 2: # single image H x W
|
||
|
tensor = tensor.unsqueeze(0)
|
||
|
if tensor.dim() == 3: # single image
|
||
|
if tensor.size(0) == 1: # if single-channel, convert to 3-channel
|
||
|
tensor = torch.cat((tensor, tensor, tensor), 0)
|
||
|
tensor = tensor.unsqueeze(0)
|
||
|
|
||
|
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
|
||
|
tensor = torch.cat((tensor, tensor, tensor), 1)
|
||
|
|
||
|
if normalize is True:
|
||
|
tensor = tensor.clone() # avoid modifying tensor in-place
|
||
|
if value_range is not None and not isinstance(value_range, tuple):
|
||
|
raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers")
|
||
|
|
||
|
def norm_ip(img, low, high):
|
||
|
img.clamp_(min=low, max=high)
|
||
|
img.sub_(low).div_(max(high - low, 1e-5))
|
||
|
|
||
|
def norm_range(t, value_range):
|
||
|
if value_range is not None:
|
||
|
norm_ip(t, value_range[0], value_range[1])
|
||
|
else:
|
||
|
norm_ip(t, float(t.min()), float(t.max()))
|
||
|
|
||
|
if scale_each is True:
|
||
|
for t in tensor: # loop over mini-batch dimension
|
||
|
norm_range(t, value_range)
|
||
|
else:
|
||
|
norm_range(tensor, value_range)
|
||
|
|
||
|
if not isinstance(tensor, torch.Tensor):
|
||
|
raise TypeError("tensor should be of type torch.Tensor")
|
||
|
if tensor.size(0) == 1:
|
||
|
return tensor.squeeze(0)
|
||
|
|
||
|
# make the mini-batch of images into a grid
|
||
|
nmaps = tensor.size(0)
|
||
|
xmaps = min(nrow, nmaps)
|
||
|
ymaps = int(math.ceil(float(nmaps) / xmaps))
|
||
|
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
|
||
|
num_channels = tensor.size(1)
|
||
|
grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
|
||
|
k = 0
|
||
|
for y in range(ymaps):
|
||
|
for x in range(xmaps):
|
||
|
if k >= nmaps:
|
||
|
break
|
||
|
# Tensor.copy_() is a valid method but seems to be missing from the stubs
|
||
|
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
|
||
|
grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined]
|
||
|
2, x * width + padding, width - padding
|
||
|
).copy_(tensor[k])
|
||
|
k = k + 1
|
||
|
return grid
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def save_image(
|
||
|
tensor: Union[torch.Tensor, List[torch.Tensor]],
|
||
|
fp: Union[str, pathlib.Path, BinaryIO],
|
||
|
format: Optional[str] = None,
|
||
|
**kwargs,
|
||
|
) -> None:
|
||
|
"""
|
||
|
Save a given Tensor into an image file.
|
||
|
|
||
|
Args:
|
||
|
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
|
||
|
saves the tensor as a grid of images by calling ``make_grid``.
|
||
|
fp (string or file object): A filename or a file object
|
||
|
format(Optional): If omitted, the format to use is determined from the filename extension.
|
||
|
If a file object was used instead of a filename, this parameter should always be used.
|
||
|
**kwargs: Other arguments are documented in ``make_grid``.
|
||
|
"""
|
||
|
|
||
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||
|
_log_api_usage_once(save_image)
|
||
|
grid = make_grid(tensor, **kwargs)
|
||
|
# Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer
|
||
|
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
||
|
im = Image.fromarray(ndarr)
|
||
|
im.save(fp, format=format)
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def draw_bounding_boxes(
|
||
|
image: torch.Tensor,
|
||
|
boxes: torch.Tensor,
|
||
|
labels: Optional[List[str]] = None,
|
||
|
colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
|
||
|
fill: Optional[bool] = False,
|
||
|
width: int = 1,
|
||
|
font: Optional[str] = None,
|
||
|
font_size: Optional[int] = None,
|
||
|
) -> torch.Tensor:
|
||
|
|
||
|
"""
|
||
|
Draws bounding boxes on given image.
|
||
|
The values of the input image should be uint8 between 0 and 255.
|
||
|
If fill is True, Resulting Tensor should be saved as PNG image.
|
||
|
|
||
|
Args:
|
||
|
image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
|
||
|
boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
|
||
|
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
|
||
|
`0 <= ymin < ymax < H`.
|
||
|
labels (List[str]): List containing the labels of bounding boxes.
|
||
|
colors (color or list of colors, optional): List containing the colors
|
||
|
of the boxes or single color for all boxes. The color can be represented as
|
||
|
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
|
||
|
By default, random colors are generated for boxes.
|
||
|
fill (bool): If `True` fills the bounding box with specified color.
|
||
|
width (int): Width of bounding box.
|
||
|
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
|
||
|
also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
|
||
|
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
|
||
|
font_size (int): The requested font size in points.
|
||
|
|
||
|
Returns:
|
||
|
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
|
||
|
"""
|
||
|
|
||
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||
|
_log_api_usage_once(draw_bounding_boxes)
|
||
|
if not isinstance(image, torch.Tensor):
|
||
|
raise TypeError(f"Tensor expected, got {type(image)}")
|
||
|
elif image.dtype != torch.uint8:
|
||
|
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
|
||
|
elif image.dim() != 3:
|
||
|
raise ValueError("Pass individual images, not batches")
|
||
|
elif image.size(0) not in {1, 3}:
|
||
|
raise ValueError("Only grayscale and RGB images are supported")
|
||
|
elif (boxes[:, 0] > boxes[:, 2]).any() or (boxes[:, 1] > boxes[:, 3]).any():
|
||
|
raise ValueError(
|
||
|
"Boxes need to be in (xmin, ymin, xmax, ymax) format. Use torchvision.ops.box_convert to convert them"
|
||
|
)
|
||
|
|
||
|
num_boxes = boxes.shape[0]
|
||
|
|
||
|
if num_boxes == 0:
|
||
|
warnings.warn("boxes doesn't contain any box. No box was drawn")
|
||
|
return image
|
||
|
|
||
|
if labels is None:
|
||
|
labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef]
|
||
|
elif len(labels) != num_boxes:
|
||
|
raise ValueError(
|
||
|
f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box."
|
||
|
)
|
||
|
|
||
|
colors = _parse_colors(colors, num_objects=num_boxes)
|
||
|
|
||
|
if font is None:
|
||
|
if font_size is not None:
|
||
|
warnings.warn("Argument 'font_size' will be ignored since 'font' is not set.")
|
||
|
txt_font = ImageFont.load_default()
|
||
|
else:
|
||
|
txt_font = ImageFont.truetype(font=font, size=font_size or 10)
|
||
|
|
||
|
# Handle Grayscale images
|
||
|
if image.size(0) == 1:
|
||
|
image = torch.tile(image, (3, 1, 1))
|
||
|
|
||
|
ndarr = image.permute(1, 2, 0).cpu().numpy()
|
||
|
img_to_draw = Image.fromarray(ndarr)
|
||
|
img_boxes = boxes.to(torch.int64).tolist()
|
||
|
|
||
|
if fill:
|
||
|
draw = ImageDraw.Draw(img_to_draw, "RGBA")
|
||
|
else:
|
||
|
draw = ImageDraw.Draw(img_to_draw)
|
||
|
|
||
|
for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type]
|
||
|
if fill:
|
||
|
fill_color = color + (100,)
|
||
|
draw.rectangle(bbox, width=width, outline=color, fill=fill_color)
|
||
|
else:
|
||
|
draw.rectangle(bbox, width=width, outline=color)
|
||
|
|
||
|
if label is not None:
|
||
|
margin = width + 1
|
||
|
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font)
|
||
|
|
||
|
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def draw_segmentation_masks(
|
||
|
image: torch.Tensor,
|
||
|
masks: torch.Tensor,
|
||
|
alpha: float = 0.8,
|
||
|
colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
|
||
|
) -> torch.Tensor:
|
||
|
|
||
|
"""
|
||
|
Draws segmentation masks on given RGB image.
|
||
|
The image values should be uint8 in [0, 255] or float in [0, 1].
|
||
|
|
||
|
Args:
|
||
|
image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float.
|
||
|
masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
|
||
|
alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
|
||
|
0 means full transparency, 1 means no transparency.
|
||
|
colors (color or list of colors, optional): List containing the colors
|
||
|
of the masks or single color for all masks. The color can be represented as
|
||
|
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
|
||
|
By default, random colors are generated for each mask.
|
||
|
|
||
|
Returns:
|
||
|
img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
|
||
|
"""
|
||
|
|
||
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||
|
_log_api_usage_once(draw_segmentation_masks)
|
||
|
if not isinstance(image, torch.Tensor):
|
||
|
raise TypeError(f"The image must be a tensor, got {type(image)}")
|
||
|
elif not (image.dtype == torch.uint8 or image.is_floating_point()):
|
||
|
raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
|
||
|
elif image.dim() != 3:
|
||
|
raise ValueError("Pass individual images, not batches")
|
||
|
elif image.size()[0] != 3:
|
||
|
raise ValueError("Pass an RGB image. Other Image formats are not supported")
|
||
|
if masks.ndim == 2:
|
||
|
masks = masks[None, :, :]
|
||
|
if masks.ndim != 3:
|
||
|
raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)")
|
||
|
if masks.dtype != torch.bool:
|
||
|
raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}")
|
||
|
if masks.shape[-2:] != image.shape[-2:]:
|
||
|
raise ValueError("The image and the masks must have the same height and width")
|
||
|
|
||
|
num_masks = masks.size()[0]
|
||
|
overlapping_masks = masks.sum(dim=0) > 1
|
||
|
|
||
|
if num_masks == 0:
|
||
|
warnings.warn("masks doesn't contain any mask. No mask was drawn")
|
||
|
return image
|
||
|
|
||
|
original_dtype = image.dtype
|
||
|
colors = [
|
||
|
torch.tensor(color, dtype=original_dtype, device=image.device)
|
||
|
for color in _parse_colors(colors, num_objects=num_masks, dtype=original_dtype)
|
||
|
]
|
||
|
|
||
|
img_to_draw = image.detach().clone()
|
||
|
# TODO: There might be a way to vectorize this
|
||
|
for mask, color in zip(masks, colors):
|
||
|
img_to_draw[:, mask] = color[:, None]
|
||
|
|
||
|
img_to_draw[:, overlapping_masks] = 0
|
||
|
|
||
|
out = image * (1 - alpha) + img_to_draw * alpha
|
||
|
# Note: at this point, out is a float tensor in [0, 1] or [0, 255] depending on original_dtype
|
||
|
return out.to(original_dtype)
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def draw_keypoints(
|
||
|
image: torch.Tensor,
|
||
|
keypoints: torch.Tensor,
|
||
|
connectivity: Optional[List[Tuple[int, int]]] = None,
|
||
|
colors: Optional[Union[str, Tuple[int, int, int]]] = None,
|
||
|
radius: int = 2,
|
||
|
width: int = 3,
|
||
|
visibility: Optional[torch.Tensor] = None,
|
||
|
) -> torch.Tensor:
|
||
|
|
||
|
"""
|
||
|
Draws Keypoints on given RGB image.
|
||
|
The image values should be uint8 in [0, 255] or float in [0, 1].
|
||
|
Keypoints can be drawn for multiple instances at a time.
|
||
|
|
||
|
This method allows that keypoints and their connectivity are drawn based on the visibility of this keypoint.
|
||
|
|
||
|
Args:
|
||
|
image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float.
|
||
|
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoint locations for each of the N instances,
|
||
|
in the format [x, y].
|
||
|
connectivity (List[Tuple[int, int]]]): A List of tuple where each tuple contains a pair of keypoints
|
||
|
to be connected.
|
||
|
If at least one of the two connected keypoints has a ``visibility`` of False,
|
||
|
this specific connection is not drawn.
|
||
|
Exclusions due to invisibility are computed per-instance.
|
||
|
colors (str, Tuple): The color can be represented as
|
||
|
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
|
||
|
radius (int): Integer denoting radius of keypoint.
|
||
|
width (int): Integer denoting width of line connecting keypoints.
|
||
|
visibility (Tensor): Tensor of shape (num_instances, K) specifying the visibility of the K
|
||
|
keypoints for each of the N instances.
|
||
|
True means that the respective keypoint is visible and should be drawn.
|
||
|
False means invisible, so neither the point nor possible connections containing it are drawn.
|
||
|
The input tensor will be cast to bool.
|
||
|
Default ``None`` means that all the keypoints are visible.
|
||
|
For more details, see :ref:`draw_keypoints_with_visibility`.
|
||
|
|
||
|
Returns:
|
||
|
img (Tensor[C, H, W]): Image Tensor with keypoints drawn.
|
||
|
"""
|
||
|
|
||
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||
|
_log_api_usage_once(draw_keypoints)
|
||
|
# validate image
|
||
|
if not isinstance(image, torch.Tensor):
|
||
|
raise TypeError(f"The image must be a tensor, got {type(image)}")
|
||
|
elif not (image.dtype == torch.uint8 or image.is_floating_point()):
|
||
|
raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
|
||
|
elif image.dim() != 3:
|
||
|
raise ValueError("Pass individual images, not batches")
|
||
|
elif image.size()[0] != 3:
|
||
|
raise ValueError("Pass an RGB image. Other Image formats are not supported")
|
||
|
|
||
|
# validate keypoints
|
||
|
if keypoints.ndim != 3:
|
||
|
raise ValueError("keypoints must be of shape (num_instances, K, 2)")
|
||
|
|
||
|
# validate visibility
|
||
|
if visibility is None: # set default
|
||
|
visibility = torch.ones(keypoints.shape[:-1], dtype=torch.bool)
|
||
|
# If the last dimension is 1, e.g., after calling split([2, 1], dim=-1) on the output of a keypoint-prediction
|
||
|
# model, make sure visibility has shape (num_instances, K).
|
||
|
# Iff K = 1, this has unwanted behavior, but K=1 does not really make sense in the first place.
|
||
|
visibility = visibility.squeeze(-1)
|
||
|
if visibility.ndim != 2:
|
||
|
raise ValueError(f"visibility must be of shape (num_instances, K). Got ndim={visibility.ndim}")
|
||
|
if visibility.shape != keypoints.shape[:-1]:
|
||
|
raise ValueError(
|
||
|
"keypoints and visibility must have the same dimensionality for num_instances and K. "
|
||
|
f"Got {visibility.shape = } and {keypoints.shape = }"
|
||
|
)
|
||
|
|
||
|
original_dtype = image.dtype
|
||
|
if original_dtype.is_floating_point:
|
||
|
from torchvision.transforms.v2.functional import to_dtype # noqa
|
||
|
|
||
|
image = to_dtype(image, dtype=torch.uint8, scale=True)
|
||
|
|
||
|
ndarr = image.permute(1, 2, 0).cpu().numpy()
|
||
|
img_to_draw = Image.fromarray(ndarr)
|
||
|
draw = ImageDraw.Draw(img_to_draw)
|
||
|
img_kpts = keypoints.to(torch.int64).tolist()
|
||
|
img_vis = visibility.cpu().bool().tolist()
|
||
|
|
||
|
for kpt_inst, vis_inst in zip(img_kpts, img_vis):
|
||
|
for kpt_coord, kp_vis in zip(kpt_inst, vis_inst):
|
||
|
if not kp_vis:
|
||
|
continue
|
||
|
x1 = kpt_coord[0] - radius
|
||
|
x2 = kpt_coord[0] + radius
|
||
|
y1 = kpt_coord[1] - radius
|
||
|
y2 = kpt_coord[1] + radius
|
||
|
draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0)
|
||
|
|
||
|
if connectivity:
|
||
|
for connection in connectivity:
|
||
|
if (not vis_inst[connection[0]]) or (not vis_inst[connection[1]]):
|
||
|
continue
|
||
|
start_pt_x = kpt_inst[connection[0]][0]
|
||
|
start_pt_y = kpt_inst[connection[0]][1]
|
||
|
|
||
|
end_pt_x = kpt_inst[connection[1]][0]
|
||
|
end_pt_y = kpt_inst[connection[1]][1]
|
||
|
|
||
|
draw.line(
|
||
|
((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)),
|
||
|
width=width,
|
||
|
)
|
||
|
|
||
|
out = torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)
|
||
|
if original_dtype.is_floating_point:
|
||
|
out = to_dtype(out, dtype=original_dtype, scale=True)
|
||
|
return out
|
||
|
|
||
|
|
||
|
# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
|
||
|
@torch.no_grad()
|
||
|
def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
|
||
|
|
||
|
"""
|
||
|
Converts a flow to an RGB image.
|
||
|
|
||
|
Args:
|
||
|
flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
|
||
|
|
||
|
Returns:
|
||
|
img (Tensor): Image Tensor of dtype uint8 where each color corresponds
|
||
|
to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
|
||
|
"""
|
||
|
|
||
|
if flow.dtype != torch.float:
|
||
|
raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")
|
||
|
|
||
|
orig_shape = flow.shape
|
||
|
if flow.ndim == 3:
|
||
|
flow = flow[None] # Add batch dim
|
||
|
|
||
|
if flow.ndim != 4 or flow.shape[1] != 2:
|
||
|
raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")
|
||
|
|
||
|
max_norm = torch.sum(flow**2, dim=1).sqrt().max()
|
||
|
epsilon = torch.finfo((flow).dtype).eps
|
||
|
normalized_flow = flow / (max_norm + epsilon)
|
||
|
img = _normalized_flow_to_image(normalized_flow)
|
||
|
|
||
|
if len(orig_shape) == 3:
|
||
|
img = img[0] # Remove batch dim
|
||
|
return img
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
|
||
|
|
||
|
"""
|
||
|
Converts a batch of normalized flow to an RGB image.
|
||
|
|
||
|
Args:
|
||
|
normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
|
||
|
Returns:
|
||
|
img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
|
||
|
"""
|
||
|
|
||
|
N, _, H, W = normalized_flow.shape
|
||
|
device = normalized_flow.device
|
||
|
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device)
|
||
|
colorwheel = _make_colorwheel().to(device) # shape [55x3]
|
||
|
num_cols = colorwheel.shape[0]
|
||
|
norm = torch.sum(normalized_flow**2, dim=1).sqrt()
|
||
|
a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
|
||
|
fk = (a + 1) / 2 * (num_cols - 1)
|
||
|
k0 = torch.floor(fk).to(torch.long)
|
||
|
k1 = k0 + 1
|
||
|
k1[k1 == num_cols] = 0
|
||
|
f = fk - k0
|
||
|
|
||
|
for c in range(colorwheel.shape[1]):
|
||
|
tmp = colorwheel[:, c]
|
||
|
col0 = tmp[k0] / 255.0
|
||
|
col1 = tmp[k1] / 255.0
|
||
|
col = (1 - f) * col0 + f * col1
|
||
|
col = 1 - norm * (1 - col)
|
||
|
flow_image[:, c, :, :] = torch.floor(255 * col)
|
||
|
return flow_image
|
||
|
|
||
|
|
||
|
def _make_colorwheel() -> torch.Tensor:
|
||
|
"""
|
||
|
Generates a color wheel for optical flow visualization as presented in:
|
||
|
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
||
|
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf.
|
||
|
|
||
|
Returns:
|
||
|
colorwheel (Tensor[55, 3]): Colorwheel Tensor.
|
||
|
"""
|
||
|
|
||
|
RY = 15
|
||
|
YG = 6
|
||
|
GC = 4
|
||
|
CB = 11
|
||
|
BM = 13
|
||
|
MR = 6
|
||
|
|
||
|
ncols = RY + YG + GC + CB + BM + MR
|
||
|
colorwheel = torch.zeros((ncols, 3))
|
||
|
col = 0
|
||
|
|
||
|
# RY
|
||
|
colorwheel[0:RY, 0] = 255
|
||
|
colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY)
|
||
|
col = col + RY
|
||
|
# YG
|
||
|
colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG)
|
||
|
colorwheel[col : col + YG, 1] = 255
|
||
|
col = col + YG
|
||
|
# GC
|
||
|
colorwheel[col : col + GC, 1] = 255
|
||
|
colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC)
|
||
|
col = col + GC
|
||
|
# CB
|
||
|
colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB)
|
||
|
colorwheel[col : col + CB, 2] = 255
|
||
|
col = col + CB
|
||
|
# BM
|
||
|
colorwheel[col : col + BM, 2] = 255
|
||
|
colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM)
|
||
|
col = col + BM
|
||
|
# MR
|
||
|
colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR)
|
||
|
colorwheel[col : col + MR, 0] = 255
|
||
|
return colorwheel
|
||
|
|
||
|
|
||
|
def _generate_color_palette(num_objects: int):
|
||
|
palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1])
|
||
|
return [tuple((i * palette) % 255) for i in range(num_objects)]
|
||
|
|
||
|
|
||
|
def _parse_colors(
|
||
|
colors: Union[None, str, Tuple[int, int, int], List[Union[str, Tuple[int, int, int]]]],
|
||
|
*,
|
||
|
num_objects: int,
|
||
|
dtype: torch.dtype = torch.uint8,
|
||
|
) -> List[Tuple[int, int, int]]:
|
||
|
"""
|
||
|
Parses a specification of colors for a set of objects.
|
||
|
|
||
|
Args:
|
||
|
colors: A specification of colors for the objects. This can be one of the following:
|
||
|
- None: to generate a color palette automatically.
|
||
|
- A list of colors: where each color is either a string (specifying a named color) or an RGB tuple.
|
||
|
- A string or an RGB tuple: to use the same color for all objects.
|
||
|
|
||
|
If `colors` is a tuple, it should be a 3-tuple specifying the RGB values of the color.
|
||
|
If `colors` is a list, it should have at least as many elements as the number of objects to color.
|
||
|
|
||
|
num_objects (int): The number of objects to color.
|
||
|
|
||
|
Returns:
|
||
|
A list of 3-tuples, specifying the RGB values of the colors.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If the number of colors in the list is less than the number of objects to color.
|
||
|
If `colors` is not a list, tuple, string or None.
|
||
|
"""
|
||
|
if colors is None:
|
||
|
colors = _generate_color_palette(num_objects)
|
||
|
elif isinstance(colors, list):
|
||
|
if len(colors) < num_objects:
|
||
|
raise ValueError(
|
||
|
f"Number of colors must be equal or larger than the number of objects, but got {len(colors)} < {num_objects}."
|
||
|
)
|
||
|
elif not isinstance(colors, (tuple, str)):
|
||
|
raise ValueError("`colors` must be a tuple or a string, or a list thereof, but got {colors}.")
|
||
|
elif isinstance(colors, tuple) and len(colors) != 3:
|
||
|
raise ValueError("If passed as tuple, colors should be an RGB triplet, but got {colors}.")
|
||
|
else: # colors specifies a single color for all objects
|
||
|
colors = [colors] * num_objects
|
||
|
|
||
|
colors = [ImageColor.getrgb(color) if isinstance(color, str) else color for color in colors]
|
||
|
if dtype.is_floating_point: # [0, 255] -> [0, 1]
|
||
|
colors = [tuple(v / 255 for v in color) for color in colors]
|
||
|
return colors
|
||
|
|
||
|
|
||
|
def _log_api_usage_once(obj: Any) -> None:
|
||
|
|
||
|
"""
|
||
|
Logs API usage(module and name) within an organization.
|
||
|
In a large ecosystem, it's often useful to track the PyTorch and
|
||
|
TorchVision APIs usage. This API provides the similar functionality to the
|
||
|
logging module in the Python stdlib. It can be used for debugging purpose
|
||
|
to log which methods are used and by default it is inactive, unless the user
|
||
|
manually subscribes a logger via the `SetAPIUsageLogger method <https://github.com/pytorch/pytorch/blob/eb3b9fe719b21fae13c7a7cf3253f970290a573e/c10/util/Logging.cpp#L114>`_.
|
||
|
Please note it is triggered only once for the same API call within a process.
|
||
|
It does not collect any data from open-source users since it is no-op by default.
|
||
|
For more information, please refer to
|
||
|
* PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging;
|
||
|
* Logging policy: https://github.com/pytorch/vision/issues/5052;
|
||
|
|
||
|
Args:
|
||
|
obj (class instance or method): an object to extract info from.
|
||
|
"""
|
||
|
module = obj.__module__
|
||
|
if not module.startswith("torchvision"):
|
||
|
module = f"torchvision.internal.{module}"
|
||
|
name = obj.__class__.__name__
|
||
|
if isinstance(obj, FunctionType):
|
||
|
name = obj.__name__
|
||
|
torch._C._log_api_usage_once(f"{module}.{name}")
|
||
|
|
||
|
|
||
|
def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
|
||
|
"""
|
||
|
Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
|
||
|
Otherwise, we will make a tuple of length n, all with value of x.
|
||
|
reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
|
||
|
|
||
|
Args:
|
||
|
x (Any): input value
|
||
|
n (int): length of the resulting tuple
|
||
|
"""
|
||
|
if isinstance(x, collections.abc.Iterable):
|
||
|
return tuple(x)
|
||
|
return tuple(repeat(x, n))
|