asdf98 commited on
Commit
ef19514
·
verified ·
1 Parent(s): e80ddea

Add microforge/training.py

Browse files
Files changed (1) hide show
  1. microforge/training.py +385 -0
microforge/training.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MicroForge Training: Rectified Flow + Consistency Distillation
3
+ ===============================================================
4
+
5
+ Training objectives:
6
+ 1. Rectified Flow (primary): learn velocity v(z_t, t) = z_1 - z_0
7
+ 2. Consistency Distillation (secondary): for few-step inference
8
+ 3. VAE losses: L1 recon + KL + perceptual (LPIPS placeholder)
9
+
10
+ Rectified Flow formulation:
11
+ z_t = (1-t) * z_0 + t * epsilon (linear interpolation)
12
+ v_target = epsilon - z_0 (velocity)
13
+ L_flow = ||v_theta(z_t, t) - v_target||^2
14
+
15
+ Logit-normal timestep sampling (from SnapGen/SD3):
16
+ t ~ sigma(Normal(mean, std)) where mean=0, std=1
17
+ This puts more weight on intermediate timesteps.
18
+
19
+ Staged curriculum (from DreamLite + SnapGen):
20
+ Stage 1: Low-res composition (128-256px)
21
+ Stage 2: Texture refinement (256-512px)
22
+ Stage 3: High-res detail (512-1024px)
23
+ Stage 4: Editing tasks (with spatial concat)
24
+ Stage 5: Step distillation (LADD or consistency)
25
+ """
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+ import math
31
+ from typing import Optional, Dict, Tuple
32
+
33
+
34
+ class FlowMatchingScheduler:
35
+ """
36
+ Rectified Flow / Flow Matching schedule.
37
+
38
+ Forward process: z_t = (1-t) * z_0 + t * epsilon
39
+ Velocity: v = epsilon - z_0
40
+ At t=0: z_t = z_0 (clean)
41
+ At t=1: z_t = epsilon (noise)
42
+
43
+ Timestep sampling: logit-normal distribution
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ logit_mean: float = 0.0,
49
+ logit_std: float = 1.0,
50
+ time_shift: float = 3.0,
51
+ ):
52
+ self.logit_mean = logit_mean
53
+ self.logit_std = logit_std
54
+ self.time_shift = time_shift
55
+
56
+ def sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor:
57
+ """
58
+ Sample timesteps from logit-normal distribution.
59
+ Returns t in [0, 1].
60
+ """
61
+ u = torch.randn(batch_size, device=device) * self.logit_std + self.logit_mean
62
+ t = torch.sigmoid(u)
63
+
64
+ # Dynamic time shifting (from FLUX/DreamLite)
65
+ if self.time_shift != 1.0:
66
+ t = self.time_shift * t / (1 + (self.time_shift - 1) * t)
67
+
68
+ return t
69
+
70
+ def add_noise(
71
+ self,
72
+ z_0: torch.Tensor,
73
+ noise: torch.Tensor,
74
+ t: torch.Tensor,
75
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
76
+ """
77
+ Create noised sample and target velocity.
78
+
79
+ z_t = (1-t) * z_0 + t * epsilon
80
+ v_target = epsilon - z_0
81
+
82
+ Args:
83
+ z_0: [B, C, H, W] clean latent
84
+ noise: [B, C, H, W] standard normal noise
85
+ t: [B] timesteps
86
+
87
+ Returns:
88
+ z_t: [B, C, H, W] noised latent
89
+ v_target: [B, C, H, W] target velocity
90
+ """
91
+ t_expanded = t[:, None, None, None] # [B, 1, 1, 1]
92
+ z_t = (1 - t_expanded) * z_0 + t_expanded * noise
93
+ v_target = noise - z_0
94
+ return z_t, v_target
95
+
96
+ @torch.no_grad()
97
+ def euler_step(
98
+ self,
99
+ z_t: torch.Tensor,
100
+ v_pred: torch.Tensor,
101
+ t: float,
102
+ t_next: float,
103
+ ) -> torch.Tensor:
104
+ """
105
+ Single Euler step for ODE sampling.
106
+ z_{t_next} = z_t + (t_next - t) * v_pred
107
+ """
108
+ dt = t_next - t
109
+ return z_t + dt * v_pred
110
+
111
+ @torch.no_grad()
112
+ def sample(
113
+ self,
114
+ model,
115
+ noise: torch.Tensor,
116
+ text_emb: torch.Tensor,
117
+ text_pooled: torch.Tensor,
118
+ num_steps: int = 20,
119
+ cfg_scale: float = 7.5,
120
+ planner=None,
121
+ ) -> torch.Tensor:
122
+ """
123
+ Full sampling loop using Euler ODE solver.
124
+
125
+ Args:
126
+ model: MicroForgeBackbone
127
+ noise: [B, C, H, W] initial noise
128
+ text_emb: [B, M, D] text embeddings
129
+ text_pooled: [B, D] pooled text
130
+ num_steps: number of denoising steps
131
+ cfg_scale: classifier-free guidance scale
132
+ planner: optional RecurrentLatentPlanner
133
+
134
+ Returns:
135
+ z_0: [B, C, H, W] generated clean latent
136
+ """
137
+ timesteps = torch.linspace(1, 0, num_steps + 1, device=noise.device)
138
+ z_t = noise
139
+ plan = None
140
+
141
+ for i in range(num_steps):
142
+ t = timesteps[i]
143
+ t_next = timesteps[i + 1]
144
+ t_batch = torch.full((noise.shape[0],), t, device=noise.device)
145
+
146
+ planner_tokens = None
147
+ if planner is not None:
148
+ # Initialize or update plan
149
+ from .backbone import PatchEmbed2D
150
+ # Simple flattening for planner input
151
+ B, C, H, W = z_t.shape
152
+ img_tokens = z_t.reshape(B, C, -1).permute(0, 2, 1)
153
+
154
+ plan = planner.initialize_plan(text_pooled, B, plan)
155
+ t_emb = model.time_embed(t_batch)
156
+ plan, planner_tokens = planner(img_tokens, plan, t_emb)
157
+
158
+ # Classifier-free guidance
159
+ if cfg_scale > 1.0:
160
+ # Conditional prediction
161
+ v_cond = model(z_t, t_batch, text_emb, text_pooled, planner_tokens)
162
+ # Unconditional prediction (empty text)
163
+ null_text = torch.zeros_like(text_emb)
164
+ null_pooled = torch.zeros_like(text_pooled)
165
+ v_uncond = model(z_t, t_batch, null_text, null_pooled, None)
166
+ # CFG
167
+ v_pred = v_uncond + cfg_scale * (v_cond - v_uncond)
168
+ else:
169
+ v_pred = model(z_t, t_batch, text_emb, text_pooled, planner_tokens)
170
+
171
+ z_t = self.euler_step(z_t, v_pred, t.item(), t_next.item())
172
+
173
+ return z_t
174
+
175
+
176
+ class MicroForgeLoss(nn.Module):
177
+ """
178
+ Combined loss function for MicroForge training.
179
+
180
+ L_total = L_flow + lambda_kl * L_kl + lambda_recon * L_recon
181
+
182
+ For distillation stages, additional losses are added.
183
+ """
184
+
185
+ def __init__(
186
+ self,
187
+ lambda_kl: float = 1e-6,
188
+ lambda_recon: float = 1.0,
189
+ ):
190
+ super().__init__()
191
+ self.lambda_kl = lambda_kl
192
+ self.lambda_recon = lambda_recon
193
+
194
+ def flow_matching_loss(
195
+ self,
196
+ v_pred: torch.Tensor,
197
+ v_target: torch.Tensor,
198
+ t: Optional[torch.Tensor] = None,
199
+ ) -> torch.Tensor:
200
+ """
201
+ Flow matching loss with optional timestep weighting.
202
+ L = ||v_pred - v_target||^2
203
+
204
+ Optional: t-scaling (from SnapGen) to prioritize perceptually important timesteps.
205
+ """
206
+ loss = F.mse_loss(v_pred, v_target, reduction='none')
207
+
208
+ if t is not None:
209
+ # T-scaling: weight intermediate timesteps more
210
+ # SNR-based weighting: higher weight at intermediate noise levels
211
+ weight = 1.0 / (1.0 + torch.abs(2 * t - 1)) # Peak at t=0.5
212
+ weight = weight[:, None, None, None]
213
+ loss = loss * weight
214
+
215
+ return loss.mean()
216
+
217
+ def vae_loss(
218
+ self,
219
+ x_recon: torch.Tensor,
220
+ x: torch.Tensor,
221
+ mu: torch.Tensor,
222
+ logvar: torch.Tensor,
223
+ ) -> Dict[str, torch.Tensor]:
224
+ """VAE training loss: L1 recon + KL."""
225
+ l_recon = F.l1_loss(x_recon, x)
226
+ l_kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
227
+
228
+ total = self.lambda_recon * l_recon + self.lambda_kl * l_kl
229
+ return {
230
+ 'total': total,
231
+ 'recon': l_recon,
232
+ 'kl': l_kl,
233
+ }
234
+
235
+ def forward(
236
+ self,
237
+ v_pred: torch.Tensor,
238
+ v_target: torch.Tensor,
239
+ t: Optional[torch.Tensor] = None,
240
+ ) -> Dict[str, torch.Tensor]:
241
+ """Compute flow matching loss (main training objective)."""
242
+ l_flow = self.flow_matching_loss(v_pred, v_target, t)
243
+ return {'total': l_flow, 'flow': l_flow}
244
+
245
+
246
+ class MicroForgeTrainer:
247
+ """
248
+ Training orchestrator for MicroForge.
249
+
250
+ Implements the staged curriculum:
251
+ Stage 1: VAE training (or use pretrained DC-AE)
252
+ Stage 2: Backbone training with flow matching at low-res
253
+ Stage 3: Progressive resolution increase
254
+ Stage 4: Editing task joint training
255
+ Stage 5: Step distillation (consistency or LADD)
256
+
257
+ Memory optimization for 16GB GPU:
258
+ - Gradient checkpointing
259
+ - Mixed precision (fp16/bf16)
260
+ - Small batch + gradient accumulation
261
+ - Freeze VAE during backbone training
262
+ """
263
+
264
+ def __init__(
265
+ self,
266
+ vae,
267
+ backbone,
268
+ planner=None,
269
+ lr: float = 1e-4,
270
+ weight_decay: float = 0.01,
271
+ grad_clip: float = 2.0,
272
+ use_ema: bool = True,
273
+ ema_decay: float = 0.9999,
274
+ ):
275
+ self.vae = vae
276
+ self.backbone = backbone
277
+ self.planner = planner
278
+ self.scheduler = FlowMatchingScheduler()
279
+ self.loss_fn = MicroForgeLoss()
280
+ self.grad_clip = grad_clip
281
+
282
+ # Setup optimizer
283
+ params = list(backbone.parameters())
284
+ if planner is not None:
285
+ params += list(planner.parameters())
286
+
287
+ self.optimizer = torch.optim.AdamW(
288
+ params, lr=lr, weight_decay=weight_decay,
289
+ betas=(0.9, 0.999),
290
+ )
291
+
292
+ # EMA
293
+ self.use_ema = use_ema
294
+ self.ema_decay = ema_decay
295
+ if use_ema:
296
+ self.ema_backbone = self._create_ema(backbone)
297
+
298
+ def _create_ema(self, model):
299
+ """Create EMA copy of model."""
300
+ import copy
301
+ ema = copy.deepcopy(model)
302
+ for p in ema.parameters():
303
+ p.data = p.data.clone()
304
+ p.requires_grad_(False)
305
+ return ema
306
+
307
+ @torch.no_grad()
308
+ def _update_ema(self):
309
+ """Update EMA weights."""
310
+ if not self.use_ema:
311
+ return
312
+ for p_ema, p_model in zip(self.ema_backbone.parameters(), self.backbone.parameters()):
313
+ p_ema.data.mul_(self.ema_decay).add_(p_model.data, alpha=1 - self.ema_decay)
314
+
315
+ def train_step(
316
+ self,
317
+ images: torch.Tensor,
318
+ text_emb: torch.Tensor,
319
+ text_pooled: torch.Tensor,
320
+ ) -> Dict[str, float]:
321
+ """
322
+ Single training step.
323
+
324
+ Args:
325
+ images: [B, 3, H, W] input images
326
+ text_emb: [B, M, text_dim] text embeddings
327
+ text_pooled: [B, text_dim] pooled text
328
+
329
+ Returns:
330
+ dict of loss values
331
+ """
332
+ device = images.device
333
+
334
+ # Encode to latent (VAE frozen)
335
+ with torch.no_grad():
336
+ z_0 = self.vae.get_latent(images)
337
+
338
+ # Sample timesteps and noise
339
+ B = z_0.shape[0]
340
+ t = self.scheduler.sample_timesteps(B, device)
341
+ noise = torch.randn_like(z_0)
342
+
343
+ # Create noised latent and target
344
+ z_t, v_target = self.scheduler.add_noise(z_0, noise, t)
345
+
346
+ # Optional: planner
347
+ planner_tokens = None
348
+ if self.planner is not None:
349
+ img_tokens = z_t.reshape(B, z_t.shape[1], -1).permute(0, 2, 1)
350
+ plan = self.planner.initialize_plan(text_pooled, B)
351
+ t_emb = self.backbone.time_embed(t)
352
+ _, planner_tokens = self.planner(img_tokens, plan, t_emb)
353
+
354
+ # Predict velocity
355
+ v_pred = self.backbone(z_t, t, text_emb, text_pooled, planner_tokens)
356
+
357
+ # Compute loss
358
+ losses = self.loss_fn(v_pred, v_target, t)
359
+
360
+ # Backward + optimize
361
+ self.optimizer.zero_grad()
362
+ losses['total'].backward()
363
+ torch.nn.utils.clip_grad_norm_(self.backbone.parameters(), self.grad_clip)
364
+ self.optimizer.step()
365
+
366
+ # Update EMA
367
+ self._update_ema()
368
+
369
+ return {k: v.item() for k, v in losses.items()}
370
+
371
+ def train_vae_step(
372
+ self,
373
+ images: torch.Tensor,
374
+ vae_optimizer: torch.optim.Optimizer,
375
+ ) -> Dict[str, float]:
376
+ """Training step for VAE."""
377
+ x_recon, mu, logvar = self.vae(images)
378
+ losses = self.loss_fn.vae_loss(x_recon, images, mu, logvar)
379
+
380
+ vae_optimizer.zero_grad()
381
+ losses['total'].backward()
382
+ torch.nn.utils.clip_grad_norm_(self.vae.parameters(), self.grad_clip)
383
+ vae_optimizer.step()
384
+
385
+ return {k: v.item() for k, v in losses.items()}