File size: 28,814 Bytes
f5c1b06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
"""
LiRA Core Modules: Gated State-Space Backbone (GS3B)

Mathematical Foundation:
========================
Traditional transformers use self-attention: O_i = softmax(Q_i K^T / sqrt(d)) V
This is O(N^2) in sequence length - prohibitive for high-res images.

Our approach combines three key innovations:

1. SELECTIVE STATE SPACE (from Mamba/S6):
   State evolution: h_t = A_t * h_{t-1} + B_t * x_t
   Output: y_t = C_t * h_t + D * x_t
   Where A_t, B_t, C_t are INPUT-DEPENDENT (selective) - this is the key insight
   from Mamba that makes SSMs competitive with attention.

2. BIDIRECTIONAL GATED SCANNING (from DiM + RWKV-7):
   Images are 2D, not 1D. We scan in 4 directions:
   - Horizontal L→R, R→L
   - Vertical T→B, B→T
   Each direction maintains its own state. A learned gate fuses them:
   y = gate * [y_lr; y_rl; y_tb; y_bt]
   
   From RWKV-7 we take the generalized delta rule for state updates:
   S_t = S_{t-1} * (diag(w_t) - k_t^T (a_t ⊗ k_t)) + v_t^T k_t
   This gives us input-dependent decay with O(N) complexity.

3. FREQUENCY-AWARE PROCESSING (from DiMSUM):
   We apply lightweight wavelet decomposition to separate structure from detail,
   process each frequency band with appropriate granularity, then recombine.
   Low-freq (structure) → fewer tokens, heavier processing
   High-freq (detail) → more tokens, lighter processing

Combined complexity: O(N * d * H) where N=tokens, d=state_dim, H=num_heads
For 1024px with f32 VAE: N = 32*32 = 1024 tokens → extremely efficient
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple
from einops import rearrange


# ============================================================================
# Core Building Block: Gated Selective State-Space Layer
# ============================================================================

class SelectiveStateSpace(nn.Module):
    """
    Selective State Space layer with input-dependent parameters.
    
    Mathematical formulation:
        h_t = diag(exp(A_t)) * h_{t-1} + B_t * x_t    (state transition)
        y_t = C_t * h_t                                  (output projection)
    
    Where A_t, B_t, C_t are all computed from the input (selective/data-dependent).
    This selectivity is what allows SSMs to match transformer quality.
    
    Key insight: discretization of continuous dynamics means we can model
    any timescale of dependencies by learning the step size Δ.
    """
    
    def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        
        # Input projections for selectivity
        # We project to 2*d_model: one for the "gate" branch, one for the SSM branch
        self.in_proj = nn.Linear(d_model, 2 * d_model, bias=False)
        
        # Local convolution for capturing immediate neighbors (from Mamba)
        self.conv1d = nn.Conv1d(
            d_model, d_model, kernel_size=d_conv, 
            padding=d_conv - 1, groups=d_model, bias=True
        )
        
        # Selective parameters: ∆ (step size), B, C are input-dependent
        # A is a learnable diagonal matrix (log-space for stability)
        self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1, dtype=torch.float32).repeat(d_model, 1)))
        self.D = nn.Parameter(torch.ones(d_model))  # Skip connection
        
        # Input-dependent projections
        self.dt_proj = nn.Linear(d_model, d_model, bias=True)
        self.B_proj = nn.Linear(d_model, d_state, bias=False)
        self.C_proj = nn.Linear(d_model, d_state, bias=False)
        
        # Output projection
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        
        # Initialize dt bias to ensure positive step sizes  
        dt_init_std = d_model ** -0.5
        nn.init.uniform_(self.dt_proj.bias, -4.0, -2.0)  # Initialize in log space
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, L, D) input sequence
        Returns: (B, L, D) output sequence
        """
        B, L, D = x.shape
        
        # Split into gate and SSM branches
        xz = self.in_proj(x)  # (B, L, 2D)
        x_ssm, z = xz.chunk(2, dim=-1)  # Each (B, L, D)
        
        # Local convolution (causal)
        x_conv = x_ssm.transpose(1, 2)  # (B, D, L)
        x_conv = self.conv1d(x_conv)[:, :, :L]  # Causal: trim to L
        x_conv = x_conv.transpose(1, 2)  # (B, L, D)
        x_conv = F.silu(x_conv)
        
        # Compute selective parameters
        dt = F.softplus(self.dt_proj(x_conv))  # (B, L, D) - step sizes
        B_sel = self.B_proj(x_conv)  # (B, L, N)
        C_sel = self.C_proj(x_conv)  # (B, L, N)
        
        # Discretize A
        A = -torch.exp(self.A_log)  # (D, N)
        
        # Selective scan (vectorized for speed)
        y = self._selective_scan(x_conv, dt, A, B_sel, C_sel)  # (B, L, D)
        
        # Skip connection
        y = y + self.D.unsqueeze(0).unsqueeze(0) * x_conv
        
        # Gating (from Mamba - SiLU gate)
        y = y * F.silu(z)
        
        return self.out_proj(y)
    
    def _selective_scan(self, x, dt, A, B, C):
        """
        Parallel selective scan using cumulative operations.
        
        For training, we use the parallel form:
        h_t = exp(A * dt_t) * h_{t-1} + dt_t * B_t * x_t
        y_t = C_t * h_t
        
        We compute this via log-space cumsum for numerical stability.
        """
        B_batch, L, D = x.shape
        N = A.shape[1]
        
        # Compute discretized A and B
        # dA = exp(A * dt): (B, L, D, N)
        dt_expanded = dt.unsqueeze(-1)  # (B, L, D, 1)
        A_expanded = A.unsqueeze(0).unsqueeze(0)  # (1, 1, D, N)
        dA = torch.exp(dt_expanded * A_expanded)  # (B, L, D, N)
        
        # dB * x: (B, L, D, N)
        dBx = dt_expanded * B.unsqueeze(2) * x.unsqueeze(-1)  # (B, L, D, N)
        
        # Sequential scan (we'll use a chunked approach for efficiency)
        # For moderate sequence lengths (1024), direct scan is fast enough
        h = torch.zeros(B_batch, D, N, device=x.device, dtype=x.dtype)
        ys = []
        
        # Use chunks of 64 for better memory efficiency
        chunk_size = min(64, L)
        for i in range(0, L, chunk_size):
            end = min(i + chunk_size, L)
            chunk_len = end - i
            
            chunk_ys = []
            for t in range(chunk_len):
                idx = i + t
                h = dA[:, idx] * h + dBx[:, idx]  # (B, D, N)
                y_t = (h * C[:, idx].unsqueeze(1)).sum(-1)  # (B, D)
                chunk_ys.append(y_t)
            
            ys.extend(chunk_ys)
        
        y = torch.stack(ys, dim=1)  # (B, L, D)
        return y


