PencilFolder / docs /wan_video_instance_control_design.md
PencilHu's picture
Upload folder using huggingface_hub
1146a67 verified

Wan Video Instance Control:模型设计说明(bbox + per-frame state weights)

本文档描述当前 DiffSynth-Studio 的 Wan Video Statemachine + Instance Control 设计:仅使用

  • instance_ids(区分同类不同个体)
  • instance_class_text(每个实例的 tag/class 文本)
  • instance_state_texts(每个实例的 固定 state 文本集合)
  • instance_state_weights逐帧 state 权重,允许软融合)
  • instance_bboxes逐帧 2D bbox,xyxy 像素坐标)

来驱动 DiT 中的 instance-aware cross attention。除以上输入外,其它 instance 相关字段(class_id/state_id/mask/state_a/b/progress 等)不再使用。


1. 入口 API 与张量约定

入口在 diffsynth/pipelines/wan_video_statemachine.pyWanVideoPipeline.__call__

1.1 必需字段(启用 instance control 时)

当你传入以下任意一个字段不为 None 时,pipeline 视为启用 instance control,并要求 全部提供

  • instance_ids: Tensor,形状 (B, N)(N,),dtype long
    • N:实例数(objects)
  • instance_class_text: List[str](长度 N)或 str(单实例)
    • 每个实例一个 class/tag,例如 "egg", "dog", "person"
  • instance_state_texts: List[List[str]](形状 N × S)或 List[str](单实例)
    • 每个实例有一个 固定大小 的 state 候选集合(S 个 state 文本)
    • 例如单实例:["raw", "cooked"];多实例:[["idle","run"], ["open","close"]]
    • 约束:所有实例的 S 必须相同(当前实现强制)。
  • instance_state_weights: Tensor/list,形状 (B, N, F, S)(N, F, S),dtype float
    • F:逐帧权重的时间长度(推荐等于输入视频帧数 num_frames,但允许不同,后续会映射/下采样到 patch-time)
    • S:state 数量,必须等于 instance_state_texts 的 state 数
    • 语义:对每个 (b,n,f),给出 S 个 state 的权重(可 one-hot,也可软融合)
  • instance_bboxes: Tensor,形状 (B, N, F, 4)(N, F, 4),dtype float
    • bbox 是 xyxy,单位为像素坐标,坐标系必须与推理时的 height/width 对齐
    • 约束:instance_bboxes.shape[2] 必须等于 instance_state_weights.shape[2](同一个 F

1.2 推荐的常见配置

  • 单实例(N=1)+ 两状态(S=2):
    • instance_class_text="egg"
    • instance_state_texts=["raw","cooked"]
    • instance_state_weights.shape=(1,1,F,2)
    • instance_bboxes.shape=(1,1,F,4)
  • 多实例(N>1):
    • instance_class_text 长度必须与 N 相同
    • instance_state_texts 外层长度必须与 N 相同

2. Pipeline 数据流(从输入到 model_fn)

对应代码:

  • diffsynth/pipelines/wan_video_statemachine.py
    • WanVideoPipeline.__call__
    • WanVideoUnit_InstanceStateTextEmbedder
    • model_fn_wan_video

2.1 参数归一化与校验(__call__)

__call__ 中会把输入转为 Tensor,并补齐 batch 维:

  • instance_ids:若输入为 (N,) 会补成 (1,N)
  • instance_bboxes:若输入为 (N,F,4) 会补成 (1,N,F,4)
  • instance_state_weights:若输入为 (N,F,S) 会补成 (1,N,F,S)

启用 instance control 时会做关键校验:

  • 5 个输入必须同时存在:ids/class_text/state_texts/state_weights/bboxes
  • state_weightsbboxesF 必须一致

2.2 文本编码(WanVideoUnit_InstanceStateTextEmbedder)

该 unit 负责把 (class_text, state_texts) 变成可供 DiT 使用的 state phrase embedding:

  1. 先构造短语:
    • 对每个实例 n,对每个 state s
      • phrase = "<class_text[n]> is <state_texts[n][s]>"
  2. 使用 T5 encoder 编码短语序列,并做 mask-aware mean pooling 得到每个短语的 pooled embedding:
    • 输出 instance_state_text_embeds_multi,形状 (1, N, S, text_dim)

注意:

  • 这里不使用 instance_state_weights 做融合;融合在 DiT 内根据逐帧权重完成。
  • unit 只产出 instance_state_text_embeds_multi,并且 pipeline 在 unit 之后会把 instance_class_text/instance_state_textsinputs_shared 中移除,确保下游 model_fn 只接收张量(最小化接口)。

3. DiT 内部设计(instance tokens + bbox mask-guided attention)

对应代码:

  • diffsynth/models/wan_video_dit_instance.py
    • InstanceFeatureExtractor
    • MaskGuidedCrossAttention
    • DiTBlock.forward(..., instance_tokens, instance_masks)

3.1 从“逐帧权重”生成“按 patch-time 的 instance tokens”

核心目标:把 per-frame 的 state 权重变成与 DiT patch token 的时间轴一致的 instance tokens,再对每个 patch 做 masked attention。

输入

  • state_text_embeds_multi: (B, N, S, text_dim)
    每个 state 对应短语 "<class> is <state>" 的 pooled embedding
  • state_weights: (B, N, F, S)
    每帧对 S 个 state 的权重
  • instance_ids: (B, N)
    用于区分同类个体
  • num_time_patches = f
    DiT patchify 后的时间 patch 数(由 patch_embedding 决定)

步骤

  1. 文本投影到 hidden_dim
    • sem_multi = text_proj(state_text_embeds_multi)(B, N, S, H)
  2. 权重截断与时间下采样
    • weights = clamp(state_weights, min=0)
    • F != f:把 (B,N,F,S) 平均池化到 (B,N,f,S)
      • 映射规则:pt = floor(t * f / F)
  3. 按权重对 state 语义做逐时间融合
    • sem_time[b,n,t] = sum_s( sem_multi[b,n,s] * w[b,n,t,s] ) / sum_s(w)
    • 得到 (B, N, f, H)
  4. 融合 instance_id embedding
    • i_feat = Emb(instance_ids)(B, N, H),并广播到时间维 (B, N, f, H)
    • 拼接并通过 fusion MLP:
      • token_time[b,n,t] = fusion( concat(sem_time[b,n,t], i_feat[b,n]) )
    • 输出 inst_tokens(B, f, N, D)(注意转置后时间维在前)

3.2 bbox → patch mask(每个 patch 是否被某实例覆盖)

WanModel.process_masksinstance_bboxes 投影到 patch token 网格,返回 inst_mask_flat

  • 输入 bbox:(B, N, F, 4)xyxy 像素坐标
  • patch 网格:(f_p, h_p, w_p)
  • 输出 mask:(B, N, L),其中 L = f_p * h_p * w_p

关键映射规则:

  • 空间缩放:
    • px = x * (w_p / W_img)
    • py = y * (h_p / H_img)
  • 时间映射:
    • pt = floor(t * f_p / F_bbox)

最终每个 (b,n,pt) 上把 bbox 覆盖到的 (py0:py1, px0:px1) patch 置 1。

3.3 MaskGuidedCrossAttention(log-mask trick)

每个 DiT block 都包含一个 instance cross attention:

  • Q:来自 patch tokens x(形状 (B, L, D)
  • K/V:来自 instance tokens(按时间对齐后使用)
  • Mask:(B, N, L)

attention logits 里加入 log(mask) 作为 bias:

  • sim = (q · k) / sqrt(d)
  • sim = sim + log(mask.clamp(min=1e-6))

这样 mask=0 的位置会得到接近 -inf 的 bias,从而 softmax 后强制为 0,实现 只让每个 patch 关注覆盖它的实例

3.4 时间对齐方式(per-time tokens vs per-token tokens)

MaskGuidedCrossAttention 支持三种形状:

  • (B, N, D):整段序列共享同一组 instance tokens(当前不用)
  • (B, T, N, D):按 patch-time 切分(默认路径)
    • 假设序列按时间展开:L = T * (h*w),按时间分段计算 attention
  • (B, L, N, D):按 token 位置提供 instance tokens(用于 Unified Sequence Parallel)

model_fn_wan_video 开启 USP 时会将 inst_tokens (B,T,N,D) 转换成当前 rank 的 chunk 对应的 (B, chunk_len, N, D)

  • 先计算每个 token 在全局序列里的位置 global_pos
  • time_index = global_pos // (h*w)
  • inst_tokens_chunk = inst_tokens[:, time_index]

并对 padding 部分置 0,避免污染。

3.5 Reference latents 的处理

当 pipeline 使用 reference_latents 拼到序列前面时:

  • patch token 序列会多出 1 个时间片(f += 1
  • inst_mask_flat 会在序列前补 0(reference 部分不属于任何 instance)
  • inst_tokens 也会在时间维前补 0(reference 时间片不注入 instance 语义)

4. 重要限制与注意事项

  1. 必须给每个实例提供相同数量的 state 文本(S 必须一致)
  2. instance_state_weightsinstance_bboxes 的时间长度 F 必须一致
  3. bbox 的像素坐标必须与推理时的 height/width 对齐
    • 如果 pipeline 会 resize 输入图像/视频,你需要用 resize 后的坐标系提供 bbox
  4. sliding window 不支持 instance control
    • model_fn_wan_videosliding_window_size/stride 与 instance 输入同时存在时直接报错

5. 最小可用示例(伪代码)

F = num_frames
N = 1
S = 3

instance_ids = torch.tensor([[1]])                 # (1,1)
instance_class_text = ["egg"]                      # len=1
instance_state_texts = [["raw", "half", "cooked"]] # (N,S)

# 逐帧权重 (1,1,F,3):例如线性从 raw -> cooked
w = torch.zeros((1,1,F,S), dtype=torch.float32)
t = torch.linspace(0, 1, F)
w[0,0,:,0] = (1 - t)  # raw
w[0,0,:,2] = t        # cooked
w[0,0,:,1] = 0.0      # half (可选)

# bbox (1,1,F,4):每帧一个 bbox,xyxy
b = torch.zeros((1,1,F,4), dtype=torch.float32)
b[0,0,:,0] = 100; b[0,0,:,1] = 120
b[0,0,:,2] = 260; b[0,0,:,3] = 320

video = pipe(
    prompt="...",
    height=H, width=W, num_frames=F,
    instance_ids=instance_ids,
    instance_class_text=instance_class_text,
    instance_state_texts=instance_state_texts,
    instance_state_weights=w,
    instance_bboxes=b,
)

6. 代码入口索引

  • Pipeline API / 文本编码:
    • diffsynth/pipelines/wan_video_statemachine.py
      • WanVideoPipeline.__call__
      • WanVideoUnit_InstanceStateTextEmbedder
      • model_fn_wan_video
  • Instance-aware DiT:
    • diffsynth/models/wan_video_dit_instance.py
      • InstanceFeatureExtractor
      • MaskGuidedCrossAttention
      • WanModel.forward(... instance_*)