Fix: mLSTM SiLU gate+activation, GroupNorm 192, stochastic depth 0.05, Hanning window
Browse files- vil_tracker/models/mlstm.py +10 -6
vil_tracker/models/mlstm.py
CHANGED
|
@@ -79,10 +79,12 @@ class mLSTMCell(nn.Module):
|
|
| 79 |
- LinearHeadwiseExpand for Q, K, V projections
|
| 80 |
- igate, fgate: Linear(3*inner_dim, num_heads) from concat(q,k,v)
|
| 81 |
- Parallel scan: C_t = f_t*C_{t-1} + i_t*(v_t ⊗ k_t), h_t = C_t*q_t
|
| 82 |
-
- Output: (h + skip*conv_act) *
|
| 83 |
|
| 84 |
ViL-S config: D=384, proj_factor=2.0, inner_dim=768,
|
| 85 |
-
qkv_proj_blocksize=4, num_heads=4
|
|
|
|
|
|
|
| 86 |
Per-cell params: ~920K (vs 2.66M with full Linear Q/K/V)
|
| 87 |
"""
|
| 88 |
def __init__(
|
|
@@ -103,6 +105,7 @@ class mLSTMCell(nn.Module):
|
|
| 103 |
|
| 104 |
# Number of projection heads for Q/K/V (block-diagonal)
|
| 105 |
num_proj_heads = self.inner_dim // qkv_proj_blocksize
|
|
|
|
| 106 |
|
| 107 |
# Up-projection: D -> 2*inner_dim (mLSTM branch + output gate branch)
|
| 108 |
self.proj_up = nn.Linear(dim, 2 * self.inner_dim, bias=bias)
|
|
@@ -126,8 +129,9 @@ class mLSTMCell(nn.Module):
|
|
| 126 |
self.igate = nn.Linear(3 * self.inner_dim, num_heads, bias=True)
|
| 127 |
self.fgate = nn.Linear(3 * self.inner_dim, num_heads, bias=True)
|
| 128 |
|
| 129 |
-
# Output normalization
|
| 130 |
-
|
|
|
|
| 131 |
|
| 132 |
# Down-projection: inner_dim -> D
|
| 133 |
self.proj_down = nn.Linear(self.inner_dim, dim, bias=bias)
|
|
@@ -166,7 +170,7 @@ class mLSTMCell(nn.Module):
|
|
| 166 |
# 2. Causal conv1d on mLSTM branch
|
| 167 |
x_conv = self.conv1d(x_mlstm.transpose(1, 2)) # (B, inner, S+pad)
|
| 168 |
x_conv = x_conv[..., :S].transpose(1, 2) # causal: keep first S
|
| 169 |
-
x_conv_act = F.
|
| 170 |
|
| 171 |
# 3. Q/K/V projections (block-diagonal, very lightweight)
|
| 172 |
q = self.q_proj(x_conv_act) # (B, S, inner)
|
|
@@ -230,7 +234,7 @@ class mLSTMCell(nn.Module):
|
|
| 230 |
|
| 231 |
# 8. Skip connection + output gate
|
| 232 |
h_skip = h + self.learnable_skip * x_conv_act
|
| 233 |
-
output = h_skip *
|
| 234 |
|
| 235 |
# 9. Down-project + layer scale
|
| 236 |
output = self.proj_down(output)
|
|
|
|
| 79 |
- LinearHeadwiseExpand for Q, K, V projections
|
| 80 |
- igate, fgate: Linear(3*inner_dim, num_heads) from concat(q,k,v)
|
| 81 |
- Parallel scan: C_t = f_t*C_{t-1} + i_t*(v_t ⊗ k_t), h_t = C_t*q_t
|
| 82 |
+
- Output: (h + skip*conv_act) * SiLU(z), then proj_down
|
| 83 |
|
| 84 |
ViL-S config: D=384, proj_factor=2.0, inner_dim=768,
|
| 85 |
+
qkv_proj_blocksize=4, num_heads=4 (memory heads)
|
| 86 |
+
Note: GroupNorm uses num_proj_heads (192) groups, matching official
|
| 87 |
+
MultiHeadLayerNorm — one group per projection head, NOT per memory head.
|
| 88 |
Per-cell params: ~920K (vs 2.66M with full Linear Q/K/V)
|
| 89 |
"""
|
| 90 |
def __init__(
|
|
|
|
| 105 |
|
| 106 |
# Number of projection heads for Q/K/V (block-diagonal)
|
| 107 |
num_proj_heads = self.inner_dim // qkv_proj_blocksize
|
| 108 |
+
self.num_proj_heads = num_proj_heads
|
| 109 |
|
| 110 |
# Up-projection: D -> 2*inner_dim (mLSTM branch + output gate branch)
|
| 111 |
self.proj_up = nn.Linear(dim, 2 * self.inner_dim, bias=bias)
|
|
|
|
| 129 |
self.igate = nn.Linear(3 * self.inner_dim, num_heads, bias=True)
|
| 130 |
self.fgate = nn.Linear(3 * self.inner_dim, num_heads, bias=True)
|
| 131 |
|
| 132 |
+
# Output normalization: per-projection-head group norm (192 groups for ViL-S)
|
| 133 |
+
# Matches official MultiHeadLayerNorm — one group per projection head
|
| 134 |
+
self.outnorm = nn.GroupNorm(num_proj_heads, self.inner_dim, affine=True)
|
| 135 |
|
| 136 |
# Down-projection: inner_dim -> D
|
| 137 |
self.proj_down = nn.Linear(self.inner_dim, dim, bias=bias)
|
|
|
|
| 170 |
# 2. Causal conv1d on mLSTM branch
|
| 171 |
x_conv = self.conv1d(x_mlstm.transpose(1, 2)) # (B, inner, S+pad)
|
| 172 |
x_conv = x_conv[..., :S].transpose(1, 2) # causal: keep first S
|
| 173 |
+
x_conv_act = F.silu(x_conv)
|
| 174 |
|
| 175 |
# 3. Q/K/V projections (block-diagonal, very lightweight)
|
| 176 |
q = self.q_proj(x_conv_act) # (B, S, inner)
|
|
|
|
| 234 |
|
| 235 |
# 8. Skip connection + output gate
|
| 236 |
h_skip = h + self.learnable_skip * x_conv_act
|
| 237 |
+
output = h_skip * F.silu(z) # output gate: SiLU (not sigmoid) per official ViL
|
| 238 |
|
| 239 |
# 9. Down-project + layer scale
|
| 240 |
output = self.proj_down(output)
|