# ============================================================================
# Bidirectional Spatial Scanner  
# ============================================================================

class BidirectionalSpatialScanner(nn.Module):
    """
    Scans 2D spatial features in 4 directions to capture full spatial context.
    
    Innovation: Instead of 4 separate SSMs (expensive), we use 2 SSMs with
    input reversal, and fuse with a learned spatial gate.
    
    Directions:
    1. Row-major L→R (horizontal forward)
    2. Row-major R→L (horizontal backward) 
    3. Col-major T→B (vertical forward)
    4. Col-major B→T (vertical backward)
    
    The gate learns to weight each direction based on spatial position and content.
    """
    
    def __init__(self, d_model: int, d_state: int = 16):
        super().__init__()
        
        # Only 2 SSM instances - we reverse inputs for bidirectional
        self.ssm_horizontal = SelectiveStateSpace(d_model, d_state)
        self.ssm_vertical = SelectiveStateSpace(d_model, d_state)
        
        # Spatial fusion gate - learns to weight directions
        self.fusion_gate = nn.Sequential(
            nn.Linear(d_model, d_model, bias=False),
            nn.Sigmoid()
        )
        
        # Norm for stability
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
        """
        x: (B, H*W, D) flattened spatial features
        Returns: (B, H*W, D) with full spatial context
        """
        B, L, D = x.shape
        
        # Horizontal scanning (row-major order)
        x_fwd = self.ssm_horizontal(x)
        x_bwd = self._reverse_scan(x, self.ssm_horizontal, H, W, reverse_dim='horizontal')
        
        # Vertical scanning (column-major order)
        x_col = rearrange(x, 'b (h w) d -> b (w h) d', h=H, w=W)
        x_top_down = self.ssm_vertical(x_col)
        x_top_down = rearrange(x_top_down, 'b (w h) d -> b (h w) d', h=H, w=W)
        
        x_bot_up = self._reverse_scan(x_col, self.ssm_vertical, W, H, reverse_dim='vertical')
        x_bot_up = rearrange(x_bot_up, 'b (w h) d -> b (h w) d', h=H, w=W)
        
        # Learned fusion
        combined = (x_fwd + x_bwd + x_top_down + x_bot_up) / 4.0
        gate = self.fusion_gate(x)
        
        out = gate * combined + (1 - gate) * x
        return self.norm(out)
    
    def _reverse_scan(self, x, ssm, H, W, reverse_dim):
        """Scan in reverse direction"""
        x_rev = x.flip(dims=[1])
        y_rev = ssm(x_rev)
        return y_rev.flip(dims=[1])


