File size: 11,159 Bytes
40a4412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
"""
LiquidFlow Block — Hybrid CfC + Mamba-2 SSD architecture.

The core innovation: combine Liquid Neural Network dynamics (CfC)
with Mamba-2's efficient linear-time state space model.

Architecture per block:
    Input → [CfC Gate → Mamba2 SSD → CfC Gate] → Output
                ↑                        ↑
            Adaptive gating        Gated output

The CfC provides:
    - Time-continuous adaptive gating (what to process/ignore)
    - State initialization for the SSM (the "liquid" memory)
    
The Mamba-2 SSD provides:
    - Efficient O(N) sequence processing
    - Content-aware selection mechanism
    - Parallelizable computation (no sequential bottleneck)

Together they create a "Liquid State Space Model" (LSSM):
    h_t = σ(-f(x_t;θ_f)·t) ⊙ SSM(x_t, h_{t-1}) + (1-σ(...)) ⊙ h(x_t;θ_h)

Where SSM is the Mamba-2 selective state space model and the
CfC time-gates control how much the SSM output influences state.

This is inspired by:
- LNNs: adaptive time constants for state evolution
- Mamba-2: efficient selective state space models
- DiMSUM: multi-scan architecture for 2D images
- Gated SSM: gating mechanism from CfC applied to SSM
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from .cfc_cell import CfCCell
from .mamba2_ssd import Mamba2SSD


class LiquidMambaBlock(nn.Module):
    """
    LiquidMamba: CfC-gated Mamba-2 SSD block.
    
    The CfC cell acts as a learned gate on the Mamba-2 output,
    creating a liquid time-constant mechanism for the SSM:
    
    1. Input goes through Mamba-2 SSD (multi-directional scan)
    2. CfC cell receives the SSM output + original input
    3. CfC produces a time-gated output: σ(f)·SSM_out + (1-σ(f))·input
    4. The CfC's liquid dynamics adaptively mix SSM features with raw input
    
    This creates a "content-aware gating" that the CfC learns to
    control based on both the input and the SSM's processed features.
    """
    
    def __init__(self, dim, d_state=16, d_conv=4, expand=2, dropout=0.0):
        super().__init__()
        self.dim = dim
        
        # LayerNorms
        self.norm_in = nn.LayerNorm(dim)
        self.norm_mamba = nn.LayerNorm(dim)
        self.norm_out = nn.LayerNorm(dim)
        
        # Mamba-2 SSD for efficient sequence processing
        self.mamba = Mamba2SSD(dim=dim, d_state=d_state, d_conv=d_conv, expand=expand)
        
        # CfC gate: controls the flow between Mamba output and residual
        self.cfc_gate = CfCCell(dim=dim, backbone_dropout=dropout, use_conv=True)
        
        # Feed-forward
        ff_dim = dim * expand
        self.ff = nn.Sequential(
            nn.Linear(dim, ff_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, dim),
            nn.Dropout(dropout),
        )
        
        # Learnable mixing ratio init
        self.gate_scale = nn.Parameter(torch.ones(1) * 0.5)
    
    def forward(self, x):
        """
        Args:
            x: [B, C, H, W] (2D) or [B, L, C] (1D seq)
        Returns:
            Same shape as input
        """
        is_2d = x.dim() == 4
        
        if is_2d:
            B, C, H, W = x.shape
            L = H * W
            x_flat = x.flatten(2).transpose(1, 2)  # [B, HW, C]
        else:
            B, L, C = x.shape
            x_flat = x
        
        residual = x_flat
        x_norm = self.norm_in(x_flat)
        
        # Mamba-2 SSD processing with multi-directional scan
        if is_2d:
            # Reshape for 2D scanning
            x_2d = x_norm.transpose(1, 2).reshape(B, C, H, W)
            mamba_out = self._mamba_2d_scan(x_2d)
            mamba_out = mamba_out.flatten(2).transpose(1, 2)  # [B, HW, C]
        else:
            mamba_out = self.mamba(x_norm)
        
        # CfC gating: liquid dynamics control the mix
        mamba_norm = self.norm_mamba(mamba_out)
        
        # CfC receives both the Mamba output and the residual
        # This lets it learn when to trust the SSM vs the original signal
        cfc_input = mamba_norm + residual
        cfc_out = self.cfc_gate(cfc_input)
        
        # Gated mix: CfC controls the blend
        gate = torch.sigmoid(self.gate_scale * (cfc_out - mamba_out))
        mixed = gate * mamba_out + (1 - gate) * residual + cfc_out
        
        # Feed-forward + residual
        out_norm = self.norm_out(mixed)
        out = mixed + self.ff(out_norm)
        
        if is_2d:
            out = out.transpose(1, 2).reshape(B, C, H, W)
        
        return out
    
    def _mamba_2d_scan(self, x):
        """
        Multi-directional Mamba-2 scan for 2D images.
        
        Scans in forward and backward raster directions, then merges.
        This preserves 2D spatial structure better than single-direction scan.
        """
        B, C, H, W = x.shape
        device = x.device
        
        # Forward raster: left→right, top→bottom
        fwd = x.flatten(2)  # [B, C, HW]
        fwd_seq = fwd.transpose(1, 2)  # [B, HW, C]
        fwd_out = self.mamba(fwd_seq)
        
        # Backward raster: right→left, bottom→top  
        bwd = torch.flip(x.flatten(2), dims=[-1])  # [B, C, HW]
        bwd_seq = bwd.transpose(1, 2)
        bwd_out = self.mamba(bwd_seq)
        bwd_out = torch.flip(bwd_out, dims=[1])  # Flip back
        
        # Merge both directions
        merged = (fwd_out + bwd_out) / 2
        merged = merged.transpose(1, 2).reshape(B, C, H, W)
        
        return merged


class LiquidFlowStage(nn.Module):
    """
    A stage in LiquidFlow: multiple LiquidMamba blocks at the same resolution.
    
    Architecture:
        [LiquidMamba Block] × num_blocks
        [Optional Downsample/Upsample]
    
    This mirrors the hierarchical design from DiT/DiMSUM but with
    liquid neural network dynamics in every block.
    """
    
    def __init__(self, dim, num_blocks=4, d_state=16, expand=2, dropout=0.0):
        super().__init__()
        self.dim = dim
        
        self.blocks = nn.ModuleList([
            LiquidMambaBlock(dim=dim, d_state=d_state, expand=expand, dropout=dropout)
            for _ in range(num_blocks)
        ])
    
    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x


class LiquidFlowBackbone(nn.Module):
    """
    Complete LiquidFlow backbone for image generation.
    
    Architecture:
        Input (noisy latent) [B, C, H, W]

        [Patch Embed + Positional Encoding]

        [LiquidMamba Stages × N]  (at uniform resolution)

        [Output Head] → predicted noise
    
    This is designed as a DiT-style noise predictor for diffusion models.
    
    Args:
        in_channels: Input channels (latent dim from VAE)
        hidden_dim: Hidden dimension
        num_stages: Number of processing stages
        blocks_per_stage: Number of blocks per stage
        d_state: SSM state dimension
        expand: Expansion factor
        dropout: Dropout rate
    """
    
    def __init__(
        self,
        in_channels=4,
        hidden_dim=256,
        num_stages=4,
        blocks_per_stage=4,
        d_state=16,
        expand=2,
        dropout=0.0,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.hidden_dim = hidden_dim
        self.num_stages = num_stages
        
        # Input embedding: patch embedding
        self.patch_size = 2  # Fixed patch size
        self.in_proj = nn.Conv2d(in_channels, hidden_dim, kernel_size=1)
        
        # Time embedding (for diffusion timestep)
        self.time_embed = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.SiLU(),
            nn.Linear(hidden_dim * 4, hidden_dim),
        )
        
        # Learnable positional encoding
        # For 128×128 with patch_size=2: 64×64 = 4096 positions
        self.pos_embed = nn.Parameter(torch.randn(1, 4096, hidden_dim) * 0.02)
        
        # LiquidFlow stages
        self.stages = nn.ModuleList([
            LiquidFlowStage(
                dim=hidden_dim, 
                num_blocks=blocks_per_stage,
                d_state=d_state,
                expand=expand,
                dropout=dropout,
            )
            for _ in range(num_stages)
        ])
        
        # Output head
        self.out_norm = nn.LayerNorm(hidden_dim)
        self.out_proj = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, in_channels * self.patch_size * self.patch_size),
        )
        
        # Timestep conditioner (modulated conv trick)
        self.t_conditioner = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim * 2),  # scale, shift
        )
    
    def _get_timestep_embedding(self, timesteps, dim, max_period=10000):
        """Sinusoidal timestep embedding (from DiT)."""
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(timesteps.device)
        args = timesteps.float().unsqueeze(-1) * freqs.unsqueeze(0)
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding
    
    def forward(self, x, t):
        """
        Args:
            x: Noisy latent [B, C, H, W]
            t: Diffusion timesteps [B]
        
        Returns:
            Predicted noise [B, C, H, W]
        """
        B, C, H, W = x.shape
        device = x.device
        L = (H // self.patch_size) * (W // self.patch_size)
        
        # Input projection
        x = self.in_proj(x)  # [B, hidden_dim, H, W]
        
        # Flatten and add positional encoding
        x_flat = x.flatten(2).transpose(1, 2)  # [B, H*W, hidden_dim]
        
        # Time embedding
        t_emb = self._get_timestep_embedding(t, self.hidden_dim)
        t_emb = self.time_embed(t_emb)  # [B, hidden_dim]
        
        # Add time conditioning as bias to input
        t_cond = self.t_conditioner(t_emb)  # [B, hidden_dim * 2]
        t_scale, t_shift = t_cond.chunk(2, dim=-1)
        x_flat = x_flat * (1 + t_scale.unsqueeze(1)) + t_shift.unsqueeze(1)
        
        # Add positional encoding
        x_flat = x_flat + self.pos_embed[:, :L, :]
        
        # Reshape back to 2D for processing
        x_2d = x_flat.transpose(1, 2).reshape(B, self.hidden_dim, H, W)
        
        # Process through all stages
        for stage in self.stages:
            x_2d = stage(x_2d)
        
        # Output head
        x_out = x_2d.flatten(2).transpose(1, 2)  # [B, H*W, hidden_dim]
        x_out = self.out_norm(x_out)
        x_out = self.out_proj(x_out)  # [B, H*W, C * patch²]
        
        # Reshape to image
        x_out = x_out.reshape(B, H, W, C, self.patch_size, self.patch_size)
        x_out = x_out.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H * self.patch_size, W * self.patch_size)
        
        return x_out


import math