| import torch, warnings, glob, os, types |
| import numpy as np |
| from PIL import Image |
| from einops import repeat, reduce |
| from typing import Optional, Union |
| from dataclasses import dataclass |
| from modelscope import snapshot_download |
| from einops import rearrange |
| import numpy as np |
| from PIL import Image |
| from tqdm import tqdm |
| from typing import Optional |
| from typing_extensions import Literal |
| import torch.nn.functional as F |
| from PIL import Image, ImageOps |
|
|
| from diffsynth.core import ModelConfig |
| from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit, PipelineUnitRunner |
| from diffsynth.models import ModelManager, load_state_dict |
| from diffsynth.models.wan_video_dit_mvid import WanModel, RMSNorm, sinusoidal_embedding_1d |
| from diffsynth.models.wan_video_dit_s2v import rope_precompute |
| from diffsynth.models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm |
| from diffsynth.models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample |
| from diffsynth.models.wan_video_image_encoder import WanImageEncoder |
| from diffsynth.models.wan_video_vace import VaceWanModel |
| from diffsynth.models.wan_video_motion_controller import WanMotionControllerModel |
| from diffsynth.schedulers.flow_match import FlowMatchScheduler |
| from diffsynth.prompters import WanPrompter |
| from diffsynth.vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm |
| from diffsynth.lora import GeneralLoRALoader |
| from diffsynth.utils.data import save_video |
|
|
|
|
|
|
| import random |
| from torchvision.transforms import Compose, Normalize, ToTensor |
|
|
|
|
| class WanVideoPipeline(BasePipeline): |
|
|
| def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None): |
| super().__init__( |
| device=device, torch_dtype=torch_dtype, |
| height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 |
| ) |
| self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) |
| self.prompter = WanPrompter(tokenizer_path=tokenizer_path) |
| self.text_encoder: WanTextEncoder = None |
| self.image_encoder: WanImageEncoder = None |
| self.dit: WanModel = None |
| self.dit2: WanModel = None |
| self.vae: WanVideoVAE = None |
| self.motion_controller: WanMotionControllerModel = None |
| self.vace: VaceWanModel = None |
| self.in_iteration_models = ("dit", "motion_controller", "vace") |
| self.in_iteration_models_2 = ("dit2", "motion_controller", "vace") |
| self.unit_runner = PipelineUnitRunner() |
| self.units = [ |
| WanVideoUnit_ShapeChecker(), |
| WanVideoUnit_NoiseInitializer(), |
| WanVideoUnit_PromptEmbedder(), |
| |
| WanVideoUnit_InputVideoEmbedder(), |
| WanVideoUnit_ImageEmbedderVAE(), |
| WanVideoUnit_ImageEmbedderCLIP(), |
| WanVideoUnit_ImageEmbedderFused(), |
| WanVideoUnit_VideoEmbedderFused(), |
| WanVideoUnit_RefEmbedderFused(), |
| WanVideoUnit_FunControl(), |
| WanVideoUnit_FunReference(), |
| WanVideoUnit_FunCameraControl(), |
| WanVideoUnit_SpeedControl(), |
| |
| WanVideoUnit_UnifiedSequenceParallel(), |
| WanVideoUnit_TeaCache(), |
| WanVideoUnit_CfgMerger(), |
| ] |
|
|
| self.model_fn = model_fn_wan_video |
|
|
| |
| def extrac_ref_latents(self, ref_images, vae, device, dtype, min_value=-1., max_value=1.): |
| |
| ref_vae_latents = [] |
| for img in ref_images: |
| img = torch.Tensor(np.array(img, dtype=np.float32)) |
| img = img.to(dtype=dtype, device=device) |
| img = img * ((max_value - min_value) / 255) + min_value |
| img_vae_latent = vae.encode([img.permute(2,0,1).unsqueeze(1)], device=device) |
| ref_vae_latents.append(img_vae_latent) |
| return torch.cat(ref_vae_latents, dim=2) |
| |
|
|
| |
| def load_lora(self, module, path, alpha=1): |
| loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device) |
| lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device) |
| loader.load(module, lora, alpha=alpha) |
|
|
| |
| def training_loss(self, **inputs): |
| max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps) |
| min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps) |
| timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) |
| timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device) |
| |
| inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep) |
| if inputs["ref_images_latents"] is not None: |
| if random.random() < inputs["args"].zero_face_ratio: |
| inputs["latents"] = torch.cat([inputs["latents"], torch.zeros_like(inputs['ref_images_latents'])], dim=2) |
| else: |
| inputs["latents"] = torch.cat([inputs["latents"], inputs['ref_images_latents']], dim=2) |
| training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep) |
| |
| noise_pred = self.model_fn(**inputs, timestep=timestep) |
| |
| loss = torch.nn.functional.mse_loss(noise_pred.float()[:, :, :-inputs["num_ref_images"]], training_target.float()) |
| loss = loss * self.scheduler.training_weight(timestep) |
| return loss |
|
|
|
|
| def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5): |
| self.vram_management_enabled = True |
| if num_persistent_param_in_dit is not None: |
| vram_limit = None |
| else: |
| if vram_limit is None: |
| vram_limit = self.get_vram() |
| vram_limit = vram_limit - vram_buffer |
| if self.text_encoder is not None: |
| dtype = next(iter(self.text_encoder.parameters())).dtype |
| enable_vram_management( |
| self.text_encoder, |
| module_map = { |
| torch.nn.Linear: AutoWrappedLinear, |
| torch.nn.Embedding: AutoWrappedModule, |
| T5RelativeEmbedding: AutoWrappedModule, |
| T5LayerNorm: AutoWrappedModule, |
| }, |
| module_config = dict( |
| offload_dtype=dtype, |
| offload_device="cpu", |
| onload_dtype=dtype, |
| onload_device="cpu", |
| computation_dtype=self.torch_dtype, |
| computation_device=self.device, |
| ), |
| vram_limit=vram_limit, |
| ) |
| if self.dit is not None: |
| dtype = next(iter(self.dit.parameters())).dtype |
| device = "cpu" if vram_limit is not None else self.device |
| enable_vram_management( |
| self.dit, |
| module_map = { |
| torch.nn.Linear: AutoWrappedLinear, |
| torch.nn.Conv3d: AutoWrappedModule, |
| torch.nn.LayerNorm: WanAutoCastLayerNorm, |
| RMSNorm: AutoWrappedModule, |
| torch.nn.Conv2d: AutoWrappedModule, |
| torch.nn.Conv1d: AutoWrappedModule, |
| torch.nn.Embedding: AutoWrappedModule, |
| }, |
| module_config = dict( |
| offload_dtype=dtype, |
| offload_device="cpu", |
| onload_dtype=dtype, |
| onload_device=device, |
| computation_dtype=self.torch_dtype, |
| computation_device=self.device, |
| ), |
| max_num_param=num_persistent_param_in_dit, |
| overflow_module_config = dict( |
| offload_dtype=dtype, |
| offload_device="cpu", |
| onload_dtype=dtype, |
| onload_device="cpu", |
| computation_dtype=self.torch_dtype, |
| computation_device=self.device, |
| ), |
| vram_limit=vram_limit, |
| ) |
| if self.dit2 is not None: |
| dtype = next(iter(self.dit2.parameters())).dtype |
| device = "cpu" if vram_limit is not None else self.device |
| enable_vram_management( |
| self.dit2, |
| module_map = { |
| torch.nn.Linear: AutoWrappedLinear, |
| torch.nn.Conv3d: AutoWrappedModule, |
| torch.nn.LayerNorm: WanAutoCastLayerNorm, |
| RMSNorm: AutoWrappedModule, |
| torch.nn.Conv2d: AutoWrappedModule, |
| }, |
| module_config = dict( |
| offload_dtype=dtype, |
| offload_device="cpu", |
| onload_dtype=dtype, |
| onload_device=device, |
| computation_dtype=self.torch_dtype, |
| computation_device=self.device, |
| ), |
| max_num_param=num_persistent_param_in_dit, |
| overflow_module_config = dict( |
| offload_dtype=dtype, |
| offload_device="cpu", |
| onload_dtype=dtype, |
| onload_device="cpu", |
| computation_dtype=self.torch_dtype, |
| computation_device=self.device, |
| ), |
| vram_limit=vram_limit, |
| ) |
| if self.vae is not None: |
| dtype = next(iter(self.vae.parameters())).dtype |
| enable_vram_management( |
| self.vae, |
| module_map = { |
| torch.nn.Linear: AutoWrappedLinear, |
| torch.nn.Conv2d: AutoWrappedModule, |
| RMS_norm: AutoWrappedModule, |
| CausalConv3d: AutoWrappedModule, |
| Upsample: AutoWrappedModule, |
| torch.nn.SiLU: AutoWrappedModule, |
| torch.nn.Dropout: AutoWrappedModule, |
| }, |
| module_config = dict( |
| offload_dtype=dtype, |
| offload_device="cpu", |
| onload_dtype=dtype, |
| onload_device=self.device, |
| computation_dtype=self.torch_dtype, |
| computation_device=self.device, |
| ), |
| ) |
| if self.image_encoder is not None: |
| dtype = next(iter(self.image_encoder.parameters())).dtype |
| enable_vram_management( |
| self.image_encoder, |
| module_map = { |
| torch.nn.Linear: AutoWrappedLinear, |
| torch.nn.Conv2d: AutoWrappedModule, |
| torch.nn.LayerNorm: AutoWrappedModule, |
| }, |
| module_config = dict( |
| offload_dtype=dtype, |
| offload_device="cpu", |
| onload_dtype=dtype, |
| onload_device="cpu", |
| computation_dtype=dtype, |
| computation_device=self.device, |
| ), |
| ) |
| if self.motion_controller is not None: |
| dtype = next(iter(self.motion_controller.parameters())).dtype |
| enable_vram_management( |
| self.motion_controller, |
| module_map = { |
| torch.nn.Linear: AutoWrappedLinear, |
| }, |
| module_config = dict( |
| offload_dtype=dtype, |
| offload_device="cpu", |
| onload_dtype=dtype, |
| onload_device="cpu", |
| computation_dtype=dtype, |
| computation_device=self.device, |
| ), |
| ) |
| if self.vace is not None: |
| device = "cpu" if vram_limit is not None else self.device |
| enable_vram_management( |
| self.vace, |
| module_map = { |
| torch.nn.Linear: AutoWrappedLinear, |
| torch.nn.Conv3d: AutoWrappedModule, |
| torch.nn.LayerNorm: AutoWrappedModule, |
| RMSNorm: AutoWrappedModule, |
| }, |
| module_config = dict( |
| offload_dtype=dtype, |
| offload_device="cpu", |
| onload_dtype=dtype, |
| onload_device=device, |
| computation_dtype=self.torch_dtype, |
| computation_device=self.device, |
| ), |
| vram_limit=vram_limit, |
| ) |
| |
| def initialize_usp(self): |
| import torch.distributed as dist |
| from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment |
| dist.init_process_group(backend="nccl", init_method="env://") |
| init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) |
| initialize_model_parallel( |
| sequence_parallel_degree=dist.get_world_size(), |
| ring_degree=1, |
| ulysses_degree=dist.get_world_size(), |
| ) |
| torch.cuda.set_device(dist.get_rank()) |
| |
| |
| def enable_usp(self): |
| from xfuser.core.distributed import get_sequence_parallel_world_size |
| from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward |
|
|
| for block in self.dit.blocks: |
| block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) |
| self.dit.forward = types.MethodType(usp_dit_forward, self.dit) |
| if self.dit2 is not None: |
| for block in self.dit2.blocks: |
| block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) |
| self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2) |
| self.sp_size = get_sequence_parallel_world_size() |
| self.use_unified_sequence_parallel = True |
|
|
|
|
| @staticmethod |
| def from_pretrained( |
| torch_dtype: torch.dtype = torch.bfloat16, |
| device: Union[str, torch.device] = "cuda", |
| model_configs: list[ModelConfig] = [], |
| tokenizer_config: ModelConfig = ModelConfig(model_id="/root/paddle_job/workspace/qizipeng/wanx_pretrainedmodels/Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"), |
| audio_processor_config: ModelConfig = None, |
| redirect_common_files: bool = True, |
| use_usp=False, |
| ): |
| |
| if redirect_common_files: |
| redirect_dict = { |
| "models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B", |
| "Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B", |
| "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P", |
| } |
| for model_config in model_configs: |
| if model_config.origin_file_pattern is None or model_config.model_id is None: |
| continue |
| if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern]: |
| print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.") |
| model_config.model_id = redirect_dict[model_config.origin_file_pattern] |
| |
| |
| pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) |
| if use_usp: pipe.initialize_usp() |
| |
| |
| model_manager = ModelManager() |
| for model_config in model_configs: |
| model_config.download_if_necessary(use_usp=use_usp) |
| model_manager.load_model( |
| model_config.path, |
| device=model_config.offload_device or device, |
| torch_dtype=model_config.offload_dtype or torch_dtype |
| ) |
| |
| |
| pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder") |
| dit = model_manager.fetch_model("wan_video_dit", index=2) |
| if isinstance(dit, list): |
| pipe.dit, pipe.dit2 = dit |
| else: |
| pipe.dit = dit |
| pipe.vae = model_manager.fetch_model("wan_video_vae") |
| |
| if pipe.vae is not None: |
| pipe.height_division_factor = pipe.vae.upsampling_factor * 2 |
| pipe.width_division_factor = pipe.vae.upsampling_factor * 2 |
|
|
| tokenizer_config.download_if_necessary(use_usp=use_usp) |
| pipe.prompter.fetch_models(pipe.text_encoder) |
| |
| pipe.prompter.fetch_tokenizer('/root/paddlejob/workspace/qizipeng/wanx_pretrainedmodels/Wan2.2-TI2V-5B/google/umt5-xxl') |
|
|
| if audio_processor_config is not None: |
| audio_processor_config.download_if_necessary(use_usp=use_usp) |
| from transformers import Wav2Vec2Processor |
| pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path) |
| |
| if use_usp: pipe.enable_usp() |
| return pipe |
|
|
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| args, |
| |
| prompt: str, |
| negative_prompt: Optional[str] = "", |
| |
| input_image: Optional[Image.Image] = None, |
| |
| end_image: Optional[Image.Image] = None, |
| |
| input_video: Optional[list[Image.Image]] = None, |
| input_pre_video: Optional[list[Image.Image]] = None, |
| ref_images: Optional[list[Image.Image]] = None, |
| prev_latent=None, |
| denoising_strength: Optional[float] = 1.0, |
| |
| input_audio: Optional[str] = None, |
| audio_sample_rate: Optional[int] = 16000, |
| s2v_pose_video: Optional[list[Image.Image]] = None, |
| |
| control_video: Optional[list[Image.Image]] = None, |
| reference_image: Optional[Image.Image] = None, |
| |
| camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None, |
| camera_control_speed: Optional[float] = 1/54, |
| camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0), |
| |
| vace_video: Optional[list[Image.Image]] = None, |
| vace_video_mask: Optional[Image.Image] = None, |
| vace_reference_image: Optional[Image.Image] = None, |
| vace_scale: Optional[float] = 1.0, |
| |
| seed: Optional[int] = None, |
| rand_device: Optional[str] = "cpu", |
| |
| height: Optional[int] = 480, |
| width: Optional[int] = 832, |
| num_frames=81, |
| |
| cfg_scale: Optional[float] = 5.0, |
| cfg_scale_face: Optional[float] = 5.0, |
| cfg_merge: Optional[bool] = False, |
| |
| switch_DiT_boundary: Optional[float] = 0.875, |
| |
| num_inference_steps: Optional[int] = 50, |
| sigma_shift: Optional[float] = 5.0, |
| |
| motion_bucket_id: Optional[int] = None, |
| |
| tiled: Optional[bool] = True, |
| tile_size: Optional[tuple[int, int]] = (30, 52), |
| tile_stride: Optional[tuple[int, int]] = (15, 26), |
| |
| sliding_window_size: Optional[int] = None, |
| sliding_window_stride: Optional[int] = None, |
| |
| tea_cache_l1_thresh: Optional[float] = None, |
| tea_cache_model_id: Optional[str] = "", |
| |
| progress_bar_cmd=tqdm, |
| num_ref_images: Optional[int] = None, |
| ): |
| |
| self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) |
| |
| |
| inputs_posi = { |
| "prompt": prompt, |
| "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, |
| } |
| inputs_nega = { |
| "negative_prompt": negative_prompt, |
| "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, |
| } |
| inputs_shared = { |
| "input_image": input_image, |
| "end_image": end_image, |
| "input_video": input_video, "denoising_strength": denoising_strength, |
| "input_pre_video":input_pre_video, |
| "ref_images":ref_images, |
| "control_video": control_video, "reference_image": reference_image, |
| "camera_control_direction": camera_control_direction, "camera_control_speed": camera_control_speed, "camera_control_origin": camera_control_origin, |
| "vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale, |
| "seed": seed, "rand_device": rand_device, |
| "height": height, "width": width, "num_frames": num_frames, |
| "cfg_scale": cfg_scale, "cfg_merge": cfg_merge, |
| "sigma_shift": sigma_shift, |
| "motion_bucket_id": motion_bucket_id, |
| "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, |
| "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, |
| "input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, |
| "num_ref_images":num_ref_images, |
| "batch_size": 1 |
| } |
| for unit in self.units: |
| inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) |
|
|
| |
| self.load_models_to_device(self.in_iteration_models) |
| models = {name: getattr(self, name) for name in self.in_iteration_models} |
| for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): |
| |
| if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2: |
| self.load_models_to_device(self.in_iteration_models_2) |
| models["dit"] = self.dit2 |
| |
| |
| timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) |
| |
| |
| noise_pred_posi = self.model_fn(args, **models, **inputs_shared, **inputs_posi, timestep=timestep) |
| if cfg_scale != 1.0: |
| if cfg_merge: |
| noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0) |
| else: |
| |
| if 'ref_images_latents' in inputs_shared: |
| inputs_shared['latents'][:, :, -inputs_shared["ref_images_latents"].shape[2]:] = torch.zeros_like(inputs_shared['ref_images_latents']) |
| noise_pred_nega_face = self.model_fn(args, **models, **inputs_shared, **inputs_posi, timestep=timestep) |
| noise_all_eng = self.model_fn(args, **models, **inputs_shared, **inputs_nega, timestep=timestep) |
| noise_pred = noise_all_eng + cfg_scale * (noise_pred_posi - noise_pred_nega_face) + cfg_scale_face * (noise_pred_nega_face - noise_all_eng) |
| else: |
| noise_pred = noise_pred_posi |
|
|
| |
| inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) |
| |
| if "ref_images_latents" in inputs_shared: |
| inputs_shared["latents"][:, :, -inputs_shared["ref_images_latents"].shape[2]:] = inputs_shared["ref_images_latents"] |
| |
| |
| |
| |
| |
| |
| |
| |
| if vace_reference_image is not None: |
| inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:] |
|
|
| |
| if "ref_images_latents" in inputs_shared: |
| inputs_shared["latents"] = inputs_shared["latents"][:, :, :-inputs_shared["ref_images_latents"].shape[2]] |
| self.load_models_to_device(['vae']) |
| video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) |
| video = self.vae_output_to_video(video) |
| self.load_models_to_device([]) |
|
|
| return video, inputs_shared["latents"] |
|
|
|
|
|
|
| class WanVideoUnit_ShapeChecker(PipelineUnit): |
| def __init__(self): |
| super().__init__(input_params=("height", "width", "num_frames")) |
|
|
| def process(self, pipe: WanVideoPipeline, height, width, num_frames): |
| height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) |
| return {"height": height, "width": width, "num_frames": num_frames} |
|
|
|
|
|
|
| class WanVideoUnit_NoiseInitializer(PipelineUnit): |
| def __init__(self): |
| super().__init__(input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image", "batch_size")) |
|
|
| def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image, batch_size = 1): |
| length = (num_frames - 1) // 4 + 1 |
| if vace_reference_image is not None: |
| length += 1 |
|
|
| shape = (batch_size, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor) |
| |
| noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device) |
| if vace_reference_image is not None: |
| noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2) |
|
|
| return {"noise": noise} |
| |
|
|
|
|
| class WanVideoUnit_InputVideoEmbedder(PipelineUnit): |
| def __init__(self): |
| super().__init__( |
| input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"), |
| onload_model_names=("vae",) |
| ) |
|
|
| def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image): |
| if input_video is None: |
| return {"latents": noise} |
| pipe.load_models_to_device(["vae"]) |
| input_latents = [] |
| for input_video_ in input_video: |
| input_video_ = pipe.preprocess_video(input_video_) |
| input_latent_ = pipe.vae.encode(input_video_, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) |
| input_latents.append(input_latent_) |
| input_latents = torch.cat(input_latents, dim = 0) |
| |
| |
| |
| |
| if pipe.scheduler.training: |
| return {"latents": noise, "input_latents": input_latents} |
| else: |
| latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) |
| return {"latents": latents} |
|
|
|
|
|
|
| class WanVideoUnit_PromptEmbedder(PipelineUnit): |
| def __init__(self): |
| super().__init__( |
| seperate_cfg=True, |
| input_params_posi={"prompt": "prompt", "positive": "positive"}, |
| input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, |
| onload_model_names=("text_encoder",) |
| ) |
|
|
| def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict: |
| |
| pipe.text_encoder = pipe.text_encoder.to(pipe.device) |
| prompt_emb_list = [] |
| for prompt_ in prompt: |
| prompt_emb_ = pipe.prompter.encode_prompt(prompt_, positive=positive, device=pipe.device) |
| prompt_emb_list.append(prompt_emb_) |
| prompt_emb = torch.cat(prompt_emb_list, dim = 0) |
| return {"context": prompt_emb} |
|
|
|
|
|
|
| class WanVideoUnit_ImageEmbedder(PipelineUnit): |
| """ |
| Deprecated |
| """ |
| def __init__(self): |
| super().__init__( |
| input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), |
| onload_model_names=("image_encoder", "vae") |
| ) |
|
|
| def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): |
| if input_image is None or pipe.image_encoder is None: |
| return {} |
| pipe.load_models_to_device(self.onload_model_names) |
| image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) |
| clip_context = pipe.image_encoder.encode_image([image]) |
| msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) |
| msk[:, 1:] = 0 |
| if end_image is not None: |
| end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) |
| vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) |
| if pipe.dit.has_image_pos_emb: |
| clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) |
| msk[:, -1:] = 1 |
| else: |
| vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) |
|
|
| msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) |
| msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) |
| msk = msk.transpose(1, 2)[0] |
| |
| y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] |
| y = y.to(dtype=pipe.torch_dtype, device=pipe.device) |
| y = torch.concat([msk, y]) |
| y = y.unsqueeze(0) |
| clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) |
| y = y.to(dtype=pipe.torch_dtype, device=pipe.device) |
| return {"clip_feature": clip_context, "y": y} |
|
|
|
|
|
|
| class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit): |
| def __init__(self): |
| super().__init__( |
| input_params=("input_image", "end_image", "height", "width"), |
| onload_model_names=("image_encoder",) |
| ) |
|
|
| def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width): |
| if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding: |
| return {} |
| pipe.load_models_to_device(self.onload_model_names) |
| image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) |
| clip_context = pipe.image_encoder.encode_image([image]) |
| if end_image is not None: |
| end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) |
| if pipe.dit.has_image_pos_emb: |
| clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) |
| clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) |
| return {"clip_feature": clip_context} |
| |
|
|
|
|
| class WanVideoUnit_ImageEmbedderVAE(PipelineUnit): |
| def __init__(self): |
| super().__init__( |
| input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), |
| onload_model_names=("vae",) |
| ) |
|
|
| def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): |
| if input_image is None or not pipe.dit.require_vae_embedding: |
| return {} |
| pipe.load_models_to_device(self.onload_model_names) |
| image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) |
| msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) |
| msk[:, 1:] = 0 |
| if end_image is not None: |
| end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) |
| vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) |
| msk[:, -1:] = 1 |
| else: |
| vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) |
|
|
| msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) |
| msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) |
| msk = msk.transpose(1, 2)[0] |
| |
| y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] |
| y = y.to(dtype=pipe.torch_dtype, device=pipe.device) |
| y = torch.concat([msk, y]) |
| y = y.unsqueeze(0) |
| y = y.to(dtype=pipe.torch_dtype, device=pipe.device) |
| return {"y": y} |
|
|
|
|
|
|
| class WanVideoUnit_ImageEmbedderFused(PipelineUnit): |
| """ |
| Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. |
| """ |
| def __init__(self): |
| super().__init__( |
| input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"), |
| onload_model_names=("vae",) |
| ) |
|
|
| def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride): |
| if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents: |
| return {} |
| pipe.load_models_to_device(self.onload_model_names) |
| image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) |
| z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) |
| latents[:, :, 0: 1] = z |
| return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z} |
|
|
|
|
| class WanVideoUnit_VideoEmbedderFused(PipelineUnit): |
| """ |
| Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. |
| """ |
| def __init__(self): |
| super().__init__( |
| input_params=("input_pre_video", "latents", "height", "width", "tiled", "tile_size", "tile_stride"), |
| onload_model_names=("vae",) |
| ) |
|
|
| def process(self, pipe: WanVideoPipeline, input_pre_video, latents, height, width, tiled, tile_size, tile_stride): |
| if input_pre_video is None or not pipe.dit.fuse_vae_embedding_in_latents: |
| return {} |
| pipe.load_models_to_device(self.onload_model_names) |
| |
| |
| input_pre_video = pipe.preprocess_video(input_pre_video) |
| input_pre_video_latent = pipe.vae.encode(input_pre_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) |
|
|
| pre_t_num = input_pre_video_latent.shape[2] |
| latents[:, :, :pre_t_num] = input_pre_video_latent |
| return {"latents": latents, "fuse_vae_embedding_in_latents": True, "prev_video_latents": input_pre_video_latent} |
|
|
|
|
| class WanVideoUnit_RefEmbedderFused(PipelineUnit): |
| def __init__(self): |
| super().__init__( |
| input_params=("ref_images", "latents", "height", "width", "tiled", "tile_size", "tile_stride", "num_ref_images"), |
| onload_model_names=("vae",) |
| ) |
| |
| def process(self, pipe: WanVideoPipeline, ref_images, latents, height, width, tiled, tile_size, tile_stride, num_ref_images): |
| if ref_images is None or not pipe.dit.fuse_vae_embedding_in_latents: |
| return {} |
| pipe.load_models_to_device(self.onload_model_names) |
| ref_images_latents = [] |
| for ref_images_ in ref_images: |
| ref_images_latent_ = pipe.extrac_ref_latents(ref_images_, pipe.vae, device=pipe.device, dtype=pipe.torch_dtype)[0][None] |
| ref_images_latents.append(ref_images_latent_) |
| ref_images_latents = torch.concat(ref_images_latents, dim=0) |
| |
| |
| latents = torch.concat([latents, ref_images_latents], dim=2) |
| return {"latents": latents, "fuse_vae_embedding_in_latents": True, "ref_images_latents": ref_images_latents} |
|
|
|
|
| class WanVideoUnit_FunReference(PipelineUnit): |
| def __init__(self): |
| super().__init__( |
| input_params=("reference_image", "height", "width", "reference_image"), |
| onload_model_names=("vae",) |
| ) |
|
|
| def process(self, pipe: WanVideoPipeline, reference_image, height, width): |
| if reference_image is None: |
| return {} |
| pipe.load_models_to_device(["vae"]) |
| reference_image = reference_image.resize((width, height)) |
| reference_latents = pipe.preprocess_video([reference_image]) |
| reference_latents = pipe.vae.encode(reference_latents, device=pipe.device) |
| if pipe.image_encoder is None: |
| return {"reference_latents": reference_latents} |
| clip_feature = pipe.preprocess_image(reference_image) |
| clip_feature = pipe.image_encoder.encode_image([clip_feature]) |
| return {"reference_latents": reference_latents, "clip_feature": clip_feature} |
|
|
|
|
|
|
| class WanVideoUnit_FunCameraControl(PipelineUnit): |
| def __init__(self): |
| super().__init__( |
| input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image", "tiled", "tile_size", "tile_stride"), |
| onload_model_names=("vae",) |
| ) |
|
|
| def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image, tiled, tile_size, tile_stride): |
| if camera_control_direction is None: |
| return {} |
| pipe.load_models_to_device(self.onload_model_names) |
| camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates( |
| camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin) |
| |
| control_camera_video = camera_control_plucker_embedding[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0) |
| control_camera_latents = torch.concat( |
| [ |
| torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), |
| control_camera_video[:, :, 1:] |
| ], dim=2 |
| ).transpose(1, 2) |
| b, f, c, h, w = control_camera_latents.shape |
| control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) |
| control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) |
| control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype) |
| |
| input_image = input_image.resize((width, height)) |
| input_latents = pipe.preprocess_video([input_image]) |
| input_latents = pipe.vae.encode(input_latents, device=pipe.device) |
| y = torch.zeros_like(latents).to(pipe.device) |
| y[:, :, :1] = input_latents |
| y = y.to(dtype=pipe.torch_dtype, device=pipe.device) |
|
|
| if y.shape[1] != pipe.dit.in_dim - latents.shape[1]: |
| image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) |
| vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) |
| y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] |
| y = y.to(dtype=pipe.torch_dtype, device=pipe.device) |
| msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) |
| msk[:, 1:] = 0 |
| msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) |
| msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) |
| msk = msk.transpose(1, 2)[0] |
| y = torch.cat([msk,y]) |
| y = y.unsqueeze(0) |
| y = y.to(dtype=pipe.torch_dtype, device=pipe.device) |
| return {"control_camera_latents_input": control_camera_latents_input, "y": y} |
|
|
|
|
|
|
| class WanVideoUnit_SpeedControl(PipelineUnit): |
| def __init__(self): |
| super().__init__(input_params=("motion_bucket_id",)) |
|
|
| def process(self, pipe: WanVideoPipeline, motion_bucket_id): |
| if motion_bucket_id is None: |
| return {} |
| motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device) |
| return {"motion_bucket_id": motion_bucket_id} |
|
|
|
|
|
|
| class WanVideoUnit_VACE(PipelineUnit): |
| def __init__(self): |
| super().__init__( |
| input_params=("vace_video", "vace_video_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"), |
| onload_model_names=("vae",) |
| ) |
|
|
| def process( |
| self, |
| pipe: WanVideoPipeline, |
| vace_video, vace_video_mask, vace_reference_image, vace_scale, |
| height, width, num_frames, |
| tiled, tile_size, tile_stride |
| ): |
| if vace_video is not None or vace_video_mask is not None or vace_reference_image is not None: |
| pipe.load_models_to_device(["vae"]) |
| if vace_video is None: |
| vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device) |
| else: |
| vace_video = pipe.preprocess_video(vace_video) |
| |
| if vace_video_mask is None: |
| vace_video_mask = torch.ones_like(vace_video) |
| else: |
| vace_video_mask = pipe.preprocess_video(vace_video_mask, min_value=0, max_value=1) |
| |
| inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask |
| reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask) |
| inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) |
| reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) |
| vace_video_latents = torch.concat((inactive, reactive), dim=1) |
| |
| vace_mask_latents = rearrange(vace_video_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8) |
| vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact') |
| |
| if vace_reference_image is None: |
| pass |
| else: |
| vace_reference_image = pipe.preprocess_video([vace_reference_image]) |
| vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) |
| vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) |
| vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2) |
| vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2) |
| |
| vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1) |
| return {"vace_context": vace_context, "vace_scale": vace_scale} |
| else: |
| return {"vace_context": None, "vace_scale": vace_scale} |
|
|
|
|
|
|
| class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit): |
| def __init__(self): |
| super().__init__(input_params=()) |
|
|
| def process(self, pipe: WanVideoPipeline): |
| if hasattr(pipe, "use_unified_sequence_parallel"): |
| if pipe.use_unified_sequence_parallel: |
| return {"use_unified_sequence_parallel": True} |
| return {} |
|
|
|
|
|
|
| class WanVideoUnit_TeaCache(PipelineUnit): |
| def __init__(self): |
| super().__init__( |
| seperate_cfg=True, |
| input_params_posi={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, |
| input_params_nega={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, |
| ) |
|
|
| def process(self, pipe: WanVideoPipeline, num_inference_steps, tea_cache_l1_thresh, tea_cache_model_id): |
| if tea_cache_l1_thresh is None: |
| return {} |
| return {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id)} |
|
|
|
|
| class WanVideoUnit_ShotEmbedder(PipelineUnit): |
| def __init__(self): |
| super().__init__(input_params=("shot_cut_frames", "num_frames")) |
|
|
| def process(self, pipe: WanVideoHoloCinePipeline, shot_cut_frames, num_frames): |
| if shot_cut_frames is None: |
| return {} |
| |
| num_latent_frames = (num_frames - 1) // 4 + 1 |
| |
| |
| shot_cut_latents = [0] |
| for frame_idx in sorted(shot_cut_frames): |
| if frame_idx > 0: |
| latent_idx = (frame_idx - 1) // 4 + 1 |
| if latent_idx < num_latent_frames: |
| shot_cut_latents.append(latent_idx) |
| |
| cuts = sorted(list(set(shot_cut_latents))) + [num_latent_frames] |
|
|
|
|
| shot_indices = torch.zeros(num_latent_frames, dtype=torch.long) |
| for i in range(len(cuts) - 1): |
| start_latent, end_latent = cuts[i], cuts[i+1] |
| shot_indices[start_latent:end_latent] = i |
| |
| shot_indices = shot_indices.unsqueeze(0).to(device=pipe.device) |
| |
| return {"shot_indices": shot_indices} |
|
|
|
|
|
|
|
|
| class WanVideoUnit_CfgMerger(PipelineUnit): |
| def __init__(self): |
| super().__init__(take_over=True) |
| self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"] |
|
|
| def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): |
| if not inputs_shared["cfg_merge"]: |
| return inputs_shared, inputs_posi, inputs_nega |
| for name in self.concat_tensor_names: |
| tensor_posi = inputs_posi.get(name) |
| tensor_nega = inputs_nega.get(name) |
| tensor_shared = inputs_shared.get(name) |
| if tensor_posi is not None and tensor_nega is not None: |
| inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0) |
| elif tensor_shared is not None: |
| inputs_shared[name] = torch.concat((tensor_shared, tensor_shared), dim=0) |
| inputs_posi.clear() |
| inputs_nega.clear() |
| return inputs_shared, inputs_posi, inputs_nega |
|
|
|
|
| class TeaCache: |
| def __init__(self, num_inference_steps, rel_l1_thresh, model_id): |
| self.num_inference_steps = num_inference_steps |
| self.step = 0 |
| self.accumulated_rel_l1_distance = 0 |
| self.previous_modulated_input = None |
| self.rel_l1_thresh = rel_l1_thresh |
| self.previous_residual = None |
| self.previous_hidden_states = None |
| |
| self.coefficients_dict = { |
| "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], |
| "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], |
| "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], |
| "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], |
| } |
| if model_id not in self.coefficients_dict: |
| supported_model_ids = ", ".join([i for i in self.coefficients_dict]) |
| raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") |
| self.coefficients = self.coefficients_dict[model_id] |
|
|
| def check(self, dit: WanModel, x, t_mod): |
| modulated_inp = t_mod.clone() |
| if self.step == 0 or self.step == self.num_inference_steps - 1: |
| should_calc = True |
| self.accumulated_rel_l1_distance = 0 |
| else: |
| coefficients = self.coefficients |
| rescale_func = np.poly1d(coefficients) |
| self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) |
| if self.accumulated_rel_l1_distance < self.rel_l1_thresh: |
| should_calc = False |
| else: |
| should_calc = True |
| self.accumulated_rel_l1_distance = 0 |
| self.previous_modulated_input = modulated_inp |
| self.step += 1 |
| if self.step == self.num_inference_steps: |
| self.step = 0 |
| if should_calc: |
| self.previous_hidden_states = x.clone() |
| return not should_calc |
|
|
| def store(self, hidden_states): |
| self.previous_residual = hidden_states - self.previous_hidden_states |
| self.previous_hidden_states = None |
|
|
| def update(self, hidden_states): |
| hidden_states = hidden_states + self.previous_residual |
| return hidden_states |
|
|
|
|
|
|
| class TemporalTiler_BCTHW: |
| def __init__(self): |
| pass |
|
|
| def build_1d_mask(self, length, left_bound, right_bound, border_width): |
| x = torch.ones((length,)) |
| if border_width == 0: |
| return x |
| |
| shift = 0.5 |
| if not left_bound: |
| x[:border_width] = (torch.arange(border_width) + shift) / border_width |
| if not right_bound: |
| x[-border_width:] = torch.flip((torch.arange(border_width) + shift) / border_width, dims=(0,)) |
| return x |
|
|
| def build_mask(self, data, is_bound, border_width): |
| _, _, T, _, _ = data.shape |
| t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0]) |
| mask = repeat(t, "T -> 1 1 T 1 1") |
| return mask |
| |
| def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names, batch_size=None): |
| tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None] |
| tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names} |
| B, C, T, H, W = tensor_dict[tensor_names[0]].shape |
| if batch_size is not None: |
| B *= batch_size |
| data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype |
| value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype) |
| weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype) |
| for t in range(0, T, sliding_window_stride): |
| if t - sliding_window_stride >= 0 and t - sliding_window_stride + sliding_window_size >= T: |
| continue |
| t_ = min(t + sliding_window_size, T) |
| model_kwargs.update({ |
| tensor_name: tensor_dict[tensor_name][:, :, t: t_:, :].to(device=computation_device, dtype=computation_dtype) \ |
| for tensor_name in tensor_names |
| }) |
| model_output = model_fn(**model_kwargs).to(device=data_device, dtype=data_dtype) |
| mask = self.build_mask( |
| model_output, |
| is_bound=(t == 0, t_ == T), |
| border_width=(sliding_window_size - sliding_window_stride,) |
| ).to(device=data_device, dtype=data_dtype) |
| value[:, :, t: t_, :, :] += model_output * mask |
| weight[:, :, t: t_, :, :] += mask |
| value /= weight |
| model_kwargs.update(tensor_dict) |
| return value |
|
|
|
|
|
|
| def model_fn_wan_video( |
| args, |
| dit: WanModel, |
| motion_controller: WanMotionControllerModel = None, |
| vace: VaceWanModel = None, |
| latents: torch.Tensor = None, |
| timestep: torch.Tensor = None, |
| context: torch.Tensor = None, |
| clip_feature: Optional[torch.Tensor] = None, |
| y: Optional[torch.Tensor] = None, |
| reference_latents = None, |
| vace_context = None, |
| vace_scale = 1.0, |
| audio_input: Optional[torch.Tensor] = None, |
| motion_latents: Optional[torch.Tensor] = None, |
| pose_cond: Optional[torch.Tensor] = None, |
| tea_cache: TeaCache = None, |
| use_unified_sequence_parallel: bool = False, |
| motion_bucket_id: Optional[torch.Tensor] = None, |
| sliding_window_size: Optional[int] = None, |
| sliding_window_stride: Optional[int] = None, |
| cfg_merge: bool = False, |
| use_gradient_checkpointing: bool = False, |
| use_gradient_checkpointing_offload: bool = False, |
| control_camera_latents_input = None, |
| fuse_vae_embedding_in_latents: bool = False, |
| num_ref_images=None, |
| prev_video_latents: Optional[torch.Tensor] = None, |
| **kwargs, |
| ): |
| if sliding_window_size is not None and sliding_window_stride is not None: |
| model_kwargs = dict( |
| dit=dit, |
| motion_controller=motion_controller, |
| vace=vace, |
| latents=latents, |
| timestep=timestep, |
| context=context, |
| clip_feature=clip_feature, |
| y=y, |
| reference_latents=reference_latents, |
| vace_context=vace_context, |
| vace_scale=vace_scale, |
| tea_cache=tea_cache, |
| use_unified_sequence_parallel=use_unified_sequence_parallel, |
| motion_bucket_id=motion_bucket_id, |
| ) |
| return TemporalTiler_BCTHW().run( |
| model_fn_wan_video, |
| sliding_window_size, sliding_window_stride, |
| latents.device, latents.dtype, |
| model_kwargs=model_kwargs, |
| tensor_names=["latents", "y"], |
| batch_size=2 if cfg_merge else 1 |
| ) |
|
|
| if use_unified_sequence_parallel: |
| import torch.distributed as dist |
| from xfuser.core.distributed import (get_sequence_parallel_rank, |
| get_sequence_parallel_world_size, |
| get_sp_group) |
|
|
| |
| if dit.seperated_timestep and fuse_vae_embedding_in_latents: |
| timestep = torch.concat([ |
| torch.ones((latents.shape[2] - num_ref_images, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep, |
| torch.zeros((num_ref_images, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) |
| ]).flatten() |
| t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0)) |
| if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: |
| t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1) |
| t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks] |
| t = t_chunks[get_sequence_parallel_rank()] |
| t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) |
|
|
| else: |
| t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) |
| t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) |
|
|
| |
| |
| |
| if motion_bucket_id is not None and motion_controller is not None: |
| t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) |
| context = dit.text_embedding(context) |
|
|
| x = latents |
| |
| if x.shape[0] != context.shape[0]: |
| x = torch.concat([x] * context.shape[0], dim=0) |
| if timestep.shape[0] != context.shape[0]: |
| timestep = torch.concat([timestep] * context.shape[0], dim=0) |
|
|
| |
| if y is not None and dit.require_vae_embedding: |
| x = torch.cat([x, y], dim=1) |
| if clip_feature is not None and dit.require_clip_embedding: |
| clip_embdding = dit.img_emb(clip_feature) |
| context = torch.cat([clip_embdding, context], dim=1) |
|
|
| |
| x, (f, h, w) = dit.patchify(x, control_camera_latents_input) |
| |
| |
| if reference_latents is not None: |
| if len(reference_latents.shape) == 5: |
| reference_latents = reference_latents[:, :, 0] |
| reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2) |
| x = torch.concat([reference_latents, x], dim=1) |
| f += 1 |
|
|
|
|
|
|
| if args.shot_rope: |
| device = dit.shot_freqs[0].device |
| freq_s, freq_f, freq_h, freq_w = dit.shot_freqs |
| shots_nums_batch = [ |
| [20, 20, 20, 3, 3], |
| [20, 20, 20, 3, 3], |
| ] |
| batch_freqs = [] |
|
|
| for shots_nums in shots_nums_batch: |
| sample_freqs = [] |
| for shot_index, num_frames in enumerate(shots_nums): |
| f = num_frames |
| rope_s = freq_s[shot_index] \ |
| .view(1, 1, 1, -1) \ |
| .expand(f, h, w, -1) |
|
|
| rope_f = freq_f[:f] \ |
| .view(f, 1, 1, -1) \ |
| .expand(f, h, w, -1) |
|
|
| rope_h = freq_h[:h] \ |
| .view(1, h, 1, -1) \ |
| .expand(f, h, w, -1) |
|
|
| rope_w = freq_w[:w] \ |
| .view(1, 1, w, -1) \ |
| .expand(f, h, w, -1) |
|
|
| freqs = torch.cat( |
| [rope_s, rope_f, rope_h, rope_w], |
| dim=-1 |
| ) |
|
|
| freqs = freqs.reshape(f * h * w, 1, -1) |
| sample_freqs.append(freqs) |
|
|
| |
| sample_freqs = torch.cat(sample_freqs, dim=0) |
| batch_freqs.append(sample_freqs) |
|
|
| |
| batch_freqs = torch.stack(batch_freqs, dim=0).to(x.device) |
| |
|
|
| |
| if args.split_rope: |
| device = dit.freqs[0].device |
| freq_f, freq_h, freq_w = dit.freqs |
| |
| |
| |
| f_video = torch.arange(f - num_ref_images, device=device) |
| h_video = torch.arange(h, device=device) |
| w_video = torch.arange(w, device=device) |
|
|
| rope_f_video = freq_f[f_video].view(f - num_ref_images, 1, 1, -1).expand(f - num_ref_images, h, w, -1) |
| rope_h_video = freq_h[h_video].view(1, h, 1, -1).expand(f - num_ref_images, h, w, -1) |
| rope_w_video = freq_w[w_video].view(1, 1, w, -1).expand(f - num_ref_images, h, w, -1) |
|
|
| rope_video = torch.cat([rope_f_video, rope_h_video, rope_w_video], dim=-1) |
| rope_video = rope_video.reshape((f - num_ref_images) * h * w, 1, -1).to(x.device) |
|
|
| |
| |
| |
| |
| |
| offset=f - num_ref_images + 10 |
| if args.split1: |
| |
| f_ref = torch.arange(num_ref_images, device=device) + offset |
| |
| h_ref = torch.arange(h, device=device) + offset |
| w_ref = torch.arange(w, device=device) + offset |
| elif args.split2: |
| |
| f_ref = torch.arange(num_ref_images, device=device) + offset |
| |
| h_ref = torch.arange(h, device=device) |
| w_ref = torch.arange(w, device=device) |
| |
| elif args.split3: |
| |
| f_ref = torch.tensor([0, 0, 0], device=device) + offset |
| |
| h_ref = torch.arange(h, device=device) + offset |
| w_ref = torch.arange(w, device=device) + offset |
|
|
|
|
| rope_f_ref = freq_f[f_ref].view(num_ref_images, 1, 1, -1).expand(num_ref_images, h, w, -1) |
| rope_h_ref = freq_h[h_ref].view(1, h, 1, -1).expand(num_ref_images, h, w, -1) |
| rope_w_ref = freq_w[w_ref].view(1, 1, w, -1).expand(num_ref_images, h, w, -1) |
|
|
| rope_ref = torch.cat([rope_f_ref, rope_h_ref, rope_w_ref], dim=-1) |
| rope_ref = rope_ref.reshape(num_ref_images * h * w, 1, -1).to(x.device) |
|
|
| |
| |
| |
| freqs = torch.cat([rope_video, rope_ref], dim=0) |
|
|
|
|
|
|
| else: |
| freqs = torch.cat([ |
| dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), |
| dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), |
| dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) |
| ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) |
| |
| |
| if tea_cache is not None: |
| tea_cache_update = tea_cache.check(dit, x, t_mod) |
| else: |
| tea_cache_update = False |
| |
| if vace_context is not None: |
| vace_hints = vace(x, vace_context, context, t_mod, freqs) |
|
|
| |
| use_attn_mask = True |
| if use_attn_mask: |
| shot_ranges = [ |
| (s0, e0), |
| (s1, e1), |
| ] |
| try: |
| B, S_q = x.shape[0], x.shape[1] |
| L_text_ctx = context.shape[1] |
|
|
| shot_ranges = text_cut_positions['shots'] |
| S_shots = len(shot_ranges) |
|
|
| device, dtype = x.device, x.dtype |
|
|
| |
| |
| |
| shot_table = torch.zeros( |
| S_shots, L_text_ctx, |
| dtype=torch.bool, |
| device=device |
| ) |
|
|
| for sid, (s0, s1) in enumerate(shot_ranges): |
| s0 = int(s0) |
| s1 = int(s1) |
| shot_table[sid, s0:s1 + 1] = True |
|
|
| |
| |
| |
| |
| |
| |
| vid_shot = shot_indices.repeat_interleave(h * w, dim=1) |
|
|
| |
| max_shot_id = int(vid_shot.max()) |
| assert max_shot_id < S_shots, \ |
| f"shot index out of bounds: max={max_shot_id}, S_shots={S_shots}" |
|
|
| |
| |
| |
| allow = shot_table[vid_shot] |
|
|
| |
| |
| |
| block_value = -1e4 |
| bias = torch.zeros( |
| B, S_q, L_text_ctx, |
| dtype=dtype, |
| device=device |
| ) |
| bias = bias.masked_fill(~allow, block_value) |
|
|
| |
| attn_mask = bias.unsqueeze(1) |
|
|
| except Exception as e: |
| print("!!!!!! ERROR FOUND IN SHOT ATTENTION MASK !!!!!!!") |
| raise e |
| else: |
| attn_mask = None |
|
|
| use_sparse_self_attn = getattr(dit, 'use_sparse_self_attn', False) |
| if use_sparse_self_attn: |
| shot_latent_indices = shot_indices.repeat_interleave(h * w, dim=1) |
| shot_latent_indices = labels_to_cuts(shot_latent_indices) |
| else: |
| shot_latent_indices = None |
|
|
|
|
|
|
| |
| |
| |
| if use_unified_sequence_parallel: |
| if dist.is_initialized() and dist.get_world_size() > 1: |
| chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) |
| pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] |
| chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] |
| x = chunks[get_sequence_parallel_rank()] |
| if tea_cache_update: |
| x = tea_cache.update(x) |
| else: |
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs) |
| return custom_forward |
| |
| for block_id, block in enumerate(dit.blocks): |
| if use_gradient_checkpointing_offload: |
| with torch.autograd.graph.save_on_cpu(): |
| x = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(block), |
| x, context, t_mod, freqs, |
| use_reentrant=False, |
| ) |
| elif use_gradient_checkpointing: |
| x = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(block), |
| x, context, t_mod, freqs, |
| use_reentrant=False, |
| ) |
| else: |
| x = block(x, context, t_mod, freqs) |
| if vace_context is not None and block_id in vace.vace_layers_mapping: |
| current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] |
| if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: |
| current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] |
| current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0) |
| x = x + current_vace_hint * vace_scale |
| if tea_cache is not None: |
| tea_cache.store(x) |
| |
| x = dit.head(x, t) |
| if use_unified_sequence_parallel: |
| if dist.is_initialized() and dist.get_world_size() > 1: |
| x = get_sp_group().all_gather(x, dim=1) |
| x = x[:, :-pad_shape] if pad_shape > 0 else x |
| |
| if reference_latents is not None: |
| x = x[:, reference_latents.shape[1]:] |
| f -= 1 |
| x = dit.unpatchify(x, (f, h, w)) |
| return x |
|
|
|
|
| def labels_to_cuts(batch_labels: torch.Tensor): |
|
|
| assert batch_labels.dim() == 2, "expect [b, s]" |
| b, s = batch_labels.shape |
| labs = batch_labels.to(torch.long) |
|
|
|
|
| diffs = torch.zeros((b, s), dtype=torch.bool, device=labs.device) |
| diffs[:, 1:] = labs[:, 1:] != labs[:, :-1] |
|
|
| cuts_list = [] |
| for i in range(b): |
|
|
| change_pos = torch.nonzero(diffs[i], as_tuple=False).flatten() |
| cuts = [0] |
| cuts.extend(change_pos.tolist()) |
| if cuts[-1] != s: |
| cuts.append(s) |
|
|
| cuts_list.append(cuts) |
| return cuts_list |
|
|