vil-tracker / vil_tracker /models /backbone.py
omar-ah's picture
Upload vil_tracker/models/backbone.py with huggingface_hub
9556ef9 verified
raw
history blame
8.9 kB
"""
ViL (Vision-LSTM) Backbone for single object tracking.
Architecture:
- Patch embedding (Conv2d) for template + search region
- Stack of mLSTM blocks with bidirectional scanning (even=L→R, odd=R→L)
- Optional TMoE-MLP in last N blocks (dense routing, frozen shared expert)
- Outputs concatenated template+search features for head processing
ViL-S config: dim=384, depth=24, patch_size=16, ~23M backbone params
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .mlstm import mLSTMBlock, SwiGLUMLP, StochasticDepth
class PatchEmbed(nn.Module):
"""Convert image patches to token embeddings using Conv2d."""
def __init__(self, patch_size: int = 16, in_channels: int = 3, dim: int = 384):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, C, H, W) image tensor
Returns:
(B, N, D) patch token embeddings, N = (H/P)*(W/P)
"""
x = self.proj(x) # (B, D, H/P, W/P)
x = rearrange(x, 'b d h w -> b (h w) d')
x = self.norm(x)
return x
class TMoEMLP(nn.Module):
"""Temporal Mixture-of-Experts MLP.
Uses dense routing with a shared expert (frozen after Phase 1) and
K specialized experts. Output = shared_out + sum(gate_k * expert_k_out).
For tracking: experts specialize on different temporal dynamics
(fast motion, occlusion recovery, scale change).
"""
def __init__(
self,
dim: int = 384,
mlp_ratio: float = 4.0,
num_experts: int = 4,
bias: bool = False,
):
super().__init__()
self.num_experts = num_experts
hidden_dim = int(dim * mlp_ratio)
# Shared expert (frozen after Phase 1 training)
self.shared_expert = SwiGLUMLP(dim=dim, mlp_ratio=mlp_ratio, bias=bias)
# Specialized experts (smaller: mlp_ratio/2)
small_ratio = mlp_ratio / 2
self.experts = nn.ModuleList([
SwiGLUMLP(dim=dim, mlp_ratio=small_ratio, bias=bias)
for _ in range(num_experts)
])
# Dense router: soft gating over experts
self.router = nn.Linear(dim, num_experts, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Shared expert output (always contributes)
shared_out = self.shared_expert(x)
# Router logits and softmax gates
gates = F.softmax(self.router(x), dim=-1) # (B, S, num_experts)
# Expert outputs, weighted by gates
expert_out = torch.zeros_like(shared_out)
for i, expert in enumerate(self.experts):
expert_out = expert_out + gates[..., i:i+1] * expert(x)
return shared_out + expert_out
def freeze_shared_expert(self):
"""Freeze the shared expert for Phase 2 training."""
for p in self.shared_expert.parameters():
p.requires_grad = False
class mLSTMBlockWithTMoE(nn.Module):
"""mLSTM block with TMoE MLP instead of standard SwiGLU MLP."""
def __init__(
self,
dim: int = 384,
proj_factor: float = 2.0,
qkv_proj_blocksize: int = 4,
num_heads: int = 4,
conv_kernel: int = 4,
mlp_ratio: float = 4.0,
drop_path: float = 0.0,
num_experts: int = 4,
bias: bool = False,
):
super().__init__()
from .mlstm import mLSTMCell
self.norm1 = nn.LayerNorm(dim, bias=False)
self.mlstm = mLSTMCell(
dim=dim,
proj_factor=proj_factor,
qkv_proj_blocksize=qkv_proj_blocksize,
num_heads=num_heads,
conv_kernel=conv_kernel,
bias=bias,
)
self.norm2 = nn.LayerNorm(dim, bias=False)
self.mlp = TMoEMLP(dim=dim, mlp_ratio=mlp_ratio, num_experts=num_experts, bias=bias)
self.drop_path = StochasticDepth(drop_path)
def forward(self, x: torch.Tensor, reverse: bool = False) -> torch.Tensor:
x = x + self.drop_path(self.mlstm(self.norm1(x), reverse=reverse))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def freeze_shared_expert(self):
self.mlp.freeze_shared_expert()
class ViLBackbone(nn.Module):
"""Vision-LSTM backbone for tracking.
Concatenates template + search patches into a single sequence,
processes through bidirectional mLSTM blocks, then separates outputs.
Template: 128x128 → 8x8 = 64 tokens
Search: 256x256 → 16x16 = 256 tokens
Total sequence: 320 tokens
Bidirectional scanning: even blocks L→R, odd blocks R→L.
Last `tmoe_blocks` blocks use TMoE MLP for temporal specialization.
"""
def __init__(
self,
dim: int = 384,
depth: int = 24,
patch_size: int = 16,
in_channels: int = 3,
proj_factor: float = 2.0,
qkv_proj_blocksize: int = 4,
num_heads: int = 4,
conv_kernel: int = 4,
mlp_ratio: float = 4.0,
drop_path_rate: float = 0.1,
tmoe_blocks: int = 2,
num_experts: int = 4,
bias: bool = False,
):
super().__init__()
self.dim = dim
self.depth = depth
self.patch_size = patch_size
# Patch embedding
self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=in_channels, dim=dim)
# Positional embeddings for template and search regions
# Template: 128/16 = 8x8 = 64 tokens
# Search: 256/16 = 16x16 = 256 tokens
self.template_pos = nn.Parameter(torch.randn(1, 64, dim) * 0.02)
self.search_pos = nn.Parameter(torch.randn(1, 256, dim) * 0.02)
# Token type embeddings (template vs search)
self.template_type = nn.Parameter(torch.randn(1, 1, dim) * 0.02)
self.search_type = nn.Parameter(torch.randn(1, 1, dim) * 0.02)
# Stochastic depth rates (linearly increasing)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
# Build blocks: last `tmoe_blocks` use TMoE MLP
self.blocks = nn.ModuleList()
for i in range(depth):
if i >= depth - tmoe_blocks:
block = mLSTMBlockWithTMoE(
dim=dim, proj_factor=proj_factor,
qkv_proj_blocksize=qkv_proj_blocksize,
num_heads=num_heads, conv_kernel=conv_kernel,
mlp_ratio=mlp_ratio, drop_path=dpr[i],
num_experts=num_experts, bias=bias,
)
else:
block = mLSTMBlock(
dim=dim, proj_factor=proj_factor,
qkv_proj_blocksize=qkv_proj_blocksize,
num_heads=num_heads, conv_kernel=conv_kernel,
mlp_ratio=mlp_ratio, drop_path=dpr[i], bias=bias,
)
self.blocks.append(block)
# Final norm
self.norm = nn.LayerNorm(dim, bias=False)
def forward(
self,
template: torch.Tensor,
search: torch.Tensor,
) -> tuple:
"""
Args:
template: (B, 3, 128, 128) template image
search: (B, 3, 256, 256) search region image
Returns:
template_feat: (B, 64, D) template features
search_feat: (B, 256, D) search features
"""
B = template.shape[0]
# Patch embed
t_tokens = self.patch_embed(template) # (B, 64, D)
s_tokens = self.patch_embed(search) # (B, 256, D)
# Add positional + type embeddings
t_tokens = t_tokens + self.template_pos + self.template_type
s_tokens = s_tokens + self.search_pos + self.search_type
# Concatenate: [template | search]
tokens = torch.cat([t_tokens, s_tokens], dim=1) # (B, 320, D)
# Process through bidirectional mLSTM blocks
for i, block in enumerate(self.blocks):
reverse = (i % 2 == 1) # odd blocks: R→L
tokens = block(tokens, reverse=reverse)
tokens = self.norm(tokens)
# Split back
n_template = t_tokens.shape[1]
template_feat = tokens[:, :n_template]
search_feat = tokens[:, n_template:]
return template_feat, search_feat
def freeze_shared_experts(self):
"""Freeze shared experts in TMoE blocks for Phase 2 training."""
for block in self.blocks:
if hasattr(block, 'freeze_shared_expert'):
block.freeze_shared_expert()