Wan Video Instance Control:模型设计说明(bbox + per-frame state weights)
本文档描述当前 DiffSynth-Studio 的 Wan Video Statemachine + Instance Control 设计:仅使用
instance_ids(区分同类不同个体)instance_class_text(每个实例的 tag/class 文本)instance_state_texts(每个实例的 固定 state 文本集合)instance_state_weights(逐帧 state 权重,允许软融合)instance_bboxes(逐帧 2D bbox,xyxy 像素坐标)
来驱动 DiT 中的 instance-aware cross attention。除以上输入外,其它 instance 相关字段(class_id/state_id/mask/state_a/b/progress 等)不再使用。
1. 入口 API 与张量约定
入口在 diffsynth/pipelines/wan_video_statemachine.py 的 WanVideoPipeline.__call__。
1.1 必需字段(启用 instance control 时)
当你传入以下任意一个字段不为 None 时,pipeline 视为启用 instance control,并要求 全部提供:
instance_ids:Tensor,形状(B, N)或(N,),dtypelongN:实例数(objects)
instance_class_text:List[str](长度N)或str(单实例)- 每个实例一个 class/tag,例如
"egg","dog","person"…
- 每个实例一个 class/tag,例如
instance_state_texts:List[List[str]](形状N × S)或List[str](单实例)- 每个实例有一个 固定大小 的 state 候选集合(
S个 state 文本) - 例如单实例:
["raw", "cooked"];多实例:[["idle","run"], ["open","close"]] - 约束:所有实例的
S必须相同(当前实现强制)。
- 每个实例有一个 固定大小 的 state 候选集合(
instance_state_weights:Tensor/list,形状(B, N, F, S)或(N, F, S),dtypefloatF:逐帧权重的时间长度(推荐等于输入视频帧数num_frames,但允许不同,后续会映射/下采样到 patch-time)S:state 数量,必须等于instance_state_texts的 state 数- 语义:对每个
(b,n,f),给出S个 state 的权重(可 one-hot,也可软融合)
instance_bboxes:Tensor,形状(B, N, F, 4)或(N, F, 4),dtypefloat- bbox 是
xyxy,单位为像素坐标,坐标系必须与推理时的height/width对齐 - 约束:
instance_bboxes.shape[2]必须等于instance_state_weights.shape[2](同一个F)
- bbox 是
1.2 推荐的常见配置
- 单实例(N=1)+ 两状态(S=2):
instance_class_text="egg"instance_state_texts=["raw","cooked"]instance_state_weights.shape=(1,1,F,2)instance_bboxes.shape=(1,1,F,4)
- 多实例(N>1):
instance_class_text长度必须与N相同instance_state_texts外层长度必须与N相同
2. Pipeline 数据流(从输入到 model_fn)
对应代码:
diffsynth/pipelines/wan_video_statemachine.pyWanVideoPipeline.__call__WanVideoUnit_InstanceStateTextEmbeddermodel_fn_wan_video
2.1 参数归一化与校验(__call__)
在 __call__ 中会把输入转为 Tensor,并补齐 batch 维:
instance_ids:若输入为(N,)会补成(1,N)instance_bboxes:若输入为(N,F,4)会补成(1,N,F,4)instance_state_weights:若输入为(N,F,S)会补成(1,N,F,S)
启用 instance control 时会做关键校验:
- 5 个输入必须同时存在:
ids/class_text/state_texts/state_weights/bboxes state_weights与bboxes的F必须一致
2.2 文本编码(WanVideoUnit_InstanceStateTextEmbedder)
该 unit 负责把 (class_text, state_texts) 变成可供 DiT 使用的 state phrase embedding:
- 先构造短语:
- 对每个实例
n,对每个 states:- phrase =
"<class_text[n]> is <state_texts[n][s]>"
- phrase =
- 对每个实例
- 使用 T5 encoder 编码短语序列,并做 mask-aware mean pooling 得到每个短语的 pooled embedding:
- 输出
instance_state_text_embeds_multi,形状(1, N, S, text_dim)
- 输出
注意:
- 这里不使用
instance_state_weights做融合;融合在 DiT 内根据逐帧权重完成。 - unit 只产出
instance_state_text_embeds_multi,并且 pipeline 在 unit 之后会把instance_class_text/instance_state_texts从inputs_shared中移除,确保下游 model_fn 只接收张量(最小化接口)。
3. DiT 内部设计(instance tokens + bbox mask-guided attention)
对应代码:
diffsynth/models/wan_video_dit_instance.pyInstanceFeatureExtractorMaskGuidedCrossAttentionDiTBlock.forward(..., instance_tokens, instance_masks)
3.1 从“逐帧权重”生成“按 patch-time 的 instance tokens”
核心目标:把 per-frame 的 state 权重变成与 DiT patch token 的时间轴一致的 instance tokens,再对每个 patch 做 masked attention。
输入
state_text_embeds_multi:(B, N, S, text_dim)
每个 state 对应短语"<class> is <state>"的 pooled embeddingstate_weights:(B, N, F, S)
每帧对S个 state 的权重instance_ids:(B, N)
用于区分同类个体num_time_patches = f
DiT patchify 后的时间 patch 数(由patch_embedding决定)
步骤
- 文本投影到 hidden_dim
sem_multi = text_proj(state_text_embeds_multi)→(B, N, S, H)
- 权重截断与时间下采样
weights = clamp(state_weights, min=0)- 若
F != f:把(B,N,F,S)平均池化到(B,N,f,S)- 映射规则:
pt = floor(t * f / F)
- 映射规则:
- 按权重对 state 语义做逐时间融合
sem_time[b,n,t] = sum_s( sem_multi[b,n,s] * w[b,n,t,s] ) / sum_s(w)- 得到
(B, N, f, H)
- 融合 instance_id embedding
i_feat = Emb(instance_ids)→(B, N, H),并广播到时间维(B, N, f, H)- 拼接并通过 fusion MLP:
token_time[b,n,t] = fusion( concat(sem_time[b,n,t], i_feat[b,n]) )
- 输出
inst_tokens:(B, f, N, D)(注意转置后时间维在前)
3.2 bbox → patch mask(每个 patch 是否被某实例覆盖)
WanModel.process_masks 将 instance_bboxes 投影到 patch token 网格,返回 inst_mask_flat:
- 输入 bbox:
(B, N, F, 4),xyxy像素坐标 - patch 网格:
(f_p, h_p, w_p) - 输出 mask:
(B, N, L),其中L = f_p * h_p * w_p
关键映射规则:
- 空间缩放:
px = x * (w_p / W_img)py = y * (h_p / H_img)
- 时间映射:
pt = floor(t * f_p / F_bbox)
最终每个 (b,n,pt) 上把 bbox 覆盖到的 (py0:py1, px0:px1) patch 置 1。
3.3 MaskGuidedCrossAttention(log-mask trick)
每个 DiT block 都包含一个 instance cross attention:
- Q:来自 patch tokens
x(形状(B, L, D)) - K/V:来自 instance tokens(按时间对齐后使用)
- Mask:
(B, N, L)
attention logits 里加入 log(mask) 作为 bias:
sim = (q · k) / sqrt(d)sim = sim + log(mask.clamp(min=1e-6))
这样 mask=0 的位置会得到接近 -inf 的 bias,从而 softmax 后强制为 0,实现 只让每个 patch 关注覆盖它的实例。
3.4 时间对齐方式(per-time tokens vs per-token tokens)
MaskGuidedCrossAttention 支持三种形状:
(B, N, D):整段序列共享同一组 instance tokens(当前不用)(B, T, N, D):按 patch-time 切分(默认路径)- 假设序列按时间展开:
L = T * (h*w),按时间分段计算 attention
- 假设序列按时间展开:
(B, L, N, D):按 token 位置提供 instance tokens(用于 Unified Sequence Parallel)
在 model_fn_wan_video 开启 USP 时会将 inst_tokens (B,T,N,D) 转换成当前 rank 的 chunk 对应的 (B, chunk_len, N, D):
- 先计算每个 token 在全局序列里的位置
global_pos time_index = global_pos // (h*w)inst_tokens_chunk = inst_tokens[:, time_index]
并对 padding 部分置 0,避免污染。
3.5 Reference latents 的处理
当 pipeline 使用 reference_latents 拼到序列前面时:
- patch token 序列会多出 1 个时间片(
f += 1) inst_mask_flat会在序列前补 0(reference 部分不属于任何 instance)inst_tokens也会在时间维前补 0(reference 时间片不注入 instance 语义)
4. 重要限制与注意事项
- 必须给每个实例提供相同数量的 state 文本(S 必须一致)
instance_state_weights与instance_bboxes的时间长度F必须一致- bbox 的像素坐标必须与推理时的
height/width对齐- 如果 pipeline 会 resize 输入图像/视频,你需要用 resize 后的坐标系提供 bbox
- sliding window 不支持 instance control
model_fn_wan_video在sliding_window_size/stride与 instance 输入同时存在时直接报错
5. 最小可用示例(伪代码)
F = num_frames
N = 1
S = 3
instance_ids = torch.tensor([[1]]) # (1,1)
instance_class_text = ["egg"] # len=1
instance_state_texts = [["raw", "half", "cooked"]] # (N,S)
# 逐帧权重 (1,1,F,3):例如线性从 raw -> cooked
w = torch.zeros((1,1,F,S), dtype=torch.float32)
t = torch.linspace(0, 1, F)
w[0,0,:,0] = (1 - t) # raw
w[0,0,:,2] = t # cooked
w[0,0,:,1] = 0.0 # half (可选)
# bbox (1,1,F,4):每帧一个 bbox,xyxy
b = torch.zeros((1,1,F,4), dtype=torch.float32)
b[0,0,:,0] = 100; b[0,0,:,1] = 120
b[0,0,:,2] = 260; b[0,0,:,3] = 320
video = pipe(
prompt="...",
height=H, width=W, num_frames=F,
instance_ids=instance_ids,
instance_class_text=instance_class_text,
instance_state_texts=instance_state_texts,
instance_state_weights=w,
instance_bboxes=b,
)
6. 代码入口索引
- Pipeline API / 文本编码:
diffsynth/pipelines/wan_video_statemachine.pyWanVideoPipeline.__call__WanVideoUnit_InstanceStateTextEmbeddermodel_fn_wan_video
- Instance-aware DiT:
diffsynth/models/wan_video_dit_instance.pyInstanceFeatureExtractorMaskGuidedCrossAttentionWanModel.forward(... instance_*)