class fl_sim.data_processing.FedMNIST(datadir: Path | str | None = None, transform: str | Callable | None = 'none', seed: int = 0, **extra_config: Any)[source]#

Bases: FedVisionDataset

MNIST is a dataset to study image classification of handwritten digits 0-9.

To simulate a heterogeneous setting, FedML distribute the data among 1000 devices such that each device has samples of only 2 digits and the number of samples per device follows a power law. This dataset is adopted from [1], which is also used in [2].

NOTE: the maximum value of the raw data is 264.2510681152344, which could lead to numerical instability. We normalize the data to range [0, 1].

  • datadir (Union[pathlib.Path, str], optional) – Directory to store data. If None, use default directory.

  • transform (Union[str, Callable], default "none") – Transform to apply to data. Conventions: "none" means no transform, using TensorDataset.

  • seed (int, default 0) – Random seed for data partitioning.

  • **extra_config (dict, optional) – Extra configurations.


property candidate_models: Dict[str, Module]#

A set of candidate models.

property doi: List[str]#

DOI(s) related to the dataset.

evaluate(probs: Tensor, truths: Tensor) Dict[str, float][source]#

Evaluation using predictions and ground truth.


Evaluation results.

Return type:

Dict[str, float]

extra_repr_keys() List[str][source]#

Extra keys for __repr__() and __str__().

get_dataloader(train_bs: int | None = None, test_bs: int | None = None, client_idx: int | None = None) Tuple[DataLoader, DataLoader][source]#

Get local dataloader at client client_idx or get the global dataloader.

  • train_bs (int, optional) – Batch size for training dataloader. If None, use default batch size.

  • test_bs (int, optional) – Batch size for testing dataloader. If None, use default batch size.

  • client_idx (int, optional) – Index of the client to get dataloader. If None, get the dataloader containing all data. Usually used for centralized training.


property label_map: dict#

Label map for the dataset.

random_grid_view(nrow: int, ncol: int, save_path: Path | str | None = None) None[source]#

Select randomly nrow x ncol images from the dataset and plot them in a grid.

  • nrow (int) – Number of rows in the grid.

  • ncol (int) – Number of columns in the grid.

  • save_path (Union[str, Path], optional) – Path to save the figure. If None, do not save the figure.

Return type:


property url: str#

URL for downloading the dataset.

view_image(client_idx: int, image_idx: int) None[source]#

View a single image.

  • client_idx (int) – Index of the client on which the image is located.

  • image_idx (int) – Index of the image in the client.

Return type:
