Varad1707 commited on
Commit
25eb67d
·
verified ·
1 Parent(s): cd66daa

Upload wan_i2v_pipeline.py

Browse files
Files changed (1) hide show
  1. wan_i2v_pipeline.py +204 -0
wan_i2v_pipeline.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ import time
4
+ from dataclasses import dataclass
5
+ from typing import Union
6
+ from pathlib import Path
7
+
8
+ from diffusers import (
9
+ AutoencoderKLWan,
10
+ WanImageToVideoPipeline,
11
+ WanTransformer3DModel,
12
+ UniPCMultistepScheduler
13
+ )
14
+ from diffusers.utils import export_to_video, load_image
15
+ from transformers import CLIPVisionModel, UMT5EncoderModel
16
+ from PIL import Image
17
+
18
+ @dataclass
19
+ class WanPipelineConfig:
20
+ model_id: str = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
21
+ data_type: torch.dtype = torch.bfloat16
22
+ device: str = "cuda"
23
+ width: int = 1024
24
+ height: int = 576
25
+ num_frames: int = 81
26
+ guidance_scale: float = 5.0
27
+ num_inference_steps: int = 30
28
+ fps: int = 16
29
+
30
+ class WanI2VPipeline:
31
+ def __init__(self, config: WanPipelineConfig):
32
+ self.config = config
33
+ self.pipe = None
34
+ self.setup_distributed()
35
+
36
+ def setup_distributed(self):
37
+ """Initialize distributed training setup"""
38
+ if not dist.is_initialized():
39
+ dist.init_process_group()
40
+ torch.cuda.set_device(dist.get_rank())
41
+
42
+ def load_models(self):
43
+ """Load and initialize all required models"""
44
+ try:
45
+ print("Loading models...")
46
+ start_time = time.time()
47
+
48
+ # Load all model components
49
+ image_encoder = CLIPVisionModel.from_pretrained(
50
+ self.config.model_id,
51
+ subfolder="image_encoder",
52
+ torch_dtype=torch.float32
53
+ )
54
+
55
+ text_encoder = UMT5EncoderModel.from_pretrained(
56
+ self.config.model_id,
57
+ subfolder="text_encoder",
58
+ torch_dtype=self.config.data_type
59
+ )
60
+
61
+ vae = AutoencoderKLWan.from_pretrained(
62
+ self.config.model_id,
63
+ subfolder="vae",
64
+ torch_dtype=torch.float32
65
+ )
66
+
67
+ transformer = WanTransformer3DModel.from_pretrained(
68
+ self.config.model_id,
69
+ subfolder="transformer",
70
+ torch_dtype=self.config.data_type
71
+ )
72
+
73
+ # Initialize pipeline
74
+ self.pipe = WanImageToVideoPipeline.from_pretrained(
75
+ self.config.model_id,
76
+ vae=vae,
77
+ transformer=transformer,
78
+ text_encoder=text_encoder,
79
+ image_encoder=image_encoder,
80
+ torch_dtype=self.config.data_type
81
+ )
82
+
83
+ # Configure scheduler and move to device
84
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(
85
+ self.pipe.scheduler.config,
86
+ flow_shift=5.0
87
+ )
88
+ self.pipe.to(self.config.device)
89
+
90
+ # Apply optimizations
91
+ self._apply_optimizations()
92
+
93
+ print(f"Models loaded in {time.time() - start_time:.2f} seconds")
94
+
95
+ except Exception as e:
96
+ raise RuntimeError(f"Failed to load models: {str(e)}")
97
+
98
+ def _apply_optimizations(self):
99
+ """Apply various pipeline optimizations"""
100
+ from para_attn.context_parallel import init_context_parallel_mesh
101
+ from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
102
+ from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
103
+
104
+ # Apply parallel attention
105
+ parallelize_pipe(
106
+ self.pipe,
107
+ mesh=init_context_parallel_mesh(self.pipe.device.type)
108
+ )
109
+
110
+ # Apply caching
111
+ apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.1)
112
+
113
+ def generate_video(
114
+ self,
115
+ image_path: Union[str, Path],
116
+ prompt: str,
117
+ negative_prompt: str,
118
+ output_path: str = "output.mp4"
119
+ ) -> None:
120
+ """Generate video from input image"""
121
+ try:
122
+ # Load and preprocess image
123
+ image = self._prepare_image(image_path)
124
+
125
+ # Generate video frames
126
+ print("Generating video...")
127
+ start_time = time.time()
128
+
129
+ output = self.pipe(
130
+ image=image,
131
+ prompt=prompt,
132
+ negative_prompt=negative_prompt,
133
+ height=self.config.height,
134
+ width=self.config.width,
135
+ num_frames=self.config.num_frames,
136
+ guidance_scale=self.config.guidance_scale,
137
+ num_inference_steps=self.config.num_inference_steps,
138
+ output_type="pil" if dist.get_rank() == 0 else "pt",
139
+ ).frames[0]
140
+
141
+ # Save video if primary process
142
+ if dist.get_rank() == 0:
143
+ self._save_video(output, output_path)
144
+ self._print_statistics(start_time)
145
+
146
+ except Exception as e:
147
+ raise RuntimeError(f"Video generation failed: {str(e)}")
148
+ finally:
149
+ self._cleanup()
150
+
151
+ def _prepare_image(self, image_path: Union[str, Path]) -> Image.Image:
152
+ """Load and preprocess input image"""
153
+ image = load_image(image_path)
154
+ return image.resize((self.config.width, self.config.height))
155
+
156
+ def _save_video(self, frames, output_path: str):
157
+ """Save generated frames as video"""
158
+ if isinstance(frames[0], torch.Tensor):
159
+ frames = [frame.cpu() if frame.device.type == 'cuda' else frame for frame in frames]
160
+ export_to_video(frames, output_path, fps=self.config.fps)
161
+ print(f"Video saved to {output_path}")
162
+
163
+ def _print_statistics(self, start_time: float):
164
+ """Print generation statistics"""
165
+ print(f"{'=' * 50}")
166
+ print(f"Device: {torch.cuda.get_device_name()}")
167
+ print(f"Number of GPUs: {dist.get_world_size()}")
168
+ print(f"Resolution: {self.config.width}x{self.config.height}")
169
+ print(f"Generation Time: {time.time() - start_time:.2f} seconds")
170
+ print(f"{'=' * 50}")
171
+
172
+ def _cleanup(self):
173
+ """Cleanup resources"""
174
+ torch.cuda.empty_cache()
175
+ import gc
176
+ gc.collect()
177
+
178
+ def __del__(self):
179
+ """Cleanup on destruction"""
180
+ if dist.is_initialized():
181
+ dist.destroy_process_group()
182
+
183
+ # Example usage:
184
+ if __name__ == "__main__":
185
+ config = WanPipelineConfig()
186
+ pipeline = WanI2VPipeline(config)
187
+ pipeline.load_models()
188
+
189
+ prompt = "Cars racing in slow motion"
190
+ negative_prompt = (
191
+ "bright colors, overexposed, static, blurred details, subtitles, "
192
+ "style, artwork, painting, picture, still, overall gray, worst quality, "
193
+ "low quality, JPEG compression residue, ugly, incomplete, extra fingers, "
194
+ "poorly drawn hands, poorly drawn faces, deformed, disfigured, malformed limbs, "
195
+ "fused fingers, still picture, cluttered background, three legs, "
196
+ "many people in the background, walking backwards"
197
+ )
198
+
199
+ pipeline.generate_video(
200
+ image_path="car_720p.png",
201
+ prompt=prompt,
202
+ negative_prompt=negative_prompt,
203
+ output_path="wan-i2v.mp4"
204
+ )