FFTInput

class torch_ecg.components.FFTInput(config: InputConfig)[source]

Bases: BaseInput

Inputs from the FFT, via concatenating the amplitudes and the phases.

One can set the following optional parameters for initialization:
  • nfft: int

    the number of FFT bins. If nfft is None, the number of FFT bins is computed from the input shape.

  • drop_dc: bool, default True

    Whether to drop the zero frequency bin (the DC component).

  • norm: str, optional
    The normalization of the FFT, can be
    • “forward”

    • “backward”

    • “ortho”

Examples

>>> from torch_ecg.cfg import DEFAULTS
>>> BATCH_SIZE = 32
>>> N_CHANNELS = 12
>>> N_SAMPLES = 5000
>>> input_config = InputConfig(
...     input_type="fft",
...     n_channels=N_CHANNELS,
...     n_samples=N_SAMPLES,
...     n_fft=200,
...     drop_dc=True,
...     norm="ortho",
... )
>>> inputer = FFTInput(input_config)
>>> waveform = torch.randn(BATCH_SIZE, N_CHANNELS, N_SAMPLES)
>>> inputer(waveform).ndim
3
>>> inputer(waveform).shape == inputer.compute_input_shape(waveform.shape)
True
>>> waveform = DEFAULTS.RNG.uniform(size=(N_CHANNELS, N_SAMPLES))
>>> inputer(waveform).ndim
3
>>> inputer(waveform).shape == inputer.compute_input_shape(waveform.shape)
True
>>> input_config = InputConfig(
...     input_type="fft",
...     n_channels=N_CHANNELS,
...     n_samples=N_SAMPLES,
...     n_fft=None,
...     drop_dc=False,
...     norm="forward",
...     ensure_batch_dim=False,
... )
>>> inputer = FFTInput(input_config)
>>> waveform = DEFAULTS.RNG.uniform(size=(N_CHANNELS, N_SAMPLES))
>>> inputer(waveform).ndim
2
>>> inputer(waveform).shape == inputer.compute_input_shape(waveform.shape)
True
extra_repr_keys() List[str][source]

Extra keys for __repr__() and __str__().

from_waveform(waveform: ndarray | Tensor) Tensor[source]

Converts the input ndarray or Tensor waveform to a Tensor of FFTs.

Parameters:

waveform (numpy.ndarray or torch.Tensor) – The waveform to be transformed, of shape (batch_size, n_channels, n_samples) or (n_channels, n_samples).

Returns:

The transformed waveform, of shape (batch_size, 2 * n_channels, seq_len), where seq_len is computed via torch.fft.rfftfreq(nfft).shape[0], if drop_dc is True, then seq_len is reduced by 1

Return type:

torch.Tensor

Note

If the input is a 2D tensor, then the batch dimension is added (batch_size = 1).