""" LiquidFlow Block — Hybrid CfC + Mamba-2 SSD architecture. The core innovation: combine Liquid Neural Network dynamics (CfC) with Mamba-2's efficient linear-time state space model. Architecture per block: Input → [CfC Gate → Mamba2 SSD → CfC Gate] → Output ↑ ↑ Adaptive gating Gated output The CfC provides: - Time-continuous adaptive gating (what to process/ignore) - State initialization for the SSM (the "liquid" memory) The Mamba-2 SSD provides: - Efficient O(N) sequence processing - Content-aware selection mechanism - Parallelizable computation (no sequential bottleneck) Together they create a "Liquid State Space Model" (LSSM): h_t = σ(-f(x_t;θ_f)·t) ⊙ SSM(x_t, h_{t-1}) + (1-σ(...)) ⊙ h(x_t;θ_h) Where SSM is the Mamba-2 selective state space model and the CfC time-gates control how much the SSM output influences state. This is inspired by: - LNNs: adaptive time constants for state evolution - Mamba-2: efficient selective state space models - DiMSUM: multi-scan architecture for 2D images - Gated SSM: gating mechanism from CfC applied to SSM """ import torch import torch.nn as nn import torch.nn.functional as F from .cfc_cell import CfCCell from .mamba2_ssd import Mamba2SSD class LiquidMambaBlock(nn.Module): """ LiquidMamba: CfC-gated Mamba-2 SSD block. The CfC cell acts as a learned gate on the Mamba-2 output, creating a liquid time-constant mechanism for the SSM: 1. Input goes through Mamba-2 SSD (multi-directional scan) 2. CfC cell receives the SSM output + original input 3. CfC produces a time-gated output: σ(f)·SSM_out + (1-σ(f))·input 4. The CfC's liquid dynamics adaptively mix SSM features with raw input This creates a "content-aware gating" that the CfC learns to control based on both the input and the SSM's processed features. """ def __init__(self, dim, d_state=16, d_conv=4, expand=2, dropout=0.0): super().__init__() self.dim = dim # LayerNorms self.norm_in = nn.LayerNorm(dim) self.norm_mamba = nn.LayerNorm(dim) self.norm_out = nn.LayerNorm(dim) # Mamba-2 SSD for efficient sequence processing self.mamba = Mamba2SSD(dim=dim, d_state=d_state, d_conv=d_conv, expand=expand) # CfC gate: controls the flow between Mamba output and residual self.cfc_gate = CfCCell(dim=dim, backbone_dropout=dropout, use_conv=True) # Feed-forward ff_dim = dim * expand self.ff = nn.Sequential( nn.Linear(dim, ff_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(ff_dim, dim), nn.Dropout(dropout), ) # Learnable mixing ratio init self.gate_scale = nn.Parameter(torch.ones(1) * 0.5) def forward(self, x): """ Args: x: [B, C, H, W] (2D) or [B, L, C] (1D seq) Returns: Same shape as input """ is_2d = x.dim() == 4 if is_2d: B, C, H, W = x.shape L = H * W x_flat = x.flatten(2).transpose(1, 2) # [B, HW, C] else: B, L, C = x.shape x_flat = x residual = x_flat x_norm = self.norm_in(x_flat) # Mamba-2 SSD processing with multi-directional scan if is_2d: # Reshape for 2D scanning x_2d = x_norm.transpose(1, 2).reshape(B, C, H, W) mamba_out = self._mamba_2d_scan(x_2d) mamba_out = mamba_out.flatten(2).transpose(1, 2) # [B, HW, C] else: mamba_out = self.mamba(x_norm) # CfC gating: liquid dynamics control the mix mamba_norm = self.norm_mamba(mamba_out) # CfC receives both the Mamba output and the residual # This lets it learn when to trust the SSM vs the original signal cfc_input = mamba_norm + residual cfc_out = self.cfc_gate(cfc_input) # Gated mix: CfC controls the blend gate = torch.sigmoid(self.gate_scale * (cfc_out - mamba_out)) mixed = gate * mamba_out + (1 - gate) * residual + cfc_out # Feed-forward + residual out_norm = self.norm_out(mixed) out = mixed + self.ff(out_norm) if is_2d: out = out.transpose(1, 2).reshape(B, C, H, W) return out def _mamba_2d_scan(self, x): """ Multi-directional Mamba-2 scan for 2D images. Scans in forward and backward raster directions, then merges. This preserves 2D spatial structure better than single-direction scan. """ B, C, H, W = x.shape device = x.device # Forward raster: left→right, top→bottom fwd = x.flatten(2) # [B, C, HW] fwd_seq = fwd.transpose(1, 2) # [B, HW, C] fwd_out = self.mamba(fwd_seq) # Backward raster: right→left, bottom→top bwd = torch.flip(x.flatten(2), dims=[-1]) # [B, C, HW] bwd_seq = bwd.transpose(1, 2) bwd_out = self.mamba(bwd_seq) bwd_out = torch.flip(bwd_out, dims=[1]) # Flip back # Merge both directions merged = (fwd_out + bwd_out) / 2 merged = merged.transpose(1, 2).reshape(B, C, H, W) return merged class LiquidFlowStage(nn.Module): """ A stage in LiquidFlow: multiple LiquidMamba blocks at the same resolution. Architecture: [LiquidMamba Block] × num_blocks [Optional Downsample/Upsample] This mirrors the hierarchical design from DiT/DiMSUM but with liquid neural network dynamics in every block. """ def __init__(self, dim, num_blocks=4, d_state=16, expand=2, dropout=0.0): super().__init__() self.dim = dim self.blocks = nn.ModuleList([ LiquidMambaBlock(dim=dim, d_state=d_state, expand=expand, dropout=dropout) for _ in range(num_blocks) ]) def forward(self, x): for block in self.blocks: x = block(x) return x class LiquidFlowBackbone(nn.Module): """ Complete LiquidFlow backbone for image generation. Architecture: Input (noisy latent) [B, C, H, W] ↓ [Patch Embed + Positional Encoding] ↓ [LiquidMamba Stages × N] (at uniform resolution) ↓ [Output Head] → predicted noise This is designed as a DiT-style noise predictor for diffusion models. Args: in_channels: Input channels (latent dim from VAE) hidden_dim: Hidden dimension num_stages: Number of processing stages blocks_per_stage: Number of blocks per stage d_state: SSM state dimension expand: Expansion factor dropout: Dropout rate """ def __init__( self, in_channels=4, hidden_dim=256, num_stages=4, blocks_per_stage=4, d_state=16, expand=2, dropout=0.0, ): super().__init__() self.in_channels = in_channels self.hidden_dim = hidden_dim self.num_stages = num_stages # Input embedding: patch embedding self.patch_size = 2 # Fixed patch size self.in_proj = nn.Conv2d(in_channels, hidden_dim, kernel_size=1) # Time embedding (for diffusion timestep) self.time_embed = nn.Sequential( nn.Linear(hidden_dim, hidden_dim * 4), nn.SiLU(), nn.Linear(hidden_dim * 4, hidden_dim), ) # Learnable positional encoding # For 128×128 with patch_size=2: 64×64 = 4096 positions self.pos_embed = nn.Parameter(torch.randn(1, 4096, hidden_dim) * 0.02) # LiquidFlow stages self.stages = nn.ModuleList([ LiquidFlowStage( dim=hidden_dim, num_blocks=blocks_per_stage, d_state=d_state, expand=expand, dropout=dropout, ) for _ in range(num_stages) ]) # Output head self.out_norm = nn.LayerNorm(hidden_dim) self.out_proj = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, in_channels * self.patch_size * self.patch_size), ) # Timestep conditioner (modulated conv trick) self.t_conditioner = nn.Sequential( nn.SiLU(), nn.Linear(hidden_dim, hidden_dim * 2), # scale, shift ) def _get_timestep_embedding(self, timesteps, dim, max_period=10000): """Sinusoidal timestep embedding (from DiT).""" half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(timesteps.device) args = timesteps.float().unsqueeze(-1) * freqs.unsqueeze(0) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, x, t): """ Args: x: Noisy latent [B, C, H, W] t: Diffusion timesteps [B] Returns: Predicted noise [B, C, H, W] """ B, C, H, W = x.shape device = x.device L = (H // self.patch_size) * (W // self.patch_size) # Input projection x = self.in_proj(x) # [B, hidden_dim, H, W] # Flatten and add positional encoding x_flat = x.flatten(2).transpose(1, 2) # [B, H*W, hidden_dim] # Time embedding t_emb = self._get_timestep_embedding(t, self.hidden_dim) t_emb = self.time_embed(t_emb) # [B, hidden_dim] # Add time conditioning as bias to input t_cond = self.t_conditioner(t_emb) # [B, hidden_dim * 2] t_scale, t_shift = t_cond.chunk(2, dim=-1) x_flat = x_flat * (1 + t_scale.unsqueeze(1)) + t_shift.unsqueeze(1) # Add positional encoding x_flat = x_flat + self.pos_embed[:, :L, :] # Reshape back to 2D for processing x_2d = x_flat.transpose(1, 2).reshape(B, self.hidden_dim, H, W) # Process through all stages for stage in self.stages: x_2d = stage(x_2d) # Output head x_out = x_2d.flatten(2).transpose(1, 2) # [B, H*W, hidden_dim] x_out = self.out_norm(x_out) x_out = self.out_proj(x_out) # [B, H*W, C * patch²] # Reshape to image x_out = x_out.reshape(B, H, W, C, self.patch_size, self.patch_size) x_out = x_out.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H * self.patch_size, W * self.patch_size) return x_out import math