"""
1d analog grad cam,
in order to inspect attention area of the ECG deep models
References
----------
https://github.com/jacobgil/pytorch-grad-cam
"""
from typing import List, Optional, Sequence, Tuple
import numpy as np
import torch
from torch import Tensor, nn
__all__ = [
"GradCam",
]
class FeatureExtractor(object):
"""
Class for extracting activations and
registering gradients from targetted intermediate layers
"""
def __init__(self, model: nn.Module, target_layers: Sequence[str]) -> None:
"""
Parameters
----------
model: Module,
target_layers: sequence of str,
"""
self.model = model
self.target_layers = target_layers
self.gradients = []
def save_gradient(self, grad: Tensor) -> None:
""" """
self.gradients.append(grad)
def __call__(self, x: Tensor) -> Tuple[List[Tensor], Tensor]:
""" """
outputs = []
self.gradients = []
for name, module in self.model._modules.items():
x = module(x)
if name in self.target_layers:
x.register_hook(self.save_gradient)
outputs.append(x)
last_out = x
return outputs, last_out
class ModelOutputs(object):
"""
Class for making a forward pass, and getting:
1. The network output.
2. Activations from intermeddiate targetted layers.
3. Gradients from intermeddiate targetted layers.
"""
def __init__(self, model: nn.Module, feature_module: nn.Module, target_layers: Sequence[str]) -> None:
""" """
self.model = model
self.feature_module = feature_module
self.feature_extractor = FeatureExtractor(self.feature_module, target_layers)
def get_gradients(self) -> List[Tensor]:
""" """
return self.feature_extractor.gradients
def __call__(self, x: Tensor) -> Tuple[List[Tensor], Tensor]:
""" """
target_activations = []
for name, module in self.model._modules.items():
if module == self.feature_module:
target_activations, x = self.feature_extractor(x)
elif "avgpool" in name.lower():
x = module(x)
x = x.view(x.size(0), -1)
else:
x = module(x)
return target_activations, x
[docs]class GradCam(object):
"""NOT finished,"""
__name__ = "GradCam"
def __init__(
self,
model: nn.Module,
feature_module: nn.Module,
target_layer_names: Sequence[str],
target_channel_last: bool = False,
device: str = "cpu",
) -> None:
"""
Parameters
----------
to write
"""
self.model = model
self.feature_module = feature_module
self.target_layer_names = target_layer_names
self.target_channel_last = target_channel_last
self.device = torch.device(device)
self.model.eval()
self.model.to(self.device)
self.extractor = ModelOutputs(self.model, self.feature_module, self.target_layer_names)
def forward(self, input: Tensor) -> Tensor:
""" """
return self.model(input)
def __call__(self, input: Tensor, index: Optional[int] = None):
"""NOT finished,
Parameters
----------
input: Tensor,
input tensor of shape (batch_size (=1), channels, seq_len)
index: int, optional,
the index of the output of the final classifying layer of `self.model`
"""
# output of shape (batch_size (=1), n_classes)
features, output = self.extractor(input.to(self.device))
n_classes = output.shape[-1]
if index is None:
index = np.argmax(output.cpu().detach().numpy()[0])
one_hot = np.zeros((1, n_classes), dtype=np.float32)
one_hot[0][index] = 1
one_hot = torch.from_numpy(one_hot).requires_grad_(True).to(self.device)
one_hot = torch.sum(one_hot * output)
self.feature_module.zero_grad()
self.model.zero_grad()
one_hot.backward(retain_graph=True)
grads_val = self.extractor.get_gradients()[-1].cpu().detach().numpy()
# of shape (batch_size (=1), channels, seq_len) or (batch_size (=1), seq_len, channels)
target = features[-1]
# of shape (channels, seq_len) or (seq_len, channels)
target = target.cpu().detach().numpy()[0, :]
if self.target_channel_last:
weights = np.mean(grads_val, axis=-2)[0, :]
else:
weights = np.mean(grads_val, axis=-1)[0, :]
cam = np.zeros(target.shape[1:], dtype=np.float32)
for i, w in enumerate(weights):
cam += w * target[i, :, :]
# cam = np.maximum(cam, 0)
# cam = cv2.resize(cam, input.shape[2:])
# cam = cam - np.min(cam)
# cam = cam / np.max(cam)
return cam