PencilFolder / diffsynth /models /wan_video_dit_slots.py
PencilHu's picture
Upload folder using huggingface_hub
1146a67 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple, Optional, Dict, Any
from einops import rearrange
from .wan_video_camera_controller import SimpleAdapter
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTN_AVAILABLE = True
except ModuleNotFoundError:
SAGE_ATTN_AVAILABLE = False
def flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
num_heads: int,
compatibility_mode: bool = False
):
if compatibility_mode:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
elif FLASH_ATTN_3_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v)
if isinstance(x, tuple):
x = x[0]
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif FLASH_ATTN_2_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn.flash_attn_func(q, k, v)
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif SAGE_ATTN_AVAILABLE:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = sageattn(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
else:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
return x
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return (x * (1 + scale) + shift)
def sinusoidal_embedding_1d(dim, position):
sinusoid = torch.outer(
position.type(torch.float64),
torch.pow(
10000,
-torch.arange(dim // 2, dtype=torch.float64, device=position.device).div(dim // 2),
),
)
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x.to(position.dtype)
def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
# 3d rope precompute
f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
return f_freqs_cis, h_freqs_cis, w_freqs_cis
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
# 1d rope precompute
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].double() / dim))
freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
x_out = torch.view_as_complex(
x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
)
x_out = torch.view_as_real(x_out * freqs).flatten(2)
return x_out.to(x.dtype)
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def forward(self, x):
dtype = x.dtype
return self.norm(x.float()).to(dtype) * self.weight
class AttentionModule(nn.Module):
def __init__(self, num_heads):
super().__init__()
self.num_heads = num_heads
def forward(self, q, k, v):
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)
return x
class SelfAttention(nn.Module):
"""原有 SelfAttention:带 RoPE,给 video patch tokens 用"""
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads)
def forward(self, x, freqs):
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
q = rope_apply(q, freqs, self.num_heads)
k = rope_apply(k, freqs, self.num_heads)
x = self.attn(q, k, v)
return self.o(x)
class SelfAttentionNoRoPE(nn.Module):
"""给 slots 用的 self-attn:不加 RoPE(slot 没有稳定网格位置信息时更稳)"""
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads)
def forward(self, x):
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
x = self.attn(q, k, v)
return self.o(x)
class CrossAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.has_image_input = has_image_input
if has_image_input:
self.k_img = nn.Linear(dim, dim)
self.v_img = nn.Linear(dim, dim)
self.norm_k_img = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads)
def forward(self, x: torch.Tensor, y: torch.Tensor):
"""
x: queries
y: keys/values (context)
"""
if self.has_image_input:
img = y[:, :257]
ctx = y[:, 257:]
else:
ctx = y
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(ctx))
v = self.v(ctx)
x_out = self.attn(q, k, v)
if self.has_image_input:
k_img = self.norm_k_img(self.k_img(img))
v_img = self.v_img(img)
y_img = flash_attention(q, k_img, v_img, num_heads=self.num_heads)
x_out = x_out + y_img
return self.o(x_out)
class GateModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, gate, residual):
return x + gate * residual
# --------------------------
# ROI gather/scatter helpers
# --------------------------
def gather_tokens(x: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
"""
x: (b, s, c)
idx: (b, m) long
return: (b, m, c)
"""
b, s, c = x.shape
return x.gather(1, idx[..., None].expand(b, idx.shape[1], c))
def scatter_tokens(x: torch.Tensor, idx: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
x: (b, s, c)
idx: (b, m) long
y: (b, m, c)
return: (b, s, c) with y written back to idx positions
"""
b, s, c = x.shape
out = x.clone()
out.scatter_(1, idx[..., None].expand(b, idx.shape[1], c), y)
return out
def bbox_to_mask(bbox_xyxy: torch.Tensor, H: int, W: int) -> torch.Tensor:
"""
bbox_xyxy: (b, 4) float/int in [0,W) [0,H)
return: (b, 1, H, W) float {0,1}
"""
b = bbox_xyxy.shape[0]
mask = torch.zeros((b, 1, H, W), device=bbox_xyxy.device, dtype=torch.float32)
x1, y1, x2, y2 = bbox_xyxy[:, 0], bbox_xyxy[:, 1], bbox_xyxy[:, 2], bbox_xyxy[:, 3]
x1 = x1.clamp(0, W - 1).long()
x2 = x2.clamp(0, W).long()
y1 = y1.clamp(0, H - 1).long()
y2 = y2.clamp(0, H).long()
for i in range(b):
mask[i, 0, y1[i]:y2[i], x1[i]:x2[i]] = 1.0
return mask
@torch.no_grad()
def mask_to_roi_idx(
mask: torch.Tensor,
grid_fhw: Tuple[int, int, int],
*,
frame_index: int = 0,
roi_token_budget: int = 256,
mode: str = "topk",
) -> torch.Tensor:
"""
把单帧 mask 下采样到 patch 网格,输出固定长度 roi_idx(gather/scatter 用)。
mask: (b, H, W) or (b,1,H,W) —— 推理你说只给一帧,就传这一帧的 mask
grid_fhw: (f, h, w) from patchify
frame_index: 指定这次交互发生在第几帧(默认 0)
roi_token_budget: 固定 m,避免 flash-attn 变长不兼容
mode: "topk" 或 "random"
"""
if mask.dim() == 3:
mask = mask[:, None]
b, _, H, W = mask.shape
f, h, w = grid_fhw
assert 0 <= frame_index < f, f"frame_index {frame_index} out of range f={f}"
# 下采样到 patch 网格 (h,w)
m_small = F.interpolate(mask.float(), size=(h, w), mode="bilinear", align_corners=False) # (b,1,h,w)
m_small = m_small[:, 0] # (b,h,w)
flat = m_small.reshape(b, h * w) # (b, h*w)
# 变成全局 token index:token 顺序是 (f,h,w) flatten
base = frame_index * (h * w)
if mode == "topk":
scores = flat
k = min(roi_token_budget, h * w)
topv, topi = torch.topk(scores, k=k, dim=1)
# 如果 k < budget,补齐
if k < roi_token_budget:
pad = topi[:, :1].expand(b, roi_token_budget - k)
topi = torch.cat([topi, pad], dim=1)
idx = topi[:, :roi_token_budget] + base
return idx.long()
if mode == "random":
# 从非零位置随机采样,不够就重复补齐
idx_list = []
for bi in range(b):
nz = torch.nonzero(flat[bi] > 0.01, as_tuple=False).flatten()
if nz.numel() == 0:
# 全 0,退化为随机
nz = torch.arange(h * w, device=mask.device)
if nz.numel() >= roi_token_budget:
sel = nz[torch.randperm(nz.numel(), device=mask.device)[:roi_token_budget]]
else:
rep = nz[torch.randint(0, nz.numel(), (roi_token_budget,), device=mask.device)]
sel = rep
idx_list.append(sel)
idx = torch.stack(idx_list, dim=0) + base
return idx.long()
raise ValueError(f"Unknown mode={mode}")
# --------------------------
# Slots-enabled DiT block
# --------------------------
class DiTBlockWithSlots(nn.Module):
"""
在原 DiTBlock 基础上加入:
- text -> slots
- roi_patches -> slots
- slots self-attn(实例交互)
- slots -> roi_patches(写回)
并且默认关闭原来的 patch <- text 全局 cross-attn(避免全局污染)。
"""
def __init__(
self,
has_image_input: bool,
dim: int,
num_heads: int,
ffn_dim: int,
eps: float = 1e-6,
enable_patch_text_cross_attn: bool = False,
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.ffn_dim = ffn_dim
self.enable_patch_text_cross_attn = enable_patch_text_cross_attn
# patch path (保持原逻辑)
self.self_attn = SelfAttention(dim, num_heads, eps)
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim),
nn.GELU(approximate="tanh"),
nn.Linear(ffn_dim, dim),
)
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
self.gate = GateModule()
# (可选) patch <- text(原来的 cross-attn)
if enable_patch_text_cross_attn:
self.cross_attn = CrossAttention(dim, num_heads, eps, has_image_input=has_image_input)
self.norm3 = nn.LayerNorm(dim, eps=eps)
# slot modules
self.slot_norm = nn.LayerNorm(dim, eps=eps)
self.slot_text_attn = CrossAttention(dim, num_heads, eps, has_image_input=has_image_input) # slots <- context
self.slot_from_patch = CrossAttention(dim, num_heads, eps, has_image_input=False) # slots <- roi_patches
self.slot_self = SelfAttentionNoRoPE(dim, num_heads, eps) # slots <-> slots
self.patch_from_slot = CrossAttention(dim, num_heads, eps, has_image_input=False) # roi_patches <- slots
def forward(
self,
x: torch.Tensor, # (b, s, dim) patch tokens
slots: torch.Tensor, # (b, n_slots, dim)
context: torch.Tensor, # (b, n_ctx, dim)
t_mod: torch.Tensor, # (b, 6, dim)
freqs: torch.Tensor, # (s, 1, rope_dim) complex
roi_idx: Optional[torch.Tensor] = None, # (b, m)
):
# ---- original patch self-attn + gated adaLN ----
has_seq = len(t_mod.shape) == 4
chunk_dim = 2 if has_seq else 1
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod
).chunk(6, dim=chunk_dim)
if has_seq:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
shift_msa.squeeze(2),
scale_msa.squeeze(2),
gate_msa.squeeze(2),
shift_mlp.squeeze(2),
scale_mlp.squeeze(2),
gate_mlp.squeeze(2),
)
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
# ---- ROI select (空间局部) ----
if roi_idx is None:
x_roi = x
else:
x_roi = gather_tokens(x, roi_idx)
# ---- text -> slots (实例受文本/指令影响) ----
slots = slots + self.slot_text_attn(self.slot_norm(slots), context)
# ---- video(ROI) -> slots (实例从局部视频读取状态/外观) ----
slots = slots + self.slot_from_patch(self.slot_norm(slots), x_roi)
# ---- slots self-attn (实例之间交互) ----
slots = slots + self.slot_self(self.slot_norm(slots))
# ---- slots -> video(ROI) (把交互/状态写回局部视频) ----
x_roi = x_roi + self.patch_from_slot(self.slot_norm(x_roi), slots)
if roi_idx is None:
x = x_roi
else:
x = scatter_tokens(x, roi_idx, x_roi)
# ---- (optional) patch <- text (不推荐默认开) ----
if self.enable_patch_text_cross_attn:
x = x + self.cross_attn(self.norm3(x), context)
# ---- original FFN ----
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = self.gate(x, gate_mlp, self.ffn(input_x))
return x, slots
class MLP(torch.nn.Module):
def __init__(self, in_dim, out_dim, has_pos_emb=False):
super().__init__()
self.proj = torch.nn.Sequential(
nn.LayerNorm(in_dim),
nn.Linear(in_dim, in_dim),
nn.GELU(),
nn.Linear(in_dim, out_dim),
nn.LayerNorm(out_dim),
)
self.has_pos_emb = has_pos_emb
if has_pos_emb:
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
def forward(self, x):
if self.has_pos_emb:
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
return self.proj(x)
class Head(nn.Module):
def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
super().__init__()
self.dim = dim
self.patch_size = patch_size
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
def forward(self, x, t_mod):
if len(t_mod.shape) == 3:
shift, scale = (
self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)
).chunk(2, dim=2)
x = self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2))
else:
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
x = self.head(self.norm(x) * (1 + scale) + shift)
return x
class WanModel(torch.nn.Module):
def __init__(
self,
dim: int,
in_dim: int,
ffn_dim: int,
out_dim: int,
text_dim: int,
freq_dim: int,
eps: float,
patch_size: Tuple[int, int, int],
num_heads: int,
num_layers: int,
has_image_input: bool,
has_image_pos_emb: bool = False,
has_ref_conv: bool = False,
add_control_adapter: bool = False,
in_dim_control_adapter: int = 24,
seperated_timestep: bool = False,
require_vae_embedding: bool = True,
require_clip_embedding: bool = True,
fuse_vae_embedding_in_latents: bool = False,
# -------- slots args (新增) --------
enable_slots: bool = True,
num_slots: int = 16,
instance_state_dim: int = 0, # 你的 InstanceCap state 维度
state_head_dim: int = 0, # 如果 >0,输出 slots->state_pred 用于监督
enable_patch_text_cross_attn: bool = False, # 不推荐默认开
):
super().__init__()
self.dim = dim
self.in_dim = in_dim
self.freq_dim = freq_dim
self.has_image_input = has_image_input
self.patch_size = patch_size
self.seperated_timestep = seperated_timestep
self.require_vae_embedding = require_vae_embedding
self.require_clip_embedding = require_clip_embedding
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
self.enable_slots = enable_slots
self.num_slots = num_slots
self.instance_state_dim = instance_state_dim
self.state_head_dim = state_head_dim
self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim),
nn.GELU(approximate="tanh"),
nn.Linear(dim, dim),
)
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim),
)
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks
if enable_slots:
self.blocks = nn.ModuleList([
DiTBlockWithSlots(
has_image_input=has_image_input,
dim=dim,
num_heads=num_heads,
ffn_dim=ffn_dim,
eps=eps,
enable_patch_text_cross_attn=enable_patch_text_cross_attn,
)
for _ in range(num_layers)
])
else:
# 退回到你原来的 DiTBlock(如需)
self.blocks = nn.ModuleList([
DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps)
for _ in range(num_layers)
])
self.head = Head(dim, out_dim, patch_size, eps)
head_dim = dim // num_heads
self.freqs = precompute_freqs_cis_3d(head_dim)
if has_image_input:
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
if has_ref_conv:
self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2))
self.has_image_pos_emb = has_image_pos_emb
self.has_ref_conv = has_ref_conv
if add_control_adapter:
self.control_adapter = SimpleAdapter(
in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:]
)
else:
self.control_adapter = None
# ---- slots params (新增) ----
if enable_slots:
self.slot_base = nn.Parameter(torch.randn(1, num_slots, dim) / dim**0.5)
self.instance_proj = None
if instance_state_dim > 0:
self.instance_proj = nn.Sequential(
nn.LayerNorm(instance_state_dim),
nn.Linear(instance_state_dim, dim),
nn.GELU(),
nn.Linear(dim, dim),
)
self.state_head = None
if state_head_dim > 0:
self.state_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim),
nn.GELU(),
nn.Linear(dim, state_head_dim),
)
def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None):
"""
return:
tokens: (b, f*h*w, dim)
grid: (f, h, w)
"""
x = self.patch_embedding(x) # (b, dim, f, h, w)
if self.control_adapter is not None and control_camera_latents_input is not None:
y_camera = self.control_adapter(control_camera_latents_input)
# 兼容 y_camera 可能是 list/tuple 或 tensor 的情况
if isinstance(y_camera, (list, tuple)):
# 如果你 adapter 返回的是 (b, dim, f, h, w) 的列表
x = [u + v for u, v in zip(x, y_camera)]
x = x[0].unsqueeze(0)
else:
x = x + y_camera
f, h, w = x.shape[2], x.shape[3], x.shape[4]
x = rearrange(x, "b c f h w -> b (f h w) c")
return x, (f, h, w)
def unpatchify(self, x: torch.Tensor, grid_size: Tuple[int, int, int]):
return rearrange(
x,
"b (f h w) (x y z c) -> b c (f x) (h y) (w z)",
f=grid_size[0],
h=grid_size[1],
w=grid_size[2],
x=self.patch_size[0],
y=self.patch_size[1],
z=self.patch_size[2],
)
def _init_slots(
self,
batch_size: int,
*,
instance_state: Optional[torch.Tensor] = None,
state_override: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
"""
instance_state:
- 训练:建议传 (b, num_slots, state_dim)(已做 slot 对齐/跟踪)
- 或者你也可以先传 (b, n_inst, state_dim) 再在外部做 matching 到 slot
state_override(推理交互用):
{
"slot_ids": LongTensor (b,) or (b,k),
"state": Tensor (..., state_dim),
"alpha": float (default 1.0),
"hard": bool (default False)
}
"""
slots = self.slot_base.expand(batch_size, -1, -1) # (b, num_slots, dim)
if (instance_state is not None) and (self.instance_proj is not None):
# instance_state 期望 (b, num_slots, state_dim)
slots = slots + self.instance_proj(instance_state)
# 推理时:对指定 slot 注入目标 state,实现“状态控制”
if state_override is not None and self.instance_proj is not None:
slot_ids = state_override["slot_ids"]
target_state = state_override["state"]
alpha = float(state_override.get("alpha", 1.0))
hard = bool(state_override.get("hard", False))
# 统一形状:slot_ids -> (b, k)
if slot_ids.dim() == 1:
slot_ids = slot_ids[:, None] # (b,1)
b = slots.shape[0]
k = slot_ids.shape[1]
# target_state 支持两种:
# (b, state_dim) -> broadcast to (b,k,state_dim)
# (b,k,state_dim) -> per-slot
if target_state.dim() == 2:
target_state = target_state[:, None, :].expand(b, k, -1)
delta = self.instance_proj(target_state) # (b,k,dim)
if hard:
# hard: slot = base + proj(state)
base = self.slot_base.expand(b, -1, -1)
# 先把 base 写入对应 slot 再加 delta
slots = slots.clone()
slots.scatter_(1, slot_ids[..., None].expand(b, k, slots.shape[-1]),
gather_tokens(base, slot_ids) + delta)
else:
# soft: slot += alpha * proj(state)
slots = slots.clone()
cur = gather_tokens(slots, slot_ids)
new = cur + alpha * delta
slots = scatter_tokens(slots, slot_ids, new)
return slots
def forward(
self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
# ---- 新增:实例状态/交互输入 ----
instance_state: Optional[torch.Tensor] = None, # (b, num_slots, state_dim)(训练可 per-frame)
roi_idx: Optional[torch.Tensor] = None, # (b, m)(推理你用 SAM mask->roi_idx)
state_override: Optional[Dict[str, Any]] = None, # 推理交互:对某些 slot 强制状态
return_state_pred: bool = False,
**kwargs,
):
# time + text
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep).to(x.dtype))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
context = self.text_embedding(context)
# optional image input
if self.has_image_input:
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
clip_embdding = self.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
# patchify -> tokens
x, (f, h, w) = self.patchify(x, control_camera_latents_input=kwargs.get("control_camera_latents_input", None))
# rope freqs for patch self-attn
freqs = torch.cat(
[
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(f * h * w, 1, -1).to(x.device)
# init slots
slots = None
if self.enable_slots:
slots = self._init_slots(
batch_size=x.shape[0],
instance_state=instance_state,
state_override=state_override,
)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# blocks
for block in self.blocks:
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
if self.enable_slots:
x, slots = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, slots, context, t_mod, freqs, roi_idx,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
if self.enable_slots:
x, slots = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, slots, context, t_mod, freqs, roi_idx,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
if self.enable_slots:
x, slots = block(x, slots, context, t_mod, freqs, roi_idx=roi_idx)
else:
x = block(x, context, t_mod, freqs)
# output head
out = self.head(x, t)
out = self.unpatchify(out, (f, h, w))
if return_state_pred and self.enable_slots and (self.state_head is not None):
state_pred = self.state_head(slots) # (b, num_slots, state_head_dim)
return out, state_pred, slots
return out
# --------- 原始 DiTBlock(保留以兼容 enable_slots=False)---------
class DiTBlock(nn.Module):
def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.ffn_dim = ffn_dim
self.self_attn = SelfAttention(dim, num_heads, eps)
self.cross_attn = CrossAttention(dim, num_heads, eps, has_image_input=has_image_input)
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.norm3 = nn.LayerNorm(dim, eps=eps)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim),
nn.GELU(approximate="tanh"),
nn.Linear(ffn_dim, dim),
)
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
self.gate = GateModule()
def forward(self, x, context, t_mod, freqs):
has_seq = len(t_mod.shape) == 4
chunk_dim = 2 if has_seq else 1
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod
).chunk(6, dim=chunk_dim)
if has_seq:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
shift_msa.squeeze(2),
scale_msa.squeeze(2),
gate_msa.squeeze(2),
shift_mlp.squeeze(2),
scale_mlp.squeeze(2),
gate_mlp.squeeze(2),
)
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
x = x + self.cross_attn(self.norm3(x), context)
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = self.gate(x, gate_mlp, self.ffn(input_x))
return x