| """ |
| BC-MAE Fine-tuning Script. |
| |
| Two-phase fine-tuning of a pre-trained BC-MAE encoder for ASD/TD classification. |
| |
| Phase 1 — Linear probe (encoder frozen, ~50 epochs) |
| Warms up the classification head without distorting the encoder. |
| |
| Phase 2 — Full fine-tune (encoder + head, discriminative LRs, ~150 epochs) |
| Head : lr (full) |
| Encoder: lr × encoder_lr_scale (default 0.1) |
| |
| Data: use_fc_degree_features=True → (W=30, N=200) mean |FC| per window, |
| same feature as pre-training. Labels used only in fine-tuning loss. |
| |
| Usage: |
| python -m brain_gcn.finetune_main \\ |
| --mae_ckpt checkpoints/mae/mae-best-*.ckpt \\ |
| --data_dir data \\ |
| --probe_epochs 50 \\ |
| --finetune_epochs 150 \\ |
| --lr 5e-4 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import copy |
| from pathlib import Path |
|
|
| import numpy as np |
| import pytorch_lightning as pl |
| import torch |
| from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint |
| from torch import nn |
| from torchmetrics.classification import ( |
| BinaryAUROC, |
| BinaryAccuracy, |
| BinaryF1Score, |
| BinaryRecall, |
| BinarySpecificity, |
| ) |
|
|
| from brain_gcn.models.mae import BrainFCClassifier, BrainFCEncoder |
| from brain_gcn.utils.data.datamodule import ABIDEDataModule |
|
|
|
|
| |
| |
| |
|
|
| class MAEClassificationTask(pl.LightningModule): |
| def __init__( |
| self, |
| classifier: BrainFCClassifier, |
| class_weights: torch.Tensor | None = None, |
| lr: float = 5e-4, |
| encoder_lr_scale: float = 0.1, |
| weight_decay: float = 1e-4, |
| bold_noise_std: float = 0.01, |
| cosine_t0: int = 30, |
| cosine_eta_min: float = 1e-6, |
| freeze_encoder: bool = False, |
| ): |
| super().__init__() |
| self.save_hyperparameters(ignore=["classifier", "class_weights"]) |
| self.model = classifier |
| self.register_buffer("class_weights", class_weights) |
| self.loss_fn = nn.CrossEntropyLoss(weight=class_weights) |
|
|
| self.train_acc = BinaryAccuracy() |
| self.val_acc = BinaryAccuracy() |
| self.val_auc = BinaryAUROC() |
| self.val_f1 = BinaryF1Score() |
| self.val_sens = BinaryRecall() |
| self.val_spec = BinarySpecificity() |
| self.test_acc = BinaryAccuracy() |
| self.test_auc = BinaryAUROC() |
| self.test_f1 = BinaryF1Score() |
| self.test_sens = BinaryRecall() |
| self.test_spec = BinarySpecificity() |
|
|
| def forward(self, x: torch.Tensor, adj: torch.Tensor | None = None) -> torch.Tensor: |
| return self.model(x, adj) |
|
|
| def training_step(self, batch, batch_idx: int) -> torch.Tensor: |
| _, adj, labels, _ = batch |
| |
| x = adj |
| if self.hparams.bold_noise_std > 0.0: |
| sig = x.std(dim=(1, 2), keepdim=True).detach() |
| x = x + torch.randn_like(x) * self.hparams.bold_noise_std * sig |
| logits = self(x) |
| loss = self.loss_fn(logits, labels) |
| preds = logits.argmax(-1) |
| self.train_acc.update(preds, labels) |
| self.log("train_loss", loss, prog_bar=True, on_epoch=True, on_step=False) |
| self.log("train_acc", self.train_acc, prog_bar=True, on_epoch=True, on_step=False) |
| return loss |
|
|
| def validation_step(self, batch, batch_idx: int) -> torch.Tensor: |
| _, adj, labels, _ = batch |
| x = adj |
| logits = self(x) |
| loss = self.loss_fn(logits, labels) |
| probs = torch.softmax(logits, -1)[:, 1] |
| preds = logits.argmax(-1) |
| self.val_acc.update(preds, labels) |
| self.val_auc.update(probs, labels) |
| self.val_f1.update(preds, labels) |
| self.val_sens.update(preds, labels) |
| self.val_spec.update(preds, labels) |
| self.log("val_loss", loss, prog_bar=True, on_epoch=True, on_step=False) |
| self.log("val_acc", self.val_acc, prog_bar=True, on_epoch=True, on_step=False) |
| self.log("val_auc", self.val_auc, prog_bar=True, on_epoch=True, on_step=False) |
| self.log("val_f1", self.val_f1, prog_bar=False, on_epoch=True, on_step=False) |
| self.log("val_sens", self.val_sens, prog_bar=False, on_epoch=True, on_step=False) |
| self.log("val_spec", self.val_spec, prog_bar=False, on_epoch=True, on_step=False) |
| return loss |
|
|
| def test_step(self, batch, batch_idx: int) -> torch.Tensor: |
| _, adj, labels, _ = batch |
| x = adj |
| logits = self(x) |
| loss = self.loss_fn(logits, labels) |
| probs = torch.softmax(logits, -1)[:, 1] |
| preds = logits.argmax(-1) |
| self.test_acc.update(preds, labels) |
| self.test_auc.update(probs, labels) |
| self.test_f1.update(preds, labels) |
| self.test_sens.update(preds, labels) |
| self.test_spec.update(preds, labels) |
| self.log("test_loss", loss, on_epoch=True, on_step=False) |
| self.log("test_acc", self.test_acc, prog_bar=True, on_epoch=True, on_step=False) |
| self.log("test_auc", self.test_auc, prog_bar=True, on_epoch=True, on_step=False) |
| self.log("test_f1", self.test_f1, prog_bar=True, on_epoch=True, on_step=False) |
| self.log("test_sens", self.test_sens, prog_bar=True, on_epoch=True, on_step=False) |
| self.log("test_spec", self.test_spec, prog_bar=True, on_epoch=True, on_step=False) |
| return loss |
|
|
| def configure_optimizers(self): |
| enc_ids = {id(p) for p in self.model.encoder.parameters()} |
| enc_params = [p for p in self.model.parameters() if id(p) in enc_ids] |
| head_params = [p for p in self.model.parameters() if id(p) not in enc_ids] |
|
|
| if self.hparams.freeze_encoder: |
| param_groups = [{"params": head_params, "lr": self.hparams.lr}] |
| else: |
| param_groups = [ |
| {"params": head_params, "lr": self.hparams.lr}, |
| {"params": enc_params, "lr": self.hparams.lr * self.hparams.encoder_lr_scale}, |
| ] |
|
|
| opt = torch.optim.AdamW(param_groups, weight_decay=self.hparams.weight_decay) |
| sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( |
| opt, |
| T_0=self.hparams.cosine_t0, |
| eta_min=self.hparams.cosine_eta_min, |
| ) |
| return {"optimizer": opt, "lr_scheduler": {"scheduler": sch, "interval": "epoch"}} |
|
|
|
|
| |
| |
| |
|
|
| def _compute_class_weights(dm: ABIDEDataModule) -> torch.Tensor: |
| labels = np.array([int(np.load(p, allow_pickle=True)["label"]) for p in dm._train_paths]) |
| n_td = int((labels == 0).sum()) |
| n_asd = int((labels == 1).sum()) |
| total = n_td + n_asd |
| return torch.tensor([total / (2.0 * n_td), total / (2.0 * n_asd)], dtype=torch.float32) |
|
|
|
|
| def _load_encoder( |
| ckpt_path: str, |
| num_rois: int, |
| num_windows: int, |
| hidden_dim: int, |
| num_heads: int, |
| encoder_layers: int, |
| dropout: float, |
| ) -> BrainFCEncoder: |
| """Extract BrainFCEncoder weights from a BrainMAETask Lightning checkpoint.""" |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| state = ckpt["state_dict"] |
|
|
| enc_state = { |
| k[len("mae.encoder."):]: v |
| for k, v in state.items() |
| if k.startswith("mae.encoder.") |
| } |
| if not enc_state: |
| raise KeyError( |
| f"No 'mae.encoder.*' keys found in {ckpt_path}. " |
| "Make sure you pass a BrainMAETask checkpoint, not a classifier checkpoint." |
| ) |
|
|
| encoder = BrainFCEncoder( |
| num_rois=num_rois, |
| num_windows=num_windows, |
| hidden_dim=hidden_dim, |
| num_heads=num_heads, |
| num_layers=encoder_layers, |
| dropout=dropout, |
| ) |
| encoder.load_state_dict(enc_state, strict=True) |
| print(f"Loaded encoder from {ckpt_path} ({len(enc_state)} tensors)") |
| return encoder |
|
|
|
|
| def _load_head_weights(task: MAEClassificationTask, ckpt_path: str) -> None: |
| """Restore time_attn + head weights from a previous phase checkpoint.""" |
| sd = torch.load(ckpt_path, map_location="cpu", weights_only=False)["state_dict"] |
| mapping = {} |
| for k, v in sd.items(): |
| if k.startswith("model.time_attn.") or k.startswith("model.head."): |
| new_k = k[len("model."):] |
| mapping[new_k] = v |
| if mapping: |
| current = task.model.state_dict() |
| current.update(mapping) |
| task.model.load_state_dict(current, strict=True) |
| print(f"Restored {len(mapping)} head tensors from {ckpt_path}") |
|
|
|
|
| def _make_trainer( |
| max_epochs: int, |
| ckpt_dir: Path, |
| prefix: str, |
| accelerator: str, |
| devices: str, |
| patience: int = 30, |
| ) -> pl.Trainer: |
| ckpt_dir.mkdir(parents=True, exist_ok=True) |
| return pl.Trainer( |
| max_epochs=max_epochs, |
| accelerator=accelerator, |
| devices=devices, |
| deterministic=True, |
| log_every_n_steps=1, |
| callbacks=[ |
| EarlyStopping(monitor="val_auc", mode="max", patience=patience), |
| ModelCheckpoint( |
| dirpath=str(ckpt_dir), |
| monitor="val_auc", |
| mode="max", |
| save_top_k=3, |
| filename=f"{prefix}-{{epoch:03d}}-{{val_auc:.3f}}", |
| ), |
| ], |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def build_parser() -> argparse.ArgumentParser: |
| p = argparse.ArgumentParser(description="BC-MAE Fine-tuning") |
| p.add_argument("--mae_ckpt", type=str, required=True, |
| help="Path to best MAE pre-training checkpoint (.ckpt)") |
| p.add_argument("--data_dir", type=str, default="data") |
| p.add_argument("--max_windows", type=int, default=30) |
| p.add_argument("--hidden_dim", type=int, default=128) |
| p.add_argument("--num_heads", type=int, default=4) |
| p.add_argument("--encoder_layers", type=int, default=4) |
| p.add_argument("--dropout_encoder", type=float, default=0.1) |
| p.add_argument("--dropout_head", type=float, default=0.5) |
| |
| p.add_argument("--probe_epochs", type=int, default=50, |
| help="Epochs with frozen encoder (linear probe).") |
| p.add_argument("--probe_lr", type=float, default=1e-3) |
| |
| p.add_argument("--finetune_epochs", type=int, default=150, |
| help="Epochs with full encoder fine-tuning.") |
| p.add_argument("--finetune_lr", type=float, default=5e-4) |
| p.add_argument("--encoder_lr_scale", type=float, default=0.1, |
| help="Encoder LR = finetune_lr × this. Default 0.1 (10x smaller).") |
| p.add_argument("--weight_decay", type=float, default=1e-4) |
| p.add_argument("--bold_noise_std", type=float, default=0.01) |
| p.add_argument("--cosine_t0", type=int, default=30) |
| p.add_argument("--cosine_eta_min", type=float, default=1e-6) |
| |
| p.add_argument("--batch_size", type=int, default=32) |
| p.add_argument("--num_workers", type=int, default=4) |
| p.add_argument("--split_strategy", choices=["stratified", "site_holdout"], |
| default="stratified") |
| p.add_argument("--val_site", type=str, default=None) |
| p.add_argument("--test_site", type=str, default=None) |
| |
| p.add_argument("--accelerator", type=str, default="auto") |
| p.add_argument("--devices", type=str, default="auto") |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument("--ckpt_dir", type=str, default="checkpoints/mae_finetune") |
| p.add_argument("--test", action="store_true", |
| help="Run test set evaluation after fine-tuning.") |
| p.add_argument("--skip_probe", action="store_true", |
| help="Skip Phase 1 and jump straight to full fine-tuning.") |
| return p |
|
|
|
|
| def main() -> None: |
| torch.set_float32_matmul_precision("medium") |
| args = build_parser().parse_args() |
| pl.seed_everything(args.seed, workers=True) |
|
|
| |
| |
| |
| |
| dm = ABIDEDataModule( |
| data_dir=args.data_dir, |
| use_population_adj=False, |
| preserve_fc_sign=True, |
| fc_threshold=0.0, |
| batch_size=args.batch_size, |
| num_workers=args.num_workers, |
| split_strategy=args.split_strategy, |
| val_site=args.val_site, |
| test_site=args.test_site, |
| ) |
| dm.prepare_data() |
| dm.setup() |
|
|
| num_rois = dm.num_nodes |
| class_weights = _compute_class_weights(dm) |
| print(f"num_rois={num_rois} class_weights={class_weights.tolist()}") |
|
|
| |
| encoder = _load_encoder( |
| ckpt_path=args.mae_ckpt, |
| num_rois=num_rois, |
| num_windows=num_rois, |
| hidden_dim=args.hidden_dim, |
| num_heads=args.num_heads, |
| encoder_layers=args.encoder_layers, |
| dropout=args.dropout_encoder, |
| ) |
|
|
| ckpt_dir = Path(args.ckpt_dir) |
|
|
| best_probe_ckpt: str | None = None |
|
|
| |
| if not args.skip_probe: |
| print(f"\n{'='*60}") |
| print(f"Phase 1: Linear probe ({args.probe_epochs} epochs, LR={args.probe_lr})") |
| print(f"{'='*60}") |
|
|
| classifier_p1 = BrainFCClassifier( |
| encoder=encoder, |
| hidden_dim=args.hidden_dim, |
| num_classes=2, |
| dropout=args.dropout_head, |
| freeze_encoder=True, |
| ) |
| task_p1 = MAEClassificationTask( |
| classifier=classifier_p1, |
| class_weights=class_weights, |
| lr=args.probe_lr, |
| encoder_lr_scale=0.0, |
| weight_decay=args.weight_decay, |
| bold_noise_std=0.0, |
| cosine_t0=args.cosine_t0, |
| cosine_eta_min=args.cosine_eta_min, |
| freeze_encoder=True, |
| ) |
| trainer_p1 = _make_trainer( |
| max_epochs=args.probe_epochs, |
| ckpt_dir=ckpt_dir / "probe", |
| prefix="probe", |
| accelerator=args.accelerator, |
| devices=args.devices, |
| patience=20, |
| ) |
| trainer_p1.fit(task_p1, datamodule=dm) |
| best_probe_ckpt = trainer_p1.checkpoint_callback.best_model_path |
| best_probe_auc = trainer_p1.callback_metrics.get("val_auc", torch.tensor(0.0)) |
| print(f"Phase 1 best val_auc: {float(best_probe_auc):.4f}") |
|
|
| |
| print(f"\n{'='*60}") |
| print(f"Phase 2: Full fine-tune ({args.finetune_epochs} epochs, " |
| f"LR={args.finetune_lr}, enc_scale={args.encoder_lr_scale})") |
| print(f"{'='*60}") |
|
|
| classifier_p2 = BrainFCClassifier( |
| encoder=copy.deepcopy(encoder), |
| hidden_dim=args.hidden_dim, |
| num_classes=2, |
| dropout=args.dropout_head, |
| freeze_encoder=False, |
| ) |
| task_p2 = MAEClassificationTask( |
| classifier=classifier_p2, |
| class_weights=class_weights, |
| lr=args.finetune_lr, |
| encoder_lr_scale=args.encoder_lr_scale, |
| weight_decay=args.weight_decay, |
| bold_noise_std=args.bold_noise_std, |
| cosine_t0=args.cosine_t0, |
| cosine_eta_min=args.cosine_eta_min, |
| freeze_encoder=False, |
| ) |
|
|
| |
| if best_probe_ckpt: |
| _load_head_weights(task_p2, best_probe_ckpt) |
|
|
| trainer_p2 = _make_trainer( |
| max_epochs=args.finetune_epochs, |
| ckpt_dir=ckpt_dir / "finetune", |
| prefix="ft", |
| accelerator=args.accelerator, |
| devices=args.devices, |
| patience=40, |
| ) |
| trainer_p2.fit(task_p2, datamodule=dm) |
| best_ft_auc = trainer_p2.callback_metrics.get("val_auc", torch.tensor(0.0)) |
| print(f"\nPhase 2 best val_auc: {float(best_ft_auc):.4f}") |
|
|
| if args.test: |
| print("\nRunning test set evaluation ...") |
| trainer_p2.test(task_p2, datamodule=dm, ckpt_path="best") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|