Spaces:
Running on Zero
Running on Zero
| import torch, types | |
| import numpy as np | |
| from PIL import Image | |
| from einops import repeat | |
| from typing import Optional, Union | |
| from einops import rearrange | |
| import numpy as np | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from typing import Optional | |
| from typing_extensions import Literal | |
| from transformers import Wav2Vec2Processor | |
| import json | |
| from ..core.device.npu_compatible_device import get_device_type | |
| from ..diffusion import FlowMatchScheduler | |
| from ..core import ModelConfig, gradient_checkpoint_forward | |
| from ..diffusion.base_pipeline import BasePipeline, PipelineUnit | |
| import safetensors.torch | |
| from ..models.wan_video_dit import WanModel, sinusoidal_embedding_1d | |
| from ..models.wan_video_dit_s2v import rope_precompute | |
| from ..models.wan_video_text_encoder import WanTextEncoder, HuggingfaceTokenizer | |
| from ..models.wan_video_vae import WanVideoVAE | |
| from ..models.wan_video_image_encoder import WanImageEncoder | |
| from ..models.wan_video_vace import VaceWanModel | |
| from ..models.wan_video_motion_controller import WanMotionControllerModel | |
| from ..models.wan_video_animate_adapter import WanAnimateAdapter | |
| from ..models.wan_video_mot import MotWanModel | |
| from ..models.wav2vec import WanS2VAudioEncoder | |
| from ..models.longcat_video_dit import LongCatVideoTransformer3DModel | |
| def load_file(path): | |
| state_dict = safetensors.torch.load_file(path, device="cpu") | |
| return dict(state_dict) | |
| class WanVideoPipeline(BasePipeline): | |
| def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): | |
| super().__init__( | |
| device=device, torch_dtype=torch_dtype, | |
| height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 | |
| ) | |
| self.scheduler = FlowMatchScheduler("Wan") | |
| self.tokenizer: HuggingfaceTokenizer = None | |
| self.audio_processor: Wav2Vec2Processor = None | |
| self.text_encoder: WanTextEncoder = None | |
| self.image_encoder: WanImageEncoder = None | |
| self.dit: WanModel = None | |
| self.dit2: WanModel = None | |
| self.vae: WanVideoVAE = None | |
| self.motion_controller: WanMotionControllerModel = None | |
| self.vace: VaceWanModel = None | |
| self.vace2: VaceWanModel = None | |
| self.vap: MotWanModel = None | |
| self.animate_adapter: WanAnimateAdapter = None | |
| self.audio_encoder: WanS2VAudioEncoder = None | |
| self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter", "vap") | |
| self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter", "vap") | |
| self.units = [ | |
| WanVideoUnit_ShapeChecker(), | |
| WanVideoUnit_NoiseInitializer(), | |
| WanVideoUnit_PromptEmbedder(), | |
| WanVideoUnit_S2V(), | |
| WanVideoUnit_InputVideoEmbedder(), | |
| WanVideoUnit_ImageEmbedderVAE(), | |
| WanVideoUnit_ImageEmbedderCLIP(), | |
| WanVideoUnit_ImageEmbedderFused(), | |
| WanVideoUnit_FunControl(), | |
| WanVideoUnit_FunReference(), | |
| WanVideoUnit_FunCameraControl(), | |
| WanVideoUnit_SpeedControl(), | |
| WanVideoUnit_VACE(), | |
| WanVideoUnit_AnimateVideoSplit(), | |
| WanVideoUnit_AnimatePoseLatents(), | |
| WanVideoUnit_AnimateFacePixelValues(), | |
| WanVideoUnit_AnimateInpaint(), | |
| WanVideoUnit_VAP(), | |
| WanVideoUnit_UnifiedSequenceParallel(), | |
| WanVideoUnit_TeaCache(), | |
| WanVideoUnit_CfgMerger(), | |
| WanVideoUnit_LongCatVideo(), | |
| ] | |
| self.post_units = [ | |
| WanVideoPostUnit_S2V(), | |
| ] | |
| self.model_fn = model_fn_wan_video | |
| def enable_usp(self): | |
| from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward | |
| for block in self.dit.blocks: | |
| block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) | |
| self.dit.forward = types.MethodType(usp_dit_forward, self.dit) | |
| if self.dit2 is not None: | |
| for block in self.dit2.blocks: | |
| block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) | |
| self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2) | |
| self.sp_size = get_sequence_parallel_world_size() | |
| self.use_unified_sequence_parallel = True | |
| def from_pretrained( | |
| torch_dtype: torch.dtype = torch.bfloat16, | |
| device: Union[str, torch.device] = get_device_type(), | |
| model_configs: list[ModelConfig] = [], | |
| tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), | |
| audio_processor_config: ModelConfig = None, | |
| redirect_common_files: bool = True, | |
| use_usp: bool = False, | |
| vram_limit: float = None, | |
| wan_paths: list[str] = [], | |
| wan_config_path: str = None, | |
| ): | |
| # Redirect model path | |
| if redirect_common_files: | |
| redirect_dict = { | |
| "models_t5_umt5-xxl-enc-bf16.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "models_t5_umt5-xxl-enc-bf16.safetensors"), | |
| "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14.safetensors"), | |
| "Wan2.1_VAE.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "Wan2.1_VAE.safetensors"), | |
| "Wan2.2_VAE.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "Wan2.2_VAE.safetensors"), | |
| } | |
| for model_config in model_configs: | |
| if model_config.origin_file_pattern is None or model_config.model_id is None: | |
| continue | |
| if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern][0]: | |
| print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to {redirect_dict[model_config.origin_file_pattern]}. You can use `redirect_common_files=False` to disable file redirection.") | |
| model_config.model_id = redirect_dict[model_config.origin_file_pattern][0] | |
| model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1] | |
| if use_usp: | |
| from ..utils.xfuser import initialize_usp | |
| initialize_usp(device) | |
| import torch.distributed as dist | |
| from ..core.device.npu_compatible_device import get_device_name | |
| if dist.is_available() and dist.is_initialized(): | |
| device = get_device_name() | |
| # Initialize pipeline | |
| pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) | |
| model_pool = pipe.download_and_load_models(model_configs, vram_limit) | |
| # Fetch models | |
| pipe.text_encoder = model_pool.fetch_model("wan_video_text_encoder") | |
| print(f"====== go load wan config ======") | |
| with open(wan_config_path, "r") as f: | |
| config = json.load(f) | |
| dit = WanModel(**config) | |
| print(f"====== go load wan weight ======") | |
| dit_state_dict = {} | |
| for each in wan_paths: | |
| dit_state_dict.update(load_file(each)) | |
| missing, unexpected = dit.load_state_dict(dit_state_dict, strict=False) | |
| with torch.no_grad(): | |
| miss = set(missing) | |
| for name, p in dit.named_parameters(): | |
| if name in miss: | |
| p.zero_() | |
| for name, b in dit.named_buffers(): | |
| if name in miss: | |
| if b.is_floating_point() or b.is_complex(): | |
| b.zero_() | |
| else: | |
| b.fill_(0) | |
| print(f"====== load wan weight ok ======") | |
| pipe.dit = dit.to(torch.bfloat16) | |
| # dit = model_pool.fetch_model("wan_video_dit", index=2) | |
| # if isinstance(dit, list): | |
| # pipe.dit, pipe.dit2 = dit | |
| # else: | |
| # pipe.dit = dit | |
| pipe.vae = model_pool.fetch_model("wan_video_vae") | |
| pipe.image_encoder = model_pool.fetch_model("wan_video_image_encoder") | |
| pipe.motion_controller = model_pool.fetch_model("wan_video_motion_controller") | |
| vace = model_pool.fetch_model("wan_video_vace", index=2) | |
| if isinstance(vace, list): | |
| pipe.vace, pipe.vace2 = vace | |
| else: | |
| pipe.vace = vace | |
| pipe.vap = model_pool.fetch_model("wan_video_vap") | |
| pipe.audio_encoder = model_pool.fetch_model("wans2v_audio_encoder") | |
| pipe.animate_adapter = model_pool.fetch_model("wan_video_animate_adapter") | |
| # Size division factor | |
| if pipe.vae is not None: | |
| pipe.height_division_factor = pipe.vae.upsampling_factor * 2 | |
| pipe.width_division_factor = pipe.vae.upsampling_factor * 2 | |
| # Initialize tokenizer and processor | |
| if tokenizer_config is not None: | |
| tokenizer_config.download_if_necessary() | |
| pipe.tokenizer = HuggingfaceTokenizer(name=tokenizer_config.path, seq_len=512, clean='whitespace') | |
| if audio_processor_config is not None: | |
| audio_processor_config.download_if_necessary() | |
| pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path) | |
| # Unified Sequence Parallel | |
| if use_usp: pipe.enable_usp() | |
| # VRAM Management | |
| pipe.vram_management_enabled = pipe.check_vram_management_state() | |
| return pipe | |
| def __call__( | |
| self, | |
| # Prompt | |
| prompt: str, | |
| negative_prompt: Optional[str] = "", | |
| # Image-to-video | |
| input_image: Optional[Image.Image] = None, | |
| # First-last-frame-to-video | |
| end_image: Optional[Image.Image] = None, | |
| ########## | |
| src_video: Optional[list[Image.Image]] = None, | |
| tgt_video: Optional[list[Image.Image]] = None, | |
| ########## | |
| # Video-to-video | |
| input_video: Optional[list[Image.Image]] = None, | |
| denoising_strength: Optional[float] = 1.0, | |
| # Speech-to-video | |
| input_audio: Optional[np.array] = None, | |
| audio_embeds: Optional[torch.Tensor] = None, | |
| audio_sample_rate: Optional[int] = 16000, | |
| s2v_pose_video: Optional[list[Image.Image]] = None, | |
| s2v_pose_latents: Optional[torch.Tensor] = None, | |
| motion_video: Optional[list[Image.Image]] = None, | |
| # ControlNet | |
| control_video: Optional[list[Image.Image]] = None, | |
| reference_image: Optional[Image.Image] = None, | |
| # Camera control | |
| camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None, | |
| camera_control_speed: Optional[float] = 1/54, | |
| camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0), | |
| # VACE | |
| vace_video: Optional[list[Image.Image]] = None, | |
| vace_video_mask: Optional[Image.Image] = None, | |
| vace_reference_image: Optional[Image.Image] = None, | |
| vace_scale: Optional[float] = 1.0, | |
| # Animate | |
| animate_pose_video: Optional[list[Image.Image]] = None, | |
| animate_face_video: Optional[list[Image.Image]] = None, | |
| animate_inpaint_video: Optional[list[Image.Image]] = None, | |
| animate_mask_video: Optional[list[Image.Image]] = None, | |
| # VAP | |
| vap_video: Optional[list[Image.Image]] = None, | |
| vap_prompt: Optional[str] = " ", | |
| negative_vap_prompt: Optional[str] = " ", | |
| # Randomness | |
| seed: Optional[int] = None, | |
| rand_device: Optional[str] = "cpu", | |
| # Shape | |
| height: Optional[int] = 480, | |
| width: Optional[int] = 832, | |
| num_frames=81, | |
| # Classifier-free guidance | |
| cfg_scale: Optional[float] = 5.0, | |
| cfg_merge: Optional[bool] = False, | |
| # Boundary | |
| switch_DiT_boundary: Optional[float] = 0.875, | |
| # Scheduler | |
| num_inference_steps: Optional[int] = 50, | |
| sigma_shift: Optional[float] = 5.0, | |
| # Speed control | |
| motion_bucket_id: Optional[int] = None, | |
| # LongCat-Video | |
| longcat_video: Optional[list[Image.Image]] = None, | |
| # VAE tiling | |
| tiled: Optional[bool] = True, | |
| tile_size: Optional[tuple[int, int]] = (30, 52), | |
| tile_stride: Optional[tuple[int, int]] = (15, 26), | |
| # Sliding window | |
| sliding_window_size: Optional[int] = None, | |
| sliding_window_stride: Optional[int] = None, | |
| # Teacache | |
| tea_cache_l1_thresh: Optional[float] = None, | |
| tea_cache_model_id: Optional[str] = "", | |
| # progress_bar | |
| progress_bar_cmd=tqdm, | |
| output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized", | |
| ): | |
| # Scheduler | |
| self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) | |
| # Inputs | |
| inputs_posi = { | |
| "prompt": prompt, | |
| "vap_prompt": vap_prompt, | |
| "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, | |
| } | |
| inputs_nega = { | |
| "negative_prompt": negative_prompt, | |
| "negative_vap_prompt": negative_vap_prompt, | |
| "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps, | |
| } | |
| inputs_shared = { | |
| "input_image": input_image, | |
| "end_image": end_image, | |
| "src_video": src_video, "tgt_video":tgt_video, | |
| "input_video": input_video, "denoising_strength": denoising_strength, | |
| "control_video": control_video, "reference_image": reference_image, | |
| "camera_control_direction": camera_control_direction, "camera_control_speed": camera_control_speed, "camera_control_origin": camera_control_origin, | |
| "vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale, | |
| "seed": seed, "rand_device": rand_device, | |
| "height": height, "width": width, "num_frames": num_frames, | |
| "cfg_scale": cfg_scale, "cfg_merge": cfg_merge, | |
| "sigma_shift": sigma_shift, | |
| "motion_bucket_id": motion_bucket_id, | |
| "longcat_video": longcat_video, | |
| "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, | |
| "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, | |
| "input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video, | |
| "animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video, | |
| "vap_video": vap_video, | |
| } | |
| for unit in self.units: | |
| inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) | |
| # Denoise | |
| self.load_models_to_device(self.in_iteration_models) | |
| tgt_latent_length = inputs_shared["latents"].shape[2] | |
| models = {name: getattr(self, name) for name in self.in_iteration_models} | |
| for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): | |
| # Switch DiT if necessary | |
| if timestep.item() < switch_DiT_boundary * 1000 and self.dit2 is not None and not models["dit"] is self.dit2: | |
| self.load_models_to_device(self.in_iteration_models_2) | |
| models["dit"] = self.dit2 | |
| models["vace"] = self.vace2 | |
| # Timestep | |
| timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) | |
| inputs_shared["latents"] = torch.cat([inputs_shared["latents"], inputs_shared["input_latents"]], dim=2) | |
| # Inference | |
| noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) | |
| if cfg_scale != 1.0: | |
| if cfg_merge: | |
| noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0) | |
| else: | |
| noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep) | |
| noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) | |
| else: | |
| noise_pred = noise_pred_posi | |
| # Scheduler | |
| inputs_shared["latents"] = self.scheduler.step(noise_pred[:,:,:tgt_latent_length,...], self.scheduler.timesteps[progress_id], inputs_shared["latents"][:,:,:tgt_latent_length,...]) | |
| if "first_frame_latents" in inputs_shared: | |
| inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"] | |
| # VACE (TODO: remove it) | |
| if vace_reference_image is not None or (animate_pose_video is not None and animate_face_video is not None): | |
| if vace_reference_image is not None and isinstance(vace_reference_image, list): | |
| f = len(vace_reference_image) | |
| else: | |
| f = 1 | |
| inputs_shared["latents"] = inputs_shared["latents"][:, :, f:] | |
| # post-denoising, pre-decoding processing logic | |
| for unit in self.post_units: | |
| inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) | |
| # Decode | |
| self.load_models_to_device(['vae']) | |
| video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) | |
| if output_type == "quantized": | |
| video = self.vae_output_to_video(video) | |
| elif output_type == "floatpoint": | |
| pass | |
| self.load_models_to_device([]) | |
| return video | |
| class WanVideoUnit_ShapeChecker(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("height", "width", "num_frames"), | |
| output_params=("height", "width", "num_frames"), | |
| ) | |
| def process(self, pipe: WanVideoPipeline, height, width, num_frames): | |
| height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) | |
| return {"height": height, "width": width, "num_frames": num_frames} | |
| class WanVideoUnit_NoiseInitializer(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image"), | |
| output_params=("noise",) | |
| ) | |
| def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image): | |
| length = (num_frames - 1) // 4 + 1 | |
| if vace_reference_image is not None: | |
| f = len(vace_reference_image) if isinstance(vace_reference_image, list) else 1 | |
| length += f | |
| shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor) | |
| noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device) | |
| if vace_reference_image is not None: | |
| noise = torch.concat((noise[:, :, -f:], noise[:, :, :-f]), dim=2) | |
| return {"noise": noise} | |
| class WanVideoUnit_InputVideoEmbedder(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("src_video", "tgt_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"), | |
| output_params=("latents", "input_latents"), | |
| onload_model_names=("vae",) | |
| ) | |
| def process(self, pipe: WanVideoPipeline, src_video, tgt_video, noise, tiled, tile_size, tile_stride, vace_reference_image): | |
| if src_video is None: | |
| return {"latents": noise} | |
| pipe.load_models_to_device(self.onload_model_names) | |
| src_video = pipe.preprocess_video(src_video) | |
| src_latents = pipe.vae.encode(src_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| if tgt_video is not None: | |
| tgt_video = pipe.preprocess_video(tgt_video) | |
| tgt_latents = pipe.vae.encode(tgt_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| # print() | |
| # print(src_latents.shape) | |
| # print(tgt_latents.shape) | |
| # print() | |
| input_latents = torch.concat([tgt_latents, src_latents], dim=2) | |
| if vace_reference_image is not None: | |
| if not isinstance(vace_reference_image, list): | |
| vace_reference_image = [vace_reference_image] | |
| vace_reference_image = pipe.preprocess_video(vace_reference_image) | |
| vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| input_latents = torch.concat([vace_reference_latents, input_latents], dim=2) | |
| if pipe.scheduler.training: | |
| return {"latents": noise, "input_latents": input_latents} | |
| else: | |
| return {"latents": noise, "input_latents": src_latents} | |
| class WanVideoUnit_PromptEmbedder(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| seperate_cfg=True, | |
| input_params_posi={"prompt": "prompt", "positive": "positive"}, | |
| input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, | |
| output_params=("context",), | |
| onload_model_names=("text_encoder",) | |
| ) | |
| def encode_prompt(self, pipe: WanVideoPipeline, prompt): | |
| ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True) | |
| ids = ids.to(pipe.device) | |
| mask = mask.to(pipe.device) | |
| seq_lens = mask.gt(0).sum(dim=1).long() | |
| prompt_emb = pipe.text_encoder(ids, mask) | |
| for i, v in enumerate(seq_lens): | |
| prompt_emb[:, v:] = 0 | |
| return prompt_emb | |
| def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict: | |
| pipe.load_models_to_device(self.onload_model_names) | |
| prompt_emb = self.encode_prompt(pipe, prompt) | |
| return {"context": prompt_emb} | |
| class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("input_image", "end_image", "height", "width"), | |
| output_params=("clip_feature",), | |
| onload_model_names=("image_encoder",) | |
| ) | |
| def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width): | |
| if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding: | |
| return {} | |
| pipe.load_models_to_device(self.onload_model_names) | |
| image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) | |
| clip_context = pipe.image_encoder.encode_image([image]) | |
| if end_image is not None: | |
| end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) | |
| if pipe.dit.has_image_pos_emb: | |
| clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) | |
| clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| return {"clip_feature": clip_context} | |
| class WanVideoUnit_ImageEmbedderVAE(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), | |
| output_params=("y",), | |
| onload_model_names=("vae",) | |
| ) | |
| def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): | |
| if input_image is None or not pipe.dit.require_vae_embedding: | |
| return {} | |
| pipe.load_models_to_device(self.onload_model_names) | |
| image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) | |
| msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) | |
| msk[:, 1:] = 0 | |
| if end_image is not None: | |
| end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) | |
| vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) | |
| msk[:, -1:] = 1 | |
| else: | |
| vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) | |
| msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) | |
| msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) | |
| msk = msk.transpose(1, 2)[0] | |
| y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] | |
| y = y.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| y = torch.concat([msk, y]) | |
| y = y.unsqueeze(0) | |
| y = y.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| return {"y": y} | |
| class WanVideoUnit_ImageEmbedderFused(PipelineUnit): | |
| """ | |
| Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. | |
| """ | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"), | |
| output_params=("latents", "fuse_vae_embedding_in_latents", "first_frame_latents"), | |
| onload_model_names=("vae",) | |
| ) | |
| def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride): | |
| if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents: | |
| return {} | |
| pipe.load_models_to_device(self.onload_model_names) | |
| image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) | |
| z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) | |
| latents[:, :, 0: 1] = z | |
| return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z} | |
| class WanVideoUnit_FunControl(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y", "latents"), | |
| output_params=("clip_feature", "y"), | |
| onload_model_names=("vae",) | |
| ) | |
| def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y, latents): | |
| if control_video is None: | |
| return {} | |
| pipe.load_models_to_device(self.onload_model_names) | |
| control_video = pipe.preprocess_video(control_video) | |
| control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| y_dim = pipe.dit.in_dim-control_latents.shape[1]-latents.shape[1] | |
| if clip_feature is None or y is None: | |
| clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device) | |
| y = torch.zeros((1, y_dim, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) | |
| else: | |
| y = y[:, -y_dim:] | |
| y = torch.concat([control_latents, y], dim=1) | |
| return {"clip_feature": clip_feature, "y": y} | |
| class WanVideoUnit_FunReference(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("reference_image", "height", "width", "reference_image"), | |
| output_params=("reference_latents", "clip_feature"), | |
| onload_model_names=("vae", "image_encoder") | |
| ) | |
| def process(self, pipe: WanVideoPipeline, reference_image, height, width): | |
| if reference_image is None: | |
| return {} | |
| pipe.load_models_to_device(self.onload_model_names) | |
| reference_image = reference_image.resize((width, height)) | |
| reference_latents = pipe.preprocess_video([reference_image]) | |
| reference_latents = pipe.vae.encode(reference_latents, device=pipe.device) | |
| if pipe.image_encoder is None: | |
| return {"reference_latents": reference_latents} | |
| clip_feature = pipe.preprocess_image(reference_image) | |
| clip_feature = pipe.image_encoder.encode_image([clip_feature]) | |
| return {"reference_latents": reference_latents, "clip_feature": clip_feature} | |
| class WanVideoUnit_FunCameraControl(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image", "tiled", "tile_size", "tile_stride"), | |
| output_params=("control_camera_latents_input", "y"), | |
| onload_model_names=("vae",) | |
| ) | |
| def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image, tiled, tile_size, tile_stride): | |
| if camera_control_direction is None: | |
| return {} | |
| pipe.load_models_to_device(self.onload_model_names) | |
| camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates( | |
| camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin) | |
| control_camera_video = camera_control_plucker_embedding[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0) | |
| control_camera_latents = torch.concat( | |
| [ | |
| torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), | |
| control_camera_video[:, :, 1:] | |
| ], dim=2 | |
| ).transpose(1, 2) | |
| b, f, c, h, w = control_camera_latents.shape | |
| control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) | |
| control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) | |
| control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype) | |
| input_image = input_image.resize((width, height)) | |
| input_latents = pipe.preprocess_video([input_image]) | |
| input_latents = pipe.vae.encode(input_latents, device=pipe.device) | |
| y = torch.zeros_like(latents).to(pipe.device) | |
| y[:, :, :1] = input_latents | |
| y = y.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| if y.shape[1] != pipe.dit.in_dim - latents.shape[1]: | |
| image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) | |
| vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) | |
| y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] | |
| y = y.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) | |
| msk[:, 1:] = 0 | |
| msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) | |
| msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) | |
| msk = msk.transpose(1, 2)[0] | |
| y = torch.cat([msk,y]) | |
| y = y.unsqueeze(0) | |
| y = y.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| return {"control_camera_latents_input": control_camera_latents_input, "y": y} | |
| class WanVideoUnit_SpeedControl(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("motion_bucket_id",), | |
| output_params=("motion_bucket_id",) | |
| ) | |
| def process(self, pipe: WanVideoPipeline, motion_bucket_id): | |
| if motion_bucket_id is None: | |
| return {} | |
| motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| return {"motion_bucket_id": motion_bucket_id} | |
| class WanVideoUnit_VACE(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("vace_video", "vace_video_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"), | |
| output_params=("vace_context", "vace_scale"), | |
| onload_model_names=("vae",) | |
| ) | |
| def process( | |
| self, | |
| pipe: WanVideoPipeline, | |
| vace_video, vace_video_mask, vace_reference_image, vace_scale, | |
| height, width, num_frames, | |
| tiled, tile_size, tile_stride | |
| ): | |
| if vace_video is not None or vace_video_mask is not None or vace_reference_image is not None: | |
| pipe.load_models_to_device(["vae"]) | |
| if vace_video is None: | |
| vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device) | |
| else: | |
| vace_video = pipe.preprocess_video(vace_video) | |
| if vace_video_mask is None: | |
| vace_video_mask = torch.ones_like(vace_video) | |
| else: | |
| vace_video_mask = pipe.preprocess_video(vace_video_mask, min_value=0, max_value=1) | |
| inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask | |
| reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask) | |
| inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| vace_video_latents = torch.concat((inactive, reactive), dim=1) | |
| vace_mask_latents = rearrange(vace_video_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8) | |
| vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact') | |
| if vace_reference_image is None: | |
| pass | |
| else: | |
| if not isinstance(vace_reference_image,list): | |
| vace_reference_image = [vace_reference_image] | |
| vace_reference_image = pipe.preprocess_video(vace_reference_image) | |
| bs, c, f, h, w = vace_reference_image.shape | |
| new_vace_ref_images = [] | |
| for j in range(f): | |
| new_vace_ref_images.append(vace_reference_image[0, :, j:j+1]) | |
| vace_reference_image = new_vace_ref_images | |
| vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) | |
| vace_reference_latents = [u.unsqueeze(0) for u in vace_reference_latents] | |
| vace_video_latents = torch.concat((*vace_reference_latents, vace_video_latents), dim=2) | |
| vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :f]), vace_mask_latents), dim=2) | |
| vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1) | |
| return {"vace_context": vace_context, "vace_scale": vace_scale} | |
| else: | |
| return {"vace_context": None, "vace_scale": vace_scale} | |
| class WanVideoUnit_VAP(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| take_over=True, | |
| onload_model_names=("text_encoder", "vae", "image_encoder"), | |
| input_params=("vap_video", "vap_prompt", "negative_vap_prompt", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), | |
| output_params=("vap_clip_feature", "vap_hidden_state", "context_vap") | |
| ) | |
| def encode_prompt(self, pipe: WanVideoPipeline, prompt): | |
| ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True) | |
| ids = ids.to(pipe.device) | |
| mask = mask.to(pipe.device) | |
| seq_lens = mask.gt(0).sum(dim=1).long() | |
| prompt_emb = pipe.text_encoder(ids, mask) | |
| for i, v in enumerate(seq_lens): | |
| prompt_emb[:, v:] = 0 | |
| return prompt_emb | |
| def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): | |
| if inputs_shared.get("vap_video") is None: | |
| return inputs_shared, inputs_posi, inputs_nega | |
| else: | |
| # 1. encode vap prompt | |
| pipe.load_models_to_device(["text_encoder"]) | |
| vap_prompt, negative_vap_prompt = inputs_posi.get("vap_prompt", ""), inputs_nega.get("negative_vap_prompt", "") | |
| vap_prompt_emb = self.encode_prompt(pipe, vap_prompt) | |
| negative_vap_prompt_emb = self.encode_prompt(pipe, negative_vap_prompt) | |
| inputs_posi.update({"context_vap":vap_prompt_emb}) | |
| inputs_nega.update({"context_vap":negative_vap_prompt_emb}) | |
| # 2. prepare vap image clip embedding | |
| pipe.load_models_to_device(["vae", "image_encoder"]) | |
| vap_video, end_image = inputs_shared.get("vap_video"), inputs_shared.get("end_image") | |
| num_frames, height, width = inputs_shared.get("num_frames"),inputs_shared.get("height"), inputs_shared.get("width") | |
| image_vap = pipe.preprocess_image(vap_video[0].resize((width, height))).to(pipe.device) | |
| vap_clip_context = pipe.image_encoder.encode_image([image_vap]) | |
| if end_image is not None: | |
| vap_end_image = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device) | |
| if pipe.dit.has_image_pos_emb: | |
| vap_clip_context = torch.concat([vap_clip_context, pipe.image_encoder.encode_image([vap_end_image])], dim=1) | |
| vap_clip_context = vap_clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| inputs_shared.update({"vap_clip_feature":vap_clip_context}) | |
| # 3. prepare vap latents | |
| msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) | |
| msk[:, 1:] = 0 | |
| if end_image is not None: | |
| msk[:, -1:] = 1 | |
| last_image_vap = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device) | |
| vae_input = torch.concat([image_vap.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image_vap.device), last_image_vap.transpose(0,1)],dim=1) | |
| else: | |
| vae_input = torch.concat([image_vap.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image_vap.device)], dim=1) | |
| msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) | |
| msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) | |
| msk = msk.transpose(1, 2)[0] | |
| tiled,tile_size,tile_stride = inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride") | |
| y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] | |
| y = y.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| y = torch.concat([msk, y]) | |
| y = y.unsqueeze(0) | |
| y = y.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| vap_video = pipe.preprocess_video(vap_video) | |
| vap_latent = pipe.vae.encode(vap_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| vap_latent = torch.concat([vap_latent,y], dim=1).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| inputs_shared.update({"vap_hidden_state":vap_latent}) | |
| return inputs_shared, inputs_posi, inputs_nega | |
| class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit): | |
| def __init__(self): | |
| super().__init__(input_params=(), output_params=("use_unified_sequence_parallel",)) | |
| def process(self, pipe: WanVideoPipeline): | |
| if hasattr(pipe, "use_unified_sequence_parallel"): | |
| if pipe.use_unified_sequence_parallel: | |
| return {"use_unified_sequence_parallel": True} | |
| return {} | |
| class WanVideoUnit_TeaCache(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| seperate_cfg=True, | |
| input_params_posi={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, | |
| input_params_nega={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, | |
| output_params=("tea_cache",) | |
| ) | |
| def process(self, pipe: WanVideoPipeline, num_inference_steps, tea_cache_l1_thresh, tea_cache_model_id): | |
| if tea_cache_l1_thresh is None: | |
| return {} | |
| return {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id)} | |
| class WanVideoUnit_CfgMerger(PipelineUnit): | |
| def __init__(self): | |
| super().__init__(take_over=True) | |
| self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"] | |
| def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): | |
| if not inputs_shared["cfg_merge"]: | |
| return inputs_shared, inputs_posi, inputs_nega | |
| for name in self.concat_tensor_names: | |
| tensor_posi = inputs_posi.get(name) | |
| tensor_nega = inputs_nega.get(name) | |
| tensor_shared = inputs_shared.get(name) | |
| if tensor_posi is not None and tensor_nega is not None: | |
| inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0) | |
| elif tensor_shared is not None: | |
| inputs_shared[name] = torch.concat((tensor_shared, tensor_shared), dim=0) | |
| inputs_posi.clear() | |
| inputs_nega.clear() | |
| return inputs_shared, inputs_posi, inputs_nega | |
| class WanVideoUnit_S2V(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| take_over=True, | |
| onload_model_names=("audio_encoder", "vae",), | |
| input_params=("input_audio", "audio_embeds", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "audio_sample_rate", "s2v_pose_video", "s2v_pose_latents", "motion_video"), | |
| output_params=("audio_embeds", "motion_latents", "drop_motion_frames", "s2v_pose_latents"), | |
| ) | |
| def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames, fps=16, audio_embeds=None, return_all=False): | |
| if audio_embeds is not None: | |
| return {"audio_embeds": audio_embeds} | |
| pipe.load_models_to_device(["audio_encoder"]) | |
| audio_embeds = pipe.audio_encoder.get_audio_feats_per_inference(input_audio, audio_sample_rate, pipe.audio_processor, fps=fps, batch_frames=num_frames-1, dtype=pipe.torch_dtype, device=pipe.device) | |
| if return_all: | |
| return audio_embeds | |
| else: | |
| return {"audio_embeds": audio_embeds[0]} | |
| def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride, motion_video=None): | |
| pipe.load_models_to_device(["vae"]) | |
| motion_frames = 73 | |
| kwargs = {} | |
| if motion_video is not None: | |
| assert motion_video.shape[2] == motion_frames, f"motion video must have {motion_frames} frames, but got {motion_video.shape[2]}" | |
| motion_latents = motion_video | |
| kwargs["drop_motion_frames"] = False | |
| else: | |
| motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device) | |
| kwargs["drop_motion_frames"] = True | |
| motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| kwargs.update({"motion_latents": motion_latents}) | |
| return kwargs | |
| def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=None, num_repeats=1, return_all=False): | |
| if s2v_pose_latents is not None: | |
| return {"s2v_pose_latents": s2v_pose_latents} | |
| if s2v_pose_video is None: | |
| return {"s2v_pose_latents": None} | |
| pipe.load_models_to_device(["vae"]) | |
| infer_frames = num_frames - 1 | |
| input_video = pipe.preprocess_video(s2v_pose_video)[:, :, :infer_frames * num_repeats] | |
| # pad if not enough frames | |
| padding_frames = infer_frames * num_repeats - input_video.shape[2] | |
| input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2) | |
| input_videos = input_video.chunk(num_repeats, dim=2) | |
| pose_conds = [] | |
| for r in range(num_repeats): | |
| cond = input_videos[r] | |
| cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], dim=2) | |
| cond_latents = pipe.vae.encode(cond, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| pose_conds.append(cond_latents[:,:,1:]) | |
| if return_all: | |
| return pose_conds | |
| else: | |
| return {"s2v_pose_latents": pose_conds[0]} | |
| def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): | |
| if (inputs_shared.get("input_audio") is None and inputs_shared.get("audio_embeds") is None) or pipe.audio_encoder is None or pipe.audio_processor is None: | |
| return inputs_shared, inputs_posi, inputs_nega | |
| num_frames, height, width, tiled, tile_size, tile_stride = inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride") | |
| input_audio, audio_embeds, audio_sample_rate = inputs_shared.pop("input_audio", None), inputs_shared.pop("audio_embeds", None), inputs_shared.get("audio_sample_rate", 16000) | |
| s2v_pose_video, s2v_pose_latents, motion_video = inputs_shared.pop("s2v_pose_video", None), inputs_shared.pop("s2v_pose_latents", None), inputs_shared.pop("motion_video", None) | |
| audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames, audio_embeds=audio_embeds) | |
| inputs_posi.update(audio_input_positive) | |
| inputs_nega.update({"audio_embeds": 0.0 * audio_input_positive["audio_embeds"]}) | |
| inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride, motion_video)) | |
| inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=s2v_pose_latents)) | |
| return inputs_shared, inputs_posi, inputs_nega | |
| def pre_calculate_audio_pose(pipe: WanVideoPipeline, input_audio=None, audio_sample_rate=16000, s2v_pose_video=None, num_frames=81, height=448, width=832, fps=16, tiled=True, tile_size=(30, 52), tile_stride=(15, 26)): | |
| assert pipe.audio_encoder is not None and pipe.audio_processor is not None, "Please load audio encoder and audio processor first." | |
| shapes = WanVideoUnit_ShapeChecker().process(pipe, height, width, num_frames) | |
| height, width, num_frames = shapes["height"], shapes["width"], shapes["num_frames"] | |
| unit = WanVideoUnit_S2V() | |
| audio_embeds = unit.process_audio(pipe, input_audio, audio_sample_rate, num_frames, fps, return_all=True) | |
| pose_latents = unit.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, num_repeats=len(audio_embeds), return_all=True, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) | |
| pose_latents = None if s2v_pose_video is None else pose_latents | |
| return audio_embeds, pose_latents, len(audio_embeds) | |
| class WanVideoPostUnit_S2V(PipelineUnit): | |
| def __init__(self): | |
| super().__init__(input_params=("latents", "motion_latents", "drop_motion_frames")) | |
| def process(self, pipe: WanVideoPipeline, latents, motion_latents, drop_motion_frames): | |
| if pipe.audio_encoder is None or motion_latents is None or drop_motion_frames: | |
| return {} | |
| latents = torch.cat([motion_latents, latents[:,:,1:]], dim=2) | |
| return {"latents": latents} | |
| class WanVideoUnit_AnimateVideoSplit(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("input_video", "animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video"), | |
| output_params=("animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video") | |
| ) | |
| def process(self, pipe: WanVideoPipeline, input_video, animate_pose_video, animate_face_video, animate_inpaint_video, animate_mask_video): | |
| if input_video is None: | |
| return {} | |
| if animate_pose_video is not None: | |
| animate_pose_video = animate_pose_video[:len(input_video) - 4] | |
| if animate_face_video is not None: | |
| animate_face_video = animate_face_video[:len(input_video) - 4] | |
| if animate_inpaint_video is not None: | |
| animate_inpaint_video = animate_inpaint_video[:len(input_video) - 4] | |
| if animate_mask_video is not None: | |
| animate_mask_video = animate_mask_video[:len(input_video) - 4] | |
| return {"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video} | |
| class WanVideoUnit_AnimatePoseLatents(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("animate_pose_video", "tiled", "tile_size", "tile_stride"), | |
| output_params=("pose_latents",), | |
| onload_model_names=("vae",) | |
| ) | |
| def process(self, pipe: WanVideoPipeline, animate_pose_video, tiled, tile_size, tile_stride): | |
| if animate_pose_video is None: | |
| return {} | |
| pipe.load_models_to_device(self.onload_model_names) | |
| animate_pose_video = pipe.preprocess_video(animate_pose_video) | |
| pose_latents = pipe.vae.encode(animate_pose_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| return {"pose_latents": pose_latents} | |
| class WanVideoUnit_AnimateFacePixelValues(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| take_over=True, | |
| input_params=("animate_face_video",), | |
| output_params=("face_pixel_values"), | |
| ) | |
| def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): | |
| if inputs_shared.get("animate_face_video", None) is None: | |
| return inputs_shared, inputs_posi, inputs_nega | |
| inputs_posi["face_pixel_values"] = pipe.preprocess_video(inputs_shared["animate_face_video"]) | |
| inputs_nega["face_pixel_values"] = torch.zeros_like(inputs_posi["face_pixel_values"]) - 1 | |
| return inputs_shared, inputs_posi, inputs_nega | |
| class WanVideoUnit_AnimateInpaint(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("animate_inpaint_video", "animate_mask_video", "input_image", "tiled", "tile_size", "tile_stride"), | |
| output_params=("y",), | |
| onload_model_names=("vae",) | |
| ) | |
| def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device=get_device_type()): | |
| if mask_pixel_values is None: | |
| msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device) | |
| else: | |
| msk = mask_pixel_values.clone() | |
| msk[:, :mask_len] = 1 | |
| msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) | |
| msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) | |
| msk = msk.transpose(1, 2)[0] | |
| return msk | |
| def process(self, pipe: WanVideoPipeline, animate_inpaint_video, animate_mask_video, input_image, tiled, tile_size, tile_stride): | |
| if animate_inpaint_video is None or animate_mask_video is None: | |
| return {} | |
| pipe.load_models_to_device(self.onload_model_names) | |
| bg_pixel_values = pipe.preprocess_video(animate_inpaint_video) | |
| y_reft = pipe.vae.encode(bg_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0].to(dtype=pipe.torch_dtype, device=pipe.device) | |
| _, lat_t, lat_h, lat_w = y_reft.shape | |
| ref_pixel_values = pipe.preprocess_video([input_image]) | |
| ref_latents = pipe.vae.encode(ref_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=pipe.device) | |
| y_ref = torch.concat([mask_ref, ref_latents[0]]).to(dtype=torch.bfloat16, device=pipe.device) | |
| mask_pixel_values = 1 - pipe.preprocess_video(animate_mask_video, max_value=1, min_value=0) | |
| mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w") | |
| mask_pixel_values = torch.nn.functional.interpolate(mask_pixel_values, size=(lat_h, lat_w), mode='nearest') | |
| mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0] | |
| msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, 0, mask_pixel_values=mask_pixel_values, device=pipe.device) | |
| y_reft = torch.concat([msk_reft, y_reft]).to(dtype=torch.bfloat16, device=pipe.device) | |
| y = torch.concat([y_ref, y_reft], dim=1).unsqueeze(0) | |
| return {"y": y} | |
| class WanVideoUnit_LongCatVideo(PipelineUnit): | |
| def __init__(self): | |
| super().__init__( | |
| input_params=("longcat_video",), | |
| output_params=("longcat_latents",), | |
| onload_model_names=("vae",) | |
| ) | |
| def process(self, pipe: WanVideoPipeline, longcat_video): | |
| if longcat_video is None: | |
| return {} | |
| pipe.load_models_to_device(self.onload_model_names) | |
| longcat_video = pipe.preprocess_video(longcat_video) | |
| longcat_latents = pipe.vae.encode(longcat_video, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| return {"longcat_latents": longcat_latents} | |
| class TeaCache: | |
| def __init__(self, num_inference_steps, rel_l1_thresh, model_id): | |
| self.num_inference_steps = num_inference_steps | |
| self.step = 0 | |
| self.accumulated_rel_l1_distance = 0 | |
| self.previous_modulated_input = None | |
| self.rel_l1_thresh = rel_l1_thresh | |
| self.previous_residual = None | |
| self.previous_hidden_states = None | |
| self.coefficients_dict = { | |
| "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], | |
| "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], | |
| "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], | |
| "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], | |
| } | |
| if model_id not in self.coefficients_dict: | |
| supported_model_ids = ", ".join([i for i in self.coefficients_dict]) | |
| raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") | |
| self.coefficients = self.coefficients_dict[model_id] | |
| def check(self, dit: WanModel, x, t_mod): | |
| modulated_inp = t_mod.clone() | |
| if self.step == 0 or self.step == self.num_inference_steps - 1: | |
| should_calc = True | |
| self.accumulated_rel_l1_distance = 0 | |
| else: | |
| coefficients = self.coefficients | |
| rescale_func = np.poly1d(coefficients) | |
| self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) | |
| if self.accumulated_rel_l1_distance < self.rel_l1_thresh: | |
| should_calc = False | |
| else: | |
| should_calc = True | |
| self.accumulated_rel_l1_distance = 0 | |
| self.previous_modulated_input = modulated_inp | |
| self.step += 1 | |
| if self.step == self.num_inference_steps: | |
| self.step = 0 | |
| if should_calc: | |
| self.previous_hidden_states = x.clone() | |
| return not should_calc | |
| def store(self, hidden_states): | |
| self.previous_residual = hidden_states - self.previous_hidden_states | |
| self.previous_hidden_states = None | |
| def update(self, hidden_states): | |
| hidden_states = hidden_states + self.previous_residual | |
| return hidden_states | |
| class TemporalTiler_BCTHW: | |
| def __init__(self): | |
| pass | |
| def build_1d_mask(self, length, left_bound, right_bound, border_width): | |
| x = torch.ones((length,)) | |
| if border_width == 0: | |
| return x | |
| shift = 0.5 | |
| if not left_bound: | |
| x[:border_width] = (torch.arange(border_width) + shift) / border_width | |
| if not right_bound: | |
| x[-border_width:] = torch.flip((torch.arange(border_width) + shift) / border_width, dims=(0,)) | |
| return x | |
| def build_mask(self, data, is_bound, border_width): | |
| _, _, T, _, _ = data.shape | |
| t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0]) | |
| mask = repeat(t, "T -> 1 1 T 1 1") | |
| return mask | |
| def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names, batch_size=None): | |
| tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None] | |
| tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names} | |
| B, C, T, H, W = tensor_dict[tensor_names[0]].shape | |
| if batch_size is not None: | |
| B *= batch_size | |
| data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype | |
| value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype) | |
| weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype) | |
| for t in range(0, T, sliding_window_stride): | |
| if t - sliding_window_stride >= 0 and t - sliding_window_stride + sliding_window_size >= T: | |
| continue | |
| t_ = min(t + sliding_window_size, T) | |
| model_kwargs.update({ | |
| tensor_name: tensor_dict[tensor_name][:, :, t: t_:, :].to(device=computation_device, dtype=computation_dtype) \ | |
| for tensor_name in tensor_names | |
| }) | |
| model_output = model_fn(**model_kwargs).to(device=data_device, dtype=data_dtype) | |
| mask = self.build_mask( | |
| model_output, | |
| is_bound=(t == 0, t_ == T), | |
| border_width=(sliding_window_size - sliding_window_stride,) | |
| ).to(device=data_device, dtype=data_dtype) | |
| value[:, :, t: t_, :, :] += model_output * mask | |
| weight[:, :, t: t_, :, :] += mask | |
| value /= weight | |
| model_kwargs.update(tensor_dict) | |
| return value | |
| def model_fn_wan_video( | |
| dit: WanModel, | |
| motion_controller: WanMotionControllerModel = None, | |
| vace: VaceWanModel = None, | |
| vap: MotWanModel = None, | |
| animate_adapter: WanAnimateAdapter = None, | |
| latents: torch.Tensor = None, | |
| timestep: torch.Tensor = None, | |
| context: torch.Tensor = None, | |
| clip_feature: Optional[torch.Tensor] = None, | |
| y: Optional[torch.Tensor] = None, | |
| reference_latents = None, | |
| vace_context = None, | |
| vace_scale = 1.0, | |
| audio_embeds: Optional[torch.Tensor] = None, | |
| motion_latents: Optional[torch.Tensor] = None, | |
| s2v_pose_latents: Optional[torch.Tensor] = None, | |
| vap_hidden_state = None, | |
| vap_clip_feature = None, | |
| context_vap = None, | |
| drop_motion_frames: bool = True, | |
| tea_cache: TeaCache = None, | |
| use_unified_sequence_parallel: bool = False, | |
| motion_bucket_id: Optional[torch.Tensor] = None, | |
| pose_latents=None, | |
| face_pixel_values=None, | |
| longcat_latents=None, | |
| sliding_window_size: Optional[int] = None, | |
| sliding_window_stride: Optional[int] = None, | |
| cfg_merge: bool = False, | |
| use_gradient_checkpointing: bool = False, | |
| use_gradient_checkpointing_offload: bool = False, | |
| control_camera_latents_input = None, | |
| fuse_vae_embedding_in_latents: bool = False, | |
| **kwargs, | |
| ): | |
| if sliding_window_size is not None and sliding_window_stride is not None: | |
| model_kwargs = dict( | |
| dit=dit, | |
| motion_controller=motion_controller, | |
| vace=vace, | |
| latents=latents, | |
| timestep=timestep, | |
| context=context, | |
| clip_feature=clip_feature, | |
| y=y, | |
| reference_latents=reference_latents, | |
| vace_context=vace_context, | |
| vace_scale=vace_scale, | |
| tea_cache=tea_cache, | |
| use_unified_sequence_parallel=use_unified_sequence_parallel, | |
| motion_bucket_id=motion_bucket_id, | |
| ) | |
| return TemporalTiler_BCTHW().run( | |
| model_fn_wan_video, | |
| sliding_window_size, sliding_window_stride, | |
| latents.device, latents.dtype, | |
| model_kwargs=model_kwargs, | |
| tensor_names=["latents", "y"], | |
| batch_size=2 if cfg_merge else 1 | |
| ) | |
| # LongCat-Video | |
| if isinstance(dit, LongCatVideoTransformer3DModel): | |
| return model_fn_longcat_video( | |
| dit=dit, | |
| latents=latents, | |
| timestep=timestep, | |
| context=context, | |
| longcat_latents=longcat_latents, | |
| use_gradient_checkpointing=use_gradient_checkpointing, | |
| use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, | |
| ) | |
| # wan2.2 s2v | |
| if audio_embeds is not None: | |
| return model_fn_wans2v( | |
| dit=dit, | |
| latents=latents, | |
| timestep=timestep, | |
| context=context, | |
| audio_embeds=audio_embeds, | |
| motion_latents=motion_latents, | |
| s2v_pose_latents=s2v_pose_latents, | |
| drop_motion_frames=drop_motion_frames, | |
| use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, | |
| use_gradient_checkpointing=use_gradient_checkpointing, | |
| use_unified_sequence_parallel=use_unified_sequence_parallel, | |
| ) | |
| if use_unified_sequence_parallel: | |
| import torch.distributed as dist | |
| from xfuser.core.distributed import (get_sequence_parallel_rank, | |
| get_sequence_parallel_world_size, | |
| get_sp_group) | |
| # Timestep | |
| if dit.seperated_timestep and fuse_vae_embedding_in_latents: | |
| timestep = torch.concat([ | |
| torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device), | |
| torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep | |
| ]).flatten() | |
| t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0)) | |
| if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: | |
| t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1) | |
| t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks] | |
| t = t_chunks[get_sequence_parallel_rank()] | |
| t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) | |
| else: | |
| t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) | |
| t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) | |
| # Motion Controller | |
| if motion_bucket_id is not None and motion_controller is not None: | |
| t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) | |
| context = dit.text_embedding(context) | |
| x = latents | |
| # Merged cfg | |
| if x.shape[0] != context.shape[0]: | |
| x = torch.concat([x] * context.shape[0], dim=0) | |
| if timestep.shape[0] != context.shape[0]: | |
| timestep = torch.concat([timestep] * context.shape[0], dim=0) | |
| # Image Embedding | |
| if y is not None and dit.require_vae_embedding: | |
| x = torch.cat([x, y], dim=1) | |
| if clip_feature is not None and dit.require_clip_embedding: | |
| clip_embdding = dit.img_emb(clip_feature) | |
| context = torch.cat([clip_embdding, context], dim=1) | |
| # Camera control | |
| x = dit.patchify(x, control_camera_latents_input) | |
| # Animate | |
| if pose_latents is not None and face_pixel_values is not None: | |
| x, motion_vec = animate_adapter.after_patch_embedding(x, pose_latents, face_pixel_values) | |
| # Patchify | |
| f, h, w = x.shape[2:] | |
| first_frame_len = h * w | |
| x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() | |
| # Reference image | |
| if reference_latents is not None: | |
| if len(reference_latents.shape) == 5: | |
| reference_latents = reference_latents[:, :, 0] | |
| reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2) | |
| x = torch.concat([reference_latents, x], dim=1) | |
| f += 1 | |
| freqs = torch.cat([ | |
| dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), | |
| dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), | |
| dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) | |
| ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) | |
| # VAP | |
| if vap is not None: | |
| # hidden state | |
| x_vap = vap_hidden_state | |
| x_vap = vap.patchify(x_vap) | |
| x_vap = rearrange(x_vap, 'b c f h w -> b (f h w) c').contiguous() | |
| # Timestep | |
| clean_timestep = torch.ones(timestep.shape, device=timestep.device).to(timestep.dtype) | |
| t = vap.time_embedding(sinusoidal_embedding_1d(vap.freq_dim, clean_timestep)) | |
| t_mod_vap = vap.time_projection(t).unflatten(1, (6, vap.dim)) | |
| # rope | |
| freqs_vap = vap.compute_freqs_mot(f,h,w).to(x.device) | |
| # context | |
| vap_clip_embedding = vap.img_emb(vap_clip_feature) | |
| context_vap = vap.text_embedding(context_vap) | |
| context_vap = torch.cat([vap_clip_embedding, context_vap], dim=1) | |
| # TeaCache | |
| if tea_cache is not None: | |
| tea_cache_update = tea_cache.check(dit, x, t_mod) | |
| else: | |
| tea_cache_update = False | |
| if vace_context is not None: | |
| vace_hints = vace( | |
| x, vace_context, context, t_mod, freqs, | |
| use_gradient_checkpointing=use_gradient_checkpointing, | |
| use_gradient_checkpointing_offload=use_gradient_checkpointing_offload | |
| ) | |
| # blocks | |
| if use_unified_sequence_parallel: | |
| if dist.is_initialized() and dist.get_world_size() > 1: | |
| chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) | |
| pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] | |
| chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] | |
| x = chunks[get_sequence_parallel_rank()] | |
| if tea_cache_update: | |
| x = tea_cache.update(x) | |
| else: | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs) | |
| return custom_forward | |
| def create_custom_forward_vap(block, vap): | |
| def custom_forward(*inputs): | |
| return vap(block, *inputs) | |
| return custom_forward | |
| for block_id, block in enumerate(dit.blocks): | |
| # Block | |
| if vap is not None and block_id in vap.mot_layers_mapping: | |
| if use_gradient_checkpointing_offload: | |
| with torch.autograd.graph.save_on_cpu(): | |
| x, x_vap = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward_vap(block, vap), | |
| x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, | |
| use_reentrant=False, | |
| ) | |
| elif use_gradient_checkpointing: | |
| x, x_vap = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward_vap(block, vap), | |
| x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, | |
| use_reentrant=False, | |
| ) | |
| else: | |
| x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id) | |
| else: | |
| if use_gradient_checkpointing_offload: | |
| with torch.autograd.graph.save_on_cpu(): | |
| x = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| x, context, t_mod, freqs, | |
| use_reentrant=False, | |
| ) | |
| elif use_gradient_checkpointing: | |
| x = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| x, context, t_mod, freqs, first_frame_len, | |
| use_reentrant=False, | |
| ) | |
| else: | |
| x = block(x, context, t_mod, freqs, first_frame_len) | |
| # VACE | |
| if vace_context is not None and block_id in vace.vace_layers_mapping: | |
| current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] | |
| if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: | |
| current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] | |
| current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0) | |
| x = x + current_vace_hint * vace_scale | |
| # Animate | |
| if pose_latents is not None and face_pixel_values is not None: | |
| x = animate_adapter.after_transformer_block(block_id, x, motion_vec) | |
| if tea_cache is not None: | |
| tea_cache.store(x) | |
| x = dit.head(x, t) | |
| if use_unified_sequence_parallel: | |
| if dist.is_initialized() and dist.get_world_size() > 1: | |
| x = get_sp_group().all_gather(x, dim=1) | |
| x = x[:, :-pad_shape] if pad_shape > 0 else x | |
| # Remove reference latents | |
| if reference_latents is not None: | |
| x = x[:, reference_latents.shape[1]:] | |
| f -= 1 | |
| x = dit.unpatchify(x, (f, h, w)) | |
| return x | |
| def model_fn_longcat_video( | |
| dit: LongCatVideoTransformer3DModel, | |
| latents: torch.Tensor = None, | |
| timestep: torch.Tensor = None, | |
| context: torch.Tensor = None, | |
| longcat_latents: torch.Tensor = None, | |
| use_gradient_checkpointing=False, | |
| use_gradient_checkpointing_offload=False, | |
| ): | |
| if longcat_latents is not None: | |
| latents[:, :, :longcat_latents.shape[2]] = longcat_latents | |
| num_cond_latents = longcat_latents.shape[2] | |
| else: | |
| num_cond_latents = 0 | |
| context = context.unsqueeze(0) | |
| encoder_attention_mask = torch.any(context != 0, dim=-1)[:, 0].to(torch.int64) | |
| output = dit( | |
| latents, | |
| timestep, | |
| context, | |
| encoder_attention_mask, | |
| num_cond_latents=num_cond_latents, | |
| use_gradient_checkpointing=use_gradient_checkpointing, | |
| use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, | |
| ) | |
| output = -output | |
| output = output.to(latents.dtype) | |
| return output | |
| def model_fn_wans2v( | |
| dit, | |
| latents, | |
| timestep, | |
| context, | |
| audio_embeds, | |
| motion_latents, | |
| s2v_pose_latents, | |
| drop_motion_frames=True, | |
| use_gradient_checkpointing_offload=False, | |
| use_gradient_checkpointing=False, | |
| use_unified_sequence_parallel=False, | |
| ): | |
| if use_unified_sequence_parallel: | |
| import torch.distributed as dist | |
| from xfuser.core.distributed import (get_sequence_parallel_rank, | |
| get_sequence_parallel_world_size, | |
| get_sp_group) | |
| origin_ref_latents = latents[:, :, 0:1] | |
| x = latents[:, :, 1:] | |
| # context embedding | |
| context = dit.text_embedding(context) | |
| # audio encode | |
| audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_embeds) | |
| # x and s2v_pose_latents | |
| s2v_pose_latents = torch.zeros_like(x) if s2v_pose_latents is None else s2v_pose_latents | |
| x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(s2v_pose_latents)) | |
| seq_len_x = seq_len_x_global = x.shape[1] # global used for unified sequence parallel | |
| # reference image | |
| ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) | |
| grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw)) | |
| x = torch.cat([x, ref_latents], dim=1) | |
| # mask | |
| mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device) | |
| # freqs | |
| pre_compute_freqs = rope_precompute(x.detach().view(1, x.size(1), dit.num_heads, dit.dim // dit.num_heads), grid_sizes, dit.freqs, start=None) | |
| # motion | |
| x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=2) | |
| x = x + dit.trainable_cond_mask(mask).to(x.dtype) | |
| # tmod | |
| timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) | |
| t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) | |
| t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2) | |
| if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: | |
| world_size, sp_rank = get_sequence_parallel_world_size(), get_sequence_parallel_rank() | |
| assert x.shape[1] % world_size == 0, f"the dimension after chunk must be divisible by world size, but got {x.shape[1]} and {get_sequence_parallel_world_size()}" | |
| x = torch.chunk(x, world_size, dim=1)[sp_rank] | |
| seg_idxs = [0] + list(torch.cumsum(torch.tensor([x.shape[1]] * world_size), dim=0).cpu().numpy()) | |
| seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), x.shape[1]) for i in range(len(seg_idxs)-1)] | |
| seq_len_x = seq_len_x_list[sp_rank] | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs) | |
| return custom_forward | |
| for block_id, block in enumerate(dit.blocks): | |
| if use_gradient_checkpointing_offload: | |
| with torch.autograd.graph.save_on_cpu(): | |
| x = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| x, context, t_mod, seq_len_x, pre_compute_freqs[0], | |
| use_reentrant=False, | |
| ) | |
| x = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), | |
| x, | |
| use_reentrant=False, | |
| ) | |
| elif use_gradient_checkpointing: | |
| x = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| x, context, t_mod, seq_len_x, pre_compute_freqs[0], | |
| use_reentrant=False, | |
| ) | |
| x = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), | |
| x, | |
| use_reentrant=False, | |
| ) | |
| else: | |
| x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) | |
| x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel) | |
| if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: | |
| x = get_sp_group().all_gather(x, dim=1) | |
| x = x[:, :seq_len_x_global] | |
| x = dit.head(x, t[:-1]) | |
| x = dit.unpatchify(x, (f, h, w)) | |
| # make compatible with wan video | |
| x = torch.cat([origin_ref_latents, x], dim=2) | |
| return x | |