from typing import Optional from .wan_video import WanVideoPipeline, ModelConfig from ..models.comp_attn_model import ( CompAttnConfig, CompAttnMergeUnit, CompAttnUnit, patch_cross_attention, wrap_model_fn, ) from .wan_video import WanVideoUnit_PromptEmbedder, WanVideoUnit_CfgMerger def attach_comp_attn(pipe: WanVideoPipeline) -> WanVideoPipeline: if getattr(pipe, "_comp_attn_attached", False): return pipe prompt_idx = None cfg_idx = None for idx, unit in enumerate(pipe.units): if prompt_idx is None and isinstance(unit, WanVideoUnit_PromptEmbedder): prompt_idx = idx if cfg_idx is None and isinstance(unit, WanVideoUnit_CfgMerger): cfg_idx = idx if prompt_idx is not None: pipe.units.insert(prompt_idx + 1, CompAttnUnit()) else: pipe.units.append(CompAttnUnit()) if cfg_idx is not None: pipe.units.insert(cfg_idx + 1, CompAttnMergeUnit()) else: pipe.units.append(CompAttnMergeUnit()) patch_cross_attention(pipe) wrap_model_fn(pipe) pipe._comp_attn_attached = True return pipe class WanVideoCompAttnPipeline: """Comp-Attn 增强的视频生成 Pipeline 支持两种标注模式: 1. 显式标记模式(推荐): 在 prompt 中使用 subject 标记,索引与 bboxes 对应 ```python prompt = "A <0>red car drives left, a <1>blue bicycle rides right" comp_attn = CompAttnConfig( bboxes=[car_bboxes, bike_bboxes], # 按标记索引 <0>, <1> 对应 ) ``` 2. 隐式搜索模式(兼容旧版): 提供 subjects 列表,自动在 prompt 中搜索匹配 ```python prompt = "A red car drives left, a blue bicycle rides right" comp_attn = CompAttnConfig( subjects=["red car", "blue bicycle"], bboxes=[car_bboxes, bike_bboxes], ) ``` """ def __init__(self, pipe: WanVideoPipeline): self.pipe = attach_comp_attn(pipe) def __getattr__(self, name): return getattr(self.pipe, name) @staticmethod def from_pretrained( torch_dtype=None, device="cuda", model_configs: list[ModelConfig] = None, tokenizer_config: Optional[ModelConfig] = None, audio_processor_config: Optional[ModelConfig] = None, redirect_common_files: bool = True, use_usp: bool = False, vram_limit: float = None, ): pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch_dtype, device=device, model_configs=model_configs or [], tokenizer_config=tokenizer_config, audio_processor_config=audio_processor_config, redirect_common_files=redirect_common_files, use_usp=use_usp, vram_limit=vram_limit, ) return WanVideoCompAttnPipeline(pipe) def __call__( self, prompt: str, negative_prompt: str = "", comp_attn: Optional[CompAttnConfig] = None, **kwargs, ): num_frames = kwargs.get("num_frames") if num_frames is not None: self.pipe._comp_attn_num_frames = num_frames self.pipe._comp_attn_config = comp_attn self.pipe._comp_attn_last_prompt = prompt self.pipe._comp_attn_last_negative_prompt = negative_prompt return self.pipe(prompt=prompt, negative_prompt=negative_prompt, **kwargs)