BrainConnect-ASD / brain_gcn /pretrain_main.py
Yatsuiii's picture
Upload folder using huggingface_hub
16d6869 verified
"""
BC-MAE Pre-training Script.
Self-supervised pre-training on ALL ABIDE subjects (no labels needed).
Input per subject: (W=30, N=200) mean |FC| per ROI per window
- Loaded from fc_windows.npz, site-corrected, then mean |FC| per window
- Same feature as --use_fc_degree_features in the classification pipeline
Task: BrainMAE masks 50% of windows, reconstructs them from visible ones.
Loss: MSE on masked windows only.
Saves: checkpoints/mae/mae-best-*.ckpt (full BrainMAETask checkpoint)
Usage:
python -m brain_gcn.pretrain_main \\
--data_dir data \\
--max_epochs 200 \\
--hidden_dim 128 \\
--lr 1e-3
Then fine-tune with:
python -m brain_gcn.finetune_main \\
--mae_ckpt checkpoints/mae/mae-best-*.ckpt \\
--data_dir data
"""
from __future__ import annotations
import argparse
from pathlib import Path
import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.utils.data import DataLoader, Dataset
from brain_gcn.models.mae import BrainMAE
# ---------------------------------------------------------------------------
# Dataset
# ---------------------------------------------------------------------------
class MAEDataset(Dataset):
"""All ABIDE subjects → (N, N) full FC matrix for spatial BC-MAE pre-training.
Each subject is represented as N=200 tokens, where token i is ROI i's full
connectivity profile (its FC row). The MAE masks 50% of ROIs and reconstructs
their FC rows — forcing the encoder to learn which ROIs co-activate.
"""
def __init__(
self,
npz_dir: str | Path,
site_fc_mean: dict[str, np.ndarray] | None = None,
):
self.paths = sorted(Path(npz_dir).glob("*.npz"))
if not self.paths:
raise FileNotFoundError(f"No .npz files found in {npz_dir}")
self.site_fc_mean = site_fc_mean or {}
def __len__(self) -> int:
return len(self.paths)
def __getitem__(self, idx: int) -> torch.Tensor:
data = np.load(self.paths[idx], allow_pickle=True)
site = str(data["site"])
fc = data["mean_fc"].astype(np.float32) # (N, N)
if site in self.site_fc_mean:
fc = fc - self.site_fc_mean[site]
return torch.FloatTensor(fc) # (N, N) — each row i = ROI i's FC profile
def _compute_site_fc_mean(npz_dir: Path) -> dict[str, np.ndarray]:
"""Per-site mean FC matrix (N, N) across all subjects (no train/test split
needed here since pre-training is fully self-supervised)."""
site_sums: dict[str, np.ndarray] = {}
site_counts: dict[str, int] = {}
for p in sorted(npz_dir.glob("*.npz")):
data = np.load(p, allow_pickle=True)
site = str(data["site"])
fc = data["mean_fc"].astype(np.float32)
if site not in site_sums:
site_sums[site] = np.zeros_like(fc)
site_counts[site] = 0
site_sums[site] += fc
site_counts[site] += 1
return {s: site_sums[s] / site_counts[s] for s in site_sums}
# ---------------------------------------------------------------------------
# Lightning module
# ---------------------------------------------------------------------------
class BrainMAETask(pl.LightningModule):
def __init__(
self,
num_rois: int = 200,
num_windows: int = 30,
hidden_dim: int = 128,
decoder_dim: int = 64,
num_heads: int = 4,
encoder_layers: int = 4,
decoder_layers: int = 2,
dropout: float = 0.1,
mask_ratio: float = 0.5,
lr: float = 1e-3,
weight_decay: float = 1e-4,
warmup_epochs: int = 10,
max_epochs: int = 200,
):
super().__init__()
self.save_hyperparameters()
self.mae = BrainMAE(
num_rois=num_rois,
num_windows=num_windows,
hidden_dim=hidden_dim,
decoder_dim=decoder_dim,
num_heads=num_heads,
encoder_layers=encoder_layers,
decoder_layers=decoder_layers,
dropout=dropout,
mask_ratio=mask_ratio,
)
def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
loss, _ = self.mae(batch)
self.log("train_loss", loss, prog_bar=True, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
loss, _ = self.mae(batch)
self.log("val_loss", loss, prog_bar=True, on_epoch=True, on_step=False)
return loss
def configure_optimizers(self):
opt = torch.optim.AdamW(
self.parameters(),
lr=self.hparams.lr,
weight_decay=self.hparams.weight_decay,
)
def _lr_lambda(epoch: int) -> float:
wu = self.hparams.warmup_epochs
if epoch < wu:
return epoch / max(1, wu)
progress = (epoch - wu) / max(1, self.hparams.max_epochs - wu)
return 0.5 * (1.0 + np.cos(np.pi * progress))
sch = torch.optim.lr_scheduler.LambdaLR(opt, _lr_lambda)
return {"optimizer": opt, "lr_scheduler": {"scheduler": sch, "interval": "epoch"}}
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="BC-MAE Pre-training")
p.add_argument("--data_dir", type=str, default="data")
p.add_argument("--max_windows", type=int, default=30)
p.add_argument("--max_epochs", type=int, default=200)
p.add_argument("--hidden_dim", type=int, default=128)
p.add_argument("--decoder_dim", type=int, default=64)
p.add_argument("--num_heads", type=int, default=4)
p.add_argument("--encoder_layers", type=int, default=4)
p.add_argument("--decoder_layers", type=int, default=2)
p.add_argument("--dropout", type=float, default=0.1)
p.add_argument("--mask_ratio", type=float, default=0.5)
p.add_argument("--lr", type=float, default=1e-3)
p.add_argument("--weight_decay", type=float, default=1e-4)
p.add_argument("--warmup_epochs", type=int, default=10)
p.add_argument("--batch_size", type=int, default=32)
p.add_argument("--num_workers", type=int, default=4)
p.add_argument("--val_ratio", type=float, default=0.1)
p.add_argument("--accelerator", type=str, default="auto")
p.add_argument("--devices", type=str, default="auto")
p.add_argument("--seed", type=int, default=42)
p.add_argument("--ckpt_dir", type=str, default="checkpoints/mae")
return p
def main() -> None:
torch.set_float32_matmul_precision("medium")
args = build_parser().parse_args()
pl.seed_everything(args.seed, workers=True)
processed_dir = Path(args.data_dir) / "processed"
print(f"Computing site FC means from {processed_dir} ...")
site_fc_mean = _compute_site_fc_mean(processed_dir)
print(f" {len(site_fc_mean)} sites found.")
full_ds = MAEDataset(processed_dir, site_fc_mean=site_fc_mean)
n = len(full_ds)
n_val = max(1, int(n * args.val_ratio))
n_train = n - n_val
rng = torch.Generator().manual_seed(args.seed)
train_ds, val_ds = torch.utils.data.random_split(full_ds, [n_train, n_val], generator=rng)
print(f"Pre-training split: {n_train} train / {n_val} val ({n} total)")
pin = torch.cuda.is_available()
train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, pin_memory=pin)
val_dl = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=pin)
first = np.load(full_ds.paths[0], allow_pickle=True)
num_rois = int(first["mean_fc"].shape[0])
# Spatial MAE: each of the N ROIs is a "window", its FC row (N-dim) is the patch feature
num_windows = num_rois
print(f"Spatial BC-MAE: {num_rois} ROIs × {num_rois}-dim FC rows")
task = BrainMAETask(
num_rois=num_rois,
num_windows=num_windows, # = num_rois (200) — spatial MAE
hidden_dim=args.hidden_dim,
decoder_dim=args.decoder_dim,
num_heads=args.num_heads,
encoder_layers=args.encoder_layers,
decoder_layers=args.decoder_layers,
dropout=args.dropout,
mask_ratio=args.mask_ratio,
lr=args.lr,
weight_decay=args.weight_decay,
warmup_epochs=args.warmup_epochs,
max_epochs=args.max_epochs,
)
ckpt_dir = Path(args.ckpt_dir)
ckpt_dir.mkdir(parents=True, exist_ok=True)
trainer = pl.Trainer(
max_epochs=args.max_epochs,
accelerator=args.accelerator,
devices=args.devices,
deterministic=True,
log_every_n_steps=1,
callbacks=[
EarlyStopping(monitor="val_loss", mode="min", patience=30),
ModelCheckpoint(
dirpath=str(ckpt_dir),
monitor="val_loss",
mode="min",
save_top_k=1,
filename="mae-best-{epoch:03d}-{val_loss:.4f}",
),
],
)
trainer.fit(task, train_dl, val_dl)
best = trainer.checkpoint_callback.best_model_path
print(f"\nPre-training complete.")
print(f"Best checkpoint: {best}")
print(f"\nNext step:")
print(f" python -m brain_gcn.finetune_main --mae_ckpt {best} --data_dir {args.data_dir}")
if __name__ == "__main__":
main()