Source code for torch_ecg.utils._ecg_plot

"""
BSD 3-Clause License

Copyright (c) 2024, The Alphanumerics Lab

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this
   list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice,
   this list of conditions and the following disclaimer in the documentation
   and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its
   contributors may be used to endorse or promote products derived from
   this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


Modified from https://github.com/alphanumericslab/ecg-image-kit/blob/main/codes/ecg-image-generator/ecg_plot.py
"""

import os
import random
from math import ceil
from random import randint
from typing import Dict, List, Optional, Tuple, Union

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import AutoMinorLocator

__all__ = ["ecg_plot"]


standard_values = {
    "y_grid_size": 0.5,
    "x_grid_size": 0.2,
    "y_grid_inch": 5 / 25.4,
    "x_grid_inch": 5 / 25.4,
    "grid_line_width": 0.5,
    "lead_name_offset": 0.5,
    "lead_fontsize": 11.0,
    "x_gap": 1.0,
    "y_gap": 0.5,
    "display_factor": 1.0,
    "line_width": 0.75,
    "row_height": 8.0,
    "dc_offset_length": 0.2,
    "lead_length": 3.0,
    "V1_length": 12.0,
    "width": 11.0,
    "height": 8.5,
}

standard_major_colors = {
    "colour1": (0.4274, 0.196, 0.1843),  # brown
    "colour2": (1, 0.796, 0.866),  # pink
    "colour3": (0.0, 0.0, 0.4),  # blue
    "colour4": (0, 0.3, 0.0),  # green
    "colour5": (1, 0, 0),  # red
}


standard_minor_colors = {
    "colour1": (0.5882, 0.4196, 0.3960),
    "colour2": (0.996, 0.9294, 0.9725),
    "colour3": (0.0, 0, 0.7),
    "colour4": (0, 0.8, 0.3),
    "colour5": (0.996, 0.8745, 0.8588),
}

papersize_values = {
    "A0": (33.1, 46.8),
    "A1": (33.1, 23.39),
    "A2": (16.54, 23.39),
    "A3": (11.69, 16.54),
    "A4": (8.27, 11.69),
    "letter": (8.5, 11),
}

leadNames_12 = ["III", "aVF", "V3", "V6", "II", "aVL", "V2", "V5", "I", "aVR", "V1", "V4"]


def inches_to_dots(value: float, resolution: int) -> float:
    return value * resolution


def create_signal_dictionary(signal: np.ndarray, full_leads: List[str]) -> Dict[str, np.ndarray]:
    record_dict = {}
    for k in range(len(full_leads)):
        record_dict[full_leads[k]] = signal[k]
    return record_dict


