#!/usr/bin/env python3 # Copyright (c) 2024-present, BAAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ----------------------------------------------------------------------- """Evaluation script: compare student 1-step variants vs multi-step teacher. Verified native inference regime (from A/B testing — ground truth): height=320, width=512, num_frames=49, guidance_scale=7, teacher_steps=50. no_cfg (guidance_scale=1) does NOT produce valid output for this URSA checkpoint. Student generation modes ------------------------ cfg : 1-step, guidance_scale=7 (verified working student mode) baked : 1-step, guidance_scale=1 (for students trained with CFG KD) Teacher generation modes ------------------------ cfg : 50-step, guidance_scale=7 (verified working teacher mode) Usage: python scripts/eval_onestep_ursa.py \\ --teacher_ckpt /path/to/URSA \\ --student_ckpt ./outputs/dimo/final/student.pt \\ --modes cfg \\ --eval_cfg_scale 7.0 \\ --num_frames 49 --height 320 --width 512 \\ --teacher_steps 50 \\ --out_dir ./outputs/eval """ import argparse import os import sys import numpy as np import torch _REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) if _REPO_ROOT not in sys.path: sys.path.insert(0, _REPO_ROOT) from diffnext.pipelines import URSAPipeline from diffnext.utils import export_to_video # --------------------------------------------------------------------------- # Default prompts and seeds # --------------------------------------------------------------------------- DEFAULT_PROMPTS = [ "a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur.", "beautiful fireworks in the sky with red, white and blue.", "a wave crashes on a rocky shoreline at sunset, slow motion.", "a hummingbird hovers in front of a red flower, wings a blur.", "timelapse of clouds rolling over mountain peaks.", "a neon-lit city street at night with rain-soaked reflections.", "a kitten playing with a ball of yarn on a wooden floor.", "astronaut floating weightlessly inside a space station.", ] DEFAULT_SEEDS = [0, 1, 2, 3] # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def parse_args(): p = argparse.ArgumentParser(description="URSA 1-step student eval vs teacher") p.add_argument("--teacher_ckpt", required=True, help="URSA diffusers pipeline dir") p.add_argument("--student_ckpt", required=True, help="student.pt checkpoint from train_onestep_ursa_dimo.py") p.add_argument("--out_dir", default="./outputs/eval") # Geometry (verified native: 320×512×49) p.add_argument("--num_frames", type=int, default=49) p.add_argument("--height", type=int, default=320) p.add_argument("--width", type=int, default=512) p.add_argument("--fps", type=int, default=12) # Generation — default: cfg only (no_cfg is known to fail) p.add_argument("--modes", nargs="+", default=["cfg"], choices=["no_cfg", "cfg", "baked"], help="Student generation modes. Default: ['cfg']. " "no_cfg is known to produce blank/blurry output.") p.add_argument("--eval_cfg_scale", type=float, default=7.0, help="Guidance scale for 'cfg' mode (verified working value=7)") p.add_argument("--teacher_steps", type=int, default=50, help="Inference steps for teacher (verified default=50)") p.add_argument("--teacher_modes", nargs="+", default=["cfg"], choices=["no_cfg", "cfg"], help="Teacher modes. Default: ['cfg']. " "no_cfg is NOT a valid baseline for this checkpoint.") p.add_argument("--guidance_trunc", type=float, default=0.9, help="Truncation threshold for inference CFG (passed to pipeline)") p.add_argument("--max_prompt_length", type=int, default=320) p.add_argument("--vae_batch_size", type=int, default=1) # Data p.add_argument("--prompt_file", default=None, help="Optional: text file with one prompt per line") p.add_argument("--seeds", nargs="*", type=int, default=DEFAULT_SEEDS) # Device p.add_argument("--device", type=int, default=0) p.add_argument("--mixed_precision", default="bf16", choices=["fp16", "bf16", "fp32"]) return p.parse_args() # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def slug(text: str, max_len: int = 40) -> str: s = text.lower() s = "".join(c if c.isalnum() or c == " " else "" for c in s) s = "_".join(s.split())[:max_len] return s or "prompt" def frames_to_mp4(frames, path: str, fps: int = 12): os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) if isinstance(frames, np.ndarray) and frames.ndim == 4: frames = list(frames) export_to_video(frames, output_video_path=path, fps=fps) def _extract_frames(frames_output): """Normalise pipeline output → list of uint8 numpy arrays [H, W, 3].""" if isinstance(frames_output, np.ndarray): frames_output = frames_output[0] if frames_output.ndim == 5 else frames_output frames = list(frames_output) elif isinstance(frames_output, list): frames = [np.array(f) if not isinstance(f, np.ndarray) else f for f in frames_output] else: raise TypeError(f"Unexpected frames type: {type(frames_output)}") result = [] for f in frames: if f.dtype != np.uint8: f = (f * 255).clip(0, 255).astype(np.uint8) if f.max() <= 1.0 else f.astype(np.uint8) result.append(f) return result DEFAULT_NEGATIVE_PROMPT = ( "worst quality, low quality, inconsistent motion, static, still, " "blurry, jittery, distorted, ugly" ) def _gen(pipe, prompt, seed, num_frames, height, width, guidance_scale, num_inference_steps, guidance_trunc, max_prompt_length, vae_batch_size, device, negative_prompt=None): """Single generation call, returns list of uint8 frames.""" gen = torch.Generator(device=device).manual_seed(seed) out = pipe( prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, num_frames=num_frames, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, guidance_trunc=guidance_trunc, max_prompt_length=max_prompt_length, vae_batch_size=vae_batch_size, output_type="np", generator=gen, ) return _extract_frames(out.frames) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): args = parse_args() dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32} dtype = dtype_map[args.mixed_precision] device = torch.device("cuda", args.device) if torch.cuda.is_available() else torch.device("cpu") os.makedirs(args.out_dir, exist_ok=True) # -- Verified regime validation ---------------------------------------- _NATIVE = dict(height=320, width=512, num_frames=49, guidance_scale=7.0, teacher_steps=50) is_native = ( args.height == _NATIVE["height"] and args.width == _NATIVE["width"] and args.num_frames == _NATIVE["num_frames"] and args.eval_cfg_scale == _NATIVE["guidance_scale"] and args.teacher_steps == _NATIVE["teacher_steps"] ) print(f"[eval] verified_native_regime={is_native}") print(f"[eval] geometry=({args.num_frames},{args.height},{args.width}), " f"guidance_scale={args.eval_cfg_scale}, teacher_steps={args.teacher_steps}") if not is_native: print(f"[WARN] Current config deviates from the verified native URSA regime " f"({_NATIVE['num_frames']}×{_NATIVE['height']}×{_NATIVE['width']}, " f"cfg={_NATIVE['guidance_scale']}, steps={_NATIVE['teacher_steps']}).") all_modes = list(args.modes) + list(args.teacher_modes) if "no_cfg" in all_modes: print("[WARN] no_cfg is known to fail for this URSA checkpoint. " "Outputs may be blank or blurry.") # -- Load prompts ----------------------------------------------------- if args.prompt_file: with open(args.prompt_file, encoding="utf-8") as f: prompts = [l.strip() for l in f if l.strip() and not l.startswith("#")] else: prompts = DEFAULT_PROMPTS print(f"[eval] {len(prompts)} prompts × {len(args.seeds)} seeds " f"| student modes={args.modes} | teacher modes={args.teacher_modes}") # -- Load pipeline --------------------------------------------------- print(f"[eval] Loading pipeline from {args.teacher_ckpt} …") pipe = URSAPipeline.from_pretrained( args.teacher_ckpt, torch_dtype=dtype, trust_remote_code=True ).to(device) # -- Load student checkpoint ----------------------------------------- print(f"[eval] Loading student weights from {args.student_ckpt} …") student_state = torch.load(args.student_ckpt, map_location=device, weights_only=True) teacher_state = {k: v.clone() for k, v in pipe.transformer.state_dict().items()} # Common kwargs passed to every pipeline call gen_kwargs = dict( num_frames=args.num_frames, height=args.height, width=args.width, guidance_trunc=args.guidance_trunc, max_prompt_length=args.max_prompt_length, vae_batch_size=args.vae_batch_size, ) # Mode → guidance_scale mapping # no_cfg : single forward, no guidance # cfg : dual forward, eval_cfg_scale # baked : single forward, no guidance (student trained with guided KD) student_guidance = { "no_cfg": 1.0, "cfg": args.eval_cfg_scale, "baked": 1.0, } teacher_guidance = { "no_cfg": 1.0, "cfg": args.eval_cfg_scale, } # -- Evaluation loop ------------------------------------------------- for idx, prompt in enumerate(prompts): p_slug = slug(prompt) print(f"\n[{idx+1}/{len(prompts)}] {prompt[:70]}") for seed in args.seeds: # ---- Student: selected modes -------------------------------- for mode in args.modes: g_scale = student_guidance[mode] neg = DEFAULT_NEGATIVE_PROMPT if g_scale > 1 else None pipe.transformer.load_state_dict(student_state, strict=True) pipe.transformer.eval() with torch.no_grad(): frames = _gen(pipe, prompt, seed, guidance_scale=g_scale, num_inference_steps=1, negative_prompt=neg, device=device, **gen_kwargs) path = os.path.join( args.out_dir, f"{idx:02d}_s{seed}_{p_slug}_student_1step_{mode}.mp4", ) frames_to_mp4(frames, path, fps=args.fps) print(f" [student/{mode:6s}] seed={seed} scale={g_scale} → {path}") # ---- Teacher: reference videos ------------------------------ for t_mode in args.teacher_modes: g_scale = teacher_guidance[t_mode] neg = DEFAULT_NEGATIVE_PROMPT if g_scale > 1 else None pipe.transformer.load_state_dict(teacher_state, strict=True) pipe.transformer.eval() with torch.no_grad(): frames = _gen(pipe, prompt, seed, guidance_scale=g_scale, num_inference_steps=args.teacher_steps, negative_prompt=neg, device=device, **gen_kwargs) path = os.path.join( args.out_dir, f"{idx:02d}_s{seed}_{p_slug}_teacher_{args.teacher_steps}step_{t_mode}.mp4", ) frames_to_mp4(frames, path, fps=args.fps) print(f" [teacher/{t_mode:6s}] seed={seed} scale={g_scale} " f"steps={args.teacher_steps} → {path}") print(f"\n[eval] Done. Results in {args.out_dir}") _print_interpretation_guide(args) def _print_interpretation_guide(args): print(f""" ╔══════════════════════════════════════════════════════════════╗ ║ Interpretation guide for generated videos ║ ╠══════════════════════════════════════════════════════════════╣ ║ student_1step_cfg : 1-step + CFG={args.eval_cfg_scale:<4} ║ ║ (verified working student mode) ║ ║ student_1step_baked : 1-step, guidance_scale=1 ║ ║ (for students trained with CFG KD) ║ ║ teacher_{args.teacher_steps}step_cfg : {args.teacher_steps}-step + CFG={args.eval_cfg_scale:<4} ║ ║ (verified working teacher mode) ║ ╠══════════════════════════════════════════════════════════════╣ ║ NOTE: no_cfg (guidance_scale=1) is NOT a valid baseline ║ ║ for this URSA checkpoint — outputs are blank or blurry. ║ ╚══════════════════════════════════════════════════════════════╝""") if __name__ == "__main__": main()