File size: 5,502 Bytes
b4b2877 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | #!/usr/bin/env python3
"""
Experiment G: Per-subject diagnostic analysis.
Load the best scene-recognition checkpoint(s) from previous T1 runs and
produce a per-test-volunteer breakdown of F1 and Accuracy. Reveals whether
aggregate metrics are driven by one or two outlier subjects, as reviewers
often ask.
Runs CPU-side; no training.
"""
import os
import sys
import json
import glob
import argparse
import numpy as np
import torch
from sklearn.metrics import accuracy_score, f1_score
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.dataset import (
MultimodalSceneDataset, TEST_VOLS, SCENE_LABELS, NUM_CLASSES,
get_dataloaders,
)
from nets.models import build_model
def per_subject_eval(model, device, modalities, stats, downsample):
"""Evaluate one model across each test volunteer separately."""
breakdown = {}
for vol in TEST_VOLS:
ds = MultimodalSceneDataset([vol], modalities, downsample=downsample,
stats=stats)
if len(ds) == 0:
breakdown[vol] = {'n': 0}
continue
preds, ys = [], []
model.eval()
with torch.no_grad():
for i in range(len(ds)):
x, y = ds[i]
x = x.to(device).unsqueeze(0)
mask = torch.ones(1, x.size(1), dtype=torch.bool).to(device)
logits = model(x, mask)
preds.append(logits.argmax(dim=1).cpu().item())
ys.append(y)
breakdown[vol] = {
'n': len(ds),
'acc': float(accuracy_score(ys, preds)),
'f1': float(f1_score(ys, preds, average='macro', zero_division=0)),
'preds': preds,
'labels': ys,
'samples': ds.sample_info,
}
return breakdown
def run_on_checkpoint(ckpt_path, args_json_path, output_dir):
ckpt_args = json.load(open(args_json_path))['args']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
modalities = ckpt_args['modalities'] if isinstance(ckpt_args['modalities'], list) \
else ckpt_args['modalities'].split(',')
downsample = ckpt_args.get('downsample', 5)
# Get train stats
_, _, _, info = get_dataloaders(modalities,
batch_size=ckpt_args.get('batch_size', 16),
downsample=downsample)
# Need the actual stats object -- re-load train set to compute
tr_ds = MultimodalSceneDataset(
__import__('experiments.dataset', fromlist=['TRAIN_VOLS']).TRAIN_VOLS,
modalities, downsample=downsample)
stats = tr_ds.get_stats()
model = build_model(
ckpt_args.get('model', 'transformer'),
ckpt_args.get('fusion', 'late'),
info['feat_dim'], info['modality_dims'], NUM_CLASSES,
hidden_dim=ckpt_args.get('hidden_dim', 128),
proj_dim=ckpt_args.get('proj_dim', 0),
late_agg=ckpt_args.get('late_agg', 'mean'),
).to(device)
try:
sd = torch.load(ckpt_path, weights_only=True, map_location=device)
except Exception:
sd = torch.load(ckpt_path, map_location=device)
model.load_state_dict(sd, strict=False)
breakdown = per_subject_eval(model, device, modalities, stats, downsample)
# Overall F1
all_preds, all_ys = [], []
for v, info_v in breakdown.items():
if info_v.get('n', 0) > 0:
all_preds.extend(info_v['preds'])
all_ys.extend(info_v['labels'])
overall_f1 = float(f1_score(all_ys, all_preds, average='macro', zero_division=0))
overall_acc = float(accuracy_score(all_ys, all_preds))
# Per-subject summary
summary = {
'ckpt': ckpt_path,
'modalities': modalities,
'overall': {'acc': overall_acc, 'f1': overall_f1,
'n': len(all_preds)},
'per_subject': {
v: {'n': b.get('n'), 'acc': b.get('acc'), 'f1': b.get('f1')}
for v, b in breakdown.items()
},
'detail': breakdown,
}
os.makedirs(output_dir, exist_ok=True)
out_path = os.path.join(output_dir, os.path.basename(
os.path.dirname(ckpt_path)) + '_per_subject.json')
with open(out_path, 'w') as f:
json.dump(summary, f, indent=2)
print(f"Per-subject breakdown saved: {out_path}")
print(f"Overall F1: {overall_f1:.4f} Acc: {overall_acc:.4f}")
for v, b in summary['per_subject'].items():
print(f" {v}: n={b['n']} acc={b.get('acc'):.3f} f1={b.get('f1'):.3f}"
if b.get('n') else f" {v}: (empty)")
return summary
def main():
p = argparse.ArgumentParser()
p.add_argument('--exp_root', type=str, required=True,
help='Directory containing run subdirs with model_best.pt and results.json')
p.add_argument('--output_dir', type=str, required=True)
args = p.parse_args()
runs = []
for sub in sorted(os.listdir(args.exp_root)):
if sub == 'slurm_logs':
continue
ckpt = os.path.join(args.exp_root, sub, 'model_best.pt')
res = os.path.join(args.exp_root, sub, 'results.json')
if os.path.exists(ckpt) and os.path.exists(res):
runs.append((ckpt, res))
print(f"Found {len(runs)} runs with checkpoints.")
for ckpt, res in runs:
try:
run_on_checkpoint(ckpt, res, args.output_dir)
except Exception as e:
print(f" FAIL {ckpt}: {e}")
if __name__ == '__main__':
main()
|