""" ViL (Vision-LSTM) Backbone for single object tracking. Architecture: - Patch embedding (Conv2d) for template + search region - Stack of mLSTM blocks with bidirectional scanning (even=L→R, odd=R→L) - Optional TMoE-MLP in last N blocks (dense routing, frozen shared expert) - Outputs concatenated template+search features for head processing ViL-S config: dim=384, depth=24, patch_size=16, ~23M backbone params """ import math import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from .mlstm import mLSTMBlock, SwiGLUMLP, StochasticDepth class PatchEmbed(nn.Module): """Convert image patches to token embeddings using Conv2d.""" def __init__(self, patch_size: int = 16, in_channels: int = 3, dim: int = 384): super().__init__() self.patch_size = patch_size self.proj = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size) self.norm = nn.LayerNorm(dim, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (B, C, H, W) image tensor Returns: (B, N, D) patch token embeddings, N = (H/P)*(W/P) """ x = self.proj(x) # (B, D, H/P, W/P) x = rearrange(x, 'b d h w -> b (h w) d') x = self.norm(x) return x class TMoEMLP(nn.Module): """Temporal Mixture-of-Experts MLP. Uses dense routing with a shared expert (frozen after Phase 1) and K specialized experts. Output = shared_out + sum(gate_k * expert_k_out). For tracking: experts specialize on different temporal dynamics (fast motion, occlusion recovery, scale change). """ def __init__( self, dim: int = 384, mlp_ratio: float = 4.0, num_experts: int = 4, bias: bool = False, ): super().__init__() self.num_experts = num_experts hidden_dim = int(dim * mlp_ratio) # Shared expert (frozen after Phase 1 training) self.shared_expert = SwiGLUMLP(dim=dim, mlp_ratio=mlp_ratio, bias=bias) # Specialized experts (smaller: mlp_ratio/2) small_ratio = mlp_ratio / 2 self.experts = nn.ModuleList([ SwiGLUMLP(dim=dim, mlp_ratio=small_ratio, bias=bias) for _ in range(num_experts) ]) # Dense router: soft gating over experts self.router = nn.Linear(dim, num_experts, bias=True) def forward(self, x: torch.Tensor) -> torch.Tensor: # Shared expert output (always contributes) shared_out = self.shared_expert(x) # Router logits and softmax gates gates = F.softmax(self.router(x), dim=-1) # (B, S, num_experts) # Expert outputs, weighted by gates expert_out = torch.zeros_like(shared_out) for i, expert in enumerate(self.experts): expert_out = expert_out + gates[..., i:i+1] * expert(x) return shared_out + expert_out def freeze_shared_expert(self): """Freeze the shared expert for Phase 2 training.""" for p in self.shared_expert.parameters(): p.requires_grad = False class mLSTMBlockWithTMoE(nn.Module): """mLSTM block with TMoE MLP instead of standard SwiGLU MLP.""" def __init__( self, dim: int = 384, proj_factor: float = 2.0, qkv_proj_blocksize: int = 4, num_heads: int = 4, conv_kernel: int = 4, mlp_ratio: float = 4.0, drop_path: float = 0.0, num_experts: int = 4, bias: bool = False, ): super().__init__() from .mlstm import mLSTMCell self.norm1 = nn.LayerNorm(dim, bias=False) self.mlstm = mLSTMCell( dim=dim, proj_factor=proj_factor, qkv_proj_blocksize=qkv_proj_blocksize, num_heads=num_heads, conv_kernel=conv_kernel, bias=bias, ) self.norm2 = nn.LayerNorm(dim, bias=False) self.mlp = TMoEMLP(dim=dim, mlp_ratio=mlp_ratio, num_experts=num_experts, bias=bias) self.drop_path = StochasticDepth(drop_path) def forward(self, x: torch.Tensor, reverse: bool = False) -> torch.Tensor: x = x + self.drop_path(self.mlstm(self.norm1(x), reverse=reverse)) x = x + self.drop_path(self.mlp(self.norm2(x))) return x def freeze_shared_expert(self): self.mlp.freeze_shared_expert() class ViLBackbone(nn.Module): """Vision-LSTM backbone for tracking. Concatenates template + search patches into a single sequence, processes through bidirectional mLSTM blocks, then separates outputs. Template: 128x128 → 8x8 = 64 tokens Search: 256x256 → 16x16 = 256 tokens Total sequence: 320 tokens Bidirectional scanning: even blocks L→R, odd blocks R→L. Last `tmoe_blocks` blocks use TMoE MLP for temporal specialization. """ def __init__( self, dim: int = 384, depth: int = 24, patch_size: int = 16, in_channels: int = 3, proj_factor: float = 2.0, qkv_proj_blocksize: int = 4, num_heads: int = 4, conv_kernel: int = 4, mlp_ratio: float = 4.0, drop_path_rate: float = 0.1, tmoe_blocks: int = 2, num_experts: int = 4, bias: bool = False, ): super().__init__() self.dim = dim self.depth = depth self.patch_size = patch_size # Patch embedding self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=in_channels, dim=dim) # Positional embeddings for template and search regions # Template: 128/16 = 8x8 = 64 tokens # Search: 256/16 = 16x16 = 256 tokens self.template_pos = nn.Parameter(torch.randn(1, 64, dim) * 0.02) self.search_pos = nn.Parameter(torch.randn(1, 256, dim) * 0.02) # Token type embeddings (template vs search) self.template_type = nn.Parameter(torch.randn(1, 1, dim) * 0.02) self.search_type = nn.Parameter(torch.randn(1, 1, dim) * 0.02) # Stochastic depth rates (linearly increasing) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # Build blocks: last `tmoe_blocks` use TMoE MLP self.blocks = nn.ModuleList() for i in range(depth): if i >= depth - tmoe_blocks: block = mLSTMBlockWithTMoE( dim=dim, proj_factor=proj_factor, qkv_proj_blocksize=qkv_proj_blocksize, num_heads=num_heads, conv_kernel=conv_kernel, mlp_ratio=mlp_ratio, drop_path=dpr[i], num_experts=num_experts, bias=bias, ) else: block = mLSTMBlock( dim=dim, proj_factor=proj_factor, qkv_proj_blocksize=qkv_proj_blocksize, num_heads=num_heads, conv_kernel=conv_kernel, mlp_ratio=mlp_ratio, drop_path=dpr[i], bias=bias, ) self.blocks.append(block) # Final norm self.norm = nn.LayerNorm(dim, bias=False) def forward( self, template: torch.Tensor, search: torch.Tensor, ) -> tuple: """ Args: template: (B, 3, 128, 128) template image search: (B, 3, 256, 256) search region image Returns: template_feat: (B, 64, D) template features search_feat: (B, 256, D) search features """ B = template.shape[0] # Patch embed t_tokens = self.patch_embed(template) # (B, 64, D) s_tokens = self.patch_embed(search) # (B, 256, D) # Add positional + type embeddings t_tokens = t_tokens + self.template_pos + self.template_type s_tokens = s_tokens + self.search_pos + self.search_type # Concatenate: [template | search] tokens = torch.cat([t_tokens, s_tokens], dim=1) # (B, 320, D) # Process through bidirectional mLSTM blocks for i, block in enumerate(self.blocks): reverse = (i % 2 == 1) # odd blocks: R→L tokens = block(tokens, reverse=reverse) tokens = self.norm(tokens) # Split back n_template = t_tokens.shape[1] template_feat = tokens[:, :n_template] search_feat = tokens[:, n_template:] return template_feat, search_feat def freeze_shared_experts(self): """Freeze shared experts in TMoE blocks for Phase 2 training.""" for block in self.blocks: if hasattr(block, 'freeze_shared_expert'): block.freeze_shared_expert()