PencilFolder / docs /comp_attn_design.md
PencilHu's picture
Upload folder using huggingface_hub
1146a67 verified

Comp-Attn 设计文档

基于论文 "Comp-Attn: Present-and-Align Attention for Compositional Video Generation" 的实现总结

概述

Comp-Attn 是一种 composition-aware cross-attention 变体,采用 "Present-and-Align" 范式,用于解决多主体视频生成中的两个核心问题:

挑战 描述 解决方案
Subject Presence 并非所有主体都能在视频中呈现 SCI(条件层面)
Inter-subject Relations 主体间的交互和空间关系错位 LAM(注意力层面)

核心组件

1. Subject-aware Condition Interpolation (SCI)

SCI 在 条件编码阶段 增强每个主体的语义表示,确保所有主体都能被"召回"。

工作原理

┌─────────────────────────────────────────────────────────────┐
│  输入: prompt 中的主体 tokens + 独立编码的 anchor embeddings │
├─────────────────────────────────────────────────────────────┤
│  Step 1: 计算语义显著性 (Saliency)                           │
│          - 比较 prompt 中主体 token 与 anchor 的余弦相似度    │
│          - 使用 softmax(τ=0.2) 归一化得到显著性分数            │
│          - 低显著性 = 主体在 prompt 上下文中被"淹没"          │
├─────────────────────────────────────────────────────────────┤
│  Step 2: 计算语义差异 (Delta)                                │
│          - delta_i = anchor_i * N - Σ(anchors)              │
│          - 表示每个主体相对于其他主体的独特语义               │
├─────────────────────────────────────────────────────────────┤
│  Step 3: 自适应插值                                          │
│          - scale = ω * (1 - saliency)                       │
│          - context' = context + delta * scale               │
│          - 显著性越低,增强越多;早期时间步增强越强           │
└─────────────────────────────────────────────────────────────┘

关键代码

# 显著性计算
def compute_saliency(prompt_vecs, anchor_vecs, tau=0.2):
    cosine = cosine_similarity(prompt_vecs, anchor_vecs)
    scores = exp(cosine / tau)
    return scores.diagonal() / scores.sum(dim=1)

# 应用 SCI
omega = 1.0 - (timestep / 1000.0)  # 时间步调度
scale = omega * (1.0 - saliency)
context = context + delta * scale

2. Layout-forcing Attention Modulation (LAM)

LAM 在 注意力计算阶段 动态调制注意力分布,使其与预定义的空间布局对齐。

工作原理

┌─────────────────────────────────────────────────────────────┐
│  输入: Q (视频 tokens), K/V (文本 tokens), Layout (bbox)     │
├─────────────────────────────────────────────────────────────┤
│  Step 1: 计算原始 attention scores                           │
│          attn_scores = Q @ K^T / sqrt(d)                    │
├─────────────────────────────────────────────────────────────┤
│  Step 2: 构建调制函数                                        │
│          g_plus  = max(scores) - scores  (增强调制)         │
│          g_minus = min(scores) - scores  (抑制调制)         │
├─────────────────────────────────────────────────────────────┤
│  Step 3: 计算 IOU 引导的强度                                 │
│          - adapt_mask: 当前注意力分布 > mean 的区域          │
│          - layout_mask: 目标 bbox 区域                       │
│          - iou = intersection / union                       │
│          - strength = 1 - iou (IOU 越低,调制越强)          │
├─────────────────────────────────────────────────────────────┤
│  Step 4: 应用动态调制                                        │
│          - 在 bbox 内增强对应主体 token (g_plus)             │
│          - 抑制其他主体 token (g_minus)                      │
│          - bias = g_k * strength * layout_mask              │
│          - final_scores = attn_scores + bias                │
└─────────────────────────────────────────────────────────────┘

LAM vs 传统 Layout Control

方法 机制 缺点
硬掩码 (Hard Mask) 强制注意力只在 bbox 内 无法适应多样的物体形状
LAM (Ours) IOU 引导的动态软调制 ✅ 灵活适应不同形状

3. 关键帧插值

遵循论文附录 F 的设计,支持用 4 个关键帧描述运动轨迹,然后线性插值到所有帧:

# 输入: 4 个关键帧的 bbox
keyframe_bboxes = [frame_0, frame_1, frame_2, frame_3]

# 输出: 81 帧的 bbox(线性插值)
all_bboxes = F.interpolate(keyframe_bboxes, size=81, mode='linear')

架构设计

┌─────────────────────────────────────────────────────────────┐
│                    WanVideoCompAttnPipeline                 │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ┌──────────────────┐    ┌──────────────────┐               │
│  │  CompAttnUnit    │───▶│  CompAttnMerge   │               │
│  │  (SCI 预处理)    │    │  (CFG 合并)      │               │
│  └──────────────────┘    └──────────────────┘               │
│           │                                                 │
│           ▼                                                 │
│  ┌──────────────────────────────────────────────────────┐  │
│  │  model_fn_wrapper (动态注入)                          │  │
│  │  - apply_sci(): 修改 context                         │  │
│  │  - build_layout_mask(): 构建空间掩码                  │  │
│  └──────────────────────────────────────────────────────┘  │
│           │                                                 │
│           ▼                                                 │
│  ┌──────────────────────────────────────────────────────┐  │
│  │  patched cross_attention                              │  │
│  │  - lam_attention(): IOU 引导的注意力调制              │  │
│  └──────────────────────────────────────────────────────┘  │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Prompt-BBox 绑定机制

使用变量拼接(推荐)

使用 Python 变量定义主体,通过 f-string 拼接 prompt,代码层面天然表达绑定关系:

# 定义主体变量
subject0 = "red car"
subject1 = "blue bicycle"

# 使用变量拼接 prompt
prompt = f"A {subject0} drives left, a {subject1} rides right"

# subjects 和 bboxes 顺序一一对应
subjects = [subject0, subject1]
bboxes = [bbox0, bbox1]  # bbox0 对应 subject0, bbox1 对应 subject1

绑定原理

┌─────────────────────────────────────────────────────────────────────────┐
│  1. 用户定义 subjects 和 bboxes(顺序一一对应)                          │
│                                                                         │
│     subject0 = "red car"                                                │
│     subject1 = "blue bicycle"                                           │
│     subjects = [subject0, subject1]                                     │
│     bboxes = [bbox0, bbox1]                                             │
├─────────────────────────────────────────────────────────────────────────┤
│  2. Tokenize prompt,搜索每个 subject 的 token 位置                      │
│                                                                         │
│     prompt = f"A {subject0} drives while a {subject1} rides..."        │
│                  ↑↑↑↑↑↑↑                  ↑↑↑↑↑↑↑↑↑↑↑↑                 │
│                  token indices            token indices                 │
├─────────────────────────────────────────────────────────────────────────┤
│  3. 建立 subject_token_mask (关联 token 位置与 bbox)                     │
│                                                                         │
│     subject_token_mask[0, ...] = True  # subject0 tokens                │
│     subject_token_mask[1, ...] = True  # subject1 tokens                │
├─────────────────────────────────────────────────────────────────────────┤
│  4. 在推理时应用:                                                       │
│     - SCI: 根据 mask 增强对应 token 的语义                               │
│     - LAM: 在 bbox 区域内增强对应 token 的注意力                          │
└─────────────────────────────────────────────────────────────────────────┘

重要约束

  1. subjects 必须在 prompt 中出现:使用变量拼接确保这一点
  2. 顺序一一对应bboxes[i] 对应 subjects[i]
  3. 精确匹配:subject 字符串需要能被分词器识别为完整的 token 序列

使用方法

基本用法

from diffsynth.pipelines.wan_video_comp_attn import WanVideoCompAttnPipeline
from diffsynth.models.comp_attn_model import CompAttnConfig

# 创建 pipeline
pipe = WanVideoCompAttnPipeline.from_pretrained(...)

# 1. 定义主体变量
subject0 = "red car"
subject1 = "blue bicycle"