# ============================================================================
# Mix-FFN with Depthwise Convolution (from SANA, proven effective)
# ============================================================================

class MixFFN(nn.Module):
    """
    Feed-forward network with depthwise convolution for local feature mixing.
    
    From SANA: "depth-wise convolution enhances the model's ability to capture 
    local information, compensating for the weaker local information-capturing 
    ability of linear attention"
    
    Architecture: Linear → DWConv3x3 → GELU → Gate → Linear
    This is an inverted bottleneck with gating.
    """
    
    def __init__(self, d_model: int, expand_ratio: float = 2.5):
        super().__init__()
        d_inner = int(d_model * expand_ratio)
        
        # Inverted bottleneck with gating
        self.fc1 = nn.Linear(d_model, d_inner * 2)  # *2 for gating
        self.dwconv = nn.Conv2d(
            d_inner, d_inner, kernel_size=3, padding=1, 
            groups=d_inner, bias=True
        )
        self.fc2 = nn.Linear(d_inner, d_model)
        self.norm = nn.LayerNorm(d_inner)
        
    def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
        """
        x: (B, H*W, D)
        Returns: (B, H*W, D)
        """
        B, L, D = x.shape
        
        # Split into value and gate
        xg = self.fc1(x)
        x_val, x_gate = xg.chunk(2, dim=-1)  # Each (B, L, d_inner)
        
        # Depthwise conv on value branch (needs 2D reshape)
        x_val = rearrange(x_val, 'b (h w) d -> b d h w', h=H, w=W)
        x_val = self.dwconv(x_val)
        x_val = rearrange(x_val, 'b d h w -> b (h w) d')
        
        # GLU gating
        x_val = self.norm(x_val)
        x_out = x_val * F.gelu(x_gate)
        
        return self.fc2(x_out)


# ============================================================================
# Hyper-Connection Module (from the Hyper-Connections paper)
# ============================================================================

