BCEWithLogitsWithClassWeightLoss¶
- class torch_ecg.models.loss.BCEWithLogitsWithClassWeightLoss(class_weight: Tensor)[source]¶
Bases:
BCEWithLogitsLoss
Class-weighted Binary Cross Entropy Loss class.
- Parameters:
class_weight (torch.Tensor) – Class weight, of shape
(1, n_classes)
.
- forward(input: Tensor, target: Tensor) Tensor [source]¶
Forward pass.
- Parameters:
input (torch.Tensor) – The predicted value tensor (before sigmoid), of shape
(batch_size, ..., n_classes)
.target (torch.Tensor) – The target tensor, of shape
(batch_size, ..., n_classes)
.
- Returns:
The class-weighted binary cross entropy loss.
- Return type: