| """ |
| LiRA Model: Full Architecture |
| |
| Architecture Overview (Denoising Network): |
| ========================================== |
| |
| Input: z_t (noisy latent, B x C x H x W) + t (timestep) + text_features |
| | |
| v |
| βββββββββββββββββββββββββββ |
| β Patch Embedding β Conv2d(C_lat, D, 1x1) - patchify |
| β + Freq Decomposition β Optional: Haar wavelet split |
| ββββββββββββββ¬βββββββββββββ |
| β |
| v |
| βββββββββββββββββββββββββββ |
| β Latent Reasoning Loop β 2-8 adaptive steps (learned) |
| β (generates reasoning β β produces reasoning conditioning |
| β conditioning vector) β Only ~128 dims, very cheap |
| ββββββββββββββ¬βββββββββββββ |
| β reasoning_cond + timestep_embed + text_pooled |
| β β combined conditioning vector |
| v |
| βββββββββββββββββββββββββββ |
| β N x LiRA Blocks β Each block: |
| β (with HyperConnections)β 1. AdaLN conditioning |
| β β 2. Bidirectional SSM (4-dir scan) |
| β Every K blocks: β 3. Mix-FFN (DWConv + GLU) |
| β β GatedCrossStateFusionβ 4. Hyper-connection routing |
| ββββββββββββββ¬βββββββββββββ |
| β |
| v |
| βββββββββββββββββββββββββββ |
| β Final Norm + Proj β LayerNorm β Linear(D, C_lat) |
| β β velocity prediction β Predicts v = Ξ΅ - x_0 |
| βββββββββββββββββββββββββββ |
| |
| Model Sizes: |
| - LiRA-Tiny: D=384, N=12, ~50M params (for testing) |
| - LiRA-Small: D=512, N=20, ~120M params (mobile-optimized) |
| - LiRA-Base: D=768, N=28, ~300M params (quality-optimized) |
| - LiRA-Large: D=1024, N=36, ~600M params (maximum quality) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from typing import Optional, Dict, Tuple |
| from einops import rearrange |
|
|
| from .core_modules import ( |
| LiRABlock, |
| GatedCrossStateFusion, |
| LatentReasoningLoop, |
| TimestepEmbedding, |
| TextProjection, |
| HyperConnection, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class LatentPatchEmbed(nn.Module): |
| """ |
| Embeds latent space patches into model dimension. |
| |
| For DC-AE f32: latent is 32x32 for 1024px image, with 32 channels |
| For SD3/FLUX f8: latent is 128x128 for 1024px, with 16 channels |
| |
| We use simple 1x1 conv (no spatial patchify) since the VAE already |
| provides heavy spatial compression. Additional patching would lose |
| spatial resolution in the latent space. |
| |
| However, for f8 VAEs (128x128 = 16384 tokens), we optionally use |
| 2x2 patches to reduce to 64x64 = 4096 tokens. |
| """ |
| |
| def __init__(self, in_channels: int, d_model: int, patch_size: int = 1): |
| super().__init__() |
| self.patch_size = patch_size |
| self.proj = nn.Conv2d( |
| in_channels, d_model, |
| kernel_size=patch_size, stride=patch_size |
| ) |
| self.norm = nn.LayerNorm(d_model) |
| |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, int, int]: |
| """ |
| x: (B, C, H, W) latent features |
| Returns: (B, H'*W', D), H', W' |
| """ |
| x = self.proj(x) |
| B, D, H, W = x.shape |
| x = rearrange(x, 'b d h w -> b (h w) d') |
| x = self.norm(x) |
| return x, H, W |
|
|
|
|
| class LatentUnpatch(nn.Module): |
| """Reverse of LatentPatchEmbed: project back and reshape""" |
| |
| def __init__(self, d_model: int, out_channels: int, patch_size: int = 1): |
| super().__init__() |
| self.patch_size = patch_size |
| self.out_channels = out_channels |
| self.norm = nn.LayerNorm(d_model) |
| |
| if patch_size > 1: |
| |
| self.proj = nn.Linear(d_model, out_channels * patch_size * patch_size) |
| else: |
| self.proj = nn.Linear(d_model, out_channels) |
| |
| def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: |
| """ |
| x: (B, H'*W', D) |
| Returns: (B, C, H_orig, W_orig) |
| """ |
| x = self.norm(x) |
| x = self.proj(x) |
| |
| x = rearrange(x, 'b (h w) d -> b d h w', h=H, w=W) |
| |
| if self.patch_size > 1: |
| x = F.pixel_shuffle(x, self.patch_size) |
| |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class LiRAModel(nn.Module): |
| """ |
| LiRA: Liquid Reasoning Artisan - Main Denoising Network |
| |
| Novel architecture combining: |
| 1. State-space backbone (O(N) complexity) |
| 2. Latent reasoning loop (adaptive compute) |
| 3. Hyper-connections (dynamic layer arrangement) |
| 4. Gated cross-state text fusion (efficient cross-modal) |
| 5. Mix-FFN (local feature enhancement) |
| |
| Designed for mobile deployment: |
| - No quadratic attention anywhere |
| - All operations are O(N) in sequence length |
| - Compact parameter count (<400M for Base) |
| - Native 1024px via f32 VAE (32x32 = 1024 tokens) |
| """ |
| |
| |
| CONFIGS = { |
| 'tiny': { |
| 'd_model': 384, 'n_blocks': 12, 'd_state': 8, |
| 'd_reason': 96, 'max_reason_steps': 4, |
| 'ffn_expand': 2.0, 'cross_every': 4, |
| 'hc_expansion': 2, 'num_heads': 6, |
| }, |
| 'small': { |
| 'd_model': 512, 'n_blocks': 20, 'd_state': 16, |
| 'd_reason': 128, 'max_reason_steps': 6, |
| 'ffn_expand': 2.5, 'cross_every': 4, |
| 'hc_expansion': 2, 'num_heads': 8, |
| }, |
| 'base': { |
| 'd_model': 768, 'n_blocks': 28, 'd_state': 16, |
| 'd_reason': 192, 'max_reason_steps': 8, |
| 'ffn_expand': 2.5, 'cross_every': 4, |
| 'hc_expansion': 2, 'num_heads': 12, |
| }, |
| 'large': { |
| 'd_model': 1024, 'n_blocks': 36, 'd_state': 16, |
| 'd_reason': 256, 'max_reason_steps': 8, |
| 'ffn_expand': 3.0, 'cross_every': 4, |
| 'hc_expansion': 2, 'num_heads': 16, |
| }, |
| } |
| |
| def __init__( |
| self, |
| config_name: str = 'small', |
| in_channels: int = 32, |
| d_text: int = 768, |
| patch_size: int = 1, |
| **kwargs |
| ): |
| super().__init__() |
| |
| |
| if config_name in self.CONFIGS: |
| config = {**self.CONFIGS[config_name], **kwargs} |
| else: |
| config = kwargs |
| |
| self.d_model = config['d_model'] |
| self.n_blocks = config['n_blocks'] |
| self.d_state = config['d_state'] |
| self.d_reason = config['d_reason'] |
| self.cross_every = config['cross_every'] |
| self.in_channels = in_channels |
| |
| d_cond = self.d_model |
| |
| |
| self.patch_embed = LatentPatchEmbed(in_channels, self.d_model, patch_size) |
| self.unpatch = LatentUnpatch(self.d_model, in_channels, patch_size) |
| |
| |
| self.time_embed = TimestepEmbedding(self.d_model) |
| self.text_proj = TextProjection(d_text, self.d_model) |
| |
| |
| self.cond_combine = nn.Sequential( |
| nn.Linear(self.d_model * 3, self.d_model * 2), |
| nn.SiLU(), |
| nn.Linear(self.d_model * 2, self.d_model) |
| ) |
| |
| |
| self.reasoning = LatentReasoningLoop( |
| self.d_model, config['d_reason'], config['max_reason_steps'] |
| ) |
| |
| |
| self.blocks = nn.ModuleList() |
| self.cross_fusions = nn.ModuleDict() |
| |
| for i in range(self.n_blocks): |
| self.blocks.append(LiRABlock( |
| d_model=self.d_model, |
| d_cond=d_cond, |
| d_state=self.d_state, |
| ffn_expand=config['ffn_expand'], |
| hc_expansion=config['hc_expansion'], |
| )) |
| |
| |
| if (i + 1) % self.cross_every == 0: |
| self.cross_fusions[str(i)] = GatedCrossStateFusion( |
| self.d_model, self.d_model, self.d_state, config['num_heads'] |
| ) |
| |
| |
| |
| self.n_skip = self.n_blocks // 2 |
| self.skip_projs = nn.ModuleList([ |
| nn.Linear(self.d_model * 2, self.d_model) |
| for _ in range(self.n_skip) |
| ]) |
| |
| |
| self.final_norm = nn.LayerNorm(self.d_model) |
| self.final_adaln = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(d_cond, 2 * self.d_model) |
| ) |
| nn.init.zeros_(self.final_adaln[1].weight) |
| nn.init.zeros_(self.final_adaln[1].bias) |
| |
| |
| self._init_weights() |
| |
| def _init_weights(self): |
| """Careful weight initialization for training stability""" |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.trunc_normal_(m.weight, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): |
| nn.init.trunc_normal_(m.weight, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| |
| def forward( |
| self, |
| z_t: torch.Tensor, |
| t: torch.Tensor, |
| text_features: torch.Tensor, |
| text_mask: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, Dict]: |
| """ |
| Forward pass: predicts velocity v_t = Ξ΅ - x_0 |
| |
| Returns: |
| v_pred: (B, C, H, W) predicted velocity |
| info: dict with reasoning stats |
| """ |
| B = z_t.shape[0] |
| |
| |
| x, H, W = self.patch_embed(z_t) |
| t_emb = self.time_embed(t) |
| text_tokens, text_pooled = self.text_proj(text_features, text_mask) |
| |
| |
| reason_cond, reason_info = self.reasoning(x) |
| |
| |
| cond = self.cond_combine(torch.cat([t_emb, text_pooled, reason_cond], dim=-1)) |
| |
| |
| skip_features = [] |
| |
| for i, block in enumerate(self.blocks): |
| |
| if i < self.n_skip: |
| skip_features.append(x) |
| |
| |
| x = block(x, cond, H, W) |
| |
| |
| if str(i) in self.cross_fusions: |
| x = self.cross_fusions[str(i)](x, text_tokens) |
| |
| |
| if i >= self.n_skip: |
| skip_idx = self.n_blocks - 1 - i |
| if skip_idx < len(skip_features): |
| x = self.skip_projs[skip_idx]( |
| torch.cat([x, skip_features[skip_idx]], dim=-1) |
| ) |
| |
| |
| shift, scale = self.final_adaln(cond).unsqueeze(1).chunk(2, dim=-1) |
| x = self.final_norm(x) * (1 + scale) + shift |
| |
| v_pred = self.unpatch(x, H, W) |
| |
| return v_pred, reason_info |
| |
| @torch.no_grad() |
| def count_parameters(self) -> Dict[str, int]: |
| """Count parameters by component""" |
| counts = {} |
| counts['patch_embed'] = sum(p.numel() for p in self.patch_embed.parameters()) |
| counts['unpatch'] = sum(p.numel() for p in self.unpatch.parameters()) |
| counts['time_embed'] = sum(p.numel() for p in self.time_embed.parameters()) |
| counts['text_proj'] = sum(p.numel() for p in self.text_proj.parameters()) |
| counts['reasoning'] = sum(p.numel() for p in self.reasoning.parameters()) |
| counts['blocks'] = sum(p.numel() for p in self.blocks.parameters()) |
| counts['cross_fusions'] = sum(p.numel() for p in self.cross_fusions.parameters()) |
| counts['skip_projs'] = sum(p.numel() for p in self.skip_projs.parameters()) |
| counts['conditioning'] = sum(p.numel() for p in self.cond_combine.parameters()) |
| counts['output'] = ( |
| sum(p.numel() for p in self.final_norm.parameters()) + |
| sum(p.numel() for p in self.final_adaln.parameters()) |
| ) |
| counts['total'] = sum(p.numel() for p in self.parameters()) |
| return counts |
|
|
|
|
| |
| |
| |
|
|
| class TinyVAEDecoder(nn.Module): |
| """ |
| Ultra-lightweight VAE decoder inspired by SnapGen's tiny decoder. |
| |
| Key optimizations: |
| 1. NO attention layers (saves massive memory) |
| 2. Depthwise separable convolutions instead of full convolutions |
| 3. Minimal GroupNorm (only where needed to prevent color shift) |
| 4. PixelShuffle for upsampling (more efficient than transposed conv) |
| |
| For f32 VAE: 32x32 latent β 1024x1024 image (5 upsampling stages) |
| For f8 VAE: 128x128 latent β 1024x1024 image (3 upsampling stages) |
| |
| Target: ~1.5M parameters, <5MB on disk |
| """ |
| |
| def __init__( |
| self, |
| in_channels: int = 32, |
| out_channels: int = 3, |
| spatial_compression: int = 32, |
| base_channels: int = 64, |
| ): |
| super().__init__() |
| |
| num_upsample = int(math.log2(spatial_compression)) |
| |
| layers = [] |
| |
| |
| layers.append(nn.Conv2d(in_channels, base_channels, 3, padding=1)) |
| layers.append(nn.SiLU()) |
| |
| |
| current_ch = base_channels |
| for i in range(num_upsample): |
| |
| target_ch = max(base_channels // (2 ** max(0, i)), 16) |
| |
| |
| layers.append(SepConvBlock(current_ch, target_ch)) |
| current_ch = target_ch |
| |
| |
| layers.append(nn.Conv2d(current_ch, current_ch * 4, 3, padding=1)) |
| layers.append(nn.PixelShuffle(2)) |
| layers.append(nn.SiLU()) |
| |
| |
| |
| layers.append(nn.Conv2d(current_ch, out_channels, 3, padding=1)) |
| layers.append(nn.Tanh()) |
| |
| self.decoder = nn.Sequential(*layers) |
| |
| def forward(self, z: torch.Tensor) -> torch.Tensor: |
| """ |
| z: (B, C_lat, H_lat, W_lat) latent |
| Returns: (B, 3, H_img, W_img) decoded image |
| """ |
| return self.decoder(z) |
|
|
|
|
| class SepConvBlock(nn.Module): |
| """Depthwise separable convolution block""" |
| |
| def __init__(self, in_ch, out_ch): |
| super().__init__() |
| self.dwconv = nn.Conv2d(in_ch, in_ch, 3, padding=1, groups=in_ch) |
| self.pwconv = nn.Conv2d(in_ch, out_ch, 1) |
| self.norm = nn.GroupNorm(min(8, out_ch), out_ch) |
| self.act = nn.SiLU() |
| self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity() |
| |
| def forward(self, x): |
| residual = self.skip(x) |
| x = self.dwconv(x) |
| x = self.pwconv(x) |
| x = self.norm(x) |
| x = self.act(x) |
| return x + residual |
|
|
|
|
| |
| |
| |
|
|
| class LiRAPipeline(nn.Module): |
| """ |
| Complete LiRA pipeline combining: |
| 1. Pretrained VAE encoder (frozen) - for encoding images to latent space |
| 2. LiRA denoising network - the novel architecture |
| 3. Tiny VAE decoder - for mobile deployment |
| |
| During training: |
| image β VAE_encoder β z_0 β add_noise(z_0, t) β z_t β LiRA β v_pred |
| |
| During inference: |
| noise β iterative_denoise(LiRA) β z_0 β TinyVAEDecoder β image |
| """ |
| |
| def __init__( |
| self, |
| config_name: str = 'small', |
| latent_channels: int = 32, |
| spatial_compression: int = 32, |
| d_text: int = 768, |
| patch_size: int = 1, |
| ): |
| super().__init__() |
| |
| self.spatial_compression = spatial_compression |
| self.latent_channels = latent_channels |
| |
| |
| self.denoiser = LiRAModel( |
| config_name=config_name, |
| in_channels=latent_channels, |
| d_text=d_text, |
| patch_size=patch_size, |
| ) |
| |
| |
| self.tiny_decoder = TinyVAEDecoder( |
| in_channels=latent_channels, |
| spatial_compression=spatial_compression, |
| ) |
| |
| def forward(self, *args, **kwargs): |
| return self.denoiser(*args, **kwargs) |
| |
| def count_parameters(self): |
| counts = self.denoiser.count_parameters() |
| counts['tiny_decoder'] = sum(p.numel() for p in self.tiny_decoder.parameters()) |
| counts['total_with_decoder'] = counts['total'] + counts['tiny_decoder'] |
| return counts |
|
|
|
|
| |
| |
| |
|
|
| def estimate_memory_mb(model: nn.Module, batch_size: int = 1, |
| img_size: int = 1024, spatial_compression: int = 32, |
| latent_channels: int = 32, dtype_bytes: int = 2): |
| """Estimate inference memory usage in MB""" |
| |
| param_bytes = sum(p.numel() * dtype_bytes for p in model.parameters()) |
| param_mb = param_bytes / (1024 ** 2) |
| |
| |
| lat_h = img_size // spatial_compression |
| lat_w = img_size // spatial_compression |
| latent_bytes = batch_size * latent_channels * lat_h * lat_w * dtype_bytes |
| |
| |
| activation_bytes = latent_bytes * 3 |
| |
| total_mb = param_mb + (latent_bytes + activation_bytes) / (1024 ** 2) |
| |
| return { |
| 'params_mb': param_mb, |
| 'latent_mb': latent_bytes / (1024 ** 2), |
| 'activation_mb': activation_bytes / (1024 ** 2), |
| 'total_inference_mb': total_mb, |
| } |
|
|