| """ |
| mLSTM Cell and Block for Vision-LSTM (ViL) backbone. |
| |
| Architecture follows the official NX-AI ViL-S implementation: |
| - LinearHeadwiseExpand for Q/K/V projections (block-diagonal, ~3K params each) |
| - Depthwise causal Conv1d on the mLSTM branch |
| - Gates (igate, fgate) take concatenated [q, k, v] as input |
| - Output gate from second half of proj_up output |
| - Parallel mLSTM scan with matrix memory |
| |
| Reference: Beck et al., "xLSTM: Extended Long Short-Term Memory" (arXiv:2405.04517) |
| Alkin et al., "Vision-LSTM: xLSTM as Generic Vision Backbone" (arXiv:2406.04303) |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange, einsum |
|
|
|
|
| class LinearHeadwiseExpand(nn.Module): |
| """Block-diagonal linear projection: each head has its own small weight matrix. |
| |
| Instead of a full Linear(inner_dim, inner_dim) with inner_dim^2 params, |
| this uses num_heads independent (head_dim, head_dim) matrices. |
| For inner_dim=768, num_heads=192, head_dim=4: |
| Full linear: 768*768 = 589,824 params |
| Headwise: 192*4*4 = 3,072 params (192x fewer!) |
| """ |
| def __init__(self, in_features: int, num_heads: int, bias: bool = False): |
| super().__init__() |
| assert in_features % num_heads == 0, f"{in_features} not divisible by {num_heads}" |
| self.num_heads = num_heads |
| self.head_dim = in_features // num_heads |
| self.in_features = in_features |
| |
| |
| self.weight = nn.Parameter(torch.empty(num_heads, self.head_dim, self.head_dim)) |
| self.bias = nn.Parameter(torch.zeros(in_features)) if bias else None |
| self._reset_parameters() |
| |
| def _reset_parameters(self): |
| nn.init.normal_(self.weight, std=math.sqrt(2.0 / (5.0 * self.head_dim))) |
| if self.bias is not None: |
| nn.init.zeros_(self.bias) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| x = rearrange(x, '... (nh d) -> ... nh d', nh=self.num_heads) |
| x = einsum(x, self.weight, '... nh d, nh od d -> ... nh od') |
| x = rearrange(x, '... nh od -> ... (nh od)') |
| if self.bias is not None: |
| x = x + self.bias |
| return x |
|
|
|
|
| class StochasticDepth(nn.Module): |
| """Drop entire residual path with probability `drop_prob` during training.""" |
| def __init__(self, drop_prob: float = 0.0): |
| super().__init__() |
| self.drop_prob = drop_prob |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if not self.training or self.drop_prob == 0.0: |
| return x |
| keep_prob = 1 - self.drop_prob |
| shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
| mask = torch.bernoulli(torch.full(shape, keep_prob, device=x.device, dtype=x.dtype)) |
| return x * mask / keep_prob |
|
|
|
|
| class mLSTMCell(nn.Module): |
| """Parallel mLSTM cell with matrix memory. |
| |
| Official architecture from xLSTM/ViL: |
| - proj_up: Linear(D, 2*inner_dim) → split into mLSTM branch + output gate branch |
| - CausalConv1d on mLSTM branch (depthwise, k=4) |
| - LinearHeadwiseExpand for Q, K, V projections |
| - igate, fgate: Linear(3*inner_dim, num_heads) from concat(q,k,v) |
| - Parallel scan: C_t = f_t*C_{t-1} + i_t*(v_t ⊗ k_t), h_t = C_t*q_t |
| - Output: (h + skip*conv_act) * SiLU(z), then proj_down |
| |
| ViL-S config: D=384, proj_factor=2.0, inner_dim=768, |
| qkv_proj_blocksize=4, num_heads=4 (memory heads) |
| Note: GroupNorm uses num_proj_heads (192) groups, matching official |
| MultiHeadLayerNorm — one group per projection head, NOT per memory head. |
| Per-cell params: ~920K (vs 2.66M with full Linear Q/K/V) |
| """ |
| def __init__( |
| self, |
| dim: int = 384, |
| proj_factor: float = 2.0, |
| qkv_proj_blocksize: int = 4, |
| num_heads: int = 4, |
| conv_kernel: int = 4, |
| bias: bool = False, |
| ): |
| super().__init__() |
| self.dim = dim |
| |
| self.inner_dim = math.ceil(proj_factor * dim / 64) * 64 |
| self.num_heads = num_heads |
| self.head_dim = self.inner_dim // num_heads |
| |
| |
| num_proj_heads = self.inner_dim // qkv_proj_blocksize |
| self.num_proj_heads = num_proj_heads |
| |
| |
| self.proj_up = nn.Linear(dim, 2 * self.inner_dim, bias=bias) |
| |
| |
| self.conv1d = nn.Conv1d( |
| self.inner_dim, self.inner_dim, |
| kernel_size=conv_kernel, |
| padding=conv_kernel - 1, |
| groups=self.inner_dim, |
| bias=True, |
| ) |
| self.conv_kernel = conv_kernel |
| |
| |
| self.q_proj = LinearHeadwiseExpand(self.inner_dim, num_proj_heads, bias=bias) |
| self.k_proj = LinearHeadwiseExpand(self.inner_dim, num_proj_heads, bias=bias) |
| self.v_proj = LinearHeadwiseExpand(self.inner_dim, num_proj_heads, bias=bias) |
| |
| |
| self.igate = nn.Linear(3 * self.inner_dim, num_heads, bias=True) |
| self.fgate = nn.Linear(3 * self.inner_dim, num_heads, bias=True) |
| |
| |
| |
| self.outnorm = nn.GroupNorm(num_proj_heads, self.inner_dim, affine=True) |
| |
| |
| self.proj_down = nn.Linear(self.inner_dim, dim, bias=bias) |
| |
| |
| self.learnable_skip = nn.Parameter(torch.ones(self.inner_dim)) |
| self.layerscale = nn.Parameter(torch.ones(self.inner_dim)) |
| |
| self._reset_gate_bias() |
| |
| def _reset_gate_bias(self): |
| """Initialize forget gate bias high (encourages remembering) and input gate low.""" |
| with torch.no_grad(): |
| nn.init.zeros_(self.igate.bias) |
| |
| nn.init.constant_(self.fgate.bias, 3.0) |
| |
| def forward(self, x: torch.Tensor, reverse: bool = False) -> torch.Tensor: |
| """ |
| Args: |
| x: (B, S, D) input sequence |
| reverse: if True, process sequence right-to-left (for bidirectional scanning) |
| Returns: |
| (B, S, D) output |
| """ |
| B, S, D = x.shape |
| |
| if reverse: |
| x = x.flip(1) |
| |
| |
| up = self.proj_up(x) |
| x_mlstm = up[..., :self.inner_dim] |
| z = up[..., self.inner_dim:] |
| |
| |
| x_conv = self.conv1d(x_mlstm.transpose(1, 2)) |
| x_conv = x_conv[..., :S].transpose(1, 2) |
| x_conv_act = F.silu(x_conv) |
| |
| |
| q = self.q_proj(x_conv_act) |
| k = self.k_proj(x_conv_act) |
| v = self.v_proj(x_mlstm) |
| |
| |
| qkv_cat = torch.cat([q, k, v], dim=-1) |
| i_gate = self.igate(qkv_cat) |
| f_gate = self.fgate(qkv_cat) |
| |
| |
| i_tilde = torch.exp(i_gate) |
| f_tilde = torch.sigmoid(f_gate) |
| |
| log_f = torch.log(f_tilde.clamp(min=1e-6)) |
| |
| |
| q = rearrange(q, 'b s (h d) -> b h s d', h=self.num_heads) |
| k = rearrange(k, 'b s (h d) -> b h s d', h=self.num_heads) |
| v = rearrange(v, 'b s (h d) -> b h s d', h=self.num_heads) |
| |
| |
| |
| log_f_cumsum = torch.cumsum(log_f.permute(0, 2, 1), dim=-1) |
| |
| |
| |
| log_D = log_f_cumsum.unsqueeze(-1) - log_f_cumsum.unsqueeze(-2) |
| |
| |
| causal_mask = torch.tril(torch.ones(S, S, device=x.device, dtype=torch.bool)) |
| log_D = log_D.masked_fill(~causal_mask, -1e9) |
| |
| |
| i_tilde_perm = i_tilde.permute(0, 2, 1) |
| log_D = log_D + torch.log(i_tilde_perm.clamp(min=1e-6)).unsqueeze(-2) |
| |
| |
| max_log_D = log_D.max(dim=-1, keepdim=True).values.clamp(min=-10) |
| D = torch.exp(log_D - max_log_D) |
| D = D.masked_fill(~causal_mask, 0.0) |
| |
| |
| |
| q_scaled = q / math.sqrt(self.head_dim) |
| |
| |
| attn = torch.matmul(q_scaled, k.transpose(-1, -2)) * D |
| |
| |
| normalizer = attn.sum(dim=-1, keepdim=True).clamp(min=1.0) |
| attn = attn / normalizer |
| |
| |
| h = torch.matmul(attn, v) |
| h = rearrange(h, 'b h s d -> b s (h d)') |
| |
| |
| h = self.outnorm(h.transpose(1, 2)).transpose(1, 2) |
| |
| |
| h_skip = h + self.learnable_skip * x_conv_act |
| output = h_skip * F.silu(z) |
| |
| |
| output = self.proj_down(output) |
| output = output * self.layerscale[:self.dim] |
| |
| if reverse: |
| output = output.flip(1) |
| |
| return output |
|
|
|
|
| class SwiGLUMLP(nn.Module): |
| """SwiGLU MLP as used in ViL blocks. |
| |
| SwiGLU(x) = (W1·x ⊙ Swish(V·x)) then W2·hidden → output |
| """ |
| def __init__(self, dim: int, mlp_ratio: float = 4.0, bias: bool = False, drop: float = 0.0): |
| super().__init__() |
| hidden_dim = int(dim * mlp_ratio) |
| |
| self.w1 = nn.Linear(dim, hidden_dim, bias=bias) |
| self.w2 = nn.Linear(hidden_dim, dim, bias=bias) |
| self.v = nn.Linear(dim, hidden_dim, bias=bias) |
| self.drop = nn.Dropout(drop) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.drop(self.w2(F.silu(self.v(x)) * self.w1(x))) |
|
|
|
|
| class mLSTMBlock(nn.Module): |
| """Single ViL block: LayerNorm → mLSTMCell → residual. |
| |
| Following the official ViL-S architecture, there is NO separate MLP/FFN layer. |
| The gated output (proj_up → split → z-gate → proj_down) inside the mLSTMCell |
| already performs the role of dimension expansion + nonlinearity + projection. |
| |
| This matches ViL-S: ~0.92M params per block, 24 blocks ≈ 22M backbone. |
| """ |
| def __init__( |
| self, |
| dim: int = 384, |
| proj_factor: float = 2.0, |
| qkv_proj_blocksize: int = 4, |
| num_heads: int = 4, |
| conv_kernel: int = 4, |
| mlp_ratio: float = 4.0, |
| drop_path: float = 0.0, |
| bias: bool = False, |
| ): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(dim, bias=False) |
| self.mlstm = mLSTMCell( |
| dim=dim, |
| proj_factor=proj_factor, |
| qkv_proj_blocksize=qkv_proj_blocksize, |
| num_heads=num_heads, |
| conv_kernel=conv_kernel, |
| bias=bias, |
| ) |
| self.drop_path = StochasticDepth(drop_path) |
| |
| def forward(self, x: torch.Tensor, reverse: bool = False) -> torch.Tensor: |
| x = x + self.drop_path(self.mlstm(self.norm1(x), reverse=reverse)) |
| return x |
|
|