World_Model / URSA /scripts /eval_onestep_ursa.py
BryanW's picture
Add files using upload-large-folder tool
2ee4cd6 verified
#!/usr/bin/env python3
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------
"""Evaluation script: compare student 1-step variants vs multi-step teacher.
Verified native inference regime (from A/B testing β€” ground truth):
height=320, width=512, num_frames=49, guidance_scale=7, teacher_steps=50.
no_cfg (guidance_scale=1) does NOT produce valid output for this URSA checkpoint.
Student generation modes
------------------------
cfg : 1-step, guidance_scale=7 (verified working student mode)
baked : 1-step, guidance_scale=1 (for students trained with CFG KD)
Teacher generation modes
------------------------
cfg : 50-step, guidance_scale=7 (verified working teacher mode)
Usage:
python scripts/eval_onestep_ursa.py \\
--teacher_ckpt /path/to/URSA \\
--student_ckpt ./outputs/dimo/final/student.pt \\
--modes cfg \\
--eval_cfg_scale 7.0 \\
--num_frames 49 --height 320 --width 512 \\
--teacher_steps 50 \\
--out_dir ./outputs/eval
"""
import argparse
import os
import sys
import numpy as np
import torch
_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if _REPO_ROOT not in sys.path:
sys.path.insert(0, _REPO_ROOT)
from diffnext.pipelines import URSAPipeline
from diffnext.utils import export_to_video
# ---------------------------------------------------------------------------
# Default prompts and seeds
# ---------------------------------------------------------------------------
DEFAULT_PROMPTS = [
"a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur.",
"beautiful fireworks in the sky with red, white and blue.",
"a wave crashes on a rocky shoreline at sunset, slow motion.",
"a hummingbird hovers in front of a red flower, wings a blur.",
"timelapse of clouds rolling over mountain peaks.",
"a neon-lit city street at night with rain-soaked reflections.",
"a kitten playing with a ball of yarn on a wooden floor.",
"astronaut floating weightlessly inside a space station.",
]
DEFAULT_SEEDS = [0, 1, 2, 3]
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_args():
p = argparse.ArgumentParser(description="URSA 1-step student eval vs teacher")
p.add_argument("--teacher_ckpt", required=True, help="URSA diffusers pipeline dir")
p.add_argument("--student_ckpt", required=True,
help="student.pt checkpoint from train_onestep_ursa_dimo.py")
p.add_argument("--out_dir", default="./outputs/eval")
# Geometry (verified native: 320Γ—512Γ—49)
p.add_argument("--num_frames", type=int, default=49)
p.add_argument("--height", type=int, default=320)
p.add_argument("--width", type=int, default=512)
p.add_argument("--fps", type=int, default=12)
# Generation β€” default: cfg only (no_cfg is known to fail)
p.add_argument("--modes", nargs="+", default=["cfg"],
choices=["no_cfg", "cfg", "baked"],
help="Student generation modes. Default: ['cfg']. "
"no_cfg is known to produce blank/blurry output.")
p.add_argument("--eval_cfg_scale", type=float, default=7.0,
help="Guidance scale for 'cfg' mode (verified working value=7)")
p.add_argument("--teacher_steps", type=int, default=50,
help="Inference steps for teacher (verified default=50)")
p.add_argument("--teacher_modes", nargs="+", default=["cfg"],
choices=["no_cfg", "cfg"],
help="Teacher modes. Default: ['cfg']. "
"no_cfg is NOT a valid baseline for this checkpoint.")
p.add_argument("--guidance_trunc", type=float, default=0.9,
help="Truncation threshold for inference CFG (passed to pipeline)")
p.add_argument("--max_prompt_length", type=int, default=320)
p.add_argument("--vae_batch_size", type=int, default=1)
# Data
p.add_argument("--prompt_file", default=None,
help="Optional: text file with one prompt per line")
p.add_argument("--seeds", nargs="*", type=int, default=DEFAULT_SEEDS)
# Device
p.add_argument("--device", type=int, default=0)
p.add_argument("--mixed_precision", default="bf16", choices=["fp16", "bf16", "fp32"])
return p.parse_args()
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def slug(text: str, max_len: int = 40) -> str:
s = text.lower()
s = "".join(c if c.isalnum() or c == " " else "" for c in s)
s = "_".join(s.split())[:max_len]
return s or "prompt"
def frames_to_mp4(frames, path: str, fps: int = 12):
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
if isinstance(frames, np.ndarray) and frames.ndim == 4:
frames = list(frames)
export_to_video(frames, output_video_path=path, fps=fps)
def _extract_frames(frames_output):
"""Normalise pipeline output β†’ list of uint8 numpy arrays [H, W, 3]."""
if isinstance(frames_output, np.ndarray):
frames_output = frames_output[0] if frames_output.ndim == 5 else frames_output
frames = list(frames_output)
elif isinstance(frames_output, list):
frames = [np.array(f) if not isinstance(f, np.ndarray) else f for f in frames_output]
else:
raise TypeError(f"Unexpected frames type: {type(frames_output)}")
result = []
for f in frames:
if f.dtype != np.uint8:
f = (f * 255).clip(0, 255).astype(np.uint8) if f.max() <= 1.0 else f.astype(np.uint8)
result.append(f)
return result
DEFAULT_NEGATIVE_PROMPT = (
"worst quality, low quality, inconsistent motion, static, still, "
"blurry, jittery, distorted, ugly"
)
def _gen(pipe, prompt, seed, num_frames, height, width, guidance_scale,
num_inference_steps, guidance_trunc, max_prompt_length, vae_batch_size,
device, negative_prompt=None):
"""Single generation call, returns list of uint8 frames."""
gen = torch.Generator(device=device).manual_seed(seed)
out = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
guidance_trunc=guidance_trunc,
max_prompt_length=max_prompt_length,
vae_batch_size=vae_batch_size,
output_type="np",
generator=gen,
)
return _extract_frames(out.frames)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
args = parse_args()
dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}
dtype = dtype_map[args.mixed_precision]
device = torch.device("cuda", args.device) if torch.cuda.is_available() else torch.device("cpu")
os.makedirs(args.out_dir, exist_ok=True)
# -- Verified regime validation ----------------------------------------
_NATIVE = dict(height=320, width=512, num_frames=49, guidance_scale=7.0, teacher_steps=50)
is_native = (
args.height == _NATIVE["height"]
and args.width == _NATIVE["width"]
and args.num_frames == _NATIVE["num_frames"]
and args.eval_cfg_scale == _NATIVE["guidance_scale"]
and args.teacher_steps == _NATIVE["teacher_steps"]
)
print(f"[eval] verified_native_regime={is_native}")
print(f"[eval] geometry=({args.num_frames},{args.height},{args.width}), "
f"guidance_scale={args.eval_cfg_scale}, teacher_steps={args.teacher_steps}")
if not is_native:
print(f"[WARN] Current config deviates from the verified native URSA regime "
f"({_NATIVE['num_frames']}Γ—{_NATIVE['height']}Γ—{_NATIVE['width']}, "
f"cfg={_NATIVE['guidance_scale']}, steps={_NATIVE['teacher_steps']}).")
all_modes = list(args.modes) + list(args.teacher_modes)
if "no_cfg" in all_modes:
print("[WARN] no_cfg is known to fail for this URSA checkpoint. "
"Outputs may be blank or blurry.")
# -- Load prompts -----------------------------------------------------
if args.prompt_file:
with open(args.prompt_file, encoding="utf-8") as f:
prompts = [l.strip() for l in f if l.strip() and not l.startswith("#")]
else:
prompts = DEFAULT_PROMPTS
print(f"[eval] {len(prompts)} prompts Γ— {len(args.seeds)} seeds "
f"| student modes={args.modes} | teacher modes={args.teacher_modes}")
# -- Load pipeline ---------------------------------------------------
print(f"[eval] Loading pipeline from {args.teacher_ckpt} …")
pipe = URSAPipeline.from_pretrained(
args.teacher_ckpt, torch_dtype=dtype, trust_remote_code=True
).to(device)
# -- Load student checkpoint -----------------------------------------
print(f"[eval] Loading student weights from {args.student_ckpt} …")
student_state = torch.load(args.student_ckpt, map_location=device, weights_only=True)
teacher_state = {k: v.clone() for k, v in pipe.transformer.state_dict().items()}
# Common kwargs passed to every pipeline call
gen_kwargs = dict(
num_frames=args.num_frames,
height=args.height,
width=args.width,
guidance_trunc=args.guidance_trunc,
max_prompt_length=args.max_prompt_length,
vae_batch_size=args.vae_batch_size,
)
# Mode β†’ guidance_scale mapping
# no_cfg : single forward, no guidance
# cfg : dual forward, eval_cfg_scale
# baked : single forward, no guidance (student trained with guided KD)
student_guidance = {
"no_cfg": 1.0,
"cfg": args.eval_cfg_scale,
"baked": 1.0,
}
teacher_guidance = {
"no_cfg": 1.0,
"cfg": args.eval_cfg_scale,
}
# -- Evaluation loop -------------------------------------------------
for idx, prompt in enumerate(prompts):
p_slug = slug(prompt)
print(f"\n[{idx+1}/{len(prompts)}] {prompt[:70]}")
for seed in args.seeds:
# ---- Student: selected modes --------------------------------
for mode in args.modes:
g_scale = student_guidance[mode]
neg = DEFAULT_NEGATIVE_PROMPT if g_scale > 1 else None
pipe.transformer.load_state_dict(student_state, strict=True)
pipe.transformer.eval()
with torch.no_grad():
frames = _gen(pipe, prompt, seed,
guidance_scale=g_scale,
num_inference_steps=1,
negative_prompt=neg,
device=device, **gen_kwargs)
path = os.path.join(
args.out_dir,
f"{idx:02d}_s{seed}_{p_slug}_student_1step_{mode}.mp4",
)
frames_to_mp4(frames, path, fps=args.fps)
print(f" [student/{mode:6s}] seed={seed} scale={g_scale} β†’ {path}")
# ---- Teacher: reference videos ------------------------------
for t_mode in args.teacher_modes:
g_scale = teacher_guidance[t_mode]
neg = DEFAULT_NEGATIVE_PROMPT if g_scale > 1 else None
pipe.transformer.load_state_dict(teacher_state, strict=True)
pipe.transformer.eval()
with torch.no_grad():
frames = _gen(pipe, prompt, seed,
guidance_scale=g_scale,
num_inference_steps=args.teacher_steps,
negative_prompt=neg,
device=device, **gen_kwargs)
path = os.path.join(
args.out_dir,
f"{idx:02d}_s{seed}_{p_slug}_teacher_{args.teacher_steps}step_{t_mode}.mp4",
)
frames_to_mp4(frames, path, fps=args.fps)
print(f" [teacher/{t_mode:6s}] seed={seed} scale={g_scale} "
f"steps={args.teacher_steps} β†’ {path}")
print(f"\n[eval] Done. Results in {args.out_dir}")
_print_interpretation_guide(args)
def _print_interpretation_guide(args):
print(f"""
╔══════════════════════════════════════════════════════════════╗
β•‘ Interpretation guide for generated videos β•‘
╠══════════════════════════════════════════════════════════════╣
β•‘ student_1step_cfg : 1-step + CFG={args.eval_cfg_scale:<4} β•‘
β•‘ (verified working student mode) β•‘
β•‘ student_1step_baked : 1-step, guidance_scale=1 β•‘
β•‘ (for students trained with CFG KD) β•‘
β•‘ teacher_{args.teacher_steps}step_cfg : {args.teacher_steps}-step + CFG={args.eval_cfg_scale:<4} β•‘
β•‘ (verified working teacher mode) β•‘
╠══════════════════════════════════════════════════════════════╣
β•‘ NOTE: no_cfg (guidance_scale=1) is NOT a valid baseline β•‘
β•‘ for this URSA checkpoint β€” outputs are blank or blurry. β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•""")
if __name__ == "__main__":
main()