FedVisionDataset#
- class fl_sim.data_processing.FedVisionDataset(datadir: Path | str | None = None, transform: str | Callable | None = 'none', seed: int = 0, **extra_config: Any)[source]#
Bases:
FedDataset
,ABC
Base class for all federated vision datasets.
Methods that have to be implemented by subclasses:
get_dataloader
_preload
evaluate
Properties that have to be implemented by subclasses:
url
candidate_models
doi
label_map
- Parameters:
datadir (Union[pathlib.Path, str], optional) – Directory to store data. If
None
, use default directory.transform (Union[str, Callable], default "none") – Transform to apply to data. Conventions:
"none"
means no transform, using TensorDataset;None
for default transform from torchvision.seed (int, default 0) – Random seed for data partitioning.
**extra_config (dict, optional) – Extra configurations.
- get_class(label: Tensor) str [source]#
Get class name from label.
- Parameters:
label (torch.Tensor) – Label.
- Returns:
Class name.
- Return type:
- get_classes(labels: Tensor) List[str] [source]#
Get class names from labels.
- Parameters:
labels (torch.Tensor) – Labels.
- Returns:
Class names.
- Return type:
List[str]
- abstract get_dataloader(train_bs: int, test_bs: int, client_idx: int | None = None) Tuple[DataLoader, DataLoader] [source]#
Get dataloader for client client_idx or get global dataloader.
- load_partition_data(batch_size: int | None = None) tuple [source]#
Partition data into all local clients.
- Parameters:
batch_size (int, optional) – Batch size for dataloader. If
None
, use default batch size.- Returns:
- train_clients_num:
int
Number of training clients.
- train_clients_num:
- train_data_num:
int
Number of training data.
- train_data_num:
- test_data_num:
int
Number of testing data.
- test_data_num:
- train_data_global:
torch.utils.data.DataLoader
Global training dataloader.
- train_data_global:
- test_data_global:
torch.utils.data.DataLoader
Global testing dataloader.
- test_data_global:
- data_local_num_dict:
dict
Number of local training data for each client.
- data_local_num_dict:
- train_data_local_dict:
dict
Local training dataloader for each client.
- train_data_local_dict:
- test_data_local_dict:
dict
Local testing dataloader for each client.
- test_data_local_dict:
- n_class:
int
Number of classes.
- n_class:
- Return type:
- load_partition_data_distributed(process_id: int, batch_size: int | None = None) tuple [source]#
Get local dataloader at client process_id or get global dataloader.
- Parameters:
- Returns:
- train_clients_num:
int
Number of training clients.
- train_clients_num:
- train_data_num:
int
Number of training data.
- train_data_num:
- train_data_global:
torch.utils.data.DataLoader
or None Global training dataloader.
- train_data_global:
- test_data_global:
torch.utils.data.DataLoader
or None Global testing dataloader.
- test_data_global:
- local_data_num:
int
Number of local training data.
- local_data_num:
- train_data_local:
torch.utils.data.DataLoader
or None Local training dataloader.
- train_data_local:
- test_data_local:
torch.utils.data.DataLoader
or None Local testing dataloader.
- test_data_local:
- n_class:
int
Number of classes.
- n_class:
- Return type:
- static show_image(tensor: Tensor | ndarray) Image [source]#
Show image from tensor.
- Parameters:
tensor (Union[torch.Tensor, np.ndarray]) – Image tensor with shape
(C, H, W)
or(H, W, C)
or(H, W)
, where C is channel, H is height, W is width. C must be 1 or 3.- Returns:
PIL image.
- Return type:
Image.Image