File size: 11,186 Bytes
9556ef9
 
 
 
 
 
c08a3b0
9556ef9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4d3af5
9556ef9
a4d3af5
 
9556ef9
a4d3af5
 
 
 
 
 
 
 
 
 
 
 
9556ef9
 
a4d3af5
 
9556ef9
 
 
 
 
 
 
 
 
 
 
 
7619099
9556ef9
 
 
c08a3b0
9556ef9
 
 
 
 
c08a3b0
9556ef9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4d3af5
c08a3b0
9556ef9
 
a4d3af5
 
9556ef9
 
a4d3af5
 
c08a3b0
9556ef9
 
a4d3af5
 
9556ef9
 
a4d3af5
9556ef9
a4d3af5
 
9556ef9
a4d3af5
 
 
 
9556ef9
a4d3af5
9556ef9
a4d3af5
 
 
 
 
 
 
9556ef9
a4d3af5
 
 
 
 
 
9556ef9
a4d3af5
9556ef9
c08a3b0
 
 
9556ef9
 
 
c08a3b0
 
 
a4d3af5
 
 
 
 
 
 
9556ef9
a4d3af5
9556ef9
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
"""
ViL (Vision-LSTM) Backbone for single object tracking.

Architecture:
- Patch embedding (Conv2d) for template + search region
- Stack of mLSTM blocks with bidirectional scanning (even=L→R, odd=R→L)
- FiLM temporal modulation integrated BETWEEN blocks (at interval=6)
- Optional TMoE-MLP in last N blocks (dense routing, frozen shared expert)
- Outputs concatenated template+search features for head processing

ViL-S config: dim=384, depth=24, patch_size=16, ~23M backbone params
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from .mlstm import mLSTMBlock, SwiGLUMLP, StochasticDepth


class PatchEmbed(nn.Module):
    """Convert image patches to token embeddings using Conv2d."""
    def __init__(self, patch_size: int = 16, in_channels: int = 3, dim: int = 384):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(dim, bias=False)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (B, C, H, W) image tensor
        Returns:
            (B, N, D) patch token embeddings, N = (H/P)*(W/P)
        """
        x = self.proj(x)  # (B, D, H/P, W/P)
        x = rearrange(x, 'b d h w -> b (h w) d')
        x = self.norm(x)
        return x


class TMoEMLP(nn.Module):
    """Temporal Mixture-of-Experts MLP.
    
    Uses dense routing with a shared expert (frozen after Phase 1) and 
    K specialized experts. Output = shared_out + sum(gate_k * expert_k_out).
    
    For tracking: experts specialize on different temporal dynamics
    (fast motion, occlusion recovery, scale change).
    """
    def __init__(
        self,
        dim: int = 384,
        mlp_ratio: float = 4.0,
        num_experts: int = 4,
        bias: bool = False,
    ):
        super().__init__()
        self.num_experts = num_experts
        hidden_dim = int(dim * mlp_ratio)
        
        # Shared expert (frozen after Phase 1 training)
        self.shared_expert = SwiGLUMLP(dim=dim, mlp_ratio=mlp_ratio, bias=bias)
        
        # Specialized experts (smaller: mlp_ratio/2)
        small_ratio = mlp_ratio / 2
        self.experts = nn.ModuleList([
            SwiGLUMLP(dim=dim, mlp_ratio=small_ratio, bias=bias)
            for _ in range(num_experts)
        ])
        
        # Dense router: soft gating over experts
        self.router = nn.Linear(dim, num_experts, bias=True)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Shared expert output (always contributes)
        shared_out = self.shared_expert(x)
        
        # Router logits and softmax gates
        gates = F.softmax(self.router(x), dim=-1)  # (B, S, num_experts)
        
        # Expert outputs, weighted by gates
        expert_out = torch.zeros_like(shared_out)
        for i, expert in enumerate(self.experts):
            expert_out = expert_out + gates[..., i:i+1] * expert(x)
        
        return shared_out + expert_out
    
    def freeze_shared_expert(self):
        """Freeze the shared expert for Phase 2 training."""
        for p in self.shared_expert.parameters():
            p.requires_grad = False


