File size: 4,255 Bytes
4620b21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/env python3
"""
VBVR-Wan2.1 Image-to-Video Inference Example

Generate a video from a reference image using the VBVR-Wan2.1 model.

Usage:
    python example.py --model_path /path/to/VBVR-Wan2.1
"""

import os
import argparse

import numpy as np
import torch
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel

# ─────────────── Configuration ───────────────
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="VBVR-Wan2.1")
parser.add_argument("--image_path", type=str, default=None,
                    help="Input image path (default: assets/first_frame.png inside model_path)")
parser.add_argument("--output_path", type=str, default="output.mp4")
parser.add_argument("--max_area", type=int, default=720 * 1280,
                    help="Max pixel area for resolution calculation")
parser.add_argument("--num_frames", type=int, default=81)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--guidance_scale", type=float, default=5.0)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()

model_path = args.model_path
image_path = args.image_path or os.path.join(model_path, "assets", "first_frame.png")
output_path = args.output_path

# Prompt
prompt = (
    "The scene contains two types of shapes, each type has three shapes of "
    "different sizes arranged randomly. Keep all shapes unchanged in appearance "
    "(type, size, and color). Only rearrange their positions: first group the "
    "shapes by type, then within each group, sort the shapes from smallest to "
    "largest (left to right), and arrange all shapes in a single horizontal "
    "line from left to right."
)
negative_prompt = (
    "Bright tones, overexposed, static, blurred details, subtitles, style, "
    "works, paintings, images, static, overall gray, worst quality, low quality, "
    "JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, "
    "poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, "
    "still picture, messy background, three legs, many people in the background, "
    "walking backwards"
)

# ──────────────────────── Load Pipeline ────────────────────────

print(f"Loading model from: {model_path}")

image_encoder = CLIPVisionModel.from_pretrained(
    model_path, subfolder="image_encoder", torch_dtype=torch.float32
)
vae = AutoencoderKLWan.from_pretrained(
    model_path, subfolder="vae", torch_dtype=torch.float32
)
pipe = WanImageToVideoPipeline.from_pretrained(
    model_path,
    vae=vae,
    image_encoder=image_encoder,
    torch_dtype=torch.bfloat16,
)
pipe.to("cuda")

# ──────────────────────── Load Image ────────────────────────

print(f"Loading image: {image_path}")
image = load_image(image_path)

aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(args.max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(args.max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
print(f"Image resized to: {width}x{height} (max_area={args.max_area})")

# ──────────────────────── Generate Video ────────────────────────

print(f"Generating video: {args.num_frames} frames @ {width}x{height}, "
      f"{args.num_inference_steps} steps")
generator = torch.Generator(device="cuda").manual_seed(args.seed)

output = pipe(
    image=image,
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=height,
    width=width,
    num_frames=args.num_frames,
    num_inference_steps=args.num_inference_steps,
    guidance_scale=args.guidance_scale,
    generator=generator,
)

export_to_video(output.frames[0], output_path, fps=16)
print(f"Video saved to: {output_path}")