asdf98 commited on
Commit
c6e7340
·
verified ·
1 Parent(s): bc114d4

Upload luminars/model.py

Browse files
Files changed (1) hide show
  1. luminars/model.py +265 -0
luminars/model.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LuminaRS -- Lightweight Latent Recursive Diffusion.
3
+ A small UNet+iterative-refinement model (~110M params) for art/illustration generation.
4
+ Uses: pretrained VAE, pretrained CLIP text encoder (both frozen), custom lightweight UNet.
5
+ """
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from einops import rearrange
11
+
12
+
13
+ # ---------------------------------------------------------------------------
14
+ # Utilities
15
+ # ---------------------------------------------------------------------------
16
+ def timestep_embedding(t, dim, max_period=10000):
17
+ """Create sinusoidal timestep embeddings."""
18
+ half = dim // 2
19
+ freqs = torch.exp(-math.log(max_period) * torch.arange(0, half, dtype=torch.float32, device=t.device) / half)
20
+ args = t[:, None] * freqs[None]
21
+ emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
22
+ if dim % 2:
23
+ emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=-1)
24
+ return emb
25
+
26
+
27
+ class RMSNorm(nn.Module):
28
+ def __init__(self, dim, eps=1e-6):
29
+ super().__init__()
30
+ self.eps = eps
31
+ self.g = nn.Parameter(torch.ones(dim))
32
+ def forward(self, x):
33
+ return self.g * x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
34
+
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # Multi-Query Attention (MQA) -- faster than MHA on mobile
38
+ # ---------------------------------------------------------------------------
39
+ class MQAttention(nn.Module):
40
+ def __init__(self, dim, n_heads=8):
41
+ super().__init__()
42
+ assert dim % n_heads == 0
43
+ self.n_heads = n_heads
44
+ self.dh = dim // n_heads
45
+ self.scale = self.dh ** -0.5
46
+ self.q_proj = nn.Linear(dim, dim)
47
+ self.k_proj = nn.Linear(dim, dim)
48
+ self.v_proj = nn.Linear(dim, dim)
49
+ self.out_proj = nn.Linear(dim, dim)
50
+ def forward(self, x, context=None):
51
+ B, L, C = x.shape
52
+ if context is None:
53
+ context = x
54
+ q = self.q_proj(x).view(B, L, self.n_heads, self.dh).transpose(1, 2)
55
+ k = self.k_proj(context).view(B, -1, self.n_heads, self.dh).transpose(1, 2)
56
+ v = self.v_proj(context).view(B, -1, self.n_heads, self.dh).transpose(1, 2)
57
+ attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
58
+ attn = attn.softmax(dim=-1)
59
+ out = torch.matmul(attn, v).transpose(1, 2).reshape(B, L, C)
60
+ return self.out_proj(out)
61
+
62
+
63
+ # ---------------------------------------------------------------------------
64
+ # ConvNeXt-like Block (depthwise + pointwise + GELU)
65
+ # ---------------------------------------------------------------------------
66
+ class ConvNeXtBlock(nn.Module):
67
+ def __init__(self, dim, drop_path=0.0, text_dim=None):
68
+ super().__init__()
69
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
70
+ self.norm = nn.GroupNorm(1, dim)
71
+ self.pwconv1 = nn.Linear(dim, dim * 4)
72
+ self.act = nn.GELU()
73
+ self.pwconv2 = nn.Linear(dim * 4, dim)
74
+ self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1)) if drop_path == 0.0 else None
75
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
76
+
77
+ # Optional cross-attention for text conditioning
78
+ self.text_attn = None
79
+ if text_dim is not None:
80
+ self.text_norm = RMSNorm(dim)
81
+ self.text_attn = MQAttention(dim)
82
+ self.text_proj = nn.Linear(text_dim, dim)
83
+ def forward(self, x, text_emb=None):
84
+ shortcut = x
85
+ x = self.dwconv(x)
86
+ x = self.norm(x)
87
+ # pointwise via 1x1 conv (channel mixer)
88
+ x = x.permute(0, 2, 3, 1) # (B, H, W, C)
89
+ x = self.pwconv1(x)
90
+ x = self.act(x)
91
+ x = self.pwconv2(x)
92
+ x = x.permute(0, 3, 1, 2) # (B, C, H, W)
93
+ if self.gamma is not None:
94
+ x = x * self.gamma
95
+ x = shortcut + self.drop_path(x)
96
+
97
+ if self.text_attn is not None and text_emb is not None:
98
+ B, C, H, W = x.shape
99
+ x_flat = x.view(B, C, H * W).transpose(1, 2) # (B, HW, C)
100
+ x_flat = x_flat + self.text_attn(
101
+ self.text_norm(x_flat),
102
+ self.text_proj(text_emb)
103
+ )
104
+ x = x_flat.transpose(1, 2).view(B, C, H, W)
105
+ return x
106
+
107
+
108
+ class DropPath(nn.Module):
109
+ """Stochastic depth (drop path)."""
110
+ def __init__(self, drop_prob=0.0):
111
+ super().__init__()
112
+ self.drop_prob = drop_prob
113
+ def forward(self, x):
114
+ if self.drop_prob == 0.0 or not self.training:
115
+ return x
116
+ keep_prob = 1 - self.drop_prob
117
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
118
+ return x * keep_prob + x * torch.zeros(shape, device=x.device).bernoulli_(keep_prob)
119
+
120
+
121
+ # ---------------------------------------------------------------------------
122
+ # Down/Up blocks
123
+ # ---------------------------------------------------------------------------
124
+ class DownBlock(nn.Module):
125
+ def __init__(self, in_ch, out_ch, n_blocks=2, text_dim=None, drop_path=0.0):
126
+ super().__init__()
127
+ self.blocks = nn.ModuleList([
128
+ ConvNeXtBlock(in_ch if i == 0 else out_ch, drop_path=drop_path, text_dim=text_dim)
129
+ for i in range(n_blocks)
130
+ ])
131
+ self.down = nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=2, padding=1)
132
+ def forward(self, x, text_emb=None):
133
+ for blk in self.blocks:
134
+ x = blk(x, text_emb)
135
+ x = self.down(x)
136
+ return x
137
+
138
+ class UpBlock(nn.Module):
139
+ def __init__(self, in_ch, out_ch, n_blocks=2, text_dim=None, drop_path=0.0):
140
+ super().__init__()
141
+ self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
142
+ self.blocks = nn.ModuleList([
143
+ ConvNeXtBlock(out_ch, drop_path=drop_path, text_dim=text_dim)
144
+ for _ in range(n_blocks)
145
+ ])
146
+ def forward(self, x, skip, text_emb=None):
147
+ x = self.up(x)
148
+ x = x + skip
149
+ for blk in self.blocks:
150
+ x = blk(x, text_emb)
151
+ return x
152
+
153
+
154
+ # ---------------------------------------------------------------------------
155
+ # Time Embedder
156
+ # ---------------------------------------------------------------------------
157
+ class TimeEmbed(nn.Module):
158
+ def __init__(self, t_dim=256, out_dim=256):
159
+ super().__init__()
160
+ self.mlp = nn.Sequential(
161
+ nn.Linear(t_dim, out_dim),
162
+ nn.SiLU(),
163
+ nn.Linear(out_dim, out_dim),
164
+ )
165
+ def forward(self, t):
166
+ return self.mlp(timestep_embedding(t, self.mlp[0].in_features))
167
+
168
+
169
+ # ---------------------------------------------------------------------------
170
+ # MAIN MODEL: LuminaRS
171
+ # ---------------------------------------------------------------------------
172
+ class LuminaRS(nn.Module):
173
+ """
174
+ Lightweight latent diffusion model with iterative refinement.
175
+
176
+ Architecture (1024x1024 target, 32x32x16 latent):
177
+ - Encoder: 16 -> 32 -> 64 -> 128 -> 256 (channels at each scale)
178
+ - Bottleneck: 256-ch blocks
179
+ - Decoder: 256 -> 128 -> 64 -> 32 -> 16 (with skip)
180
+ - Cross-attention at every block (MQA)
181
+ - Shared weights applied recursively T times per denoising step (like TRM/HRM)
182
+ """
183
+ def __init__(self, cfg):
184
+ super().__init__()
185
+ self.cfg = cfg
186
+ chs = cfg.channels
187
+ self.time_embed = TimeEmbed(cfg.t_embed_dim, cfg.channels[0] * 4)
188
+
189
+ # Project time into each scale
190
+ self.time_projs = nn.ModuleList([nn.Linear(cfg.channels[0] * 4, c) for c in chs])
191
+
192
+ # Text conditioning (use frozen CLIP text encoder externally)
193
+ self.text_proj = nn.Linear(cfg.text_embed_dim, cfg.channels[0])
194
+
195
+ # --- Encoder ---
196
+ self.in_conv = nn.Conv2d(cfg.latent_dim, chs[0], kernel_size=3, padding=1)
197
+ self.enc_blocks = nn.ModuleList()
198
+ for i in range(len(chs) - 1):
199
+ self.enc_blocks.append(DownBlock(chs[i], chs[i+1], n_blocks=2,
200
+ text_dim=cfg.channels[0], drop_path=cfg.drop_path))
201
+
202
+ # --- Bottleneck ---
203
+ self.bottleneck = nn.ModuleList([
204
+ ConvNeXtBlock(chs[-1], drop_path=cfg.drop_path, text_dim=cfg.channels[0])
205
+ for _ in range(cfg.n_bottleneck)
206
+ ])
207
+
208
+ # --- Decoder ---
209
+ self.dec_blocks = nn.ModuleList()
210
+ for i in range(len(chs) - 1, 0, -1):
211
+ self.dec_blocks.append(UpBlock(chs[i], chs[i-1], n_blocks=2,
212
+ text_dim=cfg.channels[0], drop_path=cfg.drop_path))
213
+
214
+ self.out_conv = nn.Conv2d(chs[0], cfg.latent_dim, kernel_size=1)
215
+
216
+ # --- Iterative Refinement (recursive depth like TRM) ---
217
+ self.n_recurse = cfg.n_recurse # T: number of shared-weight passes
218
+
219
+ def forward(self, z, text_emb, t):
220
+ """
221
+ z: (B, latent_dim, H, W) -- noisy latent
222
+ text_emb: (B, L, text_embed_dim) -- CLIP text embeddings
223
+ t: (B,) -- timestep (0=noise, 1=clean for flow matching)
224
+ Returns: (B, latent_dim, H, W) -- predicted velocity / noise
225
+ """
226
+ B = z.shape[0]
227
+
228
+ # Time embedding
229
+ t_emb = self.time_embed(t) # (B, C0*4)
230
+
231
+ # Text projection
232
+ text_cond = self.text_proj(text_emb) # (B, L, C0)
233
+
234
+ # --- RECURSIVE REFINEMENT (TRM-style shared-weight loops) ---
235
+ x = self.in_conv(z)
236
+
237
+ for _ in range(self.n_recurse):
238
+ # Encoder
239
+ skips = []
240
+ h = x
241
+ for i, down in enumerate(self.enc_blocks):
242
+ t_scale = self.time_projs[i](t_emb)[:, :, None, None]
243
+ h = h + t_scale
244
+ h = down(h, text_cond)
245
+ skips.append(h)
246
+
247
+ # Bottleneck
248
+ for blk in self.bottleneck:
249
+ h = blk(h, text_cond)
250
+
251
+ # Decoder
252
+ for i, up in enumerate(self.dec_blocks):
253
+ t_scale = self.time_projs[len(self.enc_blocks) - i](t_emb)[:, :, None, None]
254
+ h = h + t_scale
255
+ skip = skips[len(skips) - 1 - i]
256
+ h = up(h, skip, text_cond)
257
+
258
+ x = x + h # residual update (like TRM iterative refinement)
259
+
260
+ return self.out_conv(x)
261
+
262
+ def count_params(self):
263
+ total = sum(p.numel() for p in self.parameters())
264
+ train = sum(p.numel() for p in self.parameters() if p.requires_grad)
265
+ return total, train