import torch, warnings, glob, os, types import numpy as np from PIL import Image from einops import repeat, reduce from typing import Optional, Union from dataclasses import dataclass from modelscope import snapshot_download from einops import rearrange import numpy as np from PIL import Image from tqdm import tqdm from typing import Optional from typing_extensions import Literal import torch.nn.functional as F from PIL import Image, ImageOps from diffsynth.core import ModelConfig from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit, PipelineUnitRunner from diffsynth.models import ModelManager, load_state_dict from diffsynth.models.wan_video_dit_mvid import WanModel, RMSNorm, sinusoidal_embedding_1d from diffsynth.models.wan_video_dit_s2v import rope_precompute from diffsynth.models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm from diffsynth.models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample from diffsynth.models.wan_video_image_encoder import WanImageEncoder from diffsynth.models.wan_video_vace import VaceWanModel from diffsynth.models.wan_video_motion_controller import WanMotionControllerModel from diffsynth.schedulers.flow_match import FlowMatchScheduler from diffsynth.prompters import WanPrompter from diffsynth.vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm from diffsynth.lora import GeneralLoRALoader from diffsynth.utils.data import save_video import random from torchvision.transforms import Compose, Normalize, ToTensor class WanVideoPipeline(BasePipeline): def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 ) self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) self.prompter = WanPrompter(tokenizer_path=tokenizer_path) self.text_encoder: WanTextEncoder = None self.image_encoder: WanImageEncoder = None self.dit: WanModel = None self.dit2: WanModel = None self.vae: WanVideoVAE = None self.motion_controller: WanMotionControllerModel = None self.vace: VaceWanModel = None self.in_iteration_models = ("dit", "motion_controller", "vace") self.in_iteration_models_2 = ("dit2", "motion_controller", "vace") self.unit_runner = PipelineUnitRunner() self.units = [ WanVideoUnit_ShapeChecker(), WanVideoUnit_NoiseInitializer(), WanVideoUnit_PromptEmbedder(), # WanVideoUnit_S2V(), WanVideoUnit_InputVideoEmbedder(), WanVideoUnit_ImageEmbedderVAE(), WanVideoUnit_ImageEmbedderCLIP(), WanVideoUnit_ImageEmbedderFused(), WanVideoUnit_VideoEmbedderFused(), WanVideoUnit_RefEmbedderFused(), WanVideoUnit_FunControl(), WanVideoUnit_FunReference(), WanVideoUnit_FunCameraControl(), WanVideoUnit_SpeedControl(), # WanVideoUnit_VACE(), WanVideoUnit_UnifiedSequenceParallel(), WanVideoUnit_TeaCache(), WanVideoUnit_CfgMerger(), ] self.model_fn = model_fn_wan_video def extrac_ref_latents(self, ref_images, vae, device, dtype, min_value=-1., max_value=1.): # Load image. ref_vae_latents = [] for img in ref_images: img = torch.Tensor(np.array(img, dtype=np.float32)) img = img.to(dtype=dtype, device=device) img = img * ((max_value - min_value) / 255) + min_value img_vae_latent = vae.encode([img.permute(2,0,1).unsqueeze(1)], device=device) ###1 C 1 H W ref_vae_latents.append(img_vae_latent) return torch.cat(ref_vae_latents, dim=2) ###1 C ref_num H W def load_lora(self, module, path, alpha=1): loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device) lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device) loader.load(module, lora, alpha=alpha) def training_loss(self, **inputs): max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps) min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps) timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device) inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep) if inputs["ref_images_latents"] is not None: if random.random() < inputs["args"].zero_face_ratio: inputs["latents"] = torch.cat([inputs["latents"], torch.zeros_like(inputs['ref_images_latents'])], dim=2) else: inputs["latents"] = torch.cat([inputs["latents"], inputs['ref_images_latents']], dim=2) training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep) # print(inputs["input_latents"].shape, inputs['ref_images_latents'].shape, inputs["num_ref_images"], training_target.shape) noise_pred = self.model_fn(**inputs, timestep=timestep) loss = torch.nn.functional.mse_loss(noise_pred.float()[:, :, :-inputs["num_ref_images"]], training_target.float()) loss = loss * self.scheduler.training_weight(timestep) return loss def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5): self.vram_management_enabled = True if num_persistent_param_in_dit is not None: vram_limit = None else: if vram_limit is None: vram_limit = self.get_vram() vram_limit = vram_limit - vram_buffer if self.text_encoder is not None: dtype = next(iter(self.text_encoder.parameters())).dtype enable_vram_management( self.text_encoder, module_map = { torch.nn.Linear: AutoWrappedLinear, torch.nn.Embedding: AutoWrappedModule, T5RelativeEmbedding: AutoWrappedModule, T5LayerNorm: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device="cpu", computation_dtype=self.torch_dtype, computation_device=self.device, ), vram_limit=vram_limit, ) if self.dit is not None: dtype = next(iter(self.dit.parameters())).dtype device = "cpu" if vram_limit is not None else self.device enable_vram_management( self.dit, module_map = { torch.nn.Linear: AutoWrappedLinear, torch.nn.Conv3d: AutoWrappedModule, torch.nn.LayerNorm: WanAutoCastLayerNorm, RMSNorm: AutoWrappedModule, torch.nn.Conv2d: AutoWrappedModule, torch.nn.Conv1d: AutoWrappedModule, torch.nn.Embedding: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device=device, computation_dtype=self.torch_dtype, computation_device=self.device, ), max_num_param=num_persistent_param_in_dit, overflow_module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device="cpu", computation_dtype=self.torch_dtype, computation_device=self.device, ), vram_limit=vram_limit, ) if self.dit2 is not None: dtype = next(iter(self.dit2.parameters())).dtype device = "cpu" if vram_limit is not None else self.device enable_vram_management( self.dit2, module_map = { torch.nn.Linear: AutoWrappedLinear, torch.nn.Conv3d: AutoWrappedModule, torch.nn.LayerNorm: WanAutoCastLayerNorm, RMSNorm: AutoWrappedModule, torch.nn.Conv2d: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device=device, computation_dtype=self.torch_dtype, computation_device=self.device, ), max_num_param=num_persistent_param_in_dit, overflow_module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device="cpu", computation_dtype=self.torch_dtype, computation_device=self.device, ), vram_limit=vram_limit, ) if self.vae is not None: dtype = next(iter(self.vae.parameters())).dtype enable_vram_management( self.vae, module_map = { torch.nn.Linear: AutoWrappedLinear, torch.nn.Conv2d: AutoWrappedModule, RMS_norm: AutoWrappedModule, CausalConv3d: AutoWrappedModule, Upsample: AutoWrappedModule, torch.nn.SiLU: AutoWrappedModule, torch.nn.Dropout: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device=self.device, computation_dtype=self.torch_dtype, computation_device=self.device, ), ) if self.image_encoder is not None: dtype = next(iter(self.image_encoder.parameters())).dtype enable_vram_management( self.image_encoder, module_map = { torch.nn.Linear: AutoWrappedLinear, torch.nn.Conv2d: AutoWrappedModule, torch.nn.LayerNorm: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device="cpu", computation_dtype=dtype, computation_device=self.device, ), ) if self.motion_controller is not None: dtype = next(iter(self.motion_controller.parameters())).dtype enable_vram_management( self.motion_controller, module_map = { torch.nn.Linear: AutoWrappedLinear, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device="cpu", computation_dtype=dtype, computation_device=self.device, ), ) if self.vace is not None: device = "cpu" if vram_limit is not None else self.device enable_vram_management( self.vace, module_map = { torch.nn.Linear: AutoWrappedLinear, torch.nn.Conv3d: AutoWrappedModule, torch.nn.LayerNorm: AutoWrappedModule, RMSNorm: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device=device, computation_dtype=self.torch_dtype, computation_device=self.device, ), vram_limit=vram_limit, ) def initialize_usp(self): import torch.distributed as dist from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment dist.init_process_group(backend="nccl", init_method="env://") init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) initialize_model_parallel( sequence_parallel_degree=dist.get_world_size(), ring_degree=1, ulysses_degree=dist.get_world_size(), ) torch.cuda.set_device(dist.get_rank()) def enable_usp(self): from xfuser.core.distributed import get_sequence_parallel_world_size from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward for block in self.dit.blocks: block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) self.dit.forward = types.MethodType(usp_dit_forward, self.dit) if self.dit2 is not None: for block in self.dit2.blocks: block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2) self.sp_size = get_sequence_parallel_world_size() self.use_unified_sequence_parallel = True @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, device: Union[str, torch.device] = "cuda", model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="/root/paddle_job/workspace/qizipeng/wanx_pretrainedmodels/Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"), audio_processor_config: ModelConfig = None, redirect_common_files: bool = True, use_usp=False, ): # Redirect model path if redirect_common_files: redirect_dict = { "models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B", "Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B", "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P", } for model_config in model_configs: if model_config.origin_file_pattern is None or model_config.model_id is None: continue if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern]: print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.") model_config.model_id = redirect_dict[model_config.origin_file_pattern] # Initialize pipeline pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) if use_usp: pipe.initialize_usp() # Download and load models model_manager = ModelManager() for model_config in model_configs: model_config.download_if_necessary(use_usp=use_usp) model_manager.load_model( model_config.path, device=model_config.offload_device or device, torch_dtype=model_config.offload_dtype or torch_dtype ) # Load models pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder") dit = model_manager.fetch_model("wan_video_dit", index=2) if isinstance(dit, list): pipe.dit, pipe.dit2 = dit else: pipe.dit = dit pipe.vae = model_manager.fetch_model("wan_video_vae") # Size division factor if pipe.vae is not None: pipe.height_division_factor = pipe.vae.upsampling_factor * 2 pipe.width_division_factor = pipe.vae.upsampling_factor * 2 tokenizer_config.download_if_necessary(use_usp=use_usp) pipe.prompter.fetch_models(pipe.text_encoder) # pipe.prompter.fetch_tokenizer(tokenizer_config.path) pipe.prompter.fetch_tokenizer('/root/paddlejob/workspace/qizipeng/wanx_pretrainedmodels/Wan2.2-TI2V-5B/google/umt5-xxl') if audio_processor_config is not None: audio_processor_config.download_if_necessary(use_usp=use_usp) from transformers import Wav2Vec2Processor pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path) # Unified Sequence Parallel if use_usp: pipe.enable_usp() return pipe @torch.no_grad() def __call__( self, args, # Prompt prompt: str, negative_prompt: Optional[str] = "", # Image-to-video input_image: Optional[Image.Image] = None, # First-last-frame-to-video end_image: Optional[Image.Image] = None, # Video-to-video input_video: Optional[list[Image.Image]] = None, input_pre_video: Optional[list[Image.Image]] = None, ref_images: Optional[list[Image.Image]] = None, prev_latent=None, denoising_strength: Optional[float] = 1.0, # Speech-to-video input_audio: Optional[str] = None, audio_sample_rate: Optional[int] = 16000, s2v_pose_video: Optional[list[Image.Image]] = None, # ControlNet control_video: Optional[list[Image.Image]] = None, reference_image: Optional[Image.Image] = None, # Camera control camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None, camera_control_speed: Optional[float] = 1/54, camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0), # VACE vace_video: Optional[list[Image.Image]] = None, vace_video_mask: Optional[Image.Image] = None, vace_reference_image: Optional[Image.Image] = None, vace_scale: Optional[float] = 1.0, # Randomness seed: Optional[int] = None, rand_device: Optional[str] = "cpu", # Shape height: Optional[int] = 480, width: Optional[int] = 832, num_frames=81, # Classifier-free guidance cfg_scale: Optional[float] = 5.0, cfg_scale_face: Optional[float] = 5.0, #### face condition negetive cfg_merge: Optional[bool] = False, # Boundary switch_DiT_boundary: Optional[float] = 0.875, # Scheduler num_inference_steps: Optional[int] = 50, sigma_shift: Optional[float] = 5.0, # Speed control motion_bucket_id: Optional[int] = None, # VAE tiling tiled: Optional[bool] = True, tile_size: Optional[tuple[int, int]] = (30, 52), tile_stride: Optional[tuple[int, int]] = (15, 26), # Sliding window sliding_window_size: Optional[int] = None, sliding_window_stride: Optional[int] = None, # Teacache tea_cache_l1_thresh: Optional[float] = None, tea_cache_model_id: Optional[str] = "", # progress_bar progress_bar_cmd=tqdm, num_ref_images: Optional[int] = None, ): # Scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) # Inputs inputs_posi = { "prompt": prompt, "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, } inputs_nega = { "negative_prompt": negative_prompt, "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, } inputs_shared = { "input_image": input_image, "end_image": end_image, "input_video": input_video, "denoising_strength": denoising_strength, "input_pre_video":input_pre_video, "ref_images":ref_images, "control_video": control_video, "reference_image": reference_image, "camera_control_direction": camera_control_direction, "camera_control_speed": camera_control_speed, "camera_control_origin": camera_control_origin, "vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale, "seed": seed, "rand_device": rand_device, "height": height, "width": width, "num_frames": num_frames, "cfg_scale": cfg_scale, "cfg_merge": cfg_merge, "sigma_shift": sigma_shift, "motion_bucket_id": motion_bucket_id, "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, "input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "num_ref_images":num_ref_images, "batch_size": 1 } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) # Denoise self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): # Switch DiT if necessary if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2: self.load_models_to_device(self.in_iteration_models_2) models["dit"] = self.dit2 # Timestep timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) # Inference noise_pred_posi = self.model_fn(args, **models, **inputs_shared, **inputs_posi, timestep=timestep) ## text img if cfg_scale != 1.0: if cfg_merge: noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0) else: # noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep) ## O img if 'ref_images_latents' in inputs_shared: inputs_shared['latents'][:, :, -inputs_shared["ref_images_latents"].shape[2]:] = torch.zeros_like(inputs_shared['ref_images_latents']) noise_pred_nega_face = self.model_fn(args, **models, **inputs_shared, **inputs_posi, timestep=timestep) # text, 0 noise_all_eng = self.model_fn(args, **models, **inputs_shared, **inputs_nega, timestep=timestep) # 0, 0 noise_pred = noise_all_eng + cfg_scale * (noise_pred_posi - noise_pred_nega_face) + cfg_scale_face * (noise_pred_nega_face - noise_all_eng) else: noise_pred = noise_pred_posi # Scheduler inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) if "ref_images_latents" in inputs_shared: inputs_shared["latents"][:, :, -inputs_shared["ref_images_latents"].shape[2]:] = inputs_shared["ref_images_latents"] # if progress_id in [0,10,20,30,40,43,44,45,46,47,48,49]: # self.load_models_to_device(['vae']) # video = self.vae.decode(inputs_shared["latents"][:, :, :-inputs_shared["ref_images_latents"].shape[2]], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) # video = self.vae_output_to_video(video) # save_video(video, f"./results/videos/video_wyzlarge_arrange5_step_{timestep.item()}_progress_id_{progress_id}.mp4", fps=24, quality=5) # VACE (TODO: remove it) if vace_reference_image is not None: inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:] # Decode if "ref_images_latents" in inputs_shared: inputs_shared["latents"] = inputs_shared["latents"][:, :, :-inputs_shared["ref_images_latents"].shape[2]] self.load_models_to_device(['vae']) video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) video = self.vae_output_to_video(video) self.load_models_to_device([]) return video, inputs_shared["latents"] class WanVideoUnit_ShapeChecker(PipelineUnit): def __init__(self): super().__init__(input_params=("height", "width", "num_frames")) def process(self, pipe: WanVideoPipeline, height, width, num_frames): height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) return {"height": height, "width": width, "num_frames": num_frames} class WanVideoUnit_NoiseInitializer(PipelineUnit): def __init__(self): super().__init__(input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image", "batch_size")) def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image, batch_size = 1): length = (num_frames - 1) // 4 + 1 if vace_reference_image is not None: length += 1 shape = (batch_size, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor) ### B C F H W # shape = (batch_size, vae.model.z_dim, length, height // vae.upsampling_factor, width // vae.upsampling_factor) noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device) if vace_reference_image is not None: noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2) return {"noise": noise} class WanVideoUnit_InputVideoEmbedder(PipelineUnit): def __init__(self): super().__init__( input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"), onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image): if input_video is None: return {"latents": noise} pipe.load_models_to_device(["vae"]) input_latents = [] for input_video_ in input_video: input_video_ = pipe.preprocess_video(input_video_) input_latent_ = pipe.vae.encode(input_video_, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) input_latents.append(input_latent_) input_latents = torch.cat(input_latents, dim = 0) ### B C F H W # if vace_reference_image is not None: # vace_reference_image = pipe.preprocess_video([vace_reference_image]) # vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) # input_latents = torch.concat([vace_reference_latents, input_latents], dim=2) if pipe.scheduler.training: return {"latents": noise, "input_latents": input_latents} else: latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) return {"latents": latents} class WanVideoUnit_PromptEmbedder(PipelineUnit): def __init__(self): super().__init__( seperate_cfg=True, input_params_posi={"prompt": "prompt", "positive": "positive"}, input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, onload_model_names=("text_encoder",) ) def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict: # pipe.load_models_to_device(self.onload_model_names) pipe.text_encoder = pipe.text_encoder.to(pipe.device) prompt_emb_list = [] for prompt_ in prompt: prompt_emb_ = pipe.prompter.encode_prompt(prompt_, positive=positive, device=pipe.device) ###B C Token prompt_emb_list.append(prompt_emb_) prompt_emb = torch.cat(prompt_emb_list, dim = 0) return {"context": prompt_emb} class WanVideoUnit_ImageEmbedder(PipelineUnit): """ Deprecated """ def __init__(self): super().__init__( input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), onload_model_names=("image_encoder", "vae") ) def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): if input_image is None or pipe.image_encoder is None: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) clip_context = pipe.image_encoder.encode_image([image]) msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) msk[:, 1:] = 0 if end_image is not None: end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) if pipe.dit.has_image_pos_emb: clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) msk[:, -1:] = 1 else: vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) msk = msk.transpose(1, 2)[0] y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] y = y.to(dtype=pipe.torch_dtype, device=pipe.device) y = torch.concat([msk, y]) y = y.unsqueeze(0) clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) y = y.to(dtype=pipe.torch_dtype, device=pipe.device) return {"clip_feature": clip_context, "y": y} class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit): def __init__(self): super().__init__( input_params=("input_image", "end_image", "height", "width"), onload_model_names=("image_encoder",) ) def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width): if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) clip_context = pipe.image_encoder.encode_image([image]) if end_image is not None: end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) if pipe.dit.has_image_pos_emb: clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) return {"clip_feature": clip_context} class WanVideoUnit_ImageEmbedderVAE(PipelineUnit): def __init__(self): super().__init__( input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): if input_image is None or not pipe.dit.require_vae_embedding: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) msk[:, 1:] = 0 if end_image is not None: end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) msk[:, -1:] = 1 else: vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) msk = msk.transpose(1, 2)[0] y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] y = y.to(dtype=pipe.torch_dtype, device=pipe.device) y = torch.concat([msk, y]) y = y.unsqueeze(0) y = y.to(dtype=pipe.torch_dtype, device=pipe.device) return {"y": y} class WanVideoUnit_ImageEmbedderFused(PipelineUnit): """ Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. """ def __init__(self): super().__init__( input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"), onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride): if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) latents[:, :, 0: 1] = z return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z} class WanVideoUnit_VideoEmbedderFused(PipelineUnit): """ Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. """ def __init__(self): super().__init__( input_params=("input_pre_video", "latents", "height", "width", "tiled", "tile_size", "tile_stride"), onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, input_pre_video, latents, height, width, tiled, tile_size, tile_stride): if input_pre_video is None or not pipe.dit.fuse_vae_embedding_in_latents: return {} pipe.load_models_to_device(self.onload_model_names) # image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) # z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) input_pre_video = pipe.preprocess_video(input_pre_video) input_pre_video_latent = pipe.vae.encode(input_pre_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) pre_t_num = input_pre_video_latent.shape[2] latents[:, :, :pre_t_num] = input_pre_video_latent return {"latents": latents, "fuse_vae_embedding_in_latents": True, "prev_video_latents": input_pre_video_latent} class WanVideoUnit_RefEmbedderFused(PipelineUnit): def __init__(self): super().__init__( input_params=("ref_images", "latents", "height", "width", "tiled", "tile_size", "tile_stride", "num_ref_images"), onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, ref_images, latents, height, width, tiled, tile_size, tile_stride, num_ref_images): if ref_images is None or not pipe.dit.fuse_vae_embedding_in_latents: return {} pipe.load_models_to_device(self.onload_model_names) ref_images_latents = [] for ref_images_ in ref_images: ref_images_latent_ = pipe.extrac_ref_latents(ref_images_, pipe.vae, device=pipe.device, dtype=pipe.torch_dtype)[0][None] ref_images_latents.append(ref_images_latent_) ##1 C ref_num H W ref_images_latents = torch.concat(ref_images_latents, dim=0) # r = num_ref_images - ref_images_latents.shape[2] # ref_images_latents = F.pad(ref_images_latents, (0, 0, 0, 0, 0, r)) latents = torch.concat([latents, ref_images_latents], dim=2) return {"latents": latents, "fuse_vae_embedding_in_latents": True, "ref_images_latents": ref_images_latents} class WanVideoUnit_FunReference(PipelineUnit): def __init__(self): super().__init__( input_params=("reference_image", "height", "width", "reference_image"), onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, reference_image, height, width): if reference_image is None: return {} pipe.load_models_to_device(["vae"]) reference_image = reference_image.resize((width, height)) reference_latents = pipe.preprocess_video([reference_image]) reference_latents = pipe.vae.encode(reference_latents, device=pipe.device) if pipe.image_encoder is None: return {"reference_latents": reference_latents} clip_feature = pipe.preprocess_image(reference_image) clip_feature = pipe.image_encoder.encode_image([clip_feature]) return {"reference_latents": reference_latents, "clip_feature": clip_feature} class WanVideoUnit_FunCameraControl(PipelineUnit): def __init__(self): super().__init__( input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image", "tiled", "tile_size", "tile_stride"), onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image, tiled, tile_size, tile_stride): if camera_control_direction is None: return {} pipe.load_models_to_device(self.onload_model_names) camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates( camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin) control_camera_video = camera_control_plucker_embedding[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0) control_camera_latents = torch.concat( [ torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), control_camera_video[:, :, 1:] ], dim=2 ).transpose(1, 2) b, f, c, h, w = control_camera_latents.shape control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype) input_image = input_image.resize((width, height)) input_latents = pipe.preprocess_video([input_image]) input_latents = pipe.vae.encode(input_latents, device=pipe.device) y = torch.zeros_like(latents).to(pipe.device) y[:, :, :1] = input_latents y = y.to(dtype=pipe.torch_dtype, device=pipe.device) if y.shape[1] != pipe.dit.in_dim - latents.shape[1]: image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] y = y.to(dtype=pipe.torch_dtype, device=pipe.device) msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) msk[:, 1:] = 0 msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) msk = msk.transpose(1, 2)[0] y = torch.cat([msk,y]) y = y.unsqueeze(0) y = y.to(dtype=pipe.torch_dtype, device=pipe.device) return {"control_camera_latents_input": control_camera_latents_input, "y": y} class WanVideoUnit_SpeedControl(PipelineUnit): def __init__(self): super().__init__(input_params=("motion_bucket_id",)) def process(self, pipe: WanVideoPipeline, motion_bucket_id): if motion_bucket_id is None: return {} motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device) return {"motion_bucket_id": motion_bucket_id} class WanVideoUnit_VACE(PipelineUnit): def __init__(self): super().__init__( input_params=("vace_video", "vace_video_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"), onload_model_names=("vae",) ) def process( self, pipe: WanVideoPipeline, vace_video, vace_video_mask, vace_reference_image, vace_scale, height, width, num_frames, tiled, tile_size, tile_stride ): if vace_video is not None or vace_video_mask is not None or vace_reference_image is not None: pipe.load_models_to_device(["vae"]) if vace_video is None: vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device) else: vace_video = pipe.preprocess_video(vace_video) if vace_video_mask is None: vace_video_mask = torch.ones_like(vace_video) else: vace_video_mask = pipe.preprocess_video(vace_video_mask, min_value=0, max_value=1) inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask) inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) vace_video_latents = torch.concat((inactive, reactive), dim=1) vace_mask_latents = rearrange(vace_video_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8) vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact') if vace_reference_image is None: pass else: vace_reference_image = pipe.preprocess_video([vace_reference_image]) vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2) vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2) vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1) return {"vace_context": vace_context, "vace_scale": vace_scale} else: return {"vace_context": None, "vace_scale": vace_scale} class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit): def __init__(self): super().__init__(input_params=()) def process(self, pipe: WanVideoPipeline): if hasattr(pipe, "use_unified_sequence_parallel"): if pipe.use_unified_sequence_parallel: return {"use_unified_sequence_parallel": True} return {} class WanVideoUnit_TeaCache(PipelineUnit): def __init__(self): super().__init__( seperate_cfg=True, input_params_posi={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, input_params_nega={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, ) def process(self, pipe: WanVideoPipeline, num_inference_steps, tea_cache_l1_thresh, tea_cache_model_id): if tea_cache_l1_thresh is None: return {} return {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id)} class WanVideoUnit_ShotEmbedder(PipelineUnit): def __init__(self): super().__init__(input_params=("shot_cut_frames", "num_frames")) def process(self, pipe: WanVideoHoloCinePipeline, shot_cut_frames, num_frames): if shot_cut_frames is None: return {} num_latent_frames = (num_frames - 1) // 4 + 1 # Convert frame cut indices to latent cut indices shot_cut_latents = [0] for frame_idx in sorted(shot_cut_frames): if frame_idx > 0: latent_idx = (frame_idx - 1) // 4 + 1 if latent_idx < num_latent_frames: shot_cut_latents.append(latent_idx) cuts = sorted(list(set(shot_cut_latents))) + [num_latent_frames] shot_indices = torch.zeros(num_latent_frames, dtype=torch.long) for i in range(len(cuts) - 1): start_latent, end_latent = cuts[i], cuts[i+1] shot_indices[start_latent:end_latent] = i shot_indices = shot_indices.unsqueeze(0).to(device=pipe.device) return {"shot_indices": shot_indices} class WanVideoUnit_CfgMerger(PipelineUnit): def __init__(self): super().__init__(take_over=True) self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"] def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): if not inputs_shared["cfg_merge"]: return inputs_shared, inputs_posi, inputs_nega for name in self.concat_tensor_names: tensor_posi = inputs_posi.get(name) tensor_nega = inputs_nega.get(name) tensor_shared = inputs_shared.get(name) if tensor_posi is not None and tensor_nega is not None: inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0) elif tensor_shared is not None: inputs_shared[name] = torch.concat((tensor_shared, tensor_shared), dim=0) inputs_posi.clear() inputs_nega.clear() return inputs_shared, inputs_posi, inputs_nega class TeaCache: def __init__(self, num_inference_steps, rel_l1_thresh, model_id): self.num_inference_steps = num_inference_steps self.step = 0 self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = None self.rel_l1_thresh = rel_l1_thresh self.previous_residual = None self.previous_hidden_states = None self.coefficients_dict = { "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], } if model_id not in self.coefficients_dict: supported_model_ids = ", ".join([i for i in self.coefficients_dict]) raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") self.coefficients = self.coefficients_dict[model_id] def check(self, dit: WanModel, x, t_mod): modulated_inp = t_mod.clone() if self.step == 0 or self.step == self.num_inference_steps - 1: should_calc = True self.accumulated_rel_l1_distance = 0 else: coefficients = self.coefficients rescale_func = np.poly1d(coefficients) self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) if self.accumulated_rel_l1_distance < self.rel_l1_thresh: should_calc = False else: should_calc = True self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = modulated_inp self.step += 1 if self.step == self.num_inference_steps: self.step = 0 if should_calc: self.previous_hidden_states = x.clone() return not should_calc def store(self, hidden_states): self.previous_residual = hidden_states - self.previous_hidden_states self.previous_hidden_states = None def update(self, hidden_states): hidden_states = hidden_states + self.previous_residual return hidden_states class TemporalTiler_BCTHW: def __init__(self): pass def build_1d_mask(self, length, left_bound, right_bound, border_width): x = torch.ones((length,)) if border_width == 0: return x shift = 0.5 if not left_bound: x[:border_width] = (torch.arange(border_width) + shift) / border_width if not right_bound: x[-border_width:] = torch.flip((torch.arange(border_width) + shift) / border_width, dims=(0,)) return x def build_mask(self, data, is_bound, border_width): _, _, T, _, _ = data.shape t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0]) mask = repeat(t, "T -> 1 1 T 1 1") return mask def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names, batch_size=None): tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None] tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names} B, C, T, H, W = tensor_dict[tensor_names[0]].shape if batch_size is not None: B *= batch_size data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype) weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype) for t in range(0, T, sliding_window_stride): if t - sliding_window_stride >= 0 and t - sliding_window_stride + sliding_window_size >= T: continue t_ = min(t + sliding_window_size, T) model_kwargs.update({ tensor_name: tensor_dict[tensor_name][:, :, t: t_:, :].to(device=computation_device, dtype=computation_dtype) \ for tensor_name in tensor_names }) model_output = model_fn(**model_kwargs).to(device=data_device, dtype=data_dtype) mask = self.build_mask( model_output, is_bound=(t == 0, t_ == T), border_width=(sliding_window_size - sliding_window_stride,) ).to(device=data_device, dtype=data_dtype) value[:, :, t: t_, :, :] += model_output * mask weight[:, :, t: t_, :, :] += mask value /= weight model_kwargs.update(tensor_dict) return value def model_fn_wan_video( args, dit: WanModel, motion_controller: WanMotionControllerModel = None, vace: VaceWanModel = None, latents: torch.Tensor = None, timestep: torch.Tensor = None, context: torch.Tensor = None, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, reference_latents = None, vace_context = None, vace_scale = 1.0, audio_input: Optional[torch.Tensor] = None, motion_latents: Optional[torch.Tensor] = None, pose_cond: Optional[torch.Tensor] = None, tea_cache: TeaCache = None, use_unified_sequence_parallel: bool = False, motion_bucket_id: Optional[torch.Tensor] = None, sliding_window_size: Optional[int] = None, sliding_window_stride: Optional[int] = None, cfg_merge: bool = False, use_gradient_checkpointing: bool = False, use_gradient_checkpointing_offload: bool = False, control_camera_latents_input = None, fuse_vae_embedding_in_latents: bool = False, num_ref_images=None, prev_video_latents: Optional[torch.Tensor] = None, **kwargs, ): if sliding_window_size is not None and sliding_window_stride is not None: model_kwargs = dict( dit=dit, motion_controller=motion_controller, vace=vace, latents=latents, timestep=timestep, context=context, clip_feature=clip_feature, y=y, reference_latents=reference_latents, vace_context=vace_context, vace_scale=vace_scale, tea_cache=tea_cache, use_unified_sequence_parallel=use_unified_sequence_parallel, motion_bucket_id=motion_bucket_id, ) return TemporalTiler_BCTHW().run( model_fn_wan_video, sliding_window_size, sliding_window_stride, latents.device, latents.dtype, model_kwargs=model_kwargs, tensor_names=["latents", "y"], batch_size=2 if cfg_merge else 1 ) if use_unified_sequence_parallel: import torch.distributed as dist from xfuser.core.distributed import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group) # Timestep if dit.seperated_timestep and fuse_vae_embedding_in_latents: timestep = torch.concat([ torch.ones((latents.shape[2] - num_ref_images, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep, torch.zeros((num_ref_images, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) ]).flatten() t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0)) if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1) t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks] t = t_chunks[get_sequence_parallel_rank()] t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) else: t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) # Motion Controller if motion_bucket_id is not None and motion_controller is not None: t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) context = dit.text_embedding(context) x = latents # Merged cfg if x.shape[0] != context.shape[0]: x = torch.concat([x] * context.shape[0], dim=0) if timestep.shape[0] != context.shape[0]: timestep = torch.concat([timestep] * context.shape[0], dim=0) # Image Embedding if y is not None and dit.require_vae_embedding: x = torch.cat([x, y], dim=1) if clip_feature is not None and dit.require_clip_embedding: clip_embdding = dit.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) # Add camera control x, (f, h, w) = dit.patchify(x, control_camera_latents_input) # Reference image if reference_latents is not None: if len(reference_latents.shape) == 5: reference_latents = reference_latents[:, :, 0] reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2) x = torch.concat([reference_latents, x], dim=1) f += 1 if args.shot_rope: device = dit.shot_freqs[0].device freq_s, freq_f, freq_h, freq_w = dit.shot_freqs # (end, dim_*/2) complex shots_nums_batch = [ [20, 20, 20, 3, 3], [20, 20, 20, 3, 3], ] batch_freqs = [] # ⭐ 每个 sample 一个 freqs for shots_nums in shots_nums_batch: # loop over batch sample_freqs = [] # 当前 sample 的所有 shot freqs for shot_index, num_frames in enumerate(shots_nums): f = num_frames rope_s = freq_s[shot_index] \ .view(1, 1, 1, -1) \ .expand(f, h, w, -1) rope_f = freq_f[:f] \ .view(f, 1, 1, -1) \ .expand(f, h, w, -1) rope_h = freq_h[:h] \ .view(1, h, 1, -1) \ .expand(f, h, w, -1) rope_w = freq_w[:w] \ .view(1, 1, w, -1) \ .expand(f, h, w, -1) freqs = torch.cat( [rope_s, rope_f, rope_h, rope_w], dim=-1 ) # (f, h, w, dim/2) complex freqs = freqs.reshape(f * h * w, 1, -1) sample_freqs.append(freqs) # 拼一个 sample 内所有 shot sample_freqs = torch.cat(sample_freqs, dim=0) # (N, 1, dim/2) batch_freqs.append(sample_freqs) # ⭐ stack 成 batch batch_freqs = torch.stack(batch_freqs, dim=0).to(x.device) # shape: (B, N, 1, dim/2) if args.split_rope: device = dit.freqs[0].device freq_f, freq_h, freq_w = dit.freqs # 预先计算好的 1D rope freqs # ============================== # 1) Video 的 RoPE 位置 # ============================== f_video = torch.arange(f - num_ref_images, device=device) h_video = torch.arange(h, device=device) w_video = torch.arange(w, device=device) rope_f_video = freq_f[f_video].view(f - num_ref_images, 1, 1, -1).expand(f - num_ref_images, h, w, -1) rope_h_video = freq_h[h_video].view(1, h, 1, -1).expand(f - num_ref_images, h, w, -1) rope_w_video = freq_w[w_video].view(1, 1, w, -1).expand(f - num_ref_images, h, w, -1) rope_video = torch.cat([rope_f_video, rope_h_video, rope_w_video], dim=-1) rope_video = rope_video.reshape((f - num_ref_images) * h * w, 1, -1).to(x.device) # ============================== # 2) Reference Images 的 RoPE 位置(全部偏移) # ============================== # f 维: ref 占用 [offset ... offset + num_ref_images - 1] offset=f - num_ref_images + 10 if args.split1: # method 1: f h w 全 offset f_ref = torch.arange(num_ref_images, device=device) + offset # h/w 全部偏移 offset h_ref = torch.arange(h, device=device) + offset w_ref = torch.arange(w, device=device) + offset elif args.split2: # method 2: f offset f_ref = torch.arange(num_ref_images, device=device) + offset # h/w 全部偏移 offset h_ref = torch.arange(h, device=device) w_ref = torch.arange(w, device=device) elif args.split3: # method 3: f offset but same h w offset f_ref = torch.tensor([0, 0, 0], device=device) + offset # h/w 全部偏移 offset h_ref = torch.arange(h, device=device) + offset w_ref = torch.arange(w, device=device) + offset rope_f_ref = freq_f[f_ref].view(num_ref_images, 1, 1, -1).expand(num_ref_images, h, w, -1) rope_h_ref = freq_h[h_ref].view(1, h, 1, -1).expand(num_ref_images, h, w, -1) rope_w_ref = freq_w[w_ref].view(1, 1, w, -1).expand(num_ref_images, h, w, -1) rope_ref = torch.cat([rope_f_ref, rope_h_ref, rope_w_ref], dim=-1) rope_ref = rope_ref.reshape(num_ref_images * h * w, 1, -1).to(x.device) # ============================== # 3) 拼接 video + ref-image # ============================== freqs = torch.cat([rope_video, rope_ref], dim=0) else: freqs = torch.cat([ dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) # TeaCache if tea_cache is not None: tea_cache_update = tea_cache.check(dit, x, t_mod) else: tea_cache_update = False if vace_context is not None: vace_hints = vace(x, vace_context, context, t_mod, freqs) ## 构造一个 attention mask,使得每个 video token 只能 attend 自己所属 shot 的 text tokens,其它全部强制屏蔽。在 cross attention 过程中 use_attn_mask = True if use_attn_mask: shot_ranges = [ (s0, e0), # shot 0 的 text (s1, e1), # shot 1 的 text ] try: B, S_q = x.shape[0], x.shape[1] L_text_ctx = context.shape[1] shot_ranges = text_cut_positions['shots'] S_shots = len(shot_ranges) device, dtype = x.device, x.dtype # -------------------------------------------------- # 1. 构建 shot_table: (S_shots, L_text_ctx) # -------------------------------------------------- shot_table = torch.zeros( S_shots, L_text_ctx, dtype=torch.bool, device=device ) for sid, (s0, s1) in enumerate(shot_ranges): s0 = int(s0) s1 = int(s1) shot_table[sid, s0:s1 + 1] = True # -------------------------------------------------- # 2. video token -> shot id # shot_indices: (B, T) # expand to (B, T*h*w) = (B, S_q) # shot_indices 是表示每个video token 属于哪一个shot 的索引 # -------------------------------------------------- vid_shot = shot_indices.repeat_interleave(h * w, dim=1) # sanity check(强烈建议保留) max_shot_id = int(vid_shot.max()) assert max_shot_id < S_shots, \ f"shot index out of bounds: max={max_shot_id}, S_shots={S_shots}" # -------------------------------------------------- # 3. allow mask: (B, S_q, L_text_ctx) # -------------------------------------------------- allow = shot_table[vid_shot] # -------------------------------------------------- # 4. 构建 attention bias # -------------------------------------------------- block_value = -1e4 bias = torch.zeros( B, S_q, L_text_ctx, dtype=dtype, device=device ) bias = bias.masked_fill(~allow, block_value) # attn_mask shape: (B, 1, S_q, L_text_ctx) attn_mask = bias.unsqueeze(1) except Exception as e: print("!!!!!! ERROR FOUND IN SHOT ATTENTION MASK !!!!!!!") raise e else: attn_mask = None use_sparse_self_attn = getattr(dit, 'use_sparse_self_attn', False) if use_sparse_self_attn: shot_latent_indices = shot_indices.repeat_interleave(h * w, dim=1) shot_latent_indices = labels_to_cuts(shot_latent_indices) else: shot_latent_indices = None # blocks if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] x = chunks[get_sequence_parallel_rank()] if tea_cache_update: x = tea_cache.update(x) else: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward for block_id, block in enumerate(dit.blocks): if use_gradient_checkpointing_offload: with torch.autograd.graph.save_on_cpu(): x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, context, t_mod, freqs, use_reentrant=False, ) elif use_gradient_checkpointing: x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, context, t_mod, freqs, use_reentrant=False, ) else: x = block(x, context, t_mod, freqs) if vace_context is not None and block_id in vace.vace_layers_mapping: current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0) x = x + current_vace_hint * vace_scale if tea_cache is not None: tea_cache.store(x) x = dit.head(x, t) if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: x = get_sp_group().all_gather(x, dim=1) x = x[:, :-pad_shape] if pad_shape > 0 else x # Remove reference latents if reference_latents is not None: x = x[:, reference_latents.shape[1]:] f -= 1 x = dit.unpatchify(x, (f, h, w)) return x def labels_to_cuts(batch_labels: torch.Tensor): assert batch_labels.dim() == 2, "expect [b, s]" b, s = batch_labels.shape labs = batch_labels.to(torch.long) diffs = torch.zeros((b, s), dtype=torch.bool, device=labs.device) diffs[:, 1:] = labs[:, 1:] != labs[:, :-1] cuts_list = [] for i in range(b): change_pos = torch.nonzero(diffs[i], as_tuple=False).flatten() cuts = [0] cuts.extend(change_pos.tolist()) if cuts[-1] != s: cuts.append(s) cuts_list.append(cuts) return cuts_list