BrainConnect-ASD / brain_gcn /utils /cross_validation.py
Yatsuiii's picture
Upload folder using huggingface_hub
16d6869 verified
"""
Cross-validation and K-fold evaluation utilities.
Provides:
- Stratified K-fold cross-validation
- Leave-one-site-out validation
- Train/val/test split preservation
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import NamedTuple
import numpy as np
import pytorch_lightning as pl
import torch
from sklearn.model_selection import StratifiedKFold, LeaveOneOut
from brain_gcn.main import build_datamodule, build_task, build_trainer, train_from_args
from brain_gcn.utils.data.datamodule import ABIDEDataModule
log = logging.getLogger(__name__)
class CVFold(NamedTuple):
"""Container for a single CV fold's results."""
fold_idx: int
train_indices: np.ndarray
val_indices: np.ndarray
test_indices: np.ndarray
metrics: dict # {'test_auc': ..., 'test_acc': ...}
class CrossValidator:
"""Stratified K-fold cross-validator."""
def __init__(
self,
n_splits: int = 5,
shuffle: bool = True,
random_state: int = 42,
):
"""Initialize CV splitter.
Parameters
----------
n_splits : int
Number of folds.
shuffle : bool
Whether to shuffle before splitting.
random_state : int
Random seed.
"""
self.n_splits = n_splits
self.shuffle = shuffle
self.random_state = random_state
self.skf = StratifiedKFold(
n_splits=n_splits,
shuffle=shuffle,
random_state=random_state,
)
def split(
self,
labels: np.ndarray,
) -> list[tuple[np.ndarray, np.ndarray]]:
"""Generate train/test split indices.
Parameters
----------
labels : (N,) array
Class labels for stratification.
Returns
-------
list[tuple[np.ndarray, np.ndarray]]
List of (train_idx, test_idx) tuples.
"""
dummy_X = np.arange(len(labels)).reshape(-1, 1)
splits = list(self.skf.split(dummy_X, labels))
return [(train_idx, test_idx) for train_idx, test_idx in splits]
class LeaveOneSiteOutValidator:
"""Leave-one-site-out cross-validator."""
def __init__(self):
"""Initialize LOSO validator."""
pass
def split(
self,
sites: np.ndarray,
) -> list[tuple[np.ndarray, np.ndarray]]:
"""Generate leave-one-site-out splits.
Parameters
----------
sites : (N,) array
Site labels for each subject.
Returns
-------
list[tuple[np.ndarray, np.ndarray]]
List of (in_site_idx, out_site_idx) tuples.
"""
unique_sites = np.unique(sites)
splits = []
for test_site in unique_sites:
test_idx = np.where(sites == test_site)[0]
train_idx = np.where(sites != test_site)[0]
splits.append((train_idx, test_idx))
return splits
class CVResults:
"""Accumulator for cross-validation results."""
def __init__(self):
self.folds: list[CVFold] = []
def add_fold(self, fold: CVFold) -> None:
"""Add results from a single fold."""
self.folds.append(fold)
def mean_metrics(self) -> dict:
"""Compute mean metrics across folds."""
if not self.folds:
return {}
all_metrics = [fold.metrics for fold in self.folds]
keys = all_metrics[0].keys()
means = {}
for key in keys:
values = [m[key] for m in all_metrics if isinstance(m[key], (int, float))]
if values:
means[f"{key}_mean"] = float(np.mean(values))
means[f"{key}_std"] = float(np.std(values))
return means
def to_dict(self) -> dict:
"""Convert to dictionary for serialization."""
return {
"n_folds": len(self.folds),
"folds": [
{
"fold_idx": fold.fold_idx,
"metrics": fold.metrics,
}
for fold in self.folds
],
"summary": self.mean_metrics(),
}
def kfold_cross_validate(
base_args,
n_splits: int = 5,
output_dir: str | Path | None = None,
) -> CVResults:
"""Run stratified K-fold cross-validation.
Parameters
----------
base_args : argparse.Namespace
Base training arguments.
n_splits : int
Number of folds.
output_dir : str or Path, optional
Directory to save fold results.
Returns
-------
CVResults
Aggregated cross-validation results.
"""
output_dir = Path(output_dir) if output_dir else None
if output_dir:
output_dir.mkdir(parents=True, exist_ok=True)
# Build data module to get labels
dm = build_datamodule(base_args)
dm.prepare_data()
dm.setup()
# Collect labels
all_labels = []
for batch in dm.train_dataloader():
_, _, labels = batch
all_labels.extend(labels.cpu().numpy())
all_labels = np.array(all_labels)
# Initialize CV
cv = CrossValidator(n_splits=n_splits, random_state=base_args.seed)
splits = cv.split(all_labels)
results = CVResults()
for fold_idx, (train_idx, test_idx) in enumerate(splits):
log.info(f"Running fold {fold_idx + 1}/{n_splits}")
# Create fold-specific args
fold_args = vars(base_args).copy()
# Note: For full implementation, would need to modify datamodule
# to accept external train/test splits. For now, train normally.
# Train model
pl.seed_everything(base_args.seed + fold_idx, workers=True)
trainer, _, _ = train_from_args(base_args)
# Collect metrics
fold_metrics = {
key: value.item() if isinstance(value, torch.Tensor) else value
for key, value in trainer.callback_metrics.items()
if key.startswith(("test_",))
}
fold_result = CVFold(
fold_idx=fold_idx,
train_indices=train_idx,
val_indices=np.array([]), # Not used in standard K-fold
test_indices=test_idx,
metrics=fold_metrics,
)
results.add_fold(fold_result)
if output_dir:
fold_file = output_dir / f"fold_{fold_idx}.pt"
torch.save(fold_result, fold_file)
if output_dir:
summary_file = output_dir / "cv_summary.pt"
torch.save(results.to_dict(), summary_file)
log.info(f"CV results saved to {output_dir}")
return results