from collections import OrderedDict
from itertools import repeat
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple, Union
import h5py
import numpy as np
import torch # noqa: F401
import torch.utils.data as torchdata
from ..models import nn as mnn
from ..models.utils import top_n_accuracy
from ..utils.const import CACHED_DATA_DIR
from ._register import register_fed_dataset
from .fed_dataset import FedNLPDataset
__all__ = [
"FedShakespeare",
]
FED_SHAKESPEARE_DATA_DIR = CACHED_DATA_DIR / "fed_shakespeare"
FED_SHAKESPEARE_DATA_DIR.mkdir(parents=True, exist_ok=True)
[docs]@register_fed_dataset()
class FedShakespeare(FedNLPDataset):
"""Federated Shakespeare dataset.
Shakespeare dataset is built from the collective works of William Shakespeare.
This dataset is used to perform tasks of next character prediction.
FedML [1]_ loaded data from TensorFlow Federated (TFF) shakespeare load_data API [2]_
and saved the unzipped data into hdf5 files.
Data partition is the same as TFF, with the following statistics.
+-------------+---------------+----------------+--------------+---------------+
| DATASET | TRAIN CLIENTS | TRAIN EXAMPLES | TEST CLIENTS | TEST EXAMPLES |
+=============+===============+================+==============+===============+
| SHAKESPEARE | 715 | 16,068 | 715 | 2356 |
+-------------+---------------+----------------+--------------+---------------+
Each client corresponds to a speaking role with at least two lines.
Parameters
----------
datadir : Union[str, pathlib.Path], optional
The directory to store the dataset.
If ``None``, use default directory.
seed : int, default 0
The random seed.
**extra_config : dict, optional
Extra configurations.
References
----------
.. [1] https://github.com/FedML-AI/FedML/tree/master/python/fedml/data/fed_shakespeare
.. [2] https://www.tensorflow.org/federated/api_docs/python/tff/simulation/datasets/shakespeare/load_data
"""
SEQUENCE_LENGTH = 80 # from McMahan et al AISTATS 2017
# Vocabulary re-used from the Federated Learning for Text Generation tutorial.
# https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation
CHAR_VOCAB = list("dhlptx@DHLPTX $(,048cgkoswCGKOSW[_#'/37;?bfjnrvzBFJNRVZ\"&*.26:\naeimquyAEIMQUY]!%)-159\r")
_pad = "<pad>"
_bos = "<bos>"
_eos = "<eos>"
_oov = "<oov>"
_words = [_pad] + CHAR_VOCAB + [_bos] + [_eos]
word_dict = OrderedDict({w: i for i, w in enumerate(_words)})
__name__ = "FedShakespeare"
def _preload(self, datadir: Optional[Union[str, Path]] = None) -> None:
"""Preload the dataset.
Parameters
----------
datadir : Union[pathlib.Path, str], optional
Directory to store data.
If ``None``, use default directory.
Returns
-------
None
"""
self.datadir = Path(datadir or FED_SHAKESPEARE_DATA_DIR).expanduser().resolve()
self.DEFAULT_TRAIN_CLIENTS_NUM = 715
self.DEFAULT_TEST_CLIENTS_NUM = 715
self.DEFAULT_BATCH_SIZE = 4
self.DEFAULT_TRAIN_FILE = "shakespeare_train.h5"
self.DEFAULT_TEST_FILE = "shakespeare_test.h5"
# group name defined by tff in h5 file
self._EXAMPLE = "examples"
self._SNIPPETS = "snippets"
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
self.download_if_needed()
train_file_path = self.datadir / self.DEFAULT_TRAIN_FILE
test_file_path = self.datadir / self.DEFAULT_TEST_FILE
with h5py.File(str(train_file_path), "r") as train_h5, h5py.File(str(test_file_path), "r") as test_h5:
self._client_ids_train = list(train_h5[self._EXAMPLE].keys())
self._client_ids_test = list(test_h5[self._EXAMPLE].keys())
[docs] def get_dataloader(
self,
train_bs: Optional[int] = None,
test_bs: Optional[int] = None,
client_idx: Optional[int] = None,
) -> Tuple[torchdata.DataLoader, torchdata.DataLoader]:
"""Get local dataloader at client `client_idx` or get the global dataloader.
Parameters
----------
train_bs : int, optional
Batch size for training dataloader.
If ``None``, use default batch size.
test_bs : int, optional
Batch size for testing dataloader.
If ``None``, use default batch size.
client_idx : int, optional
Index of the client to get dataloader.
If ``None``, get the dataloader containing all data.
Usually used for centralized training.
Returns
-------
train_dl : :class:`torch.utils.data.DataLoader`
Training dataloader.
test_dl : :class:`torch.utils.data.DataLoader`
Testing dataloader.
"""
train_h5 = h5py.File(str(self.datadir / self.DEFAULT_TRAIN_FILE), "r")
test_h5 = h5py.File(str(self.datadir / self.DEFAULT_TEST_FILE), "r")
train_ds = []
test_ds = []
# load data
if client_idx is None:
# get ids of all clients
train_ids = self._client_ids_train
test_ids = self._client_ids_test
else:
# get ids of single client
train_ids = [self._client_ids_train[client_idx]]
test_ids = [self._client_ids_test[client_idx]]
for client_id in train_ids:
raw_train = train_h5[self._EXAMPLE][client_id][self._SNIPPETS][()]
raw_train = [x.decode("utf8") for x in raw_train]
train_ds.extend(self.preprocess(raw_train))
for client_id in test_ids:
raw_test = test_h5[self._EXAMPLE][client_id][self._SNIPPETS][()]
raw_test = [x.decode("utf8") for x in raw_test]
test_ds.extend(self.preprocess(raw_test))
# split data
train_x, train_y = FedShakespeare._split_target(train_ds)
test_x, test_y = FedShakespeare._split_target(test_ds)
train_ds = torchdata.TensorDataset(torch.tensor(train_x), torch.tensor(train_y))
test_ds = torchdata.TensorDataset(torch.tensor(test_x), torch.tensor(test_y))
train_dl = torchdata.DataLoader(
dataset=train_ds,
batch_size=train_bs or self.DEFAULT_BATCH_SIZE,
shuffle=True,
drop_last=False,
)
test_dl = torchdata.DataLoader(
dataset=test_ds,
batch_size=test_bs or self.DEFAULT_BATCH_SIZE,
shuffle=True,
drop_last=False,
)
train_h5.close()
test_h5.close()
return train_dl, test_dl
@staticmethod
def _split_target(sequence_batch: List[int]) -> Tuple[np.ndarray, np.ndarray]:
"""Split a N + 1 sequence into shifted-by-1 sequences for input and output."""
sequence_batch = np.asarray(sequence_batch)
input_text = sequence_batch[..., :-1]
target_text = sequence_batch[..., 1:]
return (input_text, target_text)
[docs] def preprocess(self, sentences: Sequence[str], max_seq_len: Optional[int] = None) -> List[List[int]]:
"""Preprocess a list of sentences.
Parameters
----------
sentences : Sequence[str]
List of sentences to be preprocessed.
max_seq_len : int, optional
Maximum sequence length.
If ``None``, use default sequence length.
Returns
-------
List[List[int]]
List of tokenized sentences.
"""
sequences = []
if max_seq_len is None:
max_seq_len = self.SEQUENCE_LENGTH
def to_ids(sentence: str, num_oov_buckets: int = 1) -> Tuple[List[int]]:
"""Map list of sentence to list of ``[idx..]``
and pad to ``max_seq_len + 1``.
Parameters
----------
sentence : str
Sentence to be converted.
num_oov_buckets : int, default 1
The number of out of vocabulary buckets.
Returns
-------
Tuple[List[int]]
List of tokenized sentence.
"""
tokens = [self.char_to_id(c) for c in sentence]
tokens = [self.char_to_id(self._bos)] + tokens + [self.char_to_id(self._eos)]
if len(tokens) % (max_seq_len + 1) != 0:
pad_length = (-len(tokens)) % (max_seq_len + 1)
tokens += list(repeat(self.char_to_id(self._pad), pad_length))
return (tokens[i : i + max_seq_len + 1] for i in range(0, len(tokens), max_seq_len + 1))
for sen in sentences:
sequences.extend(to_ids(sen))
return sequences
[docs] def id_to_word(self, idx: int) -> str:
"""Convert an integer index to a character."""
return self.words[idx]
[docs] def char_to_id(self, char: str) -> int:
"""Convert a character to an integer index."""
return self.word_dict.get(char, len(self.word_dict))
@property
def words(self) -> List[str]:
"""Get the word list."""
return self._words
[docs] def get_word_dict(self) -> Dict[str, int]:
"""Get the word dictionary."""
return self.word_dict
[docs] def evaluate(self, probs: torch.Tensor, truths: torch.Tensor) -> Dict[str, float]:
"""Evaluation using predictions and ground truth.
Parameters
----------
probs : torch.Tensor
Predicted probabilities.
truths : torch.Tensor
Ground truth labels.
Returns
-------
Dict[str, float]
Evaluation results.
"""
return {
"acc": top_n_accuracy(probs, truths, 1),
"top3_acc": top_n_accuracy(probs, truths, 3),
"top5_acc": top_n_accuracy(probs, truths, 5),
"loss": self.criterion(probs, truths).item(),
"num_samples": probs.shape[0],
}
@property
def url(self) -> str:
"""URL for downloading the dataset."""
return "https://fedml.s3-us-west-1.amazonaws.com/shakespeare.tar.bz2"
@property
def candidate_models(self) -> Dict[str, torch.nn.Module]:
"""A set of candidate models."""
return {
"rnn": mnn.RNN_OriginalFedAvg(),
}
@property
def doi(self) -> List[str]:
"""DOI(s) related to the dataset."""
return [
"10.48550/ARXIV.1812.06127", # FedProx
"10.48550/ARXIV.2007.13518", # FedML
]
[docs] def view_sample(self, client_idx: int, sample_idx: Optional[int] = None) -> None:
"""View a sample from the dataset.
Parameters
----------
client_idx : int
Index of the client on which the sample is located.
sample_idx : int
Index of the sample in the client.
Returns
-------
None
"""
if client_idx >= len(self._client_ids_train):
raise ValueError(f"client_idx must be less than {len(self._client_ids_train)}")
client_id = self._client_ids_train[client_idx] # also test ids
train_h5 = h5py.File(str(self.datadir / self.DEFAULT_TRAIN_FILE), "r")
test_h5 = h5py.File(str(self.datadir / self.DEFAULT_TEST_FILE), "r")
raw_train = train_h5[self._EXAMPLE][client_id][self._SNIPPETS][()]
raw_train = [x.decode("utf8") for x in raw_train]
raw_test = test_h5[self._EXAMPLE][client_id][self._SNIPPETS][()]
raw_test = [x.decode("utf8") for x in raw_test]
snippets = raw_train + raw_test
new_line = "\n" + "-" * 50 + "\n"
if sample_idx is not None:
assert sample_idx < len(snippets), "sample_idx out of range"
print(f"Client ID (Title):{new_line}{client_id}{new_line}")
if sample_idx is None:
print(f"Snippets:{new_line}{new_line.join([repr(x) for x in snippets])}")
else:
print(f"Snippet {sample_idx}:{new_line}{repr(snippets[sample_idx])}")
train_h5.close()
test_h5.close()