File size: 5,502 Bytes
b4b2877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#!/usr/bin/env python3
"""
Experiment G: Per-subject diagnostic analysis.

Load the best scene-recognition checkpoint(s) from previous T1 runs and
produce a per-test-volunteer breakdown of F1 and Accuracy. Reveals whether
aggregate metrics are driven by one or two outlier subjects, as reviewers
often ask.

Runs CPU-side; no training.
"""

import os
import sys
import json
import glob
import argparse
import numpy as np
import torch
from sklearn.metrics import accuracy_score, f1_score

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.dataset import (
    MultimodalSceneDataset, TEST_VOLS, SCENE_LABELS, NUM_CLASSES,
    get_dataloaders,
)
from nets.models import build_model


def per_subject_eval(model, device, modalities, stats, downsample):
    """Evaluate one model across each test volunteer separately."""
    breakdown = {}
    for vol in TEST_VOLS:
        ds = MultimodalSceneDataset([vol], modalities, downsample=downsample,
                                    stats=stats)
        if len(ds) == 0:
            breakdown[vol] = {'n': 0}
            continue
        preds, ys = [], []
        model.eval()
        with torch.no_grad():
            for i in range(len(ds)):
                x, y = ds[i]
                x = x.to(device).unsqueeze(0)
                mask = torch.ones(1, x.size(1), dtype=torch.bool).to(device)
                logits = model(x, mask)
                preds.append(logits.argmax(dim=1).cpu().item())
                ys.append(y)
        breakdown[vol] = {
            'n': len(ds),
            'acc': float(accuracy_score(ys, preds)),
            'f1': float(f1_score(ys, preds, average='macro', zero_division=0)),
            'preds': preds,
            'labels': ys,
            'samples': ds.sample_info,
        }
    return breakdown


def run_on_checkpoint(ckpt_path, args_json_path, output_dir):
    ckpt_args = json.load(open(args_json_path))['args']
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    modalities = ckpt_args['modalities'] if isinstance(ckpt_args['modalities'], list) \
                 else ckpt_args['modalities'].split(',')
    downsample = ckpt_args.get('downsample', 5)
    # Get train stats
    _, _, _, info = get_dataloaders(modalities,
                                    batch_size=ckpt_args.get('batch_size', 16),
                                    downsample=downsample)
    # Need the actual stats object -- re-load train set to compute
    tr_ds = MultimodalSceneDataset(
        __import__('experiments.dataset', fromlist=['TRAIN_VOLS']).TRAIN_VOLS,
        modalities, downsample=downsample)
    stats = tr_ds.get_stats()

    model = build_model(
        ckpt_args.get('model', 'transformer'),
        ckpt_args.get('fusion', 'late'),
        info['feat_dim'], info['modality_dims'], NUM_CLASSES,
        hidden_dim=ckpt_args.get('hidden_dim', 128),
        proj_dim=ckpt_args.get('proj_dim', 0),
        late_agg=ckpt_args.get('late_agg', 'mean'),
    ).to(device)
    try:
        sd = torch.load(ckpt_path, weights_only=True, map_location=device)
    except Exception:
        sd = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(sd, strict=False)

    breakdown = per_subject_eval(model, device, modalities, stats, downsample)

    # Overall F1
    all_preds, all_ys = [], []
    for v, info_v in breakdown.items():
        if info_v.get('n', 0) > 0:
            all_preds.extend(info_v['preds'])
            all_ys.extend(info_v['labels'])
    overall_f1 = float(f1_score(all_ys, all_preds, average='macro', zero_division=0))
    overall_acc = float(accuracy_score(all_ys, all_preds))

    # Per-subject summary
    summary = {
        'ckpt': ckpt_path,
        'modalities': modalities,
        'overall': {'acc': overall_acc, 'f1': overall_f1,
                    'n': len(all_preds)},
        'per_subject': {
            v: {'n': b.get('n'), 'acc': b.get('acc'), 'f1': b.get('f1')}
            for v, b in breakdown.items()
        },
        'detail': breakdown,
    }
    os.makedirs(output_dir, exist_ok=True)
    out_path = os.path.join(output_dir, os.path.basename(
        os.path.dirname(ckpt_path)) + '_per_subject.json')
    with open(out_path, 'w') as f:
        json.dump(summary, f, indent=2)
    print(f"Per-subject breakdown saved: {out_path}")
    print(f"Overall F1: {overall_f1:.4f}  Acc: {overall_acc:.4f}")
    for v, b in summary['per_subject'].items():
        print(f"  {v}: n={b['n']} acc={b.get('acc'):.3f} f1={b.get('f1'):.3f}"
              if b.get('n') else f"  {v}: (empty)")
    return summary


def main():
    p = argparse.ArgumentParser()
    p.add_argument('--exp_root', type=str, required=True,
                   help='Directory containing run subdirs with model_best.pt and results.json')
    p.add_argument('--output_dir', type=str, required=True)
    args = p.parse_args()

    runs = []
    for sub in sorted(os.listdir(args.exp_root)):
        if sub == 'slurm_logs':
            continue
        ckpt = os.path.join(args.exp_root, sub, 'model_best.pt')
        res = os.path.join(args.exp_root, sub, 'results.json')
        if os.path.exists(ckpt) and os.path.exists(res):
            runs.append((ckpt, res))
    print(f"Found {len(runs)} runs with checkpoints.")
    for ckpt, res in runs:
        try:
            run_on_checkpoint(ckpt, res, args.output_dir)
        except Exception as e:
            print(f"  FAIL {ckpt}: {e}")


if __name__ == '__main__':
    main()