| from typing import Optional |
|
|
| from .wan_video import WanVideoPipeline, ModelConfig |
| from ..models.comp_attn_model import ( |
| CompAttnConfig, |
| CompAttnMergeUnit, |
| CompAttnUnit, |
| patch_cross_attention, |
| wrap_model_fn, |
| ) |
| from .wan_video import WanVideoUnit_PromptEmbedder, WanVideoUnit_CfgMerger |
|
|
|
|
| def attach_comp_attn(pipe: WanVideoPipeline) -> WanVideoPipeline: |
| if getattr(pipe, "_comp_attn_attached", False): |
| return pipe |
| prompt_idx = None |
| cfg_idx = None |
| for idx, unit in enumerate(pipe.units): |
| if prompt_idx is None and isinstance(unit, WanVideoUnit_PromptEmbedder): |
| prompt_idx = idx |
| if cfg_idx is None and isinstance(unit, WanVideoUnit_CfgMerger): |
| cfg_idx = idx |
| if prompt_idx is not None: |
| pipe.units.insert(prompt_idx + 1, CompAttnUnit()) |
| else: |
| pipe.units.append(CompAttnUnit()) |
| if cfg_idx is not None: |
| pipe.units.insert(cfg_idx + 1, CompAttnMergeUnit()) |
| else: |
| pipe.units.append(CompAttnMergeUnit()) |
| patch_cross_attention(pipe) |
| wrap_model_fn(pipe) |
| pipe._comp_attn_attached = True |
| return pipe |
|
|
|
|
| class WanVideoCompAttnPipeline: |
| """Comp-Attn 增强的视频生成 Pipeline |
| |
| 支持两种标注模式: |
| |
| 1. 显式标记模式(推荐): |
| 在 prompt 中使用 <N>subject</N> 标记,索引与 bboxes 对应 |
| |
| ```python |
| prompt = "A <0>red car</0> drives left, a <1>blue bicycle</1> rides right" |
| comp_attn = CompAttnConfig( |
| bboxes=[car_bboxes, bike_bboxes], # 按标记索引 <0>, <1> 对应 |
| ) |
| ``` |
| |
| 2. 隐式搜索模式(兼容旧版): |
| 提供 subjects 列表,自动在 prompt 中搜索匹配 |
| |
| ```python |
| prompt = "A red car drives left, a blue bicycle rides right" |
| comp_attn = CompAttnConfig( |
| subjects=["red car", "blue bicycle"], |
| bboxes=[car_bboxes, bike_bboxes], |
| ) |
| ``` |
| """ |
| |
| def __init__(self, pipe: WanVideoPipeline): |
| self.pipe = attach_comp_attn(pipe) |
|
|
| def __getattr__(self, name): |
| return getattr(self.pipe, name) |
|
|
| @staticmethod |
| def from_pretrained( |
| torch_dtype=None, |
| device="cuda", |
| model_configs: list[ModelConfig] = None, |
| tokenizer_config: Optional[ModelConfig] = None, |
| audio_processor_config: Optional[ModelConfig] = None, |
| redirect_common_files: bool = True, |
| use_usp: bool = False, |
| vram_limit: float = None, |
| ): |
| pipe = WanVideoPipeline.from_pretrained( |
| torch_dtype=torch_dtype, |
| device=device, |
| model_configs=model_configs or [], |
| tokenizer_config=tokenizer_config, |
| audio_processor_config=audio_processor_config, |
| redirect_common_files=redirect_common_files, |
| use_usp=use_usp, |
| vram_limit=vram_limit, |
| ) |
| return WanVideoCompAttnPipeline(pipe) |
|
|
| def __call__( |
| self, |
| prompt: str, |
| negative_prompt: str = "", |
| comp_attn: Optional[CompAttnConfig] = None, |
| **kwargs, |
| ): |
| num_frames = kwargs.get("num_frames") |
| if num_frames is not None: |
| self.pipe._comp_attn_num_frames = num_frames |
| |
| self.pipe._comp_attn_config = comp_attn |
| self.pipe._comp_attn_last_prompt = prompt |
| self.pipe._comp_attn_last_negative_prompt = negative_prompt |
| return self.pipe(prompt=prompt, negative_prompt=negative_prompt, **kwargs) |
|
|