LuminaRS / luminars /model.py
asdf98's picture
Upload luminars/model.py
359afd9 verified
"""
LuminaRS -- Lightweight Latent Recursive Diffusion for Art/Illustration.
~100M params, designed for 2-4GB VRAM mobile deployment.
Uses pretrained VAE + CLIP text encoder (both frozen).
"""
import math, torch, torch.nn as nn, torch.nn.functional as F
# ── helpers ──────────────────────────────────────────────────────────────
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(-math.log(max_period) *
torch.arange(half, device=t.device, dtype=torch.float32) / half)
args = t[:, None].float() * freqs[None]
emb = torch.cat([torch.cos(args), torch.sin(args)], -1)
if dim % 2:
emb = F.pad(emb, (0, 1))
return emb
class DropPath(nn.Module):
def __init__(self, p=0.0):
super().__init__(); self.p = p
def forward(self, x):
if self.p == 0. or not self.training: return x
mask = torch.zeros(x.shape[0],1,1,1, device=x.device).bernoulli_(1-self.p)
return x * mask / (1 - self.p)
class AdaLN(nn.Module):
"""Adaptive LayerNorm: shift+scale from time embedding."""
def __init__(self, dim, cond_dim):
super().__init__()
self.norm = nn.GroupNorm(1, dim)
self.s = nn.Linear(cond_dim, dim)
self.b = nn.Linear(cond_dim, dim)
def forward(self, x, c):
"""x: (B,C,H,W) c: (B,cond_dim)"""
h = self.norm(x)
s = self.s(c)[:,:,None,None]
b = self.b(c)[:,:,None,None]
return h * (1+s) + b
class MQAttention(nn.Module):
"""Multi-Query Attention: one K,V shared across heads."""
def __init__(self, dim, n_heads=4):
super().__init__()
assert dim % n_heads == 0
self.nh = n_heads; self.dh = dim // n_heads; self.scale = self.dh**-0.5
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
def forward(self, x, ctx=None):
B, L, C = x.shape
ctx = x if ctx is None else ctx
q = self.q(x).view(B, L, self.nh, self.dh).transpose(1,2)
k = self.k(ctx).view(B,-1, self.nh, self.dh).transpose(1,2)
v = self.v(ctx).view(B,-1, self.nh, self.dh).transpose(1,2)
a = (q @ k.transpose(-2,-1)) * self.scale
return self.o((a.softmax(-1) @ v).transpose(1,2).reshape(B,L,C))
# ── ConvNeXt + cross-attn block ─────────────────────────────────────────
class ResBlock(nn.Module):
def __init__(self, dim, cond_dim, text_dim=None, drop_path=0.0):
super().__init__()
self.adaln = AdaLN(dim, cond_dim)
self.dw = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
self.pw1 = nn.Linear(dim, dim*2)
self.act = nn.GELU()
self.pw2 = nn.Linear(dim*2, dim)
self.drop = DropPath(drop_path)
self.xattn = None
if text_dim is not None:
self.xattn = MQAttention(dim)
self.tproj = nn.Linear(text_dim, dim)
self.tnorm = nn.GroupNorm(1, dim)
def forward(self, x, cond, text=None):
r = x
x = self.adaln(x, cond)
x = self.dw(x)
x = x.permute(0,2,3,1)
x = self.pw2(self.act(self.pw1(x)))
x = x.permute(0,3,1,2)
x = r + self.drop(x)
if self.xattn is not None and text is not None:
B,C,H,W = x.shape
xf = x.view(B,C,H*W).transpose(1,2)
xf = xf + self.xattn(xf, self.tproj(text))
x = xf.transpose(1,2).view(B,C,H,W)
return x
# ── Down / Up ────────────────────────────────────────────────────────────
class Down(nn.Module):
def __init__(self, cin, cout, cond_dim, text_dim=None, n=2, dp=0.0):
super().__init__()
self.blocks = nn.ModuleList([
ResBlock(cin if i==0 else cout, cond_dim, text_dim, dp) for i in range(n)
])
if cin != cout:
self.skip_conv = nn.Conv2d(cin, cout, 1)
else:
self.skip_conv = None
self.down = nn.Conv2d(cout, cout, 3, stride=2, padding=1)
def forward(self, x, cond, text=None):
for b in self.blocks:
x = b(x, cond, text)
skip = x
if self.skip_conv is not None:
skip = self.skip_conv(skip)
return self.down(x), skip
class Up(nn.Module):
def __init__(self, cin, cout, cond_dim, text_dim=None, n=2, dp=0.0):
super().__init__()
self.up = nn.ConvTranspose2d(cin, cin, 2, stride=2)
self.blocks = nn.ModuleList([
ResBlock(cin if i==0 else cout, cond_dim, text_dim, dp) for i in range(n)
])
self.skip_conv = nn.Conv2d(cin, cout, 1) if cin != cout else nn.Identity()
def forward(self, x, skip, cond, text=None):
x = self.up(x)
x = x + skip
for b in self.blocks:
x = b(x, cond, text)
x = self.skip_conv(x) if hasattr(self, 'skip_conv') and not isinstance(self.skip_conv, nn.Identity) and b == self.blocks[0] else x
return x
# ── Main Model ───────────────────────────────────────────────────────────
class LuminaRS(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
chs = cfg.channels # e.g. (64, 128, 256, 384)
t_dim = cfg.t_embed_dim
cond_dim = chs[-1] * 4
self.time_mlp = nn.Sequential(
nn.Linear(t_dim, cond_dim), nn.SiLU(), nn.Linear(cond_dim, cond_dim))
self.text_proj = nn.Linear(cfg.text_embed_dim, cfg.text_proj_dim)
self.in_conv = nn.Conv2d(cfg.latent_dim, chs[0], 3, padding=1)
# Encoder
self.enc = nn.ModuleList()
for i in range(len(chs)-1):
self.enc.append(Down(chs[i], chs[i+1], cond_dim, cfg.text_proj_dim, n=2, dp=cfg.drop_path))
# Bottleneck
self.bot = nn.ModuleList([
ResBlock(chs[-1], cond_dim, cfg.text_proj_dim, dp=cfg.drop_path)
for _ in range(cfg.n_bottleneck)])
# Decoder
self.dec = nn.ModuleList()
for i in range(len(chs)-1, 0, -1):
self.dec.append(Up(chs[i], chs[i-1], cond_dim, cfg.text_proj_dim, n=2, dp=cfg.drop_path))
self.out_conv = nn.Conv2d(chs[0], cfg.latent_dim, 1)
self.n_recurse = cfg.n_recurse
def forward(self, z, text_emb, t):
B = z.shape[0]
t_emb = self.time_mlp(timestep_embedding(t, self.cfg.t_embed_dim))
text = self.text_proj(text_emb)
x = self.in_conv(z)
for _ in range(self.n_recurse):
h = x; skips = []
for down in self.enc:
h, sk = down(h, t_emb, text)
skips.append(sk)
for blk in self.bot:
h = blk(h, t_emb, text)
for up in self.dec:
sk = skips.pop()
h = up(h, sk, t_emb, text)
x = x + h
return self.out_conv(x)