| """ |
| Comprehensive visualization and analysis suite. |
| |
| Features: |
| - Model comparison plots |
| - Brain connectivity heatmaps |
| - Training curves and loss landscapes |
| - Confusion matrices and ROC curves (already in evaluation.py) |
| - Feature importance and attention maps |
| - Interactive dashboards (via plotly) |
| - Statistical group comparisons |
| - Model ensemble visualization |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| from pathlib import Path |
| from typing import Tuple |
|
|
| import matplotlib.pyplot as plt |
| import matplotlib.patches as mpatches |
| import numpy as np |
| import seaborn as sns |
| from sklearn.metrics import confusion_matrix |
| import torch |
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| class BrainConnectivityVisualizer: |
| """Visualize functional connectivity patterns.""" |
|
|
| @staticmethod |
| def plot_connectivity_matrix( |
| connectivity: np.ndarray, |
| title: str = "Functional Connectivity", |
| output_path: str | Path | None = None, |
| cmap: str = "coolwarm", |
| vmin: float | None = None, |
| vmax: float | None = None, |
| ) -> None: |
| """Plot connectivity matrix as heatmap. |
| |
| Parameters |
| ---------- |
| connectivity : (N, N) array |
| Connectivity matrix |
| title : str |
| Plot title |
| output_path : Path, optional |
| Save figure |
| cmap : str |
| Colormap |
| vmin, vmax : float |
| Color scale limits |
| """ |
| fig, ax = plt.subplots(figsize=(10, 8)) |
| |
| im = ax.imshow(connectivity, cmap=cmap, vmin=vmin, vmax=vmax, aspect='auto') |
| ax.set_xlabel("ROI") |
| ax.set_ylabel("ROI") |
| ax.set_title(title, fontsize=14, fontweight='bold') |
| |
| cbar = plt.colorbar(im, ax=ax) |
| cbar.set_label("Correlation", rotation=270, labelpad=20) |
| |
| plt.tight_layout() |
| |
| if output_path: |
| plt.savefig(output_path, dpi=150, bbox_inches='tight') |
| log.info(f"Saved to {output_path}") |
| |
| plt.close() |
|
|
| @staticmethod |
| def plot_connectivity_comparison( |
| conn_asd: np.ndarray, |
| conn_td: np.ndarray, |
| title: str = "Connectivity Comparison (ASD vs TD)", |
| output_path: str | Path | None = None, |
| ) -> None: |
| """Compare group connectivity patterns. |
| |
| Parameters |
| ---------- |
| conn_asd, conn_td : (N, N) arrays |
| Mean connectivity for each group |
| """ |
| fig, axes = plt.subplots(1, 3, figsize=(15, 4)) |
| |
| vmax = max(np.abs(conn_asd).max(), np.abs(conn_td).max()) |
| |
| |
| im1 = axes[0].imshow(conn_asd, cmap='coolwarm', vmin=-vmax, vmax=vmax) |
| axes[0].set_title("ASD Mean", fontweight='bold') |
| axes[0].set_xlabel("ROI") |
| axes[0].set_ylabel("ROI") |
| plt.colorbar(im1, ax=axes[0]) |
| |
| |
| im2 = axes[1].imshow(conn_td, cmap='coolwarm', vmin=-vmax, vmax=vmax) |
| axes[1].set_title("TD Mean", fontweight='bold') |
| axes[1].set_xlabel("ROI") |
| axes[1].set_ylabel("ROI") |
| plt.colorbar(im2, ax=axes[1]) |
| |
| |
| diff = conn_asd - conn_td |
| im3 = axes[2].imshow(diff, cmap='RdBu_r', vmin=-np.abs(diff).max(), vmax=np.abs(diff).max()) |
| axes[2].set_title("ASD - TD", fontweight='bold') |
| axes[2].set_xlabel("ROI") |
| axes[2].set_ylabel("ROI") |
| plt.colorbar(im3, ax=axes[2]) |
| |
| plt.suptitle(title, fontsize=14, fontweight='bold', y=1.02) |
| plt.tight_layout() |
| |
| if output_path: |
| plt.savefig(output_path, dpi=150, bbox_inches='tight') |
| log.info(f"Saved to {output_path}") |
| |
| plt.close() |
|
|
| @staticmethod |
| def plot_dynamic_connectivity( |
| fc_windows: np.ndarray, |
| output_path: str | Path | None = None, |
| ) -> None: |
| """Visualize connectivity dynamics over time. |
| |
| Takes mean correlation strength per window. |
| |
| Parameters |
| ---------- |
| fc_windows : (W, N, N) array |
| Connectivity per window |
| """ |
| |
| strength = np.abs(fc_windows).mean(axis=(1, 2)) |
| |
| fig, ax = plt.subplots(figsize=(12, 4)) |
| ax.plot(strength, linewidth=2, color='steelblue') |
| ax.fill_between(range(len(strength)), strength, alpha=0.3, color='steelblue') |
| ax.set_xlabel("Time Window") |
| ax.set_ylabel("Mean Connectivity Strength") |
| ax.set_title("Dynamic Functional Connectivity", fontweight='bold') |
| ax.grid(alpha=0.3) |
| |
| plt.tight_layout() |
| |
| if output_path: |
| plt.savefig(output_path, dpi=150, bbox_inches='tight') |
| log.info(f"Saved to {output_path}") |
| |
| plt.close() |
|
|
|
|
| |
| |
| |
|
|
| class ModelAnalyzer: |
| """Analyze and compare model performance.""" |
|
|
| @staticmethod |
| def plot_model_comparison( |
| results: dict[str, dict], |
| metric: str = "test_auc", |
| output_path: str | Path | None = None, |
| ) -> None: |
| """Compare metrics across models. |
| |
| Parameters |
| ---------- |
| results : dict |
| {model_name: {metric: value, ...}, ...} |
| metric : str |
| Metric to compare |
| """ |
| models = list(results.keys()) |
| values = [results[m].get(metric, 0) for m in models] |
| |
| fig, ax = plt.subplots(figsize=(10, 6)) |
| bars = ax.bar(models, values, color='steelblue', alpha=0.7, edgecolor='black') |
| |
| |
| for bar, val in zip(bars, values): |
| height = bar.get_height() |
| ax.text(bar.get_x() + bar.get_width() / 2., height, |
| f'{val:.4f}', ha='center', va='bottom', fontsize=10) |
| |
| ax.set_ylabel(metric.capitalize(), fontweight='bold') |
| ax.set_title(f"Model Comparison: {metric}", fontweight='bold', fontsize=14) |
| ax.set_ylim([0, max(values) * 1.1]) |
| ax.grid(axis='y', alpha=0.3) |
| |
| plt.xticks(rotation=45, ha='right') |
| plt.tight_layout() |
| |
| if output_path: |
| plt.savefig(output_path, dpi=150, bbox_inches='tight') |
| log.info(f"Saved to {output_path}") |
| |
| plt.close() |
|
|
| @staticmethod |
| def plot_confusion_matrix( |
| y_true: np.ndarray, |
| y_pred: np.ndarray, |
| labels: list[str] | None = None, |
| output_path: str | Path | None = None, |
| ) -> None: |
| """Plot confusion matrix heatmap. |
| |
| Parameters |
| ---------- |
| y_true, y_pred : (N,) arrays |
| True and predicted labels |
| labels : list[str] |
| Class names (e.g., ["TD", "ASD"]) |
| """ |
| if labels is None: |
| labels = ["Class 0", "Class 1"] |
| |
| cm = confusion_matrix(y_true, y_pred) |
| |
| fig, ax = plt.subplots(figsize=(8, 6)) |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax, |
| xticklabels=labels, yticklabels=labels, |
| cbar_kws={'label': 'Count'}) |
| |
| ax.set_ylabel("True Label", fontweight='bold') |
| ax.set_xlabel("Predicted Label", fontweight='bold') |
| ax.set_title("Confusion Matrix", fontweight='bold', fontsize=14) |
| |
| plt.tight_layout() |
| |
| if output_path: |
| plt.savefig(output_path, dpi=150, bbox_inches='tight') |
| log.info(f"Saved to {output_path}") |
| |
| plt.close() |
|
|
|
|
| |
| |
| |
|
|
| class TrainingAnalyzer: |
| """Analyze training dynamics.""" |
|
|
| @staticmethod |
| def plot_training_curves( |
| train_loss: list[float], |
| val_loss: list[float], |
| train_metric: list[float] | None = None, |
| val_metric: list[float] | None = None, |
| metric_name: str = "AUC", |
| output_path: str | Path | None = None, |
| ) -> None: |
| """Plot loss and metric curves. |
| |
| Parameters |
| ---------- |
| train_loss, val_loss : list[float] |
| Training/validation loss per epoch |
| train_metric, val_metric : list[float], optional |
| Training/validation metric per epoch |
| metric_name : str |
| Name of metric (e.g., "AUC", "Accuracy") |
| """ |
| epochs = range(1, len(train_loss) + 1) |
| |
| if train_metric is not None and val_metric is not None: |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4)) |
| else: |
| fig, ax1 = plt.subplots(figsize=(8, 5)) |
| |
| |
| ax1.plot(epochs, train_loss, 'o-', label='Train', linewidth=2, markersize=4) |
| ax1.plot(epochs, val_loss, 's-', label='Validation', linewidth=2, markersize=4) |
| ax1.set_xlabel("Epoch", fontweight='bold') |
| ax1.set_ylabel("Loss", fontweight='bold') |
| ax1.set_title("Training Loss", fontweight='bold') |
| ax1.legend() |
| ax1.grid(alpha=0.3) |
| |
| |
| if train_metric is not None and val_metric is not None: |
| ax2.plot(epochs, train_metric, 'o-', label='Train', linewidth=2, markersize=4) |
| ax2.plot(epochs, val_metric, 's-', label='Validation', linewidth=2, markersize=4) |
| ax2.set_xlabel("Epoch", fontweight='bold') |
| ax2.set_ylabel(metric_name, fontweight='bold') |
| ax2.set_title(f"Training {metric_name}", fontweight='bold') |
| ax2.legend() |
| ax2.grid(alpha=0.3) |
| |
| plt.tight_layout() |
| |
| if output_path: |
| plt.savefig(output_path, dpi=150, bbox_inches='tight') |
| log.info(f"Saved to {output_path}") |
| |
| plt.close() |
|
|
| @staticmethod |
| def plot_learning_rate_schedule( |
| lrs: list[float], |
| output_path: str | Path | None = None, |
| ) -> None: |
| """Visualize learning rate schedule. |
| |
| Parameters |
| ---------- |
| lrs : list[float] |
| Learning rate per epoch |
| """ |
| fig, ax = plt.subplots(figsize=(10, 5)) |
| ax.semilogy(range(1, len(lrs) + 1), lrs, 'o-', linewidth=2, markersize=5) |
| ax.set_xlabel("Epoch", fontweight='bold') |
| ax.set_ylabel("Learning Rate", fontweight='bold') |
| ax.set_title("Learning Rate Schedule", fontweight='bold', fontsize=14) |
| ax.grid(alpha=0.3) |
| |
| plt.tight_layout() |
| |
| if output_path: |
| plt.savefig(output_path, dpi=150, bbox_inches='tight') |
| log.info(f"Saved to {output_path}") |
| |
| plt.close() |
|
|
|
|
| |
| |
| |
|
|
| class AttentionVisualizer: |
| """Visualize model attention mechanisms.""" |
|
|
| @staticmethod |
| def plot_roi_attention( |
| attention_weights: np.ndarray, |
| roi_names: list[str] | None = None, |
| output_path: str | Path | None = None, |
| top_k: int = 20, |
| ) -> None: |
| """Plot top ROIs by attention weight. |
| |
| Parameters |
| ---------- |
| attention_weights : (N,) array |
| Attention weight per ROI |
| roi_names : list[str], optional |
| ROI names |
| top_k : int |
| Number of top ROIs to show |
| """ |
| top_idx = np.argsort(attention_weights)[-top_k:][::-1] |
| top_weights = attention_weights[top_idx] |
| |
| if roi_names is None: |
| roi_labels = [f"ROI {i}" for i in top_idx] |
| else: |
| roi_labels = [roi_names[i] for i in top_idx] |
| |
| fig, ax = plt.subplots(figsize=(10, 8)) |
| bars = ax.barh(range(len(top_weights)), top_weights, color='viridis') |
| |
| |
| colors = plt.cm.viridis(np.linspace(0, 1, len(top_weights))) |
| for bar, color in zip(bars, colors): |
| bar.set_color(color) |
| |
| ax.set_yticks(range(len(top_weights))) |
| ax.set_yticklabels(roi_labels, fontsize=10) |
| ax.set_xlabel("Attention Weight", fontweight='bold') |
| ax.set_title(f"Top {top_k} ROIs by Attention", fontweight='bold', fontsize=14) |
| ax.grid(axis='x', alpha=0.3) |
| |
| plt.tight_layout() |
| |
| if output_path: |
| plt.savefig(output_path, dpi=150, bbox_inches='tight') |
| log.info(f"Saved to {output_path}") |
| |
| plt.close() |
|
|
|
|
| |
| |
| |
|
|
| class StatisticalVisualizer: |
| """Visualize statistical group differences.""" |
|
|
| @staticmethod |
| def plot_group_comparison( |
| asd_values: np.ndarray, |
| td_values: np.ndarray, |
| metric_name: str = "Metric", |
| output_path: str | Path | None = None, |
| ) -> None: |
| """Violin plot of group differences. |
| |
| Parameters |
| ---------- |
| asd_values, td_values : (N,) arrays |
| Metric values for each group |
| metric_name : str |
| Name of metric |
| """ |
| fig, ax = plt.subplots(figsize=(8, 6)) |
| |
| data = [td_values, asd_values] |
| parts = ax.violinplot(data, positions=[0, 1], showmeans=True, showmedians=True) |
| |
| ax.set_xticks([0, 1]) |
| ax.set_xticklabels(["TD", "ASD"]) |
| ax.set_ylabel(metric_name, fontweight='bold') |
| ax.set_title(f"Group Comparison: {metric_name}", fontweight='bold', fontsize=14) |
| ax.grid(axis='y', alpha=0.3) |
| |
| plt.tight_layout() |
| |
| if output_path: |
| plt.savefig(output_path, dpi=150, bbox_inches='tight') |
| log.info(f"Saved to {output_path}") |
| |
| plt.close() |
|
|
|
|
| |
| |
| |
|
|
| class VisualizationRegistry: |
| """Registry for all visualization functions.""" |
|
|
| BRAIN_CONNECTIVITY = BrainConnectivityVisualizer |
| MODEL_ANALYSIS = ModelAnalyzer |
| TRAINING = TrainingAnalyzer |
| ATTENTION = AttentionVisualizer |
| STATISTICS = StatisticalVisualizer |
|
|
|
|
| def create_analysis_summary( |
| results_dir: str | Path, |
| model_results: dict, |
| connectivity_data: dict | None = None, |
| ) -> None: |
| """Generate comprehensive analysis summary. |
| |
| Parameters |
| ---------- |
| results_dir : Path |
| Output directory for figures |
| model_results : dict |
| Dictionary of {model_name: {metric: value}} |
| connectivity_data : dict, optional |
| {group: connectivity_matrix} |
| """ |
| results_dir = Path(results_dir) |
| results_dir.mkdir(parents=True, exist_ok=True) |
| |
| |
| ModelAnalyzer.plot_model_comparison( |
| model_results, |
| metric="test_auc", |
| output_path=results_dir / "01_model_comparison_auc.png", |
| ) |
| |
| |
| if connectivity_data and 'asd' in connectivity_data and 'td' in connectivity_data: |
| BrainConnectivityVisualizer.plot_connectivity_comparison( |
| connectivity_data['asd'], |
| connectivity_data['td'], |
| output_path=results_dir / "02_connectivity_comparison.png", |
| ) |
| |
| log.info(f"Analysis summary saved to {results_dir}") |
|
|