"""
Utilities for signal processing,
including spatial, temporal, spatio-temporal domains.
"""
import warnings
from copy import deepcopy
from numbers import Real
from typing import Iterable, Optional, Sequence, Tuple, Union
import numpy as np
from scipy import interpolate
from scipy.signal import butter, filtfilt, peak_prominences
from .utils_data import ensure_siglen
__all__ = [
"smooth",
"resample_irregular_timeseries",
"detect_peaks",
"remove_spikes_naive",
"butter_bandpass_filter",
"get_ampl",
"normalize",
]
[docs]def smooth(
x: np.ndarray,
window_len: int = 11,
window: str = "hanning",
mode: str = "valid",
keep_dtype: bool = True,
) -> np.ndarray:
"""Smooth the 1d data using a window with requested size.
This method is originally from [#smooth]_,
based on the convolution of a scaled window with the signal.
The signal is prepared by introducing reflected copies of the signal
(with the window size) in both ends so that transient parts are minimized
in the begining and end part of the output signal.
Parameters
----------
x : numpy.ndarray
The input signal.
window_len : int, default 11
Length of the smoothing window,
(previously should be an odd integer,
currently can be any (positive) integer).
window : {"flat", "hanning", "hamming", "bartlett", "blackman"}, optional
Type of window from, by default "hanning".
See also :func:`numpy.hanning`, :func:`numpy.hamming`, etc.
Flat type window will produce a moving average smoothing.
mode : str, default "valid"
Mode of convolution, see :func:`numpy.convolve` for details.
keep_dtype : bool, default True
Whether `dtype` of the returned value keeps
the same with that of `x` or not.
Returns
-------
y : numpy.ndarray
The smoothed signal.
Examples
--------
.. code-block:: python
t = np.linspace(-2, 2, 50)
x = np.sin(t) + np.random.randn(len(t)) * 0.1
y = smooth(x)
See also
--------
:func:`numpy.hanning`, :func:`numpy.hamming`,
:func:`numpy.bartlett`, :func:`numpy.blackman`, :func:`numpy.convolve`,
:func:`scipy.signal.lfilter`.
TODO
----
The window parameter could be the window itself
if an array instead of a string.
NOTE
----
length(output) != length(input), to correct this, using
.. code-block:: python
return y[(window_len/2-1):-(window_len/2)]
instead of just returning `y`.
References
----------
.. [#smooth] https://scipy-cookbook.readthedocs.io/items/SignalSmooth.html
"""
radius = min(len(x), window_len)
radius = radius if radius % 2 == 1 else radius - 1
if x.ndim != 1:
raise ValueError("function `smooth` only accepts 1 dimension arrays.")
# if x.size < radius:
# raise ValueError("Input vector needs to be bigger than window size.")
if radius < 3:
return x
if window not in ["flat", "hanning", "hamming", "bartlett", "blackman"]:
raise ValueError(""" `window` should be of "flat", "hanning", "hamming", "bartlett", "blackman" """)
s = np.r_[x[radius - 1 : 0 : -1], x, x[-2 : -radius - 1 : -1]]
# print(len(s))
if window == "flat": # moving average
w = np.ones(radius, "d")
else:
w = eval("np." + window + "(radius)")
y = np.convolve(w / w.sum(), s, mode=mode)
y = y[(radius // 2 - 1) : -(radius // 2) - 1]
assert len(x) == len(y)
if keep_dtype:
y = y.astype(x.dtype)
return y
[docs]def resample_irregular_timeseries(
sig: np.ndarray,
output_fs: Optional[Real] = None,
method: str = "spline",
return_with_time: bool = False,
tnew: Optional[np.ndarray] = None,
interp_kw: dict = {},
verbose: int = 0,
) -> np.ndarray:
"""
Resample the 2d irregular timeseries `sig` into a 1d or 2d
regular time series with frequency `output_fs`,
elements of `sig` are in the form ``[time, value]``,
where the unit of `time` is ms.
Parameters
----------
sig : numpy.ndarray
The 2d irregular timeseries.
Each row is ``[time, value]``.
output_fs : numbers.Real, optional
the frequency of the output 1d regular timeseries,
one and only one of `output_fs` and `tnew` should be specified
method : str, default "spline"
interpolation method, can be "spline" or "interp1d"
return_with_time : bool, default False
return a 2d array, with the 0-th coordinate being time
tnew : array_like, optional
the array of time of the output array,
one and only one of `output_fs` and `tnew` should be specified
interp_kw : dict, optional
additional options for the corresponding methods in scipy.interpolate
Returns
-------
numpy.ndarray
A 1d or 2d regular time series with frequency `output_freq`.
Examples
--------
.. code-block:: python
fs = 100
t_irr = np.sort(np.random.rand(fs)) * 1000
vals = np.random.randn(fs)
sig = np.stack([t_irr, vals], axis=1)
sig_reg = resample_irregular_timeseries(sig, output_fs=fs * 2, return_with_time=True)
sig_reg = resample_irregular_timeseries(sig, output_fs=fs, method="interp1d")
t_irr_2 = np.sort(np.random.rand(2 * fs)) * 1000
sig_reg = resample_irregular_timeseries(sig, tnew=t_irr_2, return_with_time=True)
NOTE
----
``pandas`` also has the function to regularly resample irregular timeseries.
"""
assert sig.ndim == 2, "`sig` should be a 2D array"
assert method.lower() in [
"spline",
"interp1d",
], f"method `{method}` not supported"
assert sum([output_fs is None, tnew is None]) == 1, "one and only one of `output_fs` and `tnew` should be specified"
_interp_kw = deepcopy(interp_kw)
if verbose >= 1:
print(f"len(sig) = {len(sig)}")
if len(sig) == 0:
return np.array([])
dtype = sig.dtype
time_series = np.atleast_2d(sig).astype(dtype)
if tnew is None:
step_ts = 1000 / output_fs
tot_len = int((time_series[-1][0] - time_series[0][0]) / step_ts) + 1
xnew = time_series[0][0] + np.arange(0, tot_len * step_ts, step_ts)
else:
assert tnew.ndim == 1, "`tnew` should be a 1D array"
xnew = np.array(tnew)
tot_len = len(xnew)
if verbose >= 1:
print(f"time_series start ts = {time_series[0][0]}, end ts = {time_series[-1][0]}")
print(f"tot_len = {tot_len}")
print(f"xnew start = {xnew[0]}, end = {xnew[-1]}")
if method.lower() == "spline":
m = len(time_series)
w = interp_kw.get("w", np.ones(shape=(m,)))
# s = interp_kw.get("s", np.random.uniform(m-np.sqrt(2*m),m+np.sqrt(2*m)))
s = interp_kw.get("s", m - np.sqrt(2 * m))
_interp_kw.update(w=w, s=s)
tck = interpolate.splrep(time_series[:, 0], time_series[:, 1], **_interp_kw)
regular_timeseries = interpolate.splev(xnew, tck)
elif method.lower() == "interp1d":
f = interpolate.interp1d(time_series[:, 0], time_series[:, 1], **_interp_kw)
regular_timeseries = f(xnew)
if return_with_time:
return np.column_stack((xnew, regular_timeseries)).astype(dtype)
else:
return regular_timeseries.astype(dtype)
[docs]def detect_peaks(
x: Sequence,
mph: Optional[Real] = None,
mpd: int = 1,
threshold: Real = 0,
left_threshold: Real = 0,
right_threshold: Real = 0,
prominence: Optional[Real] = None,
prominence_wlen: Optional[int] = None,
edge: Union[str, None] = "rising",
kpsh: bool = False,
valley: bool = False,
show: bool = False,
ax=None,
verbose: int = 0,
) -> np.ndarray:
"""Detect peaks in data based on their amplitude and other features.
Parameters
----------
x : array_like
1D array of data.
mph : positive number, optional
abbr. for maximum (minimum) peak height,
detect peaks that are greater than minimum peak height (if parameter `valley` is False),
or peaks that are smaller than maximum peak height (if parameter `valley` is True)
mpd : positive integer, default 1
abbr. for minimum peak distance,
detect peaks that are at least separated by minimum peak distance (in number of samples)
threshold : positive number, default 0
detect peaks (valleys) that are greater (smaller) than `threshold`,
in relation to their neighbors within the range of `mpd`
left_threshold : positive number, default 0
`threshold` that is restricted to the left
right_threshold : positive number, default 0
`threshold` that is restricted to the left
prominence: positive number, optional
threshold of prominence of the detected peaks (valleys)
prominence_wlen : positive int, optional
the `wlen` parameter of the function `scipy.signal.peak_prominences`
edge : str or None, default "rising"
can also be "falling", "both",
for a flat peak, keep only the rising edge ("rising"), only the falling edge ("falling"),
both edges ("both"), or don't detect a flat peak (None)
kpsh : bool, default False
keep peaks with same height even if they are closer than `mpd`
valley : bool, default False
if True (1), detect valleys (local minima) instead of peaks
show : bool, default False
if True (1), plot data in matplotlib figure
ax : a matplotlib.axes.Axes instance, optional
Returns
-------
ind : array_like
Indeces of the peaks in `x`.
NOTE
----
The detection of valleys instead of peaks is performed internally by simply
negating the data: ``ind_valleys = detect_peaks(-x)``.
The function can handle NaN's.
See this IPython Notebook [#peak]_.
References
----------
.. [#peak] https://nbviewer.org/github/demotu/BMC/blob/master/notebooks/DetectPeaks.ipynb
Examples
--------
.. code-block:: python
x = np.random.randn(100)
x[60:81] = np.nan
# detect all peaks and plot data
ind = detect_peaks(x, show=True)
print(ind)
x = np.sin(2*np.pi*5*np.linspace(0, 1, 200)) + np.random.randn(200)/5
# set minimum peak height = 0 and minimum peak distance = 20
detect_peaks(x, mph=0, mpd=20, show=True)
x = [0, 1, 0, 2, 0, 3, 0, 2, 0, 1, 0]
# set minimum peak distance = 2
detect_peaks(x, mpd=2, show=True)
x = np.sin(2*np.pi*5*np.linspace(0, 1, 200)) + np.random.randn(200)/5
# detection of valleys instead of peaks
detect_peaks(x, mph=-1.2, mpd=20, valley=True, show=True)
x = [0, 1, 1, 0, 1, 1, 0]
# detect both edges
detect_peaks(x, edge="both", show=True)
x = [-2, 1, -2, 2, 1, 1, 3, 0]
# set threshold = 2
detect_peaks(x, threshold = 2, show=True)
Version history
---------------
"1.0.5":
The sign of `mph` is inverted if parameter `valley` is True
"""
data = deepcopy(x)
data = np.atleast_1d(data).astype("float64")
if data.size < 3:
return np.array([], dtype=int)
if valley:
data = -data
if mph is not None:
mph = -mph
# find indices of all peaks
dx = data[1:] - data[:-1] # equiv to np.diff()
# handle NaN's
indnan = np.where(np.isnan(data))[0]
if indnan.size:
data[indnan] = np.inf
dx[np.where(np.isnan(dx))[0]] = np.inf
ine, ire, ife = np.array([[], [], []], dtype=int)
if not edge:
ine = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) > 0))[0]
else:
if edge.lower() in ["rising", "both"]:
ire = np.where((np.hstack((dx, 0)) <= 0) & (np.hstack((0, dx)) > 0))[0]
if edge.lower() in ["falling", "both"]:
ife = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) >= 0))[0]
ind = np.unique(np.hstack((ine, ire, ife)))
if verbose >= 1:
print(f"before filtering by mpd = {mpd}, and threshold = {threshold}, ind = {ind.tolist()}")
print(
f"additionally, left_threshold = {left_threshold}, "
f"right_threshold = {right_threshold}, length of data = {len(data)}"
)
# handle NaN's
if ind.size and indnan.size:
# NaN's and values close to NaN's cannot be peaks
ind = ind[np.in1d(ind, np.unique(np.hstack((indnan, indnan - 1, indnan + 1))), invert=True)]
if verbose >= 1:
print(f"after handling nan values, ind = {ind.tolist()}")
# peaks are only valid within [mpb, len(data)-mpb[
ind = np.array([pos for pos in ind if mpd <= pos < len(data) - mpd])
if verbose >= 1:
print(f"after fitering out elements too close to border by mpd = {mpd}, ind = {ind.tolist()}")
# first and last values of data cannot be peaks
# if ind.size and ind[0] == 0:
# ind = ind[1:]
# if ind.size and ind[-1] == data.size-1:
# ind = ind[:-1]
# remove peaks < minimum peak height
if ind.size and mph is not None:
ind = ind[data[ind] >= mph]
if verbose >= 1:
print(f"after filtering by mph = {mph}, ind = {ind.tolist()}")
# remove peaks - neighbors < threshold
_left_threshold = left_threshold if left_threshold > 0 else threshold
_right_threshold = right_threshold if right_threshold > 0 else threshold
if ind.size and (_left_threshold > 0 and _right_threshold > 0):
# dx = np.min(np.vstack([data[ind]-data[ind-1], data[ind]-data[ind+1]]), axis=0)
dx = np.max(np.vstack([data[ind] - data[ind + idx] for idx in range(-mpd, 0)]), axis=0)
ind = np.delete(ind, np.where(dx < _left_threshold)[0])
if verbose >= 2:
print(f"from left, dx = {dx.tolist()}")
print(f"after deleting those dx < _left_threshold = {_left_threshold}, ind = {ind.tolist()}")
dx = np.max(
np.vstack([data[ind] - data[ind + idx] for idx in range(1, mpd + 1)]),
axis=0,
)
ind = np.delete(ind, np.where(dx < _right_threshold)[0])
if verbose >= 2:
print(f"from right, dx = {dx.tolist()}")
print(f"after deleting those dx < _right_threshold = {_right_threshold}, ind = {ind.tolist()}")
if verbose >= 1:
print(f"after filtering by threshold, ind = {ind.tolist()}")
# detect small peaks closer than minimum peak distance
if ind.size and mpd > 1:
ind = ind[np.argsort(data[ind])][::-1] # sort ind by peak height
idel = np.zeros(ind.size, dtype=bool)
for i in range(ind.size):
if not idel[i]:
# keep peaks with the same height if kpsh is True
idel = idel | (ind >= ind[i] - mpd) & (ind <= ind[i] + mpd) & (data[ind[i]] > data[ind] if kpsh else True)
idel[i] = 0 # Keep current peak
# remove the small peaks and sort back the indices by their occurrence
ind = np.sort(ind[~idel])
ind = np.array([item for item in ind if data[item] == np.max(data[item - mpd : item + mpd + 1])])
if verbose >= 1:
print(f"after filtering by mpd, ind = {ind.tolist()}")
if prominence:
_p = peak_prominences(data, ind, prominence_wlen)[0]
ind = ind[np.where(_p >= prominence)[0]]
if verbose >= 1:
print(f"after filtering by prominence, ind = {ind.tolist()}")
if verbose >= 2:
print(f"with detailed prominence = {_p.tolist()}")
return ind
[docs]def remove_spikes_naive(sig: np.ndarray, threshold: Real = 20, inplace: bool = True) -> np.ndarray:
"""Remove signal spikes using a naive method.
This is a method proposed in entry 0416 of CPSC2019.
`spikes` here refers to abrupt large bumps with (abs) value
larger than the given threshold,
or nan values (read by `wfdb`).
Do **NOT** confuse with `spikes` in paced rhythm.
Parameters
----------
sig : numpy.ndarray
1D signal with potential spikes.
threshold : numbers.Real, optional
Values of `sig` that are larger than `threshold` will be removed.
inplace : bool, optional
Whether to modify `sig` in place or not.
Returns
-------
numpy.ndarray
Signal with `spikes` removed.
Examples
--------
.. code-block:: python
sig = np.random.randn(1000)
pos = np.random.randint(0, 1000, 10)
sig[pos] = 100
sig = remove_spikes_naive(sig)
pos = np.random.randint(0, 1000, 1)
sig[pos] = np.nan
sig = remove_spikes_naive(sig)
"""
dtype = sig.dtype
b = list(
filter(
lambda k: k > 0,
np.argwhere(np.logical_or(np.abs(sig) > threshold, np.isnan(sig))).squeeze(-1),
)
)
if not inplace:
sig = sig.copy()
if abs(sig[0]) > threshold or np.isnan(sig[0]):
sig[0] = 0
for k in b:
sig[k] = sig[k - 1]
return sig.astype(dtype)
def butter_bandpass(lowcut: Real, highcut: Real, fs: Real, order: int, verbose: int = 0) -> Tuple[np.ndarray, np.ndarray]:
"""Butterworth Bandpass Filter Design.
Parameters
----------
lowcut : numbers.Real
Low cutoff frequency.
highcut : numbers.Real
High cutoff frequency.
fs : numbers.Real
Sampling frequency of `data`.
order : int,
Order of the filter.
verbose : int, default 0
Verbosity level for debugging.
Returns
-------
b, a : numpy.ndarray
Coefficients of numerator and denominator of the filter.
NOTE
----
According to `lowcut` and `highcut`,
the filter type might degenerate to lowpass or highpass filter.
References
----------
1. :func:`scipy.signal.butter`
2. https://scipy-cookbook.readthedocs.io/items/ButterworthBandpass.html
"""
nyq = 0.5 * fs
low = lowcut / nyq
if low >= 1:
raise ValueError("frequency out of range!")
high = highcut / nyq
if low <= 0 and high >= 1:
raise ValueError("frequency out of range!")
if low <= 0:
Wn = high
btype = "low"
elif high >= 1:
Wn = low
btype = "high"
elif lowcut == highcut:
Wn = high
btype = "low"
else:
Wn = [low, high]
btype = "band"
if verbose >= 1:
print(f"by the setup of lowcut and highcut, the filter type falls to {btype}, with Wn = {Wn}")
b, a = butter(order, Wn, btype=btype)
return b, a
[docs]def butter_bandpass_filter(
data: np.ndarray,
lowcut: Real,
highcut: Real,
fs: Real,
order: int,
btype: Optional[str] = None,
verbose: int = 0,
) -> np.ndarray:
"""Butterworth bandpass filtering the signals.
Apply a Butterworth bandpass filter to the signal.
For references, see [#bp1]_ and [#bp2]_.
Parameters
----------
data : numpy.ndarray
Signal to be filtered.
lowcut : numbers.real
Low cutoff frequency.
highcut : numbers.real
High cutoff frequency.
fs : numbers.real
Frequency of the signal.
order : int
Order of the filter.
btype : {"lohi", "hilo"}, optional
(special) type of the filter.
Ignored for lowpass and highpass filters
(as defined by `lowcut` and `highcut`).
verbose : int, default 0
Verbosity level for printing.
Returns
-------
y : numpy.ndarray
The filtered signal.
References
----------
.. [#bp1] https://scipy-cookbook.readthedocs.io/items/ButterworthBandpass.html
.. [#bp2] https://dsp.stackexchange.com/questions/19084/applying-filter-in-scipy-signal-use-lfilter-or-filtfilt
"""
dtype = data.dtype
if btype is None:
b, a = butter_bandpass(lowcut, highcut, fs, order=order, verbose=verbose)
y = filtfilt(b, a, data)
return y.astype(dtype)
if btype.lower() == "lohi":
b, a = butter_bandpass(0, highcut, fs, order=order, verbose=verbose)
y = filtfilt(b, a, data)
b, a = butter_bandpass(lowcut, fs, fs, order=order, verbose=verbose)
y = filtfilt(b, a, y)
elif btype.lower() == "hilo":
b, a = butter_bandpass(lowcut, fs, fs, order=order, verbose=verbose)
y = filtfilt(b, a, data)
b, a = butter_bandpass(0, highcut, fs, order=order, verbose=verbose)
y = filtfilt(b, a, y)
else:
raise ValueError(f"special btype `{btype}` is not supported")
return y.astype(dtype)
[docs]def get_ampl(
sig: np.ndarray,
fs: Real,
fmt: str = "lead_first",
window: Real = 0.2,
critical_points: Optional[Sequence] = None,
) -> Union[float, np.ndarray]:
"""Get amplitude of a signal (near critical points if given).
Parameters
----------
sig : numpy.ndarray
(ECG) signal.
fs : numbers.Real
Sampling frequency of the signal
fmt : str, default "lead_first"
Format of the signal, can be
"channel_last" (alias "lead_last"), or
"channel_first" (alias "lead_first").
Ignored if sig is 1d array (single-lead).
window : int, default 0.2
Window length of a window for computing amplitude, with units in seconds.
critical_points : numpy.ndarray, optional
Positions of critical points near which to compute amplitude,
e.g. can be rpeaks, t peaks, etc.
Returns
-------
ampl : float or numpy.ndarray
Amplitude of the signal.
"""
dtype = sig.dtype
if fmt.lower() in ["channel_last", "lead_last"]:
_sig = sig.T
elif fmt.lower() in ["channel_first", "lead_first"]:
_sig = sig
else:
raise ValueError(f"unknown format `{fmt}`")
_window = int(round(window * fs))
half_window = _window // 2
_window = half_window * 2
if critical_points is not None:
s = np.stack(
[
ensure_siglen(
_sig[
...,
max(0, p - half_window) : min(_sig.shape[-1], p + half_window),
],
siglen=_window,
fmt="lead_first",
)
for p in critical_points
],
axis=-1,
).astype(dtype)
# the following is much slower
# for p in critical_points:
# s = _sig[...,max(0,p-half_window):min(_sig.shape[-1],p+half_window)]
# ampl = np.max(np.array([ampl, np.max(s,axis=-1) - np.min(s,axis=-1)]), axis=0)
else:
s = np.stack(
[_sig[..., idx * half_window : idx * half_window + _window] for idx in range(_sig.shape[-1] // half_window - 1)],
axis=-1,
).astype(dtype)
# the following is much slower
# for idx in range(_sig.shape[-1]//half_window-1):
# s = _sig[..., idx*half_window: idx*half_window+_window]
# ampl = np.max(np.array([ampl, np.max(s,axis=-1) - np.min(s,axis=-1)]), axis=0)
ampl = np.max(np.max(s, axis=-2) - np.min(s, axis=-2), axis=-1)
return ampl
[docs]def normalize(
sig: np.ndarray,
method: str,
mean: Union[Real, Iterable[Real]] = 0.0,
std: Union[Real, Iterable[Real]] = 1.0,
sig_fmt: str = "channel_first",
per_channel: bool = False,
) -> np.ndarray:
"""Normalize a signal.
Perform z-score normalization on `sig`,
to make it has fixed mean and standard deviation,
or perform min-max normalization on `sig`,
or normalize `sig` using `mean` and `std` via (sig - mean) / std.
More precisely,
.. math::
\\begin{align*}
\\text{Min-Max normalization:}\\quad & \\frac{sig - \\min(sig)}{\\max(sig) - \\min(sig)} \\\\
\\text{Naive normalization:}\\quad & \\frac{sig - m}{s} \\\\
\\text{Z-score normalization:}\\quad & \\left(\\frac{sig - mean(sig)}{std(sig)}\\right) \\cdot s + m
\\end{align*}
Parameters
----------
sig : numpy.ndarray
The signal to be normalized.
method : {"naive", "min-max", "z-score"}
Normalization method, case insensitive.
mean : numbers.Real or array_like, default 0.0
Mean value of the normalized signal,
or mean values for each lead of the normalized signal.
Useless if `method` is "min-max".
std : numbers.Real or array_like, default 1.0
Standard deviation of the normalized signal,
or standard deviations for each lead of the normalized signal.
Useless if `method` is "min-max".
sig_fmt : str, default "channel_first"
Format of the signal, can be of one of
"channel_last" (alias "lead_last"), or
"channel_first" (alias "lead_first"),
ignored if sig is 1d array (single-lead).
per_channel : bool, default False
If True, normalization will be done per channel.
Ignored if `sig` is 1d array (single-lead).
Returns
-------
nm_sig : numpy.ndarray
The normalized signal.
NOTE
----
In cases where normalization is infeasible (``std = 0``),
only the mean value will be shifted.
"""
assert sig.ndim in [1, 2, 3], "signal `sig` should be 1d or 2d or 3d array"
if sig.ndim == 1 and per_channel:
warnings.warn(
"per-channel normalization is not supported for 1d signal, " "`per_channel` will be set to False",
RuntimeWarning,
)
per_channel = False
dtype = sig.dtype
_method = method.lower()
assert _method in [
"z-score",
"naive",
"min-max",
], f"unknown normalization method `{method}`"
if not per_channel:
if sig.ndim == 2:
assert isinstance(mean, Real) and isinstance(
std, Real
), "`mean` and `std` should be real numbers in the non per-channel setting for 2d signal"
else: # sig.ndim == 3
assert (isinstance(mean, Real) or np.shape(mean) == (sig.shape[0],)) and (
isinstance(std, Real) or np.shape(std) == (sig.shape[0],)
), (
f"`mean` and `std` should be real numbers or have shape ({sig.shape[0]},) "
"in the non per-channel setting for 3d signal"
)
if isinstance(std, Real):
assert std > 0, "standard deviation should be positive"
else:
assert (np.array(std) > 0).all(), "standard deviations should all be positive"
assert sig_fmt.lower() in [
"channel_first",
"lead_first",
"channel_last",
"lead_last",
], f"format `{sig_fmt}` of the signal not supported!"
if isinstance(mean, Iterable):
assert sig.ndim in [2, 3], "`mean` should be a real number for 1d signal"
if sig.ndim == 2:
assert np.shape(mean) in [
(sig.shape[0],),
(sig.shape[-1],),
], f"shape of `mean` = {np.shape(mean)} not compatible with the `sig` = {np.shape(sig)}"
if sig_fmt.lower() in [
"channel_first",
"lead_first",
]:
_mean = np.array(mean, dtype=dtype)[..., np.newaxis]
else:
_mean = np.array(mean, dtype=dtype)[np.newaxis, ...]
else: # sig.ndim == 3
if sig_fmt.lower() in [
"channel_first",
"lead_first",
]:
if np.shape(mean) == (sig.shape[0],):
_mean = np.array(mean, dtype=dtype)[..., np.newaxis, np.newaxis]
elif np.shape(mean) == (sig.shape[1],):
_mean = np.repeat(
np.array(mean, dtype=dtype)[np.newaxis, ..., np.newaxis],
sig.shape[0],
axis=0,
)
elif np.shape(mean) == sig.shape[:2]:
_mean = np.array(mean, dtype=dtype)[..., np.newaxis]
else:
raise AssertionError(f"shape of `mean` = {np.shape(mean)} not compatible with the `sig` = {np.shape(sig)}")
else: # "channel_last" or "lead_last"
if np.shape(mean) == (sig.shape[0],):
_mean = np.array(mean, dtype=dtype)[..., np.newaxis, np.newaxis]
elif np.shape(mean) == (sig.shape[-1],):
_mean = np.repeat(
np.array(mean, dtype=dtype)[np.newaxis, np.newaxis, ...],
sig.shape[0],
axis=0,
)
elif np.shape(mean) == (sig.shape[0], sig.shape[-1]):
_mean = np.expand_dims(np.array(mean, dtype=dtype), axis=1)
else:
raise AssertionError(f"shape of `mean` = {np.shape(mean)} not compatible with the `sig` = {np.shape(sig)}")
else:
_mean = mean
if isinstance(std, Iterable):
assert sig.ndim in [2, 3], "`std` should be a real number for 1d signal"
if sig.ndim == 2:
assert np.shape(std) in [
(sig.shape[0],),
(sig.shape[-1],),
], f"shape of `std` = {np.shape(std)} not compatible with the `sig` = {np.shape(sig)}"
if sig_fmt.lower() in [
"channel_first",
"lead_first",
]:
_std = np.array(std, dtype=dtype)[..., np.newaxis]
else:
_std = np.array(std, dtype=dtype)[np.newaxis, ...]
else: # sig.ndim == 3
if sig_fmt.lower() in [
"channel_first",
"lead_first",
]:
if np.shape(std) == (sig.shape[0],):
_std = np.array(std, dtype=dtype)[..., np.newaxis, np.newaxis]
elif np.shape(std) == (sig.shape[1],):
_std = np.repeat(
np.array(std, dtype=dtype)[np.newaxis, ..., np.newaxis],
sig.shape[0],
axis=0,
)
elif np.shape(std) == sig.shape[:2]:
_std = np.array(std, dtype=dtype)[..., np.newaxis]
else:
raise AssertionError(f"shape of `std` = {np.shape(std)} not compatible with the `sig` = {np.shape(sig)}")
else: # "channel_last" or "lead_last"
if np.shape(std) == (sig.shape[0],):
_std = np.array(std, dtype=dtype)[..., np.newaxis, np.newaxis]
elif np.shape(std) == (sig.shape[-1],):
_std = np.repeat(
np.array(std, dtype=dtype)[np.newaxis, np.newaxis, ...],
sig.shape[0],
axis=0,
)
elif np.shape(std) == (sig.shape[0], sig.shape[-1]):
_std = np.expand_dims(np.array(std, dtype=dtype), axis=1)
else:
raise AssertionError(f"shape of `std` = {np.shape(std)} not compatible with the `sig` = {np.shape(sig)}")
else:
_std = std
if _method == "naive":
nm_sig = (sig - _mean) / _std
return nm_sig.astype(dtype)
eps = 1e-7 # to avoid dividing by zero
if sig.ndim == 3: # the first dimension is the batch dimension
if not per_channel:
options = dict(axis=(1, 2), keepdims=True)
elif sig_fmt.lower() in [
"channel_first",
"lead_first",
]:
options = dict(axis=2, keepdims=True)
else:
options = dict(axis=1, keepdims=True)
else:
if not per_channel:
options = dict(axis=None)
elif sig_fmt.lower() in [
"channel_first",
"lead_first",
]:
options = dict(axis=1, keepdims=True)
else:
options = dict(axis=0, keepdims=True)
if _method == "z-score":
nm_sig = ((sig - np.mean(sig, dtype=dtype, **options)) / (np.std(sig, dtype=dtype, **options) + eps)) * _std + _mean
elif _method == "min-max":
nm_sig = (sig - np.amin(sig, **options)) / (np.amax(sig, **options) - np.amin(sig, **options) + eps)
return nm_sig.astype(dtype)