Source code for torch_ecg.models.rr_lstm

"""
AF (and perhaps other arrhythmias like preamature beats) detection
using rr time series as input and using lstm as model.

References
----------
1. https://github.com/al3xsh/rnn-based-af-detection

"""

import warnings
from copy import deepcopy
from typing import Any, List, Optional, Sequence, Union

import torch
from einops import rearrange
from einops.layers.torch import Rearrange
from torch import Tensor, nn

from ..cfg import CFG
from ..components.outputs import BaseOutput
from ..model_configs.rr_lstm import RR_LSTM_CONFIG
from ..models._nets import MLP, ExtendedCRF, GlobalContextBlock, NonLocalBlock, SEBlock, SelfAttention, StackedLSTM
from ..utils.misc import CitationMixin
from ..utils.utils_nn import CkptMixin, SizeMixin

__all__ = [
    "RR_LSTM",
]


[docs]class RR_LSTM(nn.Module, CkptMixin, SizeMixin, CitationMixin): """LSTM model for RR time series classification or sequence labeling. LSTM model using RR time series as input is studied in [:footcite:ct:`faust2018automated`] for atrial fibrillation detection. It is further improved in [:footcite:ct:`af_detection`] via incorporating attention mechanism and conditional random fields. Parameters ---------- classes : List[str] List of the names of the classes. config : dict Other hyper-parameters, including kernel sizes, etc. Refer to corresponding config file for details. .. footbibliography:: """ __name__ = "RR_LSTM" def __init__(self, classes: Sequence[str], config: Optional[CFG] = None, **kwargs: Any) -> None: super().__init__() self.classes = list(classes) self.n_classes = len(classes) self.config = deepcopy(RR_LSTM_CONFIG) if not config: warnings.warn("No config is provided, using default config.", RuntimeWarning) self.config.update(deepcopy(config) or {}) self.config.batch_first = self.config.get("batch_first", False) if self.config.batch_first: self.in_rearrange = Rearrange("batch_size n_channels seq_len -> seq_len batch_size n_channels") else: self.in_rearrange = nn.Identity() self.lstm = StackedLSTM( input_size=1, hidden_sizes=self.config.lstm.hidden_sizes, bias=self.config.lstm.bias, dropouts=self.config.lstm.dropouts, bidirectional=self.config.lstm.bidirectional, return_sequences=self.config.lstm.retseq, ) attn_input_size = self.lstm.compute_output_shape(None, None)[-1] if not self.config.lstm.retseq: if self.config.attn.name.lower() != "none": warnings.warn( "Attention is not supported when lstm is not returning sequences", RuntimeWarning, ) self.config.attn.name = "none" self.attn_in_rearrange = nn.Identity() self.attn = nn.Identity() self.attn_out_rearrange = nn.Identity() clf_input_size = attn_input_size elif self.config.attn.name.lower() == "none": self.attn_in_rearrange = Rearrange("seq_len batch_size n_channels -> batch_size n_channels seq_len") self.attn = nn.Identity() self.attn_out_rearrange = nn.Identity() clf_input_size = attn_input_size elif self.config.attn.name.lower() == "nl": # non_local self.attn_in_rearrange = Rearrange("seq_len batch_size n_channels -> batch_size n_channels seq_len") self.attn = NonLocalBlock( in_channels=attn_input_size, filter_lengths=self.config.attn.nl.filter_lengths, subsample_length=self.config.attn.nl.subsample_length, batch_norm=self.config.attn.nl.batch_norm, ) self.attn_out_rearrange = nn.Identity() clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "se": # squeeze_exitation self.attn_in_rearrange = Rearrange("seq_len batch_size n_channels -> batch_size n_channels seq_len") self.attn = SEBlock( in_channels=attn_input_size, reduction=self.config.attn.se.reduction, activation=self.config.attn.se.activation, kw_activation=self.config.attn.se.kw_activation, bias=self.config.attn.se.bias, ) self.attn_out_rearrange = nn.Identity() clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "gc": # global_context self.attn_in_rearrange = Rearrange("seq_len batch_size n_channels -> batch_size n_channels seq_len") self.attn = GlobalContextBlock( in_channels=attn_input_size, ratio=self.config.attn.gc.ratio, reduction=self.config.attn.gc.reduction, pooling_type=self.config.attn.gc.pooling_type, fusion_types=self.config.attn.gc.fusion_types, ) self.attn_out_rearrange = nn.Identity() clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "sa": # self_attention # NOTE: this branch NOT tested self.attn_in_rearrange = nn.Identity() self.attn = SelfAttention( embed_dim=attn_input_size, num_heads=self.config.attn.sa.get("num_heads", self.config.attn.sa.get("head_num")), dropout=self.config.attn.sa.dropout, bias=self.config.attn.sa.bias, ) self.attn_out_rearrange = Rearrange("seq_len batch_size channels -> batch_size channels seq_len") clf_input_size = self.attn.compute_output_shape(None, None)[-1] else: raise NotImplementedError(f"Attn module \042{self.config.attn.name}\042 not implemented yet.") if not self.config.lstm.retseq: if self.config.global_pool.lower() != "none": warnings.warn( f"Global pooling \042{self.config.global_pool}\042 is ignored in non-sequence mode", RuntimeWarning, ) self.config.global_pool = "none" self.pool = nn.Identity() self.pool_rearrange = nn.Identity() if self.config.clf.name.lower() == "crf": warnings.warn( "CRF layer is not supported in non-sequence mode, using linear instead.", RuntimeWarning, ) self.config.clf.name = "linear" assert "linear" in self.config.clf, "Linear layer not defined in `config`." self.clf = MLP( in_channels=clf_input_size, out_channels=self.config.clf.linear.out_channels + [self.n_classes], activation=self.config.clf.linear.activation, bias=self.config.clf.linear.bias, dropouts=self.config.clf.linear.dropouts, skip_last_activation=True, ) elif self.config.clf.name.lower() == "linear": if self.config.global_pool.lower() == "max": self.pool = nn.AdaptiveMaxPool1d((self.config.global_pool_size,), return_indices=False) self.pool_rearrange = Rearrange("batch_size channels pool_size -> batch_size (channels pool_size)") clf_input_size *= self.config.global_pool_size elif self.config.global_pool.lower() == "avg": self.pool = nn.AdaptiveAvgPool1d((self.config.global_pool_size,)) self.pool_rearrange = Rearrange("batch_size channels pool_size -> batch_size (channels pool_size)") clf_input_size *= self.config.global_pool_size elif self.config.global_pool.lower() == "none": self.pool = nn.Identity() self.pool_rearrange = Rearrange("batch_size n_channels seq_len -> batch_size seq_len n_channels") else: raise NotImplementedError(f"Pooling type \042{self.config.global_pool}\042 not supported") self.clf = MLP( in_channels=clf_input_size, out_channels=self.config.clf.linear.out_channels + [self.n_classes], activation=self.config.clf.linear.activation, bias=self.config.clf.linear.bias, dropouts=self.config.clf.linear.dropouts, skip_last_activation=True, ) elif self.config.clf.name.lower() == "crf": if self.config.global_pool.lower() != "none": warnings.warn( f"Global pooling \042{self.config.global_pool}\042 is ignored for CRF prediction head.", RuntimeWarning, ) self.config.global_pool = "none" self.pool = nn.Identity() self.pool_rearrange = Rearrange("batch_size n_channels seq_len -> batch_size seq_len n_channels") self.clf = ExtendedCRF( in_channels=clf_input_size, num_tags=self.n_classes, bias=self.config.clf.crf.proj_bias, ) # for inference, except for crf self.softmax = nn.Softmax(dim=-1) self.sigmoid = nn.Sigmoid()
[docs] def forward(self, input: Tensor) -> Tensor: """Forward pass of the model. Parameters ---------- input : torch.Tensor Input RR series tensor of shape ``(seq_len, batch_size, n_channels)``, or ``(batch_size, n_channels, seq_len)`` if `config.batch_first` is True. Returns ------- torch.Tensor Output tensor, of shape ``(batch_size, seq_len, n_classes)`` or ``(batch_size, n_classes)``. """ x = self.in_rearrange(input) x = self.lstm(x) x = self.attn_in_rearrange(x) x = self.attn(x) x = self.attn_out_rearrange(x) x = self.pool(x) x = self.pool_rearrange(x) output = self.clf(x) return output
[docs] @torch.no_grad() def inference(self, input: Tensor, bin_pred_thr: float = 0.5) -> BaseOutput: """Inference method for the model.""" raise NotImplementedError("implement a task specific inference method")
[docs] def compute_output_shape( self, seq_len: Optional[int] = None, batch_size: Optional[int] = None ) -> Sequence[Union[int, None]]: """Compute the output shape of the model. Parameters ---------- seq_len : int, optional Length of the input series tensor. batch_size : int, optional Batch size of the input series tensor. Returns ------- output_shape : sequence Output shape of the model. """ if (not self.config.lstm.retseq) or (self.config.global_pool.lower() != "none"): output_shape = (batch_size, self.n_classes) else: output_shape = (batch_size, seq_len, self.n_classes) return output_shape
@property def doi(self) -> List[str]: return list( set( self.config.get("doi", []) + [ "10.1162/neco.1997.9.8.1735", "10.1016/j.compbiomed.2018.07.001", "10.1142/s0219519421400212", ] ) )
[docs] @classmethod def from_v1(cls, v1_ckpt: str, device: Optional[torch.device] = None) -> "RR_LSTM": """Restore an instance of the model from a v1 checkpoint. Parameters ---------- v1_ckpt : str Path to the v1 checkpoint file. Returns ------- model : RR_LSTM The model instance restored from the v1 checkpoint. """ v1_model, _ = RR_LSTM_v1.from_checkpoint(v1_ckpt, device=device) model = cls(classes=v1_model.classes, config=v1_model.config) model = model.to(v1_model.device) model.lstm.load_state_dict(v1_model.lstm.state_dict()) if model.attn.__class__.__name__ != "Identity": model.attn.load_state_dict(v1_model.attn.state_dict()) model.clf.load_state_dict(v1_model.clf.state_dict()) del v1_model return model
class RR_LSTM_v1(nn.Module, CkptMixin, SizeMixin, CitationMixin): """LSTM model for RR time series classification or sequence labeling. LSTM model using RR time series as input is studied in [1]_ for atrial fibrillation detection. It is further improved in [2]_ via incorporating attention mechanism and conditional random fields. Parameters ---------- classes : List[str] List of the names of the classes. config : dict Other hyper-parameters, including kernel sizes, etc. Refer to corresponding config file for details. References ---------- .. [1] Faust, Oliver, et al. "Automated detection of atrial fibrillation using long short-term memory network with RR interval signals." Computers in biology and medicine 102 (2018): 327-335. .. [2] Wen, Hao, et al. "A Scalable Hybrid Model for Atrial Fibrillation Detection." Journal of Mechanics in Medicine and Biology 21.05 (2021): 2140021. """ __name__ = "RR_LSTM_v1" def __init__(self, classes: Sequence[str], config: Optional[CFG] = None, **kwargs: Any) -> None: """ Parameters ---------- classes: list, list of the classes for classification config: dict, optional, other hyper-parameters, including kernel sizes, etc. ref. the corresponding config file """ super().__init__() self.classes = list(classes) self.n_classes = len(classes) self.config = deepcopy(RR_LSTM_CONFIG) if not config: warnings.warn("No config is provided, using default config.", RuntimeWarning) self.config.update(deepcopy(config) or {}) self.lstm = StackedLSTM( input_size=1, hidden_sizes=self.config.lstm.hidden_sizes, bias=self.config.lstm.bias, dropouts=self.config.lstm.dropouts, bidirectional=self.config.lstm.bidirectional, return_sequences=self.config.lstm.retseq, ) attn_input_size = self.lstm.compute_output_shape(None, None)[-1] if not self.config.lstm.retseq: if self.config.attn.name.lower() != "none": warnings.warn( "Attention is not supported when lstm is not returning sequences", RuntimeWarning, ) self.config.attn.name = "none" self.attn = None clf_input_size = attn_input_size elif self.config.attn.name.lower() == "none": self.attn = None clf_input_size = attn_input_size elif self.config.attn.name.lower() == "nl": # non_local self.attn = NonLocalBlock( in_channels=attn_input_size, filter_lengths=self.config.attn.nl.filter_lengths, subsample_length=self.config.attn.nl.subsample_length, batch_norm=self.config.attn.nl.batch_norm, ) clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "se": # squeeze_exitation self.attn = SEBlock( in_channels=attn_input_size, reduction=self.config.attn.se.reduction, activation=self.config.attn.se.activation, kw_activation=self.config.attn.se.kw_activation, bias=self.config.attn.se.bias, ) clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "gc": # global_context self.attn = GlobalContextBlock( in_channels=attn_input_size, ratio=self.config.attn.gc.ratio, reduction=self.config.attn.gc.reduction, pooling_type=self.config.attn.gc.pooling_type, fusion_types=self.config.attn.gc.fusion_types, ) clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "sa": # self_attention # NOTE: this branch NOT tested self.attn = SelfAttention( embed_dim=attn_input_size, num_heads=self.config.attn.sa.get("num_heads", self.config.attn.sa.get("head_num")), dropout=self.config.attn.sa.dropout, bias=self.config.attn.sa.bias, ) clf_input_size = self.attn.compute_output_shape(None, None)[-1] else: raise NotImplementedError(f"Attn module \042{self.config.attn.name}\042 not implemented yet.") if not self.config.lstm.retseq: if self.config.global_pool.lower() != "none": warnings.warn( f"Global pooling \042{self.config.global_pool}\042 is ignored in non-sequence mode", RuntimeWarning, ) self.config.global_pool = "none" self.pool = None if self.config.clf.name.lower() == "crf": warnings.warn( "CRF layer is not supported in non-sequence mode, using linear instead.", RuntimeWarning, ) self.config.clf.name = "linear" assert "linear" in self.config.clf, "Linear layer not defined in `config`." self.clf = MLP( in_channels=clf_input_size, out_channels=self.config.clf.linear.out_channels + [self.n_classes], activation=self.config.clf.linear.activation, bias=self.config.clf.linear.bias, dropouts=self.config.clf.linear.dropouts, skip_last_activation=True, ) elif self.config.clf.name.lower() == "linear": if self.config.global_pool.lower() == "max": self.pool = nn.AdaptiveMaxPool1d((self.config.global_pool_size,), return_indices=False) clf_input_size *= self.config.global_pool_size elif self.config.global_pool.lower() == "avg": self.pool = nn.AdaptiveAvgPool1d((self.config.global_pool_size,)) clf_input_size *= self.config.global_pool_size elif self.config.global_pool.lower() == "none": self.pool = None else: raise NotImplementedError(f"Pooling type \042{self.config.global_pool}\042 not supported") self.clf = MLP( in_channels=clf_input_size, out_channels=self.config.clf.linear.out_channels + [self.n_classes], activation=self.config.clf.linear.activation, bias=self.config.clf.linear.bias, dropouts=self.config.clf.linear.dropouts, skip_last_activation=True, ) elif self.config.clf.name.lower() == "crf": if self.config.global_pool.lower() != "none": warnings.warn( f"Global pooling \042{self.config.global_pool}\042 is ignored for CRF prediction head.", RuntimeWarning, ) self.config.global_pool = "none" self.pool = None self.clf = ExtendedCRF( in_channels=clf_input_size, num_tags=self.n_classes, bias=self.config.clf.crf.proj_bias, ) # for inference, except for crf self.softmax = nn.Softmax(dim=-1) self.sigmoid = nn.Sigmoid() def forward(self, input: Tensor) -> Tensor: """Forward pass of the model. Parameters ---------- input : torch.Tensor Input RR series tensor of shape ``(seq_len, batch_size, n_channels)``, or ``(batch_size, n_channels, seq_len)`` if `config.batch_first` is True. Returns ------- torch.Tensor Output tensor, of shape ``(batch_size, seq_len, n_classes)`` or ``(batch_size, n_classes)``. """ # (batch_size, n_channels, seq_len) --> (seq_len, batch_size, n_channels) # x = input.permute(1,2,0) x = self.lstm(input) # (seq_len, batch_size, n_channels) or (batch_size, n_channels) if self.attn: # (seq_len, batch_size, n_channels) --> (batch_size, n_channels, seq_len) x = x.permute(1, 2, 0) x = self.attn(x) # (batch_size, n_channels, seq_len) elif x.ndim == 3: # (seq_len, batch_size, n_channels) --> (batch_size, n_channels, seq_len) x = x.permute(1, 2, 0) if self.pool: x = self.pool(x) # (batch_size, n_channels, 1) # x = x.squeeze(dim=-1) # (batch_size, n_channels) x = rearrange( x, "batch_size n_channels pool_size -> batch_size (n_channels pool_size)", ) elif x.ndim == 3: x = x.permute(0, 2, 1) # (batch_size, n_channels, seq_len) --> (batch_size, seq_len, n_channels) else: # x of shape (batch_size, n_channels), # in the case where config.lstm.retseq = False pass if self.config.clf.name.lower() == "linear": x = self.clf(x) # (batch_size, seq_len, n_classes) or (batch_size, n_classes) elif self.config.clf.name.lower() == "crf": x = self.clf(x) # (batch_size, seq_len, n_classes) output = x return output @torch.no_grad() def inference(self, input: Tensor, bin_pred_thr: float = 0.5) -> BaseOutput: """Inference method for the model.""" raise NotImplementedError("implement a task specific inference method") def compute_output_shape( self, seq_len: Optional[int] = None, batch_size: Optional[int] = None ) -> Sequence[Union[int, None]]: """Compute the output shape of the model. Parameters ---------- seq_len : int, optional Length of the input series tensor. batch_size : int, optional Batch size of the input series tensor. Returns ------- output_shape : sequence Output shape of the model. """ if (not self.config.lstm.retseq) or (self.config.global_pool.lower() != "none"): output_shape = (batch_size, self.n_classes) else: output_shape = (batch_size, seq_len, self.n_classes) return output_shape @property def doi(self) -> List[str]: return list( set( self.config.get("doi", []) + [ "10.1162/neco.1997.9.8.1735", "10.1016/j.compbiomed.2018.07.001", "10.1142/s0219519421400212", ] ) )