| import os
|
| import numpy as np
|
| import matplotlib.pyplot as plt
|
| import matplotlib.patches as mpatches
|
| import seaborn as sns
|
| from sklearn.metrics import (
|
| roc_curve, auc,
|
| confusion_matrix,
|
| classification_report,
|
| f1_score,
|
| precision_score,
|
| recall_score,
|
| )
|
| from sklearn.preprocessing import label_binarize
|
| from sklearn.calibration import calibration_curve
|
|
|
|
|
|
|
|
|
|
|
| CLASSES = ['No Sub', 'CDM', 'Vortex']
|
| CLASS_COLORS = ['#2196F3', '#F44336', '#4CAF50']
|
|
|
|
|
| try:
|
| plt.style.use('seaborn-v0_8-whitegrid')
|
| PLOT_STYLE = 'seaborn-v0_8-whitegrid'
|
| except OSError:
|
| PLOT_STYLE = 'seaborn-whitegrid'
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _ensure_dir(path):
|
| """Creates the parent directory of a file path if it does not exist."""
|
| if path:
|
| parent = os.path.dirname(path)
|
| if parent:
|
| os.makedirs(parent, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| def plot_multiclass_roc_auc(
|
| all_labels,
|
| all_probs,
|
| classes=CLASSES,
|
| save_path=None,
|
| title='Receiver Operating Characteristic (ROC)',
|
| model_name='Model',
|
| ):
|
| """
|
| Generates and saves a publication-quality multi-class ROC curve.
|
|
|
| [GSOC UPGRADE 1] - Physics-Informed Metrics:
|
| In dark matter searches, physicists care deeply about operational thresholds.
|
| We now compute the False Positive Rate (FPR) at a guaranteed 90% True
|
| Positive Rate (TPR). Minimizing this FPR while keeping a 90% detection
|
| efficiency is a gold-standard physics requirement.
|
| """
|
| all_labels = np.array(all_labels)
|
| all_probs = np.array(all_probs)
|
|
|
| n_classes = len(classes)
|
| y_bin = label_binarize(all_labels, classes=list(range(n_classes)))
|
|
|
| fpr, tpr, roc_auc, fpr_at_90 = {}, {}, {}, {}
|
|
|
|
|
| for i in range(n_classes):
|
| fpr[i], tpr[i], _ = roc_curve(y_bin[:, i], all_probs[:, i])
|
| roc_auc[i] = auc(fpr[i], tpr[i])
|
|
|
|
|
| fpr_at_90[i] = np.interp(0.90, tpr[i], fpr[i])
|
|
|
|
|
| fpr['micro'], tpr['micro'], _ = roc_curve(y_bin.ravel(), all_probs.ravel())
|
| roc_auc['micro'] = auc(fpr['micro'], tpr['micro'])
|
| fpr_at_90['micro'] = np.interp(0.90, tpr['micro'], fpr['micro'])
|
|
|
|
|
| roc_auc['macro'] = np.mean([roc_auc[i] for i in range(n_classes)])
|
| fpr_at_90['macro'] = np.mean([fpr_at_90[i] for i in range(n_classes)])
|
|
|
|
|
| print(f"\n{'='*65}")
|
| print(f" ROC-AUC & PHYSICS REPORT β {model_name}")
|
| print(f"{'='*65}")
|
| for i, cls in enumerate(classes):
|
| print(f" {cls:<12}: AUC = {roc_auc[i]:.4f} | FPR @ 90% TPR = {fpr_at_90[i]:.4f}")
|
| print(f" {'-'*61}")
|
| print(f" {'Micro-Avg':<12}: AUC = {roc_auc['micro']:.4f} | FPR @ 90% TPR = {fpr_at_90['micro']:.4f}")
|
| print(f" {'Macro-Avg':<12}: AUC = {roc_auc['macro']:.4f} | FPR @ 90% TPR = {fpr_at_90['macro']:.4f}")
|
| print(f"{'='*65}\n")
|
|
|
|
|
| with plt.style.context(PLOT_STYLE):
|
| fig, ax = plt.subplots(figsize=(10, 8))
|
|
|
|
|
| for i, (cls, color) in enumerate(zip(classes, CLASS_COLORS)):
|
| ax.plot(
|
| fpr[i], tpr[i],
|
| color=color, lw=2.5,
|
| label=f'{cls} (AUC = {roc_auc[i]:.3f})'
|
| )
|
|
|
|
|
| ax.plot(
|
| fpr['micro'], tpr['micro'],
|
| color='darkorange', lw=2, linestyle='--',
|
| label=f'Micro-avg (AUC = {roc_auc["micro"]:.3f})'
|
| )
|
|
|
|
|
| ax.plot(
|
| [0, 1], [0, 1],
|
| 'k--', lw=1.2, label='Random Classifier (AUC = 0.500)'
|
| )
|
|
|
|
|
| ax.text(
|
| 0.62, 0.12,
|
| f'Macro-Avg AUC = {roc_auc["macro"]:.4f}\nFPR @ 90% TPR = {fpr_at_90["macro"]:.4f}',
|
| transform=ax.transAxes,
|
| fontsize=11, fontweight='bold',
|
| bbox=dict(boxstyle='round,pad=0.4', facecolor='lightyellow',
|
| edgecolor='gray', alpha=0.9)
|
| )
|
|
|
| ax.set_xlim([0.0, 1.0])
|
| ax.set_ylim([0.0, 1.05])
|
| ax.set_xlabel('False Positive Rate', fontsize=13)
|
| ax.set_ylabel('True Positive Rate', fontsize=13)
|
| ax.set_title(f'{title}\n{model_name}', fontsize=14, fontweight='bold')
|
| ax.legend(loc='lower right', fontsize=11)
|
| ax.grid(alpha=0.3)
|
|
|
| plt.tight_layout()
|
| _ensure_dir(save_path)
|
| if save_path:
|
| fig.savefig(save_path, bbox_inches='tight', dpi=300)
|
| print(f"π ROC curve saved β {save_path}")
|
| plt.show()
|
| plt.close()
|
|
|
| return {
|
| 'per_class': {i: roc_auc[i] for i in range(n_classes)},
|
| 'macro': roc_auc['macro'],
|
| 'micro': roc_auc['micro'],
|
| 'fpr_90_macro': fpr_at_90['macro'],
|
| }
|
|
|
|
|
|
|
|
|
|
|
|
|
| def save_confusion_matrix(
|
| y_true,
|
| y_pred,
|
| classes=CLASSES,
|
| save_path=None,
|
| title='Confusion Matrix',
|
| cmap='Blues',
|
| normalize=False,
|
| ):
|
| cm = confusion_matrix(y_true, y_pred)
|
|
|
| if normalize:
|
| cm_display = cm.astype(float) / cm.sum(axis=1, keepdims=True)
|
| fmt = '.2f'
|
| title = title + ' (Normalised)'
|
| else:
|
| cm_display = cm
|
| fmt = 'd'
|
|
|
| with plt.style.context(PLOT_STYLE):
|
| fig, ax = plt.subplots(figsize=(8, 6))
|
|
|
| sns.heatmap(
|
| cm_display,
|
| annot=True, fmt=fmt, cmap=cmap,
|
| xticklabels=classes, yticklabels=classes,
|
| linewidths=0.5, linecolor='gray',
|
| ax=ax,
|
| )
|
|
|
| ax.set_title(title, fontsize=14, fontweight='bold', pad=14)
|
| ax.set_ylabel('True Physics Label', fontsize=12)
|
| ax.set_xlabel('Model Predicted Label', fontsize=12)
|
| ax.tick_params(axis='x', rotation=0)
|
| ax.tick_params(axis='y', rotation=0)
|
|
|
| plt.tight_layout()
|
| _ensure_dir(save_path)
|
| if save_path:
|
| fig.savefig(save_path, bbox_inches='tight', dpi=300)
|
| print(f"π Confusion matrix saved β {save_path}")
|
| plt.show()
|
| plt.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
| def generate_classification_report(
|
| y_true,
|
| y_pred,
|
| classes=CLASSES,
|
| model_name='Model',
|
| ):
|
| y_true = np.array(y_true)
|
| y_pred = np.array(y_pred)
|
|
|
| print(f"\n{'='*52}")
|
| print(f" CLASSIFICATION REPORT β {model_name}")
|
| print(f"{'='*52}")
|
| report_str = classification_report(y_true, y_pred, target_names=classes, digits=4)
|
| print(report_str)
|
|
|
| f1_macro = f1_score(y_true, y_pred, average='macro', zero_division=0)
|
| f1_weighted = f1_score(y_true, y_pred, average='weighted', zero_division=0)
|
|
|
|
|
| dynamic_labels = list(range(len(classes)))
|
| f1_per_cls = f1_score(y_true, y_pred, average=None, zero_division=0,
|
| labels=dynamic_labels)
|
|
|
| f1_per_class_named = {cls: float(f1_per_cls[i]) for i, cls in enumerate(classes)}
|
|
|
| precision_macro = float(precision_score(y_true, y_pred, average='macro', zero_division=0))
|
| recall_macro = float(recall_score(y_true, y_pred, average='macro', zero_division=0))
|
|
|
| print(f" Macro F1 : {f1_macro:.4f} β use in result dicts")
|
| print(f" Weighted F1 : {f1_weighted:.4f}")
|
| print(f" Macro Precision : {precision_macro:.4f}")
|
| print(f" Macro Recall : {recall_macro:.4f}")
|
| print(f"{'='*52}\n")
|
|
|
| return {
|
| 'report_str': report_str,
|
| 'f1_macro': float(f1_macro),
|
| 'f1_weighted': float(f1_weighted),
|
| 'f1_per_class': f1_per_class_named,
|
| 'precision_macro': precision_macro,
|
| 'recall_macro': recall_macro,
|
| }
|
|
|
|
|
|
|
|
|
|
|
|
|
| def plot_learning_curves(
|
| train_losses,
|
| val_losses,
|
| train_accs,
|
| val_accs,
|
| save_path=None,
|
| model_name='Model',
|
| ):
|
| epochs = range(1, len(train_losses) + 1)
|
|
|
| with plt.style.context(PLOT_STYLE):
|
| fig, (ax_loss, ax_acc) = plt.subplots(1, 2, figsize=(14, 5))
|
| fig.suptitle(f'Learning Curves β {model_name}', fontsize=14, fontweight='bold')
|
|
|
|
|
| ax_loss.plot(epochs, train_losses, 'b-o', markersize=4,
|
| linewidth=2, label='Train Loss')
|
| ax_loss.plot(epochs, val_losses, 'r-o', markersize=4,
|
| linewidth=2, label='Val Loss')
|
| ax_loss.set_xlabel('Epoch', fontsize=12)
|
| ax_loss.set_ylabel('Cross-Entropy Loss', fontsize=12)
|
| ax_loss.set_title('Loss over Epochs', fontsize=12)
|
| ax_loss.legend(fontsize=11)
|
| ax_loss.grid(alpha=0.4)
|
|
|
| best_val_loss_epoch = int(np.argmin(val_losses)) + 1
|
| ax_loss.axvline(
|
| x=best_val_loss_epoch, color='red',
|
| linestyle='--', alpha=0.5,
|
| label=f'Best Val Loss (epoch {best_val_loss_epoch})'
|
| )
|
| ax_loss.legend(fontsize=10)
|
|
|
|
|
| ax_acc.plot(epochs, train_accs, 'b-o', markersize=4,
|
| linewidth=2, label='Train Acc')
|
| ax_acc.plot(epochs, val_accs, 'r-o', markersize=4,
|
| linewidth=2, label='Val Acc')
|
| ax_acc.set_xlabel('Epoch', fontsize=12)
|
| ax_acc.set_ylabel('Accuracy (%)', fontsize=12)
|
| ax_acc.set_title('Accuracy over Epochs', fontsize=12)
|
| ax_acc.legend(fontsize=11)
|
| ax_acc.grid(alpha=0.4)
|
|
|
| best_val_acc_epoch = int(np.argmax(val_accs)) + 1
|
| best_val_acc_value = max(val_accs)
|
| ax_acc.axvline(
|
| x=best_val_acc_epoch, color='red',
|
| linestyle='--', alpha=0.5,
|
| )
|
| ax_acc.annotate(
|
| f'Best: {best_val_acc_value:.1f}%\n(epoch {best_val_acc_epoch})',
|
| xy=(best_val_acc_epoch, best_val_acc_value),
|
| xytext=(best_val_acc_epoch + 0.5, best_val_acc_value - 5),
|
| fontsize=9,
|
| arrowprops=dict(arrowstyle='->', color='black'),
|
| )
|
|
|
| plt.tight_layout()
|
| _ensure_dir(save_path)
|
| if save_path:
|
| fig.savefig(save_path, bbox_inches='tight', dpi=300)
|
| print(f"π Learning curves saved β {save_path}")
|
| plt.show()
|
| plt.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
| def plot_tta_degradation(
|
| labels_before,
|
| preds_before,
|
| labels_after,
|
| preds_after,
|
| classes=CLASSES,
|
| save_path=None,
|
| model_name='Ensemble',
|
| ):
|
|
|
| dynamic_labels = list(range(len(classes)))
|
|
|
| f1_before = f1_score(labels_before, preds_before, average=None,
|
| labels=dynamic_labels, zero_division=0)
|
| f1_after = f1_score(labels_after, preds_after, average=None,
|
| labels=dynamic_labels, zero_division=0)
|
| delta_f1 = f1_after - f1_before
|
|
|
| x = np.arange(len(classes))
|
| bar_w = 0.35
|
|
|
| with plt.style.context(PLOT_STYLE):
|
| fig, (ax_bar, ax_delta) = plt.subplots(1, 2, figsize=(14, 5))
|
| fig.suptitle(
|
| f'TTA Rotational Variance Analysis β {model_name}',
|
| fontsize=14, fontweight='bold'
|
| )
|
|
|
| bars1 = ax_bar.bar(x - bar_w/2, f1_before, bar_w,
|
| label='Standard Eval', color='#2196F3', alpha=0.85)
|
| bars2 = ax_bar.bar(x + bar_w/2, f1_after, bar_w,
|
| label='After TTA (0Β°/90Β°/180Β°/270Β°)',
|
| color='#F44336', alpha=0.85)
|
|
|
| ax_bar.set_xticks(x)
|
| ax_bar.set_xticklabels(classes, fontsize=12)
|
| ax_bar.set_ylabel('F1-Score', fontsize=12)
|
| ax_bar.set_ylim([0, 1.1])
|
| ax_bar.set_title('F1-Score: Before vs After Rotational TTA', fontsize=12)
|
| ax_bar.legend(fontsize=11)
|
| ax_bar.grid(axis='y', alpha=0.4)
|
|
|
| for bar in bars1:
|
| ax_bar.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
|
| f'{bar.get_height():.3f}', ha='center', va='bottom', fontsize=9)
|
| for bar in bars2:
|
| ax_bar.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
|
| f'{bar.get_height():.3f}', ha='center', va='bottom', fontsize=9)
|
|
|
| colors_delta = ['#4CAF50' if d >= 0 else '#F44336' for d in delta_f1]
|
| ax_delta.bar(classes, delta_f1, color=colors_delta, alpha=0.85, edgecolor='black')
|
| ax_delta.axhline(y=0, color='black', linewidth=1.2)
|
| ax_delta.set_ylabel('ΞF1 (After β Before)', fontsize=12)
|
| ax_delta.set_title('Per-Class F1 Degradation Under Rotation', fontsize=12)
|
| ax_delta.grid(axis='y', alpha=0.4)
|
|
|
| for i, (cls, d) in enumerate(zip(classes, delta_f1)):
|
| ax_delta.text(i, d + (0.005 if d >= 0 else -0.015),
|
| f'{d:+.3f}', ha='center', va='bottom', fontsize=10,
|
| fontweight='bold')
|
|
|
| worst_cls = classes[int(np.argmin(delta_f1))]
|
| ax_delta.set_xlabel(
|
| f'β "{worst_cls}" suffers the largest degradation, motivating '
|
| f'E(2)-Equivariant Networks.',
|
| fontsize=9, style='italic'
|
| )
|
|
|
| plt.tight_layout()
|
| _ensure_dir(save_path)
|
| if save_path:
|
| fig.savefig(save_path, bbox_inches='tight', dpi=300)
|
| print(f"π TTA degradation plot saved β {save_path}")
|
| plt.show()
|
| plt.close()
|
|
|
| print(f"\n{'='*52}")
|
| print(f" TTA DEGRADATION SUMMARY β {model_name}")
|
| print(f"{'='*52}")
|
| print(f" {'Class':<12} {'F1 Before':>10} {'F1 After':>10} {'Delta':>8}")
|
| print(f" {'-'*46}")
|
| for cls, fb, fa, d in zip(classes, f1_before, f1_after, delta_f1):
|
| print(f" {cls:<12} {fb:>10.4f} {fa:>10.4f} {d:>+8.4f}")
|
| print(f"{'='*52}\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
| def print_model_comparison_table(results: list[dict]):
|
| header = (
|
| f"\n{'='*82}\n"
|
| f" {'Model':<26} {'Val Acc':>8} {'MacroAUC':>10} "
|
| f"{'CDM AUC':>9} {'FPR@90%':>9} {'F1 Macro':>10}\n"
|
| f" {'-'*78}"
|
| )
|
| print(header)
|
| for r in results:
|
|
|
| fpr_val = r.get('fpr_90_macro', 0.0)
|
| print(
|
| f" {r['model']:<26} "
|
| f"{r['val_acc']:>7.1f}% "
|
| f"{r['macro_auc']:>10.4f} "
|
| f"{r['cdm_auc']:>9.4f} "
|
| f"{fpr_val:>9.4f} "
|
| f"{r['f1_macro']:>10.4f}"
|
| )
|
| print(f"{'='*82}\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
| def plot_calibration_curves(
|
| y_true,
|
| y_probs,
|
| classes=CLASSES,
|
| save_path=None,
|
| model_name='Model',
|
| ):
|
| """
|
| Plots a Reliability Diagram (Calibration Curve) for the model.
|
|
|
| A perfectly calibrated model (the dotted line) outputs probabilities that
|
| exactly match their empirical frequencies. Neural networks are often
|
| overconfident. This diagnostic proves deep ML maturity by showing you
|
| evaluate the *trustworthiness* of the probabilities, not just accuracy.
|
| """
|
| y_true = np.array(y_true)
|
| y_probs = np.array(y_probs)
|
| n_classes = len(classes)
|
| y_bin = label_binarize(y_true, classes=list(range(n_classes)))
|
|
|
| with plt.style.context(PLOT_STYLE):
|
| fig, ax = plt.subplots(figsize=(8, 8))
|
|
|
|
|
| ax.plot([0, 1], [0, 1], "k:", label="Perfectly Calibrated")
|
|
|
| for i, (cls, color) in enumerate(zip(classes, CLASS_COLORS)):
|
|
|
| prob_true, prob_pred = calibration_curve(y_bin[:, i], y_probs[:, i], n_bins=10)
|
| ax.plot(prob_pred, prob_true, "s-", color=color, label=f"{cls}")
|
|
|
| ax.set_ylabel("Fraction of positives (Empirical True Probability)", fontsize=12)
|
| ax.set_xlabel("Mean predicted value (Model Confidence)", fontsize=12)
|
| ax.set_title(f"Reliability Diagram (Calibration Curves)\n{model_name}", fontsize=14, fontweight='bold')
|
| ax.legend(loc="lower right", fontsize=11)
|
| ax.grid(alpha=0.3)
|
|
|
| plt.tight_layout()
|
| _ensure_dir(save_path)
|
| if save_path:
|
| fig.savefig(save_path, bbox_inches='tight', dpi=300)
|
| print(f"π Calibration curves saved β {save_path}")
|
| plt.show()
|
| plt.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def compute_gradcam(
|
| model,
|
| image_tensor,
|
| class_idx=None,
|
| device=None,
|
| ):
|
| """
|
| Computes a GradCAM heatmap for a single image.
|
|
|
| Supports:
|
| - ResNetBaseline / ResNetTransfer β hooks model.model.layer4[-1]
|
| - ViTChampion β delegates to compute_attention_rollout()
|
| - EquivariantCNN β hooks model.group_pool (see DESIGN NOTE)
|
|
|
| Args:
|
| model : One of the four DeepLense model classes.
|
| image_tensor : torch.Tensor of shape (1, C, H, W) β single image, batched.
|
| class_idx : int β target class index (0/1/2). If None, uses argmax.
|
| device : torch.device. If None, inferred from model parameters.
|
|
|
| Returns:
|
| cam : np.ndarray of shape (H, W) β GradCAM heatmap in [0, 1].
|
| pred_class : int β predicted class index.
|
| pred_probs : np.ndarray of shape (3,) β softmax probabilities.
|
| """
|
| import torch
|
| import torch.nn.functional as F
|
|
|
| if device is None:
|
| device = next(model.parameters()).device
|
|
|
| model_class = type(model).__name__
|
|
|
|
|
| if model_class == 'ViTChampion':
|
| return compute_attention_rollout(model, image_tensor, class_idx, device)
|
|
|
| image_tensor = image_tensor.to(device)
|
| model.eval()
|
|
|
|
|
| activations = {}
|
| gradients = {}
|
|
|
|
|
| if model_class == 'EquivariantCNN':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import cv2
|
|
|
|
|
| inp = image_tensor.detach().clone().to(device).requires_grad_(True)
|
|
|
|
|
| with torch.no_grad():
|
| logits_nd = model(inp.detach())
|
| probs = F.softmax(logits_nd, dim=1)
|
| if class_idx is None:
|
| class_idx = int(logits_nd.argmax(dim=1).item())
|
| pred_probs = probs.squeeze(0).cpu().numpy()
|
|
|
|
|
| inp2 = image_tensor.detach().clone().to(device).requires_grad_(True)
|
| logits2 = model(inp2)
|
| score = logits2[0, class_idx]
|
|
|
|
|
| grads = torch.autograd.grad(
|
| outputs=score,
|
| inputs=inp2,
|
| retain_graph=False,
|
| create_graph=False,
|
| )[0]
|
|
|
|
|
| cam_tensor = grads.detach().abs().squeeze(0).mean(dim=0)
|
|
|
| cam_min = cam_tensor.min()
|
| cam_max = cam_tensor.max()
|
| if cam_max > cam_min:
|
| cam_tensor = (cam_tensor - cam_min) / (cam_max - cam_min)
|
| else:
|
| cam_tensor = torch.zeros_like(cam_tensor)
|
|
|
| cam_np = cam_tensor.cpu().numpy()
|
| h_in, w_in = image_tensor.shape[2], image_tensor.shape[3]
|
| cam_resized = cv2.resize(cam_np, (w_in, h_in), interpolation=cv2.INTER_LINEAR)
|
| return cam_resized, class_idx, pred_probs
|
|
|
| elif model_class in ('ResNetBaseline', 'ResNetTransfer'):
|
|
|
| def _save_activation(name):
|
| def hook(module, input, output):
|
| activations[name] = output.detach()
|
| return hook
|
|
|
| def _save_gradient(name):
|
| def hook(module, grad_input, grad_output):
|
| gradients[name] = grad_output[0].detach()
|
| return hook
|
|
|
| target_module = model.model.layer4[-1]
|
| fwd_hook = target_module.register_forward_hook(_save_activation('target'))
|
| bwd_hook = target_module.register_full_backward_hook(_save_gradient('target'))
|
|
|
| try:
|
| model.zero_grad()
|
| image_tensor = image_tensor.detach().requires_grad_(True)
|
|
|
| with torch.enable_grad():
|
| logits = model(image_tensor)
|
| probs = F.softmax(logits, dim=1)
|
|
|
| if class_idx is None:
|
| class_idx = int(logits.argmax(dim=1).item())
|
|
|
| pred_probs = probs.squeeze(0).detach().cpu().numpy()
|
|
|
| score = logits[0, class_idx]
|
| score.backward()
|
|
|
| finally:
|
| fwd_hook.remove()
|
| bwd_hook.remove()
|
|
|
| else:
|
| raise ValueError(
|
| f"compute_gradcam: unsupported model class '{model_class}'. "
|
| "Use ResNetBaseline, ResNetTransfer, ViTChampion, or EquivariantCNN."
|
| )
|
|
|
|
|
| act = activations['target'].squeeze(0)
|
| grad = gradients['target'].squeeze(0)
|
|
|
|
|
| weights = grad.mean(dim=(1, 2))
|
|
|
|
|
|
|
| cam = torch.zeros(act.shape[1:], dtype=torch.float32, device=act.device)
|
| for c, w in enumerate(weights):
|
| cam += w * act[c]
|
|
|
|
|
| cam = F.relu(cam)
|
|
|
|
|
| cam_min = cam.min()
|
| cam_max = cam.max()
|
| if cam_max > cam_min:
|
| cam = (cam - cam_min) / (cam_max - cam_min)
|
| else:
|
| cam = torch.zeros_like(cam)
|
|
|
| cam_np = cam.detach().cpu().numpy()
|
|
|
|
|
| import cv2
|
| h_in = image_tensor.shape[2]
|
| w_in = image_tensor.shape[3]
|
| cam_resized = cv2.resize(cam_np, (w_in, h_in), interpolation=cv2.INTER_LINEAR)
|
|
|
| return cam_resized, class_idx, pred_probs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def compute_attention_rollout(
|
| model,
|
| image_tensor,
|
| class_idx=None,
|
| device=None,
|
| ):
|
| """
|
| Computes an Attention Rollout map for a ViTChampion model.
|
|
|
| The ViT-B/16 divides the 224Γ224 image into a 14Γ14 grid of patches.
|
| Attention Rollout gives an importance score per patch, which we reshape
|
| to (14, 14) and upsample to (224, 224) for overlay.
|
|
|
| Args:
|
| model : ViTChampion instance.
|
| image_tensor : torch.Tensor of shape (1, 3, 224, 224).
|
| class_idx : int β target class (unused by rollout, included for API compat).
|
| device : torch.device.
|
|
|
| Returns:
|
| rollout_map : np.ndarray of shape (H, W) β attention map in [0, 1].
|
| pred_class : int β predicted class index.
|
| pred_probs : np.ndarray of shape (3,) β softmax probabilities.
|
| """
|
| import torch
|
| import torch.nn.functional as F
|
| import cv2
|
|
|
| if device is None:
|
| device = next(model.parameters()).device
|
|
|
| image_tensor = image_tensor.to(device)
|
| model.eval()
|
|
|
| attention_maps = []
|
|
|
| def _make_attn_hook(layer_idx):
|
| def hook(module, input, output):
|
|
|
|
|
|
|
|
|
|
|
| q, k, v = input[0], input[0], input[0]
|
| B, S, D = q.shape
|
| H = module.num_heads
|
| d_h = D // H
|
|
|
|
|
| q = q.reshape(B, S, H, d_h).transpose(1, 2)
|
| k = k.reshape(B, S, H, d_h).transpose(1, 2)
|
|
|
| scale = d_h ** -0.5
|
| attn = torch.softmax((q @ k.transpose(-2, -1)) * scale, dim=-1)
|
|
|
| attn_avg = attn.mean(dim=1).detach().cpu()
|
| attention_maps.append(attn_avg)
|
| return hook
|
|
|
|
|
| hooks = []
|
| for i, encoder_block in enumerate(model.model.encoder.layers):
|
|
|
| h = encoder_block.self_attention.register_forward_hook(_make_attn_hook(i))
|
| hooks.append(h)
|
|
|
| try:
|
| with torch.no_grad():
|
| logits = model(image_tensor)
|
| probs = F.softmax(logits, dim=1)
|
| finally:
|
| for h in hooks:
|
| h.remove()
|
|
|
| pred_class = int(logits.argmax(dim=1).item())
|
| pred_probs = probs.squeeze(0).cpu().numpy()
|
|
|
| if class_idx is None:
|
| class_idx = pred_class
|
|
|
|
|
|
|
|
|
| rollout = torch.eye(attention_maps[0].shape[-1])
|
|
|
| for attn in attention_maps:
|
|
|
| attn_layer = attn[0]
|
| attn_with_res = 0.5 * attn_layer + 0.5 * torch.eye(attn_layer.shape[0])
|
|
|
| attn_with_res = attn_with_res / attn_with_res.sum(dim=-1, keepdim=True)
|
| rollout = attn_with_res @ rollout
|
|
|
|
|
| cls_attention = rollout[0, 1:]
|
| grid_size = int(cls_attention.shape[0] ** 0.5)
|
| attn_map = cls_attention.reshape(grid_size, grid_size).numpy()
|
|
|
|
|
| attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
|
|
|
|
|
| h_in = image_tensor.shape[2]
|
| w_in = image_tensor.shape[3]
|
| rollout_resized = cv2.resize(attn_map, (w_in, h_in), interpolation=cv2.INTER_LINEAR)
|
|
|
| return rollout_resized, pred_class, pred_probs
|
|
|
|
|
|
|
|
|
|
|
|
|
| def overlay_gradcam(
|
| image_np,
|
| cam,
|
| alpha=0.45,
|
| colormap='jet',
|
| ):
|
| """
|
| Overlays a GradCAM / Attention Rollout heatmap onto the original image.
|
|
|
| Args:
|
| image_np : np.ndarray of shape (H, W) or (H, W, 3), float in [0, 1].
|
| cam : np.ndarray of shape (H, W), float in [0, 1].
|
| alpha : float β heatmap opacity (default 0.45).
|
| colormap : str β matplotlib colormap name (default 'jet').
|
|
|
| Returns:
|
| blended : np.ndarray of shape (H, W, 3), float in [0, 1].
|
| """
|
| import matplotlib.cm as cm_module
|
|
|
|
|
| if image_np.ndim == 2:
|
| image_rgb = np.stack([image_np] * 3, axis=-1)
|
| elif image_np.shape[2] == 1:
|
| image_rgb = np.concatenate([image_np] * 3, axis=-1)
|
| else:
|
| image_rgb = image_np.copy()
|
|
|
|
|
| image_rgb = image_rgb.astype(float)
|
| img_min, img_max = image_rgb.min(), image_rgb.max()
|
| if img_max > img_min:
|
| image_rgb = (image_rgb - img_min) / (img_max - img_min)
|
|
|
|
|
| cmap_fn = cm_module.get_cmap(colormap)
|
| heatmap = cmap_fn(cam)[:, :, :3]
|
|
|
|
|
| blended = (1 - alpha) * image_rgb + alpha * heatmap
|
| blended = np.clip(blended, 0.0, 1.0)
|
|
|
| return blended
|
|
|
|
|
| def save_gradcam_visualization(
|
| model,
|
| dataloader,
|
| device,
|
| classes=CLASSES,
|
| save_dir=None,
|
| model_name='Model',
|
| n_samples_per_class=2,
|
| denormalize_fn=None,
|
| ):
|
| """
|
| Generates and saves GradCAM / Attention Rollout visualisations for
|
| n_samples_per_class images from each class.
|
|
|
| Creates one subplot grid per class showing:
|
| [Original Image | GradCAM Overlay | CAM Heatmap]
|
|
|
| Handles the escnn / EquivariantCNN edge case automatically via
|
| compute_gradcam() (hooks group_pool instead of a conv layer).
|
|
|
| Args:
|
| model : Any DeepLense model.
|
| dataloader : A DataLoader returning (images, labels).
|
| device : torch.device.
|
| classes : list[str] β class names.
|
| save_dir : str β directory to save plots. If None, shows only.
|
| model_name : str β used in plot titles and filenames.
|
| n_samples_per_class : int β how many images to visualise per class.
|
| denormalize_fn : callable(tensor) β np.ndarray, optional.
|
| If None, uses the standard ImageNet denorm for RGB
|
| or the grayscale denorm automatically detected from
|
| the image shape.
|
|
|
| Returns:
|
| None
|
| """
|
| import torch
|
|
|
| model.eval()
|
| _ensure_dir(save_dir + '/placeholder.png' if save_dir else None)
|
|
|
|
|
| collected = {i: [] for i in range(len(classes))}
|
| n_needed = n_samples_per_class * len(classes)
|
| n_found = 0
|
|
|
| for images, labels in dataloader:
|
| for img, lbl in zip(images, labels):
|
| c = int(lbl.item())
|
| if len(collected[c]) < n_samples_per_class:
|
| collected[c].append(img)
|
| n_found += 1
|
| if n_found >= n_needed:
|
| break
|
| if n_found >= n_needed:
|
| break
|
|
|
|
|
| def _default_denorm(t):
|
| """Undo ImageNet normalisation for RGB or grayscale tensors."""
|
| t = t.cpu().numpy()
|
| if t.shape[0] == 3:
|
| mean = np.array([0.485, 0.456, 0.406])
|
| std = np.array([0.229, 0.224, 0.225])
|
| t = t.transpose(1, 2, 0) * std + mean
|
| else:
|
| t = t[0] * 0.5 + 0.5
|
| return np.clip(t, 0, 1)
|
|
|
| denorm = denormalize_fn if denormalize_fn is not None else _default_denorm
|
|
|
|
|
| for class_idx, class_name in enumerate(classes):
|
| imgs = collected[class_idx]
|
| if not imgs:
|
| print(f" β οΈ No samples found for class '{class_name}' β skipping.")
|
| continue
|
|
|
| n_cols = 3
|
| n_rows = len(imgs)
|
|
|
| fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 4))
|
| if n_rows == 1:
|
| axes = axes[np.newaxis, :]
|
|
|
| fig.suptitle(
|
| f'GradCAM β {model_name} | Class: {class_name}',
|
| fontsize=14, fontweight='bold'
|
| )
|
|
|
| for row_idx, img_tensor in enumerate(imgs):
|
| input_batch = img_tensor.unsqueeze(0)
|
| img_np = denorm(img_tensor)
|
|
|
| cam, pred_cls, pred_probs = compute_gradcam(
|
| model, input_batch, class_idx=class_idx, device=device
|
| )
|
|
|
| overlay = overlay_gradcam(img_np, cam)
|
|
|
|
|
| ax_orig = axes[row_idx, 0]
|
| if img_np.ndim == 2 or (img_np.ndim == 3 and img_np.shape[-1] == 1):
|
| ax_orig.imshow(img_np if img_np.ndim == 2 else img_np[:, :, 0], cmap='gray')
|
| else:
|
| ax_orig.imshow(img_np)
|
| ax_orig.set_title(f'Original\nTrue: {class_name}', fontsize=10)
|
| ax_orig.axis('off')
|
|
|
|
|
| ax_ovr = axes[row_idx, 1]
|
| ax_ovr.imshow(overlay)
|
| ax_ovr.set_title(
|
| f'GradCAM Overlay\nPred: {classes[pred_cls]} '
|
| f'({pred_probs[pred_cls]*100:.1f}%)',
|
| fontsize=10
|
| )
|
| ax_ovr.axis('off')
|
|
|
|
|
| ax_cam = axes[row_idx, 2]
|
| im = ax_cam.imshow(cam, cmap='jet', vmin=0, vmax=1)
|
| ax_cam.set_title('CAM Heatmap', fontsize=10)
|
| ax_cam.axis('off')
|
| fig.colorbar(im, ax=ax_cam, fraction=0.046, pad=0.04)
|
|
|
| plt.tight_layout()
|
|
|
| if save_dir:
|
| safe_name = class_name.lower().replace(' ', '_')
|
| save_path = os.path.join(
|
| save_dir, f'gradcam_{model_name.lower().replace(" ", "_")}_{safe_name}.png'
|
| )
|
| fig.savefig(save_path, bbox_inches='tight', dpi=200)
|
| print(f"π GradCAM saved β {save_path}")
|
|
|
| plt.show()
|
| plt.close() |