FedRotatedMNIST#
- class fl_sim.data_processing.FedRotatedMNIST(datadir: Path | str | None = None, num_rotations: int = 4, num_clients: int = 2400, transform: str | Callable | None = 'none', seed: int = 0)[source]#
Bases:
FedVisionDataset
MNIST dataset with rotation augmentation.
The rotations are fixed and are multiples of 360 / num_rotations [Ghosh et al.[1]].
The original MNIST dataset https://pytorch.org/vision/stable/_modules/torchvision/datasets/mnist.html#MNIST contains 60,000 training images and 10,000 test images. Images are 28x28 grayscale images in 10 classes (0-9 handwritten digits).
- Parameters:
datadir (str or pathlib.Path, optional) – Path to store the dataset. If not specified, the default path is used.
num_rotations (int, default 4) – Number of rotations to apply to the images in the dataset. Typical values are 2, 4.
num_clients (int, default 2400) – Number of clients to simulate. Typical values are 1200, 2400, 4800.
transform (str or callable, default 'none') – Transform (augmentation) to apply to the dataset. If ‘none’, no augmentation is applied, only the normalization transform is applied.
seed (int, default 0) – Random seed for reproducibility.
References
- evaluate(probs: Tensor, truths: Tensor) Dict[str, float] [source]#
Evaluation using predictions and ground truth.
- Parameters:
probs (torch.Tensor) – Predicted probabilities.
truths (torch.Tensor) – Ground truth labels.
- Returns:
Evaluation results.
- Return type:
- get_dataloader(train_bs: int | None = None, test_bs: int | None = None, client_idx: int | None = None) Tuple[DataLoader, DataLoader] [source]#
Get local dataloader at client client_idx or get the global dataloader.
- Parameters:
train_bs (int, optional) – Batch size for training dataloader. If
None
, use default batch size.test_bs (int, optional) – Batch size for testing dataloader. If
None
, use default batch size.client_idx (int, optional) – Index of the client to get dataloader. If
None
, get the dataloader containing all data. Usually used for centralized training.
- Returns:
train_dl (
torch.utils.data.DataLoader
) – Training dataloader.test_dl (
torch.utils.data.DataLoader
) – Testing dataloader.