class mLSTMBlockWithTMoE(nn.Module):
    """mLSTM block with TMoE MLP instead of standard SwiGLU MLP."""
    def __init__(
        self,
        dim: int = 384,
        proj_factor: float = 2.0,
        qkv_proj_blocksize: int = 4,
        num_heads: int = 4,
        conv_kernel: int = 4,
        mlp_ratio: float = 4.0,
        drop_path: float = 0.0,
        num_experts: int = 4,
        bias: bool = False,
    ):
        super().__init__()
        from .mlstm import mLSTMCell
        
        self.norm1 = nn.LayerNorm(dim, bias=False)
        self.mlstm = mLSTMCell(
            dim=dim,
            proj_factor=proj_factor,
            qkv_proj_blocksize=qkv_proj_blocksize,
            num_heads=num_heads,
            conv_kernel=conv_kernel,
            bias=bias,
        )
        self.norm2 = nn.LayerNorm(dim, bias=False)
        self.mlp = TMoEMLP(dim=dim, mlp_ratio=mlp_ratio, num_experts=num_experts, bias=bias)
        self.drop_path = StochasticDepth(drop_path)
    
    def forward(self, x: torch.Tensor, reverse: bool = False) -> torch.Tensor:
        x = x + self.drop_path(self.mlstm(self.norm1(x), reverse=reverse))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x
    
    def freeze_shared_expert(self):
        self.mlp.freeze_shared_expert()


