Comp-Attn 设计文档
基于论文 "Comp-Attn: Present-and-Align Attention for Compositional Video Generation" 的实现总结
概述
Comp-Attn 是一种 composition-aware cross-attention 变体,采用 "Present-and-Align" 范式,用于解决多主体视频生成中的两个核心问题:
| 挑战 | 描述 | 解决方案 |
|---|---|---|
| Subject Presence | 并非所有主体都能在视频中呈现 | SCI(条件层面) |
| Inter-subject Relations | 主体间的交互和空间关系错位 | LAM(注意力层面) |
核心组件
1. Subject-aware Condition Interpolation (SCI)
SCI 在 条件编码阶段 增强每个主体的语义表示,确保所有主体都能被"召回"。
工作原理
┌─────────────────────────────────────────────────────────────┐
│ 输入: prompt 中的主体 tokens + 独立编码的 anchor embeddings │
├─────────────────────────────────────────────────────────────┤
│ Step 1: 计算语义显著性 (Saliency) │
│ - 比较 prompt 中主体 token 与 anchor 的余弦相似度 │
│ - 使用 softmax(τ=0.2) 归一化得到显著性分数 │
│ - 低显著性 = 主体在 prompt 上下文中被"淹没" │
├─────────────────────────────────────────────────────────────┤
│ Step 2: 计算语义差异 (Delta) │
│ - delta_i = anchor_i * N - Σ(anchors) │
│ - 表示每个主体相对于其他主体的独特语义 │
├─────────────────────────────────────────────────────────────┤
│ Step 3: 自适应插值 │
│ - scale = ω * (1 - saliency) │
│ - context' = context + delta * scale │
│ - 显著性越低,增强越多;早期时间步增强越强 │
└─────────────────────────────────────────────────────────────┘
关键代码
# 显著性计算
def compute_saliency(prompt_vecs, anchor_vecs, tau=0.2):
cosine = cosine_similarity(prompt_vecs, anchor_vecs)
scores = exp(cosine / tau)
return scores.diagonal() / scores.sum(dim=1)
# 应用 SCI
omega = 1.0 - (timestep / 1000.0) # 时间步调度
scale = omega * (1.0 - saliency)
context = context + delta * scale
2. Layout-forcing Attention Modulation (LAM)
LAM 在 注意力计算阶段 动态调制注意力分布,使其与预定义的空间布局对齐。
工作原理
┌─────────────────────────────────────────────────────────────┐
│ 输入: Q (视频 tokens), K/V (文本 tokens), Layout (bbox) │
├─────────────────────────────────────────────────────────────┤
│ Step 1: 计算原始 attention scores │
│ attn_scores = Q @ K^T / sqrt(d) │
├─────────────────────────────────────────────────────────────┤
│ Step 2: 构建调制函数 │
│ g_plus = max(scores) - scores (增强调制) │
│ g_minus = min(scores) - scores (抑制调制) │
├─────────────────────────────────────────────────────────────┤
│ Step 3: 计算 IOU 引导的强度 │
│ - adapt_mask: 当前注意力分布 > mean 的区域 │
│ - layout_mask: 目标 bbox 区域 │
│ - iou = intersection / union │
│ - strength = 1 - iou (IOU 越低,调制越强) │
├─────────────────────────────────────────────────────────────┤
│ Step 4: 应用动态调制 │
│ - 在 bbox 内增强对应主体 token (g_plus) │
│ - 抑制其他主体 token (g_minus) │
│ - bias = g_k * strength * layout_mask │
│ - final_scores = attn_scores + bias │
└─────────────────────────────────────────────────────────────┘
LAM vs 传统 Layout Control
| 方法 | 机制 | 缺点 |
|---|---|---|
| 硬掩码 (Hard Mask) | 强制注意力只在 bbox 内 | 无法适应多样的物体形状 |
| LAM (Ours) | IOU 引导的动态软调制 | ✅ 灵活适应不同形状 |
3. 关键帧插值
遵循论文附录 F 的设计,支持用 4 个关键帧描述运动轨迹,然后线性插值到所有帧:
# 输入: 4 个关键帧的 bbox
keyframe_bboxes = [frame_0, frame_1, frame_2, frame_3]
# 输出: 81 帧的 bbox(线性插值)
all_bboxes = F.interpolate(keyframe_bboxes, size=81, mode='linear')
架构设计
┌─────────────────────────────────────────────────────────────┐
│ WanVideoCompAttnPipeline │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────────┐ ┌──────────────────┐ │
│ │ CompAttnUnit │───▶│ CompAttnMerge │ │
│ │ (SCI 预处理) │ │ (CFG 合并) │ │
│ └──────────────────┘ └──────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ model_fn_wrapper (动态注入) │ │
│ │ - apply_sci(): 修改 context │ │
│ │ - build_layout_mask(): 构建空间掩码 │ │
│ └──────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ patched cross_attention │ │
│ │ - lam_attention(): IOU 引导的注意力调制 │ │
│ └──────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
Prompt-BBox 绑定机制
使用变量拼接(推荐)
使用 Python 变量定义主体,通过 f-string 拼接 prompt,代码层面天然表达绑定关系:
# 定义主体变量
subject0 = "red car"
subject1 = "blue bicycle"
# 使用变量拼接 prompt
prompt = f"A {subject0} drives left, a {subject1} rides right"
# subjects 和 bboxes 顺序一一对应
subjects = [subject0, subject1]
bboxes = [bbox0, bbox1] # bbox0 对应 subject0, bbox1 对应 subject1
绑定原理
┌─────────────────────────────────────────────────────────────────────────┐
│ 1. 用户定义 subjects 和 bboxes(顺序一一对应) │
│ │
│ subject0 = "red car" │
│ subject1 = "blue bicycle" │
│ subjects = [subject0, subject1] │
│ bboxes = [bbox0, bbox1] │
├─────────────────────────────────────────────────────────────────────────┤
│ 2. Tokenize prompt,搜索每个 subject 的 token 位置 │
│ │
│ prompt = f"A {subject0} drives while a {subject1} rides..." │
│ ↑↑↑↑↑↑↑ ↑↑↑↑↑↑↑↑↑↑↑↑ │
│ token indices token indices │
├─────────────────────────────────────────────────────────────────────────┤
│ 3. 建立 subject_token_mask (关联 token 位置与 bbox) │
│ │
│ subject_token_mask[0, ...] = True # subject0 tokens │
│ subject_token_mask[1, ...] = True # subject1 tokens │
├─────────────────────────────────────────────────────────────────────────┤
│ 4. 在推理时应用: │
│ - SCI: 根据 mask 增强对应 token 的语义 │
│ - LAM: 在 bbox 区域内增强对应 token 的注意力 │
└─────────────────────────────────────────────────────────────────────────┘
重要约束
- subjects 必须在 prompt 中出现:使用变量拼接确保这一点
- 顺序一一对应:
bboxes[i]对应subjects[i] - 精确匹配:subject 字符串需要能被分词器识别为完整的 token 序列
使用方法
基本用法
from diffsynth.pipelines.wan_video_comp_attn import WanVideoCompAttnPipeline
from diffsynth.models.comp_attn_model import CompAttnConfig
# 创建 pipeline
pipe = WanVideoCompAttnPipeline.from_pretrained(...)
# 1. 定义主体变量
subject0 = "red car"
subject1 = "blue bicycle"
# 2. 定义运动轨迹 (4 个关键帧)
bbox0 = [(100, 250, 220, 380), (200, 250, 320, 380), (300, 250, 420, 380), (400, 250, 520, 380)]
bbox1 = [(350, 260, 410, 400), (450, 260, 510, 400), (550, 260, 610, 400), (650, 260, 710, 400)]
# 3. 使用变量拼接 prompt
prompt = f"A {subject0} drives from left to center while a {subject1} rides to the right"
# 4. 配置 Comp-Attn(顺序一一对应)
comp_attn = CompAttnConfig(
subjects=[subject0, subject1], # 变量列表
bboxes=[bbox0, bbox1], # 对应的 bbox 列表
enable_sci=True,
enable_lam=True,
interpolate=True,
)
# 5. 生成视频
video = pipe(prompt=prompt, comp_attn=comp_attn)
运动轨迹辅助函数
def create_moving_bbox(start_x, end_x, y_center, box_width, box_height, num_keyframes=4):
"""创建从 start_x 移动到 end_x 的关键帧 bbox 序列"""
keyframes = []
for i in range(num_keyframes):
progress = i / (num_keyframes - 1)
center_x = start_x + (end_x - start_x) * progress
left = center_x - box_width / 2
right = center_x + box_width / 2
top = y_center - box_height / 2
bottom = y_center + box_height / 2
keyframes.append((left, top, right, bottom))
return keyframes
# 使用示例
car_trajectory = create_moving_bbox(
start_x=100, end_x=500,
y_center=300,
box_width=120, box_height=80,
)
配置参数
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
subjects |
List[str] | - | 主体名称列表(必须出现在 prompt 中) |
bboxes |
List | None | 每个主体每个关键帧的 bbox |
enable_sci |
bool | True | 是否启用 SCI |
enable_lam |
bool | True | 是否启用 LAM |
temperature |
float | 0.2 | 显著性计算的温度参数 τ |
apply_to_negative |
bool | False | 是否对负样本 prompt 应用 |
interpolate |
bool | False | 是否对关键帧 bbox 进行插值 |
BBox 格式
# 格式: (left, top, right, bottom) 像素坐标
# 支持多种输入形式:
# 1. 静态布局(所有帧相同)
bbox = (100, 200, 300, 400)
# 2. 关键帧布局(4帧,会被插值)
bboxes = [
(100, 200, 300, 400), # 关键帧 0
(150, 200, 350, 400), # 关键帧 1
(200, 200, 400, 400), # 关键帧 2
(250, 200, 450, 400), # 关键帧 3
]
与论文的对应关系
| 论文章节 | 实现位置 | 状态 |
|---|---|---|
| Sec 3.2 SCI | compute_saliency(), compute_delta(), apply_sci() |
✅ |
| Sec 3.3 LAM | lam_attention(), build_layout_mask_from_bboxes() |
✅ |
| Appendix D 关键帧插值 | interpolate_bboxes() |
✅ |
| Training-free 集成 | patch_cross_attention(), wrap_model_fn() |
✅ |
性能特点
根据论文数据:
- T2V-CompBench 性能提升: +15.7% (Wan2.1-14B), +11.7% (Wan2.2-A14B)
- 推理时间增加: 仅 ~5%
- 兼容性: Wan, CogVideoX, VideoCrafter2, FLUX
注意事项
- 主体名称必须精确匹配:
subjects中的字符串必须能在 prompt 中找到对应的 token - BBox 使用像素坐标: 不是归一化坐标
- 关键帧数量: 推荐使用 4 个关键帧描述运动轨迹
- 温度参数: τ 过小会导致显著性估计不稳定,过大会削弱增强效果
- State tokens: per-frame state control adds extra context tokens and attention bias. Keep the number of states small to reduce overhead.
Per-frame State Control
You can inject per-frame instance states (e.g., "running" -> "idle") with:
state_texts: list of state names per subjectstate_weights: per-frame weights(M, F, S)or(B, M, F, S)state_scale: bias strength for state tokensstate_template: default"{subject} is {state}"
The implementation appends state tokens to the context and applies a per-frame attention bias based on the current token time index.
参考
- 论文: "Comp-Attn: Present-and-Align Attention for Compositional Video Generation"
- 作者: Hongyu Zhang, Yufan Deng, et al. (Peking University, Tsinghua University)