class HyperConnection(nn.Module):
    """
    Hyper-connections generalize residual connections.
    
    Instead of fixed: y = x + F(x)
    We learn a connection matrix HC that can represent any blend of
    sequential and parallel layer arrangements.
    
    For expansion rate n:
        Input: split x into n copies [x_1, ..., x_n]
        HC matrix is (n+1) x (n+1), learnable
        [input_to_layer, output_1, ..., output_n] = HC @ [F(input_to_layer), x_1, ..., x_n]
    
    This subsumes both Pre-Norm and Post-Norm residual connections,
    and can learn arrangements that are neither purely sequential nor parallel.
    """
    
    def __init__(self, d_model: int, expansion_rate: int = 2):
        super().__init__()
        self.n = expansion_rate
        self.d_model = d_model
        
        # HC matrix: (n+1) x (n+1) 
        # Initialize close to residual connection
        init_matrix = torch.zeros(self.n + 1, self.n + 1)
        # Standard residual: input goes through, output adds
        init_matrix[0, 1] = 1.0  # layer input comes from first stream
        for i in range(1, self.n + 1):
            init_matrix[i, i] = 1.0  # identity for skip
            init_matrix[i, 0] = 1.0 / self.n  # add layer output
        
        self.hc_matrix = nn.Parameter(init_matrix)
        self.norm = nn.LayerNorm(d_model)
    
    def pre_forward(self, x_streams: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        x_streams: (B, L, n*D) - n parallel streams concatenated
        Returns: (layer_input, x_streams)
        """
        B, L, _ = x_streams.shape
        
        # Split into streams
        streams = x_streams.chunk(self.n, dim=-1)  # List of (B, L, D)
        
        # Compute layer input from HC matrix first column
        layer_input = sum(self.hc_matrix[0, i + 1] * streams[i] for i in range(self.n))
        layer_input = self.norm(layer_input)
        
        return layer_input, x_streams
    
    def post_forward(self, layer_output: torch.Tensor, x_streams: torch.Tensor) -> torch.Tensor:
        """
        Combine layer output with streams using HC matrix.
        """
        streams = x_streams.chunk(self.n, dim=-1)
        
        new_streams = []
        for i in range(self.n):
            new_stream = self.hc_matrix[i + 1, 0] * layer_output
            for j in range(self.n):
                new_stream = new_stream + self.hc_matrix[i + 1, j + 1] * streams[j]
            new_streams.append(new_stream)
        
        return torch.cat(new_streams, dim=-1)
    
    def init_streams(self, x: torch.Tensor) -> torch.Tensor:
        """Initialize n streams from single input"""
        return x.repeat(1, 1, self.n)


# ============================================================================
# AdaLN-Zero Conditioning (from DiT, proven optimal for diffusion)
# ============================================================================

class AdaLNZero(nn.Module):
    """
    Adaptive Layer Normalization with zero initialization.
    
    Conditions each layer on timestep and text embeddings.
    From DiT: "regresses dimensionwise scale and shift parameters 
    from the sum of the embedding vectors"
    
    Zero initialization ensures the network acts as identity at init,
    critical for training stability.
    """
    
    def __init__(self, d_model: int, d_cond: int):
        super().__init__()
        self.norm = nn.LayerNorm(d_model, elementwise_affine=False)
        
        # Predict scale (γ), shift (β), and gate (α) - 6 values per element
        self.proj = nn.Sequential(
            nn.SiLU(),
            nn.Linear(d_cond, 6 * d_model)
        )
        
        # Zero-initialize the projection
        nn.init.zeros_(self.proj[1].weight)
        nn.init.zeros_(self.proj[1].bias)
    
    def forward(self, x: torch.Tensor, cond: torch.Tensor):
        """
        x: (B, L, D)
        cond: (B, d_cond)
        Returns: shift1, scale1, gate1, shift2, scale2, gate2
        """
        params = self.proj(cond)  # (B, 6D)
        params = params.unsqueeze(1)  # (B, 1, 6D)
        shift1, scale1, gate1, shift2, scale2, gate2 = params.chunk(6, dim=-1)
        return shift1, scale1, gate1, shift2, scale2, gate2
    
    def modulate(self, x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
        return self.norm(x) * (1 + scale) + shift


# ============================================================================
# LiRA Block: The Core Processing Unit
# ============================================================================

class LiRABlock(nn.Module):
    """
    One LiRA block = Bidirectional SSM + Mix-FFN, with:
    - AdaLN-Zero conditioning
    - Hyper-connections for dynamic layer arrangement
    
    This replaces transformer blocks with O(N) complexity while maintaining
    the quality of O(N^2) attention through:
    1. Selective state spaces (content-aware)
    2. Bidirectional scanning (full spatial context)
    3. Mix-FFN (local feature enhancement via DWConv)
    """
    
    def __init__(self, d_model: int, d_cond: int, d_state: int = 16, 
                 ffn_expand: float = 2.5, hc_expansion: int = 2):
        super().__init__()
        
        # Conditioning
        self.adaln = AdaLNZero(d_model, d_cond)
        
        # Bidirectional State-Space Scanner
        self.scanner = BidirectionalSpatialScanner(d_model, d_state)
        
        # Mix-FFN for local features
        self.ffn = MixFFN(d_model, ffn_expand)
        
        # Layer norms (pre-norm style)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
    
    def forward(self, x: torch.Tensor, cond: torch.Tensor, H: int, W: int) -> torch.Tensor:
        """
        x: (B, H*W, D)
        cond: (B, d_cond) - conditioning vector (timestep + text)
        Returns: (B, H*W, D)
        """
        # Get conditioning parameters
        shift1, scale1, gate1, shift2, scale2, gate2 = self.adaln(x, cond)
        
        # SSM branch with AdaLN conditioning
        x_mod = self.adaln.modulate(x, shift1, scale1)
        x_ssm = self.scanner(x_mod, H, W)
        x = x + gate1 * x_ssm
        
        # FFN branch with AdaLN conditioning
        x_mod = self.adaln.modulate(x, shift2, scale2)
        x_ffn = self.ffn(x_mod, H, W)
        x = x + gate2 * x_ffn
        
        return x


# ============================================================================
# Cross-Modal Fusion: Text → Image conditioning via Gated Cross-State
# ============================================================================

class GatedCrossStateFusion(nn.Module):
    """
    Novel cross-modal fusion inspired by CrossWKV (from RWKV-7 paper).
    
    Instead of expensive cross-attention (O(N*M) where N=image, M=text tokens),
    we use a state-based cross-modal mechanism:
    
    1. Compress text into a fixed-size state matrix S_text via SSM over text tokens
    2. Inject S_text into image SSM states via gated addition
    3. This gives O(M + N) complexity instead of O(N*M)
    
    Mathematical formulation:
        S_text = SSM_text(text_tokens)     → (D, d_state) state matrix
        For each image token x_i:
            h_i = A_i * h_{i-1} + B_i * x_i + G_i * S_text * r_i
        Where G_i is a learned gate and r_i is a receptance vector.
    """
    
    def __init__(self, d_model: int, d_text: int, d_state: int = 16, num_heads: int = 8):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        # Text state compression
        self.text_proj = nn.Linear(d_text, d_model)
        self.text_key = nn.Linear(d_model, d_model, bias=False)
        self.text_value = nn.Linear(d_model, d_model, bias=False)
        
        # Image query
        self.image_query = nn.Linear(d_model, d_model, bias=False)
        
        # Gating mechanism
        self.gate = nn.Sequential(
            nn.Linear(d_model * 2, d_model),
            nn.Sigmoid()
        )
        
        # Output projection
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, x_image: torch.Tensor, x_text: torch.Tensor) -> torch.Tensor:
        """
        x_image: (B, N, D) - image features
        x_text: (B, M, D_text) - text features 
        Returns: (B, N, D) - text-conditioned image features
        """
        B, N, D = x_image.shape
        
        # Project text to model dimension
        text_feat = self.text_proj(x_text)  # (B, M, D)
        
        # Compute text summary using mean pooling + per-head KV
        # This compresses all text into a single KV state per head
        text_k = self.text_key(text_feat)  # (B, M, D)
        text_v = self.text_value(text_feat)  # (B, M, D)
        
        # Reshape to heads
        text_k = rearrange(text_k, 'b m (h d) -> b h m d', h=self.num_heads)
        text_v = rearrange(text_v, 'b m (h d) -> b h m d', h=self.num_heads)
        
        # Compute text state: S = K^T V / M (compressed representation)
        # This is O(M * d^2) which is very small for typical M (77 tokens)
        text_state = torch.einsum('bhmd,bhmk->bhdk', text_k, text_v) / text_k.shape[2]
        
        # Image queries
        img_q = self.image_query(x_image)  # (B, N, D)
        img_q = rearrange(img_q, 'b n (h d) -> b h n d', h=self.num_heads)
        
        # Query the text state: y = Q * S
        cross_out = torch.einsum('bhnd,bhdk->bhnk', img_q, text_state)
        cross_out = rearrange(cross_out, 'b h n d -> b n (h d)')
        
        # Gated fusion
        gate = self.gate(torch.cat([x_image, cross_out], dim=-1))
        out = x_image + gate * cross_out
        
        return self.norm(out)


# ============================================================================
# Latent Reasoning Loop (The Novel Core Innovation)
# ============================================================================

class LatentReasoningLoop(nn.Module):
    """
    NOVEL CONTRIBUTION: Iterative reasoning in latent space for image generation.
    
    Inspired by Liquid Reasoning Transformer (LRT), but adapted for generative models.
    
    Key insight: Image generation benefits from iterative refinement. Instead of
    a fixed number of denoising steps (expensive), we add a CHEAP inner reasoning
    loop that refines the latent representation before final prediction.
    
    How it works:
    1. A "reasoning state" r_t evolves over T_think iterations
    2. Each iteration applies a lightweight SSM + FFN to refine r_t
    3. A DISCARD GATE filters bad updates (prevents error accumulation)
    4. A STOP GATE halts early for easy inputs (adaptive compute)
    5. The final r_T is used to condition the denoising prediction
    
    This gives the model "thinking time" proportional to input difficulty:
    - Simple prompts / high noise levels → few reasoning steps
    - Complex prompts / fine detail refinement → more reasoning steps
    
    Mathematical formulation:
        r_0 = MLP(concat(z_t, c_text, t_embed))
        For t in 1..T_max:
            r_proposal = SSM_think(concat(z_tokens, r_t))  
            u_t = MLP(r_proposal)                           # candidate update
            d_t = σ(W_d [r_{t-1}; u_t])                   # discard gate
            r_t = (1-d_t) * u_t + d_t * r_{t-1}           # filtered update
            s_t = σ(W_s r_t)                               # stop gate
            if s_t > τ: break
    
    Cost: T_think iterations of a SMALL network (1/10th of main backbone)
    Typical T_think: 2-8 steps (learned, not fixed)
    """
    
    def __init__(self, d_model: int, d_reason: int = 128, max_steps: int = 8):
        super().__init__()
        self.d_reason = d_reason
        self.max_steps = max_steps
        
        # Initialize reasoning state from input
        self.state_init = nn.Sequential(
            nn.Linear(d_model, d_reason * 2),
            nn.GELU(),
            nn.Linear(d_reason * 2, d_reason)
        )
        
        # Lightweight reasoning block (intentionally small)
        self.reason_ssm = SelectiveStateSpace(d_reason, d_state=8, d_conv=3)
        self.reason_ffn = nn.Sequential(
            nn.Linear(d_reason, d_reason * 2),
            nn.GELU(),
            nn.Linear(d_reason * 2, d_reason)
        )
        self.reason_norm = nn.LayerNorm(d_reason)
        
        # Discard gate: reject bad updates
        self.discard_gate = nn.Sequential(
            nn.Linear(d_reason * 2, d_reason),
            nn.Sigmoid()
        )
        
        # Stop gate: halt when converged
        self.stop_gate = nn.Sequential(
            nn.Linear(d_reason, 1),
            nn.Sigmoid()
        )
        self.stop_threshold = 0.8  # learnable threshold
        
        # Project reasoning state back to condition the main network
        self.reason_proj = nn.Linear(d_reason, d_model)
    
    def forward(self, x: torch.Tensor, return_steps: bool = False) -> Tuple[torch.Tensor, dict]:
        """
        x: (B, L, D) - input features (latent tokens + conditioning)
        Returns: (B, D_model) reasoning conditioning vector, info dict
        """
        B = x.shape[0]
        
        # Initialize reasoning state from global average of input
        x_global = x.mean(dim=1)  # (B, D)
        r = self.state_init(x_global)  # (B, d_reason)
        
        info = {'steps': [], 'discard_rates': [], 'stop_values': []}
        
        # Iterative reasoning loop
        total_steps = 0
        for step in range(self.max_steps):
            # Expand reasoning state and process with SSM
            r_expanded = r.unsqueeze(1).expand(-1, x.shape[1], -1)  # (B, L, d_reason)
            
            # Lightweight processing
            r_processed = self.reason_ssm(self.reason_norm(r_expanded))
            r_proposal = self.reason_ffn(r_processed.mean(dim=1))  # (B, d_reason)
            
            # Discard gate
            d = self.discard_gate(torch.cat([r, r_proposal], dim=-1))
            r_new = d * r + (1 - d) * r_proposal
            
            # Stop gate
            s = self.stop_gate(r_new).squeeze(-1)  # (B,)
            
            info['discard_rates'].append(d.mean().item())
            info['stop_values'].append(s.mean().item())
            
            r = r_new
            total_steps += 1
            
            # In inference, stop if all batch elements want to stop
            if not self.training and (s > self.stop_threshold).all():
                break
        
        info['total_steps'] = total_steps
        
        # Project to conditioning dimension
        cond = self.reason_proj(r)  # (B, D_model)
        return cond, info


# ============================================================================
# Timestep + Text Embedding
# ============================================================================

class TimestepEmbedding(nn.Module):
    """
    Sinusoidal timestep embedding with MLP projection.
    Standard approach from DDPM, with the addition of frequency scaling
    for better coverage of the continuous [0,1] range used in flow matching.
    """
    
    def __init__(self, d_model: int, max_period: int = 10000):
        super().__init__()
        self.d_model = d_model
        self.max_period = max_period
        
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.SiLU(),
            nn.Linear(d_model * 4, d_model)
        )
    
    def forward(self, t: torch.Tensor) -> torch.Tensor:
        """
        t: (B,) timestep values in [0, 1]
        Returns: (B, d_model)
        """
        half_dim = self.d_model // 2
        freqs = torch.exp(
            -math.log(self.max_period) * torch.arange(half_dim, device=t.device).float() / half_dim
        )
        args = t.unsqueeze(1) * freqs.unsqueeze(0) * 1000  # Scale for better range
        embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
        
        if self.d_model % 2:
            embedding = F.pad(embedding, (0, 1))
        
        return self.mlp(embedding)


class TextProjection(nn.Module):
    """
    Projects text encoder outputs to model dimension.
    Supports variable-length text with a pooled global + per-token output.
    """
    
    def __init__(self, d_text: int, d_model: int):
        super().__init__()
        self.proj = nn.Linear(d_text, d_model)
        self.pool_proj = nn.Linear(d_text, d_model)
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, text_features: torch.Tensor, text_mask: Optional[torch.Tensor] = None):
        """
        text_features: (B, M, D_text)
        text_mask: (B, M) boolean mask
        Returns: per_token (B, M, D), pooled (B, D)
        """
        per_token = self.norm(self.proj(text_features))
        
        if text_mask is not None:
            # Masked mean pooling
            mask = text_mask.unsqueeze(-1).float()
            pooled = (text_features * mask).sum(1) / mask.sum(1).clamp(min=1)
        else:
            pooled = text_features.mean(dim=1)
        
        pooled = self.pool_proj(pooled)
        return per_token, pooled