""" LiRA Model: Full Architecture Architecture Overview (Denoising Network): ========================================== Input: z_t (noisy latent, B x C x H x W) + t (timestep) + text_features | v ┌─────────────────────────┐ │ Patch Embedding │ Conv2d(C_lat, D, 1x1) - patchify │ + Freq Decomposition │ Optional: Haar wavelet split └────────────┬────────────┘ │ v ┌─────────────────────────┐ │ Latent Reasoning Loop │ 2-8 adaptive steps (learned) │ (generates reasoning │ → produces reasoning conditioning │ conditioning vector) │ Only ~128 dims, very cheap └────────────┬────────────┘ │ reasoning_cond + timestep_embed + text_pooled │ → combined conditioning vector v ┌─────────────────────────┐ │ N x LiRA Blocks │ Each block: │ (with HyperConnections)│ 1. AdaLN conditioning │ │ 2. Bidirectional SSM (4-dir scan) │ Every K blocks: │ 3. Mix-FFN (DWConv + GLU) │ → GatedCrossStateFusion│ 4. Hyper-connection routing └────────────┬────────────┘ │ v ┌─────────────────────────┐ │ Final Norm + Proj │ LayerNorm → Linear(D, C_lat) │ → velocity prediction │ Predicts v = ε - x_0 └─────────────────────────┘ Model Sizes: - LiRA-Tiny: D=384, N=12, ~50M params (for testing) - LiRA-Small: D=512, N=20, ~120M params (mobile-optimized) - LiRA-Base: D=768, N=28, ~300M params (quality-optimized) - LiRA-Large: D=1024, N=36, ~600M params (maximum quality) """ import torch import torch.nn as nn import torch.nn.functional as F import math from typing import Optional, Dict, Tuple from einops import rearrange from .core_modules import ( LiRABlock, GatedCrossStateFusion, LatentReasoningLoop, TimestepEmbedding, TextProjection, HyperConnection, ) # ============================================================================ # Patch Embedding for Latent Space # ============================================================================ class LatentPatchEmbed(nn.Module): """ Embeds latent space patches into model dimension. For DC-AE f32: latent is 32x32 for 1024px image, with 32 channels For SD3/FLUX f8: latent is 128x128 for 1024px, with 16 channels We use simple 1x1 conv (no spatial patchify) since the VAE already provides heavy spatial compression. Additional patching would lose spatial resolution in the latent space. However, for f8 VAEs (128x128 = 16384 tokens), we optionally use 2x2 patches to reduce to 64x64 = 4096 tokens. """ def __init__(self, in_channels: int, d_model: int, patch_size: int = 1): super().__init__() self.patch_size = patch_size self.proj = nn.Conv2d( in_channels, d_model, kernel_size=patch_size, stride=patch_size ) self.norm = nn.LayerNorm(d_model) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, int, int]: """ x: (B, C, H, W) latent features Returns: (B, H'*W', D), H', W' """ x = self.proj(x) # (B, D, H', W') B, D, H, W = x.shape x = rearrange(x, 'b d h w -> b (h w) d') x = self.norm(x) return x, H, W class LatentUnpatch(nn.Module): """Reverse of LatentPatchEmbed: project back and reshape""" def __init__(self, d_model: int, out_channels: int, patch_size: int = 1): super().__init__() self.patch_size = patch_size self.out_channels = out_channels self.norm = nn.LayerNorm(d_model) if patch_size > 1: # Use pixel shuffle for upsampling self.proj = nn.Linear(d_model, out_channels * patch_size * patch_size) else: self.proj = nn.Linear(d_model, out_channels) def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: """ x: (B, H'*W', D) Returns: (B, C, H_orig, W_orig) """ x = self.norm(x) x = self.proj(x) # (B, H'*W', C*p*p) x = rearrange(x, 'b (h w) d -> b d h w', h=H, w=W) if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size) return x # ============================================================================ # LiRA Denoising Network # ============================================================================ class LiRAModel(nn.Module): """ LiRA: Liquid Reasoning Artisan - Main Denoising Network Novel architecture combining: 1. State-space backbone (O(N) complexity) 2. Latent reasoning loop (adaptive compute) 3. Hyper-connections (dynamic layer arrangement) 4. Gated cross-state text fusion (efficient cross-modal) 5. Mix-FFN (local feature enhancement) Designed for mobile deployment: - No quadratic attention anywhere - All operations are O(N) in sequence length - Compact parameter count (<400M for Base) - Native 1024px via f32 VAE (32x32 = 1024 tokens) """ # Predefined configurations CONFIGS = { 'tiny': { 'd_model': 384, 'n_blocks': 12, 'd_state': 8, 'd_reason': 96, 'max_reason_steps': 4, 'ffn_expand': 2.0, 'cross_every': 4, 'hc_expansion': 2, 'num_heads': 6, }, 'small': { 'd_model': 512, 'n_blocks': 20, 'd_state': 16, 'd_reason': 128, 'max_reason_steps': 6, 'ffn_expand': 2.5, 'cross_every': 4, 'hc_expansion': 2, 'num_heads': 8, }, 'base': { 'd_model': 768, 'n_blocks': 28, 'd_state': 16, 'd_reason': 192, 'max_reason_steps': 8, 'ffn_expand': 2.5, 'cross_every': 4, 'hc_expansion': 2, 'num_heads': 12, }, 'large': { 'd_model': 1024, 'n_blocks': 36, 'd_state': 16, 'd_reason': 256, 'max_reason_steps': 8, 'ffn_expand': 3.0, 'cross_every': 4, 'hc_expansion': 2, 'num_heads': 16, }, } def __init__( self, config_name: str = 'small', in_channels: int = 32, # DC-AE f32c32 latent channels d_text: int = 768, # Text encoder dimension (CLIP or small LLM) patch_size: int = 1, # Patch size for latent tokens **kwargs ): super().__init__() # Get config if config_name in self.CONFIGS: config = {**self.CONFIGS[config_name], **kwargs} else: config = kwargs self.d_model = config['d_model'] self.n_blocks = config['n_blocks'] self.d_state = config['d_state'] self.d_reason = config['d_reason'] self.cross_every = config['cross_every'] self.in_channels = in_channels d_cond = self.d_model # Conditioning dimension # ====== Input Processing ====== self.patch_embed = LatentPatchEmbed(in_channels, self.d_model, patch_size) self.unpatch = LatentUnpatch(self.d_model, in_channels, patch_size) # ====== Conditioning ====== self.time_embed = TimestepEmbedding(self.d_model) self.text_proj = TextProjection(d_text, self.d_model) # Combine timestep + text pooled + reasoning into single conditioning vector self.cond_combine = nn.Sequential( nn.Linear(self.d_model * 3, self.d_model * 2), nn.SiLU(), nn.Linear(self.d_model * 2, self.d_model) ) # ====== Latent Reasoning Loop ====== self.reasoning = LatentReasoningLoop( self.d_model, config['d_reason'], config['max_reason_steps'] ) # ====== Main Backbone: LiRA Blocks ====== self.blocks = nn.ModuleList() self.cross_fusions = nn.ModuleDict() for i in range(self.n_blocks): self.blocks.append(LiRABlock( d_model=self.d_model, d_cond=d_cond, d_state=self.d_state, ffn_expand=config['ffn_expand'], hc_expansion=config['hc_expansion'], )) # Add cross-modal fusion every K blocks if (i + 1) % self.cross_every == 0: self.cross_fusions[str(i)] = GatedCrossStateFusion( self.d_model, self.d_model, self.d_state, config['num_heads'] ) # ====== Long Skip Connection (from U-ViT / DiM) ====== # Connect block i with block (n_blocks - 1 - i) via learned projection self.n_skip = self.n_blocks // 2 self.skip_projs = nn.ModuleList([ nn.Linear(self.d_model * 2, self.d_model) for _ in range(self.n_skip) ]) # ====== Output ====== self.final_norm = nn.LayerNorm(self.d_model) self.final_adaln = nn.Sequential( nn.SiLU(), nn.Linear(d_cond, 2 * self.d_model) ) nn.init.zeros_(self.final_adaln[1].weight) nn.init.zeros_(self.final_adaln[1].bias) # Initialize weights self._init_weights() def _init_weights(self): """Careful weight initialization for training stability""" for m in self.modules(): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) def forward( self, z_t: torch.Tensor, # (B, C, H, W) noisy latent t: torch.Tensor, # (B,) timestep in [0, 1] text_features: torch.Tensor, # (B, M, D_text) text encoder output text_mask: Optional[torch.Tensor] = None, # (B, M) mask ) -> Tuple[torch.Tensor, Dict]: """ Forward pass: predicts velocity v_t = ε - x_0 Returns: v_pred: (B, C, H, W) predicted velocity info: dict with reasoning stats """ B = z_t.shape[0] # ====== Embed inputs ====== x, H, W = self.patch_embed(z_t) # (B, N, D) t_emb = self.time_embed(t) # (B, D) text_tokens, text_pooled = self.text_proj(text_features, text_mask) # (B, M, D), (B, D) # ====== Latent Reasoning ====== reason_cond, reason_info = self.reasoning(x) # (B, D) # ====== Combine conditioning ====== cond = self.cond_combine(torch.cat([t_emb, text_pooled, reason_cond], dim=-1)) # (B, D) # ====== Main backbone with long skip connections ====== skip_features = [] for i, block in enumerate(self.blocks): # Store features for skip connections (first half) if i < self.n_skip: skip_features.append(x) # Apply LiRA block x = block(x, cond, H, W) # Apply cross-modal fusion if str(i) in self.cross_fusions: x = self.cross_fusions[str(i)](x, text_tokens) # Apply skip connections (second half) if i >= self.n_skip: skip_idx = self.n_blocks - 1 - i if skip_idx < len(skip_features): x = self.skip_projs[skip_idx]( torch.cat([x, skip_features[skip_idx]], dim=-1) ) # ====== Output projection ====== shift, scale = self.final_adaln(cond).unsqueeze(1).chunk(2, dim=-1) x = self.final_norm(x) * (1 + scale) + shift v_pred = self.unpatch(x, H, W) # (B, C, H_orig, W_orig) return v_pred, reason_info @torch.no_grad() def count_parameters(self) -> Dict[str, int]: """Count parameters by component""" counts = {} counts['patch_embed'] = sum(p.numel() for p in self.patch_embed.parameters()) counts['unpatch'] = sum(p.numel() for p in self.unpatch.parameters()) counts['time_embed'] = sum(p.numel() for p in self.time_embed.parameters()) counts['text_proj'] = sum(p.numel() for p in self.text_proj.parameters()) counts['reasoning'] = sum(p.numel() for p in self.reasoning.parameters()) counts['blocks'] = sum(p.numel() for p in self.blocks.parameters()) counts['cross_fusions'] = sum(p.numel() for p in self.cross_fusions.parameters()) counts['skip_projs'] = sum(p.numel() for p in self.skip_projs.parameters()) counts['conditioning'] = sum(p.numel() for p in self.cond_combine.parameters()) counts['output'] = ( sum(p.numel() for p in self.final_norm.parameters()) + sum(p.numel() for p in self.final_adaln.parameters()) ) counts['total'] = sum(p.numel() for p in self.parameters()) return counts # ============================================================================ # Tiny VAE Decoder for Mobile Deployment # ============================================================================ class TinyVAEDecoder(nn.Module): """ Ultra-lightweight VAE decoder inspired by SnapGen's tiny decoder. Key optimizations: 1. NO attention layers (saves massive memory) 2. Depthwise separable convolutions instead of full convolutions 3. Minimal GroupNorm (only where needed to prevent color shift) 4. PixelShuffle for upsampling (more efficient than transposed conv) For f32 VAE: 32x32 latent → 1024x1024 image (5 upsampling stages) For f8 VAE: 128x128 latent → 1024x1024 image (3 upsampling stages) Target: ~1.5M parameters, <5MB on disk """ def __init__( self, in_channels: int = 32, out_channels: int = 3, spatial_compression: int = 32, # 32 for f32, 8 for f8 base_channels: int = 64, ): super().__init__() num_upsample = int(math.log2(spatial_compression)) # 5 for f32, 3 for f8 layers = [] # Initial projection layers.append(nn.Conv2d(in_channels, base_channels, 3, padding=1)) layers.append(nn.SiLU()) # Upsampling stages - track channels carefully current_ch = base_channels for i in range(num_upsample): # Gradually reduce channels in later (higher-res) stages target_ch = max(base_channels // (2 ** max(0, i)), 16) # Depthwise separable residual block layers.append(SepConvBlock(current_ch, target_ch)) current_ch = target_ch # PixelShuffle upsample (2x): needs ch*4 input, outputs ch layers.append(nn.Conv2d(current_ch, current_ch * 4, 3, padding=1)) layers.append(nn.PixelShuffle(2)) # ch*4 → ch, spatial 2x layers.append(nn.SiLU()) # After PixelShuffle, channels stay at current_ch # Final output layers.append(nn.Conv2d(current_ch, out_channels, 3, padding=1)) layers.append(nn.Tanh()) # Output in [-1, 1] self.decoder = nn.Sequential(*layers) def forward(self, z: torch.Tensor) -> torch.Tensor: """ z: (B, C_lat, H_lat, W_lat) latent Returns: (B, 3, H_img, W_img) decoded image """ return self.decoder(z) class SepConvBlock(nn.Module): """Depthwise separable convolution block""" def __init__(self, in_ch, out_ch): super().__init__() self.dwconv = nn.Conv2d(in_ch, in_ch, 3, padding=1, groups=in_ch) self.pwconv = nn.Conv2d(in_ch, out_ch, 1) self.norm = nn.GroupNorm(min(8, out_ch), out_ch) self.act = nn.SiLU() self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity() def forward(self, x): residual = self.skip(x) x = self.dwconv(x) x = self.pwconv(x) x = self.norm(x) x = self.act(x) return x + residual # ============================================================================ # Complete LiRA Pipeline # ============================================================================ class LiRAPipeline(nn.Module): """ Complete LiRA pipeline combining: 1. Pretrained VAE encoder (frozen) - for encoding images to latent space 2. LiRA denoising network - the novel architecture 3. Tiny VAE decoder - for mobile deployment During training: image → VAE_encoder → z_0 → add_noise(z_0, t) → z_t → LiRA → v_pred During inference: noise → iterative_denoise(LiRA) → z_0 → TinyVAEDecoder → image """ def __init__( self, config_name: str = 'small', latent_channels: int = 32, spatial_compression: int = 32, d_text: int = 768, patch_size: int = 1, ): super().__init__() self.spatial_compression = spatial_compression self.latent_channels = latent_channels # Denoising network self.denoiser = LiRAModel( config_name=config_name, in_channels=latent_channels, d_text=d_text, patch_size=patch_size, ) # Tiny decoder for mobile inference self.tiny_decoder = TinyVAEDecoder( in_channels=latent_channels, spatial_compression=spatial_compression, ) def forward(self, *args, **kwargs): return self.denoiser(*args, **kwargs) def count_parameters(self): counts = self.denoiser.count_parameters() counts['tiny_decoder'] = sum(p.numel() for p in self.tiny_decoder.parameters()) counts['total_with_decoder'] = counts['total'] + counts['tiny_decoder'] return counts # ============================================================================ # Helper: Estimate memory usage # ============================================================================ def estimate_memory_mb(model: nn.Module, batch_size: int = 1, img_size: int = 1024, spatial_compression: int = 32, latent_channels: int = 32, dtype_bytes: int = 2): """Estimate inference memory usage in MB""" # Model parameters param_bytes = sum(p.numel() * dtype_bytes for p in model.parameters()) param_mb = param_bytes / (1024 ** 2) # Latent size lat_h = img_size // spatial_compression lat_w = img_size // spatial_compression latent_bytes = batch_size * latent_channels * lat_h * lat_w * dtype_bytes # Intermediate activations (rough estimate: 3x latent) activation_bytes = latent_bytes * 3 total_mb = param_mb + (latent_bytes + activation_bytes) / (1024 ** 2) return { 'params_mb': param_mb, 'latent_mb': latent_bytes / (1024 ** 2), 'activation_mb': activation_bytes / (1024 ** 2), 'total_inference_mb': total_mb, }