Varad1707 commited on
Commit
3447d56
·
verified ·
1 Parent(s): 10496cb

Upload run_wan_I2V_CP+FB_Trans.py

Browse files
Files changed (1) hide show
  1. run_wan_I2V_CP+FB_Trans.py +111 -0
run_wan_I2V_CP+FB_Trans.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ import numpy as np
4
+ import time
5
+ from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
6
+ from diffusers.utils import export_to_video, load_image
7
+ from transformers import CLIPVisionModel
8
+ from PIL import Image
9
+ from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
10
+ import io
11
+
12
+ dist.init_process_group()
13
+
14
+ torch.cuda.set_device(dist.get_rank())
15
+
16
+ start_loading = time.time()
17
+
18
+ data_type = torch.bfloat16
19
+
20
+ # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
21
+ model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
22
+
23
+ image_encoder = CLIPVisionModel.from_pretrained(
24
+ model_id, subfolder="image_encoder", torch_dtype=data_type
25
+ )
26
+ vae = AutoencoderKLWan.from_pretrained(
27
+ model_id, subfolder="vae", torch_dtype=torch.float32
28
+ )
29
+ pipe = WanImageToVideoPipeline.from_pretrained(
30
+ model_id, vae=vae, image_encoder=image_encoder, torch_dtype=data_type
31
+ )
32
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0)
33
+ pipe.to("cuda")
34
+
35
+ # Import and apply parallel attention
36
+ from para_attn.context_parallel import init_context_parallel_mesh
37
+ from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
38
+
39
+ parallelize_pipe(
40
+ pipe,
41
+ mesh=init_context_parallel_mesh(
42
+ pipe.device.type,
43
+ ),
44
+ )
45
+ end_loading = time.time()
46
+
47
+ from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe,apply_cache_on_transformer
48
+ # apply_cache_on_pipe(pipe , residual_diff_threshold=0.05)
49
+ apply_cache_on_transformer(pipe.transformer)
50
+
51
+ # pipe.enable_model_cpu_offload(gpu_id=dist.get_rank())
52
+ # pipe.enable_vae_tiling()
53
+
54
+ # torch._inductor.config.reorder_for_compute_comm_overlap = True
55
+ # pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")
56
+ print("Pipeline Loaded.....")
57
+ loading_time = end_loading - start_loading
58
+
59
+ prompt = (
60
+ "Cars racing in slow motion"
61
+ )
62
+ negative_prompt = (
63
+ "bright colors, overexposed, static, blurred details, subtitles, style, artwork, painting, picture, still, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, malformed limbs, fused fingers, still picture, cluttered background, three legs, many people in the background, walking backwards"
64
+ )
65
+
66
+ image = load_image(
67
+ "https://storage.googleapis.com/falserverless/gallery/car_720p.png"
68
+ )
69
+
70
+ # max_area = 1024 * 1024
71
+ # aspect_ratio = image.height / image.width
72
+ # mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
73
+ # print(f"MOD VALUE :{mod_value}")
74
+ # height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
75
+ # width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
76
+ width , height = image.size
77
+ # image = image.resize((width, height))
78
+
79
+ # Start measuring inference time
80
+ start_inference = time.time()
81
+
82
+ # Run the pipeline
83
+ output = pipe(
84
+ image=image,
85
+ prompt=prompt,
86
+ negative_prompt=negative_prompt,
87
+ height=height,
88
+ width=width,
89
+ num_frames=81,
90
+ guidance_scale=5.0,
91
+ num_inference_steps=30,
92
+ output_type="pil" if dist.get_rank() == 0 else "pt",
93
+ ).frames[0]
94
+
95
+ # End of inference time measurement
96
+ end_inference = time.time()
97
+ inference_time = end_inference - start_inference
98
+
99
+ # Save output and print timing info
100
+ if dist.get_rank() == 0:
101
+ print(f"{'=' * 50}")
102
+ print(f"Device: {torch.cuda.get_device_name()}")
103
+ print(f"Number of GPUs: {dist.get_world_size()}")
104
+ print(f"Pipeline Loading Time: {loading_time:.2f} seconds")
105
+ print(f"Pipeline Inference Time: {inference_time:.2f} seconds")
106
+ print(f"Resolution: {width}x{height}")
107
+ print(f"{'=' * 50}")
108
+ print("Saving video to wan-i2v.mp4")
109
+ export_to_video(output, "wan-i2v.mp4", fps=16)
110
+
111
+ dist.destroy_process_group()