| """ |
| ViL (Vision-LSTM) Backbone for single object tracking. |
| |
| Architecture: |
| - Patch embedding (Conv2d) for template + search region |
| - Stack of mLSTM blocks with bidirectional scanning (even=L→R, odd=R→L) |
| - FiLM temporal modulation integrated BETWEEN blocks (at interval=6) |
| - Optional TMoE-MLP in last N blocks (dense routing, frozen shared expert) |
| - Outputs concatenated template+search features for head processing |
| |
| ViL-S config: dim=384, depth=24, patch_size=16, ~23M backbone params |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
|
|
| from .mlstm import mLSTMBlock, SwiGLUMLP, StochasticDepth |
|
|
|
|
| class PatchEmbed(nn.Module): |
| """Convert image patches to token embeddings using Conv2d.""" |
| def __init__(self, patch_size: int = 16, in_channels: int = 3, dim: int = 384): |
| super().__init__() |
| self.patch_size = patch_size |
| self.proj = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size) |
| self.norm = nn.LayerNorm(dim, bias=False) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x: (B, C, H, W) image tensor |
| Returns: |
| (B, N, D) patch token embeddings, N = (H/P)*(W/P) |
| """ |
| x = self.proj(x) |
| x = rearrange(x, 'b d h w -> b (h w) d') |
| x = self.norm(x) |
| return x |
|
|
|
|
| class TMoEMLP(nn.Module): |
| """Temporal Mixture-of-Experts MLP. |
| |
| Uses dense routing with a shared expert (frozen after Phase 1) and |
| K specialized experts. Output = shared_out + sum(gate_k * expert_k_out). |
| |
| For tracking: experts specialize on different temporal dynamics |
| (fast motion, occlusion recovery, scale change). |
| """ |
| def __init__( |
| self, |
| dim: int = 384, |
| mlp_ratio: float = 4.0, |
| num_experts: int = 4, |
| bias: bool = False, |
| ): |
| super().__init__() |
| self.num_experts = num_experts |
| hidden_dim = int(dim * mlp_ratio) |
| |
| |
| self.shared_expert = SwiGLUMLP(dim=dim, mlp_ratio=mlp_ratio, bias=bias) |
| |
| |
| small_ratio = mlp_ratio / 2 |
| self.experts = nn.ModuleList([ |
| SwiGLUMLP(dim=dim, mlp_ratio=small_ratio, bias=bias) |
| for _ in range(num_experts) |
| ]) |
| |
| |
| self.router = nn.Linear(dim, num_experts, bias=True) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| shared_out = self.shared_expert(x) |
| |
| |
| gates = F.softmax(self.router(x), dim=-1) |
| |
| |
| expert_out = torch.zeros_like(shared_out) |
| for i, expert in enumerate(self.experts): |
| expert_out = expert_out + gates[..., i:i+1] * expert(x) |
| |
| return shared_out + expert_out |
| |
| def freeze_shared_expert(self): |
| """Freeze the shared expert for Phase 2 training.""" |
| for p in self.shared_expert.parameters(): |
| p.requires_grad = False |
|
|
|
|
| class mLSTMBlockWithTMoE(nn.Module): |
| """mLSTM block with TMoE MLP instead of standard SwiGLU MLP.""" |
| def __init__( |
| self, |
| dim: int = 384, |
| proj_factor: float = 2.0, |
| qkv_proj_blocksize: int = 4, |
| num_heads: int = 4, |
| conv_kernel: int = 4, |
| mlp_ratio: float = 4.0, |
| drop_path: float = 0.0, |
| num_experts: int = 4, |
| bias: bool = False, |
| ): |
| super().__init__() |
| from .mlstm import mLSTMCell |
| |
| self.norm1 = nn.LayerNorm(dim, bias=False) |
| self.mlstm = mLSTMCell( |
| dim=dim, |
| proj_factor=proj_factor, |
| qkv_proj_blocksize=qkv_proj_blocksize, |
| num_heads=num_heads, |
| conv_kernel=conv_kernel, |
| bias=bias, |
| ) |
| self.norm2 = nn.LayerNorm(dim, bias=False) |
| self.mlp = TMoEMLP(dim=dim, mlp_ratio=mlp_ratio, num_experts=num_experts, bias=bias) |
| self.drop_path = StochasticDepth(drop_path) |
| |
| def forward(self, x: torch.Tensor, reverse: bool = False) -> torch.Tensor: |
| x = x + self.drop_path(self.mlstm(self.norm1(x), reverse=reverse)) |
| x = x + self.drop_path(self.mlp(self.norm2(x))) |
| return x |
| |
| def freeze_shared_expert(self): |
| self.mlp.freeze_shared_expert() |
|
|
|
|
| class ViLBackbone(nn.Module): |
| """Vision-LSTM backbone for tracking with sequential multi-frame processing. |
| |
| Processes template + K search frames as one long mLSTM sequence: |
| [template_tokens | search_1_tokens | search_2_tokens | ... | search_K_tokens] |
| |
| The mLSTM memory state C carries information across frames: |
| - Template tokens establish the target appearance in memory |
| - Search_1 tokens are processed with template context in memory |
| - Search_2 tokens are processed with template + search_1 context, etc. |
| |
| This is the core advantage over ViT: temporal information accumulates |
| in the recurrent memory state, not through attention over all tokens. |
| |
| Token counts: |
| Template: 128x128 → 8x8 = 64 tokens |
| Each search: 256x256 → 16x16 = 256 tokens |
| K=3 sequence: 64 + 3×256 = 832 tokens |
| |
| Bidirectional scanning: even blocks L→R, odd blocks R→L. |
| FiLM modulation: applied between blocks at interval=6. |
| TMoE: last `tmoe_blocks` blocks. |
| """ |
| def __init__( |
| self, |
| dim: int = 384, |
| depth: int = 24, |
| patch_size: int = 16, |
| in_channels: int = 3, |
| proj_factor: float = 2.0, |
| qkv_proj_blocksize: int = 4, |
| num_heads: int = 4, |
| conv_kernel: int = 4, |
| mlp_ratio: float = 4.0, |
| drop_path_rate: float = 0.05, |
| tmoe_blocks: int = 2, |
| num_experts: int = 4, |
| bias: bool = False, |
| film_interval: int = 6, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.depth = depth |
| self.patch_size = patch_size |
| self.film_interval = film_interval |
| |
| |
| self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=in_channels, dim=dim) |
| |
| |
| |
| |
| self.template_pos = nn.Parameter(torch.randn(1, 64, dim) * 0.02) |
| self.search_pos = nn.Parameter(torch.randn(1, 256, dim) * 0.02) |
| |
| |
| self.template_type = nn.Parameter(torch.randn(1, 1, dim) * 0.02) |
| self.search_type = nn.Parameter(torch.randn(1, 1, dim) * 0.02) |
| |
| |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] |
| |
| |
| self.blocks = nn.ModuleList() |
| for i in range(depth): |
| if i >= depth - tmoe_blocks: |
| block = mLSTMBlockWithTMoE( |
| dim=dim, proj_factor=proj_factor, |
| qkv_proj_blocksize=qkv_proj_blocksize, |
| num_heads=num_heads, conv_kernel=conv_kernel, |
| mlp_ratio=mlp_ratio, drop_path=dpr[i], |
| num_experts=num_experts, bias=bias, |
| ) |
| else: |
| block = mLSTMBlock( |
| dim=dim, proj_factor=proj_factor, |
| qkv_proj_blocksize=qkv_proj_blocksize, |
| num_heads=num_heads, conv_kernel=conv_kernel, |
| mlp_ratio=mlp_ratio, drop_path=dpr[i], bias=bias, |
| ) |
| self.blocks.append(block) |
| |
| |
| self.norm = nn.LayerNorm(dim, bias=False) |
| |
| def forward( |
| self, |
| template: torch.Tensor, |
| searches: torch.Tensor, |
| temporal_mod_manager=None, |
| ) -> tuple: |
| """ |
| Process template + K search frames as one mLSTM sequence. |
| |
| Args: |
| template: (B, 3, 128, 128) template image |
| searches: (B, K, 3, 256, 256) K consecutive search frames |
| OR (B, 3, 256, 256) single search frame (backward compat) |
| temporal_mod_manager: optional TemporalModulationManager for FiLM |
| Returns: |
| template_feat: (B, 64, D) template features |
| search_feats: (B, K, 256, D) per-frame search features |
| OR (B, 256, D) if single search frame input |
| """ |
| B = template.shape[0] |
| single_frame = (searches.ndim == 4) |
| |
| if single_frame: |
| searches = searches.unsqueeze(1) |
| |
| K = searches.shape[1] |
| |
| |
| t_tokens = self.patch_embed(template) |
| t_tokens = t_tokens + self.template_pos + self.template_type |
| n_template = t_tokens.shape[1] |
| |
| |
| |
| s_flat = searches.reshape(B * K, *searches.shape[2:]) |
| s_tokens_flat = self.patch_embed(s_flat) |
| s_tokens = s_tokens_flat.reshape(B, K, -1, self.dim) |
| s_tokens = s_tokens + self.search_pos.unsqueeze(1) + self.search_type |
| n_search = s_tokens.shape[2] |
| |
| |
| |
| s_tokens_concat = s_tokens.reshape(B, K * n_search, self.dim) |
| tokens = torch.cat([t_tokens, s_tokens_concat], dim=1) |
| |
| |
| for i, block in enumerate(self.blocks): |
| reverse = (i % 2 == 1) |
| tokens = block(tokens, reverse=reverse) |
| |
| if temporal_mod_manager is not None: |
| tokens = temporal_mod_manager.modulate(tokens, i) |
| |
| tokens = self.norm(tokens) |
| |
| if temporal_mod_manager is not None: |
| temporal_mod_manager.update_temporal_context(tokens) |
| |
| |
| template_feat = tokens[:, :n_template] |
| search_tokens = tokens[:, n_template:] |
| search_feats = search_tokens.reshape(B, K, n_search, self.dim) |
| |
| if single_frame: |
| return template_feat, search_feats.squeeze(1) |
| |
| return template_feat, search_feats |
| |
| def freeze_shared_experts(self): |
| """Freeze shared experts in TMoE blocks for Phase 2 training.""" |
| for block in self.blocks: |
| if hasattr(block, 'freeze_shared_expert'): |
| block.freeze_shared_expert() |
|
|