from typing import List, Optional, Tuple, Union import torch from torch import nn, Tensor def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor: """ Efficient version of torch.cat that avoids a copy if there is only a single element in a list """ # TODO add back the assert # assert isinstance(tensors, (list, tuple)) if len(tensors) == 1: return tensors[0] return torch.cat(tensors, dim) def convert_boxes_to_roi_format(boxes: List[Tensor]) -> Tensor: concat_boxes = _cat([b for b in boxes], dim=0) temp = [] for i, b in enumerate(boxes): temp.append(torch.full_like(b[:, :1], i)) ids = _cat(temp, dim=0) rois = torch.cat([ids, concat_boxes], dim=1) return rois def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]): if isinstance(boxes, (list, tuple)): for _tensor in boxes: torch._assert( _tensor.size(1) == 4, "The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]" ) elif isinstance(boxes, torch.Tensor): torch._assert(boxes.size(1) == 5, "The boxes tensor shape is not correct as Tensor[K, 5]") else: torch._assert(False, "boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]") return def split_normalization_params( model: nn.Module, norm_classes: Optional[List[type]] = None ) -> Tuple[List[Tensor], List[Tensor]]: # Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501 if not norm_classes: norm_classes = [ nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm, nn.modules.instancenorm._InstanceNorm, nn.LocalResponseNorm, ] for t in norm_classes: if not issubclass(t, nn.Module): raise ValueError(f"Class {t} is not a subclass of nn.Module.") classes = tuple(norm_classes) norm_params = [] other_params = [] for module in model.modules(): if next(module.children(), None): other_params.extend(p for p in module.parameters(recurse=False) if p.requires_grad) elif isinstance(module, classes): norm_params.extend(p for p in module.parameters() if p.requires_grad) else: other_params.extend(p for p in module.parameters() if p.requires_grad) return norm_params, other_params def _upcast(t: Tensor) -> Tensor: # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type if t.is_floating_point(): return t if t.dtype in (torch.float32, torch.float64) else t.float() else: return t if t.dtype in (torch.int32, torch.int64) else t.int() def _upcast_non_float(t: Tensor) -> Tensor: # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type if t.dtype not in (torch.float32, torch.float64): return t.float() return t def _loss_inter_union( boxes1: torch.Tensor, boxes2: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: x1, y1, x2, y2 = boxes1.unbind(dim=-1) x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) # Intersection keypoints xkis1 = torch.max(x1, x1g) ykis1 = torch.max(y1, y1g) xkis2 = torch.min(x2, x2g) ykis2 = torch.min(y2, y2g) intsctk = torch.zeros_like(x1) mask = (ykis2 > ykis1) & (xkis2 > xkis1) intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk return intsctk, unionk