File size: 5,376 Bytes
16d6869 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | """
Multi-model comparison runner.
v2 changes:
- Captures test_sens, test_spec, and ensemble metrics in results CSV
- Passes dynamic_graph_temporal flag through correctly
- Uses site_holdout as default (inherited from updated main.py defaults)
"""
from __future__ import annotations
import argparse
import csv
import logging
from copy import deepcopy
from pathlib import Path
import torch
from brain_gcn.main import build_parser, train_from_args, validate_args
log = logging.getLogger(__name__)
DEFAULT_MODELS = ("fc_mlp", "gcn", "graph_temporal")
def metric_value(value) -> float | int | str:
if isinstance(value, torch.Tensor):
if value.numel() == 1:
return float(value.detach().cpu())
# Multi-element tensor: flatten to scalar_mean or scalar_max
scalar_mean = float(value.detach().cpu().mean())
log.warning(
f"Multi-element metric tensor with shape {value.shape} — "
f"flattening to scalar_mean={scalar_mean:.4f}. "
"Consider reducing to single-value metrics in training_step."
)
return scalar_mean
if isinstance(value, (float, int, str)):
return value
return str(value)
def build_experiment_parser() -> argparse.ArgumentParser:
parser = build_parser()
parser.description = "Run Brain-Connectivity-GCN model comparisons"
parser.add_argument(
"--models",
nargs="+",
choices=["fc_mlp", "gru", "gcn", "graph_temporal", "brain_mode"],
default=list(DEFAULT_MODELS),
help="Model modes to run in order.",
)
parser.add_argument(
"--results_csv",
type=str,
default="results/experiment_summary.csv",
)
parser.add_argument(
"--dynamic_graph_temporal",
action="store_true",
help="Run graph_temporal with per-window adjacency sequences.",
)
parser.set_defaults(test=True)
return parser
def args_for_model(base_args: argparse.Namespace, model_name: str) -> argparse.Namespace:
args = deepcopy(base_args)
args.model_name = model_name
args.prepare_data = False
if model_name in ("fc_mlp", "adv_fc_mlp", "brain_mode", "adv_brain_mode"):
# These use per-subject FC as flat features — no population/dynamic adj
args.use_population_adj = False
args.use_dynamic_adj_sequence = False
args.use_dynamic_adj = False
args.use_fc_degree_features = False
elif model_name == "graph_temporal":
# Always use per-window FC as dynamic adjacency — population adj is uninformative
# Node features: per-ROI mean |FC| per window (connectivity strength, not BOLD std)
args.use_population_adj = False
args.use_dynamic_adj_sequence = True
args.use_dynamic_adj = False
args.use_fc_degree_features = True
elif model_name == "gcn":
# Per-subject mean FC as static adjacency — population adj is same for all subjects
# Node features: per-ROI mean |FC| per window (more discriminative than BOLD std)
args.use_population_adj = False
args.use_dynamic_adj_sequence = False
args.use_dynamic_adj = False
args.use_fc_degree_features = True
elif model_name == "gru":
# GRU ignores adjacency; per-subject FC still better than population adj
args.use_population_adj = False
args.use_dynamic_adj_sequence = False
args.use_dynamic_adj = False
args.use_fc_degree_features = False
else:
args.use_dynamic_adj_sequence = False
args.use_fc_degree_features = False
validate_args(args)
return args
def summarize_run(model_name: str, trainer) -> dict[str, float | int | str]:
row: dict[str, float | int | str] = {"model_name": model_name}
for key, value in sorted(trainer.callback_metrics.items()):
if key.startswith(("train_", "val_", "test_")):
row[key] = metric_value(value)
return row
def write_results(path: Path, rows: list[dict[str, float | int | str]]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
fieldnames = sorted({key for row in rows for key in row})
# model_name first, then alphabetical
fieldnames = ["model_name"] + [k for k in fieldnames if k != "model_name"]
with path.open("w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)
def main() -> None:
parser = build_experiment_parser()
args = parser.parse_args()
# prepare and setup once (before the model loop)
# Call setup() before preprocess_all so train_subjects reflects the actual split
from brain_gcn.main import build_datamodule
prep_args = deepcopy(args)
prep_args.prepare_data = True
dm = build_datamodule(prep_args)
dm.prepare_data()
dm.setup() # Call setup here to establish actual train/val/test boundary
rows = []
for model_name in args.models:
run_args = args_for_model(args, model_name)
trainer, _, _ = train_from_args(run_args)
rows.append(summarize_run(model_name, trainer))
write_results(Path(args.results_csv), rows)
print(f"[{model_name}] done — partial results written to {args.results_csv}")
print(f"\nWrote {len(rows)} rows to {args.results_csv}")
if __name__ == "__main__":
main()
|