| """ |
| LuminaRS -- Lightweight Latent Recursive Diffusion for Art/Illustration. |
| ~100M params, designed for 2-4GB VRAM mobile deployment. |
| Uses pretrained VAE + CLIP text encoder (both frozen). |
| """ |
| import math, torch, torch.nn as nn, torch.nn.functional as F |
|
|
| |
|
|
| def timestep_embedding(t, dim, max_period=10000): |
| half = dim // 2 |
| freqs = torch.exp(-math.log(max_period) * |
| torch.arange(half, device=t.device, dtype=torch.float32) / half) |
| args = t[:, None].float() * freqs[None] |
| emb = torch.cat([torch.cos(args), torch.sin(args)], -1) |
| if dim % 2: |
| emb = F.pad(emb, (0, 1)) |
| return emb |
|
|
| class DropPath(nn.Module): |
| def __init__(self, p=0.0): |
| super().__init__(); self.p = p |
| def forward(self, x): |
| if self.p == 0. or not self.training: return x |
| mask = torch.zeros(x.shape[0],1,1,1, device=x.device).bernoulli_(1-self.p) |
| return x * mask / (1 - self.p) |
|
|
| class AdaLN(nn.Module): |
| """Adaptive LayerNorm: shift+scale from time embedding.""" |
| def __init__(self, dim, cond_dim): |
| super().__init__() |
| self.norm = nn.GroupNorm(1, dim) |
| self.s = nn.Linear(cond_dim, dim) |
| self.b = nn.Linear(cond_dim, dim) |
| def forward(self, x, c): |
| """x: (B,C,H,W) c: (B,cond_dim)""" |
| h = self.norm(x) |
| s = self.s(c)[:,:,None,None] |
| b = self.b(c)[:,:,None,None] |
| return h * (1+s) + b |
|
|
| class MQAttention(nn.Module): |
| """Multi-Query Attention: one K,V shared across heads.""" |
| def __init__(self, dim, n_heads=4): |
| super().__init__() |
| assert dim % n_heads == 0 |
| self.nh = n_heads; self.dh = dim // n_heads; self.scale = self.dh**-0.5 |
| self.q = nn.Linear(dim, dim) |
| self.k = nn.Linear(dim, dim) |
| self.v = nn.Linear(dim, dim) |
| self.o = nn.Linear(dim, dim) |
| def forward(self, x, ctx=None): |
| B, L, C = x.shape |
| ctx = x if ctx is None else ctx |
| q = self.q(x).view(B, L, self.nh, self.dh).transpose(1,2) |
| k = self.k(ctx).view(B,-1, self.nh, self.dh).transpose(1,2) |
| v = self.v(ctx).view(B,-1, self.nh, self.dh).transpose(1,2) |
| a = (q @ k.transpose(-2,-1)) * self.scale |
| return self.o((a.softmax(-1) @ v).transpose(1,2).reshape(B,L,C)) |
|
|
| |
|
|
| class ResBlock(nn.Module): |
| def __init__(self, dim, cond_dim, text_dim=None, drop_path=0.0): |
| super().__init__() |
| self.adaln = AdaLN(dim, cond_dim) |
| self.dw = nn.Conv2d(dim, dim, 7, padding=3, groups=dim) |
| self.pw1 = nn.Linear(dim, dim*2) |
| self.act = nn.GELU() |
| self.pw2 = nn.Linear(dim*2, dim) |
| self.drop = DropPath(drop_path) |
| self.xattn = None |
| if text_dim is not None: |
| self.xattn = MQAttention(dim) |
| self.tproj = nn.Linear(text_dim, dim) |
| self.tnorm = nn.GroupNorm(1, dim) |
| def forward(self, x, cond, text=None): |
| r = x |
| x = self.adaln(x, cond) |
| x = self.dw(x) |
| x = x.permute(0,2,3,1) |
| x = self.pw2(self.act(self.pw1(x))) |
| x = x.permute(0,3,1,2) |
| x = r + self.drop(x) |
| if self.xattn is not None and text is not None: |
| B,C,H,W = x.shape |
| xf = x.view(B,C,H*W).transpose(1,2) |
| xf = xf + self.xattn(xf, self.tproj(text)) |
| x = xf.transpose(1,2).view(B,C,H,W) |
| return x |
|
|
| |
|
|
| class Down(nn.Module): |
| def __init__(self, cin, cout, cond_dim, text_dim=None, n=2, dp=0.0): |
| super().__init__() |
| self.blocks = nn.ModuleList([ |
| ResBlock(cin if i==0 else cout, cond_dim, text_dim, dp) for i in range(n) |
| ]) |
| if cin != cout: |
| self.skip_conv = nn.Conv2d(cin, cout, 1) |
| else: |
| self.skip_conv = None |
| self.down = nn.Conv2d(cout, cout, 3, stride=2, padding=1) |
| def forward(self, x, cond, text=None): |
| for b in self.blocks: |
| x = b(x, cond, text) |
| skip = x |
| if self.skip_conv is not None: |
| skip = self.skip_conv(skip) |
| return self.down(x), skip |
|
|
| class Up(nn.Module): |
| def __init__(self, cin, cout, cond_dim, text_dim=None, n=2, dp=0.0): |
| super().__init__() |
| self.up = nn.ConvTranspose2d(cin, cin, 2, stride=2) |
| self.blocks = nn.ModuleList([ |
| ResBlock(cin if i==0 else cout, cond_dim, text_dim, dp) for i in range(n) |
| ]) |
| self.skip_conv = nn.Conv2d(cin, cout, 1) if cin != cout else nn.Identity() |
| def forward(self, x, skip, cond, text=None): |
| x = self.up(x) |
| x = x + skip |
| for b in self.blocks: |
| x = b(x, cond, text) |
| x = self.skip_conv(x) if hasattr(self, 'skip_conv') and not isinstance(self.skip_conv, nn.Identity) and b == self.blocks[0] else x |
| return x |
|
|
| |
|
|
| class LuminaRS(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
| chs = cfg.channels |
| t_dim = cfg.t_embed_dim |
| cond_dim = chs[-1] * 4 |
|
|
| self.time_mlp = nn.Sequential( |
| nn.Linear(t_dim, cond_dim), nn.SiLU(), nn.Linear(cond_dim, cond_dim)) |
|
|
| self.text_proj = nn.Linear(cfg.text_embed_dim, cfg.text_proj_dim) |
|
|
| self.in_conv = nn.Conv2d(cfg.latent_dim, chs[0], 3, padding=1) |
|
|
| |
| self.enc = nn.ModuleList() |
| for i in range(len(chs)-1): |
| self.enc.append(Down(chs[i], chs[i+1], cond_dim, cfg.text_proj_dim, n=2, dp=cfg.drop_path)) |
|
|
| |
| self.bot = nn.ModuleList([ |
| ResBlock(chs[-1], cond_dim, cfg.text_proj_dim, dp=cfg.drop_path) |
| for _ in range(cfg.n_bottleneck)]) |
|
|
| |
| self.dec = nn.ModuleList() |
| for i in range(len(chs)-1, 0, -1): |
| self.dec.append(Up(chs[i], chs[i-1], cond_dim, cfg.text_proj_dim, n=2, dp=cfg.drop_path)) |
|
|
| self.out_conv = nn.Conv2d(chs[0], cfg.latent_dim, 1) |
| self.n_recurse = cfg.n_recurse |
|
|
| def forward(self, z, text_emb, t): |
| B = z.shape[0] |
| t_emb = self.time_mlp(timestep_embedding(t, self.cfg.t_embed_dim)) |
| text = self.text_proj(text_emb) |
| x = self.in_conv(z) |
| for _ in range(self.n_recurse): |
| h = x; skips = [] |
| for down in self.enc: |
| h, sk = down(h, t_emb, text) |
| skips.append(sk) |
| for blk in self.bot: |
| h = blk(h, t_emb, text) |
| for up in self.dec: |
| sk = skips.pop() |
| h = up(h, sk, t_emb, text) |
| x = x + h |
| return self.out_conv(x) |
|
|