asdf98 commited on
Commit
18ce5a6
Β·
verified Β·
1 Parent(s): f5c1b06

Add lira/model.py

Browse files
Files changed (1) hide show
  1. lira/model.py +527 -0
lira/model.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LiRA Model: Full Architecture
3
+
4
+ Architecture Overview (Denoising Network):
5
+ ==========================================
6
+
7
+ Input: z_t (noisy latent, B x C x H x W) + t (timestep) + text_features
8
+ |
9
+ v
10
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
11
+ β”‚ Patch Embedding β”‚ Conv2d(C_lat, D, 1x1) - patchify
12
+ β”‚ + Freq Decomposition β”‚ Optional: Haar wavelet split
13
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
14
+ β”‚
15
+ v
16
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
17
+ β”‚ Latent Reasoning Loop β”‚ 2-8 adaptive steps (learned)
18
+ β”‚ (generates reasoning β”‚ β†’ produces reasoning conditioning
19
+ β”‚ conditioning vector) β”‚ Only ~128 dims, very cheap
20
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
21
+ β”‚ reasoning_cond + timestep_embed + text_pooled
22
+ β”‚ β†’ combined conditioning vector
23
+ v
24
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
25
+ β”‚ N x LiRA Blocks β”‚ Each block:
26
+ β”‚ (with HyperConnections)β”‚ 1. AdaLN conditioning
27
+ β”‚ β”‚ 2. Bidirectional SSM (4-dir scan)
28
+ β”‚ Every K blocks: β”‚ 3. Mix-FFN (DWConv + GLU)
29
+ β”‚ β†’ GatedCrossStateFusionβ”‚ 4. Hyper-connection routing
30
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
31
+ β”‚
32
+ v
33
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
34
+ β”‚ Final Norm + Proj β”‚ LayerNorm β†’ Linear(D, C_lat)
35
+ β”‚ β†’ velocity prediction β”‚ Predicts v = Ξ΅ - x_0
36
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
37
+
38
+ Model Sizes:
39
+ - LiRA-Tiny: D=384, N=12, ~50M params (for testing)
40
+ - LiRA-Small: D=512, N=20, ~120M params (mobile-optimized)
41
+ - LiRA-Base: D=768, N=28, ~300M params (quality-optimized)
42
+ - LiRA-Large: D=1024, N=36, ~600M params (maximum quality)
43
+ """
44
+
45
+ import torch
46
+ import torch.nn as nn
47
+ import torch.nn.functional as F
48
+ import math
49
+ from typing import Optional, Dict, Tuple
50
+ from einops import rearrange
51
+
52
+ from .core_modules import (
53
+ LiRABlock,
54
+ GatedCrossStateFusion,
55
+ LatentReasoningLoop,
56
+ TimestepEmbedding,
57
+ TextProjection,
58
+ HyperConnection,
59
+ )
60
+
61
+
62
+ # ============================================================================
63
+ # Patch Embedding for Latent Space
64
+ # ============================================================================
65
+
66
+ class LatentPatchEmbed(nn.Module):
67
+ """
68
+ Embeds latent space patches into model dimension.
69
+
70
+ For DC-AE f32: latent is 32x32 for 1024px image, with 32 channels
71
+ For SD3/FLUX f8: latent is 128x128 for 1024px, with 16 channels
72
+
73
+ We use simple 1x1 conv (no spatial patchify) since the VAE already
74
+ provides heavy spatial compression. Additional patching would lose
75
+ spatial resolution in the latent space.
76
+
77
+ However, for f8 VAEs (128x128 = 16384 tokens), we optionally use
78
+ 2x2 patches to reduce to 64x64 = 4096 tokens.
79
+ """
80
+
81
+ def __init__(self, in_channels: int, d_model: int, patch_size: int = 1):
82
+ super().__init__()
83
+ self.patch_size = patch_size
84
+ self.proj = nn.Conv2d(
85
+ in_channels, d_model,
86
+ kernel_size=patch_size, stride=patch_size
87
+ )
88
+ self.norm = nn.LayerNorm(d_model)
89
+
90
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
91
+ """
92
+ x: (B, C, H, W) latent features
93
+ Returns: (B, H'*W', D), H', W'
94
+ """
95
+ x = self.proj(x) # (B, D, H', W')
96
+ B, D, H, W = x.shape
97
+ x = rearrange(x, 'b d h w -> b (h w) d')
98
+ x = self.norm(x)
99
+ return x, H, W
100
+
101
+
102
+ class LatentUnpatch(nn.Module):
103
+ """Reverse of LatentPatchEmbed: project back and reshape"""
104
+
105
+ def __init__(self, d_model: int, out_channels: int, patch_size: int = 1):
106
+ super().__init__()
107
+ self.patch_size = patch_size
108
+ self.out_channels = out_channels
109
+ self.norm = nn.LayerNorm(d_model)
110
+
111
+ if patch_size > 1:
112
+ # Use pixel shuffle for upsampling
113
+ self.proj = nn.Linear(d_model, out_channels * patch_size * patch_size)
114
+ else:
115
+ self.proj = nn.Linear(d_model, out_channels)
116
+
117
+ def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
118
+ """
119
+ x: (B, H'*W', D)
120
+ Returns: (B, C, H_orig, W_orig)
121
+ """
122
+ x = self.norm(x)
123
+ x = self.proj(x) # (B, H'*W', C*p*p)
124
+
125
+ x = rearrange(x, 'b (h w) d -> b d h w', h=H, w=W)
126
+
127
+ if self.patch_size > 1:
128
+ x = F.pixel_shuffle(x, self.patch_size)
129
+
130
+ return x
131
+
132
+
133
+ # ============================================================================
134
+ # LiRA Denoising Network
135
+ # ============================================================================
136
+
137
+ class LiRAModel(nn.Module):
138
+ """
139
+ LiRA: Liquid Reasoning Artisan - Main Denoising Network
140
+
141
+ Novel architecture combining:
142
+ 1. State-space backbone (O(N) complexity)
143
+ 2. Latent reasoning loop (adaptive compute)
144
+ 3. Hyper-connections (dynamic layer arrangement)
145
+ 4. Gated cross-state text fusion (efficient cross-modal)
146
+ 5. Mix-FFN (local feature enhancement)
147
+
148
+ Designed for mobile deployment:
149
+ - No quadratic attention anywhere
150
+ - All operations are O(N) in sequence length
151
+ - Compact parameter count (<400M for Base)
152
+ - Native 1024px via f32 VAE (32x32 = 1024 tokens)
153
+ """
154
+
155
+ # Predefined configurations
156
+ CONFIGS = {
157
+ 'tiny': {
158
+ 'd_model': 384, 'n_blocks': 12, 'd_state': 8,
159
+ 'd_reason': 96, 'max_reason_steps': 4,
160
+ 'ffn_expand': 2.0, 'cross_every': 4,
161
+ 'hc_expansion': 2, 'num_heads': 6,
162
+ },
163
+ 'small': {
164
+ 'd_model': 512, 'n_blocks': 20, 'd_state': 16,
165
+ 'd_reason': 128, 'max_reason_steps': 6,
166
+ 'ffn_expand': 2.5, 'cross_every': 4,
167
+ 'hc_expansion': 2, 'num_heads': 8,
168
+ },
169
+ 'base': {
170
+ 'd_model': 768, 'n_blocks': 28, 'd_state': 16,
171
+ 'd_reason': 192, 'max_reason_steps': 8,
172
+ 'ffn_expand': 2.5, 'cross_every': 4,
173
+ 'hc_expansion': 2, 'num_heads': 12,
174
+ },
175
+ 'large': {
176
+ 'd_model': 1024, 'n_blocks': 36, 'd_state': 16,
177
+ 'd_reason': 256, 'max_reason_steps': 8,
178
+ 'ffn_expand': 3.0, 'cross_every': 4,
179
+ 'hc_expansion': 2, 'num_heads': 16,
180
+ },
181
+ }
182
+
183
+ def __init__(
184
+ self,
185
+ config_name: str = 'small',
186
+ in_channels: int = 32, # DC-AE f32c32 latent channels
187
+ d_text: int = 768, # Text encoder dimension (CLIP or small LLM)
188
+ patch_size: int = 1, # Patch size for latent tokens
189
+ **kwargs
190
+ ):
191
+ super().__init__()
192
+
193
+ # Get config
194
+ if config_name in self.CONFIGS:
195
+ config = {**self.CONFIGS[config_name], **kwargs}
196
+ else:
197
+ config = kwargs
198
+
199
+ self.d_model = config['d_model']
200
+ self.n_blocks = config['n_blocks']
201
+ self.d_state = config['d_state']
202
+ self.d_reason = config['d_reason']
203
+ self.cross_every = config['cross_every']
204
+ self.in_channels = in_channels
205
+
206
+ d_cond = self.d_model # Conditioning dimension
207
+
208
+ # ====== Input Processing ======
209
+ self.patch_embed = LatentPatchEmbed(in_channels, self.d_model, patch_size)
210
+ self.unpatch = LatentUnpatch(self.d_model, in_channels, patch_size)
211
+
212
+ # ====== Conditioning ======
213
+ self.time_embed = TimestepEmbedding(self.d_model)
214
+ self.text_proj = TextProjection(d_text, self.d_model)
215
+
216
+ # Combine timestep + text pooled + reasoning into single conditioning vector
217
+ self.cond_combine = nn.Sequential(
218
+ nn.Linear(self.d_model * 3, self.d_model * 2),
219
+ nn.SiLU(),
220
+ nn.Linear(self.d_model * 2, self.d_model)
221
+ )
222
+
223
+ # ====== Latent Reasoning Loop ======
224
+ self.reasoning = LatentReasoningLoop(
225
+ self.d_model, config['d_reason'], config['max_reason_steps']
226
+ )
227
+
228
+ # ====== Main Backbone: LiRA Blocks ======
229
+ self.blocks = nn.ModuleList()
230
+ self.cross_fusions = nn.ModuleDict()
231
+
232
+ for i in range(self.n_blocks):
233
+ self.blocks.append(LiRABlock(
234
+ d_model=self.d_model,
235
+ d_cond=d_cond,
236
+ d_state=self.d_state,
237
+ ffn_expand=config['ffn_expand'],
238
+ hc_expansion=config['hc_expansion'],
239
+ ))
240
+
241
+ # Add cross-modal fusion every K blocks
242
+ if (i + 1) % self.cross_every == 0:
243
+ self.cross_fusions[str(i)] = GatedCrossStateFusion(
244
+ self.d_model, self.d_model, self.d_state, config['num_heads']
245
+ )
246
+
247
+ # ====== Long Skip Connection (from U-ViT / DiM) ======
248
+ # Connect block i with block (n_blocks - 1 - i) via learned projection
249
+ self.n_skip = self.n_blocks // 2
250
+ self.skip_projs = nn.ModuleList([
251
+ nn.Linear(self.d_model * 2, self.d_model)
252
+ for _ in range(self.n_skip)
253
+ ])
254
+
255
+ # ====== Output ======
256
+ self.final_norm = nn.LayerNorm(self.d_model)
257
+ self.final_adaln = nn.Sequential(
258
+ nn.SiLU(),
259
+ nn.Linear(d_cond, 2 * self.d_model)
260
+ )
261
+ nn.init.zeros_(self.final_adaln[1].weight)
262
+ nn.init.zeros_(self.final_adaln[1].bias)
263
+
264
+ # Initialize weights
265
+ self._init_weights()
266
+
267
+ def _init_weights(self):
268
+ """Careful weight initialization for training stability"""
269
+ for m in self.modules():
270
+ if isinstance(m, nn.Linear):
271
+ nn.init.trunc_normal_(m.weight, std=0.02)
272
+ if m.bias is not None:
273
+ nn.init.zeros_(m.bias)
274
+ elif isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
275
+ nn.init.trunc_normal_(m.weight, std=0.02)
276
+ if m.bias is not None:
277
+ nn.init.zeros_(m.bias)
278
+
279
+ def forward(
280
+ self,
281
+ z_t: torch.Tensor, # (B, C, H, W) noisy latent
282
+ t: torch.Tensor, # (B,) timestep in [0, 1]
283
+ text_features: torch.Tensor, # (B, M, D_text) text encoder output
284
+ text_mask: Optional[torch.Tensor] = None, # (B, M) mask
285
+ ) -> Tuple[torch.Tensor, Dict]:
286
+ """
287
+ Forward pass: predicts velocity v_t = Ξ΅ - x_0
288
+
289
+ Returns:
290
+ v_pred: (B, C, H, W) predicted velocity
291
+ info: dict with reasoning stats
292
+ """
293
+ B = z_t.shape[0]
294
+
295
+ # ====== Embed inputs ======
296
+ x, H, W = self.patch_embed(z_t) # (B, N, D)
297
+ t_emb = self.time_embed(t) # (B, D)
298
+ text_tokens, text_pooled = self.text_proj(text_features, text_mask) # (B, M, D), (B, D)
299
+
300
+ # ====== Latent Reasoning ======
301
+ reason_cond, reason_info = self.reasoning(x) # (B, D)
302
+
303
+ # ====== Combine conditioning ======
304
+ cond = self.cond_combine(torch.cat([t_emb, text_pooled, reason_cond], dim=-1)) # (B, D)
305
+
306
+ # ====== Main backbone with long skip connections ======
307
+ skip_features = []
308
+
309
+ for i, block in enumerate(self.blocks):
310
+ # Store features for skip connections (first half)
311
+ if i < self.n_skip:
312
+ skip_features.append(x)
313
+
314
+ # Apply LiRA block
315
+ x = block(x, cond, H, W)
316
+
317
+ # Apply cross-modal fusion
318
+ if str(i) in self.cross_fusions:
319
+ x = self.cross_fusions[str(i)](x, text_tokens)
320
+
321
+ # Apply skip connections (second half)
322
+ if i >= self.n_skip:
323
+ skip_idx = self.n_blocks - 1 - i
324
+ if skip_idx < len(skip_features):
325
+ x = self.skip_projs[skip_idx](
326
+ torch.cat([x, skip_features[skip_idx]], dim=-1)
327
+ )
328
+
329
+ # ====== Output projection ======
330
+ shift, scale = self.final_adaln(cond).unsqueeze(1).chunk(2, dim=-1)
331
+ x = self.final_norm(x) * (1 + scale) + shift
332
+
333
+ v_pred = self.unpatch(x, H, W) # (B, C, H_orig, W_orig)
334
+
335
+ return v_pred, reason_info
336
+
337
+ @torch.no_grad()
338
+ def count_parameters(self) -> Dict[str, int]:
339
+ """Count parameters by component"""
340
+ counts = {}
341
+ counts['patch_embed'] = sum(p.numel() for p in self.patch_embed.parameters())
342
+ counts['unpatch'] = sum(p.numel() for p in self.unpatch.parameters())
343
+ counts['time_embed'] = sum(p.numel() for p in self.time_embed.parameters())
344
+ counts['text_proj'] = sum(p.numel() for p in self.text_proj.parameters())
345
+ counts['reasoning'] = sum(p.numel() for p in self.reasoning.parameters())
346
+ counts['blocks'] = sum(p.numel() for p in self.blocks.parameters())
347
+ counts['cross_fusions'] = sum(p.numel() for p in self.cross_fusions.parameters())
348
+ counts['skip_projs'] = sum(p.numel() for p in self.skip_projs.parameters())
349
+ counts['conditioning'] = sum(p.numel() for p in self.cond_combine.parameters())
350
+ counts['output'] = (
351
+ sum(p.numel() for p in self.final_norm.parameters()) +
352
+ sum(p.numel() for p in self.final_adaln.parameters())
353
+ )
354
+ counts['total'] = sum(p.numel() for p in self.parameters())
355
+ return counts
356
+
357
+
358
+ # ============================================================================
359
+ # Tiny VAE Decoder for Mobile Deployment
360
+ # ============================================================================
361
+
362
+ class TinyVAEDecoder(nn.Module):
363
+ """
364
+ Ultra-lightweight VAE decoder inspired by SnapGen's tiny decoder.
365
+
366
+ Key optimizations:
367
+ 1. NO attention layers (saves massive memory)
368
+ 2. Depthwise separable convolutions instead of full convolutions
369
+ 3. Minimal GroupNorm (only where needed to prevent color shift)
370
+ 4. PixelShuffle for upsampling (more efficient than transposed conv)
371
+
372
+ For f32 VAE: 32x32 latent β†’ 1024x1024 image (5 upsampling stages)
373
+ For f8 VAE: 128x128 latent β†’ 1024x1024 image (3 upsampling stages)
374
+
375
+ Target: ~1.5M parameters, <5MB on disk
376
+ """
377
+
378
+ def __init__(
379
+ self,
380
+ in_channels: int = 32,
381
+ out_channels: int = 3,
382
+ spatial_compression: int = 32, # 32 for f32, 8 for f8
383
+ base_channels: int = 64,
384
+ ):
385
+ super().__init__()
386
+
387
+ num_upsample = int(math.log2(spatial_compression)) # 5 for f32, 3 for f8
388
+
389
+ layers = []
390
+
391
+ # Initial projection
392
+ layers.append(nn.Conv2d(in_channels, base_channels, 3, padding=1))
393
+ layers.append(nn.SiLU())
394
+
395
+ # Upsampling stages - track channels carefully
396
+ current_ch = base_channels
397
+ for i in range(num_upsample):
398
+ # Gradually reduce channels in later (higher-res) stages
399
+ target_ch = max(base_channels // (2 ** max(0, i)), 16)
400
+
401
+ # Depthwise separable residual block
402
+ layers.append(SepConvBlock(current_ch, target_ch))
403
+ current_ch = target_ch
404
+
405
+ # PixelShuffle upsample (2x): needs ch*4 input, outputs ch
406
+ layers.append(nn.Conv2d(current_ch, current_ch * 4, 3, padding=1))
407
+ layers.append(nn.PixelShuffle(2)) # ch*4 β†’ ch, spatial 2x
408
+ layers.append(nn.SiLU())
409
+ # After PixelShuffle, channels stay at current_ch
410
+
411
+ # Final output
412
+ layers.append(nn.Conv2d(current_ch, out_channels, 3, padding=1))
413
+ layers.append(nn.Tanh()) # Output in [-1, 1]
414
+
415
+ self.decoder = nn.Sequential(*layers)
416
+
417
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
418
+ """
419
+ z: (B, C_lat, H_lat, W_lat) latent
420
+ Returns: (B, 3, H_img, W_img) decoded image
421
+ """
422
+ return self.decoder(z)
423
+
424
+
425
+ class SepConvBlock(nn.Module):
426
+ """Depthwise separable convolution block"""
427
+
428
+ def __init__(self, in_ch, out_ch):
429
+ super().__init__()
430
+ self.dwconv = nn.Conv2d(in_ch, in_ch, 3, padding=1, groups=in_ch)
431
+ self.pwconv = nn.Conv2d(in_ch, out_ch, 1)
432
+ self.norm = nn.GroupNorm(min(8, out_ch), out_ch)
433
+ self.act = nn.SiLU()
434
+ self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
435
+
436
+ def forward(self, x):
437
+ residual = self.skip(x)
438
+ x = self.dwconv(x)
439
+ x = self.pwconv(x)
440
+ x = self.norm(x)
441
+ x = self.act(x)
442
+ return x + residual
443
+
444
+
445
+ # ============================================================================
446
+ # Complete LiRA Pipeline
447
+ # ============================================================================
448
+
449
+ class LiRAPipeline(nn.Module):
450
+ """
451
+ Complete LiRA pipeline combining:
452
+ 1. Pretrained VAE encoder (frozen) - for encoding images to latent space
453
+ 2. LiRA denoising network - the novel architecture
454
+ 3. Tiny VAE decoder - for mobile deployment
455
+
456
+ During training:
457
+ image β†’ VAE_encoder β†’ z_0 β†’ add_noise(z_0, t) β†’ z_t β†’ LiRA β†’ v_pred
458
+
459
+ During inference:
460
+ noise β†’ iterative_denoise(LiRA) β†’ z_0 β†’ TinyVAEDecoder β†’ image
461
+ """
462
+
463
+ def __init__(
464
+ self,
465
+ config_name: str = 'small',
466
+ latent_channels: int = 32,
467
+ spatial_compression: int = 32,
468
+ d_text: int = 768,
469
+ patch_size: int = 1,
470
+ ):
471
+ super().__init__()
472
+
473
+ self.spatial_compression = spatial_compression
474
+ self.latent_channels = latent_channels
475
+
476
+ # Denoising network
477
+ self.denoiser = LiRAModel(
478
+ config_name=config_name,
479
+ in_channels=latent_channels,
480
+ d_text=d_text,
481
+ patch_size=patch_size,
482
+ )
483
+
484
+ # Tiny decoder for mobile inference
485
+ self.tiny_decoder = TinyVAEDecoder(
486
+ in_channels=latent_channels,
487
+ spatial_compression=spatial_compression,
488
+ )
489
+
490
+ def forward(self, *args, **kwargs):
491
+ return self.denoiser(*args, **kwargs)
492
+
493
+ def count_parameters(self):
494
+ counts = self.denoiser.count_parameters()
495
+ counts['tiny_decoder'] = sum(p.numel() for p in self.tiny_decoder.parameters())
496
+ counts['total_with_decoder'] = counts['total'] + counts['tiny_decoder']
497
+ return counts
498
+
499
+
500
+ # ============================================================================
501
+ # Helper: Estimate memory usage
502
+ # ============================================================================
503
+
504
+ def estimate_memory_mb(model: nn.Module, batch_size: int = 1,
505
+ img_size: int = 1024, spatial_compression: int = 32,
506
+ latent_channels: int = 32, dtype_bytes: int = 2):
507
+ """Estimate inference memory usage in MB"""
508
+ # Model parameters
509
+ param_bytes = sum(p.numel() * dtype_bytes for p in model.parameters())
510
+ param_mb = param_bytes / (1024 ** 2)
511
+
512
+ # Latent size
513
+ lat_h = img_size // spatial_compression
514
+ lat_w = img_size // spatial_compression
515
+ latent_bytes = batch_size * latent_channels * lat_h * lat_w * dtype_bytes
516
+
517
+ # Intermediate activations (rough estimate: 3x latent)
518
+ activation_bytes = latent_bytes * 3
519
+
520
+ total_mb = param_mb + (latent_bytes + activation_bytes) / (1024 ** 2)
521
+
522
+ return {
523
+ 'params_mb': param_mb,
524
+ 'latent_mb': latent_bytes / (1024 ** 2),
525
+ 'activation_mb': activation_bytes / (1024 ** 2),
526
+ 'total_inference_mb': total_mb,
527
+ }