omar-ah commited on
Commit
9556ef9
·
verified ·
1 Parent(s): 523bea4

Upload vil_tracker/models/backbone.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vil_tracker/models/backbone.py +252 -0
vil_tracker/models/backbone.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ViL (Vision-LSTM) Backbone for single object tracking.
3
+
4
+ Architecture:
5
+ - Patch embedding (Conv2d) for template + search region
6
+ - Stack of mLSTM blocks with bidirectional scanning (even=L→R, odd=R→L)
7
+ - Optional TMoE-MLP in last N blocks (dense routing, frozen shared expert)
8
+ - Outputs concatenated template+search features for head processing
9
+
10
+ ViL-S config: dim=384, depth=24, patch_size=16, ~23M backbone params
11
+ """
12
+
13
+ import math
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from einops import rearrange
18
+
19
+ from .mlstm import mLSTMBlock, SwiGLUMLP, StochasticDepth
20
+
21
+
22
+ class PatchEmbed(nn.Module):
23
+ """Convert image patches to token embeddings using Conv2d."""
24
+ def __init__(self, patch_size: int = 16, in_channels: int = 3, dim: int = 384):
25
+ super().__init__()
26
+ self.patch_size = patch_size
27
+ self.proj = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size)
28
+ self.norm = nn.LayerNorm(dim, bias=False)
29
+
30
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
31
+ """
32
+ Args:
33
+ x: (B, C, H, W) image tensor
34
+ Returns:
35
+ (B, N, D) patch token embeddings, N = (H/P)*(W/P)
36
+ """
37
+ x = self.proj(x) # (B, D, H/P, W/P)
38
+ x = rearrange(x, 'b d h w -> b (h w) d')
39
+ x = self.norm(x)
40
+ return x
41
+
42
+
43
+ class TMoEMLP(nn.Module):
44
+ """Temporal Mixture-of-Experts MLP.
45
+
46
+ Uses dense routing with a shared expert (frozen after Phase 1) and
47
+ K specialized experts. Output = shared_out + sum(gate_k * expert_k_out).
48
+
49
+ For tracking: experts specialize on different temporal dynamics
50
+ (fast motion, occlusion recovery, scale change).
51
+ """
52
+ def __init__(
53
+ self,
54
+ dim: int = 384,
55
+ mlp_ratio: float = 4.0,
56
+ num_experts: int = 4,
57
+ bias: bool = False,
58
+ ):
59
+ super().__init__()
60
+ self.num_experts = num_experts
61
+ hidden_dim = int(dim * mlp_ratio)
62
+
63
+ # Shared expert (frozen after Phase 1 training)
64
+ self.shared_expert = SwiGLUMLP(dim=dim, mlp_ratio=mlp_ratio, bias=bias)
65
+
66
+ # Specialized experts (smaller: mlp_ratio/2)
67
+ small_ratio = mlp_ratio / 2
68
+ self.experts = nn.ModuleList([
69
+ SwiGLUMLP(dim=dim, mlp_ratio=small_ratio, bias=bias)
70
+ for _ in range(num_experts)
71
+ ])
72
+
73
+ # Dense router: soft gating over experts
74
+ self.router = nn.Linear(dim, num_experts, bias=True)
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ # Shared expert output (always contributes)
78
+ shared_out = self.shared_expert(x)
79
+
80
+ # Router logits and softmax gates
81
+ gates = F.softmax(self.router(x), dim=-1) # (B, S, num_experts)
82
+
83
+ # Expert outputs, weighted by gates
84
+ expert_out = torch.zeros_like(shared_out)
85
+ for i, expert in enumerate(self.experts):
86
+ expert_out = expert_out + gates[..., i:i+1] * expert(x)
87
+
88
+ return shared_out + expert_out
89
+
90
+ def freeze_shared_expert(self):
91
+ """Freeze the shared expert for Phase 2 training."""
92
+ for p in self.shared_expert.parameters():
93
+ p.requires_grad = False
94
+
95
+
96
+ class mLSTMBlockWithTMoE(nn.Module):
97
+ """mLSTM block with TMoE MLP instead of standard SwiGLU MLP."""
98
+ def __init__(
99
+ self,
100
+ dim: int = 384,
101
+ proj_factor: float = 2.0,
102
+ qkv_proj_blocksize: int = 4,
103
+ num_heads: int = 4,
104
+ conv_kernel: int = 4,
105
+ mlp_ratio: float = 4.0,
106
+ drop_path: float = 0.0,
107
+ num_experts: int = 4,
108
+ bias: bool = False,
109
+ ):
110
+ super().__init__()
111
+ from .mlstm import mLSTMCell
112
+
113
+ self.norm1 = nn.LayerNorm(dim, bias=False)
114
+ self.mlstm = mLSTMCell(
115
+ dim=dim,
116
+ proj_factor=proj_factor,
117
+ qkv_proj_blocksize=qkv_proj_blocksize,
118
+ num_heads=num_heads,
119
+ conv_kernel=conv_kernel,
120
+ bias=bias,
121
+ )
122
+ self.norm2 = nn.LayerNorm(dim, bias=False)
123
+ self.mlp = TMoEMLP(dim=dim, mlp_ratio=mlp_ratio, num_experts=num_experts, bias=bias)
124
+ self.drop_path = StochasticDepth(drop_path)
125
+
126
+ def forward(self, x: torch.Tensor, reverse: bool = False) -> torch.Tensor:
127
+ x = x + self.drop_path(self.mlstm(self.norm1(x), reverse=reverse))
128
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
129
+ return x
130
+
131
+ def freeze_shared_expert(self):
132
+ self.mlp.freeze_shared_expert()
133
+
134
+
135
+ class ViLBackbone(nn.Module):
136
+ """Vision-LSTM backbone for tracking.
137
+
138
+ Concatenates template + search patches into a single sequence,
139
+ processes through bidirectional mLSTM blocks, then separates outputs.
140
+
141
+ Template: 128x128 → 8x8 = 64 tokens
142
+ Search: 256x256 → 16x16 = 256 tokens
143
+ Total sequence: 320 tokens
144
+
145
+ Bidirectional scanning: even blocks L→R, odd blocks R→L.
146
+ Last `tmoe_blocks` blocks use TMoE MLP for temporal specialization.
147
+ """
148
+ def __init__(
149
+ self,
150
+ dim: int = 384,
151
+ depth: int = 24,
152
+ patch_size: int = 16,
153
+ in_channels: int = 3,
154
+ proj_factor: float = 2.0,
155
+ qkv_proj_blocksize: int = 4,
156
+ num_heads: int = 4,
157
+ conv_kernel: int = 4,
158
+ mlp_ratio: float = 4.0,
159
+ drop_path_rate: float = 0.1,
160
+ tmoe_blocks: int = 2,
161
+ num_experts: int = 4,
162
+ bias: bool = False,
163
+ ):
164
+ super().__init__()
165
+ self.dim = dim
166
+ self.depth = depth
167
+ self.patch_size = patch_size
168
+
169
+ # Patch embedding
170
+ self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=in_channels, dim=dim)
171
+
172
+ # Positional embeddings for template and search regions
173
+ # Template: 128/16 = 8x8 = 64 tokens
174
+ # Search: 256/16 = 16x16 = 256 tokens
175
+ self.template_pos = nn.Parameter(torch.randn(1, 64, dim) * 0.02)
176
+ self.search_pos = nn.Parameter(torch.randn(1, 256, dim) * 0.02)
177
+
178
+ # Token type embeddings (template vs search)
179
+ self.template_type = nn.Parameter(torch.randn(1, 1, dim) * 0.02)
180
+ self.search_type = nn.Parameter(torch.randn(1, 1, dim) * 0.02)
181
+
182
+ # Stochastic depth rates (linearly increasing)
183
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
184
+
185
+ # Build blocks: last `tmoe_blocks` use TMoE MLP
186
+ self.blocks = nn.ModuleList()
187
+ for i in range(depth):
188
+ if i >= depth - tmoe_blocks:
189
+ block = mLSTMBlockWithTMoE(
190
+ dim=dim, proj_factor=proj_factor,
191
+ qkv_proj_blocksize=qkv_proj_blocksize,
192
+ num_heads=num_heads, conv_kernel=conv_kernel,
193
+ mlp_ratio=mlp_ratio, drop_path=dpr[i],
194
+ num_experts=num_experts, bias=bias,
195
+ )
196
+ else:
197
+ block = mLSTMBlock(
198
+ dim=dim, proj_factor=proj_factor,
199
+ qkv_proj_blocksize=qkv_proj_blocksize,
200
+ num_heads=num_heads, conv_kernel=conv_kernel,
201
+ mlp_ratio=mlp_ratio, drop_path=dpr[i], bias=bias,
202
+ )
203
+ self.blocks.append(block)
204
+
205
+ # Final norm
206
+ self.norm = nn.LayerNorm(dim, bias=False)
207
+
208
+ def forward(
209
+ self,
210
+ template: torch.Tensor,
211
+ search: torch.Tensor,
212
+ ) -> tuple:
213
+ """
214
+ Args:
215
+ template: (B, 3, 128, 128) template image
216
+ search: (B, 3, 256, 256) search region image
217
+ Returns:
218
+ template_feat: (B, 64, D) template features
219
+ search_feat: (B, 256, D) search features
220
+ """
221
+ B = template.shape[0]
222
+
223
+ # Patch embed
224
+ t_tokens = self.patch_embed(template) # (B, 64, D)
225
+ s_tokens = self.patch_embed(search) # (B, 256, D)
226
+
227
+ # Add positional + type embeddings
228
+ t_tokens = t_tokens + self.template_pos + self.template_type
229
+ s_tokens = s_tokens + self.search_pos + self.search_type
230
+
231
+ # Concatenate: [template | search]
232
+ tokens = torch.cat([t_tokens, s_tokens], dim=1) # (B, 320, D)
233
+
234
+ # Process through bidirectional mLSTM blocks
235
+ for i, block in enumerate(self.blocks):
236
+ reverse = (i % 2 == 1) # odd blocks: R→L
237
+ tokens = block(tokens, reverse=reverse)
238
+
239
+ tokens = self.norm(tokens)
240
+
241
+ # Split back
242
+ n_template = t_tokens.shape[1]
243
+ template_feat = tokens[:, :n_template]
244
+ search_feat = tokens[:, n_template:]
245
+
246
+ return template_feat, search_feat
247
+
248
+ def freeze_shared_experts(self):
249
+ """Freeze shared experts in TMoE blocks for Phase 2 training."""
250
+ for block in self.blocks:
251
+ if hasattr(block, 'freeze_shared_expert'):
252
+ block.freeze_shared_expert()