| """ |
| Training entry point for Brain-Connectivity-GCN. |
| |
| v2 changes: |
| - site_holdout as default split_strategy |
| - Class weights computed from training labels → weighted CE loss |
| - save_top_k=5 for checkpoint ensembling |
| - ensemble_predict() utility after training |
| - batch_size default lowered to 16 (site holdout = smaller train sets) |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
|
|
| import numpy as np |
| import pytorch_lightning as pl |
| import torch |
| from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint |
| from torchmetrics.classification import BinaryAUROC |
|
|
| from brain_gcn.models.brain_gcn import BrainModeNetwork |
| from brain_gcn.tasks import ClassificationTask |
| from brain_gcn.utils.data.datamodule import ABIDEDataModule |
|
|
|
|
| |
| |
| |
|
|
| def build_parser() -> argparse.ArgumentParser: |
| parser = argparse.ArgumentParser(description="Train Brain-Connectivity-GCN classifier") |
| parser = ABIDEDataModule.add_data_specific_arguments(parser) |
| parser = ClassificationTask.add_model_specific_arguments(parser) |
| parser.add_argument("--max_epochs", type=int, default=200) |
| parser.add_argument("--accelerator", type=str, default="auto") |
| parser.add_argument("--devices", type=str, default="auto") |
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--ckpt_tag", type=str, default="", |
| help="Optional suffix appended to checkpoint directory name (e.g. seed-specific).") |
| parser.add_argument("--log_every_n_steps", type=int, default=1) |
| parser.add_argument("--prepare_data", action="store_true") |
| parser.add_argument("--test", action="store_true") |
| parser.add_argument( |
| "--no_ensemble", |
| action="store_true", |
| help="Skip top-5 checkpoint ensembling at test time.", |
| ) |
| return parser |
|
|
|
|
| |
| |
| |
|
|
| def validate_args(args: argparse.Namespace) -> None: |
| if args.model_name in ("fc_mlp", "adv_fc_mlp", "brain_mode", "adv_brain_mode", "dynamic_fc_attn") and args.use_population_adj: |
| raise ValueError( |
| f"{args.model_name} needs per-subject connectivity. Re-run with --no-use_population_adj." |
| ) |
| if args.use_dynamic_adj_sequence and args.use_population_adj: |
| raise ValueError( |
| "Dynamic adjacency sequences are per-subject. Re-run with --no-use_population_adj." |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def build_datamodule(args: argparse.Namespace) -> ABIDEDataModule: |
| |
| preserve_fc_sign = getattr(args, "preserve_fc_sign", False) |
| if args.model_name in ("fc_mlp", "adv_fc_mlp", "brain_mode", "adv_brain_mode") and not preserve_fc_sign: |
| preserve_fc_sign = True |
|
|
| return ABIDEDataModule( |
| data_dir=args.data_dir, |
| n_subjects=args.n_subjects, |
| window_len=args.window_len, |
| step=args.step, |
| max_windows=args.max_windows, |
| fc_threshold=args.fc_threshold, |
| use_dynamic_adj=args.use_dynamic_adj, |
| use_dynamic_adj_sequence=args.use_dynamic_adj_sequence, |
| use_population_adj=args.use_population_adj, |
| preserve_fc_sign=preserve_fc_sign, |
| use_fc_variance=getattr(args, "use_fc_variance", False), |
| use_fisher_z=getattr(args, "use_fisher_z", False), |
| use_fc_degree_features=getattr(args, "use_fc_degree_features", False), |
| use_fc_row_features=getattr(args, "use_fc_row_features", False), |
| n_pca_components=getattr(args, "n_pca_components", 0), |
| batch_size=args.batch_size, |
| val_ratio=args.val_ratio, |
| test_ratio=args.test_ratio, |
| split_strategy=args.split_strategy, |
| val_site=args.val_site, |
| test_site=args.test_site, |
| num_workers=args.num_workers, |
| overwrite_cache=getattr(args, "overwrite_cache", False), |
| force_prepare=args.prepare_data, |
| ) |
|
|
|
|
| def _compute_class_weights(dm: ABIDEDataModule) -> torch.Tensor: |
| """Balanced class weights from training labels: total / (n_classes * n_per_class).""" |
| labels = np.array([int(np.load(p, allow_pickle=True)["label"]) for p in dm._train_paths]) |
| n_td = int((labels == 0).sum()) |
| n_asd = int((labels == 1).sum()) |
| total = n_td + n_asd |
| w_td = total / (2.0 * n_td) |
| w_asd = total / (2.0 * n_asd) |
| return torch.tensor([w_td, w_asd], dtype=torch.float32) |
|
|
|
|
| def _discriminative_mode_init(dm: ABIDEDataModule, num_modes: int) -> torch.Tensor: |
| """Load training FCs by class and compute SVD-based discriminative modes. |
| |
| Called only when model_name == 'brain_mode'. Reads the cached .npz files |
| to compute (mean_FC_ASD − mean_FC_TD) and returns the top-K left singular |
| vectors as the initial mode matrix (K, N). |
| """ |
| fc_asd, fc_td = [], [] |
| for p in dm._train_paths: |
| data = np.load(p, allow_pickle=True) |
| fc = data["mean_fc"].astype(np.float32) |
| lbl = int(data["label"]) |
| (fc_asd if lbl == 1 else fc_td).append(fc) |
|
|
| fc_asd_arr = np.stack(fc_asd) |
| fc_td_arr = np.stack(fc_td) |
| return BrainModeNetwork.discriminative_init(fc_asd_arr, fc_td_arr, num_modes) |
|
|
|
|
| def build_task(args: argparse.Namespace, dm: ABIDEDataModule) -> ClassificationTask: |
| """Build ClassificationTask with class weights from the training split.""" |
| |
| try: |
| class_weights = _compute_class_weights(dm) |
| except Exception as exc: |
| print(f"WARNING: Could not compute class weights ({exc}). Using uniform weights.") |
| class_weights = None |
|
|
| mode_init = None |
| if args.model_name in ("brain_mode", "adv_brain_mode"): |
| try: |
| mode_init = _discriminative_mode_init(dm, getattr(args, "num_modes", 16)) |
| except Exception as exc: |
| print(f"[BMN] discriminative init failed ({exc}), using QR init.") |
|
|
| return ClassificationTask( |
| hidden_dim=args.hidden_dim, |
| dropout=args.dropout, |
| readout=args.readout, |
| model_name=args.model_name, |
| lr=args.lr, |
| weight_decay=args.weight_decay, |
| class_weights=class_weights, |
| bold_noise_std=args.bold_noise_std, |
| drop_edge_p=args.drop_edge_p, |
| cosine_t0=args.cosine_t0, |
| cosine_t_mult=args.cosine_t_mult, |
| cosine_eta_min=args.cosine_eta_min, |
| num_sites=dm.num_sites, |
| adv_site_weight=getattr(args, "adv_site_weight", 1.0), |
| num_nodes=dm.num_nodes, |
| num_modes=getattr(args, "num_modes", 16), |
| orth_weight=getattr(args, "orth_weight", 0.01), |
| mode_init=mode_init, |
| in_features=dm.num_nodes if getattr(args, "use_fc_row_features", False) else 1, |
| ) |
|
|
|
|
| def build_trainer(args: argparse.Namespace) -> tuple[pl.Trainer, Path]: |
| ckpt_name = args.model_name |
| if getattr(args, "n_pca_components", 0) > 0: |
| ckpt_name += f"_pca{args.n_pca_components}" |
| if args.model_name in ("brain_mode", "adv_brain_mode"): |
| split_tag = getattr(args, "split_strategy", "site_holdout")[:4] |
| ckpt_name += f"_k{getattr(args, 'num_modes', 16)}_{split_tag}" |
| ckpt_tag = getattr(args, "ckpt_tag", "") |
| if ckpt_tag: |
| ckpt_name += f"_{ckpt_tag}" |
| ckpt_dir = Path("checkpoints") / ckpt_name |
| ckpt_dir.mkdir(parents=True, exist_ok=True) |
| |
| |
| config_meta = { |
| "model_name": args.model_name, |
| "use_dynamic_adj_sequence": args.use_dynamic_adj_sequence, |
| "use_population_adj": args.use_population_adj, |
| } |
| config_path = ckpt_dir / "run_config.json" |
| with open(config_path, "w") as f: |
| json.dump(config_meta, f, indent=2) |
| |
| trainer = pl.Trainer( |
| max_epochs=args.max_epochs, |
| accelerator=args.accelerator, |
| devices=args.devices, |
| deterministic=True, |
| log_every_n_steps=args.log_every_n_steps, |
| callbacks=[ |
| EarlyStopping(monitor="val_auc", mode="max", patience=40), |
| ModelCheckpoint( |
| dirpath=str(ckpt_dir), |
| monitor="val_auc", |
| mode="max", |
| save_top_k=5, |
| filename="brain-gcn-{epoch:03d}-{val_auc:.3f}", |
| ), |
| ], |
| ) |
| return trainer, ckpt_dir |
|
|
|
|
| |
| |
| |
|
|
| def ensemble_predict( |
| ckpt_dir: str | Path, |
| dm: ABIDEDataModule, |
| device: str = "cpu", |
| ) -> torch.Tensor: |
| """Average softmax probabilities over the top-5 saved checkpoints. |
| |
| Verifies that each checkpoint's model config matches the datamodule's |
| adjacency mode to prevent silent mismatches. |
| |
| Returns |
| ------- |
| probs : (N_test, num_classes) averaged probability tensor |
| """ |
| ckpt_dir = Path(ckpt_dir) |
| ckpt_paths = sorted(ckpt_dir.glob("*.ckpt")) |
| if not ckpt_paths: |
| raise FileNotFoundError(f"No checkpoints found in {ckpt_dir}") |
|
|
| |
| config_path = ckpt_dir / "run_config.json" |
| if config_path.exists(): |
| with open(config_path) as f: |
| saved_config = json.load(f) |
| assert saved_config["use_dynamic_adj_sequence"] == dm.use_dynamic_adj_sequence, ( |
| f"Checkpoint use_dynamic_adj_sequence={saved_config['use_dynamic_adj_sequence']} " |
| f"but datamodule has {dm.use_dynamic_adj_sequence}" |
| ) |
| assert saved_config["use_population_adj"] == dm.use_population_adj, ( |
| f"Checkpoint use_population_adj={saved_config['use_population_adj']} " |
| f"but datamodule has {dm.use_population_adj}" |
| ) |
|
|
| all_probs: list[torch.Tensor] = [] |
| for ckpt in ckpt_paths: |
| task = ClassificationTask.load_from_checkpoint(ckpt, map_location=device, strict=False) |
| task.eval().to(device) |
| batch_probs: list[torch.Tensor] = [] |
| with torch.no_grad(): |
| for batch in dm.test_dataloader(): |
| bold_windows, adj = batch[0], batch[1] |
| logits = task(bold_windows.to(device), adj.to(device)) |
| batch_probs.append(torch.softmax(logits, dim=-1).cpu()) |
| all_probs.append(torch.cat(batch_probs, dim=0)) |
|
|
| return torch.stack(all_probs).mean(0) |
|
|
|
|
| |
| |
| |
|
|
| def train_from_args( |
| args: argparse.Namespace, |
| ) -> tuple[pl.Trainer, ClassificationTask, ABIDEDataModule]: |
| pl.seed_everything(args.seed, workers=True) |
| validate_args(args) |
|
|
| dm = build_datamodule(args) |
| |
| dm.prepare_data() |
| dm.setup() |
|
|
| task = build_task(args, dm) |
| trainer, ckpt_dir = build_trainer(args) |
| trainer.fit(task, datamodule=dm) |
|
|
| if args.test: |
| if getattr(args, "no_ensemble", False): |
| trainer.test(task, datamodule=dm, ckpt_path="best") |
| else: |
| |
| try: |
| avg_probs = ensemble_predict(ckpt_dir, dm) |
| preds = avg_probs.argmax(dim=-1) |
| |
| labels = torch.cat([b[2] for b in dm.test_dataloader()]) |
| acc = (preds == labels).float().mean().item() |
| auc_metric = BinaryAUROC() |
| auc = auc_metric(avg_probs[:, 1], labels).item() |
| print(f"\n[Ensemble] test_acc={acc:.4f} test_auc={auc:.4f}") |
| |
| trainer.callback_metrics["test_acc_ensemble"] = torch.tensor(acc) |
| trainer.callback_metrics["test_auc_ensemble"] = torch.tensor(auc) |
| except Exception as exc: |
| print(f"[Ensemble] failed ({exc}), falling back to single-best ckpt.") |
| trainer.test(task, datamodule=dm, ckpt_path="best") |
|
|
| return trainer, task, dm |
|
|
|
|
| def main() -> None: |
| |
| torch.set_float32_matmul_precision("medium") |
| args = build_parser().parse_args() |
| train_from_args(args) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|