omar-ah commited on
Commit
faf4cb2
·
verified ·
1 Parent(s): 59fd921

Fix: mLSTM SiLU gate+activation, GroupNorm 192, stochastic depth 0.05, Hanning window

Browse files
Files changed (1) hide show
  1. vil_tracker/models/mlstm.py +10 -6
vil_tracker/models/mlstm.py CHANGED
@@ -79,10 +79,12 @@ class mLSTMCell(nn.Module):
79
  - LinearHeadwiseExpand for Q, K, V projections
80
  - igate, fgate: Linear(3*inner_dim, num_heads) from concat(q,k,v)
81
  - Parallel scan: C_t = f_t*C_{t-1} + i_t*(v_t ⊗ k_t), h_t = C_t*q_t
82
- - Output: (h + skip*conv_act) * sigmoid(z), then proj_down
83
 
84
  ViL-S config: D=384, proj_factor=2.0, inner_dim=768,
85
- qkv_proj_blocksize=4, num_heads=4
 
 
86
  Per-cell params: ~920K (vs 2.66M with full Linear Q/K/V)
87
  """
88
  def __init__(
@@ -103,6 +105,7 @@ class mLSTMCell(nn.Module):
103
 
104
  # Number of projection heads for Q/K/V (block-diagonal)
105
  num_proj_heads = self.inner_dim // qkv_proj_blocksize
 
106
 
107
  # Up-projection: D -> 2*inner_dim (mLSTM branch + output gate branch)
108
  self.proj_up = nn.Linear(dim, 2 * self.inner_dim, bias=bias)
@@ -126,8 +129,9 @@ class mLSTMCell(nn.Module):
126
  self.igate = nn.Linear(3 * self.inner_dim, num_heads, bias=True)
127
  self.fgate = nn.Linear(3 * self.inner_dim, num_heads, bias=True)
128
 
129
- # Output normalization (per-head group norm)
130
- self.outnorm = nn.GroupNorm(num_heads, self.inner_dim, affine=True)
 
131
 
132
  # Down-projection: inner_dim -> D
133
  self.proj_down = nn.Linear(self.inner_dim, dim, bias=bias)
@@ -166,7 +170,7 @@ class mLSTMCell(nn.Module):
166
  # 2. Causal conv1d on mLSTM branch
167
  x_conv = self.conv1d(x_mlstm.transpose(1, 2)) # (B, inner, S+pad)
168
  x_conv = x_conv[..., :S].transpose(1, 2) # causal: keep first S
169
- x_conv_act = F.gelu(x_conv)
170
 
171
  # 3. Q/K/V projections (block-diagonal, very lightweight)
172
  q = self.q_proj(x_conv_act) # (B, S, inner)
@@ -230,7 +234,7 @@ class mLSTMCell(nn.Module):
230
 
231
  # 8. Skip connection + output gate
232
  h_skip = h + self.learnable_skip * x_conv_act
233
- output = h_skip * torch.sigmoid(z) # output gate
234
 
235
  # 9. Down-project + layer scale
236
  output = self.proj_down(output)
 
79
  - LinearHeadwiseExpand for Q, K, V projections
80
  - igate, fgate: Linear(3*inner_dim, num_heads) from concat(q,k,v)
81
  - Parallel scan: C_t = f_t*C_{t-1} + i_t*(v_t ⊗ k_t), h_t = C_t*q_t
82
+ - Output: (h + skip*conv_act) * SiLU(z), then proj_down
83
 
84
  ViL-S config: D=384, proj_factor=2.0, inner_dim=768,
85
+ qkv_proj_blocksize=4, num_heads=4 (memory heads)
86
+ Note: GroupNorm uses num_proj_heads (192) groups, matching official
87
+ MultiHeadLayerNorm — one group per projection head, NOT per memory head.
88
  Per-cell params: ~920K (vs 2.66M with full Linear Q/K/V)
89
  """
90
  def __init__(
 
105
 
106
  # Number of projection heads for Q/K/V (block-diagonal)
107
  num_proj_heads = self.inner_dim // qkv_proj_blocksize
108
+ self.num_proj_heads = num_proj_heads
109
 
110
  # Up-projection: D -> 2*inner_dim (mLSTM branch + output gate branch)
111
  self.proj_up = nn.Linear(dim, 2 * self.inner_dim, bias=bias)
 
129
  self.igate = nn.Linear(3 * self.inner_dim, num_heads, bias=True)
130
  self.fgate = nn.Linear(3 * self.inner_dim, num_heads, bias=True)
131
 
132
+ # Output normalization: per-projection-head group norm (192 groups for ViL-S)
133
+ # Matches official MultiHeadLayerNorm — one group per projection head
134
+ self.outnorm = nn.GroupNorm(num_proj_heads, self.inner_dim, affine=True)
135
 
136
  # Down-projection: inner_dim -> D
137
  self.proj_down = nn.Linear(self.inner_dim, dim, bias=bias)
 
170
  # 2. Causal conv1d on mLSTM branch
171
  x_conv = self.conv1d(x_mlstm.transpose(1, 2)) # (B, inner, S+pad)
172
  x_conv = x_conv[..., :S].transpose(1, 2) # causal: keep first S
173
+ x_conv_act = F.silu(x_conv)
174
 
175
  # 3. Q/K/V projections (block-diagonal, very lightweight)
176
  q = self.q_proj(x_conv_act) # (B, S, inner)
 
234
 
235
  # 8. Skip connection + output gate
236
  h_skip = h + self.learnable_skip * x_conv_act
237
+ output = h_skip * F.silu(z) # output gate: SiLU (not sigmoid) per official ViL
238
 
239
  # 9. Down-project + layer scale
240
  output = self.proj_down(output)