class ViLBackbone(nn.Module):
    """Vision-LSTM backbone for tracking with sequential multi-frame processing.
    
    Processes template + K search frames as one long mLSTM sequence:
        [template_tokens | search_1_tokens | search_2_tokens | ... | search_K_tokens]
    
    The mLSTM memory state C carries information across frames:
    - Template tokens establish the target appearance in memory
    - Search_1 tokens are processed with template context in memory
    - Search_2 tokens are processed with template + search_1 context, etc.
    
    This is the core advantage over ViT: temporal information accumulates
    in the recurrent memory state, not through attention over all tokens.
    
    Token counts:
        Template: 128x128 → 8x8 = 64 tokens
        Each search: 256x256 → 16x16 = 256 tokens  
        K=3 sequence: 64 + 3×256 = 832 tokens
    
    Bidirectional scanning: even blocks L→R, odd blocks R→L.
    FiLM modulation: applied between blocks at interval=6.
    TMoE: last `tmoe_blocks` blocks.
    """
    def __init__(
        self,
        dim: int = 384,
        depth: int = 24,
        patch_size: int = 16,
        in_channels: int = 3,
        proj_factor: float = 2.0,
        qkv_proj_blocksize: int = 4,
        num_heads: int = 4,
        conv_kernel: int = 4,
        mlp_ratio: float = 4.0,
        drop_path_rate: float = 0.05,
        tmoe_blocks: int = 2,
        num_experts: int = 4,
        bias: bool = False,
        film_interval: int = 6,
    ):
        super().__init__()
        self.dim = dim
        self.depth = depth
        self.patch_size = patch_size
        self.film_interval = film_interval
        
        # Patch embedding
        self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=in_channels, dim=dim)
        
        # Positional embeddings for template and search regions
        # Template: 128/16 = 8x8 = 64 tokens
        # Search: 256/16 = 16x16 = 256 tokens
        self.template_pos = nn.Parameter(torch.randn(1, 64, dim) * 0.02)
        self.search_pos = nn.Parameter(torch.randn(1, 256, dim) * 0.02)
        
        # Token type embeddings (template vs search)
        self.template_type = nn.Parameter(torch.randn(1, 1, dim) * 0.02)
        self.search_type = nn.Parameter(torch.randn(1, 1, dim) * 0.02)
        
        # Stochastic depth rates (linearly increasing)
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        
        # Build blocks: last `tmoe_blocks` use TMoE MLP
        self.blocks = nn.ModuleList()
        for i in range(depth):
            if i >= depth - tmoe_blocks:
                block = mLSTMBlockWithTMoE(
                    dim=dim, proj_factor=proj_factor,
                    qkv_proj_blocksize=qkv_proj_blocksize,
                    num_heads=num_heads, conv_kernel=conv_kernel,
                    mlp_ratio=mlp_ratio, drop_path=dpr[i],
                    num_experts=num_experts, bias=bias,
                )
            else:
                block = mLSTMBlock(
                    dim=dim, proj_factor=proj_factor,
                    qkv_proj_blocksize=qkv_proj_blocksize,
                    num_heads=num_heads, conv_kernel=conv_kernel,
                    mlp_ratio=mlp_ratio, drop_path=dpr[i], bias=bias,
                )
            self.blocks.append(block)
        
        # Final norm
        self.norm = nn.LayerNorm(dim, bias=False)
    
    def forward(
        self,
        template: torch.Tensor,
        searches: torch.Tensor,
        temporal_mod_manager=None,
    ) -> tuple:
        """
        Process template + K search frames as one mLSTM sequence.
        
        Args:
            template: (B, 3, 128, 128) template image
            searches: (B, K, 3, 256, 256) K consecutive search frames
                      OR (B, 3, 256, 256) single search frame (backward compat)
            temporal_mod_manager: optional TemporalModulationManager for FiLM
        Returns:
            template_feat: (B, 64, D) template features
            search_feats: (B, K, 256, D) per-frame search features
                          OR (B, 256, D) if single search frame input
        """
        B = template.shape[0]
        single_frame = (searches.ndim == 4)  # (B, 3, H, W) vs (B, K, 3, H, W)
        
        if single_frame:
            searches = searches.unsqueeze(1)  # (B, 1, 3, H, W)
        
        K = searches.shape[1]
        
        # Patch embed template
        t_tokens = self.patch_embed(template)  # (B, 64, D)
        t_tokens = t_tokens + self.template_pos + self.template_type
        n_template = t_tokens.shape[1]  # 64
        
        # Patch embed all search frames
        # Reshape (B, K, 3, H, W) → (B*K, 3, H, W) for batch patch embedding
        s_flat = searches.reshape(B * K, *searches.shape[2:])
        s_tokens_flat = self.patch_embed(s_flat)  # (B*K, 256, D)
        s_tokens = s_tokens_flat.reshape(B, K, -1, self.dim)  # (B, K, 256, D)
        s_tokens = s_tokens + self.search_pos.unsqueeze(1) + self.search_type
        n_search = s_tokens.shape[2]  # 256
        
        # Build full sequence: [template | search_1 | search_2 | ... | search_K]
        # The mLSTM memory carries information across this entire sequence
        s_tokens_concat = s_tokens.reshape(B, K * n_search, self.dim)  # (B, K*256, D)
        tokens = torch.cat([t_tokens, s_tokens_concat], dim=1)  # (B, 64 + K*256, D)
        
        # Process through bidirectional mLSTM blocks
        for i, block in enumerate(self.blocks):
            reverse = (i % 2 == 1)
            tokens = block(tokens, reverse=reverse)
            
            if temporal_mod_manager is not None:
                tokens = temporal_mod_manager.modulate(tokens, i)
        
        tokens = self.norm(tokens)
        
        if temporal_mod_manager is not None:
            temporal_mod_manager.update_temporal_context(tokens)
        
        # Split: template features + per-frame search features
        template_feat = tokens[:, :n_template]  # (B, 64, D)
        search_tokens = tokens[:, n_template:]  # (B, K*256, D)
        search_feats = search_tokens.reshape(B, K, n_search, self.dim)  # (B, K, 256, D)
        
        if single_frame:
            return template_feat, search_feats.squeeze(1)  # (B, 256, D)
        
        return template_feat, search_feats
    
    def freeze_shared_experts(self):
        """Freeze shared experts in TMoE blocks for Phase 2 training."""
        for block in self.blocks:
            if hasattr(block, 'freeze_shared_expert'):
                block.freeze_shared_expert()