test / wan_i2v.py
Varad1707's picture
Upload wan_i2v.py
95bb22a verified
import torch
import torch.distributed as dist
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline,WanTransformer3DModel
from diffusers.utils import export_to_video, load_image
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from transformers import CLIPVisionModel, UMT5EncoderModel
from para_attn.context_parallel import init_context_parallel_mesh
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
import gc
import time
import logging
from PIL import Image
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class WanI2V(WanImageToVideoPipeline):
"""
Wan Image to Video Pipeline
Supports 14B and 720P models (Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers)
"""
def __init__(self,model_id,apply_cache=True):
logger.info(f"Initializing WanI2V pipeline with model {model_id}")
dist.init_process_group()
torch.cuda.set_device(dist.get_rank())
start_load_time = time.time()
self.device = "cuda"
self.pipe = None
self.model_id = model_id
self.dtype = torch.bfloat16
self.apply_cache = apply_cache
if self.model_id == "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers":
self.flow_shift = 3.0
elif self.model_id == "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers":
self.flow_shift = 5.0
try:
self.load_model()
self.optimize_pipe()
logger.info(f"Pipeline initialized {model_id} in {time.time() - start_load_time} seconds")
self.warmup()
except Exception as e:
logger.error(f"Error initializing {model_id}: {e}")
raise e
def load_model(self):
self.text_encoder = UMT5EncoderModel.from_pretrained(self.model_id, subfolder="text_encoder", torch_dtype=self.dtype)
self.vae = AutoencoderKLWan.from_pretrained(self.model_id, subfolder="vae", torch_dtype=torch.float32)
self.transformer = WanTransformer3DModel.from_pretrained(self.model_id, subfolder="transformer", torch_dtype=self.dtype)
self.image_encoder = CLIPVisionModel.from_pretrained(self.model_id, subfolder="image_encoder", torch_dtype=torch.float32)
self.pipe = WanImageToVideoPipeline.from_pretrained(
self.model_id,
vae=self.vae,
transformer=self.transformer,
text_encoder=self.text_encoder,
image_encoder=self.image_encoder,
torch_dtype=self.dtype
)
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config, flow_shift=self.flow_shift)
self.pipe.to(self.device)
def optimize_pipe(self):
self.pipe.transformer.enable_gradient_checkpointing() # Enable gradient checkpointing
self.pipe.enable_attention_slicing(slice_size="auto")
parallelize_pipe(
self.pipe,
mesh=init_context_parallel_mesh(
self.pipe.device.type,
),
)
if self.apply_cache:
from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
apply_cache_on_pipe(self.pipe , residual_diff_threshold=0.1)
def clear_memory(self):
torch.cuda.empty_cache()
gc.collect()
def generate_video(self,prompt : str,negative_prompt : str,image : Image.Image,num_frames : int = 81,guidance_scale : float = 5.0,num_inference_steps : int = 30,height : int = 576,width : int = 1024):
with torch.inference_mode():
output = self.pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
output_type="pil" if dist.get_rank() == 0 else "pt",
).frames[0]
if dist.get_rank() == 0:
self.clear_memory()
if isinstance(output[0], torch.Tensor):
output = [frame.cpu() if frame.device.type == 'cuda' else frame for frame in output]
return output
def warmup(self):
logger.info("Running Warm Up!")
prompt = "A car driving on a road"
negative_prompt = "blurry, low quality, dark"
image = load_image("https://storage.googleapis.com/falserverless/gallery/car_720p.png")
start_time = time.time()
with torch.inference_mode():
self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
num_inference_steps=30,
height=576,
width=1024,
num_frames=81,
guidance_scale=5.0
)
self.get_matrix(start_time,time.time(),576,1024)
logger.info("Warm Up Completed!")
self.clear_memory()
def shutdown(self):
dist.destroy_process_group()
def get_matrix(self,start_time : int,end_time : int,height : int,width : int):
logger.info("-"*20)
logger.info(f"inference time : {end_time - start_time}")
logger.info(f"height : {height}")
logger.info(f"width : {width}")
logger.info("-"*20)
if __name__ == "__main__":
inputs = {
"1" : {
"prompt" : "Cars racing in slow motion",
"negative_prompt" : "bright colors, overexposed, static, blurred details, subtitles, style, artwork, painting, picture, still, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, malformed limbs, fused fingers, still picture, cluttered background, three legs, many people in the background, walking backwards",
"image" : load_image("https://storage.googleapis.com/falserverless/gallery/car_720p.png")
},
"2" : {
"prompt" : "A cat in a car",
"negative_prompt" : "bright colors, overexposed, static, blurred details, subtitles, style, artwork, painting, picture, still, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, malformed limbs, fused fingers, still picture, cluttered background, three legs, many people in the background, walking backwards",
"image" : load_image("https://fancypawscatclinic.com/uploads/SiteAssets/426/images/services/payment-options-cat-720px.jpg")
}
}
RESOLUTION_CONFIG = {
"Horizontal" : {
"height" : 576,
"width" : 1024
},
"Vertical" : {
"height" : 1024,
"width" : 576
},
"Square" : {
"height" : 768,
"width" : 768
}
}
WanI2V = WanI2V("Wan-AI/Wan2.1-I2V-14B-480P-Diffusers",apply_cache=True)
resolution = "Horizontal"
for i in range(0,len(inputs)):
for j in range(0,len(inputs[str(i+1)])):
prompt = inputs[str(i+1)]["prompt"]
negative_prompt = inputs[str(i+1)]["negative_prompt"]
image = inputs[str(i+1)]["image"]
start_time = time.time()
WanI2V.generate_video(prompt,negative_prompt,image)
end_time = time.time()
WanI2V.get_matrix(start_time,end_time,RESOLUTION_CONFIG[resolution]["height"],RESOLUTION_CONFIG[resolution]["width"])
WanI2V.shutdown()