FedCIFAR#
- class fl_sim.data_processing.FedCIFAR(n_class: Literal[10, 100] = 100, datadir: Path | str | None = None, transform: str | Callable | None = 'none', seed: int = 0, **extra_config: Any)[source]#
- Bases: - FedVisionDataset- Federated CIFAR10/100 dataset. - This dataset is loaded from TensorFlow Federated (TFF) cifar100 load_data API [1], and saved as h5py files. This dataset is pre-divided into 500 training clients containing 50,000 examples in total, and 100 testing clients containing 10,000 examples in total. - The images are saved in the channel last format, i.e., - N x H x W x C, NOT the usual channel first format for PyTorch. A single image (and similarly for label and coarse_label) can be accessed by- with h5py.File(path, "r") as f: images = f["examples"]["0"]["image"][0] - where - pathis the path to the h5py file, “0” is the client id, and 0 is the index of the image in the client’s dataset.- Most methods in this class are adopted and modified from FedML [2]. - Parameters:
- n_class ({10, 100}, default 10) – Number of classes in the dataset. 10 for CIFAR10, 100 for CIFAR100. 
- datadir (str or pathlib.Path, default None) – Path to the dataset directory. Default: - None. If- None, will use built-in default directory.
- transform (str or callable, default "none") – Transformation to apply to the images. Default: - "none". If- "none", only static normalization will be applied. If callable, will be used as- transformargument for- VisionDataset. If- None, will use default dynamic augmentation transform.
- seed (int, default: 0) – Random seed for data shuffling. 
- **extra_config (dict, optional) – Extra configurations for the dataset. 
 
 - References - evaluate(probs: Tensor, truths: Tensor) Dict[str, float][source]#
- Evaluation using predictions and ground truth. - Parameters:
- probs (torch.Tensor) – Predicted probabilities. 
- truths (torch.Tensor) – Ground truth labels. 
 
- Returns:
- Evaluation results. 
- Return type:
 
 - get_dataloader(train_bs: int | None = None, test_bs: int | None = None, client_idx: int | None = None) Tuple[DataLoader, DataLoader][source]#
- Get local dataloader at client client_idx or get the global dataloader. - Parameters:
- train_bs (int, optional) – Batch size for training dataloader. If - None, use default batch size.
- test_bs (int, optional) – Batch size for testing dataloader. If - None, use default batch size.
- client_idx (int, optional) – Index of the client to get dataloader. If - None, get the dataloader containing all data. Usually used for centralized training.
 
- Returns:
- train_dl ( - torch.utils.data.DataLoader) – Training dataloader.
- test_dl ( - torch.utils.data.DataLoader) – Testing dataloader.