Varad1707 commited on
Commit
ca58a44
·
verified ·
1 Parent(s): 5dad980

Upload run_wan_I2V_FB.py

Browse files
Files changed (1) hide show
  1. run_wan_I2V_FB.py +88 -0
run_wan_I2V_FB.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ import numpy as np
4
+ import time
5
+ from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
6
+ from diffusers.utils import export_to_video, load_image
7
+ from transformers import CLIPVisionModel
8
+ from PIL import Image
9
+ from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
10
+ import io
11
+
12
+ start_loading = time.time()
13
+ data_type = torch.bfloat16
14
+
15
+ # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
16
+ model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
17
+
18
+ image_encoder = CLIPVisionModel.from_pretrained(
19
+ model_id, subfolder="image_encoder", torch_dtype=data_type
20
+ )
21
+ vae = AutoencoderKLWan.from_pretrained(
22
+ model_id, subfolder="vae", torch_dtype=torch.float32
23
+ )
24
+ pipe = WanImageToVideoPipeline.from_pretrained(
25
+ model_id, vae=vae, image_encoder=image_encoder, torch_dtype=data_type
26
+ )
27
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0)
28
+ pipe.to("cuda")
29
+
30
+ from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
31
+
32
+ apply_cache_on_pipe(pipe , residual_diff_threshold=0.1)
33
+
34
+ print("Pipeline Loaded.....")
35
+ end_loading = time.time()
36
+ loading_time = end_loading - start_loading
37
+
38
+ prompt = (
39
+ "Cars racing in slow motion"
40
+ )
41
+ negative_prompt = (
42
+ "bright colors, overexposed, static, blurred details, subtitles, style, artwork, painting, picture, still, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, malformed limbs, fused fingers, still picture, cluttered background, three legs, many people in the background, walking backwards"
43
+ )
44
+
45
+ image = load_image(
46
+ "https://storage.googleapis.com/falserverless/gallery/car_720p.png"
47
+ )
48
+
49
+ # max_area = 1024 * 1024
50
+ # aspect_ratio = image.height / image.width
51
+ # mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
52
+ # print(f"MOD VALUE :{mod_value}")
53
+ # height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
54
+ # width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
55
+ width , height = image.size
56
+ # image = image.resize((width, height))
57
+
58
+ # Start measuring inference time
59
+ start_inference = time.time()
60
+
61
+ # Run the pipeline
62
+ output = pipe(
63
+ image=image,
64
+ prompt=prompt,
65
+ negative_prompt=negative_prompt,
66
+ height=height,
67
+ width=width,
68
+ num_frames=81,
69
+ guidance_scale=5.0,
70
+ num_inference_steps=30,
71
+ output_type="pil",
72
+ ).frames[0]
73
+
74
+ # End of inference time measurement
75
+ end_inference = time.time()
76
+ inference_time = end_inference - start_inference
77
+
78
+ # Save output and print timing info
79
+ # if dist.get_rank() == 0:
80
+ print(f"{'=' * 50}")
81
+ print(f"Device: {torch.cuda.get_device_name()}")
82
+ print(f"Number of GPUs: {dist.get_world_size()}")
83
+ print(f"Pipeline Loading Time: {loading_time:.2f} seconds")
84
+ print(f"Pipeline Inference Time: {inference_time:.2f} seconds")
85
+ print(f"Resolution: {width}x{height}")
86
+ print(f"{'=' * 50}")
87
+ print("Saving video to wan-i2v.mp4")
88
+ export_to_video(output, "wan-i2v_fb.mp4", fps=16)