from ..models import ModelManager from ..models.wan_video_dit import WanModel from ..models.wan_video_dit_infer import WanModel_infer from ..models.wan_video_text_encoder import WanTextEncoder from ..models.wan_video_vae import WanVideoVAE from ..models.wan_video_image_encoder import WanImageEncoder from ..schedulers.flow_match import FlowMatchScheduler from .base import BasePipeline from ..prompters import WanPrompter import torch, os from einops import rearrange import numpy as np from PIL import Image from tqdm import tqdm from typing import Optional from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample class WanVideoAstraPipeline(BasePipeline): def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None,condition_frames=None,target_frames=None): super().__init__(device=device, torch_dtype=torch_dtype) self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) self.prompter = WanPrompter(tokenizer_path=tokenizer_path) self.text_encoder: WanTextEncoder = None self.image_encoder: WanImageEncoder = None self.dit: WanModel = None self.vae: WanVideoVAE = None self.model_names = ['text_encoder', 'dit', 'vae'] self.height_division_factor = 16 self.width_division_factor = 16 self.condition_frames = condition_frames self.target_frames = target_frames def enable_vram_management(self, num_persistent_param_in_dit=None): dtype = next(iter(self.text_encoder.parameters())).dtype enable_vram_management( self.text_encoder, module_map = { torch.nn.Linear: AutoWrappedLinear, torch.nn.Embedding: AutoWrappedModule, T5RelativeEmbedding: AutoWrappedModule, T5LayerNorm: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device="cpu", computation_dtype=self.torch_dtype, computation_device=self.device, ), ) dtype = next(iter(self.dit.parameters())).dtype enable_vram_management( self.dit, module_map = { torch.nn.Linear: AutoWrappedLinear, torch.nn.Conv3d: AutoWrappedModule, torch.nn.LayerNorm: AutoWrappedModule, RMSNorm: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device=self.device, computation_dtype=self.torch_dtype, computation_device=self.device, ), max_num_param=num_persistent_param_in_dit, overflow_module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device="cpu", computation_dtype=self.torch_dtype, computation_device=self.device, ), ) dtype = next(iter(self.vae.parameters())).dtype enable_vram_management( self.vae, module_map = { torch.nn.Linear: AutoWrappedLinear, torch.nn.Conv2d: AutoWrappedModule, RMS_norm: AutoWrappedModule, CausalConv3d: AutoWrappedModule, Upsample: AutoWrappedModule, torch.nn.SiLU: AutoWrappedModule, torch.nn.Dropout: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device=self.device, computation_dtype=self.torch_dtype, computation_device=self.device, ), ) if self.image_encoder is not None: dtype = next(iter(self.image_encoder.parameters())).dtype enable_vram_management( self.image_encoder, module_map = { torch.nn.Linear: AutoWrappedLinear, torch.nn.Conv2d: AutoWrappedModule, torch.nn.LayerNorm: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device="cpu", computation_dtype=dtype, computation_device=self.device, ), ) self.enable_cpu_offload() def fetch_models(self, model_manager: ModelManager): text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True) if text_encoder_model_and_path is not None: self.text_encoder, tokenizer_path = text_encoder_model_and_path self.prompter.fetch_models(self.text_encoder) self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl")) self.dit = model_manager.fetch_model("wan_video_dit") self.vae = model_manager.fetch_model("wan_video_vae") self.image_encoder = model_manager.fetch_model("wan_video_image_encoder") @staticmethod def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None): if device is None: device = model_manager.device if torch_dtype is None: torch_dtype = model_manager.torch_dtype pipe = WanVideoAstraPipeline(device=device, torch_dtype=torch_dtype) pipe.fetch_models(model_manager) return pipe def denoising_model(self): return self.dit def encode_prompt(self, prompt, positive=True): prompt_emb = self.prompter.encode_prompt(prompt, positive=positive) return {"context": prompt_emb} def encode_image(self, image, num_frames, height, width): image = self.preprocess_image(image.resize((width, height))).to(self.device) clip_context = self.image_encoder.encode_image([image]) msk = torch.ones(1, num_frames, height//8, width//8, device=self.device) msk[:, 1:] = 0 msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) msk = msk.transpose(1, 2)[0] vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0] y = torch.concat([msk, y]) y = y.unsqueeze(0) clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device) y = y.to(dtype=self.torch_dtype, device=self.device) return {"clip_feature": clip_context, "y": y} def tensor2video(self, frames): frames = rearrange(frames, "C T H W -> T H W C") frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8) frames = [Image.fromarray(frame) for frame in frames] return frames def prepare_extra_input(self, latents=None): return {} def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) return latents def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) return frames @torch.no_grad() def __call__( self, prompt, negative_prompt="", source_video=None, target_camera=None, input_image=None, input_video=None, denoising_strength=1.0, seed=None, rand_device="cpu", height=480, width=832, num_frames=81, cfg_scale=5.0, num_inference_steps=50, sigma_shift=5.0, tiled=True, tile_size=(30, 52), tile_stride=(15, 26), tea_cache_l1_thresh=None, tea_cache_model_id="", progress_bar_cmd=tqdm, progress_bar_st=None, condition_frames=None, # 🔧 新增参数 target_frames=53, # 🔧 新增参数 ): # 🔧 设置帧数配置 if condition_frames is not None: self.condition_frames = condition_frames if target_frames is not None: self.target_frames = target_frames # Parameter check height, width = self.check_resize_height_width(height, width) # 🔧 修改:根据target_frames调整num_frames if num_frames is None: num_frames = self.target_frames if num_frames % 4 != 1: num_frames = (num_frames + 2) // 4 * 4 + 1 print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.") # Tiler parameters tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} # Scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) # 🔧 修改:Initialize noise for target frames only target_noise_shape = (1, 16, self.target_frames, height//8, width//8) noise = self.generate_noise(target_noise_shape, seed=seed, device=rand_device, dtype=torch.float32) noise = noise.to(dtype=self.torch_dtype, device=self.device) if input_video is not None: self.load_models_to_device(['vae']) input_video = self.preprocess_images(input_video) input_video = torch.stack(input_video, dim=2).to(dtype=self.torch_dtype, device=self.device) latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) else: # 🔧 初始化target latents latents = noise # 🔧 修改:Encode source video (condition frames) self.load_models_to_device(['vae']) source_video = source_video.to(dtype=self.torch_dtype, device=self.device) # 🔧 确保source_video有足够的帧作为condition if source_video.shape[2] < self.condition_frames: raise ValueError(f"Source video has {source_video.shape[2]} frames, need at least {self.condition_frames}") # 🔧 取前condition_frames作为条件 source_latents = self.encode_video(source_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device) source_latents = source_latents[:, :, :self.condition_frames, :, :] # 🔧 确保target_camera的长度匹配target_frames cam_emb = target_camera.to(dtype=self.torch_dtype, device=self.device) if cam_emb.shape[1] != self.target_frames: cam_emb = cam_emb[:,:self.target_frames,:] # Encode prompts self.load_models_to_device(["text_encoder"]) prompt_emb_posi = self.encode_prompt(prompt, positive=True) if cfg_scale != 1.0: prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False) # Encode image if input_image is not None and self.image_encoder is not None: self.load_models_to_device(["image_encoder", "vae"]) image_emb = self.encode_image(input_image, self.target_frames, height, width) # 🔧 使用target_frames else: image_emb = {} # Extra input extra_input = self.prepare_extra_input(latents) # TeaCache tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} # 🔧 修改:Denoise loop self.load_models_to_device(["dit"]) for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) # 🔧 修改:构造输入 - [target | condition] latents_input = torch.cat([source_latents, latents], dim=2) # Inference noise_pred_posi = model_fn_wan_video( self.dit, latents_input, timestep=timestep, cam_emb=cam_emb, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi ) if cfg_scale != 1.0: noise_pred_nega = model_fn_wan_video( self.dit, latents_input, timestep=timestep, cam_emb=cam_emb, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: noise_pred = noise_pred_posi # Scheduler step latents = self.scheduler.step(noise_pred[:, :, self.condition_frames:, :, :], self.scheduler.timesteps[progress_id], latents_input[:, :, self.condition_frames:, :, :]) self.load_models_to_device(['vae']) frames = self.decode_video(torch.cat([source_latents, latents], dim=2), **tiler_kwargs) self.load_models_to_device([]) frames = self.tensor2video(frames[0]) return frames class TeaCache: def __init__(self, num_inference_steps, rel_l1_thresh, model_id): self.num_inference_steps = num_inference_steps self.step = 0 self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = None self.rel_l1_thresh = rel_l1_thresh self.previous_residual = None self.previous_hidden_states = None self.coefficients_dict = { "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], } if model_id not in self.coefficients_dict: supported_model_ids = ", ".join([i for i in self.coefficients_dict]) raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") self.coefficients = self.coefficients_dict[model_id] def check(self, dit: WanModel, x, t_mod): modulated_inp = t_mod.clone() if self.step == 0 or self.step == self.num_inference_steps - 1: should_calc = True self.accumulated_rel_l1_distance = 0 else: coefficients = self.coefficients rescale_func = np.poly1d(coefficients) self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) if self.accumulated_rel_l1_distance < self.rel_l1_thresh: should_calc = False else: should_calc = True self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = modulated_inp self.step += 1 if self.step == self.num_inference_steps: self.step = 0 if should_calc: self.previous_hidden_states = x.clone() return not should_calc def store(self, hidden_states): self.previous_residual = hidden_states - self.previous_hidden_states self.previous_hidden_states = None def update(self, hidden_states): hidden_states = hidden_states + self.previous_residual return hidden_states def model_fn_wan_video( dit: WanModel, x: torch.Tensor, timestep: torch.Tensor, cam_emb: torch.Tensor, context: torch.Tensor, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, tea_cache: TeaCache = None, **kwargs, ): t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) context = dit.text_embedding(context) if dit.has_image_input: x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) clip_embdding = dit.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) x, (f, h, w) = dit.patchify(x) freqs = torch.cat([ dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) # TeaCache tea_cache=None if tea_cache is not None: tea_cache_update = tea_cache.check(dit, x, t_mod) else: tea_cache_update = False if tea_cache_update: x = tea_cache.update(x) else: # blocks for block in dit.blocks: x = block(x, context, cam_emb, t_mod, freqs) if tea_cache is not None: tea_cache.store(x) x = dit.head(x, t) x = dit.unpatchify(x, (f, h, w)) return x