59 lines
2.2 KiB
Python
59 lines
2.2 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from ..utils import _log_api_usage_once
|
|
|
|
|
|
def sigmoid_focal_loss(
|
|
inputs: torch.Tensor,
|
|
targets: torch.Tensor,
|
|
alpha: float = 0.25,
|
|
gamma: float = 2,
|
|
reduction: str = "none",
|
|
) -> torch.Tensor:
|
|
"""
|
|
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
|
|
|
Args:
|
|
inputs (Tensor): A float tensor of arbitrary shape.
|
|
The predictions for each example.
|
|
targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
|
|
classification label for each element in inputs
|
|
(0 for the negative class and 1 for the positive class).
|
|
alpha (float): Weighting factor in range (0,1) to balance
|
|
positive vs negative examples or -1 for ignore. Default: ``0.25``.
|
|
gamma (float): Exponent of the modulating factor (1 - p_t) to
|
|
balance easy vs hard examples. Default: ``2``.
|
|
reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
|
|
``'none'``: No reduction will be applied to the output.
|
|
``'mean'``: The output will be averaged.
|
|
``'sum'``: The output will be summed. Default: ``'none'``.
|
|
Returns:
|
|
Loss tensor with the reduction option applied.
|
|
"""
|
|
# Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
|
|
|
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
|
_log_api_usage_once(sigmoid_focal_loss)
|
|
p = torch.sigmoid(inputs)
|
|
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
|
p_t = p * targets + (1 - p) * (1 - targets)
|
|
loss = ce_loss * ((1 - p_t) ** gamma)
|
|
|
|
if alpha >= 0:
|
|
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
|
loss = alpha_t * loss
|
|
|
|
# Check reduction option and return loss accordingly
|
|
if reduction == "none":
|
|
pass
|
|
elif reduction == "mean":
|
|
loss = loss.mean()
|
|
elif reduction == "sum":
|
|
loss = loss.sum()
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
|
|
)
|
|
return loss
|