"""
"""
import datetime
import inspect
import logging
import os
import re
import signal
import sys
import time
import types
import warnings
from contextlib import contextmanager
from copy import deepcopy
from functools import reduce, wraps
from glob import glob
from numbers import Number, Real
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np
import pandas as pd
from bib_lookup import CitationMixin as _CitationMixin
from deprecated import deprecated
from ..cfg import _DATA_CACHE, DEFAULTS
__all__ = [
"get_record_list_recursive",
"get_record_list_recursive2",
"get_record_list_recursive3",
"dict_to_str",
"str2bool",
"ms2samples",
"samples2ms",
"plot_single_lead",
"init_logger",
"get_date_str",
"list_sum",
"read_log_txt",
"read_event_scalars",
"dicts_equal",
"default_class_repr",
"ReprMixin",
"CitationMixin",
"MovingAverage",
"nildent",
"add_docstring",
"remove_parameters_returns_from_docstring",
"timeout",
"Timer",
"get_kwargs",
"get_required_args",
"add_kwargs",
"make_serializable",
]
def get_record_list_recursive(db_dir: Union[str, bytes, os.PathLike], rec_ext: str, relative: bool = True) -> List[str]:
"""Get the list of records in a recursive manner.
For example, there are two folders "patient1", "patient2" in `db_dir`,
and there are records "A0001", "A0002", ... in "patient1";
"B0001", "B0002", ... in "patient2",
then the output would be "patient1{sep}A0001", ..., "patient2{sep}B0001", ...,
sep is determined by the system.
Parameters
----------
db_dir : `path-like`
The parent (root) path of to search for records.
rec_ext : str
Extension of the record files.
relative : bool, default True
Whether to return the relative path of the records.
Returns
-------
List[str]
The list of records, in lexicographical order.
"""
if not rec_ext.startswith("."):
res = Path(db_dir).rglob(f"*.{rec_ext}")
else:
res = Path(db_dir).rglob(f"*{rec_ext}")
res = [str((item.relative_to(db_dir) if relative else item).with_suffix("")) for item in res if str(item).endswith(rec_ext)]
res = sorted(res)
return res
@deprecated(reason="use `get_record_list_recursive3` instead")
def get_record_list_recursive2(db_dir: Union[str, bytes, os.PathLike], rec_pattern: str) -> List[str]:
"""Get the list of records in a recursive manner.
For example, there are two folders "patient1", "patient2" in `db_dir`,
and there are records "A0001", "A0002", ... in "patient1";
"B0001", "B0002", ... in "patient2",
then the output would be "patient1{sep}A0001", ..., "patient2{sep}B0001", ...,
sep is determined by the system.
Parameters
----------
db_dir : `path-like`
The parent (root) path of to search for records.
rec_pattern : str
Pattern of the record filenames, e.g. ``"A*.mat"``.
Returns
-------
List[str]
The list of records, in lexicographical order.
"""
res = []
roots = [str(db_dir)]
while len(roots) > 0:
new_roots = []
for r in roots:
tmp = [os.path.join(r, item) for item in os.listdir(r)]
# res += [item for item in tmp if os.path.isfile(item)]
res += glob(os.path.join(r, rec_pattern), recursive=False)
new_roots += [item for item in tmp if os.path.isdir(item)]
roots = deepcopy(new_roots)
res = [os.path.splitext(item)[0].replace(str(db_dir), "").strip(os.sep) for item in res]
res = sorted(res)
return res
[docs]def get_record_list_recursive3(
db_dir: Union[str, bytes, os.PathLike],
rec_patterns: Union[str, Dict[str, str]],
relative: bool = True,
) -> Union[List[str], Dict[str, List[str]]]:
"""Get the list of records in a recursive manner.
For example, there are two folders "patient1", "patient2" in `db_dir`,
and there are records "A0001", "A0002", ... in "patient1";
"B0001", "B0002", ... in "patient2",
then the output would be "patient1{sep}A0001", ..., "patient2{sep}B0001", ...,
sep is determined by the system.
Parameters
----------
db_dir : `path-like`
The parent (root) path of to search for records.
rec_patterns : str or dict
Pattern of the record filenames, e.g. ``"A(?:\\d+).mat"``,
or patterns of several subsets, e.g. ``{"A": "A(?:\\d+).mat"}``
relative : bool, default True
Whether to return the relative path of the records.
Returns
-------
List[str] or dict
The list of records, in lexicographical order.
"""
if isinstance(rec_patterns, str):
res = []
elif isinstance(rec_patterns, dict):
res = {k: [] for k in rec_patterns.keys()}
_db_dir = Path(db_dir).resolve() # make absolute
roots = [_db_dir]
while len(roots) > 0:
new_roots = []
for r in roots:
tmp = os.listdir(r)
if isinstance(rec_patterns, str):
res += [r / item for item in filter(re.compile(rec_patterns).search, tmp)]
elif isinstance(rec_patterns, dict):
for k in rec_patterns.keys():
res[k] += [r / item for item in filter(re.compile(rec_patterns[k]).search, tmp)]
new_roots += [r / item for item in tmp if (r / item).is_dir()]
roots = deepcopy(new_roots)
if isinstance(rec_patterns, str):
res = [str((item.relative_to(_db_dir) if relative else item).with_suffix("")) for item in res]
res = sorted(res)
elif isinstance(rec_patterns, dict):
for k in rec_patterns.keys():
res[k] = [str((item.relative_to(_db_dir) if relative else item).with_suffix("")) for item in res[k]]
res[k] = sorted(res[k])
return res
[docs]def dict_to_str(d: Union[dict, list, tuple], current_depth: int = 1, indent_spaces: int = 4) -> str:
"""Convert a (possibly) nested dict into a `str` of json-like formatted form.
This nested dict might also contain lists or tuples of dict (and of str, int, etc.)
Parameters
----------
d : dict or list or tuple
A (possibly) nested :class:`dict`, or a list of :class:`dict`.
current_depth : int, default 1
Depth of `d` in the (possible) parent :class:`dict` or :class:`list`.
indent_spaces : int, default 4
The indent spaces of each depth.
Returns
-------
str
The formatted string.
"""
assert isinstance(d, (dict, list, tuple))
if len(d) == 0:
s = r"{}" if isinstance(d, dict) else "[]"
return s
# flat_types = (Number, bool, str,)
flat_types = (
Number,
bool,
)
flat_sep = ", "
s = "\n"
unit_indent = " " * indent_spaces
prefix = unit_indent * current_depth
if isinstance(d, (list, tuple)):
if all([isinstance(v, flat_types) for v in d]):
len_per_line = 110
current_len = len(prefix) + 1 # + 1 for a comma
val = []
for idx, v in enumerate(d):
add_v = f"\042{v}\042" if isinstance(v, str) else str(v)
add_len = len(add_v) + len(flat_sep)
if current_len + add_len > len_per_line:
val = ", ".join([item for item in val])
s += f"{prefix}{val},\n"
val = [add_v]
current_len = len(prefix) + 1 + len(add_v)
else:
val.append(add_v)
current_len += add_len
if len(val) > 0:
val = ", ".join([item for item in val])
s += f"{prefix}{val}\n"
else:
for idx, v in enumerate(d):
if isinstance(v, (dict, list, tuple)):
s += f"{prefix}{dict_to_str(v, current_depth+1)}"
else:
val = f"\042{v}\042" if isinstance(v, str) else v
s += f"{prefix}{val}"
if idx < len(d) - 1:
s += ",\n"
else:
s += "\n"
elif isinstance(d, dict):
for idx, (k, v) in enumerate(d.items()):
key = f"\042{k}\042" if isinstance(k, str) else k
if isinstance(v, (dict, list, tuple)):
s += f"{prefix}{key}: {dict_to_str(v, current_depth+1)}"
else:
val = f"\042{v}\042" if isinstance(v, str) else v
s += f"{prefix}{key}: {val}"
if idx < len(d) - 1:
s += ",\n"
else:
s += "\n"
s += unit_indent * (current_depth - 1)
s = f"{{{s}}}" if isinstance(d, dict) else f"[{s}]"
return s
[docs]def str2bool(v: Union[str, bool]) -> bool:
"""Converts a "boolean" value possibly
in the format of :class:`str` to :class:`bool`.
Modified from [#str2bool]_.
Parameters
----------
v : str or bool
The "boolean" value.
Returns
-------
bool
`v` in the format of a bool.
References
----------
.. [#str2bool] https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
"""
if isinstance(v, bool):
b = v
elif v.lower() in ("yes", "true", "t", "y", "1"):
b = True
elif v.lower() in ("no", "false", "f", "n", "0"):
b = False
else:
raise ValueError("Boolean value expected.")
return b
@deprecated("Use `np.diff` instead.")
def diff_with_step(a: np.ndarray, step: int = 1) -> np.ndarray:
"""Compute ``a[n+step] - a[n]`` for all valid `n`.
Parameters
----------
a : numpy.ndarray
The input data.
step : int, default 1
The step size to compute the difference.
Returns
-------
numpy.ndarray
The difference array.
"""
if step >= len(a):
raise ValueError(f"`step` ({step}) should be less than the length ({len(a)}) of `a`")
d = a[step:] - a[:-step]
return d
def ms2samples(t: Real, fs: Real) -> int:
"""Convert time duration in ms to number of samples.
Parameters
----------
t : numbers.Real
Time duration in ms.
fs : numbers.Real
Sampling frequency.
Returns
-------
n_samples : int
Number of samples converted from `t`,
with sampling frequency `fs`.
"""
n_samples = t * fs // 1000
return n_samples
def samples2ms(n_samples: int, fs: Real) -> Real:
"""Convert number of samples to time duration in ms.
Parameters
----------
n_samples : int
Number of sample points.
fs : numbers.Real
Sampling frequency.
Returns
-------
t : numbers.Real
Time duration in ms converted from `n_samples`,
with sampling frequency `fs`.
"""
t = n_samples * 1000 / fs
return t
def plot_single_lead(
t: np.ndarray,
sig: np.ndarray,
ax: Optional[Any] = None,
ticks_granularity: int = 0,
**kwargs,
) -> None:
"""Plot single lead ECG signal.
Parameters
----------
t : numpy.ndarray
The array of time points.
sig : numpy.ndarray
The signal itself.
ax : matplotlib.axes.Axes, default None
The axes to plot on.
ticks_granularity : int, default 0
Granularity to plot axis ticks, the higher the more ticks.
0 (no ticks) --> 1 (major ticks) --> 2 (major + minor ticks)
Returns
-------
None
"""
if "plt" not in dir():
import matplotlib.pyplot as plt
palette = {
"p_waves": "green",
"qrs": "red",
"t_waves": "pink",
}
plot_alpha = 0.4
y_range = np.max(np.abs(sig)) + 100
if ax is None:
fig_sz_w = int(round(4.8 * (t[-1] - t[0])))
fig_sz_h = 6 * y_range / 1500
fig, ax = plt.subplots(figsize=(fig_sz_w, fig_sz_h))
label = kwargs.get("label", None)
if label:
ax.plot(t, sig, label=kwargs.get("label"))
else:
ax.plot(t, sig)
ax.axhline(y=0, linestyle="-", linewidth="1.0", color="red")
# NOTE that `Locator` has default `MAXTICKS` equal to 1000
if ticks_granularity >= 1:
ax.xaxis.set_major_locator(plt.MultipleLocator(0.2))
ax.yaxis.set_major_locator(plt.MultipleLocator(500))
ax.grid(which="major", linestyle="-", linewidth="0.5", color="red")
if ticks_granularity >= 2:
ax.xaxis.set_minor_locator(plt.MultipleLocator(0.04))
ax.yaxis.set_minor_locator(plt.MultipleLocator(100))
ax.grid(which="minor", linestyle=":", linewidth="0.5", color="black")
waves = kwargs.get("waves", {"p_waves": [], "qrs": [], "t_waves": []})
for w, l_itv in waves.items():
for itv in l_itv:
ax.axvspan(itv[0], itv[1], color=palette[w], alpha=plot_alpha)
if label:
ax.legend(loc="upper left")
ax.set_xlim(t[0], t[-1])
ax.set_ylim(-y_range, y_range)
ax.set_xlabel("Time [s]")
ax.set_ylabel("Voltage [μV]")
[docs]def init_logger(
log_dir: Optional[Union[str, Path, bool]] = None,
log_file: Optional[str] = None,
log_name: Optional[str] = None,
suffix: Optional[str] = None,
mode: str = "a",
verbose: int = 0,
) -> logging.Logger:
"""Initialize a logger.
Parameters
----------
log_dir : `path-like` or bool, optional
Directory of the log file,
default to `DEFAULTS.log_dir`.
If is `False`, then no log file will be created.
log_file : str, optional
Name of the log file,
default to ``{DEFAULTS.prefix}-log-{get_date_str()}.txt``.
log_name : str, optional
Name of the logger.
suffix : str, optional
Suffix of the logger name.
Ignored if `log_name` is not `None`.
mode : {"a", "w"}, default "a"
Mode to open the log file.
verbose : int, default 0
Verbosity level for the logger.
Returns
-------
logger : logging.Logger
The logger.
"""
if log_dir is False:
log_file = None
else:
if log_file is None:
log_file = f"{DEFAULTS.prefix}-log-{get_date_str()}.txt"
log_dir = Path(log_dir).expanduser().resolve() if log_dir is not None else DEFAULTS.log_dir
log_dir.mkdir(parents=True, exist_ok=True)
log_file = log_dir / log_file
print(f"log file path: {str(log_file)}")
log_name = (log_name or DEFAULTS.prefix) + (f"-{suffix}" if suffix else "")
# if a logger with the same name already exists, remove it
if log_name in logging.root.manager.loggerDict:
logging.getLogger(log_name).handlers = []
logger = logging.getLogger(log_name) # to prevent from using the root logger
c_handler = logging.StreamHandler(sys.stdout)
if log_file is not None:
f_handler = logging.FileHandler(str(log_file))
if verbose >= 2:
# print("level of `c_handler` is set DEBUG")
c_handler.setLevel(logging.DEBUG)
if log_file is not None:
# print("level of `f_handler` is set DEBUG")
f_handler.setLevel(logging.DEBUG)
logger.setLevel(logging.DEBUG)
elif verbose >= 1:
# print("level of `c_handler` is set INFO")
c_handler.setLevel(logging.INFO)
if log_file is not None:
# print("level of `f_handler` is set DEBUG")
f_handler.setLevel(logging.DEBUG)
logger.setLevel(logging.DEBUG)
else:
# print("level of `c_handler` is set WARNING")
c_handler.setLevel(logging.WARNING)
if log_file is not None:
# print("level of `f_handler` is set INFO")
f_handler.setLevel(logging.INFO)
logger.setLevel(logging.INFO)
c_format = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
c_handler.setFormatter(c_format)
logger.addHandler(c_handler)
if log_file is not None:
f_format = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
f_handler.setFormatter(f_format)
logger.addHandler(f_handler)
return logger
[docs]def get_date_str(fmt: Optional[str] = None):
"""Get the current time in the :class:`str` format.
Parameters
----------
fmt : str, optional
Format of the string of date,
default to ``"%m-%d_%H-%M"``.
Returns
-------
str
Current time in the :class:`str` format.
"""
now = datetime.datetime.now()
date_str = now.strftime(fmt or "%m-%d_%H-%M")
return date_str
[docs]def list_sum(lst: Sequence[list]) -> list:
"""Sum a sequence of lists.
Parameters
----------
lst : Sequence[list]
The sequence of lists to obtain the summation.
Returns
-------
list
sum of `lst`,
i.e. if ``lst = [list1, list2, ...]``,
then ``l_sum = list1 + list2 + ...``.
"""
l_sum = reduce(lambda a, b: a + b, lst, [])
return l_sum
def read_log_txt(
fp: Union[str, bytes, os.PathLike],
epoch_startswith: str = "Train epoch_",
scalar_startswith: Union[str, Iterable[str]] = "train/|test/",
) -> pd.DataFrame:
"""Read from log txt file, in case tensorboard not working.
Parameters
----------
fp : `path-like`
Path to the log txt file.
epoch_startswith : str, default "Train epoch_"
Indicator of the start of the start of an epoch
scalar_startswith : str or Iterable[str], default "train/|test/"
Indicators of the scalar recordings.
If is :class:`str`, should be indicators separated by ``"|"``.
Returns
-------
summary : pandas.DataFrame
Scalars summary, in the format of a :class:`~pandas.DataFrame`.
"""
content = Path(fp).read_text().splitlines()
if isinstance(scalar_startswith, str):
field_pattern = f"({scalar_startswith})"
else:
field_pattern = f"""({"|".join(scalar_startswith)})"""
summary = []
new_line = None
for line in content:
if re.findall(f"{epoch_startswith}([\\d]+)", line):
if new_line:
summary.append(new_line)
epoch = re.findall(f"{epoch_startswith}([\\d]+)", line)[0]
new_line = {"epoch": epoch}
if re.findall(field_pattern, line):
field, val = line.split(":")[-2:]
field = field.strip()
val = float(val.strip())
new_line[field] = val
summary.append(new_line)
summary = pd.DataFrame(summary)
return summary
def read_event_scalars(
fp: Union[str, bytes, os.PathLike], keys: Optional[Union[str, Iterable[str]]] = None
) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]:
"""Read scalars from event file, in case tensorboard not working.
Parameters
----------
fp : `path-like`
Path to the event file.
keys : str or Iterable[str], optional
Field names of the scalars to read.
If is None, scalars of all fields will be read.
Returns
-------
summary : pandas.DataFrame or dict of pandas.DataFrame
The wall_time, step, value of the scalars.
"""
try:
from tensorflow.python.summary.event_accumulator import EventAccumulator
except Exception:
try:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
except Exception:
raise ImportError("cannot import `EventAccumulator` from `tensorflow` or `tensorboard`")
event_acc = EventAccumulator(fp)
event_acc.Reload()
if keys:
if isinstance(keys, str):
_keys = [keys]
else:
_keys = keys
else:
_keys = event_acc.scalars.Keys()
summary = {}
for k in _keys:
df = pd.DataFrame([[item.wall_time, item.step, item.value] for item in event_acc.scalars.Items(k)])
df.columns = ["wall_time", "step", "value"]
summary[k] = df
if isinstance(keys, str):
summary = summary[k]
return summary
[docs]def dicts_equal(d1: dict, d2: dict, allow_array_diff_types: bool = True) -> bool:
"""Determine if two dicts are equal.
Parameters
----------
d1, d2 : dict
The two dicts to compare equality.
allow_array_diff_types : bool, default True
Whether allow the equality of two arrays with different types,
including `list`, `tuple`, `numpy.ndarray`, `torch.Tensor`,
**NOT** including `pandas.DataFrame`, `pandas.Series`.
Returns
-------
bool
True if `d1` equals `d2`, False otherwise.
NOTE
----
The existence of :class:`~numpy.ndarray`, :class:`~torch.Tensor`,
:class:`~pandas.DataFrame` and :class:`~pandas.Series` would probably
cause errors when directly use the default ``__eq__`` method of :class:`dict`
For example:
.. code-block:: python
>>> {"a": np.array([1,2])} == {"a": np.array([1,2])}
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Example
-------
>>> d1 = {"a": pd.DataFrame([{"hehe":1,"haha":2}])[["haha","hehe"]]}
>>> d2 = {"a": pd.DataFrame([{"hehe":1,"haha":2}])[["hehe","haha"]]}
>>> dicts_equal(d1, d2)
True
"""
import torch
if len(d1) != len(d2):
return False
for k, v in d1.items():
if k not in d2:
return False
if not allow_array_diff_types and not isinstance(d2[k], type(v)):
return False
if allow_array_diff_types and isinstance(v, (list, tuple, np.ndarray, torch.Tensor)):
if not isinstance(d2[k], (list, tuple, np.ndarray, torch.Tensor)):
return False
if not np.array_equal(v, d2[k]):
return False
elif allow_array_diff_types and not isinstance(v, (list, tuple, np.ndarray, torch.Tensor)):
if not isinstance(d2[k], type(v)):
return False
if isinstance(v, dict):
if not dicts_equal(v, d2[k]):
return False
elif isinstance(v, (list, tuple, np.ndarray, torch.Tensor)):
return np.array_equal(v, d2[k])
elif isinstance(v, pd.DataFrame):
if v.shape != d2[k].shape or set(v.columns) != set(d2[k].columns):
# consider: should one check index be equal?
return False
# for c in v.columns:
# if not (v[c] == d2[k][c]).all():
# return False
if not (v.values == d2[k][v.columns].values).all():
return False
elif isinstance(v, pd.Series):
if v.shape != d2[k].shape or v.name != d2[k].name:
return False
if not (v == d2[k]).all():
return False
# TODO: consider whether there are any other dtypes that should be treated similarly
else: # other dtypes whose equality can be checked directly
if v != d2[k]:
return False
return True
[docs]def add_docstring(doc: str, mode: str = "replace") -> Callable:
"""Decorator to add docstring to a function or a class.
Parameters
----------
doc : str
The docstring to be added.
mode : {"replace", "append", "prepend"}, optional
The mode of the adding to the original docstring,
by default "replace", case insensitive.
"""
def decorator(func_or_cls: Callable) -> Callable:
if func_or_cls.__doc__ is None:
func_or_cls.__doc__ = ""
pattern = "(\\s^\n){1,}"
if mode.lower() == "replace":
func_or_cls.__doc__ = doc
elif mode.lower() == "append":
tmp = re.sub(pattern, "", func_or_cls.__doc__)
new_lines = 1 - (len(tmp) - len(tmp.rstrip("\n")))
tmp = re.sub(pattern, "", doc)
new_lines -= len(tmp) - len(tmp.lstrip("\n"))
new_lines = max(0, new_lines) * "\n"
func_or_cls.__doc__ += new_lines + doc
elif mode.lower() == "prepend":
tmp = re.sub(pattern, "", doc)
new_lines = 1 - (len(tmp) - len(tmp.rstrip("\n")))
tmp = re.sub(pattern, "", func_or_cls.__doc__)
new_lines -= len(tmp) - len(tmp.lstrip("\n"))
new_lines = max(0, new_lines) * "\n"
func_or_cls.__doc__ = doc + new_lines + func_or_cls.__doc__
else:
raise ValueError(f"mode `{mode}` is not supported")
return func_or_cls
return decorator
[docs]def default_class_repr(c: object, align: str = "center", depth: int = 1) -> str:
"""Default class representation.
Parameters
----------
c : object
The object to be represented.
align : str, default "center"
Alignment of the class arguments.
depth : int, default 1
Depth of the class arguments to be displayed.
Returns
-------
str
The representation of the class.
"""
indent = 4 * depth * " "
closing_indent = 4 * (depth - 1) * " "
if not hasattr(c, "extra_repr_keys"):
return repr(c)
elif len(c.extra_repr_keys()) > 0:
max_len = max([len(k) for k in c.extra_repr_keys()])
extra_str = (
"(\n"
+ ",\n".join(
[
f"""{indent}{k.ljust(max_len, " ") if align.lower() in ["center", "c"] else k} = {default_class_repr(eval(f"c.{k}"),align,depth+1)}"""
for k in c.__dir__()
if k in c.extra_repr_keys()
]
)
+ f"{closing_indent}\n)"
)
else:
extra_str = ""
return f"{c.__class__.__name__}{extra_str}"
[docs]class ReprMixin(object):
"""Mixin class for enhanced
:meth:`__repr__` and :meth:`__str__` methods.
"""
def __repr__(self) -> str:
return default_class_repr(self)
__str__ = __repr__
[docs]class CitationMixin(_CitationMixin):
"""Mixin class for getting citations from DOIs."""
# backwar compatibility
if (_DATA_CACHE / "database_citation.csv").exists():
try:
df_old = pd.read_csv(_DATA_CACHE / "database_citation.csv")
except pd.errors.EmptyDataError:
df_old = pd.DataFrame(columns=["doi", "citation"])
if set(df_old.columns) != set(["doi", "citation"]):
df_old = pd.DataFrame(columns=["doi", "citation"])
df_old = df_old[["doi", "citation"]]
if _CitationMixin.citation_cache.exists():
df = pd.read_csv(_CitationMixin.citation_cache)
else:
df = pd.DataFrame(columns=["doi", "citation"])
# merge the old and new tables and drop duplicates
df = pd.concat([df, df_old], axis=0, ignore_index=True)
df = df.drop_duplicates(subset="doi", keep="first")
df = df.reset_index(drop=True)
df.to_csv(_CitationMixin.citation_cache, index=False)
del df_old, df
# delete the old cache
(_DATA_CACHE / "database_citation.csv").unlink()
[docs] def get_citation(
self,
lookup: bool = True,
format: Optional[str] = None,
style: Optional[str] = None,
timeout: Optional[float] = None,
print_result: bool = True,
) -> Union[str, type(None)]:
"""Get bib citation from DOIs.
Overrides the default method to make the `print_result` argument
have default value ``True``.
Parameters
----------
lookup : bool, default True
Whether to look up the citation from the cache.
format : str, optional
The format of the citation. If not specified, the citation
will be returned in the default format (bibtex).
style : str, optional
The style of the citation. If not specified, the citation
will be returned in the default style (apa).
Valid only when `format` is ``"text"``.
timeout : float, optional
The timeout for the request.
print_result : bool, default True
Whether to print the citation.
Returns
-------
str or None
bib citation(s) from the DOI(s),
or None if `print_result` is True.
"""
return super().get_citation(
lookup=lookup,
format=format,
style=style,
timeout=timeout,
print_result=print_result,
)
[docs]class MovingAverage(object):
"""Class for computing moving average.
For more information, see [#ma_wiki]_.
Parameters
----------
data : array_like, optional
The series data to compute its moving average.
kwargs : dict, optional
Auxilliary keyword arguments
References
----------
.. [#ma_wiki] https://en.wikipedia.org/wiki/Moving_average
"""
def __init__(self, data: Optional[Sequence] = None, **kwargs: Any) -> None:
if data is None:
self.data = np.array([])
else:
self.data = np.array(data)
self.verbose = kwargs.get("verbose", 0)
def __call__(self, data: Optional[Sequence] = None, method: str = "ema", **kwargs: Any) -> np.ndarray:
"""Compute moving average.
Parameters
----------
data : array_like, optional
The series data to compute its moving average.
method : str
method for computing moving average, can be one of
- "sma", "simple", "simple moving average";
- "ema", "ewma", "exponential", "exponential weighted",
"exponential moving average", "exponential weighted moving average";
- "cma", "cumulative", "cumulative moving average";
- "wma", "weighted", "weighted moving average".
kwargs : dict, optional
Keyword arguments for the specific moving average method.
Returns
-------
ma : numpy.ndarray
The moving average of the input data.
"""
m = method.lower().replace("_", " ")
if m in ["sma", "simple", "simple moving average"]:
func = self._sma
elif m in [
"ema",
"ewma",
"exponential",
"exponential weighted",
"exponential moving average",
"exponential weighted moving average",
]:
func = self._ema
elif m in ["cma", "cumulative", "cumulative moving average"]:
func = self._cma
elif m in ["wma", "weighted", "weighted moving average"]:
func = self._wma
else:
raise NotImplementedError(f"method `{method}` is not implemented yet")
if data is not None:
self.data = np.array(data)
return func(**kwargs)
def _sma(self, window: int = 5, center: bool = False, **kwargs: Any) -> np.ndarray:
"""Simple moving average.
Parameters
----------
window : int, default 5
Window length of the moving average
center : bool, default False
If True, when computing the output value at each point,
the window will be centered at that point;
otherwise the previous `window` points of the current point will be used.
Returns
-------
numpy.ndarray
The simple moving average of the input data.
"""
if len(kwargs) > 0:
warnings.warn(
f"the following arguments are not used: `{kwargs}` for simple moving average",
RuntimeWarning,
)
smoothed = []
if center:
hw = window // 2
window = hw * 2 + 1
for n in range(window):
smoothed.append(np.mean(self.data[: n + 1]))
prev = smoothed[-1]
for n, d in enumerate(self.data[window:]):
s = prev + (d - self.data[n]) / window
prev = s
smoothed.append(s)
smoothed = np.array(smoothed)
if center:
smoothed[hw:-hw] = smoothed[window - 1 :]
for n in range(hw):
smoothed[n] = np.mean(self.data[: n + hw + 1])
smoothed[-n - 1] = np.mean(self.data[-n - hw - 1 :])
return smoothed
def _ema(self, weight: float = 0.6, **kwargs: Any) -> np.ndarray:
"""Exponential moving average
This is also the function used in Tensorboard Scalar panel,
whose parameter `smoothing` is the `weight` here.
Parameters
----------
weight : float, default 0.6
Weight of the previous data point.
Returns
-------
numpy.ndarray
The exponential moving average of the input data.
"""
if len(kwargs) > 0:
warnings.warn(
f"the following arguments are not used: `{kwargs}` for exponential moving average",
RuntimeWarning,
)
smoothed = []
prev = self.data[0]
for d in self.data:
s = prev * weight + (1 - weight) * d
prev = s
smoothed.append(s)
smoothed = np.array(smoothed)
return smoothed
def _cma(self, **kwargs) -> np.ndarray:
"""Cumulative moving average.
Parameters
----------
None
Returns
-------
numpy.ndarray
The cumulative moving average of the input data.
"""
if len(kwargs) > 0:
warnings.warn(
f"the following arguments are not used: `{kwargs}` for cumulative moving average",
RuntimeWarning,
)
smoothed = []
prev = 0
for n, d in enumerate(self.data):
s = prev + (d - prev) / (n + 1)
prev = s
smoothed.append(s)
smoothed = np.array(smoothed)
return smoothed
def _wma(self, window: int = 5, **kwargs: Any) -> np.ndarray:
"""Weighted moving average.
Parameters
----------
window : int, default 5
Window length of the moving average.
Returns
-------
numpy.ndarray
The weighted moving average of the input data.
"""
if len(kwargs) > 0:
warnings.warn(
f"the following arguments are not used: `{kwargs}` for weighted moving average",
RuntimeWarning,
)
conv = np.arange(1, window + 1)[::-1]
deno = np.sum(conv)
smoothed = np.convolve(conv, self.data, mode="same") / deno
return smoothed
[docs]def nildent(text: str) -> str:
"""
Kill all leading white spaces in each line of `text`,
while keeping all lines (including empty)
Parameters
----------
text : str
Text to be processed.
Returns
-------
str
Processed text.
"""
new_text = "\n".join([line.lstrip() for line in text.splitlines()]) + ("\n" if text.endswith("\n") else "")
return new_text
[docs]def remove_parameters_returns_from_docstring(
doc: str,
parameters: Optional[Union[str, List[str]]] = None,
returns: Optional[Union[str, List[str]]] = None,
parameters_indicator: str = "Parameters",
returns_indicator: str = "Returns",
) -> str:
"""Remove parameters and/or returns from docstring,
which is of the format of `numpydoc`.
Parameters
----------
doc : str
Docstring to be processed.
parameters : str or List[str], optional
Parameters to be removed.
returns : str or List[str], optional
Returned values to be removed.
parameters_indicator : str, default "Parameters"
Indicator of the parameters section.
returns_indicator : str, default "Returns"
Indicator of the returns section.
Returns
-------
str
The processed docstring.
TODO
----
When one section is empty, remove the whole section,
or add a line of `None` to the section.
"""
if parameters is None:
parameters = []
elif isinstance(parameters, str):
parameters = [parameters]
if returns is None:
returns = []
elif isinstance(returns, str):
returns = [returns]
new_doc = doc.splitlines()
parameters_indent = None
returns_indent = None
start_idx = None
parameters_starts = False
returns_starts = False
indices2remove = []
for idx, line in enumerate(new_doc):
if (
line.strip().startswith(parameters_indicator)
and idx < len(new_doc) - 1
and new_doc[idx + 1].strip() == "-" * len(parameters_indicator)
):
parameters_indent = " " * line.index(parameters_indicator)
parameters_starts = True
returns_starts = False
if (
line.strip().startswith(returns_indicator)
and idx < len(new_doc) - 1
and new_doc[idx + 1].strip() == "-" * len(returns_indicator)
):
returns_indent = " " * line.index(returns_indicator)
returns_starts = True
parameters_starts = False
if start_idx is not None and len(line.strip()) == 0:
indices2remove.extend(list(range(start_idx, idx)))
start_idx = None
if parameters_starts and len(line.lstrip()) == len(line) - len(parameters_indent):
if any([line.lstrip().startswith(p) for p in parameters]):
if start_idx is not None:
indices2remove.extend(list(range(start_idx, idx)))
start_idx = idx
elif start_idx is not None:
if line.lstrip().startswith(returns_indicator) and len(new_doc[idx - 1].strip()) == 0:
indices2remove.extend(list(range(start_idx, idx - 1)))
else:
indices2remove.extend(list(range(start_idx, idx)))
start_idx = None
if returns_starts and len(line.lstrip()) == len(line) - len(returns_indent):
if any([line.lstrip().startswith(p) for p in returns]):
if start_idx is not None:
indices2remove.extend(list(range(start_idx, idx)))
start_idx = idx
elif start_idx is not None:
indices2remove.extend(list(range(start_idx, idx)))
start_idx = None
if start_idx is not None:
indices2remove.extend(list(range(start_idx, len(new_doc))))
new_doc.extend(["\n", parameters_indicator or returns_indicator])
new_doc = "\n".join([line for idx, line in enumerate(new_doc) if idx not in indices2remove])
return new_doc
[docs]@contextmanager
def timeout(duration: float):
"""A context manager that raises a
:class:`TimeoutError` after a specified time.
Modified from [#timeout]_.
Parameters
----------
duration : float
The time duration in seconds,
should be non-negative, 0 for no timeout.
References
----------
.. [#timeout] https://stackoverflow.com/questions/492519/timeout-on-a-function-call
"""
if np.isinf(duration):
duration = 0
elif duration < 0:
raise ValueError("`duration` must be non-negative")
elif duration > 0: # granularity is 1 second, so round up
duration = max(1, int(duration))
def timeout_handler(signum, frame):
raise TimeoutError(f"block timedout after `{duration}` seconds")
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(duration)
yield
signal.alarm(0)
[docs]class Timer(ReprMixin):
"""Context manager to time the execution of a block of code.
Parameters
----------
name : str, optional
Name of the timer, defaults to "default timer".
verbose : int, default 0
Verbosity level of the timer.
Example
-------
>>> with Timer("task name", verbose=2) as timer:
... do_something()
... timer.add_time("subtask 1", level=2)
... do_subtask_1()
... timer.stop_timer("subtask 1")
... timer.add_time("subtask 2", level=2)
... do_subtask_2()
... timer.stop_timer("subtask 2")
... do_something_else()
"""
__name__ = "Timer"
def __init__(self, name: Optional[str] = None, verbose: int = 0) -> None:
self.name = name or "default timer"
self.verbose = verbose
self.timers = {self.name: 0.0}
self.ends = {self.name: 0.0}
self.levels = {self.name: 1}
def __enter__(self) -> "Timer":
self.timers = {self.name: time.perf_counter()}
self.ends = {self.name: 0.0}
self.levels = {self.name: 1}
return self
def __exit__(self, *args) -> None:
for k in self.timers:
self.stop_timer(k)
self.timers[k] = self.ends[k] - self.timers[k]
[docs] def add_timer(self, name: str, level: int = 1) -> None:
"""Add a new timer for some sub-task.
Parameters
----------
name : str
Name of the timer to be added.
level : int, default 1
Verbosity level of the timer.
Returns
-------
None
"""
self.timers[name] = time.perf_counter()
self.ends[name] = 0
self.levels[name] = level
[docs] def stop_timer(self, name: str) -> None:
"""Stop a timer.
Parameters
----------
name : str
Name of the timer to be stopped.
Returns
-------
None
"""
if self.ends[name] == 0:
self.ends[name] = time.perf_counter()
if self.verbose >= self.levels[name]:
time_cost, unit = self._simplify_time_expr(self.ends[name] - self.timers[name])
print(f"{name} took {time_cost:.4f} {unit}")
def _simplify_time_expr(self, time_cost: float) -> Tuple[float, str]:
"""Simplify the time expression.
Parameters
----------
time_cost : float
The time cost, with units in seconds.
Returns
-------
time_cost : float
The time cost.
unit : str
Unit of the time cost.
"""
if time_cost <= 0.1:
return 1000 * time_cost, "ms"
return time_cost, "s"
[docs]def get_kwargs(func_or_cls: callable, kwonly: bool = False) -> Dict[str, Any]:
"""Get the kwargs of a function or class.
Parameters
----------
func_or_cls : callable
The function or class to get the kwargs of.
kwonly : bool, default False
Whether to get the kwonly kwargs of the function or class.
Returns
-------
kwargs : Dict[str, Any]
The kwargs of the function or class.
"""
fas = inspect.getfullargspec(func_or_cls)
kwargs = {}
if fas.kwonlydefaults is not None:
kwargs = deepcopy(fas.kwonlydefaults)
if not kwonly and fas.defaults is not None:
kwargs.update({k: v for k, v in zip(fas.args[-len(fas.defaults) :], fas.defaults)})
if len(kwargs) == 0:
# perhaps `inspect.getfullargspec` does not work
# we should use `inspect.signature` instead
# for example, the model init functions defined in
# https://github.com/pytorch/vision/blob/release/0.13/torchvision/models/resnet.py
# TODO: discard old code, and use only this block
signature = inspect.signature(func_or_cls)
valid_kinds = [inspect.Parameter.KEYWORD_ONLY]
if not kwonly:
valid_kinds.append(inspect.Parameter.POSITIONAL_OR_KEYWORD)
for k, v in signature.parameters.items():
if v.default is not inspect.Parameter.empty and v.kind in valid_kinds:
kwargs[k] = v.default
return kwargs
[docs]def get_required_args(func_or_cls: callable) -> List[str]:
"""Get the required positional arguments of a function or class.
Parameters
----------
func_or_cls : callable
The function or class to get the required arguments of.
Returns
-------
required_args : List[str]
Names of required arguments of the function or class.
"""
signature = inspect.signature(func_or_cls)
required_args = [
k
for k, v in signature.parameters.items()
if v.default is inspect.Parameter.empty
and v.kind in [inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY]
]
return required_args
[docs]def add_kwargs(func: callable, **kwargs: Any) -> callable:
"""Add keyword arguments to a function.
This function is used to add keyword arguments to a function
in order to make it compatible with other functions。
Parameters
----------
func : callable
The function to be decorated.
kwargs : dict
The keyword arguments to be added.
Returns
-------
callable
The decorated function, with the keyword arguments added.
"""
old_kwargs = get_kwargs(func)
func_signature = inspect.signature(func)
func_parameters = func_signature.parameters.copy() # ordered dict
full_kwargs = deepcopy(old_kwargs)
kind = inspect.Parameter.POSITIONAL_OR_KEYWORD
for k, v in func_parameters.items():
if v.kind == inspect.Parameter.KEYWORD_ONLY:
kind = inspect.Parameter.KEYWORD_ONLY
break
for k, v in kwargs.items():
if k in old_kwargs:
raise ValueError(f"keyword argument `{k}` already exists!")
full_kwargs[k] = v
func_parameters[k] = inspect.Parameter(k, kind, default=v)
# move the VAR_POSITIONAL and VAR_KEYWORD in `func_parameters` to the end
for k, v in func_parameters.items():
if v.kind == inspect.Parameter.VAR_POSITIONAL:
func_parameters.move_to_end(k)
break
for k, v in func_parameters.items():
if v.kind == inspect.Parameter.VAR_KEYWORD:
func_parameters.move_to_end(k)
break
if isinstance(func, types.MethodType):
# can not assign `__signature__` to a bound method directly
func.__func__.__signature__ = func_signature.replace(parameters=func_parameters.values())
else:
func.__signature__ = func_signature.replace(parameters=func_parameters.values())
# docstring is automatically copied by `functools.wraps`
@wraps(func)
def wrapper(*args: Any, **kwargs_: Any) -> Any:
assert set(kwargs_).issubset(full_kwargs), (
"got unexpected keyword arguments: " f"{list(set(kwargs_).difference(full_kwargs))}"
)
filtered_kwargs = {k: v for k, v in kwargs_.items() if k in old_kwargs}
return func(*args, **filtered_kwargs)
return wrapper
def make_serializable(x: Union[np.ndarray, np.generic, dict, list, tuple]) -> Union[list, dict, Number]:
"""Make an object serializable.
This function is used to convert all numpy arrays to list in an object,
and also convert numpy data types to python data types in the object,
so that it can be serialized by :mod:`json`.
Parameters
----------
x : Union[numpy.ndarray, numpy.generic, dict, list, tuple]
Input data, which can be numpy array (or numpy data type),
or dict, list, tuple containing numpy arrays (or numpy data type).
Returns
-------
Union[list, dict, numbers.Number]
Converted data.
Examples
--------
>>> import numpy as np
>>> from fl_sim.utils.misc import make_serializable
>>> x = np.array([1, 2, 3])
>>> make_serializable(x)
[1, 2, 3]
>>> x = {"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])}
>>> make_serializable(x)
{'a': [1, 2, 3], 'b': [4, 5, 6]}
>>> x = [np.array([1, 2, 3]), np.array([4, 5, 6])]
>>> make_serializable(x)
[[1, 2, 3], [4, 5, 6]]
>>> x = (np.array([1, 2, 3]), np.array([4, 5, 6]).mean())
>>> obj = make_serializable(x)
>>> obj
[[1, 2, 3], 5.0]
>>> type(obj[1]), type(x[1])
(float, numpy.float64)
"""
if isinstance(x, np.ndarray):
return x.tolist()
elif isinstance(x, (list, tuple)):
# to avoid cases where the list contains numpy data types
return [make_serializable(v) for v in x]
elif isinstance(x, dict):
for k, v in x.items():
x[k] = make_serializable(v)
elif isinstance(x, np.generic):
return x.item()
# the other types will be returned directly
return x