Varad1707 commited on
Commit
cd66daa
·
verified ·
1 Parent(s): 3447d56

Upload run_wan_I2V_CP+FB_GroupOffload.py

Browse files
Files changed (1) hide show
  1. run_wan_I2V_CP+FB_GroupOffload.py +140 -0
run_wan_I2V_CP+FB_GroupOffload.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ import numpy as np
4
+ import time
5
+ from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UMT5EncoderModel, WanTransformer3DModel
6
+ from diffusers.utils import export_to_video, load_image
7
+ from transformers import CLIPVisionModel, UMT5EncoderModel
8
+ from PIL import Image
9
+ from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
10
+ import io
11
+ from diffusers.hooks.group_offloading import apply_group_offloading
12
+
13
+ dist.init_process_group()
14
+
15
+ torch.cuda.set_device(dist.get_rank())
16
+
17
+ start_loading = time.time()
18
+
19
+ data_type = torch.bfloat16
20
+
21
+ # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
22
+ model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
23
+
24
+ # image_encoder = CLIPVisionModel.from_pretrained(
25
+ # model_id, subfolder="image_encoder", torch_dtype=data_type
26
+ # )
27
+ # vae = AutoencoderKLWan.from_pretrained(
28
+ # model_id, subfolder="vae", torch_dtype=torch.float32
29
+ # )
30
+
31
+ image_encoder = CLIPVisionModel.from_pretrained(
32
+ model_id, subfolder="image_encoder", torch_dtype=torch.float32
33
+ )
34
+
35
+ text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
36
+ vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
37
+ transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
38
+
39
+
40
+ onload_device = torch.device("cuda")
41
+ offload_device = torch.device("cpu")
42
+
43
+ apply_group_offloading(text_encoder,
44
+ onload_device=onload_device,
45
+ offload_device=offload_device,
46
+ offload_type="block_level",
47
+ num_blocks_per_group=4
48
+ )
49
+ # pipe = WanImageToVideoPipeline.from_pretrained(
50
+ # model_id, vae=vae, image_encoder=image_encoder, torch_dtype=data_type
51
+ # )
52
+
53
+ pipe = WanImageToVideoPipeline.from_pretrained(
54
+ model_id,
55
+ vae=vae,
56
+ transformer=transformer,
57
+ text_encoder=text_encoder,
58
+ image_encoder=image_encoder,
59
+ torch_dtype=torch.bfloat16
60
+ )
61
+
62
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0)
63
+ pipe.to("cuda")
64
+
65
+ # Import and apply parallel attention
66
+ from para_attn.context_parallel import init_context_parallel_mesh
67
+ from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
68
+
69
+ parallelize_pipe(
70
+ pipe,
71
+ mesh=init_context_parallel_mesh(
72
+ pipe.device.type,
73
+ ),
74
+ )
75
+ end_loading = time.time()
76
+
77
+ from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
78
+ apply_cache_on_pipe(pipe , residual_diff_threshold=0.05)
79
+
80
+ # pipe.enable_model_cpu_offload(gpu_id=dist.get_rank())
81
+ # pipe.enable_vae_tiling()
82
+
83
+ # torch._inductor.config.reorder_for_compute_comm_overlap = True
84
+ # pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")
85
+ print("Pipeline Loaded.....")
86
+ loading_time = end_loading - start_loading
87
+
88
+ prompt = (
89
+ "Cars racing in slow motion"
90
+ )
91
+ negative_prompt = (
92
+ "bright colors, overexposed, static, blurred details, subtitles, style, artwork, painting, picture, still, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, malformed limbs, fused fingers, still picture, cluttered background, three legs, many people in the background, walking backwards"
93
+ )
94
+
95
+ image = load_image(
96
+ "https://storage.googleapis.com/falserverless/gallery/car_720p.png"
97
+ )
98
+
99
+ # max_area = 1024 * 1024
100
+ # aspect_ratio = image.height / image.width
101
+ # mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
102
+ # print(f"MOD VALUE :{mod_value}")
103
+ # height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
104
+ # width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
105
+ width , height = image.size
106
+ # image = image.resize((width, height))
107
+
108
+ # Start measuring inference time
109
+ start_inference = time.time()
110
+
111
+ # Run the pipeline
112
+ output = pipe(
113
+ image=image,
114
+ prompt=prompt,
115
+ negative_prompt=negative_prompt,
116
+ height=height,
117
+ width=width,
118
+ num_frames=81,
119
+ guidance_scale=5.0,
120
+ num_inference_steps=30,
121
+ output_type="pil" if dist.get_rank() == 0 else "pt",
122
+ ).frames[0]
123
+
124
+ # End of inference time measurement
125
+ end_inference = time.time()
126
+ inference_time = end_inference - start_inference
127
+
128
+ # Save output and print timing info
129
+ if dist.get_rank() == 0:
130
+ print(f"{'=' * 50}")
131
+ print(f"Device: {torch.cuda.get_device_name()}")
132
+ print(f"Number of GPUs: {dist.get_world_size()}")
133
+ print(f"Pipeline Loading Time: {loading_time:.2f} seconds")
134
+ print(f"Pipeline Inference Time: {inference_time:.2f} seconds")
135
+ print(f"Resolution: {width}x{height}")
136
+ print(f"{'=' * 50}")
137
+ print("Saving video to wan-i2v.mp4")
138
+ export_to_video(output, "wan-i2v.mp4", fps=16)
139
+
140
+ dist.destroy_process_group()