BrainConnect-ASD / brain_gcn /utils /data /datamodule.py
Yatsuiii's picture
Upload folder using huggingface_hub
16d6869 verified
"""
PyTorch Lightning DataModule for ABIDE I.
Full pipeline (called once via prepare_data / setup):
1. Download ABIDE via nilearn (download.py)
2. Preprocess subjects → .npz cache (preprocess.py)
3. Stratified train / val / test split
4. Build population adjacency from training subjects (functional_connectivity.py)
5. Expose train / val / test DataLoaders
Usage:
dm = ABIDEDataModule(data_dir="data", n_subjects=100)
dm.prepare_data()
dm.setup()
for bold_windows, adj, label in dm.train_dataloader():
...
"""
from __future__ import annotations
import argparse
import logging
from collections import Counter
from pathlib import Path
import numpy as np
import pytorch_lightning as pl
import torch
from sklearn.model_selection import StratifiedShuffleSplit
from torch.utils.data import DataLoader
from .dataset import ABIDEDataset
from .download import fetch_abide, extract_subjects
from .functional_connectivity import compute_population_adj
from .preprocess import preprocess_all
log = logging.getLogger(__name__)
def collate_fn(batch):
"""
Custom collate: stack bold_windows, labels, and site_ids; keep adj as-is.
Returns:
bold_windows : (B, W, N)
adj : (B, N, N)
labels : (B,)
site_ids : (B,)
"""
bold_windowss, adjs, labels, site_ids = zip(*batch)
return (
torch.stack(bold_windowss),
torch.stack(adjs),
torch.stack(labels),
torch.stack(site_ids),
)
class ABIDEDataModule(pl.LightningDataModule):
def __init__(
self,
data_dir: str = "data",
n_subjects: int | None = None,
window_len: int = 50,
step: int = 5,
max_windows: int | None = 30,
fc_threshold: float = 0.2,
use_dynamic_adj: bool = False,
use_dynamic_adj_sequence: bool = False,
use_population_adj: bool = True,
preserve_fc_sign: bool = False,
use_fc_variance: bool = False,
use_fisher_z: bool = False,
use_fc_degree_features: bool = False,
use_fc_row_features: bool = False,
n_pca_components: int = 0,
batch_size: int = 32,
val_ratio: float = 0.1,
test_ratio: float = 0.1,
split_strategy: str = "stratified",
val_site: str | None = None,
test_site: str | None = None,
num_workers: int = 4,
overwrite_cache: bool = False,
force_prepare: bool = False,
):
"""
Parameters
----------
data_dir : root directory for raw + processed data
n_subjects : cap for ABIDE download (None = all ~884)
window_len : sliding window length in TRs
step : sliding window step in TRs
max_windows : truncate each subject to this many windows
(ensures uniform batch shapes without padding)
fc_threshold : sparsify FC: zero edges with |fc| < threshold
use_dynamic_adj : per-subject: use mean of window FCs (vs. full-scan FC)
use_dynamic_adj_sequence: per-subject: return one adjacency per window.
Ignored when use_population_adj=True.
use_population_adj: compute a single population-level adj from training
set and use it for all subjects (recommended)
batch_size : samples per batch
val_ratio : fraction of data for validation
test_ratio : fraction of data for test
split_strategy : stratified random split or site_holdout split
val_site : validation site for site_holdout. If unset, chosen by size.
test_site : test site for site_holdout. If unset, largest site is used.
num_workers : DataLoader worker processes
overwrite_cache : re-preprocess even if .npz files exist
force_prepare : download/preprocess even when processed .npz files exist
"""
super().__init__()
self.data_dir = Path(data_dir)
self.raw_dir = self.data_dir / "raw"
self.processed_dir = self.data_dir / "processed"
self.n_subjects = n_subjects
self.window_len = window_len
self.step = step
self.max_windows = max_windows
self.fc_threshold = fc_threshold
self.use_dynamic_adj = use_dynamic_adj
self.use_dynamic_adj_sequence = use_dynamic_adj_sequence
self.use_population_adj = use_population_adj
self.preserve_fc_sign = preserve_fc_sign
self.use_fc_variance = use_fc_variance
self.use_fisher_z = use_fisher_z
self.use_fc_degree_features = use_fc_degree_features
self.use_fc_row_features = use_fc_row_features
self.n_pca_components = n_pca_components
self.batch_size = batch_size
self.val_ratio = val_ratio
self.test_ratio = test_ratio
self.split_strategy = split_strategy
self.val_site = val_site
self.test_site = test_site
self.num_workers = num_workers
self.overwrite_cache = overwrite_cache
self.force_prepare = force_prepare
self._population_adj: np.ndarray | None = None
self._site_fc_mean: dict[str, np.ndarray] = {}
self._site_to_int: dict[str, int] = {}
self._pca_mean: np.ndarray | None = None # (D,) mean FC vector
self._pca_components: np.ndarray | None = None # (K, D) principal axes
self._train_paths: list[Path] = []
self._val_paths: list[Path] = []
self._test_paths: list[Path] = []
# ------------------------------------------------------------------
# Lightning hooks
# ------------------------------------------------------------------
def prepare_data(self):
"""Download + preprocess (runs on rank 0 only in distributed settings)."""
cached_paths = list(self.processed_dir.glob("*.npz"))
n_cached = len(cached_paths)
# Skip only when we already have enough subjects and no explicit override
have_enough = (
self.n_subjects is None or n_cached >= self.n_subjects
)
if cached_paths and have_enough and not self.overwrite_cache and not self.force_prepare:
log.info(
"Found %d cached subject files in %s; skipping download/preprocess.",
n_cached,
self.processed_dir,
)
return
if n_cached > 0 and not self.overwrite_cache:
log.info(
"Have %d subjects, want %s — downloading remaining subjects.",
n_cached,
self.n_subjects or "all",
)
dataset = fetch_abide(
data_dir=self.raw_dir,
n_subjects=self.n_subjects,
)
subjects = extract_subjects(dataset, min_timepoints=self.window_len + self.step)
preprocess_all(
subjects,
processed_dir=self.processed_dir,
window_len=self.window_len,
step=self.step,
overwrite=self.overwrite_cache,
)
def setup(self, stage: str | None = None):
"""Build train/val/test splits and optionally the population adjacency."""
all_paths = sorted(self.processed_dir.glob("*.npz"))
if not all_paths:
raise RuntimeError(
f"No .npz files found in {self.processed_dir}. "
"Run prepare_data() first."
)
# Read labels/sites for splitting
labels = np.array([
int(np.load(p, allow_pickle=True)["label"]) for p in all_paths
])
sites = np.array([
str(np.load(p, allow_pickle=True)["site"]) for p in all_paths
])
# Build site → int mapping from ALL subjects (consistent across splits)
self._site_to_int = {
site: i for i, site in enumerate(sorted(set(sites.tolist())))
}
log.info("Sites (%d): %s", len(self._site_to_int), sorted(self._site_to_int))
if self.split_strategy == "stratified":
train_paths, val_paths, test_paths = self._stratified_split(
all_paths, labels, self.val_ratio, self.test_ratio
)
elif self.split_strategy == "site_holdout":
train_paths, val_paths, test_paths = self._site_holdout_split(
all_paths, labels, sites, self.val_site, self.test_site
)
else:
raise ValueError(f"Unknown split_strategy: {self.split_strategy}")
self._train_paths = train_paths
self._val_paths = val_paths
self._test_paths = test_paths
log.info(
"Split (%s): train=%d val=%d test=%d",
self.split_strategy,
len(train_paths), len(val_paths), len(test_paths),
)
# Build population adjacency from training subjects only
if self.use_population_adj:
self._population_adj = self._build_population_adj(train_paths)
# Compute per-site mean FC from training set (FC-domain site normalization)
self._site_fc_mean = self._build_site_fc_mean(train_paths)
# PCA on training FC upper triangles (reduces p>>n overfitting)
if self.n_pca_components > 0:
self._pca_mean, self._pca_components = self._build_pca(train_paths)
def train_dataloader(self) -> DataLoader:
return DataLoader(
self._make_dataset(self._train_paths),
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
collate_fn=collate_fn,
pin_memory=torch.cuda.is_available(),
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
self._make_dataset(self._val_paths),
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
collate_fn=collate_fn,
pin_memory=torch.cuda.is_available(),
)
def test_dataloader(self) -> DataLoader:
return DataLoader(
self._make_dataset(self._test_paths),
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
collate_fn=collate_fn,
pin_memory=torch.cuda.is_available(),
)
# ------------------------------------------------------------------
# Properties exposed to the model
# ------------------------------------------------------------------
@property
def num_nodes(self) -> int:
"""Number of ROIs (200 for cc200 atlas)."""
data = np.load(self._train_paths[0], allow_pickle=True)
return data["mean_fc"].shape[0]
@property
def num_windows(self) -> int:
"""Number of brain-state snapshots (sliding windows) per subject."""
if self.max_windows is not None:
return self.max_windows
data = np.load(self._train_paths[0], allow_pickle=True)
return data["bold_windows"].shape[0]
@property
def population_adj(self) -> np.ndarray | None:
return self._population_adj
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _make_dataset(self, paths: list[Path]) -> ABIDEDataset:
return ABIDEDataset(
npz_paths=paths,
population_adj=self._population_adj,
use_dynamic_adj=self.use_dynamic_adj,
use_dynamic_adj_sequence=self.use_dynamic_adj_sequence,
fc_threshold=self.fc_threshold,
max_windows=self.max_windows,
site_fc_mean=self._site_fc_mean,
preserve_fc_sign=self.preserve_fc_sign,
site_to_int=self._site_to_int,
use_fc_variance=self.use_fc_variance,
use_fisher_z=self.use_fisher_z,
pca_mean=self._pca_mean,
pca_components=self._pca_components,
use_fc_degree_features=self.use_fc_degree_features,
use_fc_row_features=self.use_fc_row_features,
)
@property
def num_sites(self) -> int:
return len(self._site_to_int)
@staticmethod
def _stratified_split(
paths: list[Path],
labels: np.ndarray,
val_ratio: float,
test_ratio: float,
) -> tuple[list[Path], list[Path], list[Path]]:
paths = np.array(paths)
# First split off test set
sss_test = StratifiedShuffleSplit(n_splits=1, test_size=test_ratio, random_state=42)
train_val_idx, test_idx = next(sss_test.split(paths, labels))
# Then split val from train
val_size = val_ratio / (1.0 - test_ratio)
sss_val = StratifiedShuffleSplit(n_splits=1, test_size=val_size, random_state=42)
train_idx, val_idx = next(sss_val.split(paths[train_val_idx], labels[train_val_idx]))
return (
list(paths[train_val_idx[train_idx]]),
list(paths[train_val_idx[val_idx]]),
list(paths[test_idx]),
)
@staticmethod
def _site_holdout_split(
paths: list[Path],
labels: np.ndarray,
sites: np.ndarray,
val_site: str | None,
test_site: str | None,
) -> tuple[list[Path], list[Path], list[Path]]:
paths_arr = np.array(paths)
site_counts = Counter(sites.tolist())
if len(site_counts) < 3:
raise ValueError("site_holdout split needs at least 3 sites.")
sorted_sites = [site for site, _ in site_counts.most_common()]
# test_site may be a comma-separated list of sites (e.g. "UCLA_1,UCLA_2")
test_sites = [s.strip() for s in test_site.split(",")] if test_site else [sorted_sites[1]]
if val_site is None:
val_site = next((s for s in reversed(sorted_sites) if s not in test_sites), None)
if val_site is None or val_site in test_sites:
raise ValueError("site_holdout split needs distinct val_site and test_site.")
for ts in test_sites:
if ts not in site_counts:
raise ValueError(f"Unknown test_site '{ts}'. Available: {sorted(site_counts)}")
if val_site not in site_counts:
raise ValueError(f"Unknown val_site '{val_site}'. Available: {sorted(site_counts)}")
train_mask = np.ones(len(sites), dtype=bool)
for ts in test_sites:
train_mask &= (sites != ts)
train_mask &= (sites != val_site)
val_mask = sites == val_site
test_mask = np.zeros(len(sites), dtype=bool)
for ts in test_sites:
test_mask |= (sites == ts)
ABIDEDataModule._assert_both_labels(labels[train_mask], "train")
ABIDEDataModule._assert_both_labels(labels[val_mask], "val")
ABIDEDataModule._assert_both_labels(labels[test_mask], "test")
return (
list(paths_arr[train_mask]),
list(paths_arr[val_mask]),
list(paths_arr[test_mask]),
)
@staticmethod
def _assert_both_labels(labels: np.ndarray, split_name: str) -> None:
unique = set(labels.tolist())
if unique != {0, 1}:
raise ValueError(
f"{split_name} split must contain both labels, got {sorted(unique)}."
)
def _build_pca(self, train_paths: list[Path]) -> tuple[np.ndarray, np.ndarray]:
"""Compute PCA on training-set FC upper triangles using truncated SVD.
Returns
-------
mean_vec : (D,) mean FC vector (for centering)
components : (K, D) top-K principal axes (rows = PCs)
With D=19900 features and N≈660 training subjects, PCA reduces to the
N-1 dimensional subspace anyway. Using K<<N avoids p>>n overfitting:
the MLP trains on K features rather than 19900.
"""
K = self.n_pca_components
log.info("Computing PCA (K=%d) from %d training FC matrices ...", K, len(train_paths))
# Build training matrix: (N_train, D)
rows = []
for p in train_paths:
data = np.load(p, allow_pickle=True)
fc = data["mean_fc"].astype(np.float32)
n = fc.shape[0]
r, c = np.triu_indices(n, k=1)
if self.use_fisher_z:
fc = np.arctanh(np.clip(fc, -0.9999, 0.9999))
rows.append(fc[r, c])
X = np.stack(rows, axis=0) # (N_train, D)
mean_vec = X.mean(axis=0) # (D,)
X_centered = X - mean_vec # (N_train, D)
# Truncated SVD via economy SVD on the smaller dimension
# X = U S Vt → principal components = Vt[:K]
# Since N << D, use X @ Xt for the eigen-decomposition shortcut
# (N_train × N_train covariance, then recover Vt)
C = (X_centered @ X_centered.T) / (len(train_paths) - 1) # (N, N)
eigenvalues, U = np.linalg.eigh(C) # ascending
# eigh returns ascending; we want descending
idx = np.argsort(-eigenvalues)
U = U[:, idx[:K]] # (N, K)
components = (X_centered.T @ U) # (D, K)
# Normalise each column to unit length → rows of Vt
components /= np.linalg.norm(components, axis=0, keepdims=True) + 1e-8
components = components.T.astype(np.float32) # (K, D)
var_explained = eigenvalues[idx[:K]].sum() / (eigenvalues.sum() + 1e-8)
log.info("PCA: top-%d components explain %.1f%% of FC variance.", K, 100 * var_explained)
return mean_vec.astype(np.float32), components
def _build_site_fc_mean(self, train_paths: list[Path]) -> dict[str, np.ndarray]:
"""Compute per-site mean FC matrix (N, N) from training subjects.
Subtracting this at load time removes scanner-specific connectivity biases
(a simple FC-domain site normalization). BOLD is already z-scored so
BOLD-domain corrections have no effect."""
log.info("Computing per-site FC means from %d training subjects ...", len(train_paths))
site_sums: dict[str, np.ndarray] = {}
site_counts: dict[str, int] = {}
for p in train_paths:
data = np.load(p, allow_pickle=True)
site = str(data["site"])
fc = data["mean_fc"].astype(np.float32) # (N, N)
if site not in site_sums:
site_sums[site] = np.zeros_like(fc)
site_counts[site] = 0
site_sums[site] += fc
site_counts[site] += 1
return {s: site_sums[s] / site_counts[s] for s in site_sums}
def _build_population_adj(self, train_paths: list[Path]) -> np.ndarray:
log.info("Building population adjacency from %d training subjects ...", len(train_paths))
mean_fcs = []
for p in train_paths:
data = np.load(p, allow_pickle=True)
mean_fcs.append(data["mean_fc"].astype(np.float32))
adj = compute_population_adj(mean_fcs, threshold=self.fc_threshold)
log.info(
"Population adj: %d nodes, %.1f%% edges non-zero.",
adj.shape[0],
100.0 * (adj > 0).sum() / adj.size,
)
return adj
# ------------------------------------------------------------------
# argparse integration
# ------------------------------------------------------------------
@staticmethod
def add_data_specific_arguments(parent_parser: argparse.ArgumentParser):
parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--data_dir", type=str, default="data")
parser.add_argument("--n_subjects", type=int, default=None)
parser.add_argument("--window_len", type=int, default=50)
parser.add_argument("--step", type=int, default=5)
parser.add_argument("--max_windows", type=int, default=30)
parser.add_argument("--fc_threshold", type=float, default=0.2)
parser.add_argument("--use_dynamic_adj", action="store_true")
parser.add_argument("--use_dynamic_adj_sequence", action="store_true")
parser.add_argument("--use_population_adj", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--preserve_fc_sign", action="store_true",
help="Keep signed FC values in adjacency (required for fc_mlp).")
parser.add_argument("--use_fc_variance", action="store_true",
help="Append std(fc_windows) as a second feature channel alongside mean FC.")
parser.add_argument("--use_fc_degree_features", action="store_true",
help="Replace BOLD std node features with per-ROI mean |FC| per window.")
parser.add_argument("--use_fc_row_features", action="store_true",
help="Use FC rows as node features (W,N,N). Requires graph_temporal + in_features=num_nodes.")
parser.add_argument("--use_fisher_z", action="store_true",
help="Apply Fisher r-to-z transform to FC values before classification.")
parser.add_argument("--n_pca_components", type=int, default=0,
help="If >0, reduce FC to this many PCA components before the MLP.")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--val_ratio", type=float, default=0.1)
parser.add_argument("--test_ratio", type=float, default=0.1)
parser.add_argument("--split_strategy", choices=["stratified", "site_holdout"], default="stratified")
parser.add_argument("--val_site", type=str, default=None)
parser.add_argument("--test_site", type=str, default=None)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument(
"--overwrite_cache",
action="store_true",
help="Force re-download and re-preprocess even if .npz files already exist.",
)
return parser