|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Evaluation script: compare student 1-step variants vs multi-step teacher.
|
|
|
| Verified native inference regime (from A/B testing β ground truth):
|
| height=320, width=512, num_frames=49, guidance_scale=7, teacher_steps=50.
|
| no_cfg (guidance_scale=1) does NOT produce valid output for this URSA checkpoint.
|
|
|
| Student generation modes
|
| ------------------------
|
| cfg : 1-step, guidance_scale=7 (verified working student mode)
|
| baked : 1-step, guidance_scale=1 (for students trained with CFG KD)
|
|
|
| Teacher generation modes
|
| ------------------------
|
| cfg : 50-step, guidance_scale=7 (verified working teacher mode)
|
|
|
| Usage:
|
| python scripts/eval_onestep_ursa.py \\
|
| --teacher_ckpt /path/to/URSA \\
|
| --student_ckpt ./outputs/dimo/final/student.pt \\
|
| --modes cfg \\
|
| --eval_cfg_scale 7.0 \\
|
| --num_frames 49 --height 320 --width 512 \\
|
| --teacher_steps 50 \\
|
| --out_dir ./outputs/eval
|
| """
|
|
|
| import argparse
|
| import os
|
| import sys
|
|
|
| import numpy as np
|
| import torch
|
|
|
| _REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| if _REPO_ROOT not in sys.path:
|
| sys.path.insert(0, _REPO_ROOT)
|
|
|
| from diffnext.pipelines import URSAPipeline
|
| from diffnext.utils import export_to_video
|
|
|
|
|
|
|
|
|
|
|
|
|
| DEFAULT_PROMPTS = [
|
| "a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur.",
|
| "beautiful fireworks in the sky with red, white and blue.",
|
| "a wave crashes on a rocky shoreline at sunset, slow motion.",
|
| "a hummingbird hovers in front of a red flower, wings a blur.",
|
| "timelapse of clouds rolling over mountain peaks.",
|
| "a neon-lit city street at night with rain-soaked reflections.",
|
| "a kitten playing with a ball of yarn on a wooden floor.",
|
| "astronaut floating weightlessly inside a space station.",
|
| ]
|
|
|
| DEFAULT_SEEDS = [0, 1, 2, 3]
|
|
|
|
|
|
|
|
|
|
|
|
|
| def parse_args():
|
| p = argparse.ArgumentParser(description="URSA 1-step student eval vs teacher")
|
|
|
| p.add_argument("--teacher_ckpt", required=True, help="URSA diffusers pipeline dir")
|
| p.add_argument("--student_ckpt", required=True,
|
| help="student.pt checkpoint from train_onestep_ursa_dimo.py")
|
| p.add_argument("--out_dir", default="./outputs/eval")
|
|
|
|
|
| p.add_argument("--num_frames", type=int, default=49)
|
| p.add_argument("--height", type=int, default=320)
|
| p.add_argument("--width", type=int, default=512)
|
| p.add_argument("--fps", type=int, default=12)
|
|
|
|
|
| p.add_argument("--modes", nargs="+", default=["cfg"],
|
| choices=["no_cfg", "cfg", "baked"],
|
| help="Student generation modes. Default: ['cfg']. "
|
| "no_cfg is known to produce blank/blurry output.")
|
| p.add_argument("--eval_cfg_scale", type=float, default=7.0,
|
| help="Guidance scale for 'cfg' mode (verified working value=7)")
|
| p.add_argument("--teacher_steps", type=int, default=50,
|
| help="Inference steps for teacher (verified default=50)")
|
| p.add_argument("--teacher_modes", nargs="+", default=["cfg"],
|
| choices=["no_cfg", "cfg"],
|
| help="Teacher modes. Default: ['cfg']. "
|
| "no_cfg is NOT a valid baseline for this checkpoint.")
|
| p.add_argument("--guidance_trunc", type=float, default=0.9,
|
| help="Truncation threshold for inference CFG (passed to pipeline)")
|
| p.add_argument("--max_prompt_length", type=int, default=320)
|
| p.add_argument("--vae_batch_size", type=int, default=1)
|
|
|
|
|
| p.add_argument("--prompt_file", default=None,
|
| help="Optional: text file with one prompt per line")
|
| p.add_argument("--seeds", nargs="*", type=int, default=DEFAULT_SEEDS)
|
|
|
|
|
| p.add_argument("--device", type=int, default=0)
|
| p.add_argument("--mixed_precision", default="bf16", choices=["fp16", "bf16", "fp32"])
|
|
|
| return p.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
| def slug(text: str, max_len: int = 40) -> str:
|
| s = text.lower()
|
| s = "".join(c if c.isalnum() or c == " " else "" for c in s)
|
| s = "_".join(s.split())[:max_len]
|
| return s or "prompt"
|
|
|
|
|
| def frames_to_mp4(frames, path: str, fps: int = 12):
|
| os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
| if isinstance(frames, np.ndarray) and frames.ndim == 4:
|
| frames = list(frames)
|
| export_to_video(frames, output_video_path=path, fps=fps)
|
|
|
|
|
| def _extract_frames(frames_output):
|
| """Normalise pipeline output β list of uint8 numpy arrays [H, W, 3]."""
|
| if isinstance(frames_output, np.ndarray):
|
| frames_output = frames_output[0] if frames_output.ndim == 5 else frames_output
|
| frames = list(frames_output)
|
| elif isinstance(frames_output, list):
|
| frames = [np.array(f) if not isinstance(f, np.ndarray) else f for f in frames_output]
|
| else:
|
| raise TypeError(f"Unexpected frames type: {type(frames_output)}")
|
| result = []
|
| for f in frames:
|
| if f.dtype != np.uint8:
|
| f = (f * 255).clip(0, 255).astype(np.uint8) if f.max() <= 1.0 else f.astype(np.uint8)
|
| result.append(f)
|
| return result
|
|
|
|
|
| DEFAULT_NEGATIVE_PROMPT = (
|
| "worst quality, low quality, inconsistent motion, static, still, "
|
| "blurry, jittery, distorted, ugly"
|
| )
|
|
|
|
|
| def _gen(pipe, prompt, seed, num_frames, height, width, guidance_scale,
|
| num_inference_steps, guidance_trunc, max_prompt_length, vae_batch_size,
|
| device, negative_prompt=None):
|
| """Single generation call, returns list of uint8 frames."""
|
| gen = torch.Generator(device=device).manual_seed(seed)
|
| out = pipe(
|
| prompt=prompt,
|
| negative_prompt=negative_prompt,
|
| height=height,
|
| width=width,
|
| num_frames=num_frames,
|
| guidance_scale=guidance_scale,
|
| num_inference_steps=num_inference_steps,
|
| guidance_trunc=guidance_trunc,
|
| max_prompt_length=max_prompt_length,
|
| vae_batch_size=vae_batch_size,
|
| output_type="np",
|
| generator=gen,
|
| )
|
| return _extract_frames(out.frames)
|
|
|
|
|
|
|
|
|
|
|
|
|
| def main():
|
| args = parse_args()
|
|
|
| dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}
|
| dtype = dtype_map[args.mixed_precision]
|
| device = torch.device("cuda", args.device) if torch.cuda.is_available() else torch.device("cpu")
|
| os.makedirs(args.out_dir, exist_ok=True)
|
|
|
|
|
| _NATIVE = dict(height=320, width=512, num_frames=49, guidance_scale=7.0, teacher_steps=50)
|
| is_native = (
|
| args.height == _NATIVE["height"]
|
| and args.width == _NATIVE["width"]
|
| and args.num_frames == _NATIVE["num_frames"]
|
| and args.eval_cfg_scale == _NATIVE["guidance_scale"]
|
| and args.teacher_steps == _NATIVE["teacher_steps"]
|
| )
|
| print(f"[eval] verified_native_regime={is_native}")
|
| print(f"[eval] geometry=({args.num_frames},{args.height},{args.width}), "
|
| f"guidance_scale={args.eval_cfg_scale}, teacher_steps={args.teacher_steps}")
|
| if not is_native:
|
| print(f"[WARN] Current config deviates from the verified native URSA regime "
|
| f"({_NATIVE['num_frames']}Γ{_NATIVE['height']}Γ{_NATIVE['width']}, "
|
| f"cfg={_NATIVE['guidance_scale']}, steps={_NATIVE['teacher_steps']}).")
|
|
|
| all_modes = list(args.modes) + list(args.teacher_modes)
|
| if "no_cfg" in all_modes:
|
| print("[WARN] no_cfg is known to fail for this URSA checkpoint. "
|
| "Outputs may be blank or blurry.")
|
|
|
|
|
| if args.prompt_file:
|
| with open(args.prompt_file, encoding="utf-8") as f:
|
| prompts = [l.strip() for l in f if l.strip() and not l.startswith("#")]
|
| else:
|
| prompts = DEFAULT_PROMPTS
|
|
|
| print(f"[eval] {len(prompts)} prompts Γ {len(args.seeds)} seeds "
|
| f"| student modes={args.modes} | teacher modes={args.teacher_modes}")
|
|
|
|
|
| print(f"[eval] Loading pipeline from {args.teacher_ckpt} β¦")
|
| pipe = URSAPipeline.from_pretrained(
|
| args.teacher_ckpt, torch_dtype=dtype, trust_remote_code=True
|
| ).to(device)
|
|
|
|
|
| print(f"[eval] Loading student weights from {args.student_ckpt} β¦")
|
| student_state = torch.load(args.student_ckpt, map_location=device, weights_only=True)
|
| teacher_state = {k: v.clone() for k, v in pipe.transformer.state_dict().items()}
|
|
|
|
|
| gen_kwargs = dict(
|
| num_frames=args.num_frames,
|
| height=args.height,
|
| width=args.width,
|
| guidance_trunc=args.guidance_trunc,
|
| max_prompt_length=args.max_prompt_length,
|
| vae_batch_size=args.vae_batch_size,
|
| )
|
|
|
|
|
|
|
|
|
|
|
| student_guidance = {
|
| "no_cfg": 1.0,
|
| "cfg": args.eval_cfg_scale,
|
| "baked": 1.0,
|
| }
|
| teacher_guidance = {
|
| "no_cfg": 1.0,
|
| "cfg": args.eval_cfg_scale,
|
| }
|
|
|
|
|
| for idx, prompt in enumerate(prompts):
|
| p_slug = slug(prompt)
|
| print(f"\n[{idx+1}/{len(prompts)}] {prompt[:70]}")
|
|
|
| for seed in args.seeds:
|
|
|
| for mode in args.modes:
|
| g_scale = student_guidance[mode]
|
| neg = DEFAULT_NEGATIVE_PROMPT if g_scale > 1 else None
|
| pipe.transformer.load_state_dict(student_state, strict=True)
|
| pipe.transformer.eval()
|
|
|
| with torch.no_grad():
|
| frames = _gen(pipe, prompt, seed,
|
| guidance_scale=g_scale,
|
| num_inference_steps=1,
|
| negative_prompt=neg,
|
| device=device, **gen_kwargs)
|
|
|
| path = os.path.join(
|
| args.out_dir,
|
| f"{idx:02d}_s{seed}_{p_slug}_student_1step_{mode}.mp4",
|
| )
|
| frames_to_mp4(frames, path, fps=args.fps)
|
| print(f" [student/{mode:6s}] seed={seed} scale={g_scale} β {path}")
|
|
|
|
|
| for t_mode in args.teacher_modes:
|
| g_scale = teacher_guidance[t_mode]
|
| neg = DEFAULT_NEGATIVE_PROMPT if g_scale > 1 else None
|
| pipe.transformer.load_state_dict(teacher_state, strict=True)
|
| pipe.transformer.eval()
|
|
|
| with torch.no_grad():
|
| frames = _gen(pipe, prompt, seed,
|
| guidance_scale=g_scale,
|
| num_inference_steps=args.teacher_steps,
|
| negative_prompt=neg,
|
| device=device, **gen_kwargs)
|
|
|
| path = os.path.join(
|
| args.out_dir,
|
| f"{idx:02d}_s{seed}_{p_slug}_teacher_{args.teacher_steps}step_{t_mode}.mp4",
|
| )
|
| frames_to_mp4(frames, path, fps=args.fps)
|
| print(f" [teacher/{t_mode:6s}] seed={seed} scale={g_scale} "
|
| f"steps={args.teacher_steps} β {path}")
|
|
|
| print(f"\n[eval] Done. Results in {args.out_dir}")
|
| _print_interpretation_guide(args)
|
|
|
|
|
| def _print_interpretation_guide(args):
|
| print(f"""
|
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| β Interpretation guide for generated videos β
|
| β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ£
|
| β student_1step_cfg : 1-step + CFG={args.eval_cfg_scale:<4} β
|
| β (verified working student mode) β
|
| β student_1step_baked : 1-step, guidance_scale=1 β
|
| β (for students trained with CFG KD) β
|
| β teacher_{args.teacher_steps}step_cfg : {args.teacher_steps}-step + CFG={args.eval_cfg_scale:<4} β
|
| β (verified working teacher mode) β
|
| β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ£
|
| β NOTE: no_cfg (guidance_scale=1) is NOT a valid baseline β
|
| β for this URSA checkpoint β outputs are blank or blurry. β
|
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ""")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|