cfgpp commited on
Commit
566c475
·
verified ·
1 Parent(s): 428ac38

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +41 -0
model_utils.py CHANGED
@@ -54,5 +54,46 @@ def predict(model, img_tensor, device):
54
  return dict(sorted_probs)
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
 
 
54
  return dict(sorted_probs)
55
 
56
 
57
+ import cv2
58
+ import numpy as np
59
+ import torch
60
+
61
+ def generate_gradcam(model, input_tensor, target_class, device):
62
+ features = []
63
+ gradients = []
64
+
65
+ def forward_hook(module, input, output):
66
+ features.append(output.detach())
67
+
68
+ def backward_hook(module, grad_input, grad_output):
69
+ gradients.append(grad_output[0].detach())
70
+
71
+ last_conv_layer = model.features[-1]
72
+ forward_handle = last_conv_layer.register_forward_hook(forward_hook)
73
+ backward_handle = last_conv_layer.register_backward_hook(backward_hook)
74
+
75
+ model.zero_grad()
76
+ output = model(input_tensor.unsqueeze(0).to(device))
77
+ class_score = output[0][target_class]
78
+ class_score.backward()
79
+
80
+ grads = gradients[0]
81
+ fmap = features[0]
82
+
83
+ weights = grads.mean(dim=[2, 3], keepdim=True)
84
+ cam = (weights * fmap).sum(dim=1).squeeze()
85
+ cam = torch.relu(cam).cpu().numpy()
86
+
87
+ cam = cam - cam.min()
88
+ cam = cam / cam.max()
89
+ cam = cv2.resize(cam, (224, 224))
90
+
91
+ forward_handle.remove()
92
+ backward_handle.remove()
93
+
94
+ return cam
95
+
96
+
97
+
98
 
99