BrainConnect-ASD / brain_gcn /tasks /classification.py
Yatsuiii's picture
Upload folder using huggingface_hub
16d6869 verified
"""
PyTorch Lightning training task for ASD/TD classification.
v2 changes:
- class_weights arg → weighted CrossEntropyLoss (fixes class imbalance)
- CosineAnnealingWarmRestarts scheduler (T_0=50, T_mult=2)
- BOLD noise augmentation in training_step
- Sensitivity (ASD recall) + Specificity (TD recall) metrics added
- drop_edge_p forwarded to build_model
"""
from __future__ import annotations
import argparse
import pytorch_lightning as pl
import torch
from torch import nn
from torchmetrics.classification import (
BinaryAUROC,
BinaryAccuracy,
BinaryF1Score,
BinaryRecall,
BinarySpecificity,
)
from brain_gcn.models import build_model
from brain_gcn.utils.grl import ganin_alpha
class ClassificationTask(pl.LightningModule):
def __init__(
self,
hidden_dim: int = 64,
dropout: float = 0.5,
readout: str = "attention",
model_name: str = "graph_temporal",
lr: float = 1e-3,
weight_decay: float = 1e-4,
class_weights: torch.Tensor | None = None,
bold_noise_std: float = 0.01,
drop_edge_p: float = 0.1,
cosine_t0: int = 50,
cosine_t_mult: int = 2,
cosine_eta_min: float = 1e-5,
num_sites: int = 1,
adv_site_weight: float = 1.0,
num_nodes: int = 200,
num_modes: int = 16,
orth_weight: float = 0.01,
mode_init: "torch.Tensor | None" = None,
in_features: int = 1,
):
"""
Parameters
----------
class_weights : 1-D tensor of length num_classes for weighted CE.
bold_noise_std : std dev of Gaussian noise added during training.
drop_edge_p : edge drop probability for graph models.
cosine_t0 : CosineAnnealingWarmRestarts first restart epoch.
cosine_t_mult : restart interval multiplier.
cosine_eta_min : minimum LR after annealing.
num_sites : number of acquisition sites (for adv_fc_mlp).
adv_site_weight : weight on the adversarial site loss term.
in_features : node feature dimension (1 for BOLD std, N for FC rows).
"""
super().__init__()
self.save_hyperparameters(ignore=["class_weights", "mode_init"])
self.register_buffer("class_weights", class_weights)
self.model = build_model(
model_name=model_name,
hidden_dim=hidden_dim,
num_sites=num_sites,
num_nodes=num_nodes,
num_modes=num_modes,
dropout=dropout,
readout=readout,
drop_edge_p=drop_edge_p,
mode_init=mode_init,
in_features=in_features,
)
self.loss_fn = nn.CrossEntropyLoss(weight=class_weights)
# Site cross-entropy — unweighted (sites roughly balanced)
self.site_loss_fn = nn.CrossEntropyLoss(ignore_index=-1)
# --- Metrics --------------------------------------------------------
self.train_acc = BinaryAccuracy()
self.val_acc = BinaryAccuracy()
self.val_auc = BinaryAUROC()
self.val_f1 = BinaryF1Score()
self.val_sens = BinaryRecall() # sensitivity = ASD recall
self.val_spec = BinarySpecificity() # specificity = TD recall
self.test_acc = BinaryAccuracy()
self.test_auc = BinaryAUROC()
self.test_f1 = BinaryF1Score()
self.test_sens = BinaryRecall()
self.test_spec = BinarySpecificity()
@property
def _is_adversarial(self) -> bool:
return self.hparams.model_name in ("adv_fc_mlp", "adv_brain_mode")
# ------------------------------------------------------------------
def forward(self, bold_windows: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
return self.model(bold_windows, adj)
def _step(self, batch, stage: str) -> torch.Tensor:
bold_windows, adj, labels, site_ids = batch
logits = self(bold_windows, adj)
loss = self.loss_fn(logits, labels)
probs = torch.softmax(logits, dim=-1)[:, 1]
preds = torch.argmax(logits, dim=-1)
self.log(f"{stage}_loss", loss, prog_bar=True, on_epoch=True, on_step=False)
if stage == "train":
self.train_acc.update(preds, labels)
self.log("train_acc", self.train_acc, prog_bar=True, on_epoch=True, on_step=False)
elif stage == "val":
self.val_acc.update(preds, labels)
self.val_auc.update(probs, labels)
self.val_f1.update(preds, labels)
self.val_sens.update(preds, labels)
self.val_spec.update(preds, labels)
self.log("val_acc", self.val_acc, prog_bar=True, on_epoch=True, on_step=False)
self.log("val_auc", self.val_auc, prog_bar=True, on_epoch=True, on_step=False)
self.log("val_f1", self.val_f1, prog_bar=False, on_epoch=True, on_step=False)
self.log("val_sens", self.val_sens, prog_bar=False, on_epoch=True, on_step=False)
self.log("val_spec", self.val_spec, prog_bar=False, on_epoch=True, on_step=False)
elif stage == "test":
self.test_acc.update(preds, labels)
self.test_auc.update(probs, labels)
self.test_f1.update(preds, labels)
self.test_sens.update(preds, labels)
self.test_spec.update(preds, labels)
self.log("test_acc", self.test_acc, prog_bar=True, on_epoch=True, on_step=False)
self.log("test_auc", self.test_auc, prog_bar=True, on_epoch=True, on_step=False)
self.log("test_f1", self.test_f1, prog_bar=True, on_epoch=True, on_step=False)
self.log("test_sens", self.test_sens, prog_bar=True, on_epoch=True, on_step=False)
self.log("test_spec", self.test_spec, prog_bar=True, on_epoch=True, on_step=False)
return loss
def training_step(self, batch, batch_idx: int) -> torch.Tensor:
bold_windows, adj, labels, site_ids = batch
if self.hparams.bold_noise_std > 0.0:
signal_std = bold_windows.std(dim=(1, 2), keepdim=True).detach()
noise = torch.randn_like(bold_windows) * self.hparams.bold_noise_std * signal_std
bold_windows = bold_windows + noise
if self._is_adversarial:
# Dual loss: ASD classification + adversarial site deconfounding
asd_logits, site_logits = self.model(
bold_windows, adj, return_site_logits=True
)
asd_loss = self.loss_fn(asd_logits, labels)
site_loss = self.site_loss_fn(site_logits, site_ids)
loss = asd_loss + self.hparams.adv_site_weight * site_loss
probs = torch.softmax(asd_logits, dim=-1)[:, 1]
preds = torch.argmax(asd_logits, dim=-1)
self.log("train_asd_loss", asd_loss, prog_bar=False, on_epoch=True, on_step=False)
self.log("train_site_loss", site_loss, prog_bar=False, on_epoch=True, on_step=False)
self.log("train_loss", loss, prog_bar=True, on_epoch=True, on_step=False)
self.train_acc.update(preds, labels)
self.log("train_acc", self.train_acc, prog_bar=True, on_epoch=True, on_step=False)
else:
loss = self._step((bold_windows, adj, labels, site_ids), "train")
# Orthogonality regularization — BMN only (model exposes orthogonality_loss())
if hasattr(self.model, "orthogonality_loss") and self.hparams.orth_weight > 0.0:
orth = self.model.orthogonality_loss()
loss = loss + self.hparams.orth_weight * orth
self.log("train_orth_loss", orth, prog_bar=False, on_epoch=True, on_step=False)
return loss
def on_train_epoch_start(self) -> None:
"""Anneal the GRL alpha at the start of each epoch."""
if self._is_adversarial:
alpha = ganin_alpha(self.current_epoch, self.trainer.max_epochs)
self.model.grl.alpha = alpha
self.log("grl_alpha", alpha, prog_bar=False, on_epoch=True, on_step=False)
def validation_step(self, batch, batch_idx: int) -> torch.Tensor:
return self._step(batch, "val")
def test_step(self, batch, batch_idx: int) -> torch.Tensor:
return self._step(batch, "test")
# ------------------------------------------------------------------
def configure_optimizers(self):
opt = torch.optim.AdamW(
self.parameters(),
lr=self.hparams.lr,
weight_decay=self.hparams.weight_decay,
)
sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
opt,
T_0=self.hparams.cosine_t0,
T_mult=self.hparams.cosine_t_mult,
eta_min=self.hparams.cosine_eta_min,
)
return {
"optimizer": opt,
"lr_scheduler": {"scheduler": sch, "interval": "epoch"},
}
# ------------------------------------------------------------------
@staticmethod
def add_model_specific_arguments(parent_parser: argparse.ArgumentParser):
parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--hidden_dim", type=int, default=64)
parser.add_argument("--dropout", type=float, default=0.5)
parser.add_argument("--readout", choices=["mean", "attention"], default="attention")
parser.add_argument(
"--model_name",
choices=["graph_temporal", "gcn", "gru", "fc_mlp", "adv_fc_mlp",
"gat", "transformer", "cnn3d", "graphsage",
"brain_mode", "adv_brain_mode", "dynamic_fc_attn"],
default="graph_temporal",
)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--adv_site_weight", type=float, default=1.0,
help="Weight on adversarial site loss (adv_fc_mlp only).")
parser.add_argument("--weight_decay", type=float, default=1e-4)
parser.add_argument("--bold_noise_std", type=float, default=0.01)
parser.add_argument("--drop_edge_p", type=float, default=0.1)
parser.add_argument("--cosine_t0", type=int, default=50)
parser.add_argument("--cosine_t_mult", type=int, default=2,
help="CosineAnnealingWarmRestarts restart interval multiplier")
parser.add_argument("--cosine_eta_min", type=float, default=1e-5,
help="CosineAnnealingWarmRestarts minimum learning rate")
parser.add_argument("--num_modes", type=int, default=16,
help="Brain Mode Network: number of learnable modes K")
parser.add_argument("--orth_weight", type=float, default=0.01,
help="Brain Mode Network: orthogonality regularization weight")
return parser