| """ |
| Mamba-2 SSD (State Space Duality) — Linear-time attention replacement. |
| |
| From: "Transformers are SSMs: Generalized Models and Efficient Algorithms |
| Through Structured State Space Duality" (Dao & Gu, 2024) |
| |
| Key insight: SSMs and linear attention are the SAME computation. |
| Mamba-2's SSD can be computed in two modes: |
| 1. Linear recurrence mode (like Mamba-1): O(N) time, O(N) memory |
| 2. Matrix multiply mode (like attention): O(N²) for short sequences |
| |
| The scalar-A formulation enables chunk-scan parallelism: split sequence |
| into chunks, compute SSM within each chunk via matmul, then combine |
| with parallel associative scan across chunks. |
| |
| For our lightweight image generator, we implement the core SSD algorithm |
| in pure PyTorch without needing the mamba-ssm CUDA kernels. This makes |
| it portable to any device (CPU, GPU, mobile) and compatible with |
| ONNX/CoreML export. |
| |
| Reference implementation: tommyip/mamba2-minimal |
| Reference paper: arXiv:2405.21060 |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
|
|
|
|
| def segsum(x): |
| """More stable segment sum calculation (from mamba2-minimal).""" |
| T = x.size(-1) |
| x_cumsum = torch.cumsum(x, dim=-1) |
| x_segsum = x_cumsum.unsqueeze(-1) - x_cumsum.unsqueeze(-2) |
| mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) |
| x_segsum = x_segsum.masked_fill(~mask, -torch.inf) |
| return x_segsum |
|
|
|
|
| class Mamba2SSD(nn.Module): |
| """ |
| Mamba-2 SSD (State Space Duality) module. |
| |
| Implements the scalar-A SSM with chunked parallelism. |
| Pure PyTorch — no CUDA kernels needed. |
| |
| The SSM is defined as: |
| h_t = A_t * h_{t-1} + B_t * x_t (state update) |
| y_t = C_t^T * h_t (output) |
| |
| With scalar A (input-dependent), the system can be parallelized |
| via parallel associative scan (prefix sum). |
| |
| Args: |
| dim: Input/output dimension |
| d_state: State dimension (default 16, as in Mamba paper) |
| d_conv: Conv1d kernel size for preprocessing |
| expand: Expansion factor for inner dimension |
| chunk_size: Size for chunk-scan parallelization |
| """ |
| |
| def __init__(self, dim, d_state=16, d_conv=4, expand=2, chunk_size=64): |
| super().__init__() |
| self.dim = dim |
| self.d_state = d_state |
| self.chunk_size = chunk_size |
| |
| inner_dim = dim * expand |
| |
| |
| self.in_proj = nn.Linear(dim, inner_dim * 2) |
| |
| |
| self.conv1d = nn.Conv1d( |
| inner_dim, inner_dim, |
| kernel_size=d_conv, padding=d_conv - 1, |
| groups=inner_dim, bias=False |
| ) |
| |
| |
| self.x_proj = nn.Linear(inner_dim, d_state * 2 + 1) |
| |
| |
| dt_min = 0.001 |
| dt_max = 0.1 |
| self.dt_bias = nn.Parameter(torch.empty(inner_dim)) |
| |
| |
| nn.init.uniform_(self.dt_bias, dt_min, dt_max) |
| |
| |
| A = torch.empty(inner_dim, dtype=torch.float32).uniform_(1, 16) |
| self.A_log = nn.Parameter(torch.log(A)) |
| |
| |
| self.D = nn.Parameter(torch.ones(inner_dim)) |
| |
| |
| self.out_proj = nn.Linear(inner_dim, dim) |
| |
| |
| self.norm = nn.LayerNorm(inner_dim) |
| |
| def _selective_scan(self, u, delta, A, B, C, D): |
| """ |
| Selective scan: the core SSM recurrence. |
| |
| Args: |
| u: input [B, L, inner_dim] |
| delta: timestep [B, L, inner_dim] |
| A: state matrix parameter [inner_dim] |
| B: input projection [B, L, d_state] |
| C: output projection [B, L, d_state] |
| D: skip connection [inner_dim] |
| |
| Returns: |
| y: output [B, L, inner_dim] |
| """ |
| B_batch, L, D_inner = u.shape |
| d_state = B.shape[-1] |
| |
| |
| |
| |
| deltaA = torch.exp(delta * A.unsqueeze(0).unsqueeze(0)) |
| deltaB_u = delta.unsqueeze(-1) * B * u.unsqueeze(-1) |
| |
| |
| |
| |
| |
| y = self._parallel_scan(deltaA, deltaB_u, C) |
| |
| |
| y = y + u * D.unsqueeze(0).unsqueeze(0) |
| |
| return y |
| |
| def _parallel_scan(self, A, Bu, C): |
| """ |
| Parallel associative scan (Blelloch scan). |
| |
| The recurrence h_t = A_t * h_{t-1} + Bu_t can be parallelized |
| because it's an associative operation: |
| (a_1, b_1) ∘ (a_2, b_2) = (a_1 * a_2, b_1 * a_2 + b_2) |
| |
| Args: |
| A: [B, L, D_inner] — scalar A values (already discretized) |
| Bu: [B, L, D_inner, d_state] — B * u |
| C: [B, L, d_state] — output matrix |
| |
| Returns: |
| y: [B, L, D_inner] |
| """ |
| B, L, D_inner = A.shape |
| d_state = Bu.shape[-1] |
| |
| |
| L_orig = L |
| L_pad = 2 ** math.ceil(math.log2(L)) |
| pad_len = L_pad - L |
| |
| if pad_len > 0: |
| A = F.pad(A, (0, 0, 0, pad_len), value=1.0) |
| Bu = F.pad(Bu, (0, 0, 0, 0, 0, pad_len), value=0.0) |
| C = F.pad(C, (0, 0, 0, pad_len), value=0.0) |
| |
| |
| for d in range(int(math.log2(L_pad))): |
| step = 2 ** (d + 1) |
| half = step // 2 |
| |
| |
| A_even = A[:, half-1::step, :] |
| A_odd = A[:, step-1::step, :] |
| Bu_even = Bu[:, half-1::step, :, :] |
| Bu_odd = Bu[:, step-1::step, :, :] |
| |
| |
| A[:, step-1::step, :] = A_even * A_odd |
| Bu[:, step-1::step, :, :] = Bu_even * A_odd.unsqueeze(-1) + Bu_odd |
| |
| |
| for d in range(int(math.log2(L_pad)) - 1, -1, -1): |
| step = 2 ** (d + 1) |
| half = step // 2 |
| |
| A_left = A[:, half-1:L_pad-1:step, :] |
| Bu_left = Bu[:, half-1:L_pad-1:step, :, :] |
| |
| indices_right = range(step-1, L_pad, step) |
| A_right = A[:, indices_right, :] |
| Bu_right = Bu[:, indices_right, :, :] |
| |
| Bu[:, indices_right, :, :] = Bu_left * A_right.unsqueeze(-1) + Bu_right |
| |
| |
| |
| h = Bu[:, :L_orig, :, :] |
| y = (h * C[:, :L_orig, :].unsqueeze(2)).sum(dim=-1) |
| |
| return y |
| |
| def forward(self, x): |
| """ |
| Args: |
| x: [B, L, dim] or [B, C, H, W] (2D images) |
| |
| Returns: |
| output: same shape as input |
| """ |
| is_2d = x.dim() == 4 |
| |
| if is_2d: |
| B, C, H, W = x.shape |
| L = H * W |
| x = x.flatten(2).transpose(1, 2) |
| B, L, D = x.shape |
| else: |
| B, L, D = x.shape |
| |
| |
| |
| output = self._process_sequence(x) |
| |
| if is_2d: |
| output = output.transpose(1, 2).reshape(B, C, H, W) |
| |
| return output |
| |
| def _process_sequence(self, x): |
| """Process a 1D sequence through Mamba-2 SSD.""" |
| B, L, D = x.shape |
| device = x.device |
| |
| |
| xz = self.in_proj(x) |
| x_proj, z = xz.chunk(2, dim=-1) |
| |
| inner_dim = x_proj.shape[-1] |
| |
| |
| x_conv = x_proj.transpose(1, 2) |
| x_conv = self.conv1d(x_conv)[:, :, :L] |
| x_conv = F.silu(x_conv.transpose(1, 2)) |
| |
| |
| x_dbl = self.x_proj(x_conv) |
| |
| |
| d_state = self.d_state |
| dt, B, C = torch.split(x_dbl, [1, d_state, d_state], dim=-1) |
| |
| |
| dt = F.softplus(dt + self.dt_bias.reshape(1, 1, -1)) |
| dt = dt.squeeze(-1) |
| |
| |
| A = -torch.exp(self.A_log) |
| |
| |
| y = self._selective_scan(x_conv, dt, A, B, C, self.D) |
| y = self.norm(y) |
| |
| |
| y = y * F.silu(z) |
| |
| |
| y = self.out_proj(y) |
| |
| return y |
|
|
|
|
| class Mamba2Block(nn.Module): |
| """ |
| Mamba-2 block with multi-directional scanning for 2D images. |
| |
| Following VMamba's Cross-Scan (SS2D) strategy: |
| scan the image in 4 directions to capture 2D spatial context, |
| then merge the outputs. |
| |
| This is critical for image generation — pure 1D scanning |
| loses important spatial structure. |
| """ |
| |
| def __init__(self, dim, d_state=16, d_conv=4, expand=2, dropout=0.0): |
| super().__init__() |
| self.dim = dim |
| |
| self.norm1 = nn.LayerNorm(dim) |
| self.norm2 = nn.LayerNorm(dim) |
| |
| |
| self.ssd_fwd = Mamba2SSD(dim, d_state, d_conv, expand) |
| self.ssd_bwd = Mamba2SSD(dim, d_state, d_conv, expand) |
| self.ssd_horiz_fwd = Mamba2SSD(dim, d_state, d_conv, expand) |
| self.ssd_vert_fwd = Mamba2SSD(dim, d_state, d_conv, expand) |
| |
| |
| self.merge_proj = nn.Linear(dim * 4, dim) |
| |
| |
| ff_dim = dim * expand |
| self.ff = nn.Sequential( |
| nn.Linear(dim, ff_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(ff_dim, dim), |
| nn.Dropout(dropout), |
| ) |
| |
| def forward(self, x): |
| """ |
| Args: |
| x: [B, C, H, W] |
| Returns: |
| [B, C, H, W] |
| """ |
| is_seq = x.dim() == 3 |
| |
| if is_seq: |
| return self._forward_seq(x) |
| |
| B, C, H, W = x.shape |
| residual = x |
| |
| |
| x_flat = x.flatten(2).transpose(1, 2) |
| x_norm = self.norm1(x_flat).transpose(1, 2).reshape(B, C, H, W) |
| |
| |
| scan1 = x_norm.flatten(2).transpose(1, 2) |
| out1 = self.ssd_fwd._process_sequence(scan1) |
| out1 = out1.transpose(1, 2).reshape(B, C, H, W) |
| |
| |
| scan2 = x_norm.flatten(2).flip(-1).transpose(1, 2) |
| out2 = self.ssd_bwd._process_sequence(scan2) |
| out2 = out2.transpose(1, 2).reshape(B, C, H, W) |
| |
| out2_token = out2.flatten(2).flip(-1).reshape(B, C, H, W) |
| |
| |
| scan3 = x_norm.transpose(2, 3).flatten(2).transpose(1, 2) |
| out3 = self.ssd_horiz_fwd._process_sequence(scan3) |
| out3 = out3.transpose(1, 2).reshape(B, C, W, H).transpose(2, 3) |
| |
| |
| |
| out4_flat = self.ssd_vert_fwd._process_sequence(scan2) |
| out4 = out4_flat.transpose(1, 2).reshape(B, C, H, W) |
| out4_token = out4.flatten(2).flip(-1).reshape(B, C, H, W) |
| |
| |
| merged = torch.cat([ |
| out1.flatten(2).transpose(1, 2), |
| out2_token.flatten(2).transpose(1, 2), |
| out3.flatten(2).transpose(1, 2), |
| out4_token.flatten(2).transpose(1, 2), |
| ], dim=-1) |
| merged = self.merge_proj(merged) |
| merged = merged.transpose(1, 2).reshape(B, C, H, W) |
| |
| |
| x_out = residual + merged |
| x_ff = self.norm2(x_out.flatten(2).transpose(1, 2)) |
| x_ff = self.ff(x_ff).transpose(1, 2).reshape(B, C, H, W) |
| |
| return x_out + merged |
| |
| def _forward_seq(self, x): |
| """For 1D sequence input.""" |
| x_norm = self.norm1(x) |
| out = self.ssd_fwd._process_sequence(x_norm) |
| residual = x |
| x_out = residual + out |
| x_ff = self.norm2(x_out) |
| x_ff = self.ff(x_ff) |
| return x_out + x_ff |
|
|