RR_LSTM#
- class torch_ecg.models.RR_LSTM(classes: Sequence[str], config: CFG | None = None, **kwargs: Any)[source]#
Bases:
Module
,CkptMixin
,SizeMixin
,CitationMixin
LSTM model for RR time series classification or sequence labeling.
LSTM model using RR time series as input is studied in [Faust et al.[1]] for atrial fibrillation detection. It is further improved in [Wen et al.[2]] via incorporating attention mechanism and conditional random fields.
- Parameters:
References
- compute_output_shape(seq_len: int | None = None, batch_size: int | None = None) Sequence[int | None] [source]#
Compute the output shape of the model.
- forward(input: Tensor) Tensor [source]#
Forward pass of the model.
- Parameters:
input (torch.Tensor) – Input RR series tensor of shape
(seq_len, batch_size, n_channels)
, or(batch_size, n_channels, seq_len)
if config.batch_first is True.- Returns:
Output tensor, of shape
(batch_size, seq_len, n_classes)
or(batch_size, n_classes)
.- Return type:
- classmethod from_v1(v1_ckpt: str, device: device | None = None, return_config: bool = False) RR_LSTM | Tuple[RR_LSTM, dict] [source]#
Restore an instance of the model from a v1 checkpoint.
- Parameters:
v1_ckpt (str) – Path to the v1 checkpoint file.
device (torch.device, optional) – The device to load the model to. Defaults to “cuda” if available, otherwise “cpu”.
return_config (bool, default False) – Whether to return the config dict.
- Returns:
model – The model instance restored from the v1 checkpoint.
- Return type: