-
Notifications
You must be signed in to change notification settings - Fork 3
/
vanilla_gradient.py
36 lines (28 loc) · 1.36 KB
/
vanilla_gradient.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import numpy as np
from saliency_mask import SaliencyMask
class VanillaGradient(SaliencyMask):
def __init__(self, model):
super(VanillaGradient, self).__init__(model)
def get_mask(self, image_tensor, target_class=None):
image_tensor = image_tensor.clone()
image_tensor.requires_grad = True
image_tensor.retain_grad()
logits = self.model(image_tensor)
target = torch.zeros_like(logits)
target[0][target_class if target_class else logits.topk(1, dim=1)[1]] = 1
self.model.zero_grad()
logits.backward(target)
return np.moveaxis(image_tensor.grad.detach().cpu().numpy()[0], 0, -1)
def get_smoothed_mask(self, image_tensor, target_class=None, samples=25, std=0.15, process=lambda x: x**2):
std = std * (torch.max(image_tensor) - torch.min(image_tensor)).detach().cpu().numpy()
batch, channels, width, height = image_tensor.size()
grad_sum = np.zeros((width, height, channels))
for sample in range(samples):
noise = torch.empty(image_tensor.size()).normal_(0, std).to(image_tensor.device)
noise_image = image_tensor + noise
grad_sum += process(self.get_mask(noise_image, target_class))
return grad_sum / samples
@staticmethod
def apply_region(mask, region):
return mask * region[..., np.newaxis]