Source code for fl_sim.data_processing.fed_shakespeare

from collections import OrderedDict
from itertools import repeat
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple, Union

import h5py
import numpy as np
import torch  # noqa: F401
import torch.utils.data as torchdata

from ..models import nn as mnn
from ..models.utils import top_n_accuracy
from ..utils.const import CACHED_DATA_DIR
from ._register import register_fed_dataset
from .fed_dataset import FedNLPDataset

__all__ = [
    "FedShakespeare",
]


FED_SHAKESPEARE_DATA_DIR = CACHED_DATA_DIR / "fed_shakespeare"
FED_SHAKESPEARE_DATA_DIR.mkdir(parents=True, exist_ok=True)


[docs]@register_fed_dataset() class FedShakespeare(FedNLPDataset): """Federated Shakespeare dataset. Shakespeare dataset is built from the collective works of William Shakespeare. This dataset is used to perform tasks of next character prediction. FedML [1]_ loaded data from TensorFlow Federated (TFF) shakespeare load_data API [2]_ and saved the unzipped data into hdf5 files. Data partition is the same as TFF, with the following statistics. +-------------+---------------+----------------+--------------+---------------+ | DATASET | TRAIN CLIENTS | TRAIN EXAMPLES | TEST CLIENTS | TEST EXAMPLES | +=============+===============+================+==============+===============+ | SHAKESPEARE | 715 | 16,068 | 715 | 2356 | +-------------+---------------+----------------+--------------+---------------+ Each client corresponds to a speaking role with at least two lines. Parameters ---------- datadir : Union[str, pathlib.Path], optional The directory to store the dataset. If ``None``, use default directory. seed : int, default 0 The random seed. **extra_config : dict, optional Extra configurations. References ---------- .. [1] https://github.com/FedML-AI/FedML/tree/master/python/fedml/data/fed_shakespeare .. [2] https://www.tensorflow.org/federated/api_docs/python/tff/simulation/datasets/shakespeare/load_data """ SEQUENCE_LENGTH = 80 # from McMahan et al AISTATS 2017 # Vocabulary re-used from the Federated Learning for Text Generation tutorial. # https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation CHAR_VOCAB = list("dhlptx@DHLPTX $(,048cgkoswCGKOSW[_#'/37;?bfjnrvzBFJNRVZ\"&*.26:\naeimquyAEIMQUY]!%)-159\r") _pad = "<pad>" _bos = "<bos>" _eos = "<eos>" _oov = "<oov>" _words = [_pad] + CHAR_VOCAB + [_bos] + [_eos] word_dict = OrderedDict({w: i for i, w in enumerate(_words)}) __name__ = "FedShakespeare" def _preload(self, datadir: Optional[Union[str, Path]] = None) -> None: """Preload the dataset. Parameters ---------- datadir : Union[pathlib.Path, str], optional Directory to store data. If ``None``, use default directory. Returns ------- None """ self.datadir = Path(datadir or FED_SHAKESPEARE_DATA_DIR).expanduser().resolve() self.DEFAULT_TRAIN_CLIENTS_NUM = 715 self.DEFAULT_TEST_CLIENTS_NUM = 715 self.DEFAULT_BATCH_SIZE = 4 self.DEFAULT_TRAIN_FILE = "shakespeare_train.h5" self.DEFAULT_TEST_FILE = "shakespeare_test.h5" # group name defined by tff in h5 file self._EXAMPLE = "examples" self._SNIPPETS = "snippets" self.criterion = torch.nn.CrossEntropyLoss(ignore_index=0) self.download_if_needed() train_file_path = self.datadir / self.DEFAULT_TRAIN_FILE test_file_path = self.datadir / self.DEFAULT_TEST_FILE with h5py.File(str(train_file_path), "r") as train_h5, h5py.File(str(test_file_path), "r") as test_h5: self._client_ids_train = list(train_h5[self._EXAMPLE].keys()) self._client_ids_test = list(test_h5[self._EXAMPLE].keys())
[docs] def get_dataloader( self, train_bs: Optional[int] = None, test_bs: Optional[int] = None, client_idx: Optional[int] = None, ) -> Tuple[torchdata.DataLoader, torchdata.DataLoader]: """Get local dataloader at client `client_idx` or get the global dataloader. Parameters ---------- train_bs : int, optional Batch size for training dataloader. If ``None``, use default batch size. test_bs : int, optional Batch size for testing dataloader. If ``None``, use default batch size. client_idx : int, optional Index of the client to get dataloader. If ``None``, get the dataloader containing all data. Usually used for centralized training. Returns ------- train_dl : :class:`torch.utils.data.DataLoader` Training dataloader. test_dl : :class:`torch.utils.data.DataLoader` Testing dataloader. """ train_h5 = h5py.File(str(self.datadir / self.DEFAULT_TRAIN_FILE), "r") test_h5 = h5py.File(str(self.datadir / self.DEFAULT_TEST_FILE), "r") train_ds = [] test_ds = [] # load data if client_idx is None: # get ids of all clients train_ids = self._client_ids_train test_ids = self._client_ids_test else: # get ids of single client train_ids = [self._client_ids_train[client_idx]] test_ids = [self._client_ids_test[client_idx]] for client_id in train_ids: raw_train = train_h5[self._EXAMPLE][client_id][self._SNIPPETS][()] raw_train = [x.decode("utf8") for x in raw_train] train_ds.extend(self.preprocess(raw_train)) for client_id in test_ids: raw_test = test_h5[self._EXAMPLE][client_id][self._SNIPPETS][()] raw_test = [x.decode("utf8") for x in raw_test] test_ds.extend(self.preprocess(raw_test)) # split data train_x, train_y = FedShakespeare._split_target(train_ds) test_x, test_y = FedShakespeare._split_target(test_ds) train_ds = torchdata.TensorDataset(torch.tensor(train_x), torch.tensor(train_y)) test_ds = torchdata.TensorDataset(torch.tensor(test_x), torch.tensor(test_y)) train_dl = torchdata.DataLoader( dataset=train_ds, batch_size=train_bs or self.DEFAULT_BATCH_SIZE, shuffle=True, drop_last=False, ) test_dl = torchdata.DataLoader( dataset=test_ds, batch_size=test_bs or self.DEFAULT_BATCH_SIZE, shuffle=True, drop_last=False, ) train_h5.close() test_h5.close() return train_dl, test_dl
@staticmethod def _split_target(sequence_batch: List[int]) -> Tuple[np.ndarray, np.ndarray]: """Split a N + 1 sequence into shifted-by-1 sequences for input and output.""" sequence_batch = np.asarray(sequence_batch) input_text = sequence_batch[..., :-1] target_text = sequence_batch[..., 1:] return (input_text, target_text)
[docs] def preprocess(self, sentences: Sequence[str], max_seq_len: Optional[int] = None) -> List[List[int]]: """Preprocess a list of sentences. Parameters ---------- sentences : Sequence[str] List of sentences to be preprocessed. max_seq_len : int, optional Maximum sequence length. If ``None``, use default sequence length. Returns ------- List[List[int]] List of tokenized sentences. """ sequences = [] if max_seq_len is None: max_seq_len = self.SEQUENCE_LENGTH def to_ids(sentence: str, num_oov_buckets: int = 1) -> Tuple[List[int]]: """Map list of sentence to list of ``[idx..]`` and pad to ``max_seq_len + 1``. Parameters ---------- sentence : str Sentence to be converted. num_oov_buckets : int, default 1 The number of out of vocabulary buckets. Returns ------- Tuple[List[int]] List of tokenized sentence. """ tokens = [self.char_to_id(c) for c in sentence] tokens = [self.char_to_id(self._bos)] + tokens + [self.char_to_id(self._eos)] if len(tokens) % (max_seq_len + 1) != 0: pad_length = (-len(tokens)) % (max_seq_len + 1) tokens += list(repeat(self.char_to_id(self._pad), pad_length)) return (tokens[i : i + max_seq_len + 1] for i in range(0, len(tokens), max_seq_len + 1)) for sen in sentences: sequences.extend(to_ids(sen)) return sequences
[docs] def id_to_word(self, idx: int) -> str: """Convert an integer index to a character.""" return self.words[idx]
[docs] def char_to_id(self, char: str) -> int: """Convert a character to an integer index.""" return self.word_dict.get(char, len(self.word_dict))
@property def words(self) -> List[str]: """Get the word list.""" return self._words
[docs] def get_word_dict(self) -> Dict[str, int]: """Get the word dictionary.""" return self.word_dict
[docs] def evaluate(self, probs: torch.Tensor, truths: torch.Tensor) -> Dict[str, float]: """Evaluation using predictions and ground truth. Parameters ---------- probs : torch.Tensor Predicted probabilities. truths : torch.Tensor Ground truth labels. Returns ------- Dict[str, float] Evaluation results. """ return { "acc": top_n_accuracy(probs, truths, 1), "top3_acc": top_n_accuracy(probs, truths, 3), "top5_acc": top_n_accuracy(probs, truths, 5), "loss": self.criterion(probs, truths).item(), "num_samples": probs.shape[0], }
@property def url(self) -> str: """URL for downloading the dataset.""" return "https://fedml.s3-us-west-1.amazonaws.com/shakespeare.tar.bz2" @property def candidate_models(self) -> Dict[str, torch.nn.Module]: """A set of candidate models.""" return { "rnn": mnn.RNN_OriginalFedAvg(), } @property def doi(self) -> List[str]: """DOI(s) related to the dataset.""" return [ "10.48550/ARXIV.1812.06127", # FedProx "10.48550/ARXIV.2007.13518", # FedML ]
[docs] def view_sample(self, client_idx: int, sample_idx: Optional[int] = None) -> None: """View a sample from the dataset. Parameters ---------- client_idx : int Index of the client on which the sample is located. sample_idx : int Index of the sample in the client. Returns ------- None """ if client_idx >= len(self._client_ids_train): raise ValueError(f"client_idx must be less than {len(self._client_ids_train)}") client_id = self._client_ids_train[client_idx] # also test ids train_h5 = h5py.File(str(self.datadir / self.DEFAULT_TRAIN_FILE), "r") test_h5 = h5py.File(str(self.datadir / self.DEFAULT_TEST_FILE), "r") raw_train = train_h5[self._EXAMPLE][client_id][self._SNIPPETS][()] raw_train = [x.decode("utf8") for x in raw_train] raw_test = test_h5[self._EXAMPLE][client_id][self._SNIPPETS][()] raw_test = [x.decode("utf8") for x in raw_test] snippets = raw_train + raw_test new_line = "\n" + "-" * 50 + "\n" if sample_idx is not None: assert sample_idx < len(snippets), "sample_idx out of range" print(f"Client ID (Title):{new_line}{client_id}{new_line}") if sample_idx is None: print(f"Snippets:{new_line}{new_line.join([repr(x) for x in snippets])}") else: print(f"Snippet {sample_idx}:{new_line}{repr(snippets[sample_idx])}") train_h5.close() test_h5.close()