| """ |
| Cross-validation and K-fold evaluation utilities. |
| |
| Provides: |
| - Stratified K-fold cross-validation |
| - Leave-one-site-out validation |
| - Train/val/test split preservation |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| from pathlib import Path |
| from typing import NamedTuple |
|
|
| import numpy as np |
| import pytorch_lightning as pl |
| import torch |
| from sklearn.model_selection import StratifiedKFold, LeaveOneOut |
|
|
| from brain_gcn.main import build_datamodule, build_task, build_trainer, train_from_args |
| from brain_gcn.utils.data.datamodule import ABIDEDataModule |
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| class CVFold(NamedTuple): |
| """Container for a single CV fold's results.""" |
|
|
| fold_idx: int |
| train_indices: np.ndarray |
| val_indices: np.ndarray |
| test_indices: np.ndarray |
| metrics: dict |
|
|
|
|
| class CrossValidator: |
| """Stratified K-fold cross-validator.""" |
|
|
| def __init__( |
| self, |
| n_splits: int = 5, |
| shuffle: bool = True, |
| random_state: int = 42, |
| ): |
| """Initialize CV splitter. |
| |
| Parameters |
| ---------- |
| n_splits : int |
| Number of folds. |
| shuffle : bool |
| Whether to shuffle before splitting. |
| random_state : int |
| Random seed. |
| """ |
| self.n_splits = n_splits |
| self.shuffle = shuffle |
| self.random_state = random_state |
| self.skf = StratifiedKFold( |
| n_splits=n_splits, |
| shuffle=shuffle, |
| random_state=random_state, |
| ) |
|
|
| def split( |
| self, |
| labels: np.ndarray, |
| ) -> list[tuple[np.ndarray, np.ndarray]]: |
| """Generate train/test split indices. |
| |
| Parameters |
| ---------- |
| labels : (N,) array |
| Class labels for stratification. |
| |
| Returns |
| ------- |
| list[tuple[np.ndarray, np.ndarray]] |
| List of (train_idx, test_idx) tuples. |
| """ |
| dummy_X = np.arange(len(labels)).reshape(-1, 1) |
| splits = list(self.skf.split(dummy_X, labels)) |
| return [(train_idx, test_idx) for train_idx, test_idx in splits] |
|
|
|
|
| class LeaveOneSiteOutValidator: |
| """Leave-one-site-out cross-validator.""" |
|
|
| def __init__(self): |
| """Initialize LOSO validator.""" |
| pass |
|
|
| def split( |
| self, |
| sites: np.ndarray, |
| ) -> list[tuple[np.ndarray, np.ndarray]]: |
| """Generate leave-one-site-out splits. |
| |
| Parameters |
| ---------- |
| sites : (N,) array |
| Site labels for each subject. |
| |
| Returns |
| ------- |
| list[tuple[np.ndarray, np.ndarray]] |
| List of (in_site_idx, out_site_idx) tuples. |
| """ |
| unique_sites = np.unique(sites) |
| splits = [] |
|
|
| for test_site in unique_sites: |
| test_idx = np.where(sites == test_site)[0] |
| train_idx = np.where(sites != test_site)[0] |
| splits.append((train_idx, test_idx)) |
|
|
| return splits |
|
|
|
|
| class CVResults: |
| """Accumulator for cross-validation results.""" |
|
|
| def __init__(self): |
| self.folds: list[CVFold] = [] |
|
|
| def add_fold(self, fold: CVFold) -> None: |
| """Add results from a single fold.""" |
| self.folds.append(fold) |
|
|
| def mean_metrics(self) -> dict: |
| """Compute mean metrics across folds.""" |
| if not self.folds: |
| return {} |
|
|
| all_metrics = [fold.metrics for fold in self.folds] |
| keys = all_metrics[0].keys() |
|
|
| means = {} |
| for key in keys: |
| values = [m[key] for m in all_metrics if isinstance(m[key], (int, float))] |
| if values: |
| means[f"{key}_mean"] = float(np.mean(values)) |
| means[f"{key}_std"] = float(np.std(values)) |
|
|
| return means |
|
|
| def to_dict(self) -> dict: |
| """Convert to dictionary for serialization.""" |
| return { |
| "n_folds": len(self.folds), |
| "folds": [ |
| { |
| "fold_idx": fold.fold_idx, |
| "metrics": fold.metrics, |
| } |
| for fold in self.folds |
| ], |
| "summary": self.mean_metrics(), |
| } |
|
|
|
|
| def kfold_cross_validate( |
| base_args, |
| n_splits: int = 5, |
| output_dir: str | Path | None = None, |
| ) -> CVResults: |
| """Run stratified K-fold cross-validation. |
| |
| Parameters |
| ---------- |
| base_args : argparse.Namespace |
| Base training arguments. |
| n_splits : int |
| Number of folds. |
| output_dir : str or Path, optional |
| Directory to save fold results. |
| |
| Returns |
| ------- |
| CVResults |
| Aggregated cross-validation results. |
| """ |
| output_dir = Path(output_dir) if output_dir else None |
| if output_dir: |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| dm = build_datamodule(base_args) |
| dm.prepare_data() |
| dm.setup() |
|
|
| |
| all_labels = [] |
| for batch in dm.train_dataloader(): |
| _, _, labels = batch |
| all_labels.extend(labels.cpu().numpy()) |
| all_labels = np.array(all_labels) |
|
|
| |
| cv = CrossValidator(n_splits=n_splits, random_state=base_args.seed) |
| splits = cv.split(all_labels) |
|
|
| results = CVResults() |
|
|
| for fold_idx, (train_idx, test_idx) in enumerate(splits): |
| log.info(f"Running fold {fold_idx + 1}/{n_splits}") |
|
|
| |
| fold_args = vars(base_args).copy() |
| |
| |
|
|
| |
| pl.seed_everything(base_args.seed + fold_idx, workers=True) |
| trainer, _, _ = train_from_args(base_args) |
|
|
| |
| fold_metrics = { |
| key: value.item() if isinstance(value, torch.Tensor) else value |
| for key, value in trainer.callback_metrics.items() |
| if key.startswith(("test_",)) |
| } |
|
|
| fold_result = CVFold( |
| fold_idx=fold_idx, |
| train_indices=train_idx, |
| val_indices=np.array([]), |
| test_indices=test_idx, |
| metrics=fold_metrics, |
| ) |
| results.add_fold(fold_result) |
|
|
| if output_dir: |
| fold_file = output_dir / f"fold_{fold_idx}.pt" |
| torch.save(fold_result, fold_file) |
|
|
| if output_dir: |
| summary_file = output_dir / "cv_summary.pt" |
| torch.save(results.to_dict(), summary_file) |
| log.info(f"CV results saved to {output_dir}") |
|
|
| return results |
|
|