File size: 7,282 Bytes
c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 359afd9 c6e7340 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | """
LuminaRS -- Lightweight Latent Recursive Diffusion for Art/Illustration.
~100M params, designed for 2-4GB VRAM mobile deployment.
Uses pretrained VAE + CLIP text encoder (both frozen).
"""
import math, torch, torch.nn as nn, torch.nn.functional as F
# ββ helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(-math.log(max_period) *
torch.arange(half, device=t.device, dtype=torch.float32) / half)
args = t[:, None].float() * freqs[None]
emb = torch.cat([torch.cos(args), torch.sin(args)], -1)
if dim % 2:
emb = F.pad(emb, (0, 1))
return emb
class DropPath(nn.Module):
def __init__(self, p=0.0):
super().__init__(); self.p = p
def forward(self, x):
if self.p == 0. or not self.training: return x
mask = torch.zeros(x.shape[0],1,1,1, device=x.device).bernoulli_(1-self.p)
return x * mask / (1 - self.p)
class AdaLN(nn.Module):
"""Adaptive LayerNorm: shift+scale from time embedding."""
def __init__(self, dim, cond_dim):
super().__init__()
self.norm = nn.GroupNorm(1, dim)
self.s = nn.Linear(cond_dim, dim)
self.b = nn.Linear(cond_dim, dim)
def forward(self, x, c):
"""x: (B,C,H,W) c: (B,cond_dim)"""
h = self.norm(x)
s = self.s(c)[:,:,None,None]
b = self.b(c)[:,:,None,None]
return h * (1+s) + b
class MQAttention(nn.Module):
"""Multi-Query Attention: one K,V shared across heads."""
def __init__(self, dim, n_heads=4):
super().__init__()
assert dim % n_heads == 0
self.nh = n_heads; self.dh = dim // n_heads; self.scale = self.dh**-0.5
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
def forward(self, x, ctx=None):
B, L, C = x.shape
ctx = x if ctx is None else ctx
q = self.q(x).view(B, L, self.nh, self.dh).transpose(1,2)
k = self.k(ctx).view(B,-1, self.nh, self.dh).transpose(1,2)
v = self.v(ctx).view(B,-1, self.nh, self.dh).transpose(1,2)
a = (q @ k.transpose(-2,-1)) * self.scale
return self.o((a.softmax(-1) @ v).transpose(1,2).reshape(B,L,C))
# ββ ConvNeXt + cross-attn block βββββββββββββββββββββββββββββββββββββββββ
class ResBlock(nn.Module):
def __init__(self, dim, cond_dim, text_dim=None, drop_path=0.0):
super().__init__()
self.adaln = AdaLN(dim, cond_dim)
self.dw = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
self.pw1 = nn.Linear(dim, dim*2)
self.act = nn.GELU()
self.pw2 = nn.Linear(dim*2, dim)
self.drop = DropPath(drop_path)
self.xattn = None
if text_dim is not None:
self.xattn = MQAttention(dim)
self.tproj = nn.Linear(text_dim, dim)
self.tnorm = nn.GroupNorm(1, dim)
def forward(self, x, cond, text=None):
r = x
x = self.adaln(x, cond)
x = self.dw(x)
x = x.permute(0,2,3,1)
x = self.pw2(self.act(self.pw1(x)))
x = x.permute(0,3,1,2)
x = r + self.drop(x)
if self.xattn is not None and text is not None:
B,C,H,W = x.shape
xf = x.view(B,C,H*W).transpose(1,2)
xf = xf + self.xattn(xf, self.tproj(text))
x = xf.transpose(1,2).view(B,C,H,W)
return x
# ββ Down / Up ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class Down(nn.Module):
def __init__(self, cin, cout, cond_dim, text_dim=None, n=2, dp=0.0):
super().__init__()
self.blocks = nn.ModuleList([
ResBlock(cin if i==0 else cout, cond_dim, text_dim, dp) for i in range(n)
])
if cin != cout:
self.skip_conv = nn.Conv2d(cin, cout, 1)
else:
self.skip_conv = None
self.down = nn.Conv2d(cout, cout, 3, stride=2, padding=1)
def forward(self, x, cond, text=None):
for b in self.blocks:
x = b(x, cond, text)
skip = x
if self.skip_conv is not None:
skip = self.skip_conv(skip)
return self.down(x), skip
class Up(nn.Module):
def __init__(self, cin, cout, cond_dim, text_dim=None, n=2, dp=0.0):
super().__init__()
self.up = nn.ConvTranspose2d(cin, cin, 2, stride=2)
self.blocks = nn.ModuleList([
ResBlock(cin if i==0 else cout, cond_dim, text_dim, dp) for i in range(n)
])
self.skip_conv = nn.Conv2d(cin, cout, 1) if cin != cout else nn.Identity()
def forward(self, x, skip, cond, text=None):
x = self.up(x)
x = x + skip
for b in self.blocks:
x = b(x, cond, text)
x = self.skip_conv(x) if hasattr(self, 'skip_conv') and not isinstance(self.skip_conv, nn.Identity) and b == self.blocks[0] else x
return x
# ββ Main Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class LuminaRS(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
chs = cfg.channels # e.g. (64, 128, 256, 384)
t_dim = cfg.t_embed_dim
cond_dim = chs[-1] * 4
self.time_mlp = nn.Sequential(
nn.Linear(t_dim, cond_dim), nn.SiLU(), nn.Linear(cond_dim, cond_dim))
self.text_proj = nn.Linear(cfg.text_embed_dim, cfg.text_proj_dim)
self.in_conv = nn.Conv2d(cfg.latent_dim, chs[0], 3, padding=1)
# Encoder
self.enc = nn.ModuleList()
for i in range(len(chs)-1):
self.enc.append(Down(chs[i], chs[i+1], cond_dim, cfg.text_proj_dim, n=2, dp=cfg.drop_path))
# Bottleneck
self.bot = nn.ModuleList([
ResBlock(chs[-1], cond_dim, cfg.text_proj_dim, dp=cfg.drop_path)
for _ in range(cfg.n_bottleneck)])
# Decoder
self.dec = nn.ModuleList()
for i in range(len(chs)-1, 0, -1):
self.dec.append(Up(chs[i], chs[i-1], cond_dim, cfg.text_proj_dim, n=2, dp=cfg.drop_path))
self.out_conv = nn.Conv2d(chs[0], cfg.latent_dim, 1)
self.n_recurse = cfg.n_recurse
def forward(self, z, text_emb, t):
B = z.shape[0]
t_emb = self.time_mlp(timestep_embedding(t, self.cfg.t_embed_dim))
text = self.text_proj(text_emb)
x = self.in_conv(z)
for _ in range(self.n_recurse):
h = x; skips = []
for down in self.enc:
h, sk = down(h, t_emb, text)
skips.append(sk)
for blk in self.bot:
h = blk(h, t_emb, text)
for up in self.dec:
sk = skips.pop()
h = up(h, sk, t_emb, text)
x = x + h
return self.out_conv(x)
|