| """ |
| Tiny-batch overfit test for Stage 2 (CFM). |
| Takes a single batch, caches the Stage 1 mu_anchor, and trains a single-subject |
| CFM for 500 steps. If the loss does not approach 0, the architecture cannot learn. |
| """ |
|
|
| import sys |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| import numpy as np |
| from omegaconf import OmegaConf |
|
|
| src_dir = Path(__file__).resolve().parent.parent / "src" |
| sys.path.append(str(src_dir)) |
|
|
| from training import make_data_loaders, SUBJECTS |
| from medarc_architecture import MultiSubjectConvLinearEncoder |
| from stage2.CFM import CFM |
|
|
|
|
| def main(): |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--cfg-path", type=str, default=str(Path(__file__).resolve().parent.parent / "config.yml")) |
| parser.add_argument("--subject-idx", type=int, default=0, help="Index into subjects list (0-based)") |
| parser.add_argument("--steps", type=int, default=500) |
| parser.add_argument("--lr", type=float, default=3e-4) |
| args = parser.parse_args() |
|
|
| cfg = OmegaConf.load(args.cfg_path) |
| device = torch.device(cfg.device if torch.cuda.is_available() else "cpu") |
| subjects_list = cfg.get("subjects", SUBJECTS) |
| sub = subjects_list[args.subject_idx] |
| print(f"Device: {device} | Subject: {sub} | Steps: {args.steps}") |
|
|
| |
| data_loaders = make_data_loaders(cfg) |
| train_loader = data_loaders["train"] |
| batch = next(iter(train_loader)) |
| feats = [f.to(device) for f in batch["features"]] |
| fmri = batch["fmri"].to(device) |
| print(f"Batch stats - std: {fmri.std():.6f}, mean: {fmri.mean():.6f}") |
| target_dim = fmri.shape[-1] |
| feat_dims = [f.shape[-1] for f in batch["features"]] |
| print(f"Batch shape: fmri={list(fmri.shape)}, feats={[list(f.shape) for f in feats]}") |
|
|
| |
| stage1_model = MultiSubjectConvLinearEncoder( |
| num_subjects=len(subjects_list), |
| feat_dims=feat_dims, |
| **cfg.stage1.model, |
| ).to(device) |
|
|
| ckpt_path = Path(cfg.out_dir) / "stage1_best.pt" |
| if ckpt_path.exists(): |
| stage1_model.load_state_dict(torch.load(ckpt_path, map_location=device)) |
| print(f"Loaded Stage 1 from {ckpt_path}") |
| else: |
| print("No Stage 1 checkpoint found — using random init as mu_anchor") |
|
|
| stage1_model.eval() |
| with torch.no_grad(): |
| mu_anchor = stage1_model(feats) |
|
|
| |
| i = args.subject_idx |
| x1 = fmri[:, i].transpose(1, 2) |
| mu = mu_anchor[:, i].transpose(1, 2) |
|
|
| |
| cfm = CFM( |
| feat_dim=target_dim, |
| cfm_params=cfg.stage2.cfm, |
| velocity_net_params=cfg.stage2.velocity_net, |
| source_ve_params=cfg.stage2.source_ve, |
| transport_params=cfg.stage2.transport, |
| ).to(device) |
|
|
| optimizer = torch.optim.AdamW(cfm.parameters(), lr=args.lr) |
|
|
| |
| print(f"\nOverfitting single batch for {args.steps} steps...") |
| for step in range(1, args.steps + 1): |
| cfm.train() |
| loss, _ = cfm.compute_loss(x1, mu) |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| if step % 50 == 0 or step == 1: |
| |
| cfm.eval() |
| n_ts = cfg.stage2.get("n_timesteps", 25) |
| with torch.inference_mode(): |
| pred = cfm(mu, n_timesteps=n_ts) |
|
|
| |
| y_true = x1.transpose(1, 2).reshape(-1, target_dim).cpu().numpy() |
| y_pred = pred.transpose(1, 2).reshape(-1, target_dim).cpu().numpy() |
| from metric import pearsonr_score |
| r = np.mean(pearsonr_score(y_true, y_pred)) |
|
|
| print(f" Step {step:4d} loss={loss.item():.6f} pearson_r={r:.4f}") |
|
|
| print("\nDone. If loss did not approach 0 and r did not approach 1, " |
| "the architecture cannot learn the task.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|