flow-matching / test /overfit_test.py
sabertoaster's picture
Upload folder using huggingface_hub
0254260 verified
"""
Tiny-batch overfit test for Stage 2 (CFM).
Takes a single batch, caches the Stage 1 mu_anchor, and trains a single-subject
CFM for 500 steps. If the loss does not approach 0, the architecture cannot learn.
"""
import sys
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
from omegaconf import OmegaConf
src_dir = Path(__file__).resolve().parent.parent / "src"
sys.path.append(str(src_dir))
from training import make_data_loaders, SUBJECTS
from medarc_architecture import MultiSubjectConvLinearEncoder
from stage2.CFM import CFM
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--cfg-path", type=str, default=str(Path(__file__).resolve().parent.parent / "config.yml"))
parser.add_argument("--subject-idx", type=int, default=0, help="Index into subjects list (0-based)")
parser.add_argument("--steps", type=int, default=500)
parser.add_argument("--lr", type=float, default=3e-4)
args = parser.parse_args()
cfg = OmegaConf.load(args.cfg_path)
device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
subjects_list = cfg.get("subjects", SUBJECTS)
sub = subjects_list[args.subject_idx]
print(f"Device: {device} | Subject: {sub} | Steps: {args.steps}")
# --- Data ---
data_loaders = make_data_loaders(cfg)
train_loader = data_loaders["train"]
batch = next(iter(train_loader))
feats = [f.to(device) for f in batch["features"]]
fmri = batch["fmri"].to(device) # (B, S, T, V)
print(f"Batch stats - std: {fmri.std():.6f}, mean: {fmri.mean():.6f}")
target_dim = fmri.shape[-1]
feat_dims = [f.shape[-1] for f in batch["features"]]
print(f"Batch shape: fmri={list(fmri.shape)}, feats={[list(f.shape) for f in feats]}")
# --- Stage 1: cache mu_anchor ---
stage1_model = MultiSubjectConvLinearEncoder(
num_subjects=len(subjects_list),
feat_dims=feat_dims,
**cfg.stage1.model,
).to(device)
ckpt_path = Path(cfg.out_dir) / "stage1_best.pt"
if ckpt_path.exists():
stage1_model.load_state_dict(torch.load(ckpt_path, map_location=device))
print(f"Loaded Stage 1 from {ckpt_path}")
else:
print("No Stage 1 checkpoint found — using random init as mu_anchor")
stage1_model.eval()
with torch.no_grad():
mu_anchor = stage1_model(feats) # (B, S, T, V)
# Extract single subject
i = args.subject_idx
x1 = fmri[:, i].transpose(1, 2) # (B, V, T)
mu = mu_anchor[:, i].transpose(1, 2) # (B, V, T)
# --- Stage 2: single-subject CFM ---
cfm = CFM(
feat_dim=target_dim,
cfm_params=cfg.stage2.cfm,
velocity_net_params=cfg.stage2.velocity_net,
source_ve_params=cfg.stage2.source_ve,
transport_params=cfg.stage2.transport,
).to(device)
optimizer = torch.optim.AdamW(cfm.parameters(), lr=args.lr)
# --- Overfit loop ---
print(f"\nOverfitting single batch for {args.steps} steps...")
for step in range(1, args.steps + 1):
cfm.train()
loss, _ = cfm.compute_loss(x1, mu)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 50 == 0 or step == 1:
# Quick inference check
cfm.eval()
n_ts = cfg.stage2.get("n_timesteps", 25)
with torch.inference_mode():
pred = cfm(mu, n_timesteps=n_ts) # (B, V, T)
# Pearson r on this batch
y_true = x1.transpose(1, 2).reshape(-1, target_dim).cpu().numpy()
y_pred = pred.transpose(1, 2).reshape(-1, target_dim).cpu().numpy()
from metric import pearsonr_score
r = np.mean(pearsonr_score(y_true, y_pred))
print(f" Step {step:4d} loss={loss.item():.6f} pearson_r={r:.4f}")
print("\nDone. If loss did not approach 0 and r did not approach 1, "
"the architecture cannot learn the task.")
if __name__ == "__main__":
main()