""" LuminaRS -- Lightweight Latent Recursive Diffusion for Art/Illustration. ~100M params, designed for 2-4GB VRAM mobile deployment. Uses pretrained VAE + CLIP text encoder (both frozen). """ import math, torch, torch.nn as nn, torch.nn.functional as F # ── helpers ────────────────────────────────────────────────────────────── def timestep_embedding(t, dim, max_period=10000): half = dim // 2 freqs = torch.exp(-math.log(max_period) * torch.arange(half, device=t.device, dtype=torch.float32) / half) args = t[:, None].float() * freqs[None] emb = torch.cat([torch.cos(args), torch.sin(args)], -1) if dim % 2: emb = F.pad(emb, (0, 1)) return emb class DropPath(nn.Module): def __init__(self, p=0.0): super().__init__(); self.p = p def forward(self, x): if self.p == 0. or not self.training: return x mask = torch.zeros(x.shape[0],1,1,1, device=x.device).bernoulli_(1-self.p) return x * mask / (1 - self.p) class AdaLN(nn.Module): """Adaptive LayerNorm: shift+scale from time embedding.""" def __init__(self, dim, cond_dim): super().__init__() self.norm = nn.GroupNorm(1, dim) self.s = nn.Linear(cond_dim, dim) self.b = nn.Linear(cond_dim, dim) def forward(self, x, c): """x: (B,C,H,W) c: (B,cond_dim)""" h = self.norm(x) s = self.s(c)[:,:,None,None] b = self.b(c)[:,:,None,None] return h * (1+s) + b class MQAttention(nn.Module): """Multi-Query Attention: one K,V shared across heads.""" def __init__(self, dim, n_heads=4): super().__init__() assert dim % n_heads == 0 self.nh = n_heads; self.dh = dim // n_heads; self.scale = self.dh**-0.5 self.q = nn.Linear(dim, dim) self.k = nn.Linear(dim, dim) self.v = nn.Linear(dim, dim) self.o = nn.Linear(dim, dim) def forward(self, x, ctx=None): B, L, C = x.shape ctx = x if ctx is None else ctx q = self.q(x).view(B, L, self.nh, self.dh).transpose(1,2) k = self.k(ctx).view(B,-1, self.nh, self.dh).transpose(1,2) v = self.v(ctx).view(B,-1, self.nh, self.dh).transpose(1,2) a = (q @ k.transpose(-2,-1)) * self.scale return self.o((a.softmax(-1) @ v).transpose(1,2).reshape(B,L,C)) # ── ConvNeXt + cross-attn block ───────────────────────────────────────── class ResBlock(nn.Module): def __init__(self, dim, cond_dim, text_dim=None, drop_path=0.0): super().__init__() self.adaln = AdaLN(dim, cond_dim) self.dw = nn.Conv2d(dim, dim, 7, padding=3, groups=dim) self.pw1 = nn.Linear(dim, dim*2) self.act = nn.GELU() self.pw2 = nn.Linear(dim*2, dim) self.drop = DropPath(drop_path) self.xattn = None if text_dim is not None: self.xattn = MQAttention(dim) self.tproj = nn.Linear(text_dim, dim) self.tnorm = nn.GroupNorm(1, dim) def forward(self, x, cond, text=None): r = x x = self.adaln(x, cond) x = self.dw(x) x = x.permute(0,2,3,1) x = self.pw2(self.act(self.pw1(x))) x = x.permute(0,3,1,2) x = r + self.drop(x) if self.xattn is not None and text is not None: B,C,H,W = x.shape xf = x.view(B,C,H*W).transpose(1,2) xf = xf + self.xattn(xf, self.tproj(text)) x = xf.transpose(1,2).view(B,C,H,W) return x # ── Down / Up ──────────────────────────────────────────────────────────── class Down(nn.Module): def __init__(self, cin, cout, cond_dim, text_dim=None, n=2, dp=0.0): super().__init__() self.blocks = nn.ModuleList([ ResBlock(cin if i==0 else cout, cond_dim, text_dim, dp) for i in range(n) ]) if cin != cout: self.skip_conv = nn.Conv2d(cin, cout, 1) else: self.skip_conv = None self.down = nn.Conv2d(cout, cout, 3, stride=2, padding=1) def forward(self, x, cond, text=None): for b in self.blocks: x = b(x, cond, text) skip = x if self.skip_conv is not None: skip = self.skip_conv(skip) return self.down(x), skip class Up(nn.Module): def __init__(self, cin, cout, cond_dim, text_dim=None, n=2, dp=0.0): super().__init__() self.up = nn.ConvTranspose2d(cin, cin, 2, stride=2) self.blocks = nn.ModuleList([ ResBlock(cin if i==0 else cout, cond_dim, text_dim, dp) for i in range(n) ]) self.skip_conv = nn.Conv2d(cin, cout, 1) if cin != cout else nn.Identity() def forward(self, x, skip, cond, text=None): x = self.up(x) x = x + skip for b in self.blocks: x = b(x, cond, text) x = self.skip_conv(x) if hasattr(self, 'skip_conv') and not isinstance(self.skip_conv, nn.Identity) and b == self.blocks[0] else x return x # ── Main Model ─────────────────────────────────────────────────────────── class LuminaRS(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg chs = cfg.channels # e.g. (64, 128, 256, 384) t_dim = cfg.t_embed_dim cond_dim = chs[-1] * 4 self.time_mlp = nn.Sequential( nn.Linear(t_dim, cond_dim), nn.SiLU(), nn.Linear(cond_dim, cond_dim)) self.text_proj = nn.Linear(cfg.text_embed_dim, cfg.text_proj_dim) self.in_conv = nn.Conv2d(cfg.latent_dim, chs[0], 3, padding=1) # Encoder self.enc = nn.ModuleList() for i in range(len(chs)-1): self.enc.append(Down(chs[i], chs[i+1], cond_dim, cfg.text_proj_dim, n=2, dp=cfg.drop_path)) # Bottleneck self.bot = nn.ModuleList([ ResBlock(chs[-1], cond_dim, cfg.text_proj_dim, dp=cfg.drop_path) for _ in range(cfg.n_bottleneck)]) # Decoder self.dec = nn.ModuleList() for i in range(len(chs)-1, 0, -1): self.dec.append(Up(chs[i], chs[i-1], cond_dim, cfg.text_proj_dim, n=2, dp=cfg.drop_path)) self.out_conv = nn.Conv2d(chs[0], cfg.latent_dim, 1) self.n_recurse = cfg.n_recurse def forward(self, z, text_emb, t): B = z.shape[0] t_emb = self.time_mlp(timestep_embedding(t, self.cfg.t_embed_dim)) text = self.text_proj(text_emb) x = self.in_conv(z) for _ in range(self.n_recurse): h = x; skips = [] for down in self.enc: h, sk = down(h, t_emb, text) skips.append(sk) for blk in self.bot: h = blk(h, t_emb, text) for up in self.dec: sk = skips.pop() h = up(h, sk, t_emb, text) x = x + h return self.out_conv(x)