95 lines
3.3 KiB
Python
95 lines
3.3 KiB
Python
from typing import Tuple
|
|
|
|
import torch
|
|
|
|
from ..utils import _log_api_usage_once
|
|
from ._utils import _loss_inter_union, _upcast_non_float
|
|
|
|
|
|
def distance_box_iou_loss(
|
|
boxes1: torch.Tensor,
|
|
boxes2: torch.Tensor,
|
|
reduction: str = "none",
|
|
eps: float = 1e-7,
|
|
) -> torch.Tensor:
|
|
|
|
"""
|
|
Gradient-friendly IoU loss with an additional penalty that is non-zero when the
|
|
distance between boxes' centers isn't zero. Indeed, for two exactly overlapping
|
|
boxes, the distance IoU is the same as the IoU loss.
|
|
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable.
|
|
|
|
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
|
|
``0 <= x1 < x2`` and ``0 <= y1 < y2``, and The two boxes should have the
|
|
same dimensions.
|
|
|
|
Args:
|
|
boxes1 (Tensor[N, 4]): first set of boxes
|
|
boxes2 (Tensor[N, 4]): second set of boxes
|
|
reduction (string, optional): Specifies the reduction to apply to the output:
|
|
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be
|
|
applied to the output. ``'mean'``: The output will be averaged.
|
|
``'sum'``: The output will be summed. Default: ``'none'``
|
|
eps (float, optional): small number to prevent division by zero. Default: 1e-7
|
|
|
|
Returns:
|
|
Tensor: Loss tensor with the reduction option applied.
|
|
|
|
Reference:
|
|
Zhaohui Zheng et al.: Distance Intersection over Union Loss:
|
|
https://arxiv.org/abs/1911.08287
|
|
"""
|
|
|
|
# Original Implementation from https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py
|
|
|
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
|
_log_api_usage_once(distance_box_iou_loss)
|
|
|
|
boxes1 = _upcast_non_float(boxes1)
|
|
boxes2 = _upcast_non_float(boxes2)
|
|
|
|
loss, _ = _diou_iou_loss(boxes1, boxes2, eps)
|
|
|
|
# Check reduction option and return loss accordingly
|
|
if reduction == "none":
|
|
pass
|
|
elif reduction == "mean":
|
|
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
|
|
elif reduction == "sum":
|
|
loss = loss.sum()
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
|
|
)
|
|
return loss
|
|
|
|
|
|
def _diou_iou_loss(
|
|
boxes1: torch.Tensor,
|
|
boxes2: torch.Tensor,
|
|
eps: float = 1e-7,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
intsct, union = _loss_inter_union(boxes1, boxes2)
|
|
iou = intsct / (union + eps)
|
|
# smallest enclosing box
|
|
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
|
|
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
|
|
xc1 = torch.min(x1, x1g)
|
|
yc1 = torch.min(y1, y1g)
|
|
xc2 = torch.max(x2, x2g)
|
|
yc2 = torch.max(y2, y2g)
|
|
# The diagonal distance of the smallest enclosing box squared
|
|
diagonal_distance_squared = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps
|
|
# centers of boxes
|
|
x_p = (x2 + x1) / 2
|
|
y_p = (y2 + y1) / 2
|
|
x_g = (x1g + x2g) / 2
|
|
y_g = (y1g + y2g) / 2
|
|
# The distance between boxes' centers squared.
|
|
centers_distance_squared = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2)
|
|
# The distance IoU is the IoU penalized by a normalized
|
|
# distance between boxes' centers squared.
|
|
loss = 1 - iou + (centers_distance_squared / diagonal_distance_squared)
|
|
return loss, iou
|