Meet2304 commited on
Commit
4c590a1
·
1 Parent(s): 9466b22

Add GRAD-CAM++, and LayerCAM visualizations

Browse files
Files changed (1) hide show
  1. app.py +51 -29
app.py CHANGED
@@ -22,8 +22,8 @@ from transformers import (
22
  AutoImageProcessor
23
  )
24
 
25
- # GRAD-CAM
26
- from pytorch_grad_cam import GradCAM
27
  from pytorch_grad_cam.utils.image import show_cam_on_image
28
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
29
 
@@ -113,13 +113,13 @@ def get_target_layers(model):
113
  return [model.convnextv2.encoder.stages[-1].layers[-1]]
114
 
115
 
116
- def apply_gradcam(
117
  pixel_values: torch.Tensor,
118
  original_image: np.ndarray,
119
  target_class: Optional[int] = None
120
- ) -> Tuple[np.ndarray, int, float]:
121
  """
122
- Apply GRAD-CAM to visualize model attention.
123
 
124
  Args:
125
  pixel_values: Preprocessed image tensor
@@ -127,7 +127,7 @@ def apply_gradcam(
127
  target_class: Target class index (None for predicted class)
128
 
129
  Returns:
130
- Tuple of (visualization, predicted_class, confidence)
131
  """
132
  # Wrap the model
133
  wrapped_model = ConvNeXtGradCAMWrapper(model)
