asdf98 commited on
Commit
85accf4
·
verified ·
1 Parent(s): 18ce5a6

Add lira/training.py

Browse files
Files changed (1) hide show
  1. lira/training.py +382 -0
lira/training.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LiRA Training Pipeline
3
+
4
+ Training Strategy:
5
+ ==================
6
+ 1. Flow Matching with v-prediction (from SANA/SD3)
7
+ - More stable than epsilon prediction near t=T
8
+ - Better gradients throughout the diffusion process
9
+
10
+ 2. Laplace Noise Schedule (from "Improved Noise Schedule for Diffusion")
11
+ - Concentrates sampling around logSNR=0
12
+ - Better FID than cosine/linear schedules
13
+
14
+ 3. Progressive Resolution Training (from SANA)
15
+ - Start at 256px → 512px → 1024px
16
+ - Each stage uses the previous as initialization
17
+
18
+ 4. Curriculum Learning (from "Curriculum Learning for Diffusion")
19
+ - Easy timesteps first (high noise), hard timesteps later (low noise)
20
+
21
+ 5. EMA with post-hoc tuning (from EDM2)
22
+ - EMA decay 0.9999 during training
23
+ - Post-hoc search for optimal EMA length
24
+
25
+ Training Stability:
26
+ ===================
27
+ - Gradient clipping (max_norm=1.0)
28
+ - AdamW with weight decay 0.01
29
+ - Warmup + cosine decay learning rate
30
+ - AdaLN-Zero initialization (network acts as identity at start)
31
+ - Loss scaling: velocity prediction is naturally bounded
32
+ - Mixed precision (bf16) with gradient scaling
33
+ """
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+ import math
39
+ import os
40
+ from typing import Optional, Dict, Tuple
41
+ from dataclasses import dataclass, field
42
+
43
+
44
+ @dataclass
45
+ class LiRATrainingConfig:
46
+ """Training configuration with sensible defaults for Colab-friendly training"""
47
+
48
+ # Model
49
+ model_config: str = 'tiny' # Start small for testing
50
+ latent_channels: int = 4 # SD1.x/SDXL VAE
51
+ spatial_compression: int = 8
52
+ d_text: int = 768
53
+ patch_size: int = 2 # 2x2 patches for f8 VAE (128x128 → 64x64 tokens)
54
+
55
+ # Training
56
+ batch_size: int = 8
57
+ learning_rate: float = 1e-4
58
+ weight_decay: float = 0.01
59
+ warmup_steps: int = 1000
60
+ max_steps: int = 100000
61
+ grad_clip: float = 1.0
62
+
63
+ # EMA
64
+ ema_decay: float = 0.9999
65
+
66
+ # Flow matching
67
+ prediction_target: str = 'velocity' # 'velocity' or 'epsilon'
68
+ noise_schedule: str = 'laplace' # 'laplace', 'logit_normal', or 'uniform'
69
+
70
+ # Progressive resolution
71
+ progressive_stages: list = field(default_factory=lambda: [
72
+ {'resolution': 256, 'steps': 50000},
73
+ {'resolution': 512, 'steps': 30000},
74
+ {'resolution': 1024, 'steps': 20000},
75
+ ])
76
+
77
+ # Curriculum
78
+ use_curriculum: bool = True
79
+ curriculum_warmup: int = 10000 # Steps before full timestep range
80
+
81
+ # Logging
82
+ log_every: int = 100
83
+ save_every: int = 5000
84
+ sample_every: int = 2500
85
+
86
+ # Hardware
87
+ mixed_precision: str = 'bf16' # 'bf16', 'fp16', or 'no'
88
+ compile_model: bool = False # torch.compile (if available)
89
+
90
+ # Data
91
+ dataset_name: str = ''
92
+ num_workers: int = 4
93
+
94
+ # Output
95
+ output_dir: str = './lira_output'
96
+ hub_model_id: str = ''
97
+ push_to_hub: bool = True
98
+
99
+
100
+ class FlowMatchingScheduler:
101
+ """
102
+ Flow Matching noise scheduler with Laplace distribution.
103
+
104
+ Flow matching interpolation:
105
+ z_t = (1 - t) * z_0 + t * ε where ε ~ N(0, I)
106
+ v_t = ε - z_0 (velocity)
107
+
108
+ Laplace noise schedule (from "Improved Noise Schedule"):
109
+ t ~ Laplace(μ=0, b=1), mapped to [0, 1] via CDF
110
+ This concentrates samples around t=0.5 where learning is most effective.
111
+ """
112
+
113
+ def __init__(self, schedule: str = 'laplace', shift: float = 1.0):
114
+ self.schedule = schedule
115
+ self.shift = shift # For resolution-dependent shifting (from SD3)
116
+
117
+ def sample_timesteps(self, batch_size: int, device: torch.device,
118
+ curriculum_progress: float = 1.0) -> torch.Tensor:
119
+ """
120
+ Sample timesteps from the noise schedule.
121
+
122
+ curriculum_progress: 0→1 over training. At 0, only easy timesteps (near 1.0).
123
+ At 1.0, full range.
124
+ """
125
+ if self.schedule == 'laplace':
126
+ # Laplace distribution centered at 0, mapped to [0,1]
127
+ u = torch.rand(batch_size, device=device)
128
+ # Laplace CDF inverse: t = μ - b * sign(u-0.5) * log(1 - 2|u-0.5|)
129
+ t = 0.5 - torch.sign(u - 0.5) * torch.log(1 - 2 * torch.abs(u - 0.5) + 1e-8)
130
+ # Map from (-inf, inf) to (0, 1) via sigmoid
131
+ t = torch.sigmoid(t)
132
+
133
+ elif self.schedule == 'logit_normal':
134
+ # Logit-normal (from SD3): sample from N(0,1) then sigmoid
135
+ t = torch.sigmoid(torch.randn(batch_size, device=device))
136
+
137
+ else: # uniform
138
+ t = torch.rand(batch_size, device=device)
139
+
140
+ # Apply resolution-dependent shift (from SD3)
141
+ # Higher shift → more weight on higher noise levels
142
+ if self.shift != 1.0:
143
+ t = t * self.shift / (1 + (self.shift - 1) * t)
144
+
145
+ # Curriculum: restrict to easier timesteps early in training
146
+ if curriculum_progress < 1.0:
147
+ min_t = 0.5 * (1 - curriculum_progress) # Start from t>0.5, expand to t>0
148
+ t = min_t + t * (1 - min_t)
149
+
150
+ # Clamp for numerical stability
151
+ t = t.clamp(1e-5, 1 - 1e-5)
152
+
153
+ return t
154
+
155
+ def add_noise(self, z_0: torch.Tensor, t: torch.Tensor,
156
+ noise: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
157
+ """
158
+ Flow matching interpolation: z_t = (1-t)*z_0 + t*ε
159
+
160
+ Returns: (z_t, noise)
161
+ """
162
+ if noise is None:
163
+ noise = torch.randn_like(z_0)
164
+
165
+ t = t.view(-1, 1, 1, 1) # Broadcast over spatial dims
166
+ z_t = (1 - t) * z_0 + t * noise
167
+
168
+ return z_t, noise
169
+
170
+ def get_velocity(self, z_0: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
171
+ """Compute velocity target: v = ε - z_0"""
172
+ return noise - z_0
173
+
174
+ def predict_z0(self, z_t: torch.Tensor, v_pred: torch.Tensor,
175
+ t: torch.Tensor) -> torch.Tensor:
176
+ """Recover z_0 from z_t and predicted velocity"""
177
+ t = t.view(-1, 1, 1, 1)
178
+ # z_t = (1-t)*z_0 + t*ε
179
+ # v = ε - z_0
180
+ # z_0 = z_t - t*v / (1-t+t) ... simplified:
181
+ # z_0 = z_t - t * v_pred ... wait let me derive properly
182
+ # z_t = (1-t)*z_0 + t*(z_0 + v) = z_0 + t*v
183
+ # z_0 = z_t - t * v_pred
184
+ return z_t - t * v_pred
185
+
186
+
187
+ class EMAModel:
188
+ """Exponential Moving Average of model parameters"""
189
+
190
+ def __init__(self, model: nn.Module, decay: float = 0.9999):
191
+ self.decay = decay
192
+ self.shadow = {}
193
+ self.backup = {}
194
+
195
+ for name, param in model.named_parameters():
196
+ if param.requires_grad:
197
+ self.shadow[name] = param.data.clone()
198
+
199
+ @torch.no_grad()
200
+ def update(self, model: nn.Module):
201
+ for name, param in model.named_parameters():
202
+ if param.requires_grad and name in self.shadow:
203
+ self.shadow[name] = (
204
+ self.decay * self.shadow[name] + (1 - self.decay) * param.data
205
+ )
206
+
207
+ def apply_shadow(self, model: nn.Module):
208
+ """Replace model params with EMA params"""
209
+ for name, param in model.named_parameters():
210
+ if param.requires_grad and name in self.shadow:
211
+ self.backup[name] = param.data
212
+ param.data = self.shadow[name]
213
+
214
+ def restore(self, model: nn.Module):
215
+ """Restore original model params"""
216
+ for name, param in model.named_parameters():
217
+ if param.requires_grad and name in self.backup:
218
+ param.data = self.backup[name]
219
+ self.backup = {}
220
+
221
+ def state_dict(self):
222
+ return self.shadow
223
+
224
+ def load_state_dict(self, state_dict):
225
+ self.shadow = state_dict
226
+
227
+
228
+ def compute_loss(
229
+ model: nn.Module,
230
+ z_0: torch.Tensor,
231
+ text_features: torch.Tensor,
232
+ scheduler: FlowMatchingScheduler,
233
+ config: LiRATrainingConfig,
234
+ global_step: int = 0,
235
+ text_mask: Optional[torch.Tensor] = None,
236
+ ) -> Tuple[torch.Tensor, Dict]:
237
+ """
238
+ Compute training loss.
239
+
240
+ Loss = ||v_pred - v_target||^2 (MSE on velocity prediction)
241
+
242
+ With optional:
243
+ - Reasoning regularization (encourage adaptive compute)
244
+ - Frequency-weighted loss (higher weight on low-frequency errors)
245
+ """
246
+ device = z_0.device
247
+ B = z_0.shape[0]
248
+
249
+ # Curriculum progress
250
+ if config.use_curriculum:
251
+ curriculum_progress = min(1.0, global_step / config.curriculum_warmup)
252
+ else:
253
+ curriculum_progress = 1.0
254
+
255
+ # Sample timesteps
256
+ t = scheduler.sample_timesteps(B, device, curriculum_progress)
257
+
258
+ # Add noise
259
+ z_t, noise = scheduler.add_noise(z_0, t)
260
+
261
+ # Get velocity target
262
+ v_target = scheduler.get_velocity(z_0, noise)
263
+
264
+ # Forward pass
265
+ v_pred, reason_info = model(z_t, t, text_features, text_mask)
266
+
267
+ # MSE loss on velocity
268
+ loss = F.mse_loss(v_pred, v_target)
269
+
270
+ # Reasoning regularization: encourage variable thinking steps
271
+ # Small penalty to discourage always using max steps
272
+ if reason_info.get('total_steps', 0) > 0 and len(reason_info.get('stop_values', [])) > 0:
273
+ avg_stop = sum(reason_info['stop_values']) / len(reason_info['stop_values'])
274
+ # Encourage the stop gate to actually stop sometimes
275
+ reason_reg = 0.01 * (1.0 - avg_stop) # Small penalty
276
+ loss = loss + reason_reg
277
+
278
+ info = {
279
+ 'loss': loss.item(),
280
+ 'mse_loss': F.mse_loss(v_pred, v_target).item(),
281
+ 'reason_steps': reason_info.get('total_steps', 0),
282
+ }
283
+
284
+ return loss, info
285
+
286
+
287
+ def get_lr_scheduler(optimizer, config: LiRATrainingConfig):
288
+ """Warmup + cosine decay learning rate schedule"""
289
+
290
+ def lr_lambda(step):
291
+ if step < config.warmup_steps:
292
+ return step / config.warmup_steps
293
+ else:
294
+ progress = (step - config.warmup_steps) / (config.max_steps - config.warmup_steps)
295
+ return 0.5 * (1 + math.cos(math.pi * progress))
296
+
297
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
298
+
299
+
300
+ # ============================================================================
301
+ # DPM-Solver for fast sampling (from SANA's Flow-DPM-Solver)
302
+ # ============================================================================
303
+
304
+ class FlowDPMSolver:
305
+ """
306
+ DPM-Solver adapted for flow matching.
307
+
308
+ Standard Euler: z_{t-dt} = z_t - dt * v(z_t, t)
309
+ DPM-Solver-2: Second-order correction for fewer steps
310
+
311
+ From SANA: "Flow-DPM-Solver converges at 14-20 steps vs 28-50 for Euler"
312
+ """
313
+
314
+ def __init__(self, num_steps: int = 20, order: int = 2):
315
+ self.num_steps = num_steps
316
+ self.order = min(order, 2)
317
+
318
+ @torch.no_grad()
319
+ def sample(
320
+ self,
321
+ model: nn.Module,
322
+ shape: Tuple[int, ...],
323
+ text_features: torch.Tensor,
324
+ text_mask: Optional[torch.Tensor] = None,
325
+ cfg_scale: float = 4.0,
326
+ device: torch.device = torch.device('cpu'),
327
+ ) -> torch.Tensor:
328
+ """
329
+ Generate samples using DPM-Solver.
330
+
331
+ Args:
332
+ model: LiRA model
333
+ shape: (B, C, H, W) latent shape
334
+ text_features: (B, M, D) text features
335
+ cfg_scale: classifier-free guidance scale
336
+ """
337
+ B = shape[0]
338
+
339
+ # Start from pure noise (t=1)
340
+ z = torch.randn(shape, device=device)
341
+
342
+ # Time steps from 1 → 0
343
+ timesteps = torch.linspace(1, 0, self.num_steps + 1, device=device)
344
+
345
+ prev_v = None
346
+
347
+ for i in range(self.num_steps):
348
+ t_cur = timesteps[i]
349
+ t_next = timesteps[i + 1]
350
+ dt = t_next - t_cur # Negative (going from 1 to 0)
351
+
352
+ t_batch = t_cur.expand(B)
353
+
354
+ # Predict velocity (with CFG if scale > 1)
355
+ if cfg_scale > 1.0:
356
+ v_pred = self._cfg_predict(model, z, t_batch, text_features, text_mask, cfg_scale)
357
+ else:
358
+ v_pred, _ = model(z, t_batch, text_features, text_mask)
359
+
360
+ if self.order == 1 or prev_v is None:
361
+ # Euler step
362
+ z = z + dt * v_pred
363
+ else:
364
+ # DPM-Solver-2 (second-order correction)
365
+ # Uses previous velocity for better approximation
366
+ z = z + dt * (1.5 * v_pred - 0.5 * prev_v)
367
+
368
+ prev_v = v_pred
369
+
370
+ return z
371
+
372
+ def _cfg_predict(self, model, z, t, text_features, text_mask, cfg_scale):
373
+ """Classifier-free guidance"""
374
+ # Unconditional prediction (zero text)
375
+ null_text = torch.zeros_like(text_features)
376
+ v_uncond, _ = model(z, t, null_text, text_mask)
377
+
378
+ # Conditional prediction
379
+ v_cond, _ = model(z, t, text_features, text_mask)
380
+
381
+ # CFG
382
+ return v_uncond + cfg_scale * (v_cond - v_uncond)