MaskedBCEWithLogitsLoss¶
- class torch_ecg.models.loss.MaskedBCEWithLogitsLoss[source]¶
Bases:
BCEWithLogitsLoss
Masked Binary Cross Entropy Loss class.
This loss is used mainly for the segmentation task, where there are some regions that are of much higher importance, for example, the onsets and offsets of some particular events (e.g. paroxysmal atrial fibrillation (AF) episodes).
This loss is proposed in [1], with a reference to the loss function used in the U-Net paper [2].
References
- forward(input: Tensor, target: Tensor, weight_mask: Tensor) Tensor [source]¶
Forward pass.
- Parameters:
input (torch.Tensor) – The predicted value tensor (before sigmoid), of shape
(batch_size, sig_len, n_classes)
.target (torch.Tensor) – The target tensor, of shape
(batch_size, sig_len, n_classes)
.weight_mask (torch.Tensor) – The weight mask tensor, of shape
(batch_size, sig_len, n_classes)
.
- Returns:
The masked binary cross entropy loss.
- Return type:
Note
input, target, and weight_mask should be 3-D tensors of the same shape.