Client#
- class fl_sim.nodes.Client(client_id: int, device: device, model: Module, dataset: FedDataset, config: ClientConfig)[source]#
Bases:
Node
The class to simulate the client node.
The client node is responsible for training the local models, and communicating with the server node.
- Parameters:
client_id (int) – The id of the client.
device (torch.device) – The device to train the model on.
model (torch.nn.Module) – The model to train.
dataset (FedDataset) – The dataset to train on.
config (ClientConfig) – The config for the client.
- get_all_data() Tuple[Tensor, Tensor] [source]#
Get all the data on the client.
This method is a helper function for fast access to the data on the client, including both training and validation data; both features and labels.
- solve_inner() None [source]#
Main part of inner loop solver.
Basic example:
self.model.train() epoch_losses = [] for epoch in range(self.config.num_epochs): batch_losses = [] for batch_idx, (data, target) in enumerate(self.train_loader): data, target = data.to(self.device), target.to(self.device) self.optimizer.zero_grad() output = self.model(data) loss = self.criterion(output, target) loss.backward() self.optimizer.step() batch_losses.append(loss.item()) epoch_losses.append(sum(batch_losses) / len(batch_losses)) self.lr_scheduler.step()
- abstract train() None [source]#
Main part of inner loop solver.
Basic example:
self.model.train() epoch_losses = [] for epoch in range(self.config.num_epochs): batch_losses = [] for batch_idx, (data, target) in enumerate(self.train_loader): data, target = data.to(self.device), target.to(self.device) self.optimizer.zero_grad() output = self.model(data) loss = self.criterion(output, target) loss.backward() self.optimizer.step() batch_losses.append(loss.item()) epoch_losses.append(sum(batch_losses) / len(batch_losses)) self.lr_scheduler.step()