| """ |
| LiqMamba: Liquid-Mamba Image Generator |
| |
| Complete architecture that combines: |
| 1. SDXL VAE for encoding/decoding (pretrained, frozen) |
| 2. CfC-gated Mamba-2 SSD backbone with multi-directional 2D scans |
| 3. Flow matching objective for stable training |
| 4. Lipshitz regularization (physics-informed) to prevent collapse |
| |
| Configurations: |
| - LiqMamba-Tiny: ~8M params (extreme lightweight) |
| - LiqMamba-Small: ~25M params (Colab/Kaggle free tier target) |
| - LiqMamba-Base: ~85M params (higher quality) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional |
| import math |
|
|
| from .mamba2_ssd import MultiDirectionalScan, Mamba2SSDBlock |
| from .cfc import CfCLayer |
|
|
|
|
| class PatchEmbed(nn.Module): |
| """Convert image latents to patch tokens.""" |
| def __init__(self, in_channels=4, dim=256, patch_size=1): |
| super().__init__() |
| self.proj = nn.Conv2d(in_channels, dim, patch_size, patch_size) |
| |
| def forward(self, x): |
| |
| x = self.proj(x) |
| B, C, H, W = x.shape |
| x = x.flatten(2).transpose(1, 2) |
| return x, H, W |
|
|
|
|
| class Unpatchify(nn.Module): |
| """Convert patch tokens back to image latents.""" |
| def __init__(self, dim=256, out_channels=4, patch_size=1): |
| super().__init__() |
| self.proj = nn.Conv2d(dim, out_channels, patch_size, patch_size) |
| |
| def forward(self, x, H, W): |
| B, L, D = x.shape |
| x = x.transpose(1, 2).view(B, D, H, W) |
| return self.proj(x) |
|
|
|
|
| class AdaLNModulation(nn.Module): |
| """ |
| Adaptive Layer Norm modulation (from DiT). |
| Injects timestep and optional class conditioning. |
| """ |
| def __init__(self, dim, cond_dim=256): |
| super().__init__() |
| self.norm = nn.LayerNorm(dim, elementwise_affine=False) |
| self.scale_shift = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(cond_dim, dim * 6) |
| ) |
| |
| def forward(self, x, c): |
| |
| params = self.scale_shift(c) |
| scale1, shift1, gate1, scale2, shift2, gate2 = params.chunk(6, dim=-1) |
| |
| |
| x = self.norm(x) * (1 + scale1.unsqueeze(1)) + shift1.unsqueeze(1) |
| x = x * gate1.unsqueeze(1) |
| return x |
|
|
|
|
| class TimestepEmbedding(nn.Module): |
| """Sinusoidal timestep embedding.""" |
| def __init__(self, dim, max_period=10000): |
| super().__init__() |
| self.dim = dim |
| self.max_period = max_period |
| self.mlp = nn.Sequential( |
| nn.Linear(dim, dim * 4), |
| nn.SiLU(), |
| nn.Linear(dim * 4, dim), |
| ) |
| |
| def forward(self, t): |
| |
| half = self.dim // 2 |
| freqs = torch.exp(-math.log(self.max_period) * |
| torch.arange(0, half, device=t.device).float() / half) |
| args = t.unsqueeze(-1) * freqs.unsqueeze(0) |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| if self.dim % 2: |
| embedding = F.pad(embedding, (0, 1)) |
| return self.mlp(embedding) |
|
|
|
|
| class LiqMambaBlock(nn.Module): |
| """ |
| Core LiqMamba block combining: |
| - AdaLN-Zero modulation |
| - Multi-directional Mamba-2 SSD scan |
| - CfC liquid state modulation |
| - Feed-forward with CfC gating |
| """ |
| def __init__(self, dim, cond_dim=256, d_state=16, expand=2, |
| scan_pattern="row_fwd", use_ffn=True): |
| super().__init__() |
| self.dim = dim |
| self.scan_pattern = scan_pattern |
| |
| |
| self.adaLN = AdaLNModulation(dim, cond_dim) |
| |
| |
| self.scan = MultiDirectionalScan(dim, pattern=scan_pattern, |
| d_state=d_state, expand=expand) |
| |
| |
| if use_ffn: |
| self.cfc_ffn = CfCLayer(dim, expansion_factor=2) |
| else: |
| self.cfc_ffn = nn.Identity() |
| |
| |
| self.adaLN_ffn = AdaLNModulation(dim, cond_dim) if use_ffn else None |
| |
| def forward(self, x, c, H, W): |
| |
| |
| |
| x_mod = self.adaLN(x, c) |
| x = x + self.scan(x_mod, H, W) |
| |
| |
| if self.adaLN_ffn is not None: |
| x_mod2 = self.adaLN_ffn(x, c) |
| x = x + self.cfc_ffn(x_mod2) |
| |
| return x |
|
|
|
|
| class LiqMamba(nn.Module): |
| """ |
| LiqMamba Image Generator β Liquid Neural Network + Mamba-2 SSD |
| |
| Architecture: |
| 1. Patch embed: latent (4, H, W) β tokens (H*W, dim) |
| 2. Timestep + condition embedding |
| 3. N stacked LiqMambaBlocks with alternating scan directions |
| 4. Unpatchify: tokens β latent (4, H, W) |
| |
| Config presets: |
| - Tiny: dim=128, depth=4 β ~8M params |
| - Small: dim=256, depth=8 β ~25M params |
| - Base: dim=512, depth=12 β ~85M params |
| """ |
| |
| def __init__( |
| self, |
| in_channels: int = 4, |
| out_channels: int = 4, |
| dim: int = 256, |
| depth: int = 8, |
| cond_dim: int = 256, |
| d_state: int = 16, |
| expand: int = 2, |
| patch_size: int = 1, |
| scan_patterns: list[str] | None = None, |
| ): |
| super().__init__() |
| |
| self.dim = dim |
| self.depth = depth |
| |
| |
| if scan_patterns is None: |
| scan_patterns = ["row_fwd", "row_rev", "col_fwd", "col_rev"] |
| self.scan_patterns = scan_patterns |
| |
| |
| self.patch_embed = PatchEmbed(in_channels, dim, patch_size) |
| self.unpatchify = Unpatchify(dim, out_channels, patch_size) |
| |
| |
| self.time_embed = TimestepEmbedding(cond_dim) |
| |
| |
| self.class_embed = nn.Embedding(1000, cond_dim) |
| |
| |
| self.cfc_init = CfCLayer(dim, expansion_factor=2) |
| |
| |
| self.blocks = nn.ModuleList() |
| for i in range(depth): |
| pattern = scan_patterns[i % len(scan_patterns)] |
| use_ffn = (i % 2 == 0) |
| self.blocks.append( |
| LiqMambaBlock( |
| dim=dim, |
| cond_dim=cond_dim, |
| d_state=d_state, |
| expand=expand, |
| scan_pattern=pattern, |
| use_ffn=use_ffn, |
| ) |
| ) |
| |
| |
| self.cfc_final = CfCLayer(dim, expansion_factor=2) |
| |
| |
| self.final_norm = nn.LayerNorm(dim) |
| |
| |
| self._init_weights() |
| |
| def _init_weights(self): |
| """Initialize with small values for stable training.""" |
| for module in self.modules(): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| |
| def forward(self, x, t, class_labels=None, return_dict=False): |
| """ |
| Args: |
| x: (B, C, H, W) latent images |
| t: (B,) float timesteps in [0, 1] |
| class_labels: (B,) optional class indices |
| Returns: |
| velocity field v(x, t) used for flow matching |
| """ |
| B, C, H, W = x.shape |
| |
| |
| x, H_p, W_p = self.patch_embed(x) |
| |
| |
| c = self.time_embed(t) |
| if class_labels is not None: |
| c = c + self.class_embed(class_labels) |
| |
| |
| x = self.cfc_init(x) |
| |
| |
| for block in self.blocks: |
| x = block(x, c, H_p, W_p) |
| |
| |
| x = self.final_norm(x) |
| x = self.cfc_final(x) |
| |
| |
| x = self.unpatchify(x, H_p, W_p) |
| |
| if return_dict: |
| return {"velocity": x} |
| return x |
| |
| def get_num_params(self): |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|
|
| def liqmamba_tiny(**kwargs): |
| """Tiny variant: ~8M params, extreme lightweight.""" |
| return LiqMamba(dim=128, depth=4, d_state=8, expand=2, **kwargs) |
|
|
| def liqmamba_small(**kwargs): |
| """Small variant: ~25M params, Colab/Kaggle free tier target.""" |
| return LiqMamba(dim=256, depth=8, d_state=16, expand=2, **kwargs) |
|
|
| def liqmamba_base(**kwargs): |
| """Base variant: ~85M params, higher quality.""" |
| return LiqMamba(dim=512, depth=12, d_state=16, expand=2, **kwargs) |