PULSE-code / experiments /analysis /exp_per_subject.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
#!/usr/bin/env python3
"""
Experiment G: Per-subject diagnostic analysis.
Load the best scene-recognition checkpoint(s) from previous T1 runs and
produce a per-test-volunteer breakdown of F1 and Accuracy. Reveals whether
aggregate metrics are driven by one or two outlier subjects, as reviewers
often ask.
Runs CPU-side; no training.
"""
import os
import sys
import json
import glob
import argparse
import numpy as np
import torch
from sklearn.metrics import accuracy_score, f1_score
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.dataset import (
MultimodalSceneDataset, TEST_VOLS, SCENE_LABELS, NUM_CLASSES,
get_dataloaders,
)
from nets.models import build_model
def per_subject_eval(model, device, modalities, stats, downsample):
"""Evaluate one model across each test volunteer separately."""
breakdown = {}
for vol in TEST_VOLS:
ds = MultimodalSceneDataset([vol], modalities, downsample=downsample,
stats=stats)
if len(ds) == 0:
breakdown[vol] = {'n': 0}
continue
preds, ys = [], []
model.eval()
with torch.no_grad():
for i in range(len(ds)):
x, y = ds[i]
x = x.to(device).unsqueeze(0)
mask = torch.ones(1, x.size(1), dtype=torch.bool).to(device)
logits = model(x, mask)
preds.append(logits.argmax(dim=1).cpu().item())
ys.append(y)
breakdown[vol] = {
'n': len(ds),
'acc': float(accuracy_score(ys, preds)),
'f1': float(f1_score(ys, preds, average='macro', zero_division=0)),
'preds': preds,
'labels': ys,
'samples': ds.sample_info,
}
return breakdown
def run_on_checkpoint(ckpt_path, args_json_path, output_dir):
ckpt_args = json.load(open(args_json_path))['args']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
modalities = ckpt_args['modalities'] if isinstance(ckpt_args['modalities'], list) \
else ckpt_args['modalities'].split(',')
downsample = ckpt_args.get('downsample', 5)
# Get train stats
_, _, _, info = get_dataloaders(modalities,
batch_size=ckpt_args.get('batch_size', 16),
downsample=downsample)
# Need the actual stats object -- re-load train set to compute
tr_ds = MultimodalSceneDataset(
__import__('experiments.dataset', fromlist=['TRAIN_VOLS']).TRAIN_VOLS,
modalities, downsample=downsample)
stats = tr_ds.get_stats()
model = build_model(
ckpt_args.get('model', 'transformer'),
ckpt_args.get('fusion', 'late'),
info['feat_dim'], info['modality_dims'], NUM_CLASSES,
hidden_dim=ckpt_args.get('hidden_dim', 128),
proj_dim=ckpt_args.get('proj_dim', 0),
late_agg=ckpt_args.get('late_agg', 'mean'),
).to(device)
try:
sd = torch.load(ckpt_path, weights_only=True, map_location=device)
except Exception:
sd = torch.load(ckpt_path, map_location=device)
model.load_state_dict(sd, strict=False)
breakdown = per_subject_eval(model, device, modalities, stats, downsample)
# Overall F1
all_preds, all_ys = [], []
for v, info_v in breakdown.items():
if info_v.get('n', 0) > 0:
all_preds.extend(info_v['preds'])
all_ys.extend(info_v['labels'])
overall_f1 = float(f1_score(all_ys, all_preds, average='macro', zero_division=0))
overall_acc = float(accuracy_score(all_ys, all_preds))
# Per-subject summary
summary = {
'ckpt': ckpt_path,
'modalities': modalities,
'overall': {'acc': overall_acc, 'f1': overall_f1,
'n': len(all_preds)},
'per_subject': {
v: {'n': b.get('n'), 'acc': b.get('acc'), 'f1': b.get('f1')}
for v, b in breakdown.items()
},
'detail': breakdown,
}
os.makedirs(output_dir, exist_ok=True)
out_path = os.path.join(output_dir, os.path.basename(
os.path.dirname(ckpt_path)) + '_per_subject.json')
with open(out_path, 'w') as f:
json.dump(summary, f, indent=2)
print(f"Per-subject breakdown saved: {out_path}")
print(f"Overall F1: {overall_f1:.4f} Acc: {overall_acc:.4f}")
for v, b in summary['per_subject'].items():
print(f" {v}: n={b['n']} acc={b.get('acc'):.3f} f1={b.get('f1'):.3f}"
if b.get('n') else f" {v}: (empty)")
return summary
def main():
p = argparse.ArgumentParser()
p.add_argument('--exp_root', type=str, required=True,
help='Directory containing run subdirs with model_best.pt and results.json')
p.add_argument('--output_dir', type=str, required=True)
args = p.parse_args()
runs = []
for sub in sorted(os.listdir(args.exp_root)):
if sub == 'slurm_logs':
continue
ckpt = os.path.join(args.exp_root, sub, 'model_best.pt')
res = os.path.join(args.exp_root, sub, 'results.json')
if os.path.exists(ckpt) and os.path.exists(res):
runs.append((ckpt, res))
print(f"Found {len(runs)} runs with checkpoints.")
for ckpt, res in runs:
try:
run_on_checkpoint(ckpt, res, args.output_dir)
except Exception as e:
print(f" FAIL {ckpt}: {e}")
if __name__ == '__main__':
main()