File size: 5,376 Bytes
16d6869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
Multi-model comparison runner.

v2 changes:
  - Captures test_sens, test_spec, and ensemble metrics in results CSV
  - Passes dynamic_graph_temporal flag through correctly
  - Uses site_holdout as default (inherited from updated main.py defaults)
"""

from __future__ import annotations

import argparse
import csv
import logging
from copy import deepcopy
from pathlib import Path

import torch

from brain_gcn.main import build_parser, train_from_args, validate_args

log = logging.getLogger(__name__)


DEFAULT_MODELS = ("fc_mlp", "gcn", "graph_temporal")


def metric_value(value) -> float | int | str:
    if isinstance(value, torch.Tensor):
        if value.numel() == 1:
            return float(value.detach().cpu())
        # Multi-element tensor: flatten to scalar_mean or scalar_max
        scalar_mean = float(value.detach().cpu().mean())
        log.warning(
            f"Multi-element metric tensor with shape {value.shape} — "
            f"flattening to scalar_mean={scalar_mean:.4f}. "
            "Consider reducing to single-value metrics in training_step."
        )
        return scalar_mean
    if isinstance(value, (float, int, str)):
        return value
    return str(value)


def build_experiment_parser() -> argparse.ArgumentParser:
    parser = build_parser()
    parser.description = "Run Brain-Connectivity-GCN model comparisons"
    parser.add_argument(
        "--models",
        nargs="+",
        choices=["fc_mlp", "gru", "gcn", "graph_temporal", "brain_mode"],
        default=list(DEFAULT_MODELS),
        help="Model modes to run in order.",
    )
    parser.add_argument(
        "--results_csv",
        type=str,
        default="results/experiment_summary.csv",
    )
    parser.add_argument(
        "--dynamic_graph_temporal",
        action="store_true",
        help="Run graph_temporal with per-window adjacency sequences.",
    )
    parser.set_defaults(test=True)
    return parser


def args_for_model(base_args: argparse.Namespace, model_name: str) -> argparse.Namespace:
    args = deepcopy(base_args)
    args.model_name = model_name
    args.prepare_data = False

    if model_name in ("fc_mlp", "adv_fc_mlp", "brain_mode", "adv_brain_mode"):
        # These use per-subject FC as flat features — no population/dynamic adj
        args.use_population_adj = False
        args.use_dynamic_adj_sequence = False
        args.use_dynamic_adj = False
        args.use_fc_degree_features = False
    elif model_name == "graph_temporal":
        # Always use per-window FC as dynamic adjacency — population adj is uninformative
        # Node features: per-ROI mean |FC| per window (connectivity strength, not BOLD std)
        args.use_population_adj = False
        args.use_dynamic_adj_sequence = True
        args.use_dynamic_adj = False
        args.use_fc_degree_features = True
    elif model_name == "gcn":
        # Per-subject mean FC as static adjacency — population adj is same for all subjects
        # Node features: per-ROI mean |FC| per window (more discriminative than BOLD std)
        args.use_population_adj = False
        args.use_dynamic_adj_sequence = False
        args.use_dynamic_adj = False
        args.use_fc_degree_features = True
    elif model_name == "gru":
        # GRU ignores adjacency; per-subject FC still better than population adj
        args.use_population_adj = False
        args.use_dynamic_adj_sequence = False
        args.use_dynamic_adj = False
        args.use_fc_degree_features = False
    else:
        args.use_dynamic_adj_sequence = False
        args.use_fc_degree_features = False

    validate_args(args)
    return args


def summarize_run(model_name: str, trainer) -> dict[str, float | int | str]:
    row: dict[str, float | int | str] = {"model_name": model_name}
    for key, value in sorted(trainer.callback_metrics.items()):
        if key.startswith(("train_", "val_", "test_")):
            row[key] = metric_value(value)
    return row


def write_results(path: Path, rows: list[dict[str, float | int | str]]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    fieldnames = sorted({key for row in rows for key in row})
    # model_name first, then alphabetical
    fieldnames = ["model_name"] + [k for k in fieldnames if k != "model_name"]
    with path.open("w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)


def main() -> None:
    parser = build_experiment_parser()
    args = parser.parse_args()

    # prepare and setup once (before the model loop)
    # Call setup() before preprocess_all so train_subjects reflects the actual split
    from brain_gcn.main import build_datamodule
    prep_args = deepcopy(args)
    prep_args.prepare_data = True
    dm = build_datamodule(prep_args)
    dm.prepare_data()
    dm.setup()  # Call setup here to establish actual train/val/test boundary

    rows = []
    for model_name in args.models:
        run_args = args_for_model(args, model_name)
        trainer, _, _ = train_from_args(run_args)
        rows.append(summarize_run(model_name, trainer))
        write_results(Path(args.results_csv), rows)
        print(f"[{model_name}] done — partial results written to {args.results_csv}")

    print(f"\nWrote {len(rows)} rows to {args.results_csv}")


if __name__ == "__main__":
    main()