File size: 12,577 Bytes
523bea4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
faf4cb2
523bea4
 
faf4cb2
 
 
523bea4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
faf4cb2
523bea4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
faf4cb2
 
 
523bea4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
faf4cb2
523bea4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
faf4cb2
523bea4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
"""
mLSTM Cell and Block for Vision-LSTM (ViL) backbone.

Architecture follows the official NX-AI ViL-S implementation:
- LinearHeadwiseExpand for Q/K/V projections (block-diagonal, ~3K params each)
- Depthwise causal Conv1d on the mLSTM branch
- Gates (igate, fgate) take concatenated [q, k, v] as input
- Output gate from second half of proj_up output
- Parallel mLSTM scan with matrix memory

Reference: Beck et al., "xLSTM: Extended Long Short-Term Memory" (arXiv:2405.04517)
           Alkin et al., "Vision-LSTM: xLSTM as Generic Vision Backbone" (arXiv:2406.04303)
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, einsum


class LinearHeadwiseExpand(nn.Module):
    """Block-diagonal linear projection: each head has its own small weight matrix.
    
    Instead of a full Linear(inner_dim, inner_dim) with inner_dim^2 params,
    this uses num_heads independent (head_dim, head_dim) matrices.
    For inner_dim=768, num_heads=192, head_dim=4:
      Full linear: 768*768 = 589,824 params
      Headwise:    192*4*4 = 3,072 params  (192x fewer!)
    """
    def __init__(self, in_features: int, num_heads: int, bias: bool = False):
        super().__init__()
        assert in_features % num_heads == 0, f"{in_features} not divisible by {num_heads}"
        self.num_heads = num_heads
        self.head_dim = in_features // num_heads
        self.in_features = in_features
        
        # Weight: (num_heads, head_dim_out, head_dim_in)
        self.weight = nn.Parameter(torch.empty(num_heads, self.head_dim, self.head_dim))
        self.bias = nn.Parameter(torch.zeros(in_features)) if bias else None
        self._reset_parameters()
    
    def _reset_parameters(self):
        nn.init.normal_(self.weight, std=math.sqrt(2.0 / (5.0 * self.head_dim)))
        if self.bias is not None:
            nn.init.zeros_(self.bias)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (..., in_features)
        x = rearrange(x, '... (nh d) -> ... nh d', nh=self.num_heads)
        x = einsum(x, self.weight, '... nh d, nh od d -> ... nh od')
        x = rearrange(x, '... nh od -> ... (nh od)')
        if self.bias is not None:
            x = x + self.bias
        return x


class StochasticDepth(nn.Module):
    """Drop entire residual path with probability `drop_prob` during training."""
    def __init__(self, drop_prob: float = 0.0):
        super().__init__()
        self.drop_prob = drop_prob
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.training or self.drop_prob == 0.0:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        mask = torch.bernoulli(torch.full(shape, keep_prob, device=x.device, dtype=x.dtype))
        return x * mask / keep_prob


class mLSTMCell(nn.Module):
    """Parallel mLSTM cell with matrix memory.
    
    Official architecture from xLSTM/ViL:
    - proj_up: Linear(D, 2*inner_dim) → split into mLSTM branch + output gate branch
    - CausalConv1d on mLSTM branch (depthwise, k=4)
    - LinearHeadwiseExpand for Q, K, V projections
    - igate, fgate: Linear(3*inner_dim, num_heads) from concat(q,k,v)
    - Parallel scan: C_t = f_t*C_{t-1} + i_t*(v_t ⊗ k_t), h_t = C_t*q_t
    - Output: (h + skip*conv_act) * SiLU(z), then proj_down
    
    ViL-S config: D=384, proj_factor=2.0, inner_dim=768, 
                  qkv_proj_blocksize=4, num_heads=4 (memory heads)
    Note: GroupNorm uses num_proj_heads (192) groups, matching official 
          MultiHeadLayerNorm — one group per projection head, NOT per memory head.
    Per-cell params: ~920K (vs 2.66M with full Linear Q/K/V)
    """
    def __init__(
        self,
        dim: int = 384,
        proj_factor: float = 2.0,
        qkv_proj_blocksize: int = 4,
        num_heads: int = 4,
        conv_kernel: int = 4,
        bias: bool = False,
    ):
        super().__init__()
        self.dim = dim
        # inner_dim rounded up to multiple of 64
        self.inner_dim = math.ceil(proj_factor * dim / 64) * 64
        self.num_heads = num_heads
        self.head_dim = self.inner_dim // num_heads  # 768/4 = 192
        
        # Number of projection heads for Q/K/V (block-diagonal)
        num_proj_heads = self.inner_dim // qkv_proj_blocksize
        self.num_proj_heads = num_proj_heads
        
        # Up-projection: D -> 2*inner_dim (mLSTM branch + output gate branch)
        self.proj_up = nn.Linear(dim, 2 * self.inner_dim, bias=bias)
        
        # Depthwise causal conv1d on mLSTM branch
        self.conv1d = nn.Conv1d(
            self.inner_dim, self.inner_dim,
            kernel_size=conv_kernel,
            padding=conv_kernel - 1,  # causal: pad left
            groups=self.inner_dim,
            bias=True,
        )
        self.conv_kernel = conv_kernel
        
        # Block-diagonal Q/K/V projections
        self.q_proj = LinearHeadwiseExpand(self.inner_dim, num_proj_heads, bias=bias)
        self.k_proj = LinearHeadwiseExpand(self.inner_dim, num_proj_heads, bias=bias)
        self.v_proj = LinearHeadwiseExpand(self.inner_dim, num_proj_heads, bias=bias)
        
        # Gates: take concat(q, k, v) as input
        self.igate = nn.Linear(3 * self.inner_dim, num_heads, bias=True)
        self.fgate = nn.Linear(3 * self.inner_dim, num_heads, bias=True)
        
        # Output normalization: per-projection-head group norm (192 groups for ViL-S)
        # Matches official MultiHeadLayerNorm — one group per projection head
        self.outnorm = nn.GroupNorm(num_proj_heads, self.inner_dim, affine=True)
        
        # Down-projection: inner_dim -> D
        self.proj_down = nn.Linear(self.inner_dim, dim, bias=bias)
        
        # Learnable skip connection and layer scale
        self.learnable_skip = nn.Parameter(torch.ones(self.inner_dim))
        self.layerscale = nn.Parameter(torch.ones(self.inner_dim))
        
        self._reset_gate_bias()
    
    def _reset_gate_bias(self):
        """Initialize forget gate bias high (encourages remembering) and input gate low."""
        with torch.no_grad():
            nn.init.zeros_(self.igate.bias)
            # Forget gate bias: initialize to encourage remembering
            nn.init.constant_(self.fgate.bias, 3.0)
    
    def forward(self, x: torch.Tensor, reverse: bool = False) -> torch.Tensor:
        """
        Args:
            x: (B, S, D) input sequence
            reverse: if True, process sequence right-to-left (for bidirectional scanning)
        Returns:
            (B, S, D) output
        """
        B, S, D = x.shape
        
        if reverse:
            x = x.flip(1)
        
        # 1. Up-project to 2*inner_dim
        up = self.proj_up(x)  # (B, S, 2*inner)
        x_mlstm = up[..., :self.inner_dim]   # mLSTM branch
        z = up[..., self.inner_dim:]          # output gate branch
        
        # 2. Causal conv1d on mLSTM branch
        x_conv = self.conv1d(x_mlstm.transpose(1, 2))  # (B, inner, S+pad)
        x_conv = x_conv[..., :S].transpose(1, 2)       # causal: keep first S
        x_conv_act = F.silu(x_conv)
        
        # 3. Q/K/V projections (block-diagonal, very lightweight)
        q = self.q_proj(x_conv_act)   # (B, S, inner)
        k = self.k_proj(x_conv_act)   # (B, S, inner)
        v = self.v_proj(x_mlstm)      # V from pre-conv branch
        
        # 4. Gates from concat(q, k, v)
        qkv_cat = torch.cat([q, k, v], dim=-1)  # (B, S, 3*inner)
        i_gate = self.igate(qkv_cat)    # (B, S, num_heads)
        f_gate = self.fgate(qkv_cat)    # (B, S, num_heads)
        
        # Stabilized gates
        i_tilde = torch.exp(i_gate)                          # (B, S, H)
        f_tilde = torch.sigmoid(f_gate)                      # (B, S, H)
        # Log-space stabilization
        log_f = torch.log(f_tilde.clamp(min=1e-6))           # (B, S, H)
        
        # 5. Reshape Q/K/V for multi-head matrix memory
        q = rearrange(q, 'b s (h d) -> b h s d', h=self.num_heads)  # (B, H, S, D_h)
        k = rearrange(k, 'b s (h d) -> b h s d', h=self.num_heads)
        v = rearrange(v, 'b s (h d) -> b h s d', h=self.num_heads)
        
        # 6. Parallel mLSTM computation (log-space stabilized)
        # Cumulative sum of log forget gates for parallel scan
        log_f_cumsum = torch.cumsum(log_f.permute(0, 2, 1), dim=-1)  # (B, H, S)
        
        # Compute pairwise log forget gate differences for attention-like matrix
        # log_f_cumsum[:,:,j] - log_f_cumsum[:,:,i] gives product of f gates from i+1 to j
        log_D = log_f_cumsum.unsqueeze(-1) - log_f_cumsum.unsqueeze(-2)  # (B, H, S, S)
        
        # Causal mask: only attend to past positions
        causal_mask = torch.tril(torch.ones(S, S, device=x.device, dtype=torch.bool))
        log_D = log_D.masked_fill(~causal_mask, -1e9)
        
        # Add input gate contribution
        i_tilde_perm = i_tilde.permute(0, 2, 1)  # (B, H, S)
        log_D = log_D + torch.log(i_tilde_perm.clamp(min=1e-6)).unsqueeze(-2)  # broadcast over queries
        
        # Stabilize with max trick
        max_log_D = log_D.max(dim=-1, keepdim=True).values.clamp(min=-10)
        D = torch.exp(log_D - max_log_D)  # (B, H, S, S)
        D = D.masked_fill(~causal_mask, 0.0)
        
        # Compute attention: h = D @ v, then normalize by D @ k·q
        # Scale queries
        q_scaled = q / math.sqrt(self.head_dim)
        
        # Attention scores: (q @ k^T) * D
        attn = torch.matmul(q_scaled, k.transpose(-1, -2)) * D  # (B, H, S, S)
        
        # Normalizer
        normalizer = attn.sum(dim=-1, keepdim=True).clamp(min=1.0)
        attn = attn / normalizer
        
        # Output
        h = torch.matmul(attn, v)  # (B, H, S, D_h)
        h = rearrange(h, 'b h s d -> b s (h d)')
        
        # 7. Output norm
        h = self.outnorm(h.transpose(1, 2)).transpose(1, 2)  # GroupNorm on channel dim
        
        # 8. Skip connection + output gate
        h_skip = h + self.learnable_skip * x_conv_act
        output = h_skip * F.silu(z)  # output gate: SiLU (not sigmoid) per official ViL
        
        # 9. Down-project + layer scale
        output = self.proj_down(output)
        output = output * self.layerscale[:self.dim]  # Note: layerscale is inner_dim sized, we need dim
        
        if reverse:
            output = output.flip(1)
        
        return output


class SwiGLUMLP(nn.Module):
    """SwiGLU MLP as used in ViL blocks.
    
    SwiGLU(x) = (W1·x ⊙ Swish(V·x)) then W2·hidden → output
    """
    def __init__(self, dim: int, mlp_ratio: float = 4.0, bias: bool = False, drop: float = 0.0):
        super().__init__()
        hidden_dim = int(dim * mlp_ratio)
        # SwiGLU: two parallel projections, one gated
        self.w1 = nn.Linear(dim, hidden_dim, bias=bias)     # value path
        self.w2 = nn.Linear(hidden_dim, dim, bias=bias)     # down projection
        self.v = nn.Linear(dim, hidden_dim, bias=bias)       # gate path
        self.drop = nn.Dropout(drop)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.drop(self.w2(F.silu(self.v(x)) * self.w1(x)))


class mLSTMBlock(nn.Module):
    """Single ViL block: LayerNorm → mLSTMCell → residual.
    
    Following the official ViL-S architecture, there is NO separate MLP/FFN layer.
    The gated output (proj_up → split → z-gate → proj_down) inside the mLSTMCell
    already performs the role of dimension expansion + nonlinearity + projection.
    
    This matches ViL-S: ~0.92M params per block, 24 blocks ≈ 22M backbone.
    """
    def __init__(
        self,
        dim: int = 384,
        proj_factor: float = 2.0,
        qkv_proj_blocksize: int = 4,
        num_heads: int = 4,
        conv_kernel: int = 4,
        mlp_ratio: float = 4.0,  # kept for API compat but unused in standard blocks
        drop_path: float = 0.0,
        bias: bool = False,
    ):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, bias=False)
        self.mlstm = mLSTMCell(
            dim=dim,
            proj_factor=proj_factor,
            qkv_proj_blocksize=qkv_proj_blocksize,
            num_heads=num_heads,
            conv_kernel=conv_kernel,
            bias=bias,
        )
        self.drop_path = StochasticDepth(drop_path)
    
    def forward(self, x: torch.Tensor, reverse: bool = False) -> torch.Tensor:
        x = x + self.drop_path(self.mlstm(self.norm1(x), reverse=reverse))
        return x