BrainConnect-ASD / brain_gcn /population_main.py
Yatsuiii's picture
Upload folder using huggingface_hub
16d6869 verified
"""
Population Graph GCN — training entry point.
Architecture: Parisot et al. 2017/2018 (subject nodes, phenotypic edges).
- Nodes : subjects (N ≈ 1102)
- Features: PCA-reduced FC upper triangle (D=256)
- Edges : sex_match × age_gaussian_similarity > threshold
- Training: transductive — all nodes in graph, loss masked to train split
Usage
-----
python -m brain_gcn.population_main \\
--data_dir data \\
--pheno_csv data/raw/abide_s3/phenotypic.csv \\
--use_combat \\
--n_pca 256 \\
--hidden_dim 64 \\
--dropout 0.5 \\
--lr 5e-4 \\
--weight_decay 1e-3 \\
--epochs 500 \\
--seed 42
"""
from __future__ import annotations
import argparse
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.model_selection import StratifiedShuffleSplit
from torchmetrics.classification import BinaryAUROC, BinaryAccuracy, BinaryRecall, BinarySpecificity, BinaryF1Score
from brain_gcn.models.population_gcn import PopulationGCN
from brain_gcn.utils.data.population_graph import (
apply_pca,
build_population_adj,
extract_fc_features,
fit_pca,
harmonize_combat,
load_phenotypic,
normalize_adj,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def seed_everything(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def class_weights(labels: np.ndarray) -> torch.Tensor:
n_td = int((labels == 0).sum())
n_asd = int((labels == 1).sum())
total = n_td + n_asd
return torch.tensor([total / (2.0 * n_td), total / (2.0 * n_asd)], dtype=torch.float32)
def build_masks(n: int, train_idx, val_idx, test_idx, device):
def _mask(idx):
m = torch.zeros(n, dtype=torch.bool, device=device)
m[idx] = True
return m
return _mask(train_idx), _mask(val_idx), _mask(test_idx)
# ---------------------------------------------------------------------------
# Evaluation
# ---------------------------------------------------------------------------
@torch.no_grad()
def evaluate(logits: torch.Tensor, labels: torch.Tensor, mask: torch.Tensor):
probs = torch.softmax(logits[mask], dim=-1)
preds = probs.argmax(dim=-1)
tgts = labels[mask]
auc_m = BinaryAUROC()
acc_m = BinaryAccuracy()
sens_m = BinaryRecall()
spec_m = BinarySpecificity()
f1_m = BinaryF1Score()
auc = auc_m(probs[:, 1].cpu(), tgts.cpu()).item()
acc = acc_m(preds.cpu(), tgts.cpu()).item()
sens = sens_m(preds.cpu(), tgts.cpu()).item()
spec = spec_m(preds.cpu(), tgts.cpu()).item()
f1 = f1_m(preds.cpu(), tgts.cpu()).item()
return dict(auc=auc, acc=acc, sens=sens, spec=spec, f1=f1)
# ---------------------------------------------------------------------------
# Training loop
# ---------------------------------------------------------------------------
def train(args: argparse.Namespace) -> dict:
seed_everything(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# ------------------------------------------------------------------
# 1. Data
# ------------------------------------------------------------------
processed_dir = Path(args.data_dir) / "processed"
pheno = load_phenotypic(args.pheno_csv, processed_dir)
print(f"Subjects matched: {len(pheno)} (ASD={pheno['label'].sum()} TD={(pheno['label']==0).sum()})")
subject_ids = pheno["SUB_ID"].tolist()
labels_np = pheno["label"].values.astype(np.int64)
# ------------------------------------------------------------------
# 2. Train / val / test split (stratified)
# ------------------------------------------------------------------
sss = StratifiedShuffleSplit(n_splits=1, test_size=args.test_ratio, random_state=args.seed)
train_val_idx, test_idx = next(sss.split(subject_ids, labels_np))
val_size = args.val_ratio / (1.0 - args.test_ratio)
sss2 = StratifiedShuffleSplit(n_splits=1, test_size=val_size, random_state=args.seed)
rel_train, rel_val = next(sss2.split(train_val_idx, labels_np[train_val_idx]))
train_idx = train_val_idx[rel_train]
val_idx = train_val_idx[rel_val]
print(f"Split: train={len(train_idx)} val={len(val_idx)} test={len(test_idx)}")
# ------------------------------------------------------------------
# 3. FC features
# ------------------------------------------------------------------
print("Loading FC features …")
all_feats = extract_fc_features(processed_dir, subject_ids) # (N, 19900)
if args.use_combat:
print("Running ComBat harmonization …")
all_feats = harmonize_combat(
features=all_feats,
sites=pheno["SITE_ID"].tolist(),
labels=labels_np,
ages=pheno["AGE_AT_SCAN"].values,
sexes=pheno["sex_enc"].values,
)
# PCA fitted on training subjects only
scaler, pca = fit_pca(all_feats[train_idx], n_components=args.n_pca)
all_feats_pca = apply_pca(all_feats, scaler, pca) # (N, n_pca)
# ------------------------------------------------------------------
# 4. Population graph
# ------------------------------------------------------------------
print("Building population graph …")
adj_np = build_population_adj(
pheno,
threshold=args.graph_threshold,
use_site=args.use_site_edges,
)
adj_norm = torch.FloatTensor(normalize_adj(adj_np)).to(device)
# ------------------------------------------------------------------
# 5. Tensors
# ------------------------------------------------------------------
X = torch.FloatTensor(all_feats_pca).to(device) # (N, D)
labels = torch.LongTensor(labels_np).to(device) # (N,)
cw = class_weights(labels_np).to(device)
N = len(subject_ids)
train_mask, val_mask, test_mask = build_masks(N, train_idx, val_idx, test_idx, device)
# ------------------------------------------------------------------
# 6. Model
# ------------------------------------------------------------------
model = PopulationGCN(
in_dim=X.shape[1],
hidden_dim=args.hidden_dim,
dropout=args.dropout,
).to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer, T_0=args.cosine_t0, T_mult=2, eta_min=1e-6
)
# ------------------------------------------------------------------
# 7. Train
# ------------------------------------------------------------------
best_val_auc = 0.0
best_state = None
patience_left = args.patience
print(f"\n{'ep':>5s} | {'tr_loss':>8s} | {'val_auc':>8s} | {'val_acc':>8s} | {'val_sens':>9s} | {'val_spec':>9s}")
print("-" * 60)
for epoch in range(1, args.epochs + 1):
# ---- train ----
model.train()
optimizer.zero_grad()
logits = model(X, adj_norm)
loss = F.cross_entropy(logits[train_mask], labels[train_mask], weight=cw)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
# ---- validate ----
model.eval()
with torch.no_grad():
logits_eval = model(X, adj_norm)
val_m = evaluate(logits_eval, labels, val_mask)
if val_m["auc"] > best_val_auc:
best_val_auc = val_m["auc"]
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
patience_left = args.patience
else:
patience_left -= 1
if epoch % 10 == 0 or epoch == 1:
print(
f"{epoch:>5d} | {loss.item():>8.4f} | {val_m['auc']:>8.4f} | "
f"{val_m['acc']:>8.4f} | {val_m['sens']:>9.4f} | {val_m['spec']:>9.4f}"
)
if patience_left <= 0:
print(f"\nEarly stop at epoch {epoch}. Best val_auc={best_val_auc:.4f}")
break
# ------------------------------------------------------------------
# 8. Test
# ------------------------------------------------------------------
model.load_state_dict({k: v.to(device) for k, v in best_state.items()})
model.eval()
with torch.no_grad():
logits_final = model(X, adj_norm)
test_m = evaluate(logits_final, labels, test_mask)
print(f"\n{'='*60}")
print(f"[TEST] auc={test_m['auc']:.4f} acc={test_m['acc']:.4f} "
f"sens={test_m['sens']:.4f} spec={test_m['spec']:.4f} f1={test_m['f1']:.4f}")
print(f"{'='*60}")
# Save checkpoint
ckpt_dir = Path("checkpoints") / "population_gcn"
ckpt_dir.mkdir(parents=True, exist_ok=True)
ckpt_path = ckpt_dir / f"best_auc{best_val_auc:.3f}.pt"
torch.save({"model_state": best_state, "args": vars(args), "test_metrics": test_m}, ckpt_path)
print(f"Checkpoint saved: {ckpt_path}")
return test_m
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Population Graph GCN for ABIDE ASD classification")
p.add_argument("--data_dir", type=str, default="data")
p.add_argument("--pheno_csv", type=str, default="data/raw/abide_s3/phenotypic.csv")
p.add_argument("--use_combat", action="store_true", help="Apply ComBat site harmonization")
p.add_argument("--use_site_edges", action="store_true", help="Include site-match in graph edges")
p.add_argument("--n_pca", type=int, default=256)
p.add_argument("--graph_threshold", type=float, default=0.5)
p.add_argument("--hidden_dim", type=int, default=64)
p.add_argument("--dropout", type=float, default=0.5)
p.add_argument("--lr", type=float, default=5e-4)
p.add_argument("--weight_decay", type=float, default=1e-3)
p.add_argument("--cosine_t0", type=int, default=100)
p.add_argument("--epochs", type=int, default=500)
p.add_argument("--patience", type=int, default=60)
p.add_argument("--val_ratio", type=float, default=0.1)
p.add_argument("--test_ratio", type=float, default=0.1)
p.add_argument("--seed", type=int, default=42)
return p
def main() -> None:
torch.set_float32_matmul_precision("medium")
args = build_parser().parse_args()
train(args)
if __name__ == "__main__":
main()