| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from typing import Tuple, Optional, Dict, Any |
| from einops import rearrange |
| from .wan_video_camera_controller import SimpleAdapter |
|
|
| try: |
| import flash_attn_interface |
| FLASH_ATTN_3_AVAILABLE = True |
| except ModuleNotFoundError: |
| FLASH_ATTN_3_AVAILABLE = False |
|
|
| try: |
| import flash_attn |
| FLASH_ATTN_2_AVAILABLE = True |
| except ModuleNotFoundError: |
| FLASH_ATTN_2_AVAILABLE = False |
|
|
| try: |
| from sageattention import sageattn |
| SAGE_ATTN_AVAILABLE = True |
| except ModuleNotFoundError: |
| SAGE_ATTN_AVAILABLE = False |
|
|
|
|
| def flash_attention( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| num_heads: int, |
| compatibility_mode: bool = False |
| ): |
| if compatibility_mode: |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) |
| x = F.scaled_dot_product_attention(q, k, v) |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) |
|
|
| elif FLASH_ATTN_3_AVAILABLE: |
| q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) |
| x = flash_attn_interface.flash_attn_func(q, k, v) |
| if isinstance(x, tuple): |
| x = x[0] |
| x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) |
|
|
| elif FLASH_ATTN_2_AVAILABLE: |
| q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) |
| x = flash_attn.flash_attn_func(q, k, v) |
| x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) |
|
|
| elif SAGE_ATTN_AVAILABLE: |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) |
| x = sageattn(q, k, v) |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) |
|
|
| else: |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) |
| x = F.scaled_dot_product_attention(q, k, v) |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) |
|
|
| return x |
|
|
|
|
| def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): |
| return (x * (1 + scale) + shift) |
|
|
|
|
| def sinusoidal_embedding_1d(dim, position): |
| sinusoid = torch.outer( |
| position.type(torch.float64), |
| torch.pow( |
| 10000, |
| -torch.arange(dim // 2, dtype=torch.float64, device=position.device).div(dim // 2), |
| ), |
| ) |
| x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) |
| return x.to(position.dtype) |
|
|
|
|
| def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): |
| |
| f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta) |
| h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) |
| w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) |
| return f_freqs_cis, h_freqs_cis, w_freqs_cis |
|
|
|
|
| def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): |
| |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].double() / dim)) |
| freqs = torch.outer(torch.arange(end, device=freqs.device), freqs) |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| return freqs_cis |
|
|
|
|
| def rope_apply(x, freqs, num_heads): |
| x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) |
| x_out = torch.view_as_complex( |
| x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2) |
| ) |
| x_out = torch.view_as_real(x_out * freqs).flatten(2) |
| return x_out.to(x.dtype) |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-5): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def norm(self, x): |
| return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) |
|
|
| def forward(self, x): |
| dtype = x.dtype |
| return self.norm(x.float()).to(dtype) * self.weight |
|
|
|
|
| class AttentionModule(nn.Module): |
| def __init__(self, num_heads): |
| super().__init__() |
| self.num_heads = num_heads |
|
|
| def forward(self, q, k, v): |
| x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads) |
| return x |
|
|
|
|
| class SelfAttention(nn.Module): |
| """原有 SelfAttention:带 RoPE,给 video patch tokens 用""" |
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
|
|
| self.q = nn.Linear(dim, dim) |
| self.k = nn.Linear(dim, dim) |
| self.v = nn.Linear(dim, dim) |
| self.o = nn.Linear(dim, dim) |
| self.norm_q = RMSNorm(dim, eps=eps) |
| self.norm_k = RMSNorm(dim, eps=eps) |
|
|
| self.attn = AttentionModule(self.num_heads) |
|
|
| def forward(self, x, freqs): |
| q = self.norm_q(self.q(x)) |
| k = self.norm_k(self.k(x)) |
| v = self.v(x) |
| q = rope_apply(q, freqs, self.num_heads) |
| k = rope_apply(k, freqs, self.num_heads) |
| x = self.attn(q, k, v) |
| return self.o(x) |
|
|
|
|
| class SelfAttentionNoRoPE(nn.Module): |
| """给 slots 用的 self-attn:不加 RoPE(slot 没有稳定网格位置信息时更稳)""" |
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
|
|
| self.q = nn.Linear(dim, dim) |
| self.k = nn.Linear(dim, dim) |
| self.v = nn.Linear(dim, dim) |
| self.o = nn.Linear(dim, dim) |
| self.norm_q = RMSNorm(dim, eps=eps) |
| self.norm_k = RMSNorm(dim, eps=eps) |
|
|
| self.attn = AttentionModule(self.num_heads) |
|
|
| def forward(self, x): |
| q = self.norm_q(self.q(x)) |
| k = self.norm_k(self.k(x)) |
| v = self.v(x) |
| x = self.attn(q, k, v) |
| return self.o(x) |
|
|
|
|
| class CrossAttention(nn.Module): |
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
|
|
| self.q = nn.Linear(dim, dim) |
| self.k = nn.Linear(dim, dim) |
| self.v = nn.Linear(dim, dim) |
| self.o = nn.Linear(dim, dim) |
| self.norm_q = RMSNorm(dim, eps=eps) |
| self.norm_k = RMSNorm(dim, eps=eps) |
| self.has_image_input = has_image_input |
| if has_image_input: |
| self.k_img = nn.Linear(dim, dim) |
| self.v_img = nn.Linear(dim, dim) |
| self.norm_k_img = RMSNorm(dim, eps=eps) |
|
|
| self.attn = AttentionModule(self.num_heads) |
|
|
| def forward(self, x: torch.Tensor, y: torch.Tensor): |
| """ |
| x: queries |
| y: keys/values (context) |
| """ |
| if self.has_image_input: |
| img = y[:, :257] |
| ctx = y[:, 257:] |
| else: |
| ctx = y |
|
|
| q = self.norm_q(self.q(x)) |
| k = self.norm_k(self.k(ctx)) |
| v = self.v(ctx) |
| x_out = self.attn(q, k, v) |
|
|
| if self.has_image_input: |
| k_img = self.norm_k_img(self.k_img(img)) |
| v_img = self.v_img(img) |
| y_img = flash_attention(q, k_img, v_img, num_heads=self.num_heads) |
| x_out = x_out + y_img |
|
|
| return self.o(x_out) |
|
|
|
|
| class GateModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| def forward(self, x, gate, residual): |
| return x + gate * residual |
|
|
|
|
| |
| |
| |
| def gather_tokens(x: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: |
| """ |
| x: (b, s, c) |
| idx: (b, m) long |
| return: (b, m, c) |
| """ |
| b, s, c = x.shape |
| return x.gather(1, idx[..., None].expand(b, idx.shape[1], c)) |
|
|
|
|
| def scatter_tokens(x: torch.Tensor, idx: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| """ |
| x: (b, s, c) |
| idx: (b, m) long |
| y: (b, m, c) |
| return: (b, s, c) with y written back to idx positions |
| """ |
| b, s, c = x.shape |
| out = x.clone() |
| out.scatter_(1, idx[..., None].expand(b, idx.shape[1], c), y) |
| return out |
|
|
|
|
| def bbox_to_mask(bbox_xyxy: torch.Tensor, H: int, W: int) -> torch.Tensor: |
| """ |
| bbox_xyxy: (b, 4) float/int in [0,W) [0,H) |
| return: (b, 1, H, W) float {0,1} |
| """ |
| b = bbox_xyxy.shape[0] |
| mask = torch.zeros((b, 1, H, W), device=bbox_xyxy.device, dtype=torch.float32) |
| x1, y1, x2, y2 = bbox_xyxy[:, 0], bbox_xyxy[:, 1], bbox_xyxy[:, 2], bbox_xyxy[:, 3] |
| x1 = x1.clamp(0, W - 1).long() |
| x2 = x2.clamp(0, W).long() |
| y1 = y1.clamp(0, H - 1).long() |
| y2 = y2.clamp(0, H).long() |
| for i in range(b): |
| mask[i, 0, y1[i]:y2[i], x1[i]:x2[i]] = 1.0 |
| return mask |
|
|
|
|
| @torch.no_grad() |
| def mask_to_roi_idx( |
| mask: torch.Tensor, |
| grid_fhw: Tuple[int, int, int], |
| *, |
| frame_index: int = 0, |
| roi_token_budget: int = 256, |
| mode: str = "topk", |
| ) -> torch.Tensor: |
| """ |
| 把单帧 mask 下采样到 patch 网格,输出固定长度 roi_idx(gather/scatter 用)。 |
| |
| mask: (b, H, W) or (b,1,H,W) —— 推理你说只给一帧,就传这一帧的 mask |
| grid_fhw: (f, h, w) from patchify |
| frame_index: 指定这次交互发生在第几帧(默认 0) |
| roi_token_budget: 固定 m,避免 flash-attn 变长不兼容 |
| mode: "topk" 或 "random" |
| """ |
| if mask.dim() == 3: |
| mask = mask[:, None] |
| b, _, H, W = mask.shape |
| f, h, w = grid_fhw |
| assert 0 <= frame_index < f, f"frame_index {frame_index} out of range f={f}" |
|
|
| |
| m_small = F.interpolate(mask.float(), size=(h, w), mode="bilinear", align_corners=False) |
| m_small = m_small[:, 0] |
| flat = m_small.reshape(b, h * w) |
|
|
| |
| base = frame_index * (h * w) |
|
|
| if mode == "topk": |
| scores = flat |
| k = min(roi_token_budget, h * w) |
| topv, topi = torch.topk(scores, k=k, dim=1) |
| |
| if k < roi_token_budget: |
| pad = topi[:, :1].expand(b, roi_token_budget - k) |
| topi = torch.cat([topi, pad], dim=1) |
| idx = topi[:, :roi_token_budget] + base |
| return idx.long() |
|
|
| if mode == "random": |
| |
| idx_list = [] |
| for bi in range(b): |
| nz = torch.nonzero(flat[bi] > 0.01, as_tuple=False).flatten() |
| if nz.numel() == 0: |
| |
| nz = torch.arange(h * w, device=mask.device) |
| if nz.numel() >= roi_token_budget: |
| sel = nz[torch.randperm(nz.numel(), device=mask.device)[:roi_token_budget]] |
| else: |
| rep = nz[torch.randint(0, nz.numel(), (roi_token_budget,), device=mask.device)] |
| sel = rep |
| idx_list.append(sel) |
| idx = torch.stack(idx_list, dim=0) + base |
| return idx.long() |
|
|
| raise ValueError(f"Unknown mode={mode}") |
|
|
|
|
| |
| |
| |
| class DiTBlockWithSlots(nn.Module): |
| """ |
| 在原 DiTBlock 基础上加入: |
| - text -> slots |
| - roi_patches -> slots |
| - slots self-attn(实例交互) |
| - slots -> roi_patches(写回) |
| 并且默认关闭原来的 patch <- text 全局 cross-attn(避免全局污染)。 |
| """ |
| def __init__( |
| self, |
| has_image_input: bool, |
| dim: int, |
| num_heads: int, |
| ffn_dim: int, |
| eps: float = 1e-6, |
| enable_patch_text_cross_attn: bool = False, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| self.ffn_dim = ffn_dim |
| self.enable_patch_text_cross_attn = enable_patch_text_cross_attn |
|
|
| |
| self.self_attn = SelfAttention(dim, num_heads, eps) |
| self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) |
| self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) |
| self.ffn = nn.Sequential( |
| nn.Linear(dim, ffn_dim), |
| nn.GELU(approximate="tanh"), |
| nn.Linear(ffn_dim, dim), |
| ) |
| self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) |
| self.gate = GateModule() |
|
|
| |
| if enable_patch_text_cross_attn: |
| self.cross_attn = CrossAttention(dim, num_heads, eps, has_image_input=has_image_input) |
| self.norm3 = nn.LayerNorm(dim, eps=eps) |
|
|
| |
| self.slot_norm = nn.LayerNorm(dim, eps=eps) |
| self.slot_text_attn = CrossAttention(dim, num_heads, eps, has_image_input=has_image_input) |
| self.slot_from_patch = CrossAttention(dim, num_heads, eps, has_image_input=False) |
| self.slot_self = SelfAttentionNoRoPE(dim, num_heads, eps) |
| self.patch_from_slot = CrossAttention(dim, num_heads, eps, has_image_input=False) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| slots: torch.Tensor, |
| context: torch.Tensor, |
| t_mod: torch.Tensor, |
| freqs: torch.Tensor, |
| roi_idx: Optional[torch.Tensor] = None, |
| ): |
| |
| has_seq = len(t_mod.shape) == 4 |
| chunk_dim = 2 if has_seq else 1 |
|
|
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( |
| self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod |
| ).chunk(6, dim=chunk_dim) |
|
|
| if has_seq: |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( |
| shift_msa.squeeze(2), |
| scale_msa.squeeze(2), |
| gate_msa.squeeze(2), |
| shift_mlp.squeeze(2), |
| scale_mlp.squeeze(2), |
| gate_mlp.squeeze(2), |
| ) |
|
|
| input_x = modulate(self.norm1(x), shift_msa, scale_msa) |
| x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) |
|
|
| |
| if roi_idx is None: |
| x_roi = x |
| else: |
| x_roi = gather_tokens(x, roi_idx) |
|
|
| |
| slots = slots + self.slot_text_attn(self.slot_norm(slots), context) |
|
|
| |
| slots = slots + self.slot_from_patch(self.slot_norm(slots), x_roi) |
|
|
| |
| slots = slots + self.slot_self(self.slot_norm(slots)) |
|
|
| |
| x_roi = x_roi + self.patch_from_slot(self.slot_norm(x_roi), slots) |
|
|
| if roi_idx is None: |
| x = x_roi |
| else: |
| x = scatter_tokens(x, roi_idx, x_roi) |
|
|
| |
| if self.enable_patch_text_cross_attn: |
| x = x + self.cross_attn(self.norm3(x), context) |
|
|
| |
| input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) |
| x = self.gate(x, gate_mlp, self.ffn(input_x)) |
|
|
| return x, slots |
|
|
|
|
| class MLP(torch.nn.Module): |
| def __init__(self, in_dim, out_dim, has_pos_emb=False): |
| super().__init__() |
| self.proj = torch.nn.Sequential( |
| nn.LayerNorm(in_dim), |
| nn.Linear(in_dim, in_dim), |
| nn.GELU(), |
| nn.Linear(in_dim, out_dim), |
| nn.LayerNorm(out_dim), |
| ) |
| self.has_pos_emb = has_pos_emb |
| if has_pos_emb: |
| self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280))) |
|
|
| def forward(self, x): |
| if self.has_pos_emb: |
| x = x + self.emb_pos.to(dtype=x.dtype, device=x.device) |
| return self.proj(x) |
|
|
|
|
| class Head(nn.Module): |
| def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float): |
| super().__init__() |
| self.dim = dim |
| self.patch_size = patch_size |
| self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) |
| self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) |
| self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) |
|
|
| def forward(self, x, t_mod): |
| if len(t_mod.shape) == 3: |
| shift, scale = ( |
| self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2) |
| ).chunk(2, dim=2) |
| x = self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2)) |
| else: |
| shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) |
| x = self.head(self.norm(x) * (1 + scale) + shift) |
| return x |
|
|
|
|
| class WanModel(torch.nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| in_dim: int, |
| ffn_dim: int, |
| out_dim: int, |
| text_dim: int, |
| freq_dim: int, |
| eps: float, |
| patch_size: Tuple[int, int, int], |
| num_heads: int, |
| num_layers: int, |
| has_image_input: bool, |
| has_image_pos_emb: bool = False, |
| has_ref_conv: bool = False, |
| add_control_adapter: bool = False, |
| in_dim_control_adapter: int = 24, |
| seperated_timestep: bool = False, |
| require_vae_embedding: bool = True, |
| require_clip_embedding: bool = True, |
| fuse_vae_embedding_in_latents: bool = False, |
| |
| |
| enable_slots: bool = True, |
| num_slots: int = 16, |
| instance_state_dim: int = 0, |
| state_head_dim: int = 0, |
| enable_patch_text_cross_attn: bool = False, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.in_dim = in_dim |
| self.freq_dim = freq_dim |
| self.has_image_input = has_image_input |
| self.patch_size = patch_size |
| self.seperated_timestep = seperated_timestep |
| self.require_vae_embedding = require_vae_embedding |
| self.require_clip_embedding = require_clip_embedding |
| self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents |
|
|
| self.enable_slots = enable_slots |
| self.num_slots = num_slots |
| self.instance_state_dim = instance_state_dim |
| self.state_head_dim = state_head_dim |
|
|
| self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) |
|
|
| self.text_embedding = nn.Sequential( |
| nn.Linear(text_dim, dim), |
| nn.GELU(approximate="tanh"), |
| nn.Linear(dim, dim), |
| ) |
| self.time_embedding = nn.Sequential( |
| nn.Linear(freq_dim, dim), |
| nn.SiLU(), |
| nn.Linear(dim, dim), |
| ) |
| self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) |
|
|
| |
| if enable_slots: |
| self.blocks = nn.ModuleList([ |
| DiTBlockWithSlots( |
| has_image_input=has_image_input, |
| dim=dim, |
| num_heads=num_heads, |
| ffn_dim=ffn_dim, |
| eps=eps, |
| enable_patch_text_cross_attn=enable_patch_text_cross_attn, |
| ) |
| for _ in range(num_layers) |
| ]) |
| else: |
| |
| self.blocks = nn.ModuleList([ |
| DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps) |
| for _ in range(num_layers) |
| ]) |
|
|
| self.head = Head(dim, out_dim, patch_size, eps) |
| head_dim = dim // num_heads |
| self.freqs = precompute_freqs_cis_3d(head_dim) |
|
|
| if has_image_input: |
| self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) |
| if has_ref_conv: |
| self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2)) |
| self.has_image_pos_emb = has_image_pos_emb |
| self.has_ref_conv = has_ref_conv |
|
|
| if add_control_adapter: |
| self.control_adapter = SimpleAdapter( |
| in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:] |
| ) |
| else: |
| self.control_adapter = None |
|
|
| |
| if enable_slots: |
| self.slot_base = nn.Parameter(torch.randn(1, num_slots, dim) / dim**0.5) |
|
|
| self.instance_proj = None |
| if instance_state_dim > 0: |
| self.instance_proj = nn.Sequential( |
| nn.LayerNorm(instance_state_dim), |
| nn.Linear(instance_state_dim, dim), |
| nn.GELU(), |
| nn.Linear(dim, dim), |
| ) |
|
|
| self.state_head = None |
| if state_head_dim > 0: |
| self.state_head = nn.Sequential( |
| nn.LayerNorm(dim), |
| nn.Linear(dim, dim), |
| nn.GELU(), |
| nn.Linear(dim, state_head_dim), |
| ) |
|
|
| def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None): |
| """ |
| return: |
| tokens: (b, f*h*w, dim) |
| grid: (f, h, w) |
| """ |
| x = self.patch_embedding(x) |
|
|
| if self.control_adapter is not None and control_camera_latents_input is not None: |
| y_camera = self.control_adapter(control_camera_latents_input) |
| |
| if isinstance(y_camera, (list, tuple)): |
| |
| x = [u + v for u, v in zip(x, y_camera)] |
| x = x[0].unsqueeze(0) |
| else: |
| x = x + y_camera |
|
|
| f, h, w = x.shape[2], x.shape[3], x.shape[4] |
| x = rearrange(x, "b c f h w -> b (f h w) c") |
| return x, (f, h, w) |
|
|
| def unpatchify(self, x: torch.Tensor, grid_size: Tuple[int, int, int]): |
| return rearrange( |
| x, |
| "b (f h w) (x y z c) -> b c (f x) (h y) (w z)", |
| f=grid_size[0], |
| h=grid_size[1], |
| w=grid_size[2], |
| x=self.patch_size[0], |
| y=self.patch_size[1], |
| z=self.patch_size[2], |
| ) |
|
|
| def _init_slots( |
| self, |
| batch_size: int, |
| *, |
| instance_state: Optional[torch.Tensor] = None, |
| state_override: Optional[Dict[str, Any]] = None, |
| ) -> torch.Tensor: |
| """ |
| instance_state: |
| - 训练:建议传 (b, num_slots, state_dim)(已做 slot 对齐/跟踪) |
| - 或者你也可以先传 (b, n_inst, state_dim) 再在外部做 matching 到 slot |
| |
| state_override(推理交互用): |
| { |
| "slot_ids": LongTensor (b,) or (b,k), |
| "state": Tensor (..., state_dim), |
| "alpha": float (default 1.0), |
| "hard": bool (default False) |
| } |
| """ |
| slots = self.slot_base.expand(batch_size, -1, -1) |
|
|
| if (instance_state is not None) and (self.instance_proj is not None): |
| |
| slots = slots + self.instance_proj(instance_state) |
|
|
| |
| if state_override is not None and self.instance_proj is not None: |
| slot_ids = state_override["slot_ids"] |
| target_state = state_override["state"] |
| alpha = float(state_override.get("alpha", 1.0)) |
| hard = bool(state_override.get("hard", False)) |
|
|
| |
| if slot_ids.dim() == 1: |
| slot_ids = slot_ids[:, None] |
|
|
| b = slots.shape[0] |
| k = slot_ids.shape[1] |
|
|
| |
| |
| |
| if target_state.dim() == 2: |
| target_state = target_state[:, None, :].expand(b, k, -1) |
|
|
| delta = self.instance_proj(target_state) |
|
|
| if hard: |
| |
| base = self.slot_base.expand(b, -1, -1) |
| |
| slots = slots.clone() |
| slots.scatter_(1, slot_ids[..., None].expand(b, k, slots.shape[-1]), |
| gather_tokens(base, slot_ids) + delta) |
| else: |
| |
| slots = slots.clone() |
| cur = gather_tokens(slots, slot_ids) |
| new = cur + alpha * delta |
| slots = scatter_tokens(slots, slot_ids, new) |
|
|
| return slots |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| timestep: torch.Tensor, |
| context: torch.Tensor, |
| clip_feature: Optional[torch.Tensor] = None, |
| y: Optional[torch.Tensor] = None, |
| use_gradient_checkpointing: bool = False, |
| use_gradient_checkpointing_offload: bool = False, |
| |
| |
| instance_state: Optional[torch.Tensor] = None, |
| roi_idx: Optional[torch.Tensor] = None, |
| state_override: Optional[Dict[str, Any]] = None, |
| |
| return_state_pred: bool = False, |
| **kwargs, |
| ): |
| |
| t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep).to(x.dtype)) |
| t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) |
| context = self.text_embedding(context) |
|
|
| |
| if self.has_image_input: |
| x = torch.cat([x, y], dim=1) |
| clip_embdding = self.img_emb(clip_feature) |
| context = torch.cat([clip_embdding, context], dim=1) |
|
|
| |
| x, (f, h, w) = self.patchify(x, control_camera_latents_input=kwargs.get("control_camera_latents_input", None)) |
|
|
| |
| freqs = torch.cat( |
| [ |
| self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), |
| self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), |
| self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), |
| ], |
| dim=-1, |
| ).reshape(f * h * w, 1, -1).to(x.device) |
|
|
| |
| slots = None |
| if self.enable_slots: |
| slots = self._init_slots( |
| batch_size=x.shape[0], |
| instance_state=instance_state, |
| state_override=state_override, |
| ) |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs) |
| return custom_forward |
|
|
| |
| for block in self.blocks: |
| if self.training and use_gradient_checkpointing: |
| if use_gradient_checkpointing_offload: |
| with torch.autograd.graph.save_on_cpu(): |
| if self.enable_slots: |
| x, slots = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(block), |
| x, slots, context, t_mod, freqs, roi_idx, |
| use_reentrant=False, |
| ) |
| else: |
| x = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(block), |
| x, context, t_mod, freqs, |
| use_reentrant=False, |
| ) |
| else: |
| if self.enable_slots: |
| x, slots = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(block), |
| x, slots, context, t_mod, freqs, roi_idx, |
| use_reentrant=False, |
| ) |
| else: |
| x = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(block), |
| x, context, t_mod, freqs, |
| use_reentrant=False, |
| ) |
| else: |
| if self.enable_slots: |
| x, slots = block(x, slots, context, t_mod, freqs, roi_idx=roi_idx) |
| else: |
| x = block(x, context, t_mod, freqs) |
|
|
| |
| out = self.head(x, t) |
| out = self.unpatchify(out, (f, h, w)) |
|
|
| if return_state_pred and self.enable_slots and (self.state_head is not None): |
| state_pred = self.state_head(slots) |
| return out, state_pred, slots |
|
|
| return out |
|
|
|
|
| |
| class DiTBlock(nn.Module): |
| def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| self.ffn_dim = ffn_dim |
|
|
| self.self_attn = SelfAttention(dim, num_heads, eps) |
| self.cross_attn = CrossAttention(dim, num_heads, eps, has_image_input=has_image_input) |
| self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) |
| self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) |
| self.norm3 = nn.LayerNorm(dim, eps=eps) |
| self.ffn = nn.Sequential( |
| nn.Linear(dim, ffn_dim), |
| nn.GELU(approximate="tanh"), |
| nn.Linear(ffn_dim, dim), |
| ) |
| self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) |
| self.gate = GateModule() |
|
|
| def forward(self, x, context, t_mod, freqs): |
| has_seq = len(t_mod.shape) == 4 |
| chunk_dim = 2 if has_seq else 1 |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( |
| self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod |
| ).chunk(6, dim=chunk_dim) |
| if has_seq: |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( |
| shift_msa.squeeze(2), |
| scale_msa.squeeze(2), |
| gate_msa.squeeze(2), |
| shift_mlp.squeeze(2), |
| scale_mlp.squeeze(2), |
| gate_mlp.squeeze(2), |
| ) |
| input_x = modulate(self.norm1(x), shift_msa, scale_msa) |
| x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) |
| x = x + self.cross_attn(self.norm3(x), context) |
| input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) |
| x = self.gate(x, gate_mlp, self.ffn(input_x)) |
| return x |
|
|