# Wan Video Instance Control:模型设计说明(bbox + per-frame state weights) 本文档描述当前 DiffSynth-Studio 的 **Wan Video Statemachine + Instance Control** 设计:仅使用 - `instance_ids`(区分同类不同个体) - `instance_class_text`(每个实例的 tag/class 文本) - `instance_state_texts`(每个实例的 **固定** state 文本集合) - `instance_state_weights`(**逐帧** state 权重,允许软融合) - `instance_bboxes`(**逐帧** 2D bbox,xyxy 像素坐标) 来驱动 DiT 中的 instance-aware cross attention。除以上输入外,其它 instance 相关字段(`class_id/state_id/mask/state_a/b/progress` 等)不再使用。 --- ## 1. 入口 API 与张量约定 入口在 `diffsynth/pipelines/wan_video_statemachine.py` 的 `WanVideoPipeline.__call__`。 ### 1.1 必需字段(启用 instance control 时) 当你传入以下任意一个字段不为 `None` 时,pipeline 视为启用 instance control,并要求 **全部提供**: - `instance_ids`: `Tensor`,形状 `(B, N)` 或 `(N,)`,dtype `long` - `N`:实例数(objects) - `instance_class_text`: `List[str]`(长度 `N`)或 `str`(单实例) - 每个实例一个 class/tag,例如 `"egg"`, `"dog"`, `"person"`… - `instance_state_texts`: `List[List[str]]`(形状 `N × S`)或 `List[str]`(单实例) - 每个实例有一个 **固定大小** 的 state 候选集合(`S` 个 state 文本) - 例如单实例:`["raw", "cooked"]`;多实例:`[["idle","run"], ["open","close"]]` - 约束:所有实例的 `S` 必须相同(当前实现强制)。 - `instance_state_weights`: `Tensor`/list,形状 `(B, N, F, S)` 或 `(N, F, S)`,dtype `float` - `F`:逐帧权重的时间长度(推荐等于输入视频帧数 `num_frames`,但允许不同,后续会映射/下采样到 patch-time) - `S`:state 数量,必须等于 `instance_state_texts` 的 state 数 - 语义:对每个 `(b,n,f)`,给出 `S` 个 state 的权重(可 one-hot,也可软融合) - `instance_bboxes`: `Tensor`,形状 `(B, N, F, 4)` 或 `(N, F, 4)`,dtype `float` - bbox 是 `xyxy`,单位为像素坐标,坐标系必须与推理时的 `height/width` 对齐 - 约束:`instance_bboxes.shape[2]` 必须等于 `instance_state_weights.shape[2]`(同一个 `F`) ### 1.2 推荐的常见配置 - **单实例**(N=1)+ 两状态(S=2): - `instance_class_text="egg"` - `instance_state_texts=["raw","cooked"]` - `instance_state_weights.shape=(1,1,F,2)` - `instance_bboxes.shape=(1,1,F,4)` - **多实例**(N>1): - `instance_class_text` 长度必须与 `N` 相同 - `instance_state_texts` 外层长度必须与 `N` 相同 --- ## 2. Pipeline 数据流(从输入到 model_fn) 对应代码: - `diffsynth/pipelines/wan_video_statemachine.py` - `WanVideoPipeline.__call__` - `WanVideoUnit_InstanceStateTextEmbedder` - `model_fn_wan_video` ### 2.1 参数归一化与校验(__call__) 在 `__call__` 中会把输入转为 Tensor,并补齐 batch 维: - `instance_ids`:若输入为 `(N,)` 会补成 `(1,N)` - `instance_bboxes`:若输入为 `(N,F,4)` 会补成 `(1,N,F,4)` - `instance_state_weights`:若输入为 `(N,F,S)` 会补成 `(1,N,F,S)` 启用 instance control 时会做关键校验: - 5 个输入必须同时存在:`ids/class_text/state_texts/state_weights/bboxes` - `state_weights` 与 `bboxes` 的 `F` 必须一致 ### 2.2 文本编码(WanVideoUnit_InstanceStateTextEmbedder) 该 unit 负责把 `(class_text, state_texts)` 变成可供 DiT 使用的 state phrase embedding: 1. 先构造短语: - 对每个实例 `n`,对每个 state `s`: - phrase = `" is "` 2. 使用 T5 encoder 编码短语序列,并做 mask-aware mean pooling 得到每个短语的 pooled embedding: - 输出 `instance_state_text_embeds_multi`,形状 `(1, N, S, text_dim)` 注意: - 这里不使用 `instance_state_weights` 做融合;融合在 DiT 内根据逐帧权重完成。 - unit 只产出 `instance_state_text_embeds_multi`,并且 pipeline 在 unit 之后会把 `instance_class_text/instance_state_texts` 从 `inputs_shared` 中移除,确保下游 model_fn 只接收张量(最小化接口)。 --- ## 3. DiT 内部设计(instance tokens + bbox mask-guided attention) 对应代码: - `diffsynth/models/wan_video_dit_instance.py` - `InstanceFeatureExtractor` - `MaskGuidedCrossAttention` - `DiTBlock.forward(..., instance_tokens, instance_masks)` ### 3.1 从“逐帧权重”生成“按 patch-time 的 instance tokens” 核心目标:把 per-frame 的 state 权重变成与 DiT patch token 的时间轴一致的 instance tokens,再对每个 patch 做 masked attention。 #### 输入 - `state_text_embeds_multi`: `(B, N, S, text_dim)` 每个 state 对应短语 `" is "` 的 pooled embedding - `state_weights`: `(B, N, F, S)` 每帧对 `S` 个 state 的权重 - `instance_ids`: `(B, N)` 用于区分同类个体 - `num_time_patches = f` DiT patchify 后的时间 patch 数(由 `patch_embedding` 决定) #### 步骤 1. **文本投影到 hidden_dim** - `sem_multi = text_proj(state_text_embeds_multi)` → `(B, N, S, H)` 2. **权重截断与时间下采样** - `weights = clamp(state_weights, min=0)` - 若 `F != f`:把 `(B,N,F,S)` 平均池化到 `(B,N,f,S)` - 映射规则:`pt = floor(t * f / F)` 3. **按权重对 state 语义做逐时间融合** - `sem_time[b,n,t] = sum_s( sem_multi[b,n,s] * w[b,n,t,s] ) / sum_s(w)` - 得到 `(B, N, f, H)` 4. **融合 instance_id embedding** - `i_feat = Emb(instance_ids)` → `(B, N, H)`,并广播到时间维 `(B, N, f, H)` - 拼接并通过 fusion MLP: - `token_time[b,n,t] = fusion( concat(sem_time[b,n,t], i_feat[b,n]) )` - 输出 `inst_tokens`:`(B, f, N, D)`(注意转置后时间维在前) ### 3.2 bbox → patch mask(每个 patch 是否被某实例覆盖) `WanModel.process_masks` 将 `instance_bboxes` 投影到 patch token 网格,返回 `inst_mask_flat`: - 输入 bbox:`(B, N, F, 4)`,`xyxy` 像素坐标 - patch 网格:`(f_p, h_p, w_p)` - 输出 mask:`(B, N, L)`,其中 `L = f_p * h_p * w_p` 关键映射规则: - 空间缩放: - `px = x * (w_p / W_img)` - `py = y * (h_p / H_img)` - 时间映射: - `pt = floor(t * f_p / F_bbox)` 最终每个 `(b,n,pt)` 上把 bbox 覆盖到的 `(py0:py1, px0:px1)` patch 置 1。 ### 3.3 MaskGuidedCrossAttention(log-mask trick) 每个 DiT block 都包含一个 instance cross attention: - Q:来自 patch tokens `x`(形状 `(B, L, D)`) - K/V:来自 instance tokens(按时间对齐后使用) - Mask:`(B, N, L)` attention logits 里加入 `log(mask)` 作为 bias: - `sim = (q · k) / sqrt(d)` - `sim = sim + log(mask.clamp(min=1e-6))` 这样 mask=0 的位置会得到接近 `-inf` 的 bias,从而 softmax 后强制为 0,实现 **只让每个 patch 关注覆盖它的实例**。 ### 3.4 时间对齐方式(per-time tokens vs per-token tokens) `MaskGuidedCrossAttention` 支持三种形状: - `(B, N, D)`:整段序列共享同一组 instance tokens(当前不用) - `(B, T, N, D)`:按 patch-time 切分(默认路径) - 假设序列按时间展开:`L = T * (h*w)`,按时间分段计算 attention - `(B, L, N, D)`:按 token 位置提供 instance tokens(用于 Unified Sequence Parallel) 在 `model_fn_wan_video` 开启 USP 时会将 `inst_tokens (B,T,N,D)` 转换成当前 rank 的 chunk 对应的 `(B, chunk_len, N, D)`: - 先计算每个 token 在全局序列里的位置 `global_pos` - `time_index = global_pos // (h*w)` - `inst_tokens_chunk = inst_tokens[:, time_index]` 并对 padding 部分置 0,避免污染。 ### 3.5 Reference latents 的处理 当 pipeline 使用 `reference_latents` 拼到序列前面时: - patch token 序列会多出 1 个时间片(`f += 1`) - `inst_mask_flat` 会在序列前补 0(reference 部分不属于任何 instance) - `inst_tokens` 也会在时间维前补 0(reference 时间片不注入 instance 语义) --- ## 4. 重要限制与注意事项 1. **必须给每个实例提供相同数量的 state 文本(S 必须一致)** 2. **`instance_state_weights` 与 `instance_bboxes` 的时间长度 `F` 必须一致** 3. **bbox 的像素坐标必须与推理时的 `height/width` 对齐** - 如果 pipeline 会 resize 输入图像/视频,你需要用 resize 后的坐标系提供 bbox 4. **sliding window 不支持 instance control** - `model_fn_wan_video` 在 `sliding_window_size/stride` 与 instance 输入同时存在时直接报错 --- ## 5. 最小可用示例(伪代码) ```python F = num_frames N = 1 S = 3 instance_ids = torch.tensor([[1]]) # (1,1) instance_class_text = ["egg"] # len=1 instance_state_texts = [["raw", "half", "cooked"]] # (N,S) # 逐帧权重 (1,1,F,3):例如线性从 raw -> cooked w = torch.zeros((1,1,F,S), dtype=torch.float32) t = torch.linspace(0, 1, F) w[0,0,:,0] = (1 - t) # raw w[0,0,:,2] = t # cooked w[0,0,:,1] = 0.0 # half (可选) # bbox (1,1,F,4):每帧一个 bbox,xyxy b = torch.zeros((1,1,F,4), dtype=torch.float32) b[0,0,:,0] = 100; b[0,0,:,1] = 120 b[0,0,:,2] = 260; b[0,0,:,3] = 320 video = pipe( prompt="...", height=H, width=W, num_frames=F, instance_ids=instance_ids, instance_class_text=instance_class_text, instance_state_texts=instance_state_texts, instance_state_weights=w, instance_bboxes=b, ) ``` --- ## 6. 代码入口索引 - Pipeline API / 文本编码: - `diffsynth/pipelines/wan_video_statemachine.py` - `WanVideoPipeline.__call__` - `WanVideoUnit_InstanceStateTextEmbedder` - `model_fn_wan_video` - Instance-aware DiT: - `diffsynth/models/wan_video_dit_instance.py` - `InstanceFeatureExtractor` - `MaskGuidedCrossAttention` - `WanModel.forward(... instance_*)`