FedCIFAR#

class fl_sim.data_processing.FedCIFAR(n_class: int = 100, datadir: Path | str | None = None, transform: str | Callable | None = 'none', seed: int = 0, **extra_config: Any)[source]#

Bases: FedVisionDataset

Federated CIFAR10/100 dataset.

This dataset is loaded from TensorFlow Federated (TFF) cifar100 load_data API [1], and saved as h5py files. This dataset is pre-divided into 500 training clients containing 50,000 examples in total, and 100 testing clients containing 10,000 examples in total.

The images are saved in the channel last format, i.e., N x H x W x C, NOT the usual channel first format for PyTorch. A single image (and similarly for label and coarse_label) can be accessed by

with h5py.File(path, "r") as f:
    images = f["examples"]["0"]["image"][0]

where path is the path to the h5py file, “0” is the client id, and 0 is the index of the image in the client’s dataset.

Most methods in this class are adopted and modified from FedML [2].

Parameters:
  • n_class ({10, 100}, default 10) – Number of classes in the dataset. 10 for CIFAR10, 100 for CIFAR100.

  • datadir (str or pathlib.Path, default None) – Path to the dataset directory. Default: None. If None, will use built-in default directory.

  • transform (str or callable, default "none") – Transformation to apply to the images. Default: "none". If "none", only static normalization will be applied. If callable, will be used as transform argument for VisionDataset. If None, will use default dynamic augmentation transform.

  • seed (int, default: 0) – Random seed for data shuffling.

  • **extra_config (dict, optional) – Extra configurations for the dataset.

References

property candidate_models: Dict[str, Module]#

A set of candidate models.

property doi: str#

DOI(s) related to the dataset.

evaluate(probs: Tensor, truths: Tensor) Dict[str, float][source]#

Evaluation using predictions and ground truth.

Parameters:
Returns:

Evaluation results.

Return type:

Dict[str, float]

extra_repr_keys() List[str][source]#

Extra keys for __repr__() and __str__().

get_dataloader(train_bs: int | None = None, test_bs: int | None = None, client_idx: int | None = None) Tuple[DataLoader, DataLoader][source]#

Get local dataloader at client client_idx or get the global dataloader.

Parameters:
  • train_bs (int, optional) – Batch size for training dataloader. If None, use default batch size.

  • test_bs (int, optional) – Batch size for testing dataloader. If None, use default batch size.

  • client_idx (int, optional) – Index of the client to get dataloader. If None, get the dataloader containing all data. Usually used for centralized training.

Returns:

property label_map: dict#

Label map for the dataset.

random_grid_view(nrow: int, ncol: int, save_path: Path | str | None = None) None[source]#

Select randomly nrow x ncol images from the dataset and plot them in a grid.

Parameters:
  • nrow (int) – Number of rows in the grid.

  • ncol (int) – Number of columns in the grid.

  • save_path (Union[str, Path], optional) – Path to save the figure. If None, do not save the figure.

Return type:

None

view_image(client_idx: int, image_idx: int) None[source]#

View a single image.

Parameters:
  • client_idx (int) – Index of the client on which the image is located.

  • image_idx (int) – Index of the image in the client.

Return type:

None