PencilFolder / diffsynth /pipelines /wan_video_comp_attn.py
PencilHu's picture
Upload folder using huggingface_hub
1146a67 verified
from typing import Optional
from .wan_video import WanVideoPipeline, ModelConfig
from ..models.comp_attn_model import (
CompAttnConfig,
CompAttnMergeUnit,
CompAttnUnit,
patch_cross_attention,
wrap_model_fn,
)
from .wan_video import WanVideoUnit_PromptEmbedder, WanVideoUnit_CfgMerger
def attach_comp_attn(pipe: WanVideoPipeline) -> WanVideoPipeline:
if getattr(pipe, "_comp_attn_attached", False):
return pipe
prompt_idx = None
cfg_idx = None
for idx, unit in enumerate(pipe.units):
if prompt_idx is None and isinstance(unit, WanVideoUnit_PromptEmbedder):
prompt_idx = idx
if cfg_idx is None and isinstance(unit, WanVideoUnit_CfgMerger):
cfg_idx = idx
if prompt_idx is not None:
pipe.units.insert(prompt_idx + 1, CompAttnUnit())
else:
pipe.units.append(CompAttnUnit())
if cfg_idx is not None:
pipe.units.insert(cfg_idx + 1, CompAttnMergeUnit())
else:
pipe.units.append(CompAttnMergeUnit())
patch_cross_attention(pipe)
wrap_model_fn(pipe)
pipe._comp_attn_attached = True
return pipe
class WanVideoCompAttnPipeline:
"""Comp-Attn 增强的视频生成 Pipeline
支持两种标注模式:
1. 显式标记模式(推荐):
在 prompt 中使用 <N>subject</N> 标记,索引与 bboxes 对应
```python
prompt = "A <0>red car</0> drives left, a <1>blue bicycle</1> rides right"
comp_attn = CompAttnConfig(
bboxes=[car_bboxes, bike_bboxes], # 按标记索引 <0>, <1> 对应
)
```
2. 隐式搜索模式(兼容旧版):
提供 subjects 列表,自动在 prompt 中搜索匹配
```python
prompt = "A red car drives left, a blue bicycle rides right"
comp_attn = CompAttnConfig(
subjects=["red car", "blue bicycle"],
bboxes=[car_bboxes, bike_bboxes],
)
```
"""
def __init__(self, pipe: WanVideoPipeline):
self.pipe = attach_comp_attn(pipe)
def __getattr__(self, name):
return getattr(self.pipe, name)
@staticmethod
def from_pretrained(
torch_dtype=None,
device="cuda",
model_configs: list[ModelConfig] = None,
tokenizer_config: Optional[ModelConfig] = None,
audio_processor_config: Optional[ModelConfig] = None,
redirect_common_files: bool = True,
use_usp: bool = False,
vram_limit: float = None,
):
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch_dtype,
device=device,
model_configs=model_configs or [],
tokenizer_config=tokenizer_config,
audio_processor_config=audio_processor_config,
redirect_common_files=redirect_common_files,
use_usp=use_usp,
vram_limit=vram_limit,
)
return WanVideoCompAttnPipeline(pipe)
def __call__(
self,
prompt: str,
negative_prompt: str = "",
comp_attn: Optional[CompAttnConfig] = None,
**kwargs,
):
num_frames = kwargs.get("num_frames")
if num_frames is not None:
self.pipe._comp_attn_num_frames = num_frames
self.pipe._comp_attn_config = comp_attn
self.pipe._comp_attn_last_prompt = prompt
self.pipe._comp_attn_last_negative_prompt = negative_prompt
return self.pipe(prompt=prompt, negative_prompt=negative_prompt, **kwargs)