Upload run_wan_I2V_CP+FB_GroupOffload.py
Browse files
run_wan_I2V_CP+FB_GroupOffload.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.distributed as dist
|
| 3 |
+
import numpy as np
|
| 4 |
+
import time
|
| 5 |
+
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UMT5EncoderModel, WanTransformer3DModel
|
| 6 |
+
from diffusers.utils import export_to_video, load_image
|
| 7 |
+
from transformers import CLIPVisionModel, UMT5EncoderModel
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
|
| 10 |
+
import io
|
| 11 |
+
from diffusers.hooks.group_offloading import apply_group_offloading
|
| 12 |
+
|
| 13 |
+
dist.init_process_group()
|
| 14 |
+
|
| 15 |
+
torch.cuda.set_device(dist.get_rank())
|
| 16 |
+
|
| 17 |
+
start_loading = time.time()
|
| 18 |
+
|
| 19 |
+
data_type = torch.bfloat16
|
| 20 |
+
|
| 21 |
+
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
|
| 22 |
+
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
|
| 23 |
+
|
| 24 |
+
# image_encoder = CLIPVisionModel.from_pretrained(
|
| 25 |
+
# model_id, subfolder="image_encoder", torch_dtype=data_type
|
| 26 |
+
# )
|
| 27 |
+
# vae = AutoencoderKLWan.from_pretrained(
|
| 28 |
+
# model_id, subfolder="vae", torch_dtype=torch.float32
|
| 29 |
+
# )
|
| 30 |
+
|
| 31 |
+
image_encoder = CLIPVisionModel.from_pretrained(
|
| 32 |
+
model_id, subfolder="image_encoder", torch_dtype=torch.float32
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
|
| 36 |
+
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
| 37 |
+
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
onload_device = torch.device("cuda")
|
| 41 |
+
offload_device = torch.device("cpu")
|
| 42 |
+
|
| 43 |
+
apply_group_offloading(text_encoder,
|
| 44 |
+
onload_device=onload_device,
|
| 45 |
+
offload_device=offload_device,
|
| 46 |
+
offload_type="block_level",
|
| 47 |
+
num_blocks_per_group=4
|
| 48 |
+
)
|
| 49 |
+
# pipe = WanImageToVideoPipeline.from_pretrained(
|
| 50 |
+
# model_id, vae=vae, image_encoder=image_encoder, torch_dtype=data_type
|
| 51 |
+
# )
|
| 52 |
+
|
| 53 |
+
pipe = WanImageToVideoPipeline.from_pretrained(
|
| 54 |
+
model_id,
|
| 55 |
+
vae=vae,
|
| 56 |
+
transformer=transformer,
|
| 57 |
+
text_encoder=text_encoder,
|
| 58 |
+
image_encoder=image_encoder,
|
| 59 |
+
torch_dtype=torch.bfloat16
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0)
|
| 63 |
+
pipe.to("cuda")
|
| 64 |
+
|
| 65 |
+
# Import and apply parallel attention
|
| 66 |
+
from para_attn.context_parallel import init_context_parallel_mesh
|
| 67 |
+
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
|
| 68 |
+
|
| 69 |
+
parallelize_pipe(
|
| 70 |
+
pipe,
|
| 71 |
+
mesh=init_context_parallel_mesh(
|
| 72 |
+
pipe.device.type,
|
| 73 |
+
),
|
| 74 |
+
)
|
| 75 |
+
end_loading = time.time()
|
| 76 |
+
|
| 77 |
+
from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
|
| 78 |
+
apply_cache_on_pipe(pipe , residual_diff_threshold=0.05)
|
| 79 |
+
|
| 80 |
+
# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank())
|
| 81 |
+
# pipe.enable_vae_tiling()
|
| 82 |
+
|
| 83 |
+
# torch._inductor.config.reorder_for_compute_comm_overlap = True
|
| 84 |
+
# pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")
|
| 85 |
+
print("Pipeline Loaded.....")
|
| 86 |
+
loading_time = end_loading - start_loading
|
| 87 |
+
|
| 88 |
+
prompt = (
|
| 89 |
+
"Cars racing in slow motion"
|
| 90 |
+
)
|
| 91 |
+
negative_prompt = (
|
| 92 |
+
"bright colors, overexposed, static, blurred details, subtitles, style, artwork, painting, picture, still, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, malformed limbs, fused fingers, still picture, cluttered background, three legs, many people in the background, walking backwards"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
image = load_image(
|
| 96 |
+
"https://storage.googleapis.com/falserverless/gallery/car_720p.png"
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# max_area = 1024 * 1024
|
| 100 |
+
# aspect_ratio = image.height / image.width
|
| 101 |
+
# mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
| 102 |
+
# print(f"MOD VALUE :{mod_value}")
|
| 103 |
+
# height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
| 104 |
+
# width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
| 105 |
+
width , height = image.size
|
| 106 |
+
# image = image.resize((width, height))
|
| 107 |
+
|
| 108 |
+
# Start measuring inference time
|
| 109 |
+
start_inference = time.time()
|
| 110 |
+
|
| 111 |
+
# Run the pipeline
|
| 112 |
+
output = pipe(
|
| 113 |
+
image=image,
|
| 114 |
+
prompt=prompt,
|
| 115 |
+
negative_prompt=negative_prompt,
|
| 116 |
+
height=height,
|
| 117 |
+
width=width,
|
| 118 |
+
num_frames=81,
|
| 119 |
+
guidance_scale=5.0,
|
| 120 |
+
num_inference_steps=30,
|
| 121 |
+
output_type="pil" if dist.get_rank() == 0 else "pt",
|
| 122 |
+
).frames[0]
|
| 123 |
+
|
| 124 |
+
# End of inference time measurement
|
| 125 |
+
end_inference = time.time()
|
| 126 |
+
inference_time = end_inference - start_inference
|
| 127 |
+
|
| 128 |
+
# Save output and print timing info
|
| 129 |
+
if dist.get_rank() == 0:
|
| 130 |
+
print(f"{'=' * 50}")
|
| 131 |
+
print(f"Device: {torch.cuda.get_device_name()}")
|
| 132 |
+
print(f"Number of GPUs: {dist.get_world_size()}")
|
| 133 |
+
print(f"Pipeline Loading Time: {loading_time:.2f} seconds")
|
| 134 |
+
print(f"Pipeline Inference Time: {inference_time:.2f} seconds")
|
| 135 |
+
print(f"Resolution: {width}x{height}")
|
| 136 |
+
print(f"{'=' * 50}")
|
| 137 |
+
print("Saving video to wan-i2v.mp4")
|
| 138 |
+
export_to_video(output, "wan-i2v.mp4", fps=16)
|
| 139 |
+
|
| 140 |
+
dist.destroy_process_group()
|