Spaces:
Runtime error
Runtime error
Add GRAD-CAM++, and LayerCAM visualizations
Browse files
app.py
CHANGED
|
@@ -22,8 +22,8 @@ from transformers import (
|
|
| 22 |
AutoImageProcessor
|
| 23 |
)
|
| 24 |
|
| 25 |
-
# GRAD-CAM
|
| 26 |
-
from pytorch_grad_cam import GradCAM
|
| 27 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
| 28 |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
| 29 |
|
|
@@ -113,13 +113,13 @@ def get_target_layers(model):
|
|
| 113 |
return [model.convnextv2.encoder.stages[-1].layers[-1]]
|
| 114 |
|
| 115 |
|
| 116 |
-
def
|
| 117 |
pixel_values: torch.Tensor,
|
| 118 |
original_image: np.ndarray,
|
| 119 |
target_class: Optional[int] = None
|
| 120 |
-
) -> Tuple[np.ndarray, int, float]:
|
| 121 |
"""
|
| 122 |
-
Apply GRAD-CAM to visualize model attention.
|
| 123 |
|
| 124 |
Args:
|
| 125 |
pixel_values: Preprocessed image tensor
|
|
@@ -127,7 +127,7 @@ def apply_gradcam(
|
|
| 127 |
target_class: Target class index (None for predicted class)
|
| 128 |
|
| 129 |
Returns:
|
| 130 |
-
Tuple of (
|
| 131 |
"""
|
| 132 |
# Wrap the model
|
| 133 |
wrapped_model = ConvNeXtGradCAMWrapper(model)
|
|
@@ -135,8 +135,10 @@ def apply_gradcam(
|
|
| 135 |
# Get target layers
|
| 136 |
target_layers = get_target_layers(model)
|
| 137 |
|
| 138 |
-
# Initialize
|
| 139 |
-
|
|
|
|
|
|
|
| 140 |
|
| 141 |
# Get prediction
|
| 142 |
model.eval()
|
|
@@ -150,26 +152,41 @@ def apply_gradcam(
|
|
| 150 |
if target_class is None:
|
| 151 |
target_class = predicted_class
|
| 152 |
|
| 153 |
-
# Create target for
|
| 154 |
targets = [ClassifierOutputTarget(target_class)]
|
| 155 |
|
| 156 |
-
# Generate
|
| 157 |
-
|
| 158 |
-
|
|
|
|
| 159 |
|
| 160 |
# Resize original image to match CAM dimensions
|
| 161 |
-
cam_h, cam_w =
|
| 162 |
rgb_image_for_overlay = cv2.resize(original_image, (cam_w, cam_h)).astype(np.float32) / 255.0
|
| 163 |
|
| 164 |
-
# Create
|
| 165 |
-
|
| 166 |
rgb_image_for_overlay,
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
use_rgb=True,
|
| 169 |
colormap=cv2.COLORMAP_JET
|
| 170 |
)
|
| 171 |
|
| 172 |
-
return
|
| 173 |
|
| 174 |
|
| 175 |
# ========== GRADIO INTERFACE FUNCTIONS ==========
|
|
@@ -212,16 +229,16 @@ def predict_basic(image):
|
|
| 212 |
|
| 213 |
def predict_with_explainability(image):
|
| 214 |
"""
|
| 215 |
-
Prediction with
|
| 216 |
|
| 217 |
Args:
|
| 218 |
image: PIL Image or numpy array
|
| 219 |
|
| 220 |
Returns:
|
| 221 |
-
Tuple of (probabilities_dict, gradcam_image, info_text)
|
| 222 |
"""
|
| 223 |
if image is None:
|
| 224 |
-
return None, None, "Please upload an image."
|
| 225 |
|
| 226 |
try:
|
| 227 |
# Convert to PIL Image if needed
|
|
@@ -239,8 +256,10 @@ def predict_with_explainability(image):
|
|
| 239 |
probabilities = F.softmax(logits, dim=-1)[0]
|
| 240 |
predicted_class = logits.argmax(-1).item()
|
| 241 |
|
| 242 |
-
# Apply
|
| 243 |
-
|
|
|
|
|
|
|
| 244 |
|
| 245 |
# Format probabilities for Gradio
|
| 246 |
probs_dict = {DISPLAY_NAMES[i]: float(probabilities[i]) for i in range(len(DISPLAY_NAMES))}
|
|
@@ -248,13 +267,13 @@ def predict_with_explainability(image):
|
|
| 248 |
# Create info text
|
| 249 |
info_text = f"**Predicted Class:** {DISPLAY_NAMES[predicted_class]}\n\n"
|
| 250 |
info_text += f"**Confidence:** {confidence*100:.2f}%\n\n"
|
| 251 |
-
info_text += "The
|
| 252 |
|
| 253 |
-
return probs_dict,
|
| 254 |
|
| 255 |
except Exception as e:
|
| 256 |
print(f"Error in prediction with explainability: {e}")
|
| 257 |
-
return None, None, f"Error: {str(e)}"
|
| 258 |
|
| 259 |
|
| 260 |
# ========== GRADIO INTERFACE ==========
|
|
@@ -306,8 +325,8 @@ with gr.Blocks(css=custom_css, title="Project Phoenix - Cervical Cancer Cell Cla
|
|
| 306 |
)
|
| 307 |
|
| 308 |
# Tab 2: Prediction with Explainability
|
| 309 |
-
with gr.TabItem("🔍 Prediction + Explainability (
|
| 310 |
-
gr.Markdown("Upload an image to classify and visualize model attention using GRAD-CAM.")
|
| 311 |
|
| 312 |
with gr.Row():
|
| 313 |
with gr.Column():
|
|
@@ -316,13 +335,16 @@ with gr.Blocks(css=custom_css, title="Project Phoenix - Cervical Cancer Cell Cla
|
|
| 316 |
|
| 317 |
with gr.Column():
|
| 318 |
output_label_explain = gr.Label(label="Classification Results", num_top_classes=5)
|
| 319 |
-
|
|
|
|
|
|
|
|
|
|
| 320 |
output_info = gr.Markdown(label="Analysis")
|
| 321 |
|
| 322 |
predict_btn_explain.click(
|
| 323 |
fn=predict_with_explainability,
|
| 324 |
inputs=input_image_explain,
|
| 325 |
-
outputs=[output_label_explain, output_gradcam, output_info],
|
| 326 |
api_name="predict_with_explainability",
|
| 327 |
queue=False
|
| 328 |
)
|
|
|
|
| 22 |
AutoImageProcessor
|
| 23 |
)
|
| 24 |
|
| 25 |
+
# GRAD-CAM variants
|
| 26 |
+
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus, LayerCAM
|
| 27 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
| 28 |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
| 29 |
|
|
|
|
| 113 |
return [model.convnextv2.encoder.stages[-1].layers[-1]]
|
| 114 |
|
| 115 |
|
| 116 |
+
def apply_cam_methods(
|
| 117 |
pixel_values: torch.Tensor,
|
| 118 |
original_image: np.ndarray,
|
| 119 |
target_class: Optional[int] = None
|
| 120 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int, float]:
|
| 121 |
"""
|
| 122 |
+
Apply GRAD-CAM, GRAD-CAM++, and LayerCAM to visualize model attention.
|
| 123 |
|
| 124 |
Args:
|
| 125 |
pixel_values: Preprocessed image tensor
|
|
|
|
| 127 |
target_class: Target class index (None for predicted class)
|
| 128 |
|
| 129 |
Returns:
|
| 130 |
+
Tuple of (gradcam_viz, gradcam_pp_viz, layercam_viz, predicted_class, confidence)
|
| 131 |
"""
|
| 132 |
# Wrap the model
|
| 133 |
wrapped_model = ConvNeXtGradCAMWrapper(model)
|
|
|
|
| 135 |
# Get target layers
|
| 136 |
target_layers = get_target_layers(model)
|
| 137 |
|
| 138 |
+
# Initialize all CAM methods
|
| 139 |
+
gradcam = GradCAM(model=wrapped_model, target_layers=target_layers)
|
| 140 |
+
gradcam_pp = GradCAMPlusPlus(model=wrapped_model, target_layers=target_layers)
|
| 141 |
+
layercam = LayerCAM(model=wrapped_model, target_layers=target_layers)
|
| 142 |
|
| 143 |
# Get prediction
|
| 144 |
model.eval()
|
|
|
|
| 152 |
if target_class is None:
|
| 153 |
target_class = predicted_class
|
| 154 |
|
| 155 |
+
# Create target for CAM methods
|
| 156 |
targets = [ClassifierOutputTarget(target_class)]
|
| 157 |
|
| 158 |
+
# Generate all CAM visualizations
|
| 159 |
+
grayscale_gradcam = gradcam(input_tensor=pixel_values, targets=targets)[0, :]
|
| 160 |
+
grayscale_gradcam_pp = gradcam_pp(input_tensor=pixel_values, targets=targets)[0, :]
|
| 161 |
+
grayscale_layercam = layercam(input_tensor=pixel_values, targets=targets)[0, :]
|
| 162 |
|
| 163 |
# Resize original image to match CAM dimensions
|
| 164 |
+
cam_h, cam_w = grayscale_gradcam.shape
|
| 165 |
rgb_image_for_overlay = cv2.resize(original_image, (cam_w, cam_h)).astype(np.float32) / 255.0
|
| 166 |
|
| 167 |
+
# Create visualizations for all methods
|
| 168 |
+
viz_gradcam = show_cam_on_image(
|
| 169 |
rgb_image_for_overlay,
|
| 170 |
+
grayscale_gradcam,
|
| 171 |
+
use_rgb=True,
|
| 172 |
+
colormap=cv2.COLORMAP_JET
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
viz_gradcam_pp = show_cam_on_image(
|
| 176 |
+
rgb_image_for_overlay,
|
| 177 |
+
grayscale_gradcam_pp,
|
| 178 |
+
use_rgb=True,
|
| 179 |
+
colormap=cv2.COLORMAP_JET
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
viz_layercam = show_cam_on_image(
|
| 183 |
+
rgb_image_for_overlay,
|
| 184 |
+
grayscale_layercam,
|
| 185 |
use_rgb=True,
|
| 186 |
colormap=cv2.COLORMAP_JET
|
| 187 |
)
|
| 188 |
|
| 189 |
+
return viz_gradcam, viz_gradcam_pp, viz_layercam, predicted_class, float(probabilities[predicted_class].item())
|
| 190 |
|
| 191 |
|
| 192 |
# ========== GRADIO INTERFACE FUNCTIONS ==========
|
|
|
|
| 229 |
|
| 230 |
def predict_with_explainability(image):
|
| 231 |
"""
|
| 232 |
+
Prediction with multiple CAM explainability methods.
|
| 233 |
|
| 234 |
Args:
|
| 235 |
image: PIL Image or numpy array
|
| 236 |
|
| 237 |
Returns:
|
| 238 |
+
Tuple of (probabilities_dict, gradcam_image, gradcam_pp_image, layercam_image, info_text)
|
| 239 |
"""
|
| 240 |
if image is None:
|
| 241 |
+
return None, None, None, None, "Please upload an image."
|
| 242 |
|
| 243 |
try:
|
| 244 |
# Convert to PIL Image if needed
|
|
|
|
| 256 |
probabilities = F.softmax(logits, dim=-1)[0]
|
| 257 |
predicted_class = logits.argmax(-1).item()
|
| 258 |
|
| 259 |
+
# Apply all CAM methods
|
| 260 |
+
viz_gradcam, viz_gradcam_pp, viz_layercam, pred_class, confidence = apply_cam_methods(
|
| 261 |
+
pixel_values, original_image
|
| 262 |
+
)
|
| 263 |
|
| 264 |
# Format probabilities for Gradio
|
| 265 |
probs_dict = {DISPLAY_NAMES[i]: float(probabilities[i]) for i in range(len(DISPLAY_NAMES))}
|
|
|
|
| 267 |
# Create info text
|
| 268 |
info_text = f"**Predicted Class:** {DISPLAY_NAMES[predicted_class]}\n\n"
|
| 269 |
info_text += f"**Confidence:** {confidence*100:.2f}%\n\n"
|
| 270 |
+
info_text += "The heatmaps show regions the model focused on for classification using different visualization methods."
|
| 271 |
|
| 272 |
+
return probs_dict, viz_gradcam, viz_gradcam_pp, viz_layercam, info_text
|
| 273 |
|
| 274 |
except Exception as e:
|
| 275 |
print(f"Error in prediction with explainability: {e}")
|
| 276 |
+
return None, None, None, None, f"Error: {str(e)}"
|
| 277 |
|
| 278 |
|
| 279 |
# ========== GRADIO INTERFACE ==========
|
|
|
|
| 325 |
)
|
| 326 |
|
| 327 |
# Tab 2: Prediction with Explainability
|
| 328 |
+
with gr.TabItem("🔍 Prediction + Explainability (CAM Methods)"):
|
| 329 |
+
gr.Markdown("Upload an image to classify and visualize model attention using GRAD-CAM, GRAD-CAM++, and LayerCAM.")
|
| 330 |
|
| 331 |
with gr.Row():
|
| 332 |
with gr.Column():
|
|
|
|
| 335 |
|
| 336 |
with gr.Column():
|
| 337 |
output_label_explain = gr.Label(label="Classification Results", num_top_classes=5)
|
| 338 |
+
with gr.Row():
|
| 339 |
+
output_gradcam = gr.Image(label="GRAD-CAM")
|
| 340 |
+
output_gradcam_pp = gr.Image(label="GRAD-CAM++")
|
| 341 |
+
output_layercam = gr.Image(label="LayerCAM")
|
| 342 |
output_info = gr.Markdown(label="Analysis")
|
| 343 |
|
| 344 |
predict_btn_explain.click(
|
| 345 |
fn=predict_with_explainability,
|
| 346 |
inputs=input_image_explain,
|
| 347 |
+
outputs=[output_label_explain, output_gradcam, output_gradcam_pp, output_layercam, output_info],
|
| 348 |
api_name="predict_with_explainability",
|
| 349 |
queue=False
|
| 350 |
)
|