DiffMixin#
- class fl_sim.models.DiffMixin[source]#
Bases:
object
Mixin for differences of two models.
Examples
class ModelA(nn.Module, DiffMixin): def __init__(self, out_dim): super().__init__() self.fc = nn.Linear(10, out_dim) model_1 = ModelA(10) model_2 = ModelA(10) model_1.diff(model_2, norm=2)
- diff(other: object, norm: str | int | float | None = None) float | List[Tensor] [source]#
Compute the difference between two models.
- Parameters:
other (object) – Another model, which has the same structure as this one.
norm (str or int or float, optional) – The norm to compute the difference. None for the raw difference. Refer to
torch.linalg.norm()
for more details.
- Returns:
diff – The difference.
- Return type:
float or List[torch.Tensor]