代码解读(关键设计决策)
utils_ursa_inputs.py
build_ursa_inputs(transformer, txt_ids, visual_tokens, latents_shape, device)
严格复刻 URSAPipeline.call 的 token 拼接逻辑:
img_ids = pad(latents_flat + lm_vocab_size, (1,0), value=bov_token_id)input_ids = cat([txt_ids, img_ids], dim=1)blk_pos = flex_rope.get_pos(latents_shape, L)rope_pos = cat([txt_pos, blk_pos[0]]).unsqueeze(0).expand(B,-1,-1)
extract_visual_logits(logits, N, K)
坑 1 防护:z = logits[:, -(N+1):-1](causal slice),然后根据最后一维是否等于 K 决定是否再切 slice。
sample_t_curriculum — 前 10k 步用 t = 1-(1-u)^2 偏大,之后恢复均匀采样。
train_onestep_ursa_dimo.py 训练循环
每一步的 9 个 stage 对应 DiMO 论文的完整流程:
Stage 操作 梯度
1-2 tokenize + 采样 x_init (80% uniform / 20% corrupt) 无
3 student 在 x_init 上 1-step forward → x_hat, logp, H ✅ student
4 add_noise(x_hat, t) → x_t 无(离散采样截断)
5 teacher 在 x_t → p_T 无 (no_grad)
6 aux 在 x_t → Jeffrey(p_T, p_A) → backward → aux update ✅ aux only
7 student 在 x_t → KL(p_T ‖ p_S_t) ✅ student
8 REINFORCE: r=-loss_aux, adv=r-EMA, loss_pg=-(adv·logp) ✅ student (via logp)
9 L_s = λ_pg·loss_pg + λ_kd·loss_kd - λ_ent·H → student update ✅ student
运行命令示例
端到端冒烟测试(单卡,17帧256×256,2000步):
python scripts/train_onestep_ursa_dimo.py
--teacher_ckpt /gfs/space/private/fengzl/World_Model/URSA-1.7B/
--prompt_file /gfs/space/private/fengzl/World_Model/Koala-36M-v1/
--num_frames 17 --height 256 --width 256
--batch_size 1 --num_steps 2000
--log_every 50 --save_every 500
--out_dir ./outputs/dimo_test
评估(1-step student vs 25-step teacher):
python scripts/eval_onestep_ursa.py
--teacher_ckpt /gfs/space/private/fengzl/World_Model/URSA-1.7B/
--student_ckpt ./outputs/dimo_test/final/student.pt
--num_frames 17 --height 256 --width 256
--teacher_steps 25
--out_dir ./outputs/eval
扩展到完整分辨率(49帧 320×512):
python scripts/train_onestep_ursa_dimo.py
--teacher_ckpt /gfs/space/private/fengzl/World_Model/URSA-1.7B/
--prompt_file /gfs/space/private/fengzl/World_Model/Koala-36M-v1/
--num_frames 49 --height 320 --width 512
--batch_size 2 --num_steps 50000
--lambda_ent 0.01 --t_curriculum_steps 10000
--mixed_precision bf16 --out_dir ./outputs/dimo_full
三大稳定性机制(缺一不可) t curriculum — 前 10k 步 t 偏大,teacher 分布更尖锐,KD 信号更强,避免早期 student 随机游走 p_init mixing — 20% batch 用 corrupt(x_hat_prev, r=0.2),让 student 学会"一步修复" 熵正则 λ_ent — 初始 0.01,若检测到 tok_entropy 下降就升到 0.05
8 卡启动命令 accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml --machine_rank 0 --num_machines 1 --num_processes 8 scripts/train_distill_dimo.py config=./configs/distill_dimo.yaml experiment.output_dir=./experiments/distill_dimo distill.teacher_ckpt=/gfs/space/private/fengzl/World_Model/URSA-1.7B distill.prompt_source=/gfs/space/private/fengzl/World_Model/Koala-36M-v1 distill.batch_size_per_gpu=1
Smoke Test(50 步,保存 checkpoint)
accelerate launch --num_processes 8 --mixed_precision bf16
scripts/train_distill_dimo.py
config="./configs/distill_dimo.yaml"
experiment.output_dir="./experiments/smoke"
distill.teacher_ckpt="/gfs/space/private/fengzl/World_Model/URSA-1.7B"
distill.prompt_source="/gfs/space/private/fengzl/World_Model/Koala-36M-v1"
training.max_train_steps=50
experiment.save_every=50
加载 student.pt 做 1-step 推理 from diffnext.pipelines import URSAPipelineimport torchpipe = URSAPipeline.from_pretrained( "/path/to/URSA-1.7B-IBQ1024", torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda")# 替换 transformer 权重为 studentstate = torch.load("experiments/distill_dimo/checkpoints/final/student.pt", map_location="cuda")pipe.transformer.load_state_dict(state, strict=True)# 1-step 生成(num_inference_steps=1)frames = pipe( prompt="a dog running on a beach", height=256, width=256, num_frames=17, num_inference_steps=1, guidance_scale=3.0,).frames
最新 修改分辨率和cfg后
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml
--machine_rank 0 --num_machines 1 --num_processes 8
scripts/train_distill_dimo.py
config="./configs/distill_dimo.yaml"
experiment.output_dir="./experiments/distill_dimo"
distill.teacher_ckpt="/gfs/space/private/fengzl/World_Model/URSA-1.7B"
distill.prompt_source="/gfs/space/private/fengzl/World_Model/Koala-36M-v1"