@@ -135,8 +135,10 @@ def apply_gradcam(
135
  # Get target layers
136
  target_layers = get_target_layers(model)
137
 
138
- # Initialize GRAD-CAM
139
- cam = GradCAM(model=wrapped_model, target_layers=target_layers)
 
 
140
 
141
  # Get prediction
142
  model.eval()
@@ -150,26 +152,41 @@ def apply_gradcam(
150
  if target_class is None:
151
  target_class = predicted_class
152
 
153
- # Create target for GRAD-CAM
154
  targets = [ClassifierOutputTarget(target_class)]
155
 
156
- # Generate GRAD-CAM
157
- grayscale_cam = cam(input_tensor=pixel_values, targets=targets)
158
- grayscale_cam = grayscale_cam[0, :]
 
159
 
160
  # Resize original image to match CAM dimensions
161
- cam_h, cam_w = grayscale_cam.shape
162
  rgb_image_for_overlay = cv2.resize(original_image, (cam_w, cam_h)).astype(np.float32) / 255.0
163
 
164
- # Create visualization
165
- visualization = show_cam_on_image(
166
  rgb_image_for_overlay,
167
- grayscale_cam,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  use_rgb=True,
169
  colormap=cv2.COLORMAP_JET
170
  )
171
 
172
- return visualization, predicted_class, float(probabilities[predicted_class].item())
173
 
174
 
175
  # ========== GRADIO INTERFACE FUNCTIONS ==========
@@ -212,16 +229,16 @@ def predict_basic(image):
212
 
213
  def predict_with_explainability(image):
214
  """
215
- Prediction with GRAD-CAM explainability.
216
 
217
  Args:
218
  image: PIL Image or numpy array
219
 
220
  Returns:
221
- Tuple of (probabilities_dict, gradcam_image, info_text)
222
  """
223
  if image is None:
224
- return None, None, "Please upload an image."
225
 
226
  try:
227
  # Convert to PIL Image if needed
@@ -239,8 +256,10 @@ def predict_with_explainability(image):
239
  probabilities = F.softmax(logits, dim=-1)[0]
240
  predicted_class = logits.argmax(-1).item()
241
 
242
- # Apply GRAD-CAM
243
- visualization, pred_class, confidence = apply_gradcam(pixel_values, original_image)
 
 
244
 
245
  # Format probabilities for Gradio
246
  probs_dict = {DISPLAY_NAMES[i]: float(probabilities[i]) for i in range(len(DISPLAY_NAMES))}
@@ -248,13 +267,13 @@ def predict_with_explainability(image):
248
  # Create info text
249
  info_text = f"**Predicted Class:** {DISPLAY_NAMES[predicted_class]}\n\n"
250
  info_text += f"**Confidence:** {confidence*100:.2f}%\n\n"
251
- info_text += "The heatmap shows regions the model focused on for classification."
252
 
253
- return probs_dict, visualization, info_text
254
 
255
  except Exception as e:
256
  print(f"Error in prediction with explainability: {e}")
257
- return None, None, f"Error: {str(e)}"
258
 
259
 
260
  # ========== GRADIO INTERFACE ==========
@@ -306,8 +325,8 @@ with gr.Blocks(css=custom_css, title="Project Phoenix - Cervical Cancer Cell Cla
306
  )
307
 
308
  # Tab 2: Prediction with Explainability
309
- with gr.TabItem("🔍 Prediction + Explainability (GRAD-CAM)"):
310
- gr.Markdown("Upload an image to classify and visualize model attention using GRAD-CAM.")
311
 
312
  with gr.Row():
313
  with gr.Column():
@@ -316,13 +335,16 @@ with gr.Blocks(css=custom_css, title="Project Phoenix - Cervical Cancer Cell Cla
316
 
317
  with gr.Column():
318
  output_label_explain = gr.Label(label="Classification Results", num_top_classes=5)
319
- output_gradcam = gr.Image(label="GRAD-CAM Heatmap")
 
 
 
320
  output_info = gr.Markdown(label="Analysis")
321
 
322
  predict_btn_explain.click(
323
  fn=predict_with_explainability,
324
  inputs=input_image_explain,
325
- outputs=[output_label_explain, output_gradcam, output_info],
326
  api_name="predict_with_explainability",
327
  queue=False
328
  )
 
22
  AutoImageProcessor
23
  )
24
 
25
+ # GRAD-CAM variants
26
+ from pytorch_grad_cam import GradCAM, GradCAMPlusPlus, LayerCAM
27
  from pytorch_grad_cam.utils.image import show_cam_on_image
28
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
29
 
 
113
  return [model.convnextv2.encoder.stages[-1].layers[-1]]
114
 
115
 
116
+ def apply_cam_methods(
117
  pixel_values: torch.Tensor,
118
  original_image: np.ndarray,
119
  target_class: Optional[int] = None
120
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int, float]:
121
  """
122
+ Apply GRAD-CAM, GRAD-CAM++, and LayerCAM to visualize model attention.
123
 
124
  Args:
125
  pixel_values: Preprocessed image tensor
 
127
  target_class: Target class index (None for predicted class)
128
 
129
  Returns:
130
+ Tuple of (gradcam_viz, gradcam_pp_viz, layercam_viz, predicted_class, confidence)
131
  """
132
  # Wrap the model
133
  wrapped_model = ConvNeXtGradCAMWrapper(model)
 
135
  # Get target layers
136
  target_layers = get_target_layers(model)
137
 
138
+ # Initialize all CAM methods
139
+ gradcam = GradCAM(model=wrapped_model, target_layers=target_layers)
140
+ gradcam_pp = GradCAMPlusPlus(model=wrapped_model, target_layers=target_layers)
141
+ layercam = LayerCAM(model=wrapped_model, target_layers=target_layers)
142
 
143
  # Get prediction
144
  model.eval()
 
152
  if target_class is None:
153
  target_class = predicted_class
154
 
155
+ # Create target for CAM methods
156
  targets = [ClassifierOutputTarget(target_class)]
157
 
158
+ # Generate all CAM visualizations
159
+ grayscale_gradcam = gradcam(input_tensor=pixel_values, targets=targets)[0, :]
160
+ grayscale_gradcam_pp = gradcam_pp(input_tensor=pixel_values, targets=targets)[0, :]
161
+ grayscale_layercam = layercam(input_tensor=pixel_values, targets=targets)[0, :]
162
 
163
  # Resize original image to match CAM dimensions
164
+ cam_h, cam_w = grayscale_gradcam.shape
165
  rgb_image_for_overlay = cv2.resize(original_image, (cam_w, cam_h)).astype(np.float32) / 255.0
166
 
167
+ # Create visualizations for all methods
168
+ viz_gradcam = show_cam_on_image(
169
  rgb_image_for_overlay,
170
+ grayscale_gradcam,
171
+ use_rgb=True,
172
+ colormap=cv2.COLORMAP_JET
173
+ )
174
+
175
+ viz_gradcam_pp = show_cam_on_image(
176
+ rgb_image_for_overlay,
177
+ grayscale_gradcam_pp,
178
+ use_rgb=True,
179
+ colormap=cv2.COLORMAP_JET
180
+ )
181
+
182
+ viz_layercam = show_cam_on_image(
183
+ rgb_image_for_overlay,
184
+ grayscale_layercam,
185
  use_rgb=True,
186
  colormap=cv2.COLORMAP_JET
187
  )
188
 
189
+ return viz_gradcam, viz_gradcam_pp, viz_layercam, predicted_class, float(probabilities[predicted_class].item())
190
 
191
 
192
  # ========== GRADIO INTERFACE FUNCTIONS ==========
 
229
 
230
  def predict_with_explainability(image):
231
  """
232
+ Prediction with multiple CAM explainability methods.
233
 
234
  Args:
235
  image: PIL Image or numpy array
236
 
237
  Returns:
238
+ Tuple of (probabilities_dict, gradcam_image, gradcam_pp_image, layercam_image, info_text)
239
  """
240
  if image is None:
241
+ return None, None, None, None, "Please upload an image."
242
 
243
  try:
244
  # Convert to PIL Image if needed
 
256
  probabilities = F.softmax(logits, dim=-1)[0]
257
  predicted_class = logits.argmax(-1).item()
258
 
259
+ # Apply all CAM methods
260
+ viz_gradcam, viz_gradcam_pp, viz_layercam, pred_class, confidence = apply_cam_methods(
261
+ pixel_values, original_image
262
+ )
263
 
264
  # Format probabilities for Gradio
265
  probs_dict = {DISPLAY_NAMES[i]: float(probabilities[i]) for i in range(len(DISPLAY_NAMES))}
 
267
  # Create info text
268
  info_text = f"**Predicted Class:** {DISPLAY_NAMES[predicted_class]}\n\n"
269
  info_text += f"**Confidence:** {confidence*100:.2f}%\n\n"
270
+ info_text += "The heatmaps show regions the model focused on for classification using different visualization methods."
271
 
272
+ return probs_dict, viz_gradcam, viz_gradcam_pp, viz_layercam, info_text
273
 
274
  except Exception as e:
275
  print(f"Error in prediction with explainability: {e}")
276
+ return None, None, None, None, f"Error: {str(e)}"
277
 
278
 
279
  # ========== GRADIO INTERFACE ==========
 
325
  )
326
 
327
  # Tab 2: Prediction with Explainability
328
+ with gr.TabItem("🔍 Prediction + Explainability (CAM Methods)"):
329
+ gr.Markdown("Upload an image to classify and visualize model attention using GRAD-CAM, GRAD-CAM++, and LayerCAM.")
330
 
331
  with gr.Row():
332
  with gr.Column():
 
335
 
336
  with gr.Column():
337
  output_label_explain = gr.Label(label="Classification Results", num_top_classes=5)
338
+ with gr.Row():
339
+ output_gradcam = gr.Image(label="GRAD-CAM")
340
+ output_gradcam_pp = gr.Image(label="GRAD-CAM++")
341
+ output_layercam = gr.Image(label="LayerCAM")
342
  output_info = gr.Markdown(label="Analysis")
343
 
344
  predict_btn_explain.click(
345
  fn=predict_with_explainability,
346
  inputs=input_image_explain,
347
+ outputs=[output_label_explain, output_gradcam, output_gradcam_pp, output_layercam, output_info],
348
  api_name="predict_with_explainability",
349
  queue=False
350
  )