Source code for torch_ecg.models.ecg_crnn

"""
C(R)NN structure models, for classifying ECG arrhythmias, and other tasks.
"""

import warnings
from copy import deepcopy
from typing import Any, List, Optional, Sequence, Union

import numpy as np
import torch
from einops import rearrange
from einops.layers.torch import Rearrange
from torch import Tensor, nn

from ..cfg import CFG
from ..components.outputs import BaseOutput
from ..model_configs.ecg_crnn import ECG_CRNN_CONFIG
from ..utils.misc import CitationMixin
from ..utils.utils_nn import CkptMixin, SizeMixin
from ._nets import MLP, GlobalContextBlock, NonLocalBlock, SEBlock, SelfAttention, StackedLSTM
from .cnn.densenet import DenseNet
from .cnn.mobilenet import MobileNetV1, MobileNetV2, MobileNetV3
from .cnn.multi_scopic import MultiScopicCNN
from .cnn.regnet import RegNet
from .cnn.resnet import ResNet
from .cnn.vgg import VGG16
from .cnn.xception import Xception
from .transformers import Transformer

__all__ = [
    "ECG_CRNN",
]


[docs]class ECG_CRNN(nn.Module, CkptMixin, SizeMixin, CitationMixin): """Convolutional (Recurrent) Neural Network for ECG tasks. This C(R)NN architecture is adapted from [:footcite:ct:`yao2018ti_cnn,yao2020ati_cnn`] in the first place,and then modified to be more general, and more flexible. The most famous model is perhaps [:footcite:ct:`awni2019stanford_ecg`], which is a modified 1D-ResNet34 model. The website of this model is `<https://stanfordmlgroup.github.io/projects/ecg2/>`_, and the code is hosted on `<https://github.com/awni/ecg>`_. The C(R)NN models have long been competitive in various ECG tasks, e.g. CPSC2018 entry 0236, CPSC2019 entry 0416. The models are also used in the PhysioNet/CinC Challenges. Parameters ---------- classes : List[str] List of the names of the classes. n_leads : int Number of leads (number of input channels). config : dict Other hyper-parameters, including kernel sizes, etc. Refer to corresponding config files. .. footbibliography:: """ __name__ = "ECG_CRNN" def __init__( self, classes: Sequence[str], n_leads: int, config: Optional[CFG] = None, **kwargs: Any, ) -> None: super().__init__() self.classes = list(classes) self.n_classes = len(classes) self.n_leads = n_leads self.config = deepcopy(ECG_CRNN_CONFIG) if not config: warnings.warn("No config is provided, using default config.", RuntimeWarning) self.config.update(deepcopy(config) or {}) cnn_choice = self.config.cnn.name.lower() cnn_config = self.config.cnn[self.config.cnn.name] if "resnet" in cnn_choice or "resnext" in cnn_choice: self.cnn = ResNet(self.n_leads, **cnn_config) elif "regnet" in cnn_choice: self.cnn = RegNet(self.n_leads, **cnn_config) elif "multi_scopic" in cnn_choice: self.cnn = MultiScopicCNN(self.n_leads, **cnn_config) elif "mobile_net" in cnn_choice or "mobilenet" in cnn_choice: if "v1" in cnn_choice: self.cnn = MobileNetV1(self.n_leads, **cnn_config) elif "v2" in cnn_choice: self.cnn = MobileNetV2(self.n_leads, **cnn_config) elif "v3" in cnn_choice: self.cnn = MobileNetV3(self.n_leads, **cnn_config) else: raise ValueError(f"CNN \042{cnn_choice}\042 is not supported for {self.__name__}") elif "densenet" in cnn_choice or "dense_net" in cnn_choice: self.cnn = DenseNet(self.n_leads, **cnn_config) elif "vgg16" in cnn_choice: self.cnn = VGG16(self.n_leads, **cnn_config) elif "xception" in cnn_choice: self.cnn = Xception(self.n_leads, **cnn_config) else: raise NotImplementedError(f"CNN \042{cnn_choice}\042 not implemented yet") rnn_input_size = self.cnn.compute_output_shape(None, None)[1] if self.config.rnn.name.lower() == "none": self.rnn_in_rearrange = Rearrange("batch_size channels seq_len -> seq_len batch_size channels") self.rnn = nn.Identity() self.__rnn_seqlen_dim = 0 self.rnn_out_rearrange = nn.Identity() attn_input_size = rnn_input_size elif self.config.rnn.name.lower() == "lstm": self.rnn_in_rearrange = Rearrange("batch_size channels seq_len -> seq_len batch_size channels") self.rnn = StackedLSTM( input_size=rnn_input_size, hidden_sizes=self.config.rnn.lstm.hidden_sizes, bias=self.config.rnn.lstm.bias, dropouts=self.config.rnn.lstm.dropouts, bidirectional=self.config.rnn.lstm.bidirectional, return_sequences=self.config.rnn.lstm.retseq, ) self.__rnn_seqlen_dim = 0 self.rnn_out_rearrange = nn.Identity() attn_input_size = self.rnn.compute_output_shape(None, None)[-1] elif self.config.rnn.name.lower() == "linear": # abuse of notation, to put before the global attention module self.rnn_in_rearrange = Rearrange("batch_size channels seq_len -> batch_size seq_len channels") self.rnn = MLP( in_channels=rnn_input_size, out_channels=self.config.rnn.linear.out_channels, activation=self.config.rnn.linear.activation, bias=self.config.rnn.linear.bias, dropouts=self.config.rnn.linear.dropouts, ) self.__rnn_seqlen_dim = 1 self.rnn_out_rearrange = Rearrange("batch_size seq_len channels -> seq_len batch_size channels") attn_input_size = self.rnn.compute_output_shape(None, None)[-1] else: raise NotImplementedError(f"RNN \042{self.config.rnn.name}\042 not implemented yet") # attention if self.config.rnn.name.lower() == "lstm" and not self.config.rnn.lstm.retseq: self.attn_in_rearrange = nn.Identity() self.attn = nn.Identity() self.__attn_seqlen_dim = 0 self.attn_out_rearrange = nn.Identity() clf_input_size = attn_input_size if self.config.attn.name.lower() != "none": warnings.warn( f"since `retseq` of rnn is False, hence attention `{self.config.attn.name}` is ignored", RuntimeWarning, ) elif self.config.attn.name.lower() == "none": self.attn_in_rearrange = Rearrange("seq_len batch_size channels -> batch_size channels seq_len") self.attn = nn.Identity() self.__attn_seqlen_dim = -1 self.attn_out_rearrange = nn.Identity() clf_input_size = attn_input_size elif self.config.attn.name.lower() == "nl": # non_local self.attn_in_rearrange = Rearrange("seq_len batch_size channels -> batch_size channels seq_len") self.attn = NonLocalBlock( in_channels=attn_input_size, filter_lengths=self.config.attn.nl.filter_lengths, subsample_length=self.config.attn.nl.subsample_length, batch_norm=self.config.attn.nl.batch_norm, ) self.__attn_seqlen_dim = -1 self.attn_out_rearrange = nn.Identity() clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "se": # squeeze_exitation self.attn_in_rearrange = Rearrange("seq_len batch_size channels -> batch_size channels seq_len") self.attn = SEBlock( in_channels=attn_input_size, reduction=self.config.attn.se.reduction, activation=self.config.attn.se.activation, kw_activation=self.config.attn.se.kw_activation, bias=self.config.attn.se.bias, ) self.__attn_seqlen_dim = -1 self.attn_out_rearrange = nn.Identity() clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "gc": # global_context self.attn_in_rearrange = Rearrange("seq_len batch_size channels -> batch_size channels seq_len") self.attn = GlobalContextBlock( in_channels=attn_input_size, ratio=self.config.attn.gc.ratio, reduction=self.config.attn.gc.reduction, pooling_type=self.config.attn.gc.pooling_type, fusion_types=self.config.attn.gc.fusion_types, ) self.__attn_seqlen_dim = -1 self.attn_out_rearrange = nn.Identity() clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "sa": # self_attention # NOTE: this branch NOT tested self.attn_in_rearrange = nn.Identity() self.attn = SelfAttention( embed_dim=attn_input_size, num_heads=self.config.attn.sa.get("num_heads", self.config.attn.sa.get("head_num")), dropout=self.config.attn.sa.dropout, bias=self.config.attn.sa.bias, ) self.__attn_seqlen_dim = 0 self.attn_out_rearrange = Rearrange("seq_len batch_size channels -> batch_size channels seq_len") clf_input_size = self.attn.compute_output_shape(None, None)[-1] elif self.config.attn.name.lower() == "transformer": self.attn = Transformer( input_size=attn_input_size, hidden_size=self.config.attn.transformer.hidden_size, num_layers=self.config.attn.transformer.num_layers, num_heads=self.config.attn.transformer.num_heads, dropout=self.config.attn.transformer.dropout, activation=self.config.attn.transformer.activation, ) if self.attn.batch_first: self.attn_in_rearrange = Rearrange("seq_len batch_size channels -> batch_size seq_len channels") self.attn_out_rearrange = Rearrange("batch_size seq_len channels -> batch_size channels seq_len") self.__attn_seqlen_dim = 1 else: self.attn_in_rearrange = nn.Identity() self.attn_out_rearrange = Rearrange("seq_len batch_size channels -> batch_size channels seq_len") self.__attn_seqlen_dim = 0 clf_input_size = self.attn.compute_output_shape(None, None)[-1] else: raise NotImplementedError(f"Attention \042{self.config.attn.name}\042 not implemented yet") # global pooling if self.config.rnn.name.lower() == "lstm" and not self.config.rnn.lstm.retseq: self.pool = nn.Identity() if self.config.global_pool.lower() != "none": warnings.warn( f"since `retseq` of rnn is False, hence global pooling `{self.config.global_pool}` is ignored", RuntimeWarning, ) self.pool_rearrange = nn.Identity() self.__clf_input_seq = False elif self.config.global_pool.lower() == "max": self.pool = nn.AdaptiveMaxPool1d((self.config.global_pool_size,), return_indices=False) clf_input_size *= self.config.global_pool_size self.pool_rearrange = Rearrange("batch_size channels pool_size -> batch_size (channels pool_size)") self.__clf_input_seq = False elif self.config.global_pool.lower() == "avg": self.pool = nn.AdaptiveAvgPool1d((self.config.global_pool_size,)) clf_input_size *= self.config.global_pool_size self.pool_rearrange = Rearrange("batch_size channels pool_size -> batch_size (channels pool_size)") self.__clf_input_seq = False elif self.config.global_pool.lower() == "attn": raise NotImplementedError("Attentive pooling not implemented yet!") elif self.config.global_pool.lower() == "none": self.pool = nn.Identity() self.pool_rearrange = Rearrange("batch_size channels seq_len -> batch_size seq_len channels") self.__clf_input_seq = True else: raise NotImplementedError(f"Global Pooling \042{self.config.global_pool}\042 not implemented yet!") # input of `self.clf` has shape: batch_size, channels self.clf = MLP( in_channels=clf_input_size, out_channels=self.config.clf.out_channels + [self.n_classes], activation=self.config.clf.activation, bias=self.config.clf.bias, dropouts=self.config.clf.dropouts, skip_last_activation=True, ) # for inference # classification: if single-label, use softmax; otherwise (multi-label) use sigmoid # sequence tagging: if background counted in `classes`, use softmax; otherwise use sigmoid self.sigmoid = nn.Sigmoid() self.softmax = nn.Softmax(-1)
[docs] def extract_features(self, input: Tensor) -> Tensor: """Extract feature map before the dense (linear) classifying layer(s). Parameters ---------- input : torch.Tensor Input signal tensor, of shape ``(batch_size, channels, seq_len)``. Returns ------- features : torch.Tensor Feature map tensor, of shape ``(batch_size, channels, seq_len)`` or ``(batch_size, channels)``. """ # CNN features = self.cnn(input) # batch_size, channels, seq_len # RNN (optional) features = self.rnn_in_rearrange(features) features = self.rnn(features) features = self.rnn_out_rearrange(features) # Attention (optional) features = self.attn_in_rearrange(features) features = self.attn(features) features = self.attn_out_rearrange(features) return features
[docs] def forward(self, input: Tensor) -> Tensor: """Forward pass of the model. Parameters ---------- input : torch.Tensor Input signal tensor, of shape ``(batch_size, channels, seq_len)``. Returns ------- pred : torch.Tensor Predictions tensor, of shape ``(batch_size, seq_len, channels)`` or ``(batch_size, channels)``. """ features = self.extract_features(input) # global pooling (optional) features = self.pool(features) features = self.pool_rearrange(features) pred = self.clf(features) return pred
[docs] @torch.no_grad() def inference( self, input: Union[np.ndarray, Tensor], class_names: bool = False, bin_pred_thr: float = 0.5, ) -> BaseOutput: """Inference method for the model. Parameters ---------- input : numpy.ndarray or torch.Tensor Input tensor, of shape ``(batch_size, channels, seq_len)``. class_names : bool, default False If True, the returned scalar predictions will be a :class:`~pandas.DataFrame`, with class names for each scalar prediction. bin_pred_thr : float, default 0.5 Threshold for making binary predictions from scalar predictions. Returns ------- output : BaseOutput The output of the inference method, including the following items: - prob: numpy.ndarray or torch.Tensor, scalar predictions, (and binary predictions if `class_names` is True). - pred: numpy.ndarray or torch.Tensor, the array (with values 0, 1 for each class) of binary prediction. """ raise NotImplementedError("implement a task specific inference method")
[docs] def compute_output_shape( self, seq_len: Optional[int] = None, batch_size: Optional[int] = None ) -> Sequence[Union[int, None]]: """Compute the output shape of the model. Parameters ---------- seq_len : int, optional Length of the input signal tensor. batch_size : int, optional Batch size of the input signal tensor. Returns ------- output_shape : sequence Output shape of the model. """ output_shape = self.cnn.compute_output_shape(seq_len, batch_size) _, _, _seq_len = output_shape if self.rnn.__class__.__name__ != "Identity": output_shape = self.rnn.compute_output_shape(_seq_len, batch_size) _seq_len = output_shape[self.__rnn_seqlen_dim] if self.attn.__class__.__name__ != "Identity": output_shape = self.attn.compute_output_shape(_seq_len, batch_size) _seq_len = output_shape[self.__attn_seqlen_dim] if self.clf.__class__.__name__ != "Identity": output_shape = self.clf.compute_output_shape(_seq_len, batch_size, input_seq=self.__clf_input_seq) return output_shape
@property def doi(self) -> List[str]: doi = [] candidates = [self.config] while len(candidates) > 0: new_candidates = [] for candidate in candidates: if hasattr(candidate, "doi"): if isinstance(candidate.doi, str): doi.append(candidate.doi) else: doi.extend(list(candidate.doi)) for k, v in candidate.items(): if isinstance(v, CFG): new_candidates.append(v) candidates = new_candidates doi = list(set(doi + ["10.1016/j.inffus.2019.06.024", "10.1088/1361-6579/ac6aa3"])) return doi
[docs] @classmethod def from_v1(cls, v1_ckpt: str, device: Optional[torch.device] = None) -> "ECG_CRNN": """Restore an instance of the model from a v1 checkpoint. Parameters ---------- v1_ckpt : str Path to the v1 checkpoint file. Returns ------- model : ECG_CRNN The model instance restored from the v1 checkpoint. """ v1_model, _ = ECG_CRNN_v1.from_checkpoint(v1_ckpt, device=device) model = cls(classes=v1_model.classes, n_leads=v1_model.n_leads, config=v1_model.config) model = model.to(v1_model.device) model.cnn.load_state_dict(v1_model.cnn.state_dict()) if model.rnn.__class__.__name__ != "Identity": model.rnn.load_state_dict(v1_model.rnn.state_dict()) if model.attn.__class__.__name__ != "Identity": model.attn.load_state_dict(v1_model.attn.state_dict()) model.clf.load_state_dict(v1_model.clf.state_dict()) del v1_model return model
class ECG_CRNN_v1(nn.Module, CkptMixin, SizeMixin, CitationMixin): """Convolutional (Recurrent) Neural Network for ECG tasks. This C(R)NN architecture is adapted from [1]_, [2]_ in the first place, and then modified to be more general, and more flexible. The most famous model is perhaps [3]_, which is a modified 1D-ResNet34 model. The website of this model is [4]_, and the code is hosted on [5]_. The C(R)NN models have long been competitive in various ECG tasks, e.g. [6]_, [7]_. The models are also used in the PhysioNet/CinC Challenges. Parameters ---------- classes : List[str] List of the names of the classes. n_leads : int Number of leads (number of input channels). config : dict Other hyper-parameters, including kernel sizes, etc. Refer to corresponding config files. References ---------- .. [1] Yao, Qihang, et al. "Time-Incremental Convolutional Neural Network for Arrhythmia Detection in Varied-Length Electrocardiogram." 2018 IEEE 16th Intl Conf on Dependable, Autonomic and Secure Computing, 16th Intl Conf on Pervasive Intelligence and Computing, 4th Intl Conf on Big Data Intelligence and Computing and Cyber Science and Technology Congress (DASC/PiCom/DataCom/CyberSciTech). IEEE, 2018. .. [2] Yao, Qihang, et al. "Multi-class Arrhythmia detection from 12-lead varied-length ECG using Attention-based Time-Incremental Convolutional Neural Network." Information Fusion 53 (2020): 174-182. .. [3] Hannun, Awni Y., et al. "Cardiologist-level arrhythmia detection and classification in ambulatory electrocardiograms using a deep neural network." Nature medicine 25.1 (2019): 65. .. [4] https://stanfordmlgroup.github.io/projects/ecg2/ .. [5] https://github.com/awni/ecg .. [6] CPSC2018 entry 0236 .. [7] CPSC2019 entry 0416 """ __name__ = "ECG_CRNN_v1" def __init__( self, classes: Sequence[str], n_leads: int, config: Optional[CFG] = None, **kwargs: Any, ) -> None: super().__init__() self.classes = list(classes) self.n_classes = len(classes) self.n_leads = n_leads self.config = deepcopy(ECG_CRNN_CONFIG) if not config: warnings.warn("No config is provided, using default config.", RuntimeWarning) self.config.update(deepcopy(config) or {}) cnn_choice = self.config.cnn.name.lower() cnn_config = self.config.cnn[self.config.cnn.name] if "resnet" in cnn_choice or "resnext" in cnn_choice: self.cnn = ResNet(self.n_leads, **cnn_config) elif "regnet" in cnn_choice: self.cnn = RegNet(self.n_leads, **cnn_config) elif "multi_scopic" in cnn_choice: self.cnn = MultiScopicCNN(self.n_leads, **cnn_config) elif "mobile_net" in cnn_choice or "mobilenet" in cnn_choice: if "v1" in cnn_choice: self.cnn = MobileNetV1(self.n_leads, **cnn_config) elif "v2" in cnn_choice: self.cnn = MobileNetV2(self.n_leads, **cnn_config) elif "v3" in cnn_choice: self.cnn = MobileNetV3(self.n_leads, **cnn_config) else: raise ValueError(f"CNN \042{cnn_choice}\042 is not supported for {self.__name__}") elif "densenet" in cnn_choice or "dense_net" in cnn_choice: self.cnn = DenseNet(self.n_leads, **cnn_config) elif "vgg16" in cnn_choice: self.cnn = VGG16(self.n_leads, **cnn_config) elif "xception" in cnn_choice: self.cnn = Xception(self.n_leads, **cnn_config) else: raise NotImplementedError(f"CNN \042{cnn_choice}\042 not implemented yet") rnn_input_size = self.cnn.compute_output_shape(None, None)[1] if self.config.rnn.name.lower() == "none": self.rnn = None attn_input_size = rnn_input_size elif self.config.rnn.name.lower() == "lstm": self.rnn = StackedLSTM( input_size=rnn_input_size, hidden_sizes=self.config.rnn.lstm.hidden_sizes, bias=self.config.rnn.lstm.bias, dropouts=self.config.rnn.lstm.dropouts, bidirectional=self.config.rnn.lstm.bidirectional, return_sequences=self.config.rnn.lstm.retseq, ) attn_input_size = self.rnn.compute_output_shape(None, None)[-1] elif self.config.rnn.name.lower() == "linear": # abuse of notation, to put before the global attention module self.rnn = MLP( in_channels=rnn_input_size, out_channels=self.config.rnn.linear.out_channels, activation=self.config.rnn.linear.activation, bias=self.config.rnn.linear.bias, dropouts=self.config.rnn.linear.dropouts, ) attn_input_size = self.rnn.compute_output_shape(None, None)[-1] else: raise NotImplementedError(f"RNN \042{self.config.rnn.name}\042 not implemented yet") # attention if self.config.rnn.name.lower() == "lstm" and not self.config.rnn.lstm.retseq: self.attn = None clf_input_size = attn_input_size if self.config.attn.name.lower() != "none": warnings.warn( f"since `retseq` of rnn is False, hence attention `{self.config.attn.name}` is ignored", RuntimeWarning, ) elif self.config.attn.name.lower() == "none": self.attn = None clf_input_size = attn_input_size elif self.config.attn.name.lower() == "nl": # non_local self.attn = NonLocalBlock( in_channels=attn_input_size, filter_lengths=self.config.attn.nl.filter_lengths, subsample_length=self.config.attn.nl.subsample_length, batch_norm=self.config.attn.nl.batch_norm, ) clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "se": # squeeze_exitation self.attn = SEBlock( in_channels=attn_input_size, reduction=self.config.attn.se.reduction, activation=self.config.attn.se.activation, kw_activation=self.config.attn.se.kw_activation, bias=self.config.attn.se.bias, ) clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "gc": # global_context self.attn = GlobalContextBlock( in_channels=attn_input_size, ratio=self.config.attn.gc.ratio, reduction=self.config.attn.gc.reduction, pooling_type=self.config.attn.gc.pooling_type, fusion_types=self.config.attn.gc.fusion_types, ) clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "sa": # self_attention # NOTE: this branch NOT tested self.attn = SelfAttention( embed_dim=attn_input_size, num_heads=self.config.attn.sa.get("num_heads", self.config.attn.sa.get("head_num")), dropout=self.config.attn.sa.dropout, bias=self.config.attn.sa.bias, ) clf_input_size = self.attn.compute_output_shape(None, None)[-1] elif self.config.attn.name.lower() == "transformer": self.attn = Transformer( input_size=attn_input_size, hidden_size=self.config.attn.transformer.hidden_size, num_layers=self.config.attn.transformer.num_layers, num_heads=self.config.attn.transformer.num_heads, dropout=self.config.attn.transformer.dropout, activation=self.config.attn.transformer.activation, ) clf_input_size = self.attn.compute_output_shape(None, None)[-1] else: raise NotImplementedError(f"Attention \042{self.config.attn.name}\042 not implemented yet") if self.config.rnn.name.lower() == "lstm" and not self.config.rnn.lstm.retseq: self.pool = None if self.config.global_pool.lower() != "none": warnings.warn( f"since `retseq` of rnn is False, hence global pooling `{self.config.global_pool}` is ignored", RuntimeWarning, ) elif self.config.global_pool.lower() == "max": self.pool = nn.AdaptiveMaxPool1d((self.config.global_pool_size,), return_indices=False) clf_input_size *= self.config.global_pool_size elif self.config.global_pool.lower() == "avg": self.pool = nn.AdaptiveAvgPool1d((self.config.global_pool_size,)) clf_input_size *= self.config.global_pool_size elif self.config.global_pool.lower() == "attn": raise NotImplementedError("Attentive pooling not implemented yet!") elif self.config.global_pool.lower() == "none": self.pool = None else: raise NotImplementedError(f"Global Pooling \042{self.config.global_pool}\042 not implemented yet!") # input of `self.clf` has shape: batch_size, channels self.clf = MLP( in_channels=clf_input_size, out_channels=self.config.clf.out_channels + [self.n_classes], activation=self.config.clf.activation, bias=self.config.clf.bias, dropouts=self.config.clf.dropouts, skip_last_activation=True, ) # for inference # classification: if single-label, use softmax; otherwise (multi-label) use sigmoid # sequence tagging: if background counted in `classes`, use softmax; otherwise use sigmoid self.sigmoid = nn.Sigmoid() self.softmax = nn.Softmax(-1) def extract_features(self, input: Tensor) -> Tensor: """Extract feature map before the dense (linear) classifying layer(s). Parameters ---------- input : torch.Tensor Input signal tensor, of shape ``(batch_size, channels, seq_len)``. Returns ------- features : torch.Tensor Feature map tensor, of shape ``(batch_size, channels, seq_len)`` or ``(batch_size, channels)``. """ # CNN features = self.cnn(input) # batch_size, channels, seq_len # RNN (optional) if self.config.rnn.name.lower() in ["lstm"]: # (batch_size, channels, seq_len) --> (seq_len, batch_size, channels) features = features.permute(2, 0, 1) features = self.rnn(features) # (seq_len, batch_size, channels) or (batch_size, channels) elif self.config.rnn.name.lower() in ["linear"]: # (batch_size, channels, seq_len) --> (batch_size, seq_len, channels) features = features.permute(0, 2, 1) features = self.rnn(features) # (batch_size, seq_len, channels) # (batch_size, seq_len, channels) --> (seq_len, batch_size, channels) features = features.permute(1, 0, 2) else: # (batch_size, channels, seq_len) --> (seq_len, batch_size, channels) features = features.permute(2, 0, 1) # Attention (optional) if self.attn is None and features.ndim == 3: # (seq_len, batch_size, channels) --> (batch_size, channels, seq_len) features = features.permute(1, 2, 0) elif self.config.attn.name.lower() in ["nl", "se", "gc"]: # (seq_len, batch_size, channels) --> (batch_size, channels, seq_len) features = features.permute(1, 2, 0) features = self.attn(features) # (batch_size, channels, seq_len) elif self.config.attn.name.lower() in ["sa"]: features = self.attn(features) # (seq_len, batch_size, channels) # (seq_len, batch_size, channels) -> (batch_size, channels, seq_len) features = features.permute(1, 2, 0) elif self.config.attn.name.lower() in ["transformer"]: features = self.attn(features) # (seq_len, batch_size, channels) -> (batch_size, channels, seq_len) features = features.permute(1, 2, 0) return features def forward(self, input: Tensor) -> Tensor: """Forward pass of the model. Parameters ---------- input : torch.Tensor Input signal tensor, of shape ``(batch_size, channels, seq_len)``. Returns ------- pred : torch.Tensor Predictions tensor, of shape ``(batch_size, seq_len, channels)`` or ``(batch_size, channels)``. """ features = self.extract_features(input) if self.pool: features = self.pool(features) # (batch_size, channels, pool_size) # features = features.squeeze(dim=-1) features = rearrange( features, "batch_size channels pool_size -> batch_size (channels pool_size)", ) elif features.ndim == 3: # (batch_size, channels, seq_len) --> (batch_size, seq_len, channels) features = features.permute(0, 2, 1) # print(f"clf in shape = {features.shape}") pred = self.clf(features) # batch_size, n_classes return pred @torch.no_grad() def inference( self, input: Union[np.ndarray, Tensor], class_names: bool = False, bin_pred_thr: float = 0.5, ) -> BaseOutput: """Inference method for the model. Parameters ---------- input : numpy.ndarray or torch.Tensor Input tensor, of shape ``(batch_size, channels, seq_len)``. class_names : bool, default False If True, the returned scalar predictions will be a :class:`~pandas.DataFrame`, with class names for each scalar prediction. bin_pred_thr : float, default 0.5 Threshold for making binary predictions from scalar predictions. Returns ------- output : BaseOutput The output of the inference method, including the following items: - prob: numpy.ndarray or torch.Tensor, scalar predictions, (and binary predictions if `class_names` is True). - pred: numpy.ndarray or torch.Tensor, the array (with values 0, 1 for each class) of binary prediction. """ raise NotImplementedError("implement a task specific inference method") def compute_output_shape( self, seq_len: Optional[int] = None, batch_size: Optional[int] = None ) -> Sequence[Union[int, None]]: """Compute the output shape of the model. Parameters ---------- seq_len : int, optional Length of the input signal tensor. batch_size : int, optional Batch size of the input signal tensor. Returns ------- output_shape : sequence Output shape of the model. """ if self.pool: return (batch_size, len(self.classes)) else: _seq_len = seq_len output_shape = self.cnn.compute_output_shape(_seq_len, batch_size) _, _, _seq_len = output_shape if self.rnn: output_shape = self.rnn.compute_output_shape(_seq_len, batch_size) _seq_len = output_shape[0] if self.attn: output_shape = self.attn.compute_output_shape(_seq_len, batch_size) _seq_len = output_shape[-1] output_shape = self.clf.compute_output_shape(_seq_len, batch_size) return output_shape @property def doi(self) -> List[str]: doi = [] candidates = [self.config] while len(candidates) > 0: new_candidates = [] for candidate in candidates: if hasattr(candidate, "doi"): if isinstance(candidate.doi, str): doi.append(candidate.doi) else: doi.extend(list(candidate.doi)) for k, v in candidate.items(): if isinstance(v, CFG): new_candidates.append(v) candidates = new_candidates doi = list(set(doi + ["10.1016/j.inffus.2019.06.024", "10.1088/1361-6579/ac6aa3"])) return doi