#!/usr/bin/env python3 # Copyright (c) 2024-present, BAAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ----------------------------------------------------------------------- """URSA → URSA one-step distillation via Di[M]O-style on-policy training. Verified native inference regime (from A/B testing — ground truth): height=320, width=512, num_frames=49, guidance_scale=7, teacher_steps=50. no_cfg (guidance_scale=1) does NOT produce valid output for this URSA checkpoint. All defaults below align to this verified regime. Algorithm (9 stages per iteration) ------------------------------------ teacher : frozen URSA — provides supervision at pseudo-intermediate x_t. student : trainable copy — 1-step target. aux : trainable copy — approximates teacher at x_t; reduces REINFORCE variance. Stage 1 : tokenise prompts (cond + uncond when CFG enabled) → txt_ids [B,L] Stage 2 : sample x_init [B,T,H,W] ~ Uniform(K) (+ optional p_init mixing) Stage 3 : student 1-step forward on x_init (cond only) → x_hat, logp, H Stage 4 : pseudo-intermediate x_t = scheduler.add_noise(x_hat, t) Stage 5 : teacher forward on x_t (CFG=7 dual-branch is the default) Stage 6 : aux forward → Jeffrey KD Stage 7 : student forward on x_t → KL KD Stage 8 : reward = -KL(z_T_cond, z_S_cond) [detached] Stage 9 : two-backward student update Usage: # Smoke test (verified native regime): python scripts/train_onestep_ursa_dimo.py \\ --teacher_ckpt /path/to/URSA --prompt_file prompts.txt \\ --enable_teacher_cfg --teacher_cfg_scale 7.0 \\ --num_frames 49 --height 320 --width 512 --dry_run # Full training: python scripts/train_onestep_ursa_dimo.py \\ --teacher_ckpt /path/to/URSA --prompt_file prompts.txt \\ --enable_teacher_cfg --teacher_cfg_scale 7.0 \\ --num_frames 49 --height 320 --width 512 \\ --batch_size 1 --num_steps 10000 --out_dir ./outputs/dimo_cfg """ import argparse import copy import json import math import os import sys import torch import torch.nn.functional as F from torch.utils.data import DataLoader _REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) if _REPO_ROOT not in sys.path: sys.path.insert(0, _REPO_ROOT) from diffnext.pipelines import URSAPipeline from src.distill.prompt_dataset import InfiniteDataLoader, PromptDataset, make_collate_fn, CSVSpec from src.distill.utils_ursa_inputs import ( build_ursa_inputs, compute_latents_shape, corrupt_tokens, extract_visual_logits, sample_t_curriculum, ) def _get_logits(out): if isinstance(out, (tuple, list)): return out[0] if hasattr(out, "sample"): return out.sample if hasattr(out, "logits"): return out.logits return out # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def parse_args(): p = argparse.ArgumentParser(description="URSA DiMO one-step distillation") # Model / data p.add_argument("--teacher_ckpt", required=True) p.add_argument("--prompt_file", required=True) p.add_argument("--out_dir", default="./outputs/dimo") # Video geometry (verified native: 320×512×49) p.add_argument("--num_frames", type=int, default=49) p.add_argument("--height", type=int, default=320) p.add_argument("--width", type=int, default=512) p.add_argument("--max_prompt_length", type=int, default=320) # Training p.add_argument("--batch_size", type=int, default=1) p.add_argument("--num_steps", type=int, default=10_000) p.add_argument("--lr_student", type=float, default=1e-5) p.add_argument("--lr_aux", type=float, default=1e-5) p.add_argument("--weight_decay", type=float, default=0.01) p.add_argument("--grad_clip", type=float, default=1.0) p.add_argument("--mixed_precision", default="bf16", choices=["fp16", "bf16", "fp32"]) p.add_argument("--seed", type=int, default=42) p.add_argument("--log_every", type=int, default=50) p.add_argument("--save_every", type=int, default=1000) # Loss weights p.add_argument("--lambda_pg", type=float, default=1.0) p.add_argument("--lambda_kd", type=float, default=0.5) p.add_argument("--lambda_ent", type=float, default=0.01) p.add_argument("--tau", type=float, default=1.0, help="Student sampling temperature") p.add_argument("--tau_kd", type=float, default=1.0, help="KD softmax temperature") # ---- Teacher CFG (DiMO true_cfg style) ---------------------------- p.add_argument("--enable_teacher_cfg", action="store_true", default=False, help="Enable teacher-side CFG for KD target. " "False → prior single-branch behavior (fallback).") p.add_argument("--teacher_cfg_scale", type=float, default=7.0, help="CFG scale s (verified working value=7)") p.add_argument("--teacher_cfg_prob", type=float, default=1.0, help="Max prob of using guided target per sample (after warmup)") p.add_argument("--teacher_cfg_warmup_steps", type=int, default=2000, help="Steps to ramp teacher_cfg_prob 0 → teacher_cfg_prob") p.add_argument("--teacher_cfg_trunc", type=float, default=0.9, help="t threshold: when t >= trunc, s=1. Set >=1.0 to disable.") p.add_argument("--lambda_kd_uncond", type=float, default=0.3, help="Weight for uncond-branch KD / aux loss") p.add_argument("--reward_use_guided", action="store_true", default=False, help="[RISKY] Use guided teacher logits for REINFORCE reward.") # ---- Eval CFG (inference-time) ----------------------------------- p.add_argument("--eval_cfg_scale", type=float, default=7.0) p.add_argument("--use_cfg_eval", action="store_true", default=True) # DiMO extensions p.add_argument("--use_surrogate_grad", action="store_true", help="DiMO surrogate MSE trick applied to Stage-3 logits") p.add_argument("--lambda_surr", type=float, default=1.0) p.add_argument("--fake_rounds", type=int, default=1, help="Aux updates per generator update (DiMO=2)") # Stability p.add_argument("--t_curriculum_steps", type=int, default=10_000) p.add_argument("--p_mix_corrupt_frac", type=float, default=0.2) p.add_argument("--p_init_mix_ratio", type=float, default=0.2) p.add_argument("--collapse_warn_frac", type=float, default=0.2) # Debug p.add_argument("--dry_run", action="store_true", help="Run 1 step + grad-flow check, then exit") p.add_argument("--debug_dump", type=int, default=0, help="Dump token histogram + x_hat every N steps (0=off)") p.add_argument("--device", type=int, default=0) return p.parse_args() # --------------------------------------------------------------------------- # Checkpoint # --------------------------------------------------------------------------- def save_checkpoint(model, path: str, name: str = "student"): os.makedirs(path, exist_ok=True) ckpt_path = os.path.join(path, f"{name}.pt") torch.save(model.state_dict(), ckpt_path) print(f"[save] {ckpt_path}") # --------------------------------------------------------------------------- # Stable KL / Jeffrey divergence helpers (float32 + log_softmax) # --------------------------------------------------------------------------- def _stable_kl(z_p: torch.Tensor, z_q: torch.Tensor, tau: float = 1.0) -> torch.Tensor: """KL(p||q) from raw logits, float32 + log_softmax. → [B] (mean over N tokens). p = softmax(z_p/tau), q = softmax(z_q/tau) KL(p||q) = sum_k p_k * (log p_k - log q_k) Both log_p and log_q are computed via log_softmax to avoid log(softmax(...)) numerical issues. """ lp = F.log_softmax(z_p.float() / tau, dim=-1) # [B, N, K] lq = F.log_softmax(z_q.float() / tau, dim=-1) # [B, N, K] return (lp.exp() * (lp - lq)).sum(-1).mean(-1) # [B] def _stable_jeffrey(z_p: torch.Tensor, z_q: torch.Tensor, tau: float = 1.0) -> torch.Tensor: """Symmetric KL (Jeffrey) from logits, float32 + log_softmax. → [B].""" return _stable_kl(z_p, z_q, tau) + _stable_kl(z_q, z_p, tau) # --------------------------------------------------------------------------- # Batch-concat input builder (ONE forward for cond + uncond) # --------------------------------------------------------------------------- def _build_dual_inputs(teacher_ref, txt_cond, txt_uncond, x_t, latents_shape, device): """Concatenate cond+uncond into a single [2B] forward-pass input. Returns (ids_dual [2B, L+N+1], rpos_dual [2B, L+N+1, 3], N). After the forward: chunk(2, dim=0) → (z_cond [B], z_uncond [B]). All three models (teacher/aux/student) share the SAME ids_dual / rpos_dual so the tokens are constructed only once per step. """ txt_dual = torch.cat([txt_cond, txt_uncond], dim=0) # [2B, L] x_t_dual = torch.cat([x_t, x_t], dim=0) # [2B, T, H, W] return build_ursa_inputs(teacher_ref, txt_dual, x_t_dual, latents_shape, device) # --------------------------------------------------------------------------- # flex_attn probe / reset helpers # --------------------------------------------------------------------------- def _probe_flex_attn(model, label: str = "") -> object: """Return the FlexAttentionCausal2D object if present, else None.""" return getattr(model, "flex_attn", None) def _print_flex_attn_state(model, label: str): fa = _probe_flex_attn(model, label) if fa is None: print(f" [flex_attn/{label}] not present on model") return print( f" [flex_attn/{label}] offsets={fa.offsets!r} " f"block_mask={'set' if fa.block_mask is not None else 'None'} " f"cu_offsets={'set' if fa.cu_offsets is not None else 'None'}" ) def _reset_flex_attn(model, label: str = "", verbose: bool = False): """Reset flex_attn to None offsets so standard causal attention is used. Our distillation training processes each sample independently (batch dim) so block-packed attention (offsets != None) is not needed and must be cleared to avoid cross-sample mask contamination. """ fa = _probe_flex_attn(model, label) if fa is None: return old_offsets = fa.offsets fa.offsets = None fa.block_mask = None fa.cu_offsets = None if verbose: print(f" [flex_attn/{label}] reset: was={old_offsets!r} → None (standard causal)") # --------------------------------------------------------------------------- # Teacher CFG target construction # --------------------------------------------------------------------------- def _compute_cfg_scale(t: torch.Tensor, cfg_scale: float, trunc: float) -> torch.Tensor: """Per-sample CFG scale [B]: s=cfg_scale when t < trunc, else s=1.""" s = torch.full_like(t, cfg_scale) if trunc < 1.0: s = torch.where(t >= trunc, torch.ones_like(t), s) return s def _cfg_warmup_prob(step: int, cfg_prob: float, warmup_steps: int) -> float: """Linear warmup: 0 → cfg_prob over warmup_steps steps.""" if warmup_steps <= 0: return cfg_prob return cfg_prob * min(1.0, step / warmup_steps) def _build_guided_logits( z_T_cond: torch.Tensor, # [B, N, K] float32 z_T_uncond: torch.Tensor, # [B, N, K] float32 t: torch.Tensor, # [B] ∈ (0,1) cfg_scale: float, trunc: float, ) -> torch.Tensor: """z_guided = z_uncond + s*(z_cond - z_uncond), per-sample s [B,1,1].""" s = _compute_cfg_scale(t, cfg_scale, trunc).view(-1, 1, 1) # [B,1,1] return z_T_uncond + s * (z_T_cond - z_T_uncond) # [B, N, K] def _select_target( z_guided: torch.Tensor, # [B, N, K] z_cond: torch.Tensor, # [B, N, K] use_guided: torch.Tensor, # [B] bool — per-sample selection ) -> torch.Tensor: """Per-sample: z_guided where use_guided[b]=True, else z_cond.""" mask = use_guided.view(-1, 1, 1).expand_as(z_cond) return torch.where(mask, z_guided, z_cond) # --------------------------------------------------------------------------- # Gradient-flow debug # --------------------------------------------------------------------------- def debug_grad_flow( teacher, student, aux, txt_cond, txt_uncond, x_t, latents_shape, device, K, N, tau, tau_kd, enable_teacher_cfg, ): """One fwd+bwd without optimizer.step(). Asserts: - teacher: zero grads (frozen) - aux: non-zero grads after loss_aux.backward() - student: non-zero grads after loss_student.backward() All cond/uncond forwards are batch-concatenated per requirement (1). """ print("\n" + "=" * 64) print("[grad_flow] Starting gradient flow debug …") B = txt_cond.size(0) # -- Stage 3: student on x_init (cond only) ---------------------- x_init_dbg = torch.randint(0, K, x_t.shape, device=device, dtype=torch.long) ids_init, rpos_init, _ = build_ursa_inputs(teacher, txt_cond, x_init_dbg, latents_shape, device) logits_s = student(ids_init, rope_pos=rpos_init).sample z_s = extract_visual_logits(logits_s.float(), N, K) p_s = F.softmax(z_s / tau, dim=-1) x_hat = torch.multinomial(p_s.view(-1, K), 1).view(B, N) logp = p_s.clamp(1e-8).log().gather(-1, x_hat.unsqueeze(-1)).squeeze(-1).sum(-1) H_mean = -(p_s * p_s.clamp(1e-8).log()).sum(-1).mean() # -- Stage 5: teacher forward — [2B] if CFG, else [B] ------------ if enable_teacher_cfg and txt_uncond is not None: ids_dual, rpos_dual, _ = _build_dual_inputs(teacher, txt_cond, txt_uncond, x_t, latents_shape, device) with torch.no_grad(): logits_T_dual = teacher(ids_dual, rope_pos=rpos_dual).sample.float() z_T_dual = extract_visual_logits(logits_T_dual, N, K) z_T_cond_dbg, z_T_uncond_dbg = z_T_dual.chunk(2, dim=0) t_dbg = torch.full((B,), 0.5, device=device, dtype=torch.float32) z_T_guided_dbg = _build_guided_logits( z_T_cond_dbg.float(), z_T_uncond_dbg.float(), t_dbg, 3.0, 0.9) z_T_target_dbg = z_T_guided_dbg.detach() print(f" [grad_flow] z_T_cond shape={z_T_cond_dbg.shape} " f"min={z_T_cond_dbg.min():.3f} max={z_T_cond_dbg.max():.3f}") print(f" [grad_flow] z_T_uncond shape={z_T_uncond_dbg.shape} " f"min={z_T_uncond_dbg.min():.3f} max={z_T_uncond_dbg.max():.3f}") print(f" [grad_flow] z_T_guided shape={z_T_guided_dbg.shape} " f"min={z_T_guided_dbg.min():.3f} max={z_T_guided_dbg.max():.3f}") ids_t_ref = ids_dual[:B] rpos_t_ref = rpos_dual[:B] ids_fwd = ids_dual rpos_fwd = rpos_dual else: ids_t_ref, rpos_t_ref, _ = build_ursa_inputs(teacher, txt_cond, x_t, latents_shape, device) with torch.no_grad(): logits_T = teacher(ids_t_ref, rope_pos=rpos_t_ref).sample.float() z_T_target_dbg = extract_visual_logits(logits_T, N, K).detach() ids_fwd = ids_t_ref rpos_fwd = rpos_t_ref # Dual-path shape check (teacher vs student, same input) with torch.no_grad(): z_T_ref2 = extract_visual_logits( teacher(ids_t_ref, rope_pos=rpos_t_ref).sample.float(), N, K) z_S_ref2 = extract_visual_logits( student(ids_t_ref.detach(), rope_pos=rpos_t_ref.detach()).sample.float(), N, K) if z_T_ref2.shape != z_S_ref2.shape: raise RuntimeError( f"[FATAL] Dual-path shape mismatch: z_T={z_T_ref2.shape} z_S={z_S_ref2.shape}" ) print(f" [grad_flow] Dual-path check OK: shape={z_T_ref2.shape}") # -- Aux backward — [2B] if CFG, else [B] ------------------------- logits_A = aux(ids_fwd.detach(), rope_pos=rpos_fwd.detach()).sample if enable_teacher_cfg and txt_uncond is not None: z_A_dual2 = extract_visual_logits(logits_A.float(), N, K) z_A_cond_dbg, _ = z_A_dual2.chunk(2, dim=0) else: z_A_cond_dbg = extract_visual_logits(logits_A.float(), N, K) loss_aux_sample = _stable_jeffrey(z_T_target_dbg, z_A_cond_dbg, tau_kd) loss_aux = loss_aux_sample.mean() loss_aux.backward() teacher_grads = [p.grad for p in teacher.parameters() if p.grad is not None] aux_grads = [p.grad.norm().item() for p in aux.parameters() if p.grad is not None] print(f" [grad_flow] teacher grads with non-None grad: {len(teacher_grads)} (must be 0)") if aux_grads: print(f" [grad_flow] aux grad norm min={min(aux_grads):.3e} " f"mean={sum(aux_grads)/len(aux_grads):.3e} max={max(aux_grads):.3e}") else: print(" [grad_flow] ⚠️ aux has NO grads") for param in aux.parameters(): param.grad = None # -- Student backward — [B] (cond only for simplicity) ------------ logits_S = student(ids_t_ref.detach(), rope_pos=rpos_t_ref.detach()).sample z_S_cond = extract_visual_logits(logits_S.float(), N, K) loss_kd = _stable_kl(z_T_target_dbg, z_S_cond, tau_kd).mean() adv = (loss_aux_sample.detach() * 0 + 1.0) # dummy advantage (shape check) assert not adv.requires_grad, "[BUG] adv must be detached" loss_student = -(adv * logp).mean() + loss_kd - 0.01 * H_mean loss_student.backward() student_grads = [p.grad.norm().item() for p in student.parameters() if p.grad is not None] if student_grads: print(f" [grad_flow] student grad norm min={min(student_grads):.3e} " f"mean={sum(student_grads)/len(student_grads):.3e} " f"max={max(student_grads):.3e}") else: print(" [grad_flow] ⚠️ student has NO grads — diagnosing:") print(f" logp.requires_grad={logp.requires_grad}") print(f" z_s.requires_grad={z_s.requires_grad}") assert len(teacher_grads) == 0, "teacher has grads — not frozen" assert len(aux_grads) > 0, "aux has no grads after loss_aux.backward()" assert len(student_grads) > 0, "student has no grads — grad flow broken" for m in (student, aux): for param in m.parameters(): param.grad = None print(" [grad_flow] All gradient assertions PASSED ✓") print("=" * 64 + "\n") # --------------------------------------------------------------------------- # Main training loop # --------------------------------------------------------------------------- def main(): args = parse_args() device = torch.device("cuda", args.device) if torch.cuda.is_available() else torch.device("cpu") dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32} compute_dtype = dtype_map[args.mixed_precision] torch.manual_seed(args.seed) os.makedirs(args.out_dir, exist_ok=True) # -- Verified regime validation ---------------------------------------- _NATIVE = dict(height=320, width=512, num_frames=49, guidance_scale=7.0) is_native = ( args.height == _NATIVE["height"] and args.width == _NATIVE["width"] and args.num_frames == _NATIVE["num_frames"] ) print(f"[init] verified_native_regime={is_native} " f"geometry=({args.num_frames}×{args.height}×{args.width}) " f"teacher_cfg_scale={args.teacher_cfg_scale if args.enable_teacher_cfg else 'OFF'}") if not is_native: print(f"[WARN] Current geometry ({args.num_frames}×{args.height}×{args.width}) " f"is not the verified native URSA regime " f"({_NATIVE['num_frames']}×{_NATIVE['height']}×{_NATIVE['width']}). " "Distillation quality may degrade or become invalid.") if not args.enable_teacher_cfg: print("[WARN] Teacher CFG is DISABLED. no_cfg is known to produce " "blank/blurry output for this URSA checkpoint. " "Distillation without CFG is unlikely to produce useful results.") elif args.teacher_cfg_scale != _NATIVE["guidance_scale"]: print(f"[WARN] teacher_cfg_scale={args.teacher_cfg_scale} differs from " f"the verified working value ({_NATIVE['guidance_scale']}).") if args.enable_teacher_cfg and args.reward_use_guided: print("[WARN] --reward_use_guided is ON — can cause mode collapse, watch tok_entropy.") # -- Load pipeline --------------------------------------------------- print(f"[init] Loading from {args.teacher_ckpt} …") pipe = URSAPipeline.from_pretrained( args.teacher_ckpt, torch_dtype=compute_dtype, trust_remote_code=True ).to(device) tokenizer = pipe.tokenizer scheduler = pipe.scheduler scheduler.to(device=device) vae_t_stride = getattr(pipe.vae.config, "temporal_stride", 4) vae_s_stride = getattr(pipe.vae.config, "spatial_stride", 8) latents_shape = compute_latents_shape( args.num_frames, args.height, args.width, vae_t_stride, vae_s_stride ) T, H, W = latents_shape N = T * H * W K = scheduler.codebook_size print( f"[init] latents_shape=({T},{H},{W}) N={N} K={K} " f"CFG={'ON' if args.enable_teacher_cfg else 'OFF'}" ) # -- Pre-compute uncond token IDs (empty string, [1, L]) -------------- txt_uncond_base = tokenizer( [""], max_length=args.max_prompt_length, padding="max_length", padding_side="left", truncation=True, return_tensors="pt", ).input_ids.to(device) # [1, L] # -- Three models ---------------------------------------------------- teacher = pipe.transformer.eval().requires_grad_(False) student = copy.deepcopy(teacher).train().requires_grad_(True) aux = copy.deepcopy(teacher).train().requires_grad_(True) # -- flex_attn: reset offsets to None (standard causal attn) --------- # Our training processes B independent sequences in a batch, so block-packed # offsets are not needed and must be cleared before any forward call. if args.dry_run: print("[init] flex_attn state before reset:") for m, lbl in ((teacher, "teacher"), (student, "student"), (aux, "aux")): _print_flex_attn_state(m, lbl) for m, lbl in ((teacher, "teacher"), (student, "student"), (aux, "aux")): _reset_flex_attn(m, lbl, verbose=True) if args.dry_run: print("[init] flex_attn state after reset:") for m, lbl in ((teacher, "teacher"), (student, "student"), (aux, "aux")): _print_flex_attn_state(m, lbl) opt_student = torch.optim.AdamW( student.parameters(), lr=args.lr_student, weight_decay=args.weight_decay ) opt_aux = torch.optim.AdamW( aux.parameters(), lr=args.lr_aux, weight_decay=args.weight_decay ) # -- Dataset ---------------------------------------------------------- # dataset = PromptDataset(args.prompt_file, shuffle=True, seed=args.seed) collate = make_collate_fn(tokenizer, args.max_prompt_length, device) # loader = DataLoader( # dataset, batch_size=args.batch_size, shuffle=True, # drop_last=True, num_workers=0, collate_fn=collate, # ) dataset = PromptDataset( args.prompt_file, shuffle_files=True, shuffle_buffer=50000, # 例如 50k buffer,够用且不占太多内存 seed=args.seed, infinite=True, csv=CSVSpec(caption_field="caption"), # Koala 默认就是 caption ) loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=False, # IMPORTANT for IterableDataset drop_last=True, num_workers=2, # 视 IO 调大 collate_fn=collate, pin_memory=True, ) inf_loader = InfiniteDataLoader(loader) # -- Pre-training sanity check --------------------------------------- _sanity_check_forward(teacher, scheduler, latents_shape, device, K, args.dry_run) # -- Training state -------------------------------------------------- baseline_ema: float = 0.0 x_hat_prev = None initial_tok_entropy: float = None dump_dir = os.path.join(args.out_dir, "debug_dumps") if args.debug_dump > 0 else None num_steps = 1 if args.dry_run else args.num_steps print(f"[train] {'DRY RUN' if args.dry_run else f'{num_steps} steps'} " f"| CFG={args.enable_teacher_cfg}") for step in range(1, num_steps + 1): # ---------------------------------------------------------------- # Stage 1: Tokenise → txt_cond [B, L], txt_uncond [B, L] # ---------------------------------------------------------------- txt_cond = next(inf_loader) # [B, L] txt_cond = txt_cond.to(device, non_blocking=True) B = txt_cond.size(0) txt_uncond = None if args.enable_teacher_cfg: txt_uncond = txt_uncond_base.expand(B, -1) # [B, L] # ---------------------------------------------------------------- # Stage 2: x_init ~ Uniform(K) (+ optional p_init mixing) # ---------------------------------------------------------------- x_init = _sample_x_init(B, T, H, W, K, device, x_hat_prev, args) # ---------------------------------------------------------------- # Stage 3: Student 1-step forward on x_init — COND only. # # Gradient needed: logp and H flow back through p_s → student. # ---------------------------------------------------------------- with torch.no_grad(): ids_init, rpos_init, _ = build_ursa_inputs( teacher, txt_cond, x_init, latents_shape, device) logits_s_init = student(ids_init, rope_pos=rpos_init).sample # [B, L+N+1, D] z_s = extract_visual_logits(logits_s_init.float(), N, K) # [B, N, K] p_s = F.softmax(z_s / args.tau, dim=-1) # [B, N, K] x_hat = torch.multinomial(p_s.view(-1, K), 1).view(B, N) # [B, N] # logp = p_s.clamp(1e-8).log().gather( # -1, x_hat.unsqueeze(-1)).squeeze(-1).sum(-1) # [B] # H_mean = -(p_s * p_s.clamp(1e-8).log()).sum(-1).mean() x_hat_4d = x_hat.view(B, T, H, W) # ---------------------------------------------------------------- # Stage 4: Pseudo-intermediate x_t # ---------------------------------------------------------------- t = sample_t_curriculum(B, device, step, warmup_steps=args.t_curriculum_steps) with torch.no_grad(): x_t = scheduler.add_noise(x_hat_4d, t) # [B, T, H, W], long # ---------------------------------------------------------------- # Stage 5: Teacher forward — single [2B] forward when CFG enabled. # # ids_dual / rpos_dual are SHARED by teacher, aux, and student to # avoid redundant input construction. # ---------------------------------------------------------------- with torch.no_grad(): if args.enable_teacher_cfg: # ONE [2B] forward = cond (first B) + uncond (last B) ids_dual, rpos_dual, _ = _build_dual_inputs( teacher, txt_cond, txt_uncond, x_t, latents_shape, device) logits_T_dual = teacher(ids_dual, rope_pos=rpos_dual).sample.float() z_T_dual = extract_visual_logits(logits_T_dual, N, K) # [2B, N, K] z_T_cond, z_T_uncond = z_T_dual.chunk(2, dim=0) # [B, N, K] each ids_t = ids_dual[:B] # cond half — alias (no copy) rpos_t = rpos_dual[:B] else: ids_t, rpos_t, _ = build_ursa_inputs( teacher, txt_cond, x_t, latents_shape, device) logits_T = teacher(ids_t, rope_pos=rpos_t).sample.float() z_T_cond = extract_visual_logits(logits_T, N, K) # [B, N, K] z_T_uncond = None ids_dual = ids_t rpos_dual = rpos_t # -- CFG guided target (float32, per-sample Bernoulli) ---------- z_T_guided = None if args.enable_teacher_cfg: z_T_cond_f = z_T_cond.float() z_T_uncond_f = z_T_uncond.float() z_T_guided = _build_guided_logits( z_T_cond_f, z_T_uncond_f, t, args.teacher_cfg_scale, args.teacher_cfg_trunc) # per-sample Bernoulli: use_guided[b] ~ Bernoulli(p_guided) p_guided = _cfg_warmup_prob( step, args.teacher_cfg_prob, args.teacher_cfg_warmup_steps) use_guided = torch.rand(B, device=device) < p_guided # [B] bool use_guided_ratio = use_guided.float().mean().item() z_T_target = _select_target(z_T_guided, z_T_cond_f, use_guided) # [B, N, K] else: use_guided = torch.zeros(B, dtype=torch.bool, device=device) use_guided_ratio = 0.0 z_T_target = z_T_cond.float() # z_T_target is the KD target — must have no grad path to teacher z_T_target = z_T_target.detach() # ---------------------------------------------------------------- # Stage 6: Aux forward (fake_rounds) — single [2B] forward when CFG. # ---------------------------------------------------------------- loss_aux_cond_v_last = None loss_aux_uncond_v_last = None loss_aux_cond_sample_last = None for _fr in range(args.fake_rounds): opt_aux.zero_grad() if args.enable_teacher_cfg: # ONE [2B] forward: cond+uncond in one shot logits_A_dual = aux(ids_dual.detach(), rope_pos=rpos_dual.detach()).sample z_A_dual = extract_visual_logits(logits_A_dual.float(), N, K) # [2B, N, K] z_A_cond, z_A_uncond = z_A_dual.chunk(2, dim=0) # Cond: Jeffrey(z_T_target, z_A_cond) loss_aux_cond_sample = _stable_jeffrey(z_T_target, z_A_cond, args.tau_kd) # [B] loss_aux_cond_v = loss_aux_cond_sample.mean() # Uncond: Jeffrey(z_T_uncond, z_A_uncond) z_T_uncond_det = z_T_uncond.float().detach() loss_aux_uncond_sample = _stable_jeffrey(z_T_uncond_det, z_A_uncond, args.tau_kd) loss_aux_uncond_v = loss_aux_uncond_sample.mean() loss_aux_v = loss_aux_cond_v + args.lambda_kd_uncond * loss_aux_uncond_v else: logits_A = aux(ids_t.detach(), rope_pos=rpos_t.detach()).sample z_A_cond = extract_visual_logits(logits_A.float(), N, K) loss_aux_cond_sample = _stable_jeffrey(z_T_target, z_A_cond, args.tau_kd) # [B] loss_aux_cond_v = loss_aux_cond_sample.mean() loss_aux_uncond_v = torch.tensor(0.0, device=device) loss_aux_v = loss_aux_cond_v loss_aux_v.backward() if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(aux.parameters(), args.grad_clip) opt_aux.step() # make sure aux grads are cleared and no graph is retained for p in aux.parameters(): p.grad = None loss_aux_cond_v_last = loss_aux_cond_v.detach() loss_aux_uncond_v_last = loss_aux_uncond_v.detach() loss_aux_cond_sample_last = loss_aux_cond_sample.detach() # [B] # # ---------------------------------------------------------------- # # Stage 7: Student KD forward on x_t — single [2B] when CFG. # # Dual-path consistency check included. # # ---------------------------------------------------------------- # if args.enable_teacher_cfg: # # ONE [2B] forward # logits_S_dual = student(ids_dual.detach(), rope_pos=rpos_dual.detach()).sample # z_S_dual = extract_visual_logits(logits_S_dual.float(), N, K) # [2B, N, K] # z_S_cond, z_S_uncond = z_S_dual.chunk(2, dim=0) # else: # logits_S = student(ids_t.detach(), rope_pos=rpos_t.detach()).sample # z_S_cond = extract_visual_logits(logits_S.float(), N, K) # [B, N, K] # z_S_uncond = None # # Dual-path shape consistency check # if z_T_cond.shape != z_S_cond.shape: # raise RuntimeError( # f"[FATAL] Dual-path shape mismatch: " # f"z_T_cond={z_T_cond.shape} z_S_cond={z_S_cond.shape} — " # "vocab slicing inconsistency." # ) # # KD losses (from raw logits, float32 + log_softmax) # loss_kd_cond = _stable_kl(z_T_target, z_S_cond, args.tau_kd).mean() # loss_kd_uncond_v = torch.tensor(0.0, device=device) # if args.enable_teacher_cfg and z_S_uncond is not None: # z_T_uncond_det2 = z_T_uncond.float().detach() # loss_kd_uncond_v = _stable_kl(z_T_uncond_det2, z_S_uncond, args.tau_kd).mean() # loss_kd = loss_kd_cond + args.lambda_kd_uncond * loss_kd_uncond_v # # ---------------------------------------------------------------- # # Stage 8: REINFORCE reward + advantage # # # # INVARIANT: reward and adv MUST NOT carry student gradients. # # - z_S_cond is detached before entering reward computation. # # - adv is explicitly detached. # # - Runtime assertions enforce this. # # ---------------------------------------------------------------- # if args.enable_teacher_cfg: # if args.reward_use_guided: # z_T_for_rew = z_T_target # already detached (guided, see §5) # else: # z_T_for_rew = z_T_cond.float().detach() # non-guided cond (stable default) # # Both inputs are detached: no student gradient leaks into reward. # reward = -_stable_kl( # z_T_for_rew.detach(), z_S_cond.detach(), args.tau) # [B] # else: # reward = -loss_aux_cond_sample_last # [B], already detached # # Mandatory detach assertions: catch reward/adv gradient leaks early. # assert not reward.requires_grad, ( # "[BUG] reward.requires_grad=True — student gradient leaked into reward. " # "Ensure z_S_cond is detached in reward computation." # ) # baseline_ema = 0.99 * baseline_ema + 0.01 * reward.mean().item() # adv = (reward - baseline_ema).detach() # [B] # assert not adv.requires_grad, "[BUG] adv.requires_grad=True — explicit detach failed" # loss_pg = -(adv * logp).mean() # # ---------------------------------------------------------------- # # Stage 9: Student loss + update # # ---------------------------------------------------------------- # opt_student.zero_grad() # lambda_ent_eff = args.lambda_ent * (1.0 + 2.0 * use_guided_ratio) # loss_student = ( # args.lambda_pg * loss_pg # + args.lambda_kd * loss_kd # - lambda_ent_eff * H_mean # ) # # Optional surrogate gradient (DiMO MSE trick — applied to Stage-3 logits z_s) # loss_surr = None # if args.use_surrogate_grad: # with torch.no_grad(): # logits_A_ref = aux(ids_t.detach(), rope_pos=rpos_t.detach()).sample # z_A_ref = extract_visual_logits(logits_A_ref.float(), N, K) # # grad_surr = (p_A - p_T): pushes z_s toward teacher distribution # p_A_ref = F.softmax(z_A_ref.float() / args.tau_kd, dim=-1).detach() # p_T_surr = F.softmax(z_T_target / args.tau_kd, dim=-1).detach() # grad_surr = (p_A_ref - p_T_surr).detach() # loss_surr = 0.5 * F.mse_loss(z_s, (z_s - grad_surr).detach()) # loss_student = loss_student + args.lambda_surr * loss_surr # loss_student.backward() # if args.grad_clip > 0: # torch.nn.utils.clip_grad_norm_(student.parameters(), args.grad_clip) # opt_student.step() # # p_init mixing: save x_hat_4d for next step # x_hat_prev = x_hat_4d.detach().clone() # ---------------------------------------------------------------- # Stage 7: Student KD forward on x_t — single [2B] when CFG. # ---------------------------------------------------------------- if args.enable_teacher_cfg: logits_S_dual = _get_logits(student(ids_dual.detach(), rope_pos=rpos_dual.detach())).float() z_S_dual = extract_visual_logits(logits_S_dual, N, K) # [2B, N, K] z_S_cond, z_S_uncond = z_S_dual.chunk(2, dim=0) else: logits_S = _get_logits(student(ids_t.detach(), rope_pos=rpos_t.detach())).float() z_S_cond = extract_visual_logits(logits_S, N, K) z_S_uncond = None if z_T_cond.shape != z_S_cond.shape: raise RuntimeError(f"[FATAL] Dual-path shape mismatch: z_T_cond={z_T_cond.shape} z_S_cond={z_S_cond.shape}") loss_kd_cond = _stable_kl(z_T_target, z_S_cond, args.tau_kd).mean() loss_kd_uncond_v = torch.tensor(0.0, device=device) if args.enable_teacher_cfg and (z_S_uncond is not None): loss_kd_uncond_v = _stable_kl(z_T_uncond.float().detach(), z_S_uncond, args.tau_kd).mean() loss_kd = loss_kd_cond + args.lambda_kd_uncond * loss_kd_uncond_v # ---------------------------------------------------------------- # Stage 8: reward + advantage (detached) # ---------------------------------------------------------------- if args.enable_teacher_cfg and args.reward_use_guided: z_T_for_rew = z_T_target # already detached else: z_T_for_rew = z_T_cond.float().detach() reward = -_stable_kl(z_T_for_rew.detach(), z_S_cond.detach(), args.tau) # [B] assert not reward.requires_grad baseline_ema = 0.99 * baseline_ema + 0.01 * reward.mean().item() adv = (reward - baseline_ema).detach() assert not adv.requires_grad # ---------------------------------------------------------------- # Stage 9: update student in two backward passes (KD then PG/Ent) # ---------------------------------------------------------------- opt_student.zero_grad(set_to_none=True) # (9a) KD backward first (frees KD graph) (args.lambda_kd * loss_kd).backward() # (9b) Policy + entropy: need a fresh forward on x_init WITH grad ids_init, rpos_init, _ = build_ursa_inputs(teacher, txt_cond, x_init, latents_shape, device) logits_s_pol = _get_logits(student(ids_init, rope_pos=rpos_init)).float() z_s_pol = extract_visual_logits(logits_s_pol, N, K) logp_tok = F.log_softmax(z_s_pol / args.tau, dim=-1) # [B,N,K] p_s_pol = logp_tok.exp() # fixed action: x_hat sampled in Stage 3 (no_grad) logp_sum = logp_tok.gather(-1, x_hat.unsqueeze(-1)).squeeze(-1).sum(-1) # [B], sum over N tokens logp = logp_sum / N # [B], per-token average logp (RECOMMENDED) H_mean = -(p_s_pol * logp_tok).sum(-1).mean() loss_pg = -(adv * logp).mean() lambda_ent_eff = args.lambda_ent * (1.0 + 2.0 * use_guided_ratio) (loss_pg * args.lambda_pg - H_mean * lambda_ent_eff).backward() # (optional) surrogate grad — put it here; WARNING: extra forward makes it heavier loss_surr = None if args.use_surrogate_grad: with torch.no_grad(): logits_A_ref = _get_logits(aux(ids_t.detach(), rope_pos=rpos_t.detach())).float() z_A_ref = extract_visual_logits(logits_A_ref, N, K) p_A_ref = F.softmax(z_A_ref / args.tau_kd, dim=-1).detach() p_T_ref = F.softmax(z_T_target / args.tau_kd, dim=-1).detach() grad_surr = (p_A_ref - p_T_ref).detach() loss_surr = 0.5 * F.mse_loss(z_s_pol, (z_s_pol - grad_surr).detach()) (args.lambda_surr * loss_surr).backward() if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(student.parameters(), args.grad_clip) opt_student.step() # p_init mixing: save x_hat_4d for next step x_hat_prev = x_hat_4d.detach() #.clone() # ---------------------------------------------------------------- # Post-step: assertions (step 1), collapse detection, logging # ---------------------------------------------------------------- if step == 1: _run_assertions( x_init, ids_init, rpos_init, z_s, p_s, logp, z_T_cond, z_S_cond, x_t, K, N, B, T, H, W, teacher.config.lm_vocab_size, z_T_uncond=z_T_uncond, z_T_guided=z_T_guided, dry_run=args.dry_run, ) tok_entropy = _token_histogram_entropy(x_hat, K) if initial_tok_entropy is None: initial_tok_entropy = tok_entropy if tok_entropy < args.collapse_warn_frac * initial_tok_entropy: print( f"[COLLAPSE WARNING] step={step} tok_entropy={tok_entropy:.3f} " f"initial={initial_tok_entropy:.3f} " f"ratio={tok_entropy/max(initial_tok_entropy, 1e-8):.2f} < " f"{args.collapse_warn_frac}. " "Increase --lambda_ent (try 0.05) or --tau." ) if step % args.log_every == 0 or args.dry_run: surr_str = f" loss_surr={loss_surr.item():.4f}" if loss_surr is not None else "" print( f"[step {step:>6d}] " f"loss_aux_cond={loss_aux_cond_v_last.item():.3e} " f"loss_aux_uncond={loss_aux_uncond_v_last.item():.3e} " f"loss_kd_cond={loss_kd_cond.item():.4f} " f"loss_kd_uncond={loss_kd_uncond_v.item():.4f} " f"loss_pg={loss_pg.item():.4f}" f"{surr_str} " f"H={H_mean.item():.3f} tok_H={tok_entropy:.3f} " f"guided_ratio={use_guided_ratio:.2f} " f"baseline={baseline_ema:.4f} " f"mean_logp_tok={logp.mean().item():.3f}" ) if args.debug_dump > 0 and step % args.debug_dump == 0: _dump_debug(dump_dir, step, x_hat, K) if not args.dry_run and step % args.save_every == 0: ckpt_dir = os.path.join(args.out_dir, f"step_{step:06d}") save_checkpoint(student, ckpt_dir, "student") save_checkpoint(aux, ckpt_dir, "aux") # -- dry_run: full grad-flow check after the single training step ---- if args.dry_run: print("\n[dry_run] Running gradient flow debug …") txt_dbg = next(inf_loader) B_dbg = txt_dbg.size(0) x_t_dbg = torch.randint(0, K, (B_dbg, T, H, W), device=device, dtype=torch.long) txt_u_dbg = (txt_uncond_base.expand(B_dbg, -1) if args.enable_teacher_cfg else None) debug_grad_flow( teacher, student, aux, txt_dbg, txt_u_dbg, x_t_dbg, latents_shape, device, K, N, args.tau, args.tau_kd, args.enable_teacher_cfg, ) _dry_run_patches_789(teacher, latents_shape, K, N, device) print("[dry_run] Done. All checks (1-9) PASSED. Exiting.") return # Final save final_dir = os.path.join(args.out_dir, "final") save_checkpoint(student, final_dir, "student") save_checkpoint(aux, final_dir, "aux") print("[done] Training complete.") # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _sample_x_init(B, T, H, W, K, device, x_hat_prev, args): x_init = torch.randint(0, K, (B, T, H, W), device=device, dtype=torch.long) if x_hat_prev is not None and args.p_init_mix_ratio > 0: n_mix = max(1, int(B * args.p_init_mix_ratio)) x_init[:n_mix] = corrupt_tokens(x_hat_prev[:n_mix], r=args.p_mix_corrupt_frac, K=K) return x_init def _token_histogram_entropy(x_hat: torch.Tensor, K: int) -> float: counts = x_hat.flatten().bincount(minlength=K).float() p = counts / counts.sum() p = p[p > 0] return float(-(p * p.log()).sum().item()) def _dump_debug(dump_dir: str, step: int, x_hat: torch.Tensor, K: int): os.makedirs(dump_dir, exist_ok=True) counts = x_hat.flatten().bincount(minlength=K).tolist() with open(os.path.join(dump_dir, f"step_{step:06d}_hist.json"), "w") as fh: json.dump({"step": step, "counts": counts}, fh) torch.save(x_hat.cpu(), os.path.join(dump_dir, f"step_{step:06d}_xhat.pt")) print(f"[debug_dump] step={step} saved to {dump_dir}") def _run_assertions( x_init, ids_init, rpos_init, z_s, p_s, logp, z_T_cond, z_S_cond, x_t, K, N, B, T, H, W, lm_vocab_size, z_T_uncond=None, z_T_guided=None, dry_run=False, ): """Full shape / value-domain / consistency assertions (run at step=1).""" print("[assert] Running shape/value assertions …") L_plus_N1 = ids_init.size(1) txt_len = L_plus_N1 - (N + 1) # x_init assert x_init.dtype == torch.long, f"x_init dtype={x_init.dtype}" assert x_init.min() >= 0 and x_init.max() < K, \ f"x_init out of [0,K): [{x_init.min()}, {x_init.max()}]" # input_ids shape & token value ranges assert ids_init.shape == (B, L_plus_N1), f"ids_init.shape={ids_init.shape}" txt_part = ids_init[:, :txt_len] vis_part = ids_init[:, -N:] assert (txt_part < lm_vocab_size).all(), \ f"text tokens bleed into visual range (max={txt_part.max()})" assert (vis_part >= lm_vocab_size).all(), \ f"visual tokens not shifted (min={vis_part.min()}, lm_vocab_size={lm_vocab_size})" assert (vis_part < lm_vocab_size + K).all(), \ f"visual tokens exceed lm_vocab_size+K (max={vis_part.max()})" # rope_pos assert rpos_init.shape == (B, L_plus_N1, 3), \ f"rope_pos shape={rpos_init.shape} expected ({B},{L_plus_N1},3)" # z_s assert z_s.shape == (B, N, K), f"z_s.shape={z_s.shape}" p_err = (p_s.sum(-1) - 1).abs().max().item() assert p_err < 1e-3, f"p_s not normalised: max deviation={p_err:.2e}" # logp assert not torch.isnan(logp).any(), "logp contains NaN" assert not torch.isinf(logp).any(), "logp contains Inf" # x_t assert x_t.min() >= 0 and x_t.max() < K, \ f"x_t out of [0,K) after add_noise: [{x_t.min()}, {x_t.max()}]" # Dual-path shape check assert z_T_cond.shape == z_S_cond.shape, \ f"Dual-path mismatch: z_T_cond={z_T_cond.shape} z_S_cond={z_S_cond.shape}" assert z_T_cond.shape == (B, N, K), f"z_T_cond.shape={z_T_cond.shape}" # z_T logits printout (always in dry_run; also when uncond is available) if dry_run or z_T_uncond is not None: print( f"[assert] z_T_cond shape={z_T_cond.shape} " f"min={z_T_cond.min():.3f} max={z_T_cond.max():.3f} " f"mean={z_T_cond.mean():.3f}" ) if z_T_uncond is not None: assert z_T_uncond.shape == (B, N, K), f"z_T_uncond.shape={z_T_uncond.shape}" print( f"[assert] z_T_uncond shape={z_T_uncond.shape} " f"min={z_T_uncond.min():.3f} max={z_T_uncond.max():.3f} " f"mean={z_T_uncond.mean():.3f}" ) if z_T_guided is not None: assert z_T_guided.shape == (B, N, K), f"z_T_guided.shape={z_T_guided.shape}" g_min = z_T_guided.min().item() g_max = z_T_guided.max().item() g_mean = z_T_guided.mean().item() print( f"[assert] z_T_guided shape={z_T_guided.shape} " f"min={g_min:.3f} max={g_max:.3f} mean={g_mean:.3f}" ) # Explosion guard: guided logits must be finite and not excessively large. assert not torch.isnan(z_T_guided).any(), "z_T_guided contains NaN" assert not torch.isinf(z_T_guided).any(), "z_T_guided contains Inf" assert abs(g_min) < 1e4 and abs(g_max) < 1e4, ( f"z_T_guided magnitude too large: min={g_min:.1e} max={g_max:.1e}. " f"Reduce --teacher_cfg_scale (currently may amplify outlier logits)." ) print("[assert] All assertions PASSED ✓") def _sanity_check_forward(teacher, scheduler, latents_shape, device, K, verbose=False): print("[init] Checking logit dimensions …") T, H, W = latents_shape N, B, L = T * H * W, 1, 16 dummy_txt = torch.zeros(B, L, dtype=torch.long, device=device) dummy_vis = torch.zeros(B, T, H, W, dtype=torch.long, device=device) with torch.no_grad(): ids, rpos, _ = build_ursa_inputs(teacher, dummy_txt, dummy_vis, latents_shape, device) logits = teacher(ids, rope_pos=rpos).sample lm_head_size = teacher.config.lm_head_size lm_vocab = teacher.config.lm_vocab_size print( f"[init] logits={logits.shape} K={K} " f"lm_head={lm_head_size} lm_vocab={lm_vocab}" ) assert ids.shape == (B, L + N + 1), f"ids shape {ids.shape}" assert rpos.shape == (B, L + N + 1, 3), f"rpos shape {rpos.shape}" z = extract_visual_logits(logits.float(), N, K) assert z.shape == (B, N, K), f"z shape {z.shape}" assert lm_head_size >= K, f"lm_head_size={lm_head_size} < K={K}" if verbose: print("[init] flex_attn state during sanity check:") _print_flex_attn_state(teacher, "teacher") print("[init] Forward check OK ✓") # --------------------------------------------------------------------------- # Dry-run patches 7 / 8 / 9 # --------------------------------------------------------------------------- def _dry_run_patches_789(teacher, latents_shape, K, N, device): """Three deep self-checks executed only during --dry_run. Patch 7 — extract_visual_logits end-to-end alignment: Run a real teacher forward, manually reconstruct z_manual from raw logits using the latent_shift / codebook_size convention, and assert the result matches extract_visual_logits(). Handles the common URSA case where lm_head outputs K logits directly (latent_shift not applied to logit dim). Patch 8 — flex_attn semantics sanity: If the model exposes set_offsets_by_lens, compare visual-logit mean-delta between offsets=None (standard causal) and a single-block offset. A large delta is expected and confirms that our training correctly uses offsets=None. Gracefully skips when flex_attention is unavailable at runtime. Patch 9 — logp / token reshape consistency: With a small (T=3, H=4, W=5) shape, verify x_hat reshape round-trips and spot-check 10 token positions against manually computed log-probability. """ T, H, W = latents_shape L_test, B_test = 16, 1 print("\n" + "=" * 64) print("[patch 7/8/9] Running additional dry_run self-checks …") # ------------------------------------------------------------------------- # Build shared dummy inputs used by both patch 7 and patch 8 # ------------------------------------------------------------------------- dummy_txt = torch.zeros(B_test, L_test, dtype=torch.long, device=device) dummy_vis = torch.zeros(B_test, T, H, W, dtype=torch.long, device=device) with torch.no_grad(): ids_test, rpos_test, _ = build_ursa_inputs( teacher, dummy_txt, dummy_vis, latents_shape, device) logits_full = teacher(ids_test, rope_pos=rpos_test).sample.float() # [1, L+N+1, D] D = logits_full.size(-1) # actual logit last-dim (lm_head_size) latent_shift = teacher.config.lm_vocab_size # text-vocab offset for input token IDs # ========================================================================= # Patch 7 — extract_visual_logits end-to-end alignment # ========================================================================= print("\n[7] extract_visual_logits end-to-end alignment …") z_vis = extract_visual_logits(logits_full, N, K) # [1, N, K] assert z_vis.shape == (B_test, N, K), f"z_vis.shape={z_vis.shape}" if D >= latent_shift + K: # Full-vocab head: logit dim covers text (0..latent_shift) + visual tokens. z_seq = logits_full[:, -(N + 1) : -1] # [1, N, D] z_manual = z_seq[..., latent_shift : latent_shift + K] # [1, N, K] delta = (z_vis - z_manual).abs().max().item() print(f" [7] path=full-vocab D={D} latent_shift+K={latent_shift + K}") print(f" [7] z_vis.shape={z_vis.shape} max|z_vis - z_manual|={delta:.2e}") assert delta < 1e-5, ( f"extract_visual_logits mismatch (full-vocab path): delta={delta:.2e}. " "The function should return logits[..., latent_shift:latent_shift+K]." ) print("[7] extract_visual_logits alignment PASSED ✓") else: # Common URSA case: lm_head outputs K logits directly (lm_head_size ≈ K). # latent_shift is the input token-ID offset, NOT a logit-dimension offset. # extract_visual_logits handles this as D==K (happy path) or D>K (offset=D-K). z_seq = logits_full[:, -(N + 1) : -1] # [1, N, D] if D == K: delta = (z_vis - z_seq).abs().max().item() print( f" [7] SKIP latent_shift formula: D={D} == K={K} " f"latent_shift={latent_shift}.\n" f" [7] Explanation: URSA lm_head outputs K visual logits directly.\n" f" [7] latent_shift={latent_shift} is the input token-ID shift " f"(raw_code + lm_vocab_size), NOT a logit-dim offset.\n" f" [7] extract_visual_logits happy-path: z = logits[:, -(N+1):-1] " f"(no vocab-dim slicing).\n" f" [7] Fallback check: z_vis == raw causal slice " f"max_delta={delta:.2e}" ) assert delta < 1e-5, ( f"z_vis != raw causal slice when D==K: delta={delta:.2e}" ) else: # D > K but D < latent_shift + K → extract uses offset = D - K offset = D - K z_manual = z_seq[..., offset:] delta = (z_vis - z_manual).abs().max().item() print( f" [7] SKIP latent_shift formula: D={D} < latent_shift+K={latent_shift + K}.\n" f" [7] extract_visual_logits uses offset={offset} (D-K). " f"max_delta={delta:.2e}" ) assert delta < 1e-5, ( f"z_vis != z_seq[..., D-K:]: delta={delta:.2e}" ) print("[7] extract_visual_logits alignment PASSED (fallback path) ✓") # ========================================================================= # Patch 8 — flex_attn semantics sanity # ========================================================================= print("\n[8] flex_attn semantics sanity …") fa = _probe_flex_attn(teacher) if fa is None or not hasattr(fa, "set_offsets_by_lens"): print(" [8] flex_attn.set_offsets_by_lens not available — skip") print("[8] flex_attn semantics sanity PASSED (skipped — no flex_attn) ✓") else: L_total = ids_test.size(1) # L_test + N + 1 txt_block = L_test + (N + 1) # single-block: all tokens in one block block_lens = [txt_block] try: # Forward A: offsets=None — standard causal attention (our training config) _reset_flex_attn(teacher, "teacher") with torch.no_grad(): logits_A = teacher(ids_test, rope_pos=rpos_test).sample.float() z_A = extract_visual_logits(logits_A, N, K) # Forward B: set_offsets_by_lens with a single block. # A single block causes the mask to allow full (bidirectional) attention # within the block, which differs from standard causal attention. fa.set_offsets_by_lens(block_lens) with torch.no_grad(): logits_B = teacher(ids_test, rope_pos=rpos_test).sample.float() z_B = extract_visual_logits(logits_B, N, K) delta_mean = (z_A - z_B).abs().mean().item() delta_max = (z_A - z_B).abs().max().item() print( f" [8] offsets=None vs set_offsets_by_lens({block_lens}):\n" f" [8] mean_abs_delta={delta_mean:.4e} max_abs_delta={delta_max:.4e}" ) if delta_mean > 1e-3: print( f" [8] WARNING: mean_delta={delta_mean:.2e} > 1e-3.\n" " [8] Single-block flex_attn uses FULL (bidirectional) attention\n" " [8] inside the block, whereas offsets=None gives standard CAUSAL\n" " [8] attention. This difference is EXPECTED — it confirms our\n" " [8] training correctly uses offsets=None (no packed sequences)." ) else: print(f" [8] delta ≤ 1e-3: attention semantics equivalent for this input.") print("[8] flex_attn semantics sanity PASSED ✓") except (NotImplementedError, RuntimeError, Exception) as exc: print(f" [8] flex_attn runtime not available ({type(exc).__name__}: {exc}) — skip") print("[8] flex_attn semantics sanity PASSED (runtime skip) ✓") finally: _reset_flex_attn(teacher, "teacher") # always restore clean state # ========================================================================= # Patch 9 — logp / token reshape consistency # ========================================================================= print("\n[9] logp/token reshape consistency …") T9, H9, W9 = 3, 4, 5 N9, B9 = T9 * H9 * W9, 1 # 60 tokens, batch=1 torch.manual_seed(99) z9 = torch.randn(B9, N9, K) p9 = F.softmax(z9 / 1.0, dim=-1) # [1, 60, K]; each row sums to 1 # ----- token sampling --------------------------------------------------- x_hat_flat = torch.multinomial(p9.view(-1, K), 1) # [N9, 1] (1 sample per row) x_hat_1d = x_hat_flat.view(B9, N9) # [1, 60] x_hat_4d = x_hat_1d.view(B9, T9, H9, W9) # [1, 3, 4, 5] # reshape round-trip: 1d → 4d → 1d must be lossless x_hat_back = x_hat_4d.view(B9, N9) assert torch.equal(x_hat_1d, x_hat_back), ( f"reshape round-trip FAILED: x_hat_1d != x_hat_4d.view(B,N)\n" f" x_hat_1d.shape={x_hat_1d.shape} x_hat_back.shape={x_hat_back.shape}" ) # ----- logp computation (mirrors training code) ------------------------- # logp_all[b, n] = log p9[b, n, x_hat_1d[b, n]] logp_all = ( p9.clamp(1e-8).log() .gather(-1, x_hat_1d.unsqueeze(-1)) .squeeze(-1) ) # [B9, N9] logp_sum = logp_all.sum(-1) # [B9] # ----- spot-check 10 random token positions ----------------------------- torch.manual_seed(7) positions = torch.randperm(N9)[:10].tolist() for pos in positions: tok_id = x_hat_1d[0, pos].item() logp_man = math.log(max(p9[0, pos, tok_id].item(), 1e-8)) logp_gat = logp_all[0, pos].item() diff = abs(logp_man - logp_gat) assert diff < 1e-6, ( f"logp mismatch at pos={pos}, tok={tok_id}: " f"manual={logp_man:.8f} gathered={logp_gat:.8f} diff={diff:.2e}" ) # check logp_sum matches sum of logp_all logp_sum_manual = logp_all[0].sum().item() assert abs(logp_sum.item() - logp_sum_manual) < 1e-5, \ f"logp_sum mismatch: {logp_sum.item():.6f} vs {logp_sum_manual:.6f}" print( f" [9] T={T9},H={H9},W={W9} N={N9} K={K} " f"x_hat reshape round-trip ✓ " f"10 logp spot-checks (pos={positions}) ✓ " f"logp_sum={logp_sum.item():.3f}" ) print("[9] logp/token reshape consistency PASSED ✓") print("\n" + "=" * 64) print("[patch 7/8/9] All 3 additional dry_run checks PASSED ✓") print("=" * 64) if __name__ == "__main__": main()