VBVR-Wan2.1 / example.py
wruisi's picture
Update README
43f5d8f
#!/usr/bin/env python3
"""
VBVR-Wan2.1 Image-to-Video Inference Example
Generate a video from a reference image using the VBVR-Wan2.1 model.
Usage:
python example.py --model_path /path/to/VBVR-Wan2.1
"""
import os
import argparse
import numpy as np
import torch
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
# ─────────────── Configuration ───────────────
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="VBVR-Wan2.1")
parser.add_argument("--image_path", type=str, default=None,
help="Input image path (default: assets/first_frame.png inside model_path)")
parser.add_argument("--output_path", type=str, default="output.mp4")
parser.add_argument("--max_area", type=int, default=720 * 1280,
help="Max pixel area for resolution calculation")
parser.add_argument("--num_frames", type=int, default=81)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--guidance_scale", type=float, default=5.0)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
model_path = args.model_path
image_path = args.image_path or os.path.join(model_path, "assets", "first_frame.png")
output_path = args.output_path
# Prompt
prompt = (
"The scene contains two types of shapes, each type has three shapes of "
"different sizes arranged randomly. Keep all shapes unchanged in appearance "
"(type, size, and color). Only rearrange their positions: first group the "
"shapes by type, then within each group, sort the shapes from smallest to "
"largest (left to right), and arrange all shapes in a single horizontal "
"line from left to right."
)
negative_prompt = (
"Bright tones, overexposed, static, blurred details, subtitles, style, "
"works, paintings, images, static, overall gray, worst quality, low quality, "
"JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, "
"poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, "
"still picture, messy background, three legs, many people in the background, "
"walking backwards"
)
# ──────────────────────── Load Pipeline ────────────────────────
print(f"Loading model from: {model_path}")
image_encoder = CLIPVisionModel.from_pretrained(
model_path, subfolder="image_encoder", torch_dtype=torch.float32
)
vae = AutoencoderKLWan.from_pretrained(
model_path, subfolder="vae", torch_dtype=torch.float32
)
pipe = WanImageToVideoPipeline.from_pretrained(
model_path,
vae=vae,
image_encoder=image_encoder,
torch_dtype=torch.bfloat16,
)
pipe.to("cuda")
# ──────────────────────── Load Image ────────────────────────
print(f"Loading image: {image_path}")
image = load_image(image_path)
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(args.max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(args.max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
print(f"Image resized to: {width}x{height} (max_area={args.max_area})")
# ──────────────────────── Generate Video ────────────────────────
print(f"Generating video: {args.num_frames} frames @ {width}x{height}, "
f"{args.num_inference_steps} steps")
generator = torch.Generator(device="cuda").manual_seed(args.seed)
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=args.num_frames,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
generator=generator,
)
export_to_video(output.frames[0], output_path, fps=16)
print(f"Video saved to: {output_path}")