asdf98 commited on
Commit
6e28f73
·
verified ·
1 Parent(s): b0bf58d

Add microforge/vae.py

Browse files
Files changed (1) hide show
  1. microforge/vae.py +337 -0
microforge/vae.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MicroForge VAE: Deep Compression Autoencoder
3
+ =============================================
4
+
5
+ Inspired by DC-AE (arxiv:2410.10733) and TinyVAE (DreamLite).
6
+ Key innovations for mobile:
7
+ - 32x spatial compression (512px -> 16x16 latent grid)
8
+ - Residual autoencoding with space-to-channel shortcuts
9
+ - Lightweight decoder (<3M params) for mobile deployment
10
+ - KL-regularized continuous latent space
11
+
12
+ Architecture:
13
+ Encoder: [3,H,W] -> conv_in -> DownBlock x4 (stride 2 each) -> [C_latent, H/32, W/32]
14
+ Each DownBlock: ResBlock + optional Attention (only at lowest res) + Downsample
15
+ Residual shortcut: space_to_channel rearrange on skip connections
16
+ Decoder: Mirror of encoder with PixelShuffle upsampling
17
+
18
+ For 512px input:
19
+ Latent = [32, 16, 16] = 8192 values (vs SD-VAE's 16384)
20
+ Spatial tokens for backbone = 256 (16x16) = 16x fewer than SD-VAE's 4096
21
+ """
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ from typing import Optional, Tuple
27
+
28
+
29
+ class ResBlock(nn.Module):
30
+ """Efficient residual block with optional group norm."""
31
+ def __init__(self, in_ch: int, out_ch: int, groups: int = 8):
32
+ super().__init__()
33
+ self.norm1 = nn.GroupNorm(groups, in_ch)
34
+ self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
35
+ self.norm2 = nn.GroupNorm(groups, out_ch)
36
+ self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
37
+ self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
38
+ self.act = nn.SiLU(inplace=True)
39
+
40
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
41
+ h = self.act(self.norm1(x))
42
+ h = self.conv1(h)
43
+ h = self.act(self.norm2(h))
44
+ h = self.conv2(h)
45
+ return h + self.skip(x)
46
+
47
+
48
+ class ExpandedSeparableConv(nn.Module):
49
+ """
50
+ UIB-style expanded separable convolution (from SnapGen).
51
+ DW -> PW expand -> PW project. 24% fewer params than standard conv.
52
+ """
53
+ def __init__(self, channels: int, expansion: int = 2):
54
+ super().__init__()
55
+ expanded = channels * expansion
56
+ self.dw = nn.Conv2d(channels, channels, 3, padding=1, groups=channels)
57
+ self.pw_expand = nn.Conv2d(channels, expanded, 1)
58
+ self.act = nn.SiLU(inplace=True)
59
+ self.pw_project = nn.Conv2d(expanded, channels, 1)
60
+ self.norm = nn.GroupNorm(8, channels)
61
+
62
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
63
+ h = self.norm(x)
64
+ h = self.dw(h)
65
+ h = self.pw_expand(h)
66
+ h = self.act(h)
67
+ h = self.pw_project(h)
68
+ return h + x
69
+
70
+
71
+ class SpaceToChannel(nn.Module):
72
+ """
73
+ Residual space-to-channel shortcut (DC-AE key innovation).
74
+ Rearranges spatial dims into channels for non-parametric skip.
75
+ [B, C, H, W] -> [B, C*factor^2, H/factor, W/factor]
76
+ """
77
+ def __init__(self, factor: int = 2):
78
+ super().__init__()
79
+ self.factor = factor
80
+
81
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
82
+ B, C, H, W = x.shape
83
+ f = self.factor
84
+ # Rearrange: (B, C, H, W) -> (B, C*f*f, H/f, W/f)
85
+ x = x.reshape(B, C, H // f, f, W // f, f)
86
+ x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
87
+ x = x.reshape(B, C * f * f, H // f, W // f)
88
+ return x
89
+
90
+
91
+ class ChannelToSpace(nn.Module):
92
+ """Inverse of SpaceToChannel for decoder skip connections."""
93
+ def __init__(self, factor: int = 2):
94
+ super().__init__()
95
+ self.factor = factor
96
+
97
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
98
+ B, C, H, W = x.shape
99
+ f = self.factor
100
+ c_out = C // (f * f)
101
+ x = x.reshape(B, c_out, f, f, H, W)
102
+ x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
103
+ x = x.reshape(B, c_out, H * f, W * f)
104
+ return x
105
+
106
+
107
+ class EncoderBlock(nn.Module):
108
+ """Encoder block: ResBlocks + optional attention + downsample."""
109
+ def __init__(self, in_ch: int, out_ch: int, num_res: int = 2, use_attn: bool = False):
110
+ super().__init__()
111
+ self.res_blocks = nn.ModuleList()
112
+ self.res_blocks.append(ResBlock(in_ch, out_ch))
113
+ for _ in range(num_res - 1):
114
+ self.res_blocks.append(ResBlock(out_ch, out_ch))
115
+
116
+ self.sep_conv = ExpandedSeparableConv(out_ch)
117
+
118
+ # Self-attention only at bottleneck (lowest resolution)
119
+ self.use_attn = use_attn
120
+ if use_attn:
121
+ self.attn_norm = nn.GroupNorm(8, out_ch)
122
+ self.attn = nn.MultiheadAttention(out_ch, num_heads=4, batch_first=True)
123
+
124
+ self.downsample = nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1)
125
+ # Residual shortcut
126
+ self.space_to_channel = SpaceToChannel(factor=2)
127
+ self.shortcut_proj = nn.Conv2d(in_ch * 4, out_ch, 1) # project after space-to-channel
128
+
129
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
130
+ # Space-to-channel residual shortcut
131
+ shortcut = self.space_to_channel(x)
132
+ shortcut = self.shortcut_proj(shortcut)
133
+
134
+ for res in self.res_blocks:
135
+ x = res(x)
136
+ x = self.sep_conv(x)
137
+
138
+ if self.use_attn:
139
+ B, C, H, W = x.shape
140
+ h = self.attn_norm(x).reshape(B, C, -1).permute(0, 2, 1)
141
+ h, _ = self.attn(h, h, h)
142
+ x = x + h.permute(0, 2, 1).reshape(B, C, H, W)
143
+
144
+ x = self.downsample(x)
145
+ x = x + shortcut # Residual autoencoding
146
+ return x
147
+
148
+
149
+ class DecoderBlock(nn.Module):
150
+ """Decoder block: Upsample + ResBlocks + optional attention."""
151
+ def __init__(self, in_ch: int, out_ch: int, num_res: int = 2, use_attn: bool = False):
152
+ super().__init__()
153
+ # Upsample first
154
+ self.upsample = nn.Sequential(
155
+ nn.Conv2d(in_ch, in_ch * 4, 3, padding=1),
156
+ nn.PixelShuffle(2),
157
+ )
158
+ self.channel_to_space = ChannelToSpace(factor=2)
159
+ self.shortcut_proj = nn.Conv2d(in_ch // 4, out_ch, 1) if in_ch // 4 != out_ch else nn.Identity()
160
+
161
+ self.res_blocks = nn.ModuleList()
162
+ self.res_blocks.append(ResBlock(in_ch, out_ch))
163
+ for _ in range(num_res - 1):
164
+ self.res_blocks.append(ResBlock(out_ch, out_ch))
165
+
166
+ self.sep_conv = ExpandedSeparableConv(out_ch)
167
+
168
+ self.use_attn = use_attn
169
+ if use_attn:
170
+ self.attn_norm = nn.GroupNorm(8, out_ch)
171
+ self.attn = nn.MultiheadAttention(out_ch, num_heads=4, batch_first=True)
172
+
173
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
174
+ # Channel-to-space residual shortcut
175
+ shortcut = self.channel_to_space(x)
176
+ shortcut = self.shortcut_proj(shortcut)
177
+
178
+ x = self.upsample(x)
179
+
180
+ for res in self.res_blocks:
181
+ x = res(x)
182
+ x = self.sep_conv(x)
183
+
184
+ if self.use_attn:
185
+ B, C, H, W = x.shape
186
+ h = self.attn_norm(x).reshape(B, C, -1).permute(0, 2, 1)
187
+ h, _ = self.attn(h, h, h)
188
+ x = x + h.permute(0, 2, 1).reshape(B, C, H, W)
189
+
190
+ x = x + shortcut
191
+ return x
192
+
193
+
194
+ class MicroForgeVAE(nn.Module):
195
+ """
196
+ MicroForge VAE: Deep Compression Autoencoder
197
+
198
+ 32× spatial compression with residual space-to-channel shortcuts.
199
+ For 512px input: latent = [32, 16, 16] = 8192 values
200
+
201
+ Architecture sizes:
202
+ - Tiny (for mobile decode): ~2.5M params decoder
203
+ - Small (for training): ~12M params total
204
+ - Base (full quality): ~25M params total
205
+ """
206
+
207
+ CONFIGS = {
208
+ 'tiny': {
209
+ 'enc_channels': [32, 64, 128, 256],
210
+ 'latent_channels': 16,
211
+ 'num_res_blocks': 1,
212
+ },
213
+ 'small': {
214
+ 'enc_channels': [64, 128, 256, 512],
215
+ 'latent_channels': 32,
216
+ 'num_res_blocks': 2,
217
+ },
218
+ 'base': {
219
+ 'enc_channels': [128, 256, 512, 512],
220
+ 'latent_channels': 32,
221
+ 'num_res_blocks': 2,
222
+ }
223
+ }
224
+
225
+ def __init__(
226
+ self,
227
+ in_channels: int = 3,
228
+ config: str = 'small',
229
+ latent_channels: Optional[int] = None,
230
+ ):
231
+ super().__init__()
232
+ cfg = self.CONFIGS[config]
233
+ channels = cfg['enc_channels']
234
+ self.latent_channels = latent_channels or cfg['latent_channels']
235
+ num_res = cfg['num_res_blocks']
236
+
237
+ # Encoder: 5 stages of 2× downsample = 32× total
238
+ self.conv_in = nn.Conv2d(in_channels, channels[0], 3, padding=1)
239
+
240
+ self.encoder_blocks = nn.ModuleList()
241
+ in_ch = channels[0]
242
+ for i, out_ch in enumerate(channels):
243
+ use_attn = (i == len(channels) - 1) # Attention only at bottleneck
244
+ self.encoder_blocks.append(EncoderBlock(in_ch, out_ch, num_res, use_attn))
245
+ in_ch = out_ch
246
+
247
+ # Extra downsample to reach 32× (4 blocks = 16×, need one more 2×)
248
+ self.extra_down = nn.Sequential(
249
+ ResBlock(channels[-1], channels[-1]),
250
+ nn.Conv2d(channels[-1], channels[-1], 3, stride=2, padding=1),
251
+ )
252
+
253
+ # To latent: mu and log_var
254
+ self.to_mu = nn.Conv2d(channels[-1], self.latent_channels, 1)
255
+ self.to_logvar = nn.Conv2d(channels[-1], self.latent_channels, 1)
256
+
257
+ # From latent
258
+ self.from_latent = nn.Conv2d(self.latent_channels, channels[-1], 1)
259
+
260
+ # Extra upsample
261
+ self.extra_up = nn.Sequential(
262
+ ResBlock(channels[-1], channels[-1]),
263
+ nn.Conv2d(channels[-1], channels[-1] * 4, 3, padding=1),
264
+ nn.PixelShuffle(2),
265
+ )
266
+
267
+ # Decoder: mirror of encoder
268
+ self.decoder_blocks = nn.ModuleList()
269
+ dec_channels = list(reversed(channels))
270
+ in_ch = dec_channels[0]
271
+ for i, out_ch in enumerate(dec_channels):
272
+ use_attn = (i == 0) # Attention at first (lowest res) decoder block
273
+ self.decoder_blocks.append(DecoderBlock(in_ch, out_ch, num_res, use_attn))
274
+ in_ch = out_ch
275
+
276
+ self.conv_out = nn.Sequential(
277
+ nn.GroupNorm(8, dec_channels[-1]),
278
+ nn.SiLU(),
279
+ nn.Conv2d(dec_channels[-1], in_channels, 3, padding=1),
280
+ )
281
+
282
+ self._init_weights()
283
+
284
+ def _init_weights(self):
285
+ for m in self.modules():
286
+ if isinstance(m, nn.Conv2d):
287
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
288
+ if m.bias is not None:
289
+ nn.init.zeros_(m.bias)
290
+
291
+ def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
292
+ """Encode image to latent distribution parameters."""
293
+ h = self.conv_in(x)
294
+ for block in self.encoder_blocks:
295
+ h = block(h)
296
+ h = self.extra_down(h)
297
+ mu = self.to_mu(h)
298
+ logvar = self.to_logvar(h).clamp(-30.0, 20.0) # Clamp for numerical stability
299
+ return mu, logvar
300
+
301
+ def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
302
+ """Sample from latent distribution using reparameterization trick."""
303
+ if self.training:
304
+ std = torch.exp(0.5 * logvar)
305
+ eps = torch.randn_like(std)
306
+ return mu + eps * std
307
+ return mu
308
+
309
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
310
+ """Decode latent to image."""
311
+ h = self.from_latent(z)
312
+ h = self.extra_up(h)
313
+ for block in self.decoder_blocks:
314
+ h = block(h)
315
+ return self.conv_out(h)
316
+
317
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
318
+ """Full forward pass: encode -> reparameterize -> decode."""
319
+ mu, logvar = self.encode(x)
320
+ z = self.reparameterize(mu, logvar)
321
+ x_recon = self.decode(z)
322
+ return x_recon, mu, logvar
323
+
324
+ def get_latent(self, x: torch.Tensor) -> torch.Tensor:
325
+ """Get deterministic latent (mu only, for inference)."""
326
+ mu, _ = self.encode(x)
327
+ return mu
328
+
329
+ @staticmethod
330
+ def kl_loss(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
331
+ """KL divergence loss for VAE."""
332
+ return -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
333
+
334
+ @staticmethod
335
+ def recon_loss(x_recon: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
336
+ """Reconstruction loss (L1 + perceptual placeholder)."""
337
+ return F.l1_loss(x_recon, x)