# 2. 定义运动轨迹 (4 个关键帧)
bbox0 = [(100, 250, 220, 380), (200, 250, 320, 380), (300, 250, 420, 380), (400, 250, 520, 380)]
bbox1 = [(350, 260, 410, 400), (450, 260, 510, 400), (550, 260, 610, 400), (650, 260, 710, 400)]

# 3. 使用变量拼接 prompt
prompt = f"A {subject0} drives from left to center while a {subject1} rides to the right"

# 4. 配置 Comp-Attn(顺序一一对应)
comp_attn = CompAttnConfig(
    subjects=[subject0, subject1],  # 变量列表
    bboxes=[bbox0, bbox1],          # 对应的 bbox 列表
    enable_sci=True,
    enable_lam=True,
    interpolate=True,
)

# 5. 生成视频
video = pipe(prompt=prompt, comp_attn=comp_attn)

运动轨迹辅助函数

def create_moving_bbox(start_x, end_x, y_center, box_width, box_height, num_keyframes=4):
    """创建从 start_x 移动到 end_x 的关键帧 bbox 序列"""
    keyframes = []
    for i in range(num_keyframes):
        progress = i / (num_keyframes - 1)
        center_x = start_x + (end_x - start_x) * progress
        left = center_x - box_width / 2
        right = center_x + box_width / 2
        top = y_center - box_height / 2
        bottom = y_center + box_height / 2
        keyframes.append((left, top, right, bottom))
    return keyframes

# 使用示例
car_trajectory = create_moving_bbox(
    start_x=100, end_x=500,
    y_center=300,
    box_width=120, box_height=80,
)

配置参数

参数 类型 默认值 说明
subjects List[str] - 主体名称列表(必须出现在 prompt 中)
bboxes List None 每个主体每个关键帧的 bbox
enable_sci bool True 是否启用 SCI
enable_lam bool True 是否启用 LAM
temperature float 0.2 显著性计算的温度参数 τ
apply_to_negative bool False 是否对负样本 prompt 应用
interpolate bool False 是否对关键帧 bbox 进行插值

BBox 格式

# 格式: (left, top, right, bottom) 像素坐标
# 支持多种输入形式:

# 1. 静态布局(所有帧相同)
bbox = (100, 200, 300, 400)

# 2. 关键帧布局(4帧,会被插值)
bboxes = [
    (100, 200, 300, 400),  # 关键帧 0
    (150, 200, 350, 400),  # 关键帧 1
    (200, 200, 400, 400),  # 关键帧 2
    (250, 200, 450, 400),  # 关键帧 3
]

与论文的对应关系

论文章节 实现位置 状态
Sec 3.2 SCI compute_saliency(), compute_delta(), apply_sci()
Sec 3.3 LAM lam_attention(), build_layout_mask_from_bboxes()
Appendix D 关键帧插值 interpolate_bboxes()
Training-free 集成 patch_cross_attention(), wrap_model_fn()

性能特点

根据论文数据:

  • T2V-CompBench 性能提升: +15.7% (Wan2.1-14B), +11.7% (Wan2.2-A14B)
  • 推理时间增加: 仅 ~5%
  • 兼容性: Wan, CogVideoX, VideoCrafter2, FLUX

注意事项

  1. 主体名称必须精确匹配: subjects 中的字符串必须能在 prompt 中找到对应的 token
  2. BBox 使用像素坐标: 不是归一化坐标
  3. 关键帧数量: 推荐使用 4 个关键帧描述运动轨迹
  4. 温度参数: τ 过小会导致显著性估计不稳定,过大会削弱增强效果
  5. State tokens: per-frame state control adds extra context tokens and attention bias. Keep the number of states small to reduce overhead.

Per-frame State Control

You can inject per-frame instance states (e.g., "running" -> "idle") with:

  • state_texts: list of state names per subject
  • state_weights: per-frame weights (M, F, S) or (B, M, F, S)
  • state_scale: bias strength for state tokens
  • state_template: default "{subject} is {state}"

The implementation appends state tokens to the context and applies a per-frame attention bias based on the current token time index.

参考

  • 论文: "Comp-Attn: Present-and-Align Attention for Compositional Video Generation"
  • 作者: Hongyu Zhang, Yufan Deng, et al. (Peking University, Tsinghua University)