"""
"""
from typing import Dict, List, Optional, Union
import einops
import numpy as np
import torch
from torch import Tensor
from ..utils.torch_compat import torch_norm
__all__ = [
"CLFMixin",
"REGMixin",
"DiffMixin",
"reset_parameters",
"top_n_accuracy",
]
[docs]class CLFMixin(object):
"""Mixin class for classifiers."""
__name__ = "CLFMixin"
[docs] def predict_proba(
self,
input: Union[Tensor, np.ndarray],
multi_label: bool = False,
batched: bool = False,
) -> np.ndarray:
"""Predict probabilities for each class.
Parameters
----------
input : torch.Tensor or numpy.ndarray
The input data.
multi_label : bool, default False
Whether the model is a multi-label classifier.
batched : bool, default False
Whether the input is batched.
Returns
-------
proba : numpy.ndarray
The predicted probabilities.
"""
self.eval()
if isinstance(input, np.ndarray):
input = torch.from_numpy(input).to(self.device)
if not batched:
input = input.unsqueeze(0)
output = self.forward(input)
if multi_label:
proba = torch.sigmoid(output).cpu().detach().numpy()
proba = torch.softmax(output, dim=-1).cpu().detach().numpy()
if not batched:
proba = proba.squeeze(0)
return proba
[docs] def predict(
self,
input: Union[Tensor, np.ndarray],
thr: Optional[float] = None,
class_map: Optional[Dict[int, str]] = None,
batched: bool = False,
) -> list:
"""Predict the class labels.
Parameters
----------
input : torch.Tensor or numpy.ndarray
The input data.
thr : float, optional
The threshold for multi-label classification.
None for single-label classification.
class_map : dict, optional
The mapping from class index to class name.
batched : bool, default False
Whether the input is batched.
Returns
-------
labels : list
The predicted class labels.
"""
proba = self.predict_proba(input, multi_label=thr is not None, batched=batched)
if thr is None:
output = proba.argmax(axis=-1).tolist()
if class_map is not None:
if batched:
output = [class_map[i] for i in output]
else:
output = class_map[output]
return output
if batched:
output = [[] for _ in range(input.shape[0])]
else:
output = [[]]
if not batched:
proba = proba[np.newaxis, ...]
indices = np.where(proba > thr)
if len(indices) > 2:
raise ValueError("multi-label classification is not supported for output of 3 dimensions or more")
for i, j in zip(*indices):
output[i].append(j)
for idx in range(len(output)):
if len(output[idx]) == 0:
output[idx] = [proba[idx].argmax()]
if class_map is not None:
output = [[class_map[i] for i in item] for item in output]
if not batched:
output = output[0]
return output
[docs]class REGMixin(object):
"""Mixin for regressors."""
__name__ = "REGMixin"
[docs] def predict(self, input: Tensor) -> np.ndarray:
"""Predict the regression target.
Parameters
----------
input : torch.Tensor
The input data.
Returns
-------
output : numpy.ndarray
The predicted regression target.
"""
output = self.forward(input)
return output.cpu().detach().numpy()
[docs]class DiffMixin(object):
"""Mixin for differences of two models.
Examples
--------
.. code-block:: python
class ModelA(nn.Module, DiffMixin):
def __init__(self, out_dim):
super().__init__()
self.fc = nn.Linear(10, out_dim)
model_1 = ModelA(10)
model_2 = ModelA(10)
model_1.diff(model_2, norm=2)
"""
[docs] def diff(self, other: object, norm: Optional[Union[str, int, float]] = None) -> Union[float, List[Tensor]]:
"""Compute the difference between two models.
Parameters
----------
other : object
Another model, which has the same structure as this one.
norm : str or int or float, optional
The norm to compute the difference.
None for the raw difference.
Refer to :func:`torch.linalg.norm` for more details.
Returns
-------
diff : float or List[torch.Tensor]
The difference.
"""
assert isinstance(other, type(self)), "the two models should have the same structure"
if norm is not None:
# string type infinities to float
if norm == "inf":
norm = float("inf")
elif norm == "-inf":
norm = -float("inf")
assert isinstance(norm, (int, float)) or norm in ["nuc", "fro"], (
"norm should be an int or float or one of " "'nuc' (nuclear norm) or 'fro' (Frobenius norm)"
)
try:
if norm is not None:
diff = [
torch_norm(p1.data.cpu() - p2.data.cpu(), norm).item()
for p1, p2 in zip(self.parameters(), other.parameters())
]
else:
diff = [p1.data.cpu() - p2.data.cpu() for p1, p2 in zip(self.parameters(), other.parameters())]
except RuntimeError as e:
if norm == "nuc" and "Expected a tensor with 2 dimensions" in str(e):
raise ValueError("nuclear norm is not supported for the current model structure") from e
elif "must match the size" in str(e):
raise ValueError("the two models should have the same structure") from e
else:
raise e
if norm in ["nuc", "fro"]:
diff = np.sqrt(np.sum([d**2 for d in diff]))
elif norm == float("inf"):
diff = np.max([d for d in diff])
elif norm == -float("inf"):
diff = np.min([d for d in diff])
elif isinstance(norm, (int, float)):
# L_p norm for p finite
diff = np.sum([d**norm for d in diff]) ** (1 / norm)
return diff
[docs]def reset_parameters(module: torch.nn.Module) -> None:
"""Reset the parameters of a module and its children.
Parameters
----------
module : torch.nn.Module
The module to reset.
"""
for layer in module.children():
if hasattr(layer, "reset_parameters"):
layer.reset_parameters()
else:
reset_parameters(layer)
[docs]def top_n_accuracy(preds: Tensor, labels: Tensor, n: int = 1) -> float:
"""Top-n accuracy.
Parameters
----------
preds : torch.Tensor
Shape ``(batch_size, n_classes)`` or ``(batch_size, n_classes, d_1, ..., d_n)``.
labels : torch.Tensor
Shape ``(batch_size,)`` or ``(batch_size, d_1, ..., d_n)``.
Returns
-------
float
The top-n accuracy.
"""
assert preds.shape[0] == labels.shape[0]
batch_size, n_classes, *extra_dims = preds.shape
_, indices = torch.topk(preds, n, dim=1) # of shape (batch_size, n) or (batch_size, n, d_1, ..., d_n)
pattern = " ".join([f"d_{i+1}" for i in range(len(extra_dims))])
pattern = f"batch_size {pattern} -> batch_size n {pattern}"
correct = torch.sum(indices == einops.repeat(labels, pattern, n=n))
acc = correct.item() / preds.shape[0]
for d in extra_dims:
acc = acc / d
return acc