BrainConnect-ASD / brain_gcn /utils /visualization.py
Yatsuiii's picture
Upload folder using huggingface_hub
16d6869 verified
"""
Comprehensive visualization and analysis suite.
Features:
- Model comparison plots
- Brain connectivity heatmaps
- Training curves and loss landscapes
- Confusion matrices and ROC curves (already in evaluation.py)
- Feature importance and attention maps
- Interactive dashboards (via plotly)
- Statistical group comparisons
- Model ensemble visualization
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Tuple
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix
import torch
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Brain Connectivity Visualization
# ---------------------------------------------------------------------------
class BrainConnectivityVisualizer:
"""Visualize functional connectivity patterns."""
@staticmethod
def plot_connectivity_matrix(
connectivity: np.ndarray,
title: str = "Functional Connectivity",
output_path: str | Path | None = None,
cmap: str = "coolwarm",
vmin: float | None = None,
vmax: float | None = None,
) -> None:
"""Plot connectivity matrix as heatmap.
Parameters
----------
connectivity : (N, N) array
Connectivity matrix
title : str
Plot title
output_path : Path, optional
Save figure
cmap : str
Colormap
vmin, vmax : float
Color scale limits
"""
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(connectivity, cmap=cmap, vmin=vmin, vmax=vmax, aspect='auto')
ax.set_xlabel("ROI")
ax.set_ylabel("ROI")
ax.set_title(title, fontsize=14, fontweight='bold')
cbar = plt.colorbar(im, ax=ax)
cbar.set_label("Correlation", rotation=270, labelpad=20)
plt.tight_layout()
if output_path:
plt.savefig(output_path, dpi=150, bbox_inches='tight')
log.info(f"Saved to {output_path}")
plt.close()
@staticmethod
def plot_connectivity_comparison(
conn_asd: np.ndarray,
conn_td: np.ndarray,
title: str = "Connectivity Comparison (ASD vs TD)",
output_path: str | Path | None = None,
) -> None:
"""Compare group connectivity patterns.
Parameters
----------
conn_asd, conn_td : (N, N) arrays
Mean connectivity for each group
"""
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
vmax = max(np.abs(conn_asd).max(), np.abs(conn_td).max())
# ASD
im1 = axes[0].imshow(conn_asd, cmap='coolwarm', vmin=-vmax, vmax=vmax)
axes[0].set_title("ASD Mean", fontweight='bold')
axes[0].set_xlabel("ROI")
axes[0].set_ylabel("ROI")
plt.colorbar(im1, ax=axes[0])
# TD
im2 = axes[1].imshow(conn_td, cmap='coolwarm', vmin=-vmax, vmax=vmax)
axes[1].set_title("TD Mean", fontweight='bold')
axes[1].set_xlabel("ROI")
axes[1].set_ylabel("ROI")
plt.colorbar(im2, ax=axes[1])
# Difference
diff = conn_asd - conn_td
im3 = axes[2].imshow(diff, cmap='RdBu_r', vmin=-np.abs(diff).max(), vmax=np.abs(diff).max())
axes[2].set_title("ASD - TD", fontweight='bold')
axes[2].set_xlabel("ROI")
axes[2].set_ylabel("ROI")
plt.colorbar(im3, ax=axes[2])
plt.suptitle(title, fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
if output_path:
plt.savefig(output_path, dpi=150, bbox_inches='tight')
log.info(f"Saved to {output_path}")
plt.close()
@staticmethod
def plot_dynamic_connectivity(
fc_windows: np.ndarray,
output_path: str | Path | None = None,
) -> None:
"""Visualize connectivity dynamics over time.
Takes mean correlation strength per window.
Parameters
----------
fc_windows : (W, N, N) array
Connectivity per window
"""
# Compute mean absolute connectivity per window
strength = np.abs(fc_windows).mean(axis=(1, 2))
fig, ax = plt.subplots(figsize=(12, 4))
ax.plot(strength, linewidth=2, color='steelblue')
ax.fill_between(range(len(strength)), strength, alpha=0.3, color='steelblue')
ax.set_xlabel("Time Window")
ax.set_ylabel("Mean Connectivity Strength")
ax.set_title("Dynamic Functional Connectivity", fontweight='bold')
ax.grid(alpha=0.3)
plt.tight_layout()
if output_path:
plt.savefig(output_path, dpi=150, bbox_inches='tight')
log.info(f"Saved to {output_path}")
plt.close()
# ---------------------------------------------------------------------------
# Model Analysis & Comparison
# ---------------------------------------------------------------------------
class ModelAnalyzer:
"""Analyze and compare model performance."""
@staticmethod
def plot_model_comparison(
results: dict[str, dict],
metric: str = "test_auc",
output_path: str | Path | None = None,
) -> None:
"""Compare metrics across models.
Parameters
----------
results : dict
{model_name: {metric: value, ...}, ...}
metric : str
Metric to compare
"""
models = list(results.keys())
values = [results[m].get(metric, 0) for m in models]
fig, ax = plt.subplots(figsize=(10, 6))
bars = ax.bar(models, values, color='steelblue', alpha=0.7, edgecolor='black')
# Add value labels on bars
for bar, val in zip(bars, values):
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width() / 2., height,
f'{val:.4f}', ha='center', va='bottom', fontsize=10)
ax.set_ylabel(metric.capitalize(), fontweight='bold')
ax.set_title(f"Model Comparison: {metric}", fontweight='bold', fontsize=14)
ax.set_ylim([0, max(values) * 1.1])
ax.grid(axis='y', alpha=0.3)
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
if output_path:
plt.savefig(output_path, dpi=150, bbox_inches='tight')
log.info(f"Saved to {output_path}")
plt.close()
@staticmethod
def plot_confusion_matrix(
y_true: np.ndarray,
y_pred: np.ndarray,
labels: list[str] | None = None,
output_path: str | Path | None = None,
) -> None:
"""Plot confusion matrix heatmap.
Parameters
----------
y_true, y_pred : (N,) arrays
True and predicted labels
labels : list[str]
Class names (e.g., ["TD", "ASD"])
"""
if labels is None:
labels = ["Class 0", "Class 1"]
cm = confusion_matrix(y_true, y_pred)
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
xticklabels=labels, yticklabels=labels,
cbar_kws={'label': 'Count'})
ax.set_ylabel("True Label", fontweight='bold')
ax.set_xlabel("Predicted Label", fontweight='bold')
ax.set_title("Confusion Matrix", fontweight='bold', fontsize=14)
plt.tight_layout()
if output_path:
plt.savefig(output_path, dpi=150, bbox_inches='tight')
log.info(f"Saved to {output_path}")
plt.close()
# ---------------------------------------------------------------------------
# Training Analysis
# ---------------------------------------------------------------------------
class TrainingAnalyzer:
"""Analyze training dynamics."""
@staticmethod
def plot_training_curves(
train_loss: list[float],
val_loss: list[float],
train_metric: list[float] | None = None,
val_metric: list[float] | None = None,
metric_name: str = "AUC",
output_path: str | Path | None = None,
) -> None:
"""Plot loss and metric curves.
Parameters
----------
train_loss, val_loss : list[float]
Training/validation loss per epoch
train_metric, val_metric : list[float], optional
Training/validation metric per epoch
metric_name : str
Name of metric (e.g., "AUC", "Accuracy")
"""
epochs = range(1, len(train_loss) + 1)
if train_metric is not None and val_metric is not None:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))
else:
fig, ax1 = plt.subplots(figsize=(8, 5))
# Loss
ax1.plot(epochs, train_loss, 'o-', label='Train', linewidth=2, markersize=4)
ax1.plot(epochs, val_loss, 's-', label='Validation', linewidth=2, markersize=4)
ax1.set_xlabel("Epoch", fontweight='bold')
ax1.set_ylabel("Loss", fontweight='bold')
ax1.set_title("Training Loss", fontweight='bold')
ax1.legend()
ax1.grid(alpha=0.3)
# Metric
if train_metric is not None and val_metric is not None:
ax2.plot(epochs, train_metric, 'o-', label='Train', linewidth=2, markersize=4)
ax2.plot(epochs, val_metric, 's-', label='Validation', linewidth=2, markersize=4)
ax2.set_xlabel("Epoch", fontweight='bold')
ax2.set_ylabel(metric_name, fontweight='bold')
ax2.set_title(f"Training {metric_name}", fontweight='bold')
ax2.legend()
ax2.grid(alpha=0.3)
plt.tight_layout()
if output_path:
plt.savefig(output_path, dpi=150, bbox_inches='tight')
log.info(f"Saved to {output_path}")
plt.close()
@staticmethod
def plot_learning_rate_schedule(
lrs: list[float],
output_path: str | Path | None = None,
) -> None:
"""Visualize learning rate schedule.
Parameters
----------
lrs : list[float]
Learning rate per epoch
"""
fig, ax = plt.subplots(figsize=(10, 5))
ax.semilogy(range(1, len(lrs) + 1), lrs, 'o-', linewidth=2, markersize=5)
ax.set_xlabel("Epoch", fontweight='bold')
ax.set_ylabel("Learning Rate", fontweight='bold')
ax.set_title("Learning Rate Schedule", fontweight='bold', fontsize=14)
ax.grid(alpha=0.3)
plt.tight_layout()
if output_path:
plt.savefig(output_path, dpi=150, bbox_inches='tight')
log.info(f"Saved to {output_path}")
plt.close()
# ---------------------------------------------------------------------------
# Attention & Feature Importance
# ---------------------------------------------------------------------------
class AttentionVisualizer:
"""Visualize model attention mechanisms."""
@staticmethod
def plot_roi_attention(
attention_weights: np.ndarray,
roi_names: list[str] | None = None,
output_path: str | Path | None = None,
top_k: int = 20,
) -> None:
"""Plot top ROIs by attention weight.
Parameters
----------
attention_weights : (N,) array
Attention weight per ROI
roi_names : list[str], optional
ROI names
top_k : int
Number of top ROIs to show
"""
top_idx = np.argsort(attention_weights)[-top_k:][::-1]
top_weights = attention_weights[top_idx]
if roi_names is None:
roi_labels = [f"ROI {i}" for i in top_idx]
else:
roi_labels = [roi_names[i] for i in top_idx]
fig, ax = plt.subplots(figsize=(10, 8))
bars = ax.barh(range(len(top_weights)), top_weights, color='viridis')
# Color gradient
colors = plt.cm.viridis(np.linspace(0, 1, len(top_weights)))
for bar, color in zip(bars, colors):
bar.set_color(color)
ax.set_yticks(range(len(top_weights)))
ax.set_yticklabels(roi_labels, fontsize=10)
ax.set_xlabel("Attention Weight", fontweight='bold')
ax.set_title(f"Top {top_k} ROIs by Attention", fontweight='bold', fontsize=14)
ax.grid(axis='x', alpha=0.3)
plt.tight_layout()
if output_path:
plt.savefig(output_path, dpi=150, bbox_inches='tight')
log.info(f"Saved to {output_path}")
plt.close()
# ---------------------------------------------------------------------------
# Statistical Visualization
# ---------------------------------------------------------------------------
class StatisticalVisualizer:
"""Visualize statistical group differences."""
@staticmethod
def plot_group_comparison(
asd_values: np.ndarray,
td_values: np.ndarray,
metric_name: str = "Metric",
output_path: str | Path | None = None,
) -> None:
"""Violin plot of group differences.
Parameters
----------
asd_values, td_values : (N,) arrays
Metric values for each group
metric_name : str
Name of metric
"""
fig, ax = plt.subplots(figsize=(8, 6))
data = [td_values, asd_values]
parts = ax.violinplot(data, positions=[0, 1], showmeans=True, showmedians=True)
ax.set_xticks([0, 1])
ax.set_xticklabels(["TD", "ASD"])
ax.set_ylabel(metric_name, fontweight='bold')
ax.set_title(f"Group Comparison: {metric_name}", fontweight='bold', fontsize=14)
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
if output_path:
plt.savefig(output_path, dpi=150, bbox_inches='tight')
log.info(f"Saved to {output_path}")
plt.close()
# ---------------------------------------------------------------------------
# Visualization Registry
# ---------------------------------------------------------------------------
class VisualizationRegistry:
"""Registry for all visualization functions."""
BRAIN_CONNECTIVITY = BrainConnectivityVisualizer
MODEL_ANALYSIS = ModelAnalyzer
TRAINING = TrainingAnalyzer
ATTENTION = AttentionVisualizer
STATISTICS = StatisticalVisualizer
def create_analysis_summary(
results_dir: str | Path,
model_results: dict,
connectivity_data: dict | None = None,
) -> None:
"""Generate comprehensive analysis summary.
Parameters
----------
results_dir : Path
Output directory for figures
model_results : dict
Dictionary of {model_name: {metric: value}}
connectivity_data : dict, optional
{group: connectivity_matrix}
"""
results_dir = Path(results_dir)
results_dir.mkdir(parents=True, exist_ok=True)
# Model comparison
ModelAnalyzer.plot_model_comparison(
model_results,
metric="test_auc",
output_path=results_dir / "01_model_comparison_auc.png",
)
# Connectivity comparison if provided
if connectivity_data and 'asd' in connectivity_data and 'td' in connectivity_data:
BrainConnectivityVisualizer.plot_connectivity_comparison(
connectivity_data['asd'],
connectivity_data['td'],
output_path=results_dir / "02_connectivity_comparison.png",
)
log.info(f"Analysis summary saved to {results_dir}")