[docs]def ecg_plot( ecg: Dict[str, np.ndarray], sample_rate: int, columns: int, rec_file_name: Union[str, bytes, os.PathLike], output_dir: Union[str, bytes, os.PathLike], resolution: int = 200, pad_inches: int = 0, lead_index: Optional[List[str]] = None, full_mode: str = "None", store_text_bbox: bool = False, units: str = "mV", papersize: Optional[str] = None, x_gap: float = standard_values["x_gap"], y_gap: float = standard_values["y_gap"], display_factor: float = standard_values["display_factor"], line_width: float = standard_values["line_width"], title: Optional[str] = None, style: Optional[str] = None, row_height: float = standard_values["row_height"], show_lead_name: bool = True, show_grid: bool = False, show_dc_pulse: bool = False, y_grid: Optional[float] = None, x_grid: Optional[float] = None, standard_colours: int = 0, bbox: bool = False, print_txt: bool = False, save_format: Optional[str] = None, ) -> Tuple[float, float]: """Function to plot raw ECG signal. Parameters ---------- ecg : Dict[str, np.ndarray] Dictionary of ECG signals with lead names as keys, values as 1D numpy arrays. sample_rate : int Sampling rate of the ECG signal. columns : int Number of columns to be plotted in each row. rec_file_name : `path-like` Name of the record file. output_dir : `path-like` Output directory. resolution : int, default ``200`` Resolution of the output image. pad_inches : int, default ``0`` Padding of white margin along the image in inches. lead_index : List[str], optional Order of lead indices to be plotted. By default, the order is the same as the order in ``ecg``. full_mode : str, default ``"None"`` Sets the lead to add at the bottom of the paper ECG as a long strip. If ``"None"``, no lead is added at the bottom. If not ``"None"``, the lead ``"full" + full_mode`` must be present in ``ecg``. store_text_bbox : bool, default ``False`` If ``True``, stores the bounding box of the text in a text file. units : str, default ``"mV"`` Units of the ECG signal. NOT used currently. papersize : {``"A0"``, ``"A1"``, ``"A2"``, ``"A3"``, ``"A4"``, ``"letter"``}, default ``None`` Size of the paper to plot the ECG on. x_gap : float, default ``1.0`` Gap between paper x axis border and signal plot. y_gap : float, default ``0.5`` Gap between paper y axis border and signal plot. display_factor : float, default ``1.0`` Factor to scale the ECG signal by. line_width : float, default ``0.75`` Width of line tracing the ECG. title : str, optional Title of the figure. style : {``"bw"``, ``"color"``}, optional Sets the style of the plot. row_height : float, default ``8.0`` Gap between corresponding ECG rows. show_lead_name : bool, default ``True`` Option to show lead names or skip. show_grid : bool, default ``False`` Turn grid on or off. show_dc_pulse : bool, default ``False`` Option to show DC pulse. y_grid : float, optional Sets the y grid size in inches. x_grid : float, optional Sets the x grid size in inches. standard_colours : {``0``, ``1``, ``2``, ``3``, ``4``, ``5``}, default ``0`` Sets the colour of the plot grid. bbox : bool, default ``False`` If ``True``, stores the bounding box of the lead in a text file. print_txt : bool, default ``False`` If ``True``, prints the metadata of the plot. save_format : {``"png"``, ``"pdf"``, ``"svg"``, ``"jpg"``}, optional Format to save the plot in. If ``None``, the plot is not saved. Returns ------- x_grid_dots : float X grid size in dots. y_grid_dots : float Y grid size in dots. """ matplotlib.use("Agg") randindex = randint(0, 99) random_sampler = random.uniform(-0.05, 0.004) # check if the ecg dict is empty if ecg == {}: return 0, 0 secs = len(list(ecg.items())[0][1]) / sample_rate assert secs * columns <= 10, "Total time of ECG signal of one row cannot exceed 10 seconds" if lead_index is None: lead_index = list(ecg.keys()) leads = len(lead_index) rows = int(ceil(leads / columns)) if full_mode != "None": rows += 1 leads += 1 # Grid calibration # Each big grid corresponds to 0.2 seconds and 0.5 mV # To do: Select grid size in a better way y_grid_size = standard_values["y_grid_size"] x_grid_size = standard_values["x_grid_size"] grid_line_width = standard_values["grid_line_width"] lead_name_offset = standard_values["lead_name_offset"] lead_fontsize = standard_values["lead_fontsize"] # Set max and min coordinates to mark grid. Offset x_max slightly (i.e by 1 column width) if not papersize: width = standard_values["width"] height = standard_values["height"] else: width = papersize_values[papersize][1] height = papersize_values[papersize][0] if y_grid is None: y_grid = standard_values["y_grid_inch"] if x_grid is None: x_grid = standard_values["x_grid_inch"] y_grid_dots = y_grid * resolution x_grid_dots = x_grid * resolution # row_height = height * y_grid_size/(y_grid*(rows+2)) row_height = (height * y_grid_size / y_grid) / (rows + 2) x_max = width * x_grid_size / x_grid x_min = 0 x_gap = np.floor(((x_max - (columns * secs)) / 2) / 0.2) * 0.2 y_min = 0 y_max = height * y_grid_size / y_grid # Set figure and subplot sizes fig, ax = plt.subplots(figsize=(width, height)) fig.subplots_adjust(hspace=0, wspace=0, left=0, right=1, bottom=0, top=1) fig.suptitle(title) # Mark grid based on whether we want black and white or colour if style == "bw": color_major = (0.4, 0.4, 0.4) color_minor = (0.75, 0.75, 0.75) color_line = (0, 0, 0) elif standard_colours > 0: random_colour_index = standard_colours color_major = standard_major_colors["colour" + str(random_colour_index)] color_minor = standard_minor_colors["colour" + str(random_colour_index)] randcolorindex_grey = randint(0, 24) grey_random_color = random.uniform(0, 0.2) color_line = (grey_random_color, grey_random_color, grey_random_color) else: randcolorindex_red = randint(0, 24) major_random_color_sampler_red = random.uniform(0, 0.8) randcolorindex_green = randint(0, 24) major_random_color_sampler_green = random.uniform(0, 0.5) randcolorindex_blue = randint(0, 24) major_random_color_sampler_blue = random.uniform(0, 0.5) randcolorindex_minor = randint(0, 24) minor_offset = random.uniform(0, 0.2) minor_random_color_sampler_red = major_random_color_sampler_red + minor_offset minor_random_color_sampler_green = random.uniform(0, 0.5) + minor_offset minor_random_color_sampler_blue = random.uniform(0, 0.5) + minor_offset randcolorindex_grey = randint(0, 24) grey_random_color = random.uniform(0, 0.2) color_major = (major_random_color_sampler_red, major_random_color_sampler_green, major_random_color_sampler_blue) color_minor = (minor_random_color_sampler_red, minor_random_color_sampler_green, minor_random_color_sampler_blue) color_line = (grey_random_color, grey_random_color, grey_random_color) # Set grid # Standard ecg has grid size of 0.5 mV and 0.2 seconds. Set ticks accordingly if show_grid: ax.set_xticks(np.arange(x_min, x_max, x_grid_size)) ax.set_yticks(np.arange(y_min, y_max, y_grid_size)) ax.minorticks_on() ax.xaxis.set_minor_locator(AutoMinorLocator(5)) # set grid line style ax.grid(which="major", linestyle="-", linewidth=grid_line_width, color=color_major) ax.grid(which="minor", linestyle="-", linewidth=grid_line_width, color=color_minor) else: ax.grid(False) ax.set_ylim(y_min, y_max) ax.set_xlim(x_min, x_max) ax.tick_params(axis="x", colors="white") ax.tick_params(axis="y", colors="white") lead_num = 0 # Step size will be number of seconds per sample i.e 1/sampling_rate step = 1.0 / sample_rate dc_offset = 0 if show_dc_pulse: dc_offset = sample_rate * standard_values["dc_offset_length"] * step # Iterate through each lead in lead_index array. y_offset = row_height / 2 x_offset = 0 text_bbox = [] lead_bbox = [] for i in np.arange(len(lead_index)): if len(lead_index) == 12: leadName = leadNames_12[i] else: leadName = lead_index[i] # y_offset is computed by shifting by a certain offset based on i, # and also by row_height/2 to account for half the waveform below the axis if i % columns == 0: y_offset += row_height # x_offset will be distance by which we shift the plot in each iteration if columns > 1: x_offset = (i % columns) * secs else: x_offset = 0 # Create dc pulse wave to plot at the beginning of plot. Dc pulse will be 0.2 seconds x_range = np.arange(0, sample_rate * standard_values["dc_offset_length"] * step + 4 * step, step) dc_pulse = np.ones(len(x_range)) dc_pulse = np.concatenate(((0, 0), dc_pulse[2:-2], (0, 0))) # Print lead name at .5 ( or 5 mm distance) from plot if show_lead_name: t1 = ax.text(x_offset + x_gap, y_offset - lead_name_offset - 0.2, leadName, fontsize=lead_fontsize) if store_text_bbox: renderer1 = fig.canvas.get_renderer() transf = ax.transData.inverted() bb = t1.get_window_extent() x1 = bb.x0 * resolution / fig.dpi y1 = bb.y0 * resolution / fig.dpi x2 = bb.x1 * resolution / fig.dpi y2 = bb.y1 * resolution / fig.dpi text_bbox.append([x1, y1, x2, y2, leadName]) # If we are plotting the first row-1 plots, we plot the dc pulse prior to adding the waveform if columns == 1 and i in np.arange(0, rows): if show_dc_pulse: # Plot dc pulse for 0.2 seconds with 2 trailing and leading zeros to get the pulse t1 = ax.plot(x_range + x_offset + x_gap, dc_pulse + y_offset, linewidth=line_width * 1.5, color=color_line) if bbox: renderer1 = fig.canvas.get_renderer() transf = ax.transData.inverted() bb = t1[0].get_window_extent() x1, y1 = bb.x0 * resolution / fig.dpi, bb.y0 * resolution / fig.dpi x2, y2 = bb.x1 * resolution / fig.dpi, bb.y1 * resolution / fig.dpi elif columns == 4 and i == 0 or i == 4 or i == 8: if show_dc_pulse: # Plot dc pulse for 0.2 seconds with 2 trailing and leading zeros to get the pulse t1 = ax.plot( np.arange(0, sample_rate * standard_values["dc_offset_length"] * step + 4 * step, step) + x_offset + x_gap, dc_pulse + y_offset, linewidth=line_width * 1.5, color=color_line, ) if bbox: renderer1 = fig.canvas.get_renderer() transf = ax.transData.inverted() bb = t1[0].get_window_extent() x1, y1 = bb.x0 * resolution / fig.dpi, bb.y0 * resolution / fig.dpi x2, y2 = bb.x1 * resolution / fig.dpi, bb.y1 * resolution / fig.dpi t1 = ax.plot( np.arange(0, len(ecg[leadName]) * step, step) + x_offset + dc_offset + x_gap, ecg[leadName] + y_offset, linewidth=line_width, color=color_line, ) if bbox: renderer1 = fig.canvas.get_renderer() transf = ax.transData.inverted() bb = t1[0].get_window_extent() if show_dc_pulse is False or (columns == 4 and (i != 0 and i != 4 and i != 8)): x1, y1 = bb.x0 * resolution / fig.dpi, bb.y0 * resolution / fig.dpi x2, y2 = bb.x1 * resolution / fig.dpi, bb.y1 * resolution / fig.dpi else: y1 = min(y1, bb.y0 * resolution / fig.dpi) y2 = max(y2, bb.y1 * resolution / fig.dpi) x2 = bb.x1 * resolution / fig.dpi lead_bbox.append([x1, y1, x2, y2, 0]) start_ind = round((x_offset + dc_offset + x_gap) * x_grid_dots / x_grid_size) end_ind = round((x_offset + dc_offset + x_gap + len(ecg[leadName]) * step) * x_grid_dots / x_grid_size) # Plotting longest lead for 12 seconds if full_mode != "None": if show_lead_name: t1 = ax.text(x_gap, row_height / 2 - lead_name_offset, full_mode, fontsize=lead_fontsize) if store_text_bbox: renderer1 = fig.canvas.get_renderer() transf = ax.transData.inverted() bb = t1.get_window_extent(renderer=fig.canvas.renderer) x1 = bb.x0 * resolution / fig.dpi y1 = bb.y0 * resolution / fig.dpi x2 = bb.x1 * resolution / fig.dpi y2 = bb.y1 * resolution / fig.dpi text_bbox.append([x1, y1, x2, y2, full_mode]) if show_dc_pulse: t1 = ax.plot( x_range + x_gap, dc_pulse + row_height / 2 - lead_name_offset + 0.8, linewidth=line_width * 1.5, color=color_line, ) if bbox: renderer1 = fig.canvas.get_renderer() transf = ax.transData.inverted() bb = t1[0].get_window_extent() x1, y1 = bb.x0 * resolution / fig.dpi, bb.y0 * resolution / fig.dpi x2, y2 = bb.x1 * resolution / fig.dpi, bb.y1 * resolution / fig.dpi dc_full_lead_offset = 0 if show_dc_pulse: dc_full_lead_offset = sample_rate * standard_values["dc_offset_length"] * step t1 = ax.plot( np.arange(0, len(ecg["full" + full_mode]) * step, step) + x_gap + dc_full_lead_offset, ecg["full" + full_mode] + row_height / 2 - lead_name_offset + 0.8, linewidth=line_width, color=color_line, ) if bbox: renderer1 = fig.canvas.get_renderer() transf = ax.transData.inverted() bb = t1[0].get_window_extent() if show_dc_pulse is False: x1, y1 = bb.x0 * resolution / fig.dpi, bb.y0 * resolution / fig.dpi x2, y2 = bb.x1 * resolution / fig.dpi, bb.y1 * resolution / fig.dpi else: y1 = min(y1, bb.y0 * resolution / fig.dpi) y2 = max(y2, bb.y1 * resolution / fig.dpi) x2 = bb.x1 * resolution / fig.dpi lead_bbox.append([x1, y1, x2, y2, 1]) start_ind = round((dc_full_lead_offset + x_gap) * x_grid_dots / x_grid_size) end_ind = round((dc_full_lead_offset + x_gap + len(ecg["full" + full_mode]) * step) * x_grid_dots / x_grid_size) head, tail = os.path.split(rec_file_name) rec_file_name = os.path.join(output_dir, tail) if print_txt: pass # x_offset = 0.05 # y_offset = int(y_max) # printed_text, attributes, flag = generate_template(full_header_file) # if flag: # for l in range(0, len(printed_text), 1): # for j in printed_text[l]: # curr_l = '' # if j in attributes.keys(): # curr_l += str(attributes[j]) # ax.text(x_offset, y_offset, curr_l, fontsize=lead_fontsize) # x_offset += 3 # y_offset -= 0.5 # x_offset = 0.05 # else: # for line in printed_text: # ax.text(x_offset, y_offset, line, fontsize=lead_fontsize) # y_offset -= 0.5 # change x and y res ax.text(2, 0.5, "25mm/s", fontsize=lead_fontsize) ax.text(4, 0.5, "10mm/mV", fontsize=lead_fontsize) if save_format is not None: save_format = f""".{save_format.strip(".")}""".lower() assert save_format in [ ".png", ".pdf", ".svg", ".jpg", ], f"save_format must be one of '.png', '.pdf', '.svg', '.jpg', but got {save_format}" else: return x_grid_dots, y_grid_dots plt.savefig(os.path.join(output_dir, tail + save_format), dpi=resolution, bbox_inches="tight", pad_inches=pad_inches) plt.close(fig) plt.clf() plt.cla() # margins right = pad_inches * resolution left = pad_inches * resolution top = pad_inches * resolution bottom = pad_inches * resolution if store_text_bbox: if os.path.exists(os.path.join(output_dir, "text_bounding_box")) is False: os.mkdir(os.path.join(output_dir, "text_bounding_box")) with open(os.path.join(output_dir, "text_bounding_box", tail + ".txt"), "w") as f: for i, l in enumerate(text_bbox): if pad_inches != 0: l[0] += left l[2] += left l[1] += top l[3] += top for val in l[:4]: f.write(str(val)) f.write(",") f.write(str(l[4])) f.write("\n") if bbox: if os.path.exists(os.path.join(output_dir, "lead_bounding_box")) is False: os.mkdir(os.path.join(output_dir, "lead_bounding_box")) with open(os.path.join(output_dir, "lead_bounding_box", tail + ".txt"), "w") as f: for i, l in enumerate(lead_bbox): if pad_inches != 0: l[0] += left l[2] += left l[1] += top l[3] += top for val in l[:4]: f.write(str(val)) f.write(",") f.write(str(l[4])) f.write("\n") return x_grid_dots, y_grid_dots