""" Tiny-batch overfit test for Stage 2 (CFM). Takes a single batch, caches the Stage 1 mu_anchor, and trains a single-subject CFM for 500 steps. If the loss does not approach 0, the architecture cannot learn. """ import sys from pathlib import Path import torch import torch.nn as nn import numpy as np from omegaconf import OmegaConf src_dir = Path(__file__).resolve().parent.parent / "src" sys.path.append(str(src_dir)) from training import make_data_loaders, SUBJECTS from medarc_architecture import MultiSubjectConvLinearEncoder from stage2.CFM import CFM def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--cfg-path", type=str, default=str(Path(__file__).resolve().parent.parent / "config.yml")) parser.add_argument("--subject-idx", type=int, default=0, help="Index into subjects list (0-based)") parser.add_argument("--steps", type=int, default=500) parser.add_argument("--lr", type=float, default=3e-4) args = parser.parse_args() cfg = OmegaConf.load(args.cfg_path) device = torch.device(cfg.device if torch.cuda.is_available() else "cpu") subjects_list = cfg.get("subjects", SUBJECTS) sub = subjects_list[args.subject_idx] print(f"Device: {device} | Subject: {sub} | Steps: {args.steps}") # --- Data --- data_loaders = make_data_loaders(cfg) train_loader = data_loaders["train"] batch = next(iter(train_loader)) feats = [f.to(device) for f in batch["features"]] fmri = batch["fmri"].to(device) # (B, S, T, V) print(f"Batch stats - std: {fmri.std():.6f}, mean: {fmri.mean():.6f}") target_dim = fmri.shape[-1] feat_dims = [f.shape[-1] for f in batch["features"]] print(f"Batch shape: fmri={list(fmri.shape)}, feats={[list(f.shape) for f in feats]}") # --- Stage 1: cache mu_anchor --- stage1_model = MultiSubjectConvLinearEncoder( num_subjects=len(subjects_list), feat_dims=feat_dims, **cfg.stage1.model, ).to(device) ckpt_path = Path(cfg.out_dir) / "stage1_best.pt" if ckpt_path.exists(): stage1_model.load_state_dict(torch.load(ckpt_path, map_location=device)) print(f"Loaded Stage 1 from {ckpt_path}") else: print("No Stage 1 checkpoint found — using random init as mu_anchor") stage1_model.eval() with torch.no_grad(): mu_anchor = stage1_model(feats) # (B, S, T, V) # Extract single subject i = args.subject_idx x1 = fmri[:, i].transpose(1, 2) # (B, V, T) mu = mu_anchor[:, i].transpose(1, 2) # (B, V, T) # --- Stage 2: single-subject CFM --- cfm = CFM( feat_dim=target_dim, cfm_params=cfg.stage2.cfm, velocity_net_params=cfg.stage2.velocity_net, source_ve_params=cfg.stage2.source_ve, transport_params=cfg.stage2.transport, ).to(device) optimizer = torch.optim.AdamW(cfm.parameters(), lr=args.lr) # --- Overfit loop --- print(f"\nOverfitting single batch for {args.steps} steps...") for step in range(1, args.steps + 1): cfm.train() loss, _ = cfm.compute_loss(x1, mu) optimizer.zero_grad() loss.backward() optimizer.step() if step % 50 == 0 or step == 1: # Quick inference check cfm.eval() n_ts = cfg.stage2.get("n_timesteps", 25) with torch.inference_mode(): pred = cfm(mu, n_timesteps=n_ts) # (B, V, T) # Pearson r on this batch y_true = x1.transpose(1, 2).reshape(-1, target_dim).cpu().numpy() y_pred = pred.transpose(1, 2).reshape(-1, target_dim).cpu().numpy() from metric import pearsonr_score r = np.mean(pearsonr_score(y_true, y_pred)) print(f" Step {step:4d} loss={loss.item():.6f} pearson_r={r:.4f}") print("\nDone. If loss did not approach 0 and r did not approach 1, " "the architecture cannot learn the task.") if __name__ == "__main__": main()