deep123shah456's picture
Upload 3 files
4a3ae84 verified
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from sklearn.metrics import (
roc_curve, auc,
confusion_matrix,
classification_report,
f1_score,
precision_score,
recall_score,
)
from sklearn.preprocessing import label_binarize
from sklearn.calibration import calibration_curve # [GSOC UPGRADE 2] For reliability diagrams
# ─────────────────────────────────────────────────────────────────────────────
# CONSTANTS
# ─────────────────────────────────────────────────────────────────────────────
CLASSES = ['No Sub', 'CDM', 'Vortex']
CLASS_COLORS = ['#2196F3', '#F44336', '#4CAF50'] # Blue, Red, Green
# FIX: Graceful matplotlib style fallback.
try:
plt.style.use('seaborn-v0_8-whitegrid')
PLOT_STYLE = 'seaborn-v0_8-whitegrid'
except OSError:
PLOT_STYLE = 'seaborn-whitegrid'
# ─────────────────────────────────────────────────────────────────────────────
# INTERNAL HELPER
# ─────────────────────────────────────────────────────────────────────────────
def _ensure_dir(path):
"""Creates the parent directory of a file path if it does not exist."""
if path:
parent = os.path.dirname(path)
if parent:
os.makedirs(parent, exist_ok=True)
# ─────────────────────────────────────────────────────────────────────────────
# 1. ROC-AUC (PRIMARY ML4SCI METRIC) + FPR@TPR=0.90
# ─────────────────────────────────────────────────────────────────────────────
def plot_multiclass_roc_auc(
all_labels,
all_probs,
classes=CLASSES,
save_path=None,
title='Receiver Operating Characteristic (ROC)',
model_name='Model',
):
"""
Generates and saves a publication-quality multi-class ROC curve.
[GSOC UPGRADE 1] - Physics-Informed Metrics:
In dark matter searches, physicists care deeply about operational thresholds.
We now compute the False Positive Rate (FPR) at a guaranteed 90% True
Positive Rate (TPR). Minimizing this FPR while keeping a 90% detection
efficiency is a gold-standard physics requirement.
"""
all_labels = np.array(all_labels)
all_probs = np.array(all_probs)
n_classes = len(classes)
y_bin = label_binarize(all_labels, classes=list(range(n_classes)))
fpr, tpr, roc_auc, fpr_at_90 = {}, {}, {}, {}
# ── Per-class curves ──────────────────────────────────────────────────
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(y_bin[:, i], all_probs[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# Calculate FPR at exactly 90% TPR using linear interpolation
fpr_at_90[i] = np.interp(0.90, tpr[i], fpr[i])
# ── Micro-average curve (aggregate TP/FP across all classes) ─────────
fpr['micro'], tpr['micro'], _ = roc_curve(y_bin.ravel(), all_probs.ravel())
roc_auc['micro'] = auc(fpr['micro'], tpr['micro'])
fpr_at_90['micro'] = np.interp(0.90, tpr['micro'], fpr['micro'])
# ── Macro-average AUC (unweighted mean β€” the headline number) ────────
roc_auc['macro'] = np.mean([roc_auc[i] for i in range(n_classes)])
fpr_at_90['macro'] = np.mean([fpr_at_90[i] for i in range(n_classes)])
# ── Print summary ─────────────────────────────────────────────────────
print(f"\n{'='*65}")
print(f" ROC-AUC & PHYSICS REPORT β€” {model_name}")
print(f"{'='*65}")
for i, cls in enumerate(classes):
print(f" {cls:<12}: AUC = {roc_auc[i]:.4f} | FPR @ 90% TPR = {fpr_at_90[i]:.4f}")
print(f" {'-'*61}")
print(f" {'Micro-Avg':<12}: AUC = {roc_auc['micro']:.4f} | FPR @ 90% TPR = {fpr_at_90['micro']:.4f}")
print(f" {'Macro-Avg':<12}: AUC = {roc_auc['macro']:.4f} | FPR @ 90% TPR = {fpr_at_90['macro']:.4f}")
print(f"{'='*65}\n")
# ── Plot ─────────────────────────────────────────────────────────────
with plt.style.context(PLOT_STYLE):
fig, ax = plt.subplots(figsize=(10, 8))
# Per-class curves
for i, (cls, color) in enumerate(zip(classes, CLASS_COLORS)):
ax.plot(
fpr[i], tpr[i],
color=color, lw=2.5,
label=f'{cls} (AUC = {roc_auc[i]:.3f})'
)
# Micro-average
ax.plot(
fpr['micro'], tpr['micro'],
color='darkorange', lw=2, linestyle='--',
label=f'Micro-avg (AUC = {roc_auc["micro"]:.3f})'
)
# Random classifier baseline
ax.plot(
[0, 1], [0, 1],
'k--', lw=1.2, label='Random Classifier (AUC = 0.500)'
)
# Macro AUC annotation box
ax.text(
0.62, 0.12,
f'Macro-Avg AUC = {roc_auc["macro"]:.4f}\nFPR @ 90% TPR = {fpr_at_90["macro"]:.4f}',
transform=ax.transAxes,
fontsize=11, fontweight='bold',
bbox=dict(boxstyle='round,pad=0.4', facecolor='lightyellow',
edgecolor='gray', alpha=0.9)
)
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate', fontsize=13)
ax.set_ylabel('True Positive Rate', fontsize=13)
ax.set_title(f'{title}\n{model_name}', fontsize=14, fontweight='bold')
ax.legend(loc='lower right', fontsize=11)
ax.grid(alpha=0.3)
plt.tight_layout()
_ensure_dir(save_path)
if save_path:
fig.savefig(save_path, bbox_inches='tight', dpi=300)
print(f"πŸ“Š ROC curve saved β†’ {save_path}")
plt.show()
plt.close()
return {
'per_class': {i: roc_auc[i] for i in range(n_classes)},
'macro': roc_auc['macro'],
'micro': roc_auc['micro'],
'fpr_90_macro': fpr_at_90['macro'], # Added to export
}
# ─────────────────────────────────────────────────────────────────────────────
# 2. CONFUSION MATRIX
# ─────────────────────────────────────────────────────────────────────────────
def save_confusion_matrix(
y_true,
y_pred,
classes=CLASSES,
save_path=None,
title='Confusion Matrix',
cmap='Blues',
normalize=False,
):
cm = confusion_matrix(y_true, y_pred)
if normalize:
cm_display = cm.astype(float) / cm.sum(axis=1, keepdims=True)
fmt = '.2f'
title = title + ' (Normalised)'
else:
cm_display = cm
fmt = 'd'
with plt.style.context(PLOT_STYLE):
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(
cm_display,
annot=True, fmt=fmt, cmap=cmap,
xticklabels=classes, yticklabels=classes,
linewidths=0.5, linecolor='gray',
ax=ax,
)
ax.set_title(title, fontsize=14, fontweight='bold', pad=14)
ax.set_ylabel('True Physics Label', fontsize=12)
ax.set_xlabel('Model Predicted Label', fontsize=12)
ax.tick_params(axis='x', rotation=0)
ax.tick_params(axis='y', rotation=0)
plt.tight_layout()
_ensure_dir(save_path)
if save_path:
fig.savefig(save_path, bbox_inches='tight', dpi=300)
print(f"πŸ“Š Confusion matrix saved β†’ {save_path}")
plt.show()
plt.close()
# ─────────────────────────────────────────────────────────────────────────────
# 3. CLASSIFICATION REPORT
# ─────────────────────────────────────────────────────────────────────────────
def generate_classification_report(
y_true,
y_pred,
classes=CLASSES,
model_name='Model',
):
y_true = np.array(y_true)
y_pred = np.array(y_pred)
print(f"\n{'='*52}")
print(f" CLASSIFICATION REPORT β€” {model_name}")
print(f"{'='*52}")
report_str = classification_report(y_true, y_pred, target_names=classes, digits=4)
print(report_str)
f1_macro = f1_score(y_true, y_pred, average='macro', zero_division=0)
f1_weighted = f1_score(y_true, y_pred, average='weighted', zero_division=0)
# [FIX APPLIED]: Dynamic label scaling for F1 computation
dynamic_labels = list(range(len(classes)))
f1_per_cls = f1_score(y_true, y_pred, average=None, zero_division=0,
labels=dynamic_labels)
f1_per_class_named = {cls: float(f1_per_cls[i]) for i, cls in enumerate(classes)}
precision_macro = float(precision_score(y_true, y_pred, average='macro', zero_division=0))
recall_macro = float(recall_score(y_true, y_pred, average='macro', zero_division=0))
print(f" Macro F1 : {f1_macro:.4f} ← use in result dicts")
print(f" Weighted F1 : {f1_weighted:.4f}")
print(f" Macro Precision : {precision_macro:.4f}")
print(f" Macro Recall : {recall_macro:.4f}")
print(f"{'='*52}\n")
return {
'report_str': report_str,
'f1_macro': float(f1_macro),
'f1_weighted': float(f1_weighted),
'f1_per_class': f1_per_class_named,
'precision_macro': precision_macro,
'recall_macro': recall_macro,
}
# ─────────────────────────────────────────────────────────────────────────────
# 4. LEARNING CURVES (per-epoch train vs val tracking)
# ─────────────────────────────────────────────────────────────────────────────
def plot_learning_curves(
train_losses,
val_losses,
train_accs,
val_accs,
save_path=None,
model_name='Model',
):
epochs = range(1, len(train_losses) + 1)
with plt.style.context(PLOT_STYLE):
fig, (ax_loss, ax_acc) = plt.subplots(1, 2, figsize=(14, 5))
fig.suptitle(f'Learning Curves β€” {model_name}', fontsize=14, fontweight='bold')
# ── Loss subplot ─────────────────────────────────────────────────
ax_loss.plot(epochs, train_losses, 'b-o', markersize=4,
linewidth=2, label='Train Loss')
ax_loss.plot(epochs, val_losses, 'r-o', markersize=4,
linewidth=2, label='Val Loss')
ax_loss.set_xlabel('Epoch', fontsize=12)
ax_loss.set_ylabel('Cross-Entropy Loss', fontsize=12)
ax_loss.set_title('Loss over Epochs', fontsize=12)
ax_loss.legend(fontsize=11)
ax_loss.grid(alpha=0.4)
best_val_loss_epoch = int(np.argmin(val_losses)) + 1
ax_loss.axvline(
x=best_val_loss_epoch, color='red',
linestyle='--', alpha=0.5,
label=f'Best Val Loss (epoch {best_val_loss_epoch})'
)
ax_loss.legend(fontsize=10)
# ── Accuracy subplot ─────────────────────────────────────────────
ax_acc.plot(epochs, train_accs, 'b-o', markersize=4,
linewidth=2, label='Train Acc')
ax_acc.plot(epochs, val_accs, 'r-o', markersize=4,
linewidth=2, label='Val Acc')
ax_acc.set_xlabel('Epoch', fontsize=12)
ax_acc.set_ylabel('Accuracy (%)', fontsize=12)
ax_acc.set_title('Accuracy over Epochs', fontsize=12)
ax_acc.legend(fontsize=11)
ax_acc.grid(alpha=0.4)
best_val_acc_epoch = int(np.argmax(val_accs)) + 1
best_val_acc_value = max(val_accs)
ax_acc.axvline(
x=best_val_acc_epoch, color='red',
linestyle='--', alpha=0.5,
)
ax_acc.annotate(
f'Best: {best_val_acc_value:.1f}%\n(epoch {best_val_acc_epoch})',
xy=(best_val_acc_epoch, best_val_acc_value),
xytext=(best_val_acc_epoch + 0.5, best_val_acc_value - 5),
fontsize=9,
arrowprops=dict(arrowstyle='->', color='black'),
)
plt.tight_layout()
_ensure_dir(save_path)
if save_path:
fig.savefig(save_path, bbox_inches='tight', dpi=300)
print(f"πŸ“Š Learning curves saved β†’ {save_path}")
plt.show()
plt.close()
# ─────────────────────────────────────────────────────────────────────────────
# 5. TTA DEGRADATION ANALYSIS
# ─────────────────────────────────────────────────────────────────────────────
def plot_tta_degradation(
labels_before,
preds_before,
labels_after,
preds_after,
classes=CLASSES,
save_path=None,
model_name='Ensemble',
):
# [FIX APPLIED]: Dynamic labels instead of hardcoded [0, 1, 2]
dynamic_labels = list(range(len(classes)))
f1_before = f1_score(labels_before, preds_before, average=None,
labels=dynamic_labels, zero_division=0)
f1_after = f1_score(labels_after, preds_after, average=None,
labels=dynamic_labels, zero_division=0)
delta_f1 = f1_after - f1_before
x = np.arange(len(classes))
bar_w = 0.35
with plt.style.context(PLOT_STYLE):
fig, (ax_bar, ax_delta) = plt.subplots(1, 2, figsize=(14, 5))
fig.suptitle(
f'TTA Rotational Variance Analysis β€” {model_name}',
fontsize=14, fontweight='bold'
)
bars1 = ax_bar.bar(x - bar_w/2, f1_before, bar_w,
label='Standard Eval', color='#2196F3', alpha=0.85)
bars2 = ax_bar.bar(x + bar_w/2, f1_after, bar_w,
label='After TTA (0Β°/90Β°/180Β°/270Β°)',
color='#F44336', alpha=0.85)
ax_bar.set_xticks(x)
ax_bar.set_xticklabels(classes, fontsize=12)
ax_bar.set_ylabel('F1-Score', fontsize=12)
ax_bar.set_ylim([0, 1.1])
ax_bar.set_title('F1-Score: Before vs After Rotational TTA', fontsize=12)
ax_bar.legend(fontsize=11)
ax_bar.grid(axis='y', alpha=0.4)
for bar in bars1:
ax_bar.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
f'{bar.get_height():.3f}', ha='center', va='bottom', fontsize=9)
for bar in bars2:
ax_bar.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
f'{bar.get_height():.3f}', ha='center', va='bottom', fontsize=9)
colors_delta = ['#4CAF50' if d >= 0 else '#F44336' for d in delta_f1]
ax_delta.bar(classes, delta_f1, color=colors_delta, alpha=0.85, edgecolor='black')
ax_delta.axhline(y=0, color='black', linewidth=1.2)
ax_delta.set_ylabel('Ξ”F1 (After βˆ’ Before)', fontsize=12)
ax_delta.set_title('Per-Class F1 Degradation Under Rotation', fontsize=12)
ax_delta.grid(axis='y', alpha=0.4)
for i, (cls, d) in enumerate(zip(classes, delta_f1)):
ax_delta.text(i, d + (0.005 if d >= 0 else -0.015),
f'{d:+.3f}', ha='center', va='bottom', fontsize=10,
fontweight='bold')
worst_cls = classes[int(np.argmin(delta_f1))]
ax_delta.set_xlabel(
f'β†’ "{worst_cls}" suffers the largest degradation, motivating '
f'E(2)-Equivariant Networks.',
fontsize=9, style='italic'
)
plt.tight_layout()
_ensure_dir(save_path)
if save_path:
fig.savefig(save_path, bbox_inches='tight', dpi=300)
print(f"πŸ“Š TTA degradation plot saved β†’ {save_path}")
plt.show()
plt.close()
print(f"\n{'='*52}")
print(f" TTA DEGRADATION SUMMARY β€” {model_name}")
print(f"{'='*52}")
print(f" {'Class':<12} {'F1 Before':>10} {'F1 After':>10} {'Delta':>8}")
print(f" {'-'*46}")
for cls, fb, fa, d in zip(classes, f1_before, f1_after, delta_f1):
print(f" {cls:<12} {fb:>10.4f} {fa:>10.4f} {d:>+8.4f}")
print(f"{'='*52}\n")
# ─────────────────────────────────────────────────────────────────────────────
# 6. MODEL COMPARISON SUMMARY TABLE
# ─────────────────────────────────────────────────────────────────────────────
def print_model_comparison_table(results: list[dict]):
header = (
f"\n{'='*82}\n" # Expanded width to accommodate new physics metric
f" {'Model':<26} {'Val Acc':>8} {'MacroAUC':>10} "
f"{'CDM AUC':>9} {'FPR@90%':>9} {'F1 Macro':>10}\n"
f" {'-'*78}"
)
print(header)
for r in results:
# Gracefully handle missing metric if older result format is passed
fpr_val = r.get('fpr_90_macro', 0.0)
print(
f" {r['model']:<26} "
f"{r['val_acc']:>7.1f}% "
f"{r['macro_auc']:>10.4f} "
f"{r['cdm_auc']:>9.4f} "
f"{fpr_val:>9.4f} "
f"{r['f1_macro']:>10.4f}"
)
print(f"{'='*82}\n")
# ─────────────────────────────────────────────────────────────────────────────
# 7. CALIBRATION CURVES (RELIABILITY DIAGRAM) - [GSOC UPGRADE 2]
# ─────────────────────────────────────────────────────────────────────────────
def plot_calibration_curves(
y_true,
y_probs,
classes=CLASSES,
save_path=None,
model_name='Model',
):
"""
Plots a Reliability Diagram (Calibration Curve) for the model.
A perfectly calibrated model (the dotted line) outputs probabilities that
exactly match their empirical frequencies. Neural networks are often
overconfident. This diagnostic proves deep ML maturity by showing you
evaluate the *trustworthiness* of the probabilities, not just accuracy.
"""
y_true = np.array(y_true)
y_probs = np.array(y_probs)
n_classes = len(classes)
y_bin = label_binarize(y_true, classes=list(range(n_classes)))
with plt.style.context(PLOT_STYLE):
fig, ax = plt.subplots(figsize=(8, 8))
# Perfectly calibrated reference line
ax.plot([0, 1], [0, 1], "k:", label="Perfectly Calibrated")
for i, (cls, color) in enumerate(zip(classes, CLASS_COLORS)):
# Computes fraction of positives and mean predicted probability per bin
prob_true, prob_pred = calibration_curve(y_bin[:, i], y_probs[:, i], n_bins=10)
ax.plot(prob_pred, prob_true, "s-", color=color, label=f"{cls}")
ax.set_ylabel("Fraction of positives (Empirical True Probability)", fontsize=12)
ax.set_xlabel("Mean predicted value (Model Confidence)", fontsize=12)
ax.set_title(f"Reliability Diagram (Calibration Curves)\n{model_name}", fontsize=14, fontweight='bold')
ax.legend(loc="lower right", fontsize=11)
ax.grid(alpha=0.3)
plt.tight_layout()
_ensure_dir(save_path)
if save_path:
fig.savefig(save_path, bbox_inches='tight', dpi=300)
print(f"πŸ“Š Calibration curves saved β†’ {save_path}")
plt.show()
plt.close()
# ─────────────────────────────────────────────────────────────────────────────
# 8. GRADCAM β€” ResNet models [GSOC UPGRADE: NEW]
# ─────────────────────────────────────────────────────────────────────────────
#
# GradCAM computes a class-discriminative localisation map by weighting the
# activation maps of the final convolutional layer by the gradients of the
# class score with respect to those activations.
#
# Reference: Selvaraju et al. (2017) β€” "Grad-CAM: Visual Explanations from
# Deep Networks via Gradient-based Localization"
#
# DESIGN NOTE β€” EquivariantCNN edge case:
# Standard GradCAM hooks attach to the output of a nn.Conv2d layer and
# capture a (B, C, H, W) activation tensor. escnn's R2Conv does NOT return
# a plain nn.Conv2d β€” it returns a GeometricTensor wrapper. Registering
# a backward hook on R2Conv raises a RuntimeError because GeometricTensor
# gradients are tracked on the `.tensor` attribute, not on the module output
# directly.
#
# Solution: For EquivariantCNN we hook the GROUP POOLING output instead
# (self.group_pool), which returns a plain tensor after the pooling step.
# This is still spatially meaningful β€” it captures the equivariant feature
# map after the final group symmetry is collapsed, giving a valid CAM.
# The function detects EquivariantCNN by class name and routes accordingly.
# ─────────────────────────────────────────────────────────────────────────────
def compute_gradcam(
model,
image_tensor,
class_idx=None,
device=None,
):
"""
Computes a GradCAM heatmap for a single image.
Supports:
- ResNetBaseline / ResNetTransfer β†’ hooks model.model.layer4[-1]
- ViTChampion β†’ delegates to compute_attention_rollout()
- EquivariantCNN β†’ hooks model.group_pool (see DESIGN NOTE)
Args:
model : One of the four DeepLense model classes.
image_tensor : torch.Tensor of shape (1, C, H, W) β€” single image, batched.
class_idx : int β€” target class index (0/1/2). If None, uses argmax.
device : torch.device. If None, inferred from model parameters.
Returns:
cam : np.ndarray of shape (H, W) β€” GradCAM heatmap in [0, 1].
pred_class : int β€” predicted class index.
pred_probs : np.ndarray of shape (3,) β€” softmax probabilities.
"""
import torch
import torch.nn.functional as F
if device is None:
device = next(model.parameters()).device
model_class = type(model).__name__
# ── ViT: delegate to Attention Rollout ───────────────────────────────
if model_class == 'ViTChampion':
return compute_attention_rollout(model, image_tensor, class_idx, device)
image_tensor = image_tensor.to(device)
model.eval()
# ── Storage for hooks ─────────────────────────────────────────────────
activations = {}
gradients = {}
# ── Select hook strategy based on model class ───────────────────────────
if model_class == 'EquivariantCNN':
# escnn shares internal FieldType buffers across forward passes.
# Neither module hooks nor .backward() can be safely reused across
# multiple calls without triggering "graph freed" errors.
#
# Solution: torch.autograd.grad() with a CLONED fresh input each call.
# autograd.grad() computes gradients without accumulating into .grad
# attributes and without retaining the graph β€” it is a pure functional
# call that does not interact with escnn's internal state across calls.
#
# We use input-gradient saliency: d(class_score)/d(input_pixels).
# Reference: Simonyan et al. (2013) "Deep Inside Convolutional Networks"
import cv2
# Clone + detach β€” completely fresh tensor, no graph history
inp = image_tensor.detach().clone().to(device).requires_grad_(True)
# Get prediction first with no_grad (fast, no graph)
with torch.no_grad():
logits_nd = model(inp.detach())
probs = F.softmax(logits_nd, dim=1)
if class_idx is None:
class_idx = int(logits_nd.argmax(dim=1).item())
pred_probs = probs.squeeze(0).cpu().numpy()
# Now run a fresh forward pass WITH grad just for saliency
inp2 = image_tensor.detach().clone().to(device).requires_grad_(True)
logits2 = model(inp2)
score = logits2[0, class_idx]
# autograd.grad: functional, does not touch .grad, no graph retention
grads = torch.autograd.grad(
outputs=score,
inputs=inp2,
retain_graph=False,
create_graph=False,
)[0] # (1, C, H, W)
# Saliency map: abs gradient averaged over channels β†’ (H, W)
cam_tensor = grads.detach().abs().squeeze(0).mean(dim=0)
cam_min = cam_tensor.min()
cam_max = cam_tensor.max()
if cam_max > cam_min:
cam_tensor = (cam_tensor - cam_min) / (cam_max - cam_min)
else:
cam_tensor = torch.zeros_like(cam_tensor)
cam_np = cam_tensor.cpu().numpy()
h_in, w_in = image_tensor.shape[2], image_tensor.shape[3]
cam_resized = cv2.resize(cam_np, (w_in, h_in), interpolation=cv2.INTER_LINEAR)
return cam_resized, class_idx, pred_probs
elif model_class in ('ResNetBaseline', 'ResNetTransfer'):
def _save_activation(name):
def hook(module, input, output):
activations[name] = output.detach()
return hook
def _save_gradient(name):
def hook(module, grad_input, grad_output):
gradients[name] = grad_output[0].detach()
return hook
target_module = model.model.layer4[-1]
fwd_hook = target_module.register_forward_hook(_save_activation('target'))
bwd_hook = target_module.register_full_backward_hook(_save_gradient('target'))
try:
model.zero_grad()
image_tensor = image_tensor.detach().requires_grad_(True)
with torch.enable_grad():
logits = model(image_tensor)
probs = F.softmax(logits, dim=1)
if class_idx is None:
class_idx = int(logits.argmax(dim=1).item())
pred_probs = probs.squeeze(0).detach().cpu().numpy()
score = logits[0, class_idx]
score.backward()
finally:
fwd_hook.remove()
bwd_hook.remove()
else:
raise ValueError(
f"compute_gradcam: unsupported model class '{model_class}'. "
"Use ResNetBaseline, ResNetTransfer, ViTChampion, or EquivariantCNN."
)
# ── Compute GradCAM ───────────────────────────────────────────────────
act = activations['target'].squeeze(0) # (C, H, W)
grad = gradients['target'].squeeze(0) # (C, H, W)
# Global average pool the gradients β†’ importance weights per channel
weights = grad.mean(dim=(1, 2)) # (C,)
# Weighted combination of activation maps
# act lives on GPU when CUDA is available β€” cam must match its device
cam = torch.zeros(act.shape[1:], dtype=torch.float32, device=act.device)
for c, w in enumerate(weights):
cam += w * act[c]
# ReLU β€” only keep positive contributions
cam = F.relu(cam)
# Normalise to [0, 1]
cam_min = cam.min()
cam_max = cam.max()
if cam_max > cam_min:
cam = (cam - cam_min) / (cam_max - cam_min)
else:
cam = torch.zeros_like(cam)
cam_np = cam.detach().cpu().numpy()
# Resize to input spatial resolution
import cv2
h_in = image_tensor.shape[2]
w_in = image_tensor.shape[3]
cam_resized = cv2.resize(cam_np, (w_in, h_in), interpolation=cv2.INTER_LINEAR)
return cam_resized, class_idx, pred_probs
# ─────────────────────────────────────────────────────────────────────────────
# 9. ATTENTION ROLLOUT β€” ViT models [GSOC UPGRADE: NEW]
# ─────────────────────────────────────────────────────────────────────────────
#
# Standard GradCAM cannot directly be applied to ViT because there is no
# spatial feature map equivalent to a CNN's layer4. Instead we use
# Attention Rollout (Abnar & Zuidema, 2020), which recursively multiplies
# the attention weights across all transformer layers and extracts the
# CLS-token row to obtain a per-patch importance map.
#
# Reference: Abnar & Zuidema (2020) β€” "Quantifying Attention Flow in
# Transformers"
# ─────────────────────────────────────────────────────────────────────────────
def compute_attention_rollout(
model,
image_tensor,
class_idx=None,
device=None,
):
"""
Computes an Attention Rollout map for a ViTChampion model.
The ViT-B/16 divides the 224Γ—224 image into a 14Γ—14 grid of patches.
Attention Rollout gives an importance score per patch, which we reshape
to (14, 14) and upsample to (224, 224) for overlay.
Args:
model : ViTChampion instance.
image_tensor : torch.Tensor of shape (1, 3, 224, 224).
class_idx : int β€” target class (unused by rollout, included for API compat).
device : torch.device.
Returns:
rollout_map : np.ndarray of shape (H, W) β€” attention map in [0, 1].
pred_class : int β€” predicted class index.
pred_probs : np.ndarray of shape (3,) β€” softmax probabilities.
"""
import torch
import torch.nn.functional as F
import cv2
if device is None:
device = next(model.parameters()).device
image_tensor = image_tensor.to(device)
model.eval()
attention_maps = []
def _make_attn_hook(layer_idx):
def hook(module, input, output):
# torchvision ViT: the attention module is nn.MultiheadAttention
# We need to re-run the attention computation to extract weights.
# However, torchvision's implementation does not expose weights
# through the standard hook. We capture the input Q/K/V tensors
# and recompute scaled dot-product attention manually.
q, k, v = input[0], input[0], input[0] # self-attention: Q=K=V
B, S, D = q.shape
H = module.num_heads
d_h = D // H
# Reshape to (B, H, S, d_h)
q = q.reshape(B, S, H, d_h).transpose(1, 2)
k = k.reshape(B, S, H, d_h).transpose(1, 2)
scale = d_h ** -0.5
attn = torch.softmax((q @ k.transpose(-2, -1)) * scale, dim=-1)
# attn: (B, H, S, S) β€” average over heads
attn_avg = attn.mean(dim=1).detach().cpu() # (B, S, S)
attention_maps.append(attn_avg)
return hook
# Register hooks on every encoder block's self_attention module
hooks = []
for i, encoder_block in enumerate(model.model.encoder.layers):
# torchvision ViT encoder block has .self_attention attribute
h = encoder_block.self_attention.register_forward_hook(_make_attn_hook(i))
hooks.append(h)
try:
with torch.no_grad():
logits = model(image_tensor)
probs = F.softmax(logits, dim=1)
finally:
for h in hooks:
h.remove()
pred_class = int(logits.argmax(dim=1).item())
pred_probs = probs.squeeze(0).cpu().numpy()
if class_idx is None:
class_idx = pred_class
# ── Rollout computation ───────────────────────────────────────────────
# attention_maps: list of (B, S, S) tensors, one per transformer layer
# S = 197 (196 patches + 1 CLS token), B = 1
rollout = torch.eye(attention_maps[0].shape[-1]) # (S, S) identity
for attn in attention_maps:
# Add residual connection: A_hat = 0.5 * A + 0.5 * I
attn_layer = attn[0] # (S, S)
attn_with_res = 0.5 * attn_layer + 0.5 * torch.eye(attn_layer.shape[0])
# Row-normalise
attn_with_res = attn_with_res / attn_with_res.sum(dim=-1, keepdim=True)
rollout = attn_with_res @ rollout
# CLS token (row 0) attention to all patches (columns 1:)
cls_attention = rollout[0, 1:] # (196,)
grid_size = int(cls_attention.shape[0] ** 0.5) # 14 for ViT-B/16
attn_map = cls_attention.reshape(grid_size, grid_size).numpy()
# Normalise
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
# Upsample to original image resolution
h_in = image_tensor.shape[2]
w_in = image_tensor.shape[3]
rollout_resized = cv2.resize(attn_map, (w_in, h_in), interpolation=cv2.INTER_LINEAR)
return rollout_resized, pred_class, pred_probs
# ─────────────────────────────────────────────────────────────────────────────
# 10. GRADCAM OVERLAY VISUALISATION [GSOC UPGRADE: NEW]
# ─────────────────────────────────────────────────────────────────────────────
def overlay_gradcam(
image_np,
cam,
alpha=0.45,
colormap='jet',
):
"""
Overlays a GradCAM / Attention Rollout heatmap onto the original image.
Args:
image_np : np.ndarray of shape (H, W) or (H, W, 3), float in [0, 1].
cam : np.ndarray of shape (H, W), float in [0, 1].
alpha : float β€” heatmap opacity (default 0.45).
colormap : str β€” matplotlib colormap name (default 'jet').
Returns:
blended : np.ndarray of shape (H, W, 3), float in [0, 1].
"""
import matplotlib.cm as cm_module
# Ensure image is (H, W, 3)
if image_np.ndim == 2:
image_rgb = np.stack([image_np] * 3, axis=-1)
elif image_np.shape[2] == 1:
image_rgb = np.concatenate([image_np] * 3, axis=-1)
else:
image_rgb = image_np.copy()
# Normalise image to [0, 1]
image_rgb = image_rgb.astype(float)
img_min, img_max = image_rgb.min(), image_rgb.max()
if img_max > img_min:
image_rgb = (image_rgb - img_min) / (img_max - img_min)
# Apply colormap to CAM
cmap_fn = cm_module.get_cmap(colormap)
heatmap = cmap_fn(cam)[:, :, :3] # drop alpha channel β†’ (H, W, 3)
# Alpha blend
blended = (1 - alpha) * image_rgb + alpha * heatmap
blended = np.clip(blended, 0.0, 1.0)
return blended
def save_gradcam_visualization(
model,
dataloader,
device,
classes=CLASSES,
save_dir=None,
model_name='Model',
n_samples_per_class=2,
denormalize_fn=None,
):
"""
Generates and saves GradCAM / Attention Rollout visualisations for
n_samples_per_class images from each class.
Creates one subplot grid per class showing:
[Original Image | GradCAM Overlay | CAM Heatmap]
Handles the escnn / EquivariantCNN edge case automatically via
compute_gradcam() (hooks group_pool instead of a conv layer).
Args:
model : Any DeepLense model.
dataloader : A DataLoader returning (images, labels).
device : torch.device.
classes : list[str] β€” class names.
save_dir : str β€” directory to save plots. If None, shows only.
model_name : str β€” used in plot titles and filenames.
n_samples_per_class : int β€” how many images to visualise per class.
denormalize_fn : callable(tensor) β†’ np.ndarray, optional.
If None, uses the standard ImageNet denorm for RGB
or the grayscale denorm automatically detected from
the image shape.
Returns:
None
"""
import torch
model.eval()
_ensure_dir(save_dir + '/placeholder.png' if save_dir else None)
# ── Collect n_samples_per_class images per class ──────────────────────
collected = {i: [] for i in range(len(classes))}
n_needed = n_samples_per_class * len(classes)
n_found = 0
for images, labels in dataloader:
for img, lbl in zip(images, labels):
c = int(lbl.item())
if len(collected[c]) < n_samples_per_class:
collected[c].append(img)
n_found += 1
if n_found >= n_needed:
break
if n_found >= n_needed:
break
# ── Default denormalise ───────────────────────────────────────────────
def _default_denorm(t):
"""Undo ImageNet normalisation for RGB or grayscale tensors."""
t = t.cpu().numpy()
if t.shape[0] == 3: # RGB: (3, H, W)
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
t = t.transpose(1, 2, 0) * std + mean # (H, W, 3)
else: # Grayscale: (1, H, W)
t = t[0] * 0.5 + 0.5 # undo Normalize(0.5, 0.5)
return np.clip(t, 0, 1)
denorm = denormalize_fn if denormalize_fn is not None else _default_denorm
# ── Generate plots ────────────────────────────────────────────────────
for class_idx, class_name in enumerate(classes):
imgs = collected[class_idx]
if not imgs:
print(f" ⚠️ No samples found for class '{class_name}' β€” skipping.")
continue
n_cols = 3 # Original | Overlay | Heatmap
n_rows = len(imgs)
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 4))
if n_rows == 1:
axes = axes[np.newaxis, :] # ensure 2-D indexing
fig.suptitle(
f'GradCAM β€” {model_name} | Class: {class_name}',
fontsize=14, fontweight='bold'
)
for row_idx, img_tensor in enumerate(imgs):
input_batch = img_tensor.unsqueeze(0) # (1, C, H, W)
img_np = denorm(img_tensor)
cam, pred_cls, pred_probs = compute_gradcam(
model, input_batch, class_idx=class_idx, device=device
)
overlay = overlay_gradcam(img_np, cam)
# Original
ax_orig = axes[row_idx, 0]
if img_np.ndim == 2 or (img_np.ndim == 3 and img_np.shape[-1] == 1):
ax_orig.imshow(img_np if img_np.ndim == 2 else img_np[:, :, 0], cmap='gray')
else:
ax_orig.imshow(img_np)
ax_orig.set_title(f'Original\nTrue: {class_name}', fontsize=10)
ax_orig.axis('off')
# Overlay
ax_ovr = axes[row_idx, 1]
ax_ovr.imshow(overlay)
ax_ovr.set_title(
f'GradCAM Overlay\nPred: {classes[pred_cls]} '
f'({pred_probs[pred_cls]*100:.1f}%)',
fontsize=10
)
ax_ovr.axis('off')
# Pure heatmap
ax_cam = axes[row_idx, 2]
im = ax_cam.imshow(cam, cmap='jet', vmin=0, vmax=1)
ax_cam.set_title('CAM Heatmap', fontsize=10)
ax_cam.axis('off')
fig.colorbar(im, ax=ax_cam, fraction=0.046, pad=0.04)
plt.tight_layout()
if save_dir:
safe_name = class_name.lower().replace(' ', '_')
save_path = os.path.join(
save_dir, f'gradcam_{model_name.lower().replace(" ", "_")}_{safe_name}.png'
)
fig.savefig(save_path, bbox_inches='tight', dpi=200)
print(f"πŸ“Š GradCAM saved β†’ {save_path}")
plt.show()
plt.close()