Source code for torch_ecg.components.inputs
"""
"""
import inspect
import math
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import List, Sequence, Tuple, Union
import numpy as np
import torch
from einops.layers.torch import Rearrange
from torch.nn import functional as F
from ..cfg import CFG, DEFAULTS
from ..utils.misc import ReprMixin, add_docstring
from ..utils.utils_nn import compute_conv_output_shape
from ..utils.utils_signal_t import Spectrogram
__all__ = [
"InputConfig",
"WaveformInput",
"FFTInput",
"SpectrogramInput",
]
[docs]class InputConfig(CFG):
"""A Class to store the configuration of the input.
Parameters
----------
input_type : {"waveform", "fft", "spectrogram"}, optional
Type of the input.
n_channels : int
Number of channels of the input.
n_samples : int
Number of samples of the input.
ensure_batch_dim : bool
Whether to ensure the transformed input has a batch dimension.
Examples
--------
.. code-block:: python
input_config = InputConfig(
input_type="waveform",
n_channels=12,
n_samples=5000,
)
"""
__name__ = "InputConfig"
def __init__(
self,
*args: Union[CFG, dict],
input_type: str,
n_channels: int,
n_samples: int = -1,
ensure_batch_dim: bool = True,
**kwargs: dict,
) -> None:
super().__init__(
*args,
input_type=input_type,
n_channels=n_channels,
n_samples=n_samples,
ensure_batch_dim=ensure_batch_dim,
**kwargs,
)
assert "n_channels" in self and self.n_channels > 0, f"`n_channels` must be positive, got {self.n_channels}"
assert "n_samples" in self and (
self.n_samples > 0 or self.n_samples == -1
), f"`n_samples` must be positive or -1, got {self.n_samples}"
assert "input_type" in self and self.input_type.lower() in [
"waveform",
"fft",
"spectrogram",
], f"`input_type` must be one of ['waveform', 'fft', 'spectrogram'], got {self.input_type}"
self.input_type = self.input_type.lower()
if self.input_type in [
"spectrogram",
]:
assert "n_bins" in self, f"`n_bins` must be specified for {self.input_type} input"
assert "fs" in self or "sample_rate" in self, f"`fs` or `sample_rate` must be specified for {self.input_type} input"
class BaseInput(ReprMixin, ABC):
"""Base class for all input classes.
Parameters
----------
config : InputConfig
The configuration of the input.
"""
__name__ = "BaseInput"
def __init__(self, config: InputConfig) -> None:
""" """
assert isinstance(config, InputConfig), "`config` must be an instance of `InputConfig`"
self._config = deepcopy(config)
self._values = None
self._dtype = self._config.get("dtype", DEFAULTS.DTYPE.TORCH)
self._device = self._config.get("device", DEFAULTS.device)
self._post_init()
def __call__(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
"""Method to transform the waveform to the input tensor.
Parameters
----------
waveform : numpy.ndarray or torch.Tensor
The waveform to be transformed.
Returns
-------
torch.Tensor
The transformed waveform.
"""
return self.from_waveform(waveform)
@abstractmethod
def _from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
"""Internal method to convert the waveform to the input tensor."""
raise NotImplementedError
def from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
"""Transform the waveform to the input tensor.
Parameters
----------
waveform : numpy.ndarray or torch.Tensor
The waveform to be transformed.
Returns
-------
torch.Tensor
The transformed waveform.
"""
assert waveform.shape[-2:] == (
self.n_channels,
self.n_samples,
), (
f"`waveform` shape must be `(batch_size, {self.n_channels}, {self.n_samples})` "
f"or `({self.n_channels}, {self.n_samples})`, got `{waveform.shape}`"
)
input_tensors = self._from_waveform(waveform)
if waveform.ndim == 2 and self._config.ensure_batch_dim:
input_tensors = input_tensors.unsqueeze(0)
return input_tensors
@abstractmethod
def _post_init(self) -> None:
"""Method to be called after initialization"""
raise NotImplementedError
@property
def values(self) -> torch.Tensor:
return self._values
@property
def n_channels(self) -> int:
return self._config.n_channels
@property
def n_samples(self) -> int:
return self._config.n_samples
@property
def input_channels(self) -> int:
channel_dim = {
"waveform": -2,
"fft": -2,
# "spectrogram": -3, # implemented in `SpectrogramInput`
}
if self.values is not None:
return self.values.shape[channel_dim[self.input_type]]
return self.compute_input_shape((self.n_channels, self.n_samples))[channel_dim[self.input_type]]
@property
def input_samples(self) -> int:
if self.values is not None:
return self.values.shape[-1]
return self.compute_input_shape((self.n_channels, self.n_samples))[-1]
@property
def input_type(self) -> str:
return self._config.input_type
@property
def dtype(self) -> torch.dtype:
return self._dtype
@property
def device(self) -> torch.device:
return self._device
def compute_input_shape(self, waveform_shape: Union[Sequence[int], torch.Size]) -> Tuple[Union[type(None), int], ...]:
"""Computes the input shape of the model based on
the input type and the waveform shape.
Parameters
----------
waveform_shape : Sequence[int] or torch.Size
The shape of the waveform.
Returns
-------
Tuple[int] or None
The input shape of the model.
"""
if self.input_type == "waveform":
input_shape = tuple(waveform_shape)
elif self.input_type == "fft":
nfft = self.nfft or waveform_shape[-1]
seq_len = torch.fft.rfftfreq(nfft).shape[0]
if self.drop_dc:
seq_len -= 1
input_shape = (*waveform_shape[:-2], 2 * waveform_shape[-2], seq_len)
elif self.input_type == "spectrogram":
n_samples = compute_conv_output_shape(
waveform_shape if len(waveform_shape) == 3 else [None] + list(waveform_shape),
kernel_size=self.win_length,
stride=self.hop_length,
asymmetric_padding=[self.hop_length, self.win_length - self.hop_length],
)[-1]
if self.feature_fs is not None:
n_samples = math.floor(n_samples * self.feature_fs / self.fs)
if self.to1d:
mid_dims = (self.n_channels * self.n_bins,)
else:
mid_dims = (self.n_channels, self.n_bins)
input_shape = (*waveform_shape[:-2], *mid_dims, n_samples)
if len(waveform_shape) == 2 and self._config.ensure_batch_dim:
input_shape = (1, *input_shape)
return input_shape
def extra_repr_keys(self) -> List[str]:
return ["input_type", "n_channels", "n_samples", "dtype", "device"]
[docs]class WaveformInput(BaseInput):
"""Waveform input.
Examples
--------
>>> from torch_ecg.cfg import DEFAULTS
>>> BATCH_SIZE = 32
>>> N_CHANNELS = 12
>>> N_SAMPLES = 5000
>>> input_config = InputConfig(
... input_type="waveform",
... n_channels=N_CHANNELS,
... n_samples=N_SAMPLES,
... )
>>> inputer = WaveformInput(input_config)
>>> waveform = torch.randn(BATCH_SIZE, N_CHANNELS, N_SAMPLES)
>>> inputer(waveform).shape
torch.Size([32, 12, 5000])
>>> waveform = DEFAULTS.RNG.uniform(size=(N_CHANNELS, N_SAMPLES))
>>> inputer(waveform).shape
torch.Size([1, 12, 5000])
>>> input_config = InputConfig(
... input_type="waveform",
... n_channels=N_CHANNELS,
... n_samples=N_SAMPLES,
... ensure_batch_dim=False,
... )
>>> inputer = WaveformInput(input_config)
>>> waveform = DEFAULTS.RNG.uniform(size=(N_CHANNELS, N_SAMPLES))
>>> inputer(waveform).shape
torch.Size([12, 5000])
"""
__name__ = "WaveformInput"
def _post_init(self) -> None:
"""Make sure the input type is `waveform`."""
assert self.input_type == "waveform", "`input_type` must be `waveform`"
def _from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
"""Internal method to convert the waveform to the input tensor."""
self._values = torch.as_tensor(waveform).to(self.device, self.dtype)
return self._values
[docs] def from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
"""Converts the input :class:`~numpy.ndarray` or
:class:`~torch.Tensor` waveform to a :class:`~torch.Tensor`.
Parameters
----------
waveform : numpy.ndarray or torch.Tensor
The waveform to be transformed,
of shape ``(batch_size, n_channels, n_samples)``
or ``(n_channels, n_samples)``.
Returns
-------
torch.Tensor
The transformed waveform,
of shape ``(batch_size, n_channels, n_samples)``.
NOTE
----
If the input is a 2D tensor,
then the batch dimension is added (batch_size = 1).
"""
return super().from_waveform(waveform)
@add_docstring(from_waveform.__doc__)
def __call__(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
""" """
return self.from_waveform(waveform)
[docs]class FFTInput(BaseInput):
"""Inputs from the FFT, via concatenating the amplitudes and the phases.
One can set the following optional parameters for initialization:
- nfft: int
the number of FFT bins.
If nfft is None, the number of FFT bins is computed from the input shape.
- drop_dc: bool, default True
Whether to drop the zero frequency bin (the DC component).
- norm: str, optional
The normalization of the FFT, can be
- "forward"
- "backward"
- "ortho"
Examples
--------
>>> from torch_ecg.cfg import DEFAULTS
>>> BATCH_SIZE = 32
>>> N_CHANNELS = 12
>>> N_SAMPLES = 5000
>>> input_config = InputConfig(
... input_type="fft",
... n_channels=N_CHANNELS,
... n_samples=N_SAMPLES,
... n_fft=200,
... drop_dc=True,
... norm="ortho",
... )
>>> inputer = FFTInput(input_config)
>>> waveform = torch.randn(BATCH_SIZE, N_CHANNELS, N_SAMPLES)
>>> inputer(waveform).ndim
3
>>> inputer(waveform).shape == inputer.compute_input_shape(waveform.shape)
True
>>> waveform = DEFAULTS.RNG.uniform(size=(N_CHANNELS, N_SAMPLES))
>>> inputer(waveform).ndim
3
>>> inputer(waveform).shape == inputer.compute_input_shape(waveform.shape)
True
>>> input_config = InputConfig(
... input_type="fft",
... n_channels=N_CHANNELS,
... n_samples=N_SAMPLES,
... n_fft=None,
... drop_dc=False,
... norm="forward",
... ensure_batch_dim=False,
... )
>>> inputer = FFTInput(input_config)
>>> waveform = DEFAULTS.RNG.uniform(size=(N_CHANNELS, N_SAMPLES))
>>> inputer(waveform).ndim
2
>>> inputer(waveform).shape == inputer.compute_input_shape(waveform.shape)
True
"""
__name__ = "FFTInput"
def _post_init(self) -> None:
"""Make sure the input type is `fft` and set the parameters."""
assert self.input_type == "fft", "`input_type` must be `fft`"
self.nfft = self._config.get("nfft", None)
if self.nfft is None and self.n_samples > 0:
self.nfft = self.n_samples
self.drop_dc = self._config.get("drop_dc", True)
self.norm = self._config.get("norm", None)
if self.norm is not None:
assert self.norm in [
"forward",
"backward",
"ortho",
], f"`norm` must be one of [`forward`, `backward`, `ortho`], got {self.norm}"
def _from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
"""Internal method to convert the waveform to the input tensor."""
self._values = torch.fft.rfft(
torch.as_tensor(waveform).to(self.device, self.dtype),
n=self.nfft,
dim=-1,
norm=self.norm,
)
if self.drop_dc:
self._values = self._values[..., 1:]
self._values = torch.cat([torch.abs(self._values), torch.angle(self._values)], dim=-2)
return self._values
[docs] def from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
"""Converts the input :class:`~numpy.ndarray` or
:class:`~torch.Tensor` waveform to a :class:`~torch.Tensor` of FFTs.
Parameters
----------
waveform : numpy.ndarray or torch.Tensor
The waveform to be transformed,
of shape ``(batch_size, n_channels, n_samples)``
or ``(n_channels, n_samples)``.
Returns
-------
torch.Tensor
The transformed waveform,
of shape ``(batch_size, 2 * n_channels, seq_len)``,
where `seq_len` is computed via :code:`torch.fft.rfftfreq(nfft).shape[0]`,
if `drop_dc` is True, then seq_len is reduced by 1
NOTE
----
If the input is a 2D tensor,
then the batch dimension is added (batch_size = 1).
"""
return super().from_waveform(waveform)
@add_docstring(from_waveform.__doc__)
def __call__(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
return self.from_waveform(waveform)
[docs] def extra_repr_keys(self) -> List[str]:
return super().extra_repr_keys() + ["nfft", "drop_dc", "norm"]
class _SpectralInput(BaseInput):
"""Inputs from the spectro-temporal domain.
One has to set the following parameters for initialization:
- n_bins : int
The number of frequency bins.
- fs (or sample_rate) : int
The sample rate of the waveform.
with the following optional parameters with default values:
- window_size : float, default: 1 / 20
The size of the window in seconds.
- overlap_size : float, default: 1 / 40
The overlap of the windows in seconds.
- feature_fs : None or float,
The sample rate of the features.
If specified, the features will be resampled
against `fs` to this sample rate.
- to1d : bool, default False
Whether to convert the features to 1D.
NOTE that if `to1d` is True,
then if the convolutions with ``groups=1`` applied to the `input`
acts on all the bins, which is "global"
w.r.t. the `bins` dimension of the corresponding 2d input.
"""
__name__ = "_SpectralInput"
def _post_init(self) -> None:
"""Make sure the input type is `spectral` and set the parameters."""
self.to1d = self._config.get("to1d", False)
self.fs = self._config.get("fs", self._config.get("sample_rate"))
self.feature_fs = self._config.get("feature_fs", None)
if "window_size" not in self._config:
self._config.window_size = 1 / 20
assert 0 < self._config.window_size < 0.2, f"`window_size` must be in (0, 0.2), got {self._config.window_size}"
if "overlap_size" not in self._config:
self._config.overlap_size = 1 / 40
# TODO: consider negative overlap_size, i.e. positive gaps between windows
assert 0 < self._config.overlap_size < self._config.window_size, (
f"`overlap_size` must be in `(0, window_size)` = {(0, self._config.window_size)}, "
f"got {self._config.overlap_size}"
)
@property
def n_bins(self) -> int:
return self._config.n_bins
@property
def window_size(self) -> int:
return round(self._config.window_size * self.fs)
@property
def win_length(self) -> int:
return self.window_size
@property
def overlap_size(self) -> int:
return round(self._config.overlap_size * self.fs)
@property
def hop_length(self) -> int:
return self.window_size - self.overlap_size
@property
def input_channels(self) -> int:
channel_dim = -2 if self.to1d else -3
if self.values is not None:
return self.values.shape[channel_dim]
return self.compute_input_shape((self.n_channels, self.n_samples))[channel_dim]
@property
def input_samples(self) -> Tuple[int, ...]:
sample_dim = (-1,) if self.to1d else (-2, -1)
if self.values is not None:
input_shape = self.values.shape
input_shape = self.compute_input_shape((self.n_channels, self.n_samples))
return tuple(input_shape[dim] for dim in sample_dim)
def extra_repr_keys(self) -> List[str]:
return super().extra_repr_keys() + [
"n_bins",
"win_length",
"hop_length",
"fs",
"feature_fs",
"to1d",
]
[docs]class SpectrogramInput(_SpectralInput):
__doc__ = (
_SpectralInput.__doc__
+ """
Examples
--------
>>> from torch_ecg.cfg import DEFAULTS
>>> BATCH_SIZE = 32
>>> N_CHANNELS = 12
>>> N_SAMPLES = 5000
>>> input_config = InputConfig(
... name="spectrogram",
... n_channels=N_CHANNELS,
... n_samples=N_SAMPLES,
... n_bins=128,
... fs=500,
... window_size=1 / 20,
... overlap_size=1 / 40,
... feature_fs=100,
... to1d=True,
... )
>>> inputer = SpectrogramInput(input_config)
>>> waveform = torch.randn(BATCH_SIZE, N_CHANNELS, N_SAMPLES)
>>> spectrogram = inputer(waveform)
>>> spectrogram.shape == inputer.compute_input_shape(waveform.shape)
True
"""
)
__name__ = "SpectrogramInput"
def _post_init(self) -> None:
"""Make sure the input type is `spectrogram` and set the parameters."""
super()._post_init()
assert self.input_type in ["spectrogram"], f"`input_type` must be one of [`spectrogram`], got {self.input_type}"
args = inspect.getfullargspec(Spectrogram.__init__).args
for k in ["self", "n_fft", "win_length", "hop_length"]:
args.remove(k)
kwargs = {k: self._config[k] for k in args if k in self._config}
kwargs["n_fft"] = (self.n_bins - 1) * 2
kwargs["win_length"] = self.win_length
kwargs["hop_length"] = self.hop_length
self._transform = torch.nn.Sequential()
self._transform.add_module("spectrogram", Spectrogram(**kwargs).to(self.device, self.dtype))
if self.to1d:
self._transform.add_module(
"to1d",
Rearrange("... channel n_bins time -> ... (channel n_bins) time"),
)
def _from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
"""Internal method to convert the waveform to the input tensor."""
self._values = self._transform(torch.as_tensor(waveform).to(self.device, self.dtype))
if self.feature_fs is not None:
# self.values.ndim can be 2, 3, or 4
scale_factor = [1] * (self.values.ndim - 3) + [self.feature_fs / self.fs]
if self.values.ndim == 2:
self._values = F.interpolate(
self._values.unsqueeze(0),
scale_factor=scale_factor,
recompute_scale_factor=True,
).squeeze(0)
else:
self._values = F.interpolate(self._values, scale_factor=scale_factor, recompute_scale_factor=True)
return self._values
[docs] def from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
r"""Converts the input :class:`~numpy.ndarray` or
:class:`~torch.Tensor` waveform to a :class:`~torch.Tensor` of spectrograms.
Parameters
----------
waveform : numpy.ndarray or torch.Tensor
The waveform to be transformed,
of shape ``(batch_size, n_channels, n_samples)``
or ``(n_channels, n_samples)``.
Returns
-------
torch.Tensor
The transformed waveform,
of shape ``(batch_size, n_channels, n_bins, n_frames)``, where
.. math::
n\_frames = (n\_samples - win\_length) // hop\_length + 1
NOTE
----
If the input is a 2D tensor,
then the batch dimension is added (batch_size = 1).
"""
return super().from_waveform(waveform)
@add_docstring(from_waveform.__doc__)
def __call__(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
return self.from_waveform(waveform)