""" ViL (Vision-LSTM) Backbone for single object tracking. Architecture: - Patch embedding (Conv2d) for template + search region - Stack of mLSTM blocks with bidirectional scanning (even=L→R, odd=R→L) - FiLM temporal modulation integrated BETWEEN blocks (at interval=6) - Optional TMoE-MLP in last N blocks (dense routing, frozen shared expert) - Outputs concatenated template+search features for head processing ViL-S config: dim=384, depth=24, patch_size=16, ~23M backbone params """ import math import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from .mlstm import mLSTMBlock, SwiGLUMLP, StochasticDepth class PatchEmbed(nn.Module): """Convert image patches to token embeddings using Conv2d.""" def __init__(self, patch_size: int = 16, in_channels: int = 3, dim: int = 384): super().__init__() self.patch_size = patch_size self.proj = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size) self.norm = nn.LayerNorm(dim, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (B, C, H, W) image tensor Returns: (B, N, D) patch token embeddings, N = (H/P)*(W/P) """ x = self.proj(x) # (B, D, H/P, W/P) x = rearrange(x, 'b d h w -> b (h w) d') x = self.norm(x) return x class TMoEMLP(nn.Module): """Temporal Mixture-of-Experts MLP. Uses dense routing with a shared expert (frozen after Phase 1) and K specialized experts. Output = shared_out + sum(gate_k * expert_k_out). For tracking: experts specialize on different temporal dynamics (fast motion, occlusion recovery, scale change). """ def __init__( self, dim: int = 384, mlp_ratio: float = 4.0, num_experts: int = 4, bias: bool = False, ): super().__init__() self.num_experts = num_experts hidden_dim = int(dim * mlp_ratio) # Shared expert (frozen after Phase 1 training) self.shared_expert = SwiGLUMLP(dim=dim, mlp_ratio=mlp_ratio, bias=bias) # Specialized experts (smaller: mlp_ratio/2) small_ratio = mlp_ratio / 2 self.experts = nn.ModuleList([ SwiGLUMLP(dim=dim, mlp_ratio=small_ratio, bias=bias) for _ in range(num_experts) ]) # Dense router: soft gating over experts self.router = nn.Linear(dim, num_experts, bias=True) def forward(self, x: torch.Tensor) -> torch.Tensor: # Shared expert output (always contributes) shared_out = self.shared_expert(x) # Router logits and softmax gates gates = F.softmax(self.router(x), dim=-1) # (B, S, num_experts) # Expert outputs, weighted by gates expert_out = torch.zeros_like(shared_out) for i, expert in enumerate(self.experts): expert_out = expert_out + gates[..., i:i+1] * expert(x) return shared_out + expert_out def freeze_shared_expert(self): """Freeze the shared expert for Phase 2 training.""" for p in self.shared_expert.parameters(): p.requires_grad = False class mLSTMBlockWithTMoE(nn.Module): """mLSTM block with TMoE MLP instead of standard SwiGLU MLP.""" def __init__( self, dim: int = 384, proj_factor: float = 2.0, qkv_proj_blocksize: int = 4, num_heads: int = 4, conv_kernel: int = 4, mlp_ratio: float = 4.0, drop_path: float = 0.0, num_experts: int = 4, bias: bool = False, ): super().__init__() from .mlstm import mLSTMCell self.norm1 = nn.LayerNorm(dim, bias=False) self.mlstm = mLSTMCell( dim=dim, proj_factor=proj_factor, qkv_proj_blocksize=qkv_proj_blocksize, num_heads=num_heads, conv_kernel=conv_kernel, bias=bias, ) self.norm2 = nn.LayerNorm(dim, bias=False) self.mlp = TMoEMLP(dim=dim, mlp_ratio=mlp_ratio, num_experts=num_experts, bias=bias) self.drop_path = StochasticDepth(drop_path) def forward(self, x: torch.Tensor, reverse: bool = False) -> torch.Tensor: x = x + self.drop_path(self.mlstm(self.norm1(x), reverse=reverse)) x = x + self.drop_path(self.mlp(self.norm2(x))) return x def freeze_shared_expert(self): self.mlp.freeze_shared_expert() class ViLBackbone(nn.Module): """Vision-LSTM backbone for tracking with sequential multi-frame processing. Processes template + K search frames as one long mLSTM sequence: [template_tokens | search_1_tokens | search_2_tokens | ... | search_K_tokens] The mLSTM memory state C carries information across frames: - Template tokens establish the target appearance in memory - Search_1 tokens are processed with template context in memory - Search_2 tokens are processed with template + search_1 context, etc. This is the core advantage over ViT: temporal information accumulates in the recurrent memory state, not through attention over all tokens. Token counts: Template: 128x128 → 8x8 = 64 tokens Each search: 256x256 → 16x16 = 256 tokens K=3 sequence: 64 + 3×256 = 832 tokens Bidirectional scanning: even blocks L→R, odd blocks R→L. FiLM modulation: applied between blocks at interval=6. TMoE: last `tmoe_blocks` blocks. """ def __init__( self, dim: int = 384, depth: int = 24, patch_size: int = 16, in_channels: int = 3, proj_factor: float = 2.0, qkv_proj_blocksize: int = 4, num_heads: int = 4, conv_kernel: int = 4, mlp_ratio: float = 4.0, drop_path_rate: float = 0.05, tmoe_blocks: int = 2, num_experts: int = 4, bias: bool = False, film_interval: int = 6, ): super().__init__() self.dim = dim self.depth = depth self.patch_size = patch_size self.film_interval = film_interval # Patch embedding self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=in_channels, dim=dim) # Positional embeddings for template and search regions # Template: 128/16 = 8x8 = 64 tokens # Search: 256/16 = 16x16 = 256 tokens self.template_pos = nn.Parameter(torch.randn(1, 64, dim) * 0.02) self.search_pos = nn.Parameter(torch.randn(1, 256, dim) * 0.02) # Token type embeddings (template vs search) self.template_type = nn.Parameter(torch.randn(1, 1, dim) * 0.02) self.search_type = nn.Parameter(torch.randn(1, 1, dim) * 0.02) # Stochastic depth rates (linearly increasing) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # Build blocks: last `tmoe_blocks` use TMoE MLP self.blocks = nn.ModuleList() for i in range(depth): if i >= depth - tmoe_blocks: block = mLSTMBlockWithTMoE( dim=dim, proj_factor=proj_factor, qkv_proj_blocksize=qkv_proj_blocksize, num_heads=num_heads, conv_kernel=conv_kernel, mlp_ratio=mlp_ratio, drop_path=dpr[i], num_experts=num_experts, bias=bias, ) else: block = mLSTMBlock( dim=dim, proj_factor=proj_factor, qkv_proj_blocksize=qkv_proj_blocksize, num_heads=num_heads, conv_kernel=conv_kernel, mlp_ratio=mlp_ratio, drop_path=dpr[i], bias=bias, ) self.blocks.append(block) # Final norm self.norm = nn.LayerNorm(dim, bias=False) def forward( self, template: torch.Tensor, searches: torch.Tensor, temporal_mod_manager=None, ) -> tuple: """ Process template + K search frames as one mLSTM sequence. Args: template: (B, 3, 128, 128) template image searches: (B, K, 3, 256, 256) K consecutive search frames OR (B, 3, 256, 256) single search frame (backward compat) temporal_mod_manager: optional TemporalModulationManager for FiLM Returns: template_feat: (B, 64, D) template features search_feats: (B, K, 256, D) per-frame search features OR (B, 256, D) if single search frame input """ B = template.shape[0] single_frame = (searches.ndim == 4) # (B, 3, H, W) vs (B, K, 3, H, W) if single_frame: searches = searches.unsqueeze(1) # (B, 1, 3, H, W) K = searches.shape[1] # Patch embed template t_tokens = self.patch_embed(template) # (B, 64, D) t_tokens = t_tokens + self.template_pos + self.template_type n_template = t_tokens.shape[1] # 64 # Patch embed all search frames # Reshape (B, K, 3, H, W) → (B*K, 3, H, W) for batch patch embedding s_flat = searches.reshape(B * K, *searches.shape[2:]) s_tokens_flat = self.patch_embed(s_flat) # (B*K, 256, D) s_tokens = s_tokens_flat.reshape(B, K, -1, self.dim) # (B, K, 256, D) s_tokens = s_tokens + self.search_pos.unsqueeze(1) + self.search_type n_search = s_tokens.shape[2] # 256 # Build full sequence: [template | search_1 | search_2 | ... | search_K] # The mLSTM memory carries information across this entire sequence s_tokens_concat = s_tokens.reshape(B, K * n_search, self.dim) # (B, K*256, D) tokens = torch.cat([t_tokens, s_tokens_concat], dim=1) # (B, 64 + K*256, D) # Process through bidirectional mLSTM blocks for i, block in enumerate(self.blocks): reverse = (i % 2 == 1) tokens = block(tokens, reverse=reverse) if temporal_mod_manager is not None: tokens = temporal_mod_manager.modulate(tokens, i) tokens = self.norm(tokens) if temporal_mod_manager is not None: temporal_mod_manager.update_temporal_context(tokens) # Split: template features + per-frame search features template_feat = tokens[:, :n_template] # (B, 64, D) search_tokens = tokens[:, n_template:] # (B, K*256, D) search_feats = search_tokens.reshape(B, K, n_search, self.dim) # (B, K, 256, D) if single_frame: return template_feat, search_feats.squeeze(1) # (B, 256, D) return template_feat, search_feats def freeze_shared_experts(self): """Freeze shared experts in TMoE blocks for Phase 2 training.""" for block in self.blocks: if hasattr(block, 'freeze_shared_expert'): block.freeze_shared_expert()