""" LiqMamba: Liquid-Mamba Image Generator Complete architecture that combines: 1. SDXL VAE for encoding/decoding (pretrained, frozen) 2. CfC-gated Mamba-2 SSD backbone with multi-directional 2D scans 3. Flow matching objective for stable training 4. Lipshitz regularization (physics-informed) to prevent collapse Configurations: - LiqMamba-Tiny: ~8M params (extreme lightweight) - LiqMamba-Small: ~25M params (Colab/Kaggle free tier target) - LiqMamba-Base: ~85M params (higher quality) """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional import math from .mamba2_ssd import MultiDirectionalScan, Mamba2SSDBlock from .cfc import CfCLayer class PatchEmbed(nn.Module): """Convert image latents to patch tokens.""" def __init__(self, in_channels=4, dim=256, patch_size=1): super().__init__() self.proj = nn.Conv2d(in_channels, dim, patch_size, patch_size) def forward(self, x): # x: (B, C, H, W) -> (B, dim, H, W) -> (B, H*W, dim) x = self.proj(x) B, C, H, W = x.shape x = x.flatten(2).transpose(1, 2) return x, H, W class Unpatchify(nn.Module): """Convert patch tokens back to image latents.""" def __init__(self, dim=256, out_channels=4, patch_size=1): super().__init__() self.proj = nn.Conv2d(dim, out_channels, patch_size, patch_size) def forward(self, x, H, W): B, L, D = x.shape x = x.transpose(1, 2).view(B, D, H, W) return self.proj(x) class AdaLNModulation(nn.Module): """ Adaptive Layer Norm modulation (from DiT). Injects timestep and optional class conditioning. """ def __init__(self, dim, cond_dim=256): super().__init__() self.norm = nn.LayerNorm(dim, elementwise_affine=False) self.scale_shift = nn.Sequential( nn.SiLU(), nn.Linear(cond_dim, dim * 6) # scale, shift, gate x 2 ) def forward(self, x, c): # x: (B, L, D), c: (B, cond_dim) params = self.scale_shift(c) # (B, D*6) scale1, shift1, gate1, scale2, shift2, gate2 = params.chunk(6, dim=-1) # Modulate x = self.norm(x) * (1 + scale1.unsqueeze(1)) + shift1.unsqueeze(1) x = x * gate1.unsqueeze(1) return x class TimestepEmbedding(nn.Module): """Sinusoidal timestep embedding.""" def __init__(self, dim, max_period=10000): super().__init__() self.dim = dim self.max_period = max_period self.mlp = nn.Sequential( nn.Linear(dim, dim * 4), nn.SiLU(), nn.Linear(dim * 4, dim), ) def forward(self, t): # t: (B,) float timesteps in [0,1] half = self.dim // 2 freqs = torch.exp(-math.log(self.max_period) * torch.arange(0, half, device=t.device).float() / half) args = t.unsqueeze(-1) * freqs.unsqueeze(0) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if self.dim % 2: embedding = F.pad(embedding, (0, 1)) return self.mlp(embedding) class LiqMambaBlock(nn.Module): """ Core LiqMamba block combining: - AdaLN-Zero modulation - Multi-directional Mamba-2 SSD scan - CfC liquid state modulation - Feed-forward with CfC gating """ def __init__(self, dim, cond_dim=256, d_state=16, expand=2, scan_pattern="row_fwd", use_ffn=True): super().__init__() self.dim = dim self.scan_pattern = scan_pattern # AdaLN self.adaLN = AdaLNModulation(dim, cond_dim) # Multi-directional scan self.scan = MultiDirectionalScan(dim, pattern=scan_pattern, d_state=d_state, expand=expand) # CfC liquid layer (replaces FFN for adaptive computation) if use_ffn: self.cfc_ffn = CfCLayer(dim, expansion_factor=2) else: self.cfc_ffn = nn.Identity() # AdaLN for FFN self.adaLN_ffn = AdaLNModulation(dim, cond_dim) if use_ffn else None def forward(self, x, c, H, W): # x: (B, H*W, dim), c: (B, cond_dim) # Scan with conditioning x_mod = self.adaLN(x, c) x = x + self.scan(x_mod, H, W) # CfC FFN with conditioning if self.adaLN_ffn is not None: x_mod2 = self.adaLN_ffn(x, c) x = x + self.cfc_ffn(x_mod2) return x class LiqMamba(nn.Module): """ LiqMamba Image Generator — Liquid Neural Network + Mamba-2 SSD Architecture: 1. Patch embed: latent (4, H, W) → tokens (H*W, dim) 2. Timestep + condition embedding 3. N stacked LiqMambaBlocks with alternating scan directions 4. Unpatchify: tokens → latent (4, H, W) Config presets: - Tiny: dim=128, depth=4 → ~8M params - Small: dim=256, depth=8 → ~25M params - Base: dim=512, depth=12 → ~85M params """ def __init__( self, in_channels: int = 4, # VAE latent channels out_channels: int = 4, dim: int = 256, # Hidden dimension depth: int = 8, # Number of blocks cond_dim: int = 256, # Conditioning dimension d_state: int = 16, # SSM state dimension expand: int = 2, # SSD expansion factor patch_size: int = 1, scan_patterns: list[str] | None = None, ): super().__init__() self.dim = dim self.depth = depth # Scan pattern rotation (matching DiM's 4-pattern cycle) if scan_patterns is None: scan_patterns = ["row_fwd", "row_rev", "col_fwd", "col_rev"] self.scan_patterns = scan_patterns # Patch embedding self.patch_embed = PatchEmbed(in_channels, dim, patch_size) self.unpatchify = Unpatchify(dim, out_channels, patch_size) # Timestep embedding self.time_embed = TimestepEmbedding(cond_dim) # Optional class embedding self.class_embed = nn.Embedding(1000, cond_dim) # Initial CfC layer for liquid state initialization self.cfc_init = CfCLayer(dim, expansion_factor=2) # LiqMamba blocks self.blocks = nn.ModuleList() for i in range(depth): pattern = scan_patterns[i % len(scan_patterns)] use_ffn = (i % 2 == 0) # FFN every other block for efficiency self.blocks.append( LiqMambaBlock( dim=dim, cond_dim=cond_dim, d_state=d_state, expand=expand, scan_pattern=pattern, use_ffn=use_ffn, ) ) # Final CfC refinement layer self.cfc_final = CfCLayer(dim, expansion_factor=2) # Final projection self.final_norm = nn.LayerNorm(dim) # Initialize weights self._init_weights() def _init_weights(self): """Initialize with small values for stable training.""" for module in self.modules(): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) def forward(self, x, t, class_labels=None, return_dict=False): """ Args: x: (B, C, H, W) latent images t: (B,) float timesteps in [0, 1] class_labels: (B,) optional class indices Returns: velocity field v(x, t) used for flow matching """ B, C, H, W = x.shape # Patch embed x, H_p, W_p = self.patch_embed(x) # (B, H*W, dim) # Timestep conditioning c = self.time_embed(t) # (B, cond_dim) if class_labels is not None: c = c + self.class_embed(class_labels) # Initial liquid state x = self.cfc_init(x) # LiqMamba blocks for block in self.blocks: x = block(x, c, H_p, W_p) # Final refinement x = self.final_norm(x) x = self.cfc_final(x) # Unpatchify x = self.unpatchify(x, H_p, W_p) if return_dict: return {"velocity": x} return x def get_num_params(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) def liqmamba_tiny(**kwargs): """Tiny variant: ~8M params, extreme lightweight.""" return LiqMamba(dim=128, depth=4, d_state=8, expand=2, **kwargs) def liqmamba_small(**kwargs): """Small variant: ~25M params, Colab/Kaggle free tier target.""" return LiqMamba(dim=256, depth=8, d_state=16, expand=2, **kwargs) def liqmamba_base(**kwargs): """Base variant: ~85M params, higher quality.""" return LiqMamba(dim=512, depth=12, d_state=16, expand=2, **kwargs)