FocalLoss¶
- class torch_ecg.models.loss.FocalLoss(gamma: float = 2.0, weight: Tensor | None = None, class_weight: Tensor | None = None, size_average: bool | None = None, reduce: bool | None = None, reduction: str = 'mean', multi_label: bool = True, **kwargs: Any)[source]¶
Bases:
_WeightedLoss
Focal loss class.
The focal loss is proposed in [1], and this implementation is based on [2], [3], and [4]. The focal loss is computed as follows:
\[\operatorname{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \log(p_t)\]Where:
\(p_t\) is the model’s estimated probability for each class.
- Parameters:
gamma (float, default 2.0) – The gamma parameter of focal loss.
weight (torch.Tensor, optional) – If multi_label is True, is a manual rescaling weight given to the loss of each batch element, of size
batch_size
; if multi_label is False, is a weight for each class, of sizen_classes
.class_weight (torch.Tensor, optional) – The class weight, of shape
(1, n_classes)
.size_average (bool, optional) – Not used, to keep in accordance with PyTorch native loss.
reduce (bool, optional) – Not used, to keep in accordance with PyTorch native loss.
reduction ({"none", "mean", "sum"}, optional) – Specifies the reduction to apply to the output, by default “mean”.
multi_label (bool, default True) – If True, the loss is computed for multi-label classification.
References
- forward(input: Tensor, target: Tensor) Tensor [source]¶
Forward pass.
- Parameters:
input (torch.Tensor) – The predicted value tensor (before sigmoid), of shape
(batch_size, n_classes)
.target (torch.Tensor) – Multi-label binarized vector of shape
(batch_size, n_classes)
, or single label binarized vector of shape(batch_size,)
.
- Returns:
The focal loss.
- Return type: