LiRA / lira /model.py
asdf98's picture
Add lira/model.py
18ce5a6 verified
"""
LiRA Model: Full Architecture
Architecture Overview (Denoising Network):
==========================================
Input: z_t (noisy latent, B x C x H x W) + t (timestep) + text_features
|
v
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Patch Embedding β”‚ Conv2d(C_lat, D, 1x1) - patchify
β”‚ + Freq Decomposition β”‚ Optional: Haar wavelet split
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
v
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Latent Reasoning Loop β”‚ 2-8 adaptive steps (learned)
β”‚ (generates reasoning β”‚ β†’ produces reasoning conditioning
β”‚ conditioning vector) β”‚ Only ~128 dims, very cheap
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚ reasoning_cond + timestep_embed + text_pooled
β”‚ β†’ combined conditioning vector
v
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ N x LiRA Blocks β”‚ Each block:
β”‚ (with HyperConnections)β”‚ 1. AdaLN conditioning
β”‚ β”‚ 2. Bidirectional SSM (4-dir scan)
β”‚ Every K blocks: β”‚ 3. Mix-FFN (DWConv + GLU)
β”‚ β†’ GatedCrossStateFusionβ”‚ 4. Hyper-connection routing
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
v
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Final Norm + Proj β”‚ LayerNorm β†’ Linear(D, C_lat)
β”‚ β†’ velocity prediction β”‚ Predicts v = Ξ΅ - x_0
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
Model Sizes:
- LiRA-Tiny: D=384, N=12, ~50M params (for testing)
- LiRA-Small: D=512, N=20, ~120M params (mobile-optimized)
- LiRA-Base: D=768, N=28, ~300M params (quality-optimized)
- LiRA-Large: D=1024, N=36, ~600M params (maximum quality)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Dict, Tuple
from einops import rearrange
from .core_modules import (
LiRABlock,
GatedCrossStateFusion,
LatentReasoningLoop,
TimestepEmbedding,
TextProjection,
HyperConnection,
)
# ============================================================================
# Patch Embedding for Latent Space
# ============================================================================
class LatentPatchEmbed(nn.Module):
"""
Embeds latent space patches into model dimension.
For DC-AE f32: latent is 32x32 for 1024px image, with 32 channels
For SD3/FLUX f8: latent is 128x128 for 1024px, with 16 channels
We use simple 1x1 conv (no spatial patchify) since the VAE already
provides heavy spatial compression. Additional patching would lose
spatial resolution in the latent space.
However, for f8 VAEs (128x128 = 16384 tokens), we optionally use
2x2 patches to reduce to 64x64 = 4096 tokens.
"""
def __init__(self, in_channels: int, d_model: int, patch_size: int = 1):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_channels, d_model,
kernel_size=patch_size, stride=patch_size
)
self.norm = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
"""
x: (B, C, H, W) latent features
Returns: (B, H'*W', D), H', W'
"""
x = self.proj(x) # (B, D, H', W')
B, D, H, W = x.shape
x = rearrange(x, 'b d h w -> b (h w) d')
x = self.norm(x)
return x, H, W
class LatentUnpatch(nn.Module):
"""Reverse of LatentPatchEmbed: project back and reshape"""
def __init__(self, d_model: int, out_channels: int, patch_size: int = 1):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels
self.norm = nn.LayerNorm(d_model)
if patch_size > 1:
# Use pixel shuffle for upsampling
self.proj = nn.Linear(d_model, out_channels * patch_size * patch_size)
else:
self.proj = nn.Linear(d_model, out_channels)
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
"""
x: (B, H'*W', D)
Returns: (B, C, H_orig, W_orig)
"""
x = self.norm(x)
x = self.proj(x) # (B, H'*W', C*p*p)
x = rearrange(x, 'b (h w) d -> b d h w', h=H, w=W)
if self.patch_size > 1:
x = F.pixel_shuffle(x, self.patch_size)
return x
# ============================================================================
# LiRA Denoising Network
# ============================================================================
class LiRAModel(nn.Module):
"""
LiRA: Liquid Reasoning Artisan - Main Denoising Network
Novel architecture combining:
1. State-space backbone (O(N) complexity)
2. Latent reasoning loop (adaptive compute)
3. Hyper-connections (dynamic layer arrangement)
4. Gated cross-state text fusion (efficient cross-modal)
5. Mix-FFN (local feature enhancement)
Designed for mobile deployment:
- No quadratic attention anywhere
- All operations are O(N) in sequence length
- Compact parameter count (<400M for Base)
- Native 1024px via f32 VAE (32x32 = 1024 tokens)
"""
# Predefined configurations
CONFIGS = {
'tiny': {
'd_model': 384, 'n_blocks': 12, 'd_state': 8,
'd_reason': 96, 'max_reason_steps': 4,
'ffn_expand': 2.0, 'cross_every': 4,
'hc_expansion': 2, 'num_heads': 6,
},
'small': {
'd_model': 512, 'n_blocks': 20, 'd_state': 16,
'd_reason': 128, 'max_reason_steps': 6,
'ffn_expand': 2.5, 'cross_every': 4,
'hc_expansion': 2, 'num_heads': 8,
},
'base': {
'd_model': 768, 'n_blocks': 28, 'd_state': 16,
'd_reason': 192, 'max_reason_steps': 8,
'ffn_expand': 2.5, 'cross_every': 4,
'hc_expansion': 2, 'num_heads': 12,
},
'large': {
'd_model': 1024, 'n_blocks': 36, 'd_state': 16,
'd_reason': 256, 'max_reason_steps': 8,
'ffn_expand': 3.0, 'cross_every': 4,
'hc_expansion': 2, 'num_heads': 16,
},
}
def __init__(
self,
config_name: str = 'small',
in_channels: int = 32, # DC-AE f32c32 latent channels
d_text: int = 768, # Text encoder dimension (CLIP or small LLM)
patch_size: int = 1, # Patch size for latent tokens
**kwargs
):
super().__init__()
# Get config
if config_name in self.CONFIGS:
config = {**self.CONFIGS[config_name], **kwargs}
else:
config = kwargs
self.d_model = config['d_model']
self.n_blocks = config['n_blocks']
self.d_state = config['d_state']
self.d_reason = config['d_reason']
self.cross_every = config['cross_every']
self.in_channels = in_channels
d_cond = self.d_model # Conditioning dimension
# ====== Input Processing ======
self.patch_embed = LatentPatchEmbed(in_channels, self.d_model, patch_size)
self.unpatch = LatentUnpatch(self.d_model, in_channels, patch_size)
# ====== Conditioning ======
self.time_embed = TimestepEmbedding(self.d_model)
self.text_proj = TextProjection(d_text, self.d_model)
# Combine timestep + text pooled + reasoning into single conditioning vector
self.cond_combine = nn.Sequential(
nn.Linear(self.d_model * 3, self.d_model * 2),
nn.SiLU(),
nn.Linear(self.d_model * 2, self.d_model)
)
# ====== Latent Reasoning Loop ======
self.reasoning = LatentReasoningLoop(
self.d_model, config['d_reason'], config['max_reason_steps']
)
# ====== Main Backbone: LiRA Blocks ======
self.blocks = nn.ModuleList()
self.cross_fusions = nn.ModuleDict()
for i in range(self.n_blocks):
self.blocks.append(LiRABlock(
d_model=self.d_model,
d_cond=d_cond,
d_state=self.d_state,
ffn_expand=config['ffn_expand'],
hc_expansion=config['hc_expansion'],
))
# Add cross-modal fusion every K blocks
if (i + 1) % self.cross_every == 0:
self.cross_fusions[str(i)] = GatedCrossStateFusion(
self.d_model, self.d_model, self.d_state, config['num_heads']
)
# ====== Long Skip Connection (from U-ViT / DiM) ======
# Connect block i with block (n_blocks - 1 - i) via learned projection
self.n_skip = self.n_blocks // 2
self.skip_projs = nn.ModuleList([
nn.Linear(self.d_model * 2, self.d_model)
for _ in range(self.n_skip)
])
# ====== Output ======
self.final_norm = nn.LayerNorm(self.d_model)
self.final_adaln = nn.Sequential(
nn.SiLU(),
nn.Linear(d_cond, 2 * self.d_model)
)
nn.init.zeros_(self.final_adaln[1].weight)
nn.init.zeros_(self.final_adaln[1].bias)
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Careful weight initialization for training stability"""
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(
self,
z_t: torch.Tensor, # (B, C, H, W) noisy latent
t: torch.Tensor, # (B,) timestep in [0, 1]
text_features: torch.Tensor, # (B, M, D_text) text encoder output
text_mask: Optional[torch.Tensor] = None, # (B, M) mask
) -> Tuple[torch.Tensor, Dict]:
"""
Forward pass: predicts velocity v_t = Ξ΅ - x_0
Returns:
v_pred: (B, C, H, W) predicted velocity
info: dict with reasoning stats
"""
B = z_t.shape[0]
# ====== Embed inputs ======
x, H, W = self.patch_embed(z_t) # (B, N, D)
t_emb = self.time_embed(t) # (B, D)
text_tokens, text_pooled = self.text_proj(text_features, text_mask) # (B, M, D), (B, D)
# ====== Latent Reasoning ======
reason_cond, reason_info = self.reasoning(x) # (B, D)
# ====== Combine conditioning ======
cond = self.cond_combine(torch.cat([t_emb, text_pooled, reason_cond], dim=-1)) # (B, D)
# ====== Main backbone with long skip connections ======
skip_features = []
for i, block in enumerate(self.blocks):
# Store features for skip connections (first half)
if i < self.n_skip:
skip_features.append(x)
# Apply LiRA block
x = block(x, cond, H, W)
# Apply cross-modal fusion
if str(i) in self.cross_fusions:
x = self.cross_fusions[str(i)](x, text_tokens)
# Apply skip connections (second half)
if i >= self.n_skip:
skip_idx = self.n_blocks - 1 - i
if skip_idx < len(skip_features):
x = self.skip_projs[skip_idx](
torch.cat([x, skip_features[skip_idx]], dim=-1)
)
# ====== Output projection ======
shift, scale = self.final_adaln(cond).unsqueeze(1).chunk(2, dim=-1)
x = self.final_norm(x) * (1 + scale) + shift
v_pred = self.unpatch(x, H, W) # (B, C, H_orig, W_orig)
return v_pred, reason_info
@torch.no_grad()
def count_parameters(self) -> Dict[str, int]:
"""Count parameters by component"""
counts = {}
counts['patch_embed'] = sum(p.numel() for p in self.patch_embed.parameters())
counts['unpatch'] = sum(p.numel() for p in self.unpatch.parameters())
counts['time_embed'] = sum(p.numel() for p in self.time_embed.parameters())
counts['text_proj'] = sum(p.numel() for p in self.text_proj.parameters())
counts['reasoning'] = sum(p.numel() for p in self.reasoning.parameters())
counts['blocks'] = sum(p.numel() for p in self.blocks.parameters())
counts['cross_fusions'] = sum(p.numel() for p in self.cross_fusions.parameters())
counts['skip_projs'] = sum(p.numel() for p in self.skip_projs.parameters())
counts['conditioning'] = sum(p.numel() for p in self.cond_combine.parameters())
counts['output'] = (
sum(p.numel() for p in self.final_norm.parameters()) +
sum(p.numel() for p in self.final_adaln.parameters())
)
counts['total'] = sum(p.numel() for p in self.parameters())
return counts
# ============================================================================
# Tiny VAE Decoder for Mobile Deployment
# ============================================================================
class TinyVAEDecoder(nn.Module):
"""
Ultra-lightweight VAE decoder inspired by SnapGen's tiny decoder.
Key optimizations:
1. NO attention layers (saves massive memory)
2. Depthwise separable convolutions instead of full convolutions
3. Minimal GroupNorm (only where needed to prevent color shift)
4. PixelShuffle for upsampling (more efficient than transposed conv)
For f32 VAE: 32x32 latent β†’ 1024x1024 image (5 upsampling stages)
For f8 VAE: 128x128 latent β†’ 1024x1024 image (3 upsampling stages)
Target: ~1.5M parameters, <5MB on disk
"""
def __init__(
self,
in_channels: int = 32,
out_channels: int = 3,
spatial_compression: int = 32, # 32 for f32, 8 for f8
base_channels: int = 64,
):
super().__init__()
num_upsample = int(math.log2(spatial_compression)) # 5 for f32, 3 for f8
layers = []
# Initial projection
layers.append(nn.Conv2d(in_channels, base_channels, 3, padding=1))
layers.append(nn.SiLU())
# Upsampling stages - track channels carefully
current_ch = base_channels
for i in range(num_upsample):
# Gradually reduce channels in later (higher-res) stages
target_ch = max(base_channels // (2 ** max(0, i)), 16)
# Depthwise separable residual block
layers.append(SepConvBlock(current_ch, target_ch))
current_ch = target_ch
# PixelShuffle upsample (2x): needs ch*4 input, outputs ch
layers.append(nn.Conv2d(current_ch, current_ch * 4, 3, padding=1))
layers.append(nn.PixelShuffle(2)) # ch*4 β†’ ch, spatial 2x
layers.append(nn.SiLU())
# After PixelShuffle, channels stay at current_ch
# Final output
layers.append(nn.Conv2d(current_ch, out_channels, 3, padding=1))
layers.append(nn.Tanh()) # Output in [-1, 1]
self.decoder = nn.Sequential(*layers)
def forward(self, z: torch.Tensor) -> torch.Tensor:
"""
z: (B, C_lat, H_lat, W_lat) latent
Returns: (B, 3, H_img, W_img) decoded image
"""
return self.decoder(z)
class SepConvBlock(nn.Module):
"""Depthwise separable convolution block"""
def __init__(self, in_ch, out_ch):
super().__init__()
self.dwconv = nn.Conv2d(in_ch, in_ch, 3, padding=1, groups=in_ch)
self.pwconv = nn.Conv2d(in_ch, out_ch, 1)
self.norm = nn.GroupNorm(min(8, out_ch), out_ch)
self.act = nn.SiLU()
self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
def forward(self, x):
residual = self.skip(x)
x = self.dwconv(x)
x = self.pwconv(x)
x = self.norm(x)
x = self.act(x)
return x + residual
# ============================================================================
# Complete LiRA Pipeline
# ============================================================================
class LiRAPipeline(nn.Module):
"""
Complete LiRA pipeline combining:
1. Pretrained VAE encoder (frozen) - for encoding images to latent space
2. LiRA denoising network - the novel architecture
3. Tiny VAE decoder - for mobile deployment
During training:
image β†’ VAE_encoder β†’ z_0 β†’ add_noise(z_0, t) β†’ z_t β†’ LiRA β†’ v_pred
During inference:
noise β†’ iterative_denoise(LiRA) β†’ z_0 β†’ TinyVAEDecoder β†’ image
"""
def __init__(
self,
config_name: str = 'small',
latent_channels: int = 32,
spatial_compression: int = 32,
d_text: int = 768,
patch_size: int = 1,
):
super().__init__()
self.spatial_compression = spatial_compression
self.latent_channels = latent_channels
# Denoising network
self.denoiser = LiRAModel(
config_name=config_name,
in_channels=latent_channels,
d_text=d_text,
patch_size=patch_size,
)
# Tiny decoder for mobile inference
self.tiny_decoder = TinyVAEDecoder(
in_channels=latent_channels,
spatial_compression=spatial_compression,
)
def forward(self, *args, **kwargs):
return self.denoiser(*args, **kwargs)
def count_parameters(self):
counts = self.denoiser.count_parameters()
counts['tiny_decoder'] = sum(p.numel() for p in self.tiny_decoder.parameters())
counts['total_with_decoder'] = counts['total'] + counts['tiny_decoder']
return counts
# ============================================================================
# Helper: Estimate memory usage
# ============================================================================
def estimate_memory_mb(model: nn.Module, batch_size: int = 1,
img_size: int = 1024, spatial_compression: int = 32,
latent_channels: int = 32, dtype_bytes: int = 2):
"""Estimate inference memory usage in MB"""
# Model parameters
param_bytes = sum(p.numel() * dtype_bytes for p in model.parameters())
param_mb = param_bytes / (1024 ** 2)
# Latent size
lat_h = img_size // spatial_compression
lat_w = img_size // spatial_compression
latent_bytes = batch_size * latent_channels * lat_h * lat_w * dtype_bytes
# Intermediate activations (rough estimate: 3x latent)
activation_bytes = latent_bytes * 3
total_mb = param_mb + (latent_bytes + activation_bytes) / (1024 ** 2)
return {
'params_mb': param_mb,
'latent_mb': latent_bytes / (1024 ** 2),
'activation_mb': activation_bytes / (1024 ** 2),
'total_inference_mb': total_mb,
}