Source code for fl_sim.data_processing._register

import warnings
from typing import Any, List, Optional

from .fed_dataset import FedDataset

_built_in_fed_datasets = {}


[docs]def register_fed_dataset(name: Optional[str] = None, override: bool = True) -> Any: """Decorator to register a new federated dataset. Parameters ---------- name : str, optional Name of the federated dataset. If not specified, the class name will be used. override : bool, default True Whether to override the existing federated dataset with the same name. Returns ------- The decorated class. """ def wrapper(cls_: Any) -> Any: if name is None: if hasattr(cls_, "__name__"): _name = cls_.__name__ else: _name = cls_.__class__.__name__ else: _name = name assert issubclass(cls_, FedDataset), f"{cls_} is not a valid dataset" if _name in _built_in_fed_datasets: if override: _built_in_fed_datasets[_name] = cls_ else: # raise ValueError(f"{_name} has already been registered") warnings.warn(f"{_name} has already been registered", RuntimeWarning) else: _built_in_fed_datasets[_name] = cls_ return cls_ return wrapper
[docs]def list_fed_dataset() -> List[str]: """List all registered federated datasets.""" return list(_built_in_fed_datasets)
[docs]def get_fed_dataset(name: str) -> Any: """Get a registered federated dataset by name.""" if name not in _built_in_fed_datasets: raise ValueError(f"Federated dataset {name} is not registered") return _built_in_fed_datasets[name]