import torch import torch.nn.functional as F from ..utils import _log_api_usage_once def sigmoid_focal_loss( inputs: torch.Tensor, targets: torch.Tensor, alpha: float = 0.25, gamma: float = 2, reduction: str = "none", ) -> torch.Tensor: """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: inputs (Tensor): A float tensor of arbitrary shape. The predictions for each example. targets (Tensor): A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). alpha (float): Weighting factor in range (0,1) to balance positive vs negative examples or -1 for ignore. Default: ``0.25``. gamma (float): Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Default: ``2``. reduction (string): ``'none'`` | ``'mean'`` | ``'sum'`` ``'none'``: No reduction will be applied to the output. ``'mean'``: The output will be averaged. ``'sum'``: The output will be summed. Default: ``'none'``. Returns: Loss tensor with the reduction option applied. """ # Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(sigmoid_focal_loss) p = torch.sigmoid(inputs) ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") p_t = p * targets + (1 - p) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss # Check reduction option and return loss accordingly if reduction == "none": pass elif reduction == "mean": loss = loss.mean() elif reduction == "sum": loss = loss.sum() else: raise ValueError( f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'" ) return loss