torch_ecg.utils

This module contains a collection of utility functions and classes that are used throughout the package.

Neural network auxiliary functions and classes

extend_predictions(preds, classes, ...)

Extend the prediction arrays to prediction arrays in larger range of classes

compute_output_shape(layer_type, input_shape)

Compute the output shape of a (transpose) convolution/maxpool/avgpool layer.

compute_conv_output_shape(input_shape[, ...])

Compute the output shape of a convolution layer.

compute_deconv_output_shape(input_shape[, ...])

Compute the output shape of a transpose convolution layer

compute_maxpool_output_shape(input_shape[, ...])

Compute the output shape of a maxpool layer.

compute_avgpool_output_shape(input_shape[, ...])

Compute the output shape of a avgpool layer.

compute_sequential_output_shape(model[, ...])

Compute the output shape of a sequential model.

compute_module_size(module[, requires_grad, ...])

compute the size (number of parameters) of a Module.

default_collate_fn(batch)

Default collate functions for model training.

compute_receptive_field([kernel_sizes, ...])

Compute the receptive field of several types of Module.

adjust_cnn_filter_lengths(config, fs[, ...])

Adjust the filter lengths in the config for convolutional neural networks, according to the new sampling frequency.

SizeMixin()

Mixin class for size related methods

CkptMixin()

Mixin class for loading from checkpoint class methods

Signal processing functions

smooth(x[, window_len, window, mode, keep_dtype])

Smooth the 1d data using a window with requested size.

resample_irregular_timeseries(sig[, ...])

Resample the 2d irregular timeseries sig into a 1d or 2d regular time series with frequency output_fs, elements of sig are in the form [time, value], where the unit of time is ms.

detect_peaks(x[, mph, mpd, threshold, ...])

Detect peaks in data based on their amplitude and other features.

remove_spikes_naive(sig[, threshold, inplace])

Remove signal spikes using a naive method.

butter_bandpass_filter(data, lowcut, ...[, ...])

Butterworth bandpass filtering the signals.

get_ampl(sig, fs[, fmt, window, critical_points])

Get amplitude of a signal (near critical points if given).

normalize(sig, method[, mean, std, sig_fmt, ...])

Normalize a signal.

normalize_t(sig[, method, mean, std, ...])

Perform z-score normalization on sig, to make it has fixed mean and standard deviation, or perform min-max normalization on sig, or normalize sig using mean and std via \((sig - mean) / std\).

resample_t(sig[, fs, dst_fs, siglen, inplace])

Resample signal tensors to a new sampling frequency or a new signal length.

Data operations

get_mask(shape, critical_points, left_bias, ...)

Get the mask around the given critical points.

class_weight_to_sample_weight(y[, class_weight])

Transform class weight to sample weight.

ensure_lead_fmt(values[, n_leads, fmt])

Ensure the multi-lead (ECG) signal to be of specified format.

ensure_siglen(values, siglen[, fmt, tolerance])

Ensure the (ECG) signal to be of specified length.

masks_to_waveforms(masks, class_map, fs[, ...])

Convert masks into lists of ECGWaveForm for each lead.

mask_to_intervals(mask[, vals, right_inclusive])

Convert a mask into a list of intervals, or a dict of lists of intervals.

uniform(low, high, num)

Generate a list of numbers uniformly distributed.

stratified_train_test_split(df, stratified_cols)

Perform stratified train-test split on the dataframe.

cls_to_bin(cls_array[, num_classes])

Convert a categorical array to a one-hot array.

generate_weight_mask(target_mask, fg_weight, ...)

Generate weight mask for a binary target mask, accounting the foreground weight and boundary weight.

Interval operations

overlaps(interval, another)

Find the overlap between two intervals.

validate_interval(interval[, join_book_endeds])

Check whether interval is an Interval or a GeneralizedInterval.

in_interval(val, interval[, left_closed, ...])

Check whether val is inside interval or not.

in_generalized_interval(val, ...[, ...])

Check whether val is inside generalized_interval or not.

intervals_union(interval_list[, ...])

Find the union of intervals.

generalized_intervals_union(interval_list[, ...])

Calculate the union of a list (or tuple) of GeneralizedInterval.

intervals_intersection(interval_list[, ...])

Calculate the intersection of all intervals in interval_list.

generalized_intervals_intersection(...[, ...])

calculate the intersection of intervals.

generalized_interval_complement(...)

Calculate the complement of an interval in another interval.

get_optimal_covering(total_interval, ...[, ...])

Compute an optimal covering of to_cover by intervals.

interval_len(interval)

Compute the length of an interval.

generalized_interval_len(generalized_interval)

Compute the length of an interval.

find_extrema(signal[, mode])

Locate local extrema points in a 1D signal.

is_intersect(interval, another_interval)

Determines if two (generalized) intervals intersect or not.

max_disjoint_covering(intervals[, ...])

Find the largest (the largest interval length) covering of a sequence of intervals.

Metrics computations

top_n_accuracy(labels, outputs[, n])

Compute top n accuracy.

confusion_matrix(labels, outputs[, num_classes])

Compute a binary confusion matrix

ovr_confusion_matrix(labels, outputs[, ...])

Compute binary one-vs-rest confusion matrices.

metrics_from_confusion_matrix(labels, outputs)

Compute macro metrics, and metrics for each class.

compute_wave_delineation_metrics(...[, ...])

QRS_score(rpeaks_truths, rpeaks_preds, fs[, thr])

QRS accuracy score, proposed in CPSC2019.

Decorators and Mixins

add_docstring(doc[, mode])

Decorator to add docstring to a function or a class.

remove_parameters_returns_from_docstring(doc)

Remove parameters and/or returns from docstring, which is of the format of numpydoc.

default_class_repr(c[, align, depth])

Default class representation.

ReprMixin()

Mixin class for enhanced __repr__() and __str__() methods.

CitationMixin()

Mixin class for getting citations from DOIs.

get_kwargs(func_or_cls[, kwonly])

Get the kwargs of a function or class.

get_required_args(func_or_cls)

Get the required positional arguments of a function or class.

add_kwargs(func, **kwargs)

Add keyword arguments to a function.

Path operations

get_record_list_recursive3(db_dir, rec_patterns)

Get the list of records in a recursive manner.

String operations

dict_to_str(d[, current_depth, indent_spaces])

Convert a (possibly) nested dict into a str of json-like formatted form.

str2bool(v)

Converts a "boolean" value possibly in the format of str to bool.

nildent(text)

Kill all leading white spaces in each line of text, while keeping all lines (including empty)

get_date_str([fmt])

Get the current time in the str format.

Visualization functions

ecg_plot(ecg, sample_rate, columns, ...[, ...])

Function to plot raw ECG signal.

Miscellaneous

init_logger([log_dir, log_file, log_name, ...])

Initialize a logger.

list_sum(lst)

Sum a sequence of lists.

dicts_equal(d1, d2[, allow_array_diff_types])

Determine if two dicts are equal.

MovingAverage([data])

Class for computing moving average.

Timer([name, verbose])

Context manager to time the execution of a block of code.

timeout(duration)

A context manager that raises a TimeoutError after a specified time.