asdf98 commited on
Commit
2c6f96a
·
verified ·
1 Parent(s): ef19514

Add microforge/pipeline.py

Browse files
Files changed (1) hide show
  1. microforge/pipeline.py +335 -0
microforge/pipeline.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MicroForge Pipeline: End-to-End Generation and Editing
3
+ =======================================================
4
+
5
+ Unified pipeline for:
6
+ - Text-to-image generation
7
+ - Image-to-image editing (spatial concat, DreamLite-style)
8
+ - Inpainting (masked spatial concat)
9
+ - Super-resolution (low-res spatial concat)
10
+
11
+ The key insight (from DreamLite): spatial concatenation preserves generation
12
+ priors when adding editing capabilities. The same backbone handles all tasks
13
+ by varying what goes into the "context" panel:
14
+ - Generation: context = blank (zeros)
15
+ - Editing: context = source image latent
16
+ - Inpainting: context = masked source image latent
17
+ - Super-res: context = upsampled low-res latent
18
+ """
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from typing import Optional, Dict, List, Tuple, Union
24
+
25
+ from .vae import MicroForgeVAE
26
+ from .backbone import MicroForgeBackbone
27
+ from .planner import RecurrentLatentPlanner
28
+ from .training import FlowMatchingScheduler
29
+
30
+
31
+ class SimpleTextEncoder(nn.Module):
32
+ """
33
+ Lightweight text encoder for prototyping.
34
+ In production, replace with CLIP-L or a small LLM (Gemma-2B).
35
+
36
+ This uses a small transformer on learned token embeddings.
37
+ For the prototype, we support:
38
+ 1. Random projection (for testing)
39
+ 2. Simple learned embedding (for small-scale training)
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ vocab_size: int = 8192,
45
+ max_seq_len: int = 77,
46
+ embed_dim: int = 768,
47
+ num_heads: int = 8,
48
+ num_layers: int = 4,
49
+ ):
50
+ super().__init__()
51
+ self.embed_dim = embed_dim
52
+
53
+ self.token_embed = nn.Embedding(vocab_size, embed_dim)
54
+ self.pos_embed = nn.Parameter(torch.randn(1, max_seq_len, embed_dim) * 0.02)
55
+
56
+ encoder_layer = nn.TransformerEncoderLayer(
57
+ d_model=embed_dim,
58
+ nhead=num_heads,
59
+ dim_feedforward=embed_dim * 4,
60
+ batch_first=True,
61
+ norm_first=True,
62
+ )
63
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
64
+
65
+ self.final_norm = nn.LayerNorm(embed_dim)
66
+ self.pool_proj = nn.Linear(embed_dim, embed_dim)
67
+
68
+ def forward(self, token_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
69
+ """
70
+ Args:
71
+ token_ids: [B, L] integer token IDs
72
+
73
+ Returns:
74
+ text_emb: [B, L, D] token-level embeddings
75
+ text_pooled: [B, D] pooled embedding
76
+ """
77
+ x = self.token_embed(token_ids) + self.pos_embed[:, :token_ids.shape[1], :]
78
+ x = self.encoder(x)
79
+ x = self.final_norm(x)
80
+
81
+ # Pool: mean of all tokens
82
+ pooled = x.mean(dim=1)
83
+ pooled = self.pool_proj(pooled)
84
+
85
+ return x, pooled
86
+
87
+ def encode_text_simple(self, batch_size: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
88
+ """
89
+ Generate random text embeddings for testing.
90
+ """
91
+ text_emb = torch.randn(batch_size, 77, self.embed_dim, device=device)
92
+ text_pooled = torch.randn(batch_size, self.embed_dim, device=device)
93
+ return text_emb, text_pooled
94
+
95
+
96
+ class MicroForgePipeline:
97
+ """
98
+ End-to-end MicroForge pipeline.
99
+
100
+ Supports:
101
+ - text2img: Generate image from text
102
+ - img2img: Edit image with text guidance
103
+ - inpaint: Fill masked region with text guidance
104
+ - super_res: Upscale image with text guidance
105
+
106
+ All tasks use the same backbone via spatial concatenation.
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ vae: MicroForgeVAE,
112
+ backbone: MicroForgeBackbone,
113
+ text_encoder: SimpleTextEncoder,
114
+ planner: Optional[RecurrentLatentPlanner] = None,
115
+ device: str = 'cpu',
116
+ ):
117
+ self.vae = vae.eval()
118
+ self.backbone = backbone.eval()
119
+ self.text_encoder = text_encoder.eval()
120
+ self.planner = planner.eval() if planner is not None else None
121
+ self.device = torch.device(device)
122
+ self.scheduler = FlowMatchingScheduler()
123
+
124
+ # Move to device
125
+ self.vae.to(self.device)
126
+ self.backbone.to(self.device)
127
+ self.text_encoder.to(self.device)
128
+ if self.planner is not None:
129
+ self.planner.to(self.device)
130
+
131
+ @torch.no_grad()
132
+ def text2img(
133
+ self,
134
+ text_tokens: torch.Tensor,
135
+ height: int = 256,
136
+ width: int = 256,
137
+ num_steps: int = 20,
138
+ cfg_scale: float = 7.5,
139
+ seed: Optional[int] = None,
140
+ ) -> torch.Tensor:
141
+ """
142
+ Generate image from text.
143
+
144
+ Args:
145
+ text_tokens: [B, L] token IDs
146
+ height, width: output image size
147
+ num_steps: denoising steps
148
+ cfg_scale: classifier-free guidance scale
149
+ seed: random seed
150
+
151
+ Returns:
152
+ images: [B, 3, H, W] generated images in [-1, 1]
153
+ """
154
+ if seed is not None:
155
+ torch.manual_seed(seed)
156
+
157
+ B = text_tokens.shape[0]
158
+
159
+ # Encode text
160
+ text_emb, text_pooled = self.text_encoder(text_tokens.to(self.device))
161
+
162
+ # Latent dimensions (32x spatial compression)
163
+ latent_h = height // 32
164
+ latent_w = width // 32
165
+ latent_c = self.vae.latent_channels
166
+
167
+ # Sample noise
168
+ noise = torch.randn(B, latent_c, latent_h, latent_w, device=self.device)
169
+
170
+ # Denoise
171
+ z_0 = self.scheduler.sample(
172
+ self.backbone, noise, text_emb, text_pooled,
173
+ num_steps=num_steps, cfg_scale=cfg_scale,
174
+ planner=self.planner,
175
+ )
176
+
177
+ # Decode
178
+ images = self.vae.decode(z_0)
179
+ return images.clamp(-1, 1)
180
+
181
+ @torch.no_grad()
182
+ def img2img(
183
+ self,
184
+ source_image: torch.Tensor,
185
+ text_tokens: torch.Tensor,
186
+ strength: float = 0.7,
187
+ num_steps: int = 20,
188
+ cfg_scale: float = 7.5,
189
+ ) -> torch.Tensor:
190
+ """
191
+ Edit image with text guidance using spatial concatenation.
192
+ The source image latent is concatenated width-wise with the target latent.
193
+
194
+ Args:
195
+ source_image: [B, 3, H, W] source image
196
+ text_tokens: [B, L] edit instruction tokens
197
+ strength: how much to change (0=no change, 1=full regen)
198
+ num_steps: denoising steps
199
+ cfg_scale: guidance scale
200
+
201
+ Returns:
202
+ edited_images: [B, 3, H, W]
203
+ """
204
+ B = source_image.shape[0]
205
+
206
+ # Encode text and source
207
+ text_emb, text_pooled = self.text_encoder(text_tokens.to(self.device))
208
+ source_latent = self.vae.get_latent(source_image.to(self.device))
209
+
210
+ # Create noised target (start from partial noise of source)
211
+ noise = torch.randn_like(source_latent)
212
+ t_start = torch.tensor([strength], device=self.device)
213
+ z_t = (1 - t_start) * source_latent + t_start * noise
214
+
215
+ # Spatial concatenation: [target | source] along width
216
+ # This doubles the width of the latent
217
+ # The backbone processes both together
218
+ timesteps = torch.linspace(strength, 0, num_steps + 1, device=self.device)
219
+
220
+ for i in range(num_steps):
221
+ t = timesteps[i]
222
+ t_next = timesteps[i + 1]
223
+ t_batch = torch.full((B,), t, device=self.device)
224
+
225
+ # Concat: [target_noised | source_clean]
226
+ z_concat = torch.cat([z_t, source_latent], dim=-1) # Width concat
227
+
228
+ v_pred = self.backbone(z_concat, t_batch, text_emb, text_pooled)
229
+
230
+ # Only take the target half of the prediction
231
+ v_target = v_pred[..., :z_t.shape[-1]]
232
+
233
+ z_t = self.scheduler.euler_step(z_t, v_target, t.item(), t_next.item())
234
+
235
+ images = self.vae.decode(z_t)
236
+ return images.clamp(-1, 1)
237
+
238
+ @torch.no_grad()
239
+ def inpaint(
240
+ self,
241
+ image: torch.Tensor,
242
+ mask: torch.Tensor,
243
+ text_tokens: torch.Tensor,
244
+ num_steps: int = 20,
245
+ cfg_scale: float = 7.5,
246
+ ) -> torch.Tensor:
247
+ """
248
+ Inpaint masked region.
249
+
250
+ Args:
251
+ image: [B, 3, H, W] source image
252
+ mask: [B, 1, H, W] binary mask (1=inpaint region)
253
+ text_tokens: [B, L] description of what to fill
254
+ num_steps: denoising steps
255
+ cfg_scale: guidance scale
256
+
257
+ Returns:
258
+ inpainted: [B, 3, H, W]
259
+ """
260
+ B = image.shape[0]
261
+
262
+ text_emb, text_pooled = self.text_encoder(text_tokens.to(self.device))
263
+ source_latent = self.vae.get_latent(image.to(self.device))
264
+
265
+ # Downsample mask to latent size
266
+ latent_mask = F.interpolate(mask.float(), size=source_latent.shape[2:], mode='nearest')
267
+
268
+ # Masked source: zero out inpaint region
269
+ masked_source = source_latent * (1 - latent_mask)
270
+
271
+ # Generate in masked region
272
+ noise = torch.randn_like(source_latent)
273
+ z_t = noise
274
+
275
+ timesteps = torch.linspace(1, 0, num_steps + 1, device=self.device)
276
+
277
+ for i in range(num_steps):
278
+ t = timesteps[i]
279
+ t_next = timesteps[i + 1]
280
+ t_batch = torch.full((B,), t, device=self.device)
281
+
282
+ # Concat masked source as context
283
+ z_concat = torch.cat([z_t, masked_source], dim=-1)
284
+ v_pred = self.backbone(z_concat, t_batch, text_emb, text_pooled)
285
+ v_target = v_pred[..., :z_t.shape[-1]]
286
+
287
+ z_t = self.scheduler.euler_step(z_t, v_target, t.item(), t_next.item())
288
+
289
+ # Replace unmasked region with source
290
+ z_t = z_t * latent_mask + source_latent * (1 - latent_mask)
291
+
292
+ images = self.vae.decode(z_t)
293
+ return images.clamp(-1, 1)
294
+
295
+ def get_memory_estimate(self, height: int = 512, width: int = 512) -> Dict[str, float]:
296
+ """
297
+ Estimate memory usage in MB for given resolution.
298
+ """
299
+ # Model parameters
300
+ vae_params = sum(p.numel() for p in self.vae.parameters()) * 4 / 1e6
301
+ backbone_params = sum(p.numel() for p in self.backbone.parameters()) * 4 / 1e6
302
+ text_params = sum(p.numel() for p in self.text_encoder.parameters()) * 4 / 1e6
303
+ planner_params = 0
304
+ if self.planner is not None:
305
+ planner_params = sum(p.numel() for p in self.planner.parameters()) * 4 / 1e6
306
+
307
+ # Activation memory (rough estimate)
308
+ latent_h = height // 32
309
+ latent_w = width // 32
310
+ latent_size = latent_h * latent_w * self.vae.latent_channels * 4 / 1e6 # MB
311
+
312
+ return {
313
+ 'vae_params_mb': vae_params,
314
+ 'backbone_params_mb': backbone_params,
315
+ 'text_encoder_params_mb': text_params,
316
+ 'planner_params_mb': planner_params,
317
+ 'total_params_mb': vae_params + backbone_params + text_params + planner_params,
318
+ 'latent_size_mb': latent_size,
319
+ 'estimated_inference_mb': (vae_params + backbone_params + text_params + planner_params) * 1.3, # +30% overhead
320
+ }
321
+
322
+ def count_parameters(self) -> Dict[str, int]:
323
+ """Count parameters per module."""
324
+ return {
325
+ 'vae': sum(p.numel() for p in self.vae.parameters()),
326
+ 'backbone': sum(p.numel() for p in self.backbone.parameters()),
327
+ 'text_encoder': sum(p.numel() for p in self.text_encoder.parameters()),
328
+ 'planner': sum(p.numel() for p in self.planner.parameters()) if self.planner else 0,
329
+ 'total': (
330
+ sum(p.numel() for p in self.vae.parameters()) +
331
+ sum(p.numel() for p in self.backbone.parameters()) +
332
+ sum(p.numel() for p in self.text_encoder.parameters()) +
333
+ (sum(p.numel() for p in self.planner.parameters()) if self.planner else 0)
334
+ ),
335
+ }