asdf98 commited on
Commit
c4858e4
·
verified ·
1 Parent(s): 4f46baa

Add training pipeline

Browse files
Files changed (1) hide show
  1. train.py +378 -0
train.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LiquidGen Training Pipeline
3
+
4
+ Flow Matching training objective (velocity prediction):
5
+ - Forward: x_t = (1 - t) * x_0 + t * ε (linear interpolation)
6
+ - Target: v = ε - x_0 (velocity)
7
+ - Loss: MSE(model(x_t, t), v)
8
+
9
+ At inference: solve ODE from t=1 (noise) to t=0 (clean) using Euler steps.
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from torch.utils.data import DataLoader, Dataset
16
+ from torch.amp import autocast, GradScaler
17
+ import math
18
+ import os
19
+ import json
20
+ import time
21
+ from pathlib import Path
22
+ from typing import Optional, Dict, Any
23
+ from dataclasses import dataclass, field, asdict
24
+
25
+
26
+ @dataclass
27
+ class TrainConfig:
28
+ """Training configuration with sensible defaults for Colab free tier."""
29
+ # Model
30
+ model_size: str = "small"
31
+ num_classes: int = 0
32
+ class_drop_prob: float = 0.1
33
+
34
+ # Data
35
+ image_size: int = 256
36
+ dataset_name: str = "huggan/wikiart"
37
+ dataset_config: str = ""
38
+ image_column: str = "image"
39
+ label_column: str = ""
40
+
41
+ # VAE
42
+ vae_id: str = "black-forest-labs/FLUX.1-schnell"
43
+ vae_subfolder: str = "vae"
44
+ vae_dtype: str = "float16"
45
+ vae_scaling_factor: float = 0.3611
46
+ vae_shift_factor: float = 0.1159
47
+
48
+ # Training
49
+ batch_size: int = 8
50
+ gradient_accumulation_steps: int = 4
51
+ learning_rate: float = 1e-4
52
+ weight_decay: float = 0.01
53
+ max_grad_norm: float = 2.0
54
+ num_epochs: int = 100
55
+ warmup_steps: int = 1000
56
+ ema_decay: float = 0.9999
57
+ mixed_precision: bool = True
58
+
59
+ # Flow matching
60
+ min_timestep: float = 0.001
61
+ max_timestep: float = 0.999
62
+
63
+ # Saving
64
+ output_dir: str = "./outputs"
65
+ save_every_n_steps: int = 5000
66
+ sample_every_n_steps: int = 1000
67
+ log_every_n_steps: int = 50
68
+
69
+ # Sampling
70
+ num_sample_steps: int = 50
71
+ cfg_scale: float = 1.5
72
+ num_samples: int = 4
73
+
74
+ # System
75
+ seed: int = 42
76
+ num_workers: int = 2
77
+ pin_memory: bool = True
78
+ compile_model: bool = False
79
+
80
+ # Hub
81
+ push_to_hub: bool = False
82
+ hub_model_id: str = ""
83
+
84
+
85
+ def get_model_config(size: str, num_classes: int = 0, class_drop_prob: float = 0.1) -> dict:
86
+ """Get model kwargs for a given size preset."""
87
+ configs = {
88
+ "small": dict(embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31,
89
+ expand_ratio=2.0, mlp_ratio=3.0),
90
+ "base": dict(embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31,
91
+ expand_ratio=2.0, mlp_ratio=4.0),
92
+ "large": dict(embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31,
93
+ expand_ratio=2.5, mlp_ratio=4.0),
94
+ }
95
+ cfg = configs[size]
96
+ cfg["num_classes"] = num_classes
97
+ cfg["class_drop_prob"] = class_drop_prob
98
+ cfg["use_zigzag"] = True
99
+ return cfg
100
+
101
+
102
+ class EMAModel:
103
+ """Exponential Moving Average of model parameters."""
104
+
105
+ def __init__(self, model: nn.Module, decay: float = 0.9999):
106
+ self.decay = decay
107
+ self.shadow = {name: p.clone().detach() for name, p in model.named_parameters() if p.requires_grad}
108
+
109
+ @torch.no_grad()
110
+ def update(self, model: nn.Module):
111
+ for name, p in model.named_parameters():
112
+ if p.requires_grad and name in self.shadow:
113
+ self.shadow[name].mul_(self.decay).add_(p.data, alpha=1 - self.decay)
114
+
115
+ def apply(self, model: nn.Module):
116
+ self.backup = {name: p.data.clone() for name, p in model.named_parameters() if p.requires_grad}
117
+ for name, p in model.named_parameters():
118
+ if p.requires_grad and name in self.shadow:
119
+ p.data.copy_(self.shadow[name])
120
+
121
+ def restore(self, model: nn.Module):
122
+ for name, p in model.named_parameters():
123
+ if p.requires_grad and name in self.backup:
124
+ p.data.copy_(self.backup[name])
125
+ self.backup = {}
126
+
127
+ def state_dict(self):
128
+ return self.shadow
129
+
130
+ def load_state_dict(self, state_dict):
131
+ self.shadow = state_dict
132
+
133
+
134
+ class FlowMatchingScheduler:
135
+ """
136
+ Flow Matching scheduler for training and sampling.
137
+
138
+ Training: x_t = (1-t)*x_0 + t*ε, v_target = ε - x_0
139
+ Sampling: Euler ODE from t=1 (noise) to t=0 (clean)
140
+ """
141
+
142
+ def __init__(self, min_t: float = 0.001, max_t: float = 0.999):
143
+ self.min_t = min_t
144
+ self.max_t = max_t
145
+
146
+ def sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor:
147
+ return torch.rand(batch_size, device=device) * (self.max_t - self.min_t) + self.min_t
148
+
149
+ def add_noise(self, x0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
150
+ t_expand = t.view(-1, 1, 1, 1)
151
+ return (1 - t_expand) * x0 + t_expand * noise
152
+
153
+ def get_velocity_target(self, x0: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
154
+ return noise - x0
155
+
156
+ @torch.no_grad()
157
+ def sample(
158
+ self, model: nn.Module, shape: tuple, device: torch.device,
159
+ num_steps: int = 50, class_labels: Optional[torch.Tensor] = None,
160
+ cfg_scale: float = 1.0, dtype: torch.dtype = torch.float32,
161
+ ) -> torch.Tensor:
162
+ model.eval()
163
+ x = torch.randn(shape, device=device, dtype=dtype)
164
+ dt = 1.0 / num_steps
165
+ times = torch.linspace(1.0, dt, num_steps, device=device)
166
+
167
+ for t_val in times:
168
+ t = torch.full((shape[0],), t_val.item(), device=device, dtype=dtype)
169
+
170
+ if cfg_scale > 1.0 and class_labels is not None:
171
+ with torch.amp.autocast('cuda', enabled=(dtype != torch.float32)):
172
+ v_cond = model(x, t, class_labels)
173
+ v_uncond = model(x, t, torch.zeros_like(class_labels))
174
+ v = v_uncond + cfg_scale * (v_cond - v_uncond)
175
+ else:
176
+ with torch.amp.autocast('cuda', enabled=(dtype != torch.float32)):
177
+ v = model(x, t, class_labels)
178
+
179
+ x = x - dt * v
180
+
181
+ return x
182
+
183
+
184
+ def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps):
185
+ """Cosine LR schedule with linear warmup."""
186
+ def lr_lambda(current_step):
187
+ if current_step < warmup_steps:
188
+ return float(current_step) / float(max(1, warmup_steps))
189
+ progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
190
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
191
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
192
+
193
+
194
+ @torch.no_grad()
195
+ def encode_images_with_vae(images, vae, scaling_factor, shift_factor):
196
+ """Encode pixel images to VAE latents."""
197
+ images = images * 2.0 - 1.0
198
+ latents = vae.encode(images).latent_dist.sample()
199
+ latents = (latents - shift_factor) * scaling_factor
200
+ return latents
201
+
202
+
203
+ @torch.no_grad()
204
+ def decode_latents_with_vae(latents, vae, scaling_factor, shift_factor):
205
+ """Decode VAE latents to pixel images."""
206
+ latents = latents / scaling_factor + shift_factor
207
+ images = vae.decode(latents).sample
208
+ images = (images + 1.0) / 2.0
209
+ return images.clamp(0, 1)
210
+
211
+
212
+ def train(config: TrainConfig):
213
+ """Main training loop."""
214
+ from model import LiquidGen
215
+
216
+ torch.manual_seed(config.seed)
217
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
218
+ print(f"Device: {device}")
219
+
220
+ os.makedirs(config.output_dir, exist_ok=True)
221
+ os.makedirs(os.path.join(config.output_dir, "samples"), exist_ok=True)
222
+ os.makedirs(os.path.join(config.output_dir, "checkpoints"), exist_ok=True)
223
+
224
+ with open(os.path.join(config.output_dir, "config.json"), "w") as f:
225
+ json.dump(asdict(config), f, indent=2)
226
+
227
+ # Load VAE
228
+ print("Loading VAE...")
229
+ from diffusers import AutoencoderKL
230
+ vae_dtype = torch.float16 if config.vae_dtype == "float16" else torch.bfloat16
231
+ vae = AutoencoderKL.from_pretrained(
232
+ config.vae_id, subfolder=config.vae_subfolder, torch_dtype=vae_dtype
233
+ ).to(device).eval()
234
+ for p in vae.parameters():
235
+ p.requires_grad_(False)
236
+
237
+ # Load Dataset
238
+ print(f"Loading dataset: {config.dataset_name}")
239
+ from datasets import load_dataset
240
+ from torchvision import transforms
241
+
242
+ ds_kwargs = {}
243
+ if config.dataset_config:
244
+ ds_kwargs["name"] = config.dataset_config
245
+ dataset = load_dataset(config.dataset_name, split="train", **ds_kwargs)
246
+
247
+ transform = transforms.Compose([
248
+ transforms.Resize(config.image_size, interpolation=transforms.InterpolationMode.LANCZOS),
249
+ transforms.CenterCrop(config.image_size),
250
+ transforms.RandomHorizontalFlip(),
251
+ transforms.ToTensor(),
252
+ ])
253
+
254
+ class ImageDataset(Dataset):
255
+ def __init__(self, hf_dataset, transform, image_col, label_col=""):
256
+ self.dataset = hf_dataset
257
+ self.transform = transform
258
+ self.image_col = image_col
259
+ self.label_col = label_col
260
+
261
+ def __len__(self):
262
+ return len(self.dataset)
263
+
264
+ def __getitem__(self, idx):
265
+ item = self.dataset[idx]
266
+ img = item[self.image_col]
267
+ if img.mode != "RGB":
268
+ img = img.convert("RGB")
269
+ img = self.transform(img)
270
+ label = item[self.label_col] if self.label_col and self.label_col in item else -1
271
+ return img, label
272
+
273
+ train_dataset = ImageDataset(dataset, transform, config.image_column, config.label_column)
274
+ train_loader = DataLoader(
275
+ train_dataset, batch_size=config.batch_size, shuffle=True,
276
+ num_workers=config.num_workers, pin_memory=config.pin_memory, drop_last=True,
277
+ )
278
+
279
+ # Create Model
280
+ model_kwargs = get_model_config(config.model_size, config.num_classes, config.class_drop_prob)
281
+ model = LiquidGen(**model_kwargs).to(device)
282
+ print(f"LiquidGen-{config.model_size}: {model.count_params() / 1e6:.1f}M params")
283
+
284
+ if config.compile_model and hasattr(torch, "compile"):
285
+ model = torch.compile(model)
286
+
287
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate,
288
+ weight_decay=config.weight_decay, betas=(0.9, 0.999))
289
+ total_steps = len(train_loader) * config.num_epochs // config.gradient_accumulation_steps
290
+ scheduler = get_cosine_schedule_with_warmup(optimizer, config.warmup_steps, total_steps)
291
+ ema = EMAModel(model, decay=config.ema_decay)
292
+ scaler = GradScaler('cuda', enabled=config.mixed_precision)
293
+ fm = FlowMatchingScheduler(min_t=config.min_timestep, max_t=config.max_timestep)
294
+
295
+ print(f"\nTraining: {total_steps} steps, effective batch {config.batch_size * config.gradient_accumulation_steps}")
296
+
297
+ global_step = 0
298
+ loss_accum = 0.0
299
+
300
+ for epoch in range(config.num_epochs):
301
+ model.train()
302
+ t_start = time.time()
303
+
304
+ for batch_idx, (images, labels) in enumerate(train_loader):
305
+ images = images.to(device)
306
+ labels = labels.to(device) if config.num_classes > 0 else None
307
+
308
+ with torch.no_grad():
309
+ latents = encode_images_with_vae(
310
+ images.to(vae_dtype), vae, config.vae_scaling_factor, config.vae_shift_factor
311
+ ).float()
312
+
313
+ t = fm.sample_timesteps(latents.shape[0], device)
314
+ noise = torch.randn_like(latents)
315
+ x_t = fm.add_noise(latents, noise, t)
316
+ v_target = fm.get_velocity_target(latents, noise)
317
+
318
+ with autocast('cuda', enabled=config.mixed_precision):
319
+ v_pred = model(x_t, t, labels)
320
+ loss = F.mse_loss(v_pred, v_target) / config.gradient_accumulation_steps
321
+
322
+ scaler.scale(loss).backward()
323
+ loss_accum += loss.item()
324
+
325
+ if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
326
+ scaler.unscale_(optimizer)
327
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
328
+ scaler.step(optimizer)
329
+ scaler.update()
330
+ optimizer.zero_grad()
331
+ scheduler.step()
332
+ ema.update(model)
333
+ global_step += 1
334
+
335
+ if global_step % config.log_every_n_steps == 0:
336
+ avg_loss = loss_accum / config.log_every_n_steps
337
+ lr = optimizer.param_groups[0]["lr"]
338
+ print(f"step={global_step} | epoch={epoch} | loss={avg_loss:.4f} | "
339
+ f"grad_norm={grad_norm:.2f} | lr={lr:.2e}")
340
+ loss_accum = 0.0
341
+
342
+ if math.isnan(avg_loss) or avg_loss > 100:
343
+ print("⚠️ Training diverged!")
344
+ return
345
+
346
+ if global_step % config.sample_every_n_steps == 0:
347
+ ema.apply(model)
348
+ model.eval()
349
+ latent_size = config.image_size // 8
350
+ sample_labels = None
351
+ if config.num_classes > 0:
352
+ sample_labels = torch.randint(0, config.num_classes, (config.num_samples,), device=device)
353
+ sampled = fm.sample(model, (config.num_samples, 16, latent_size, latent_size),
354
+ device, config.num_sample_steps, sample_labels, config.cfg_scale)
355
+ sample_imgs = decode_latents_with_vae(sampled.to(vae_dtype), vae,
356
+ config.vae_scaling_factor, config.vae_shift_factor).float()
357
+ from torchvision.utils import save_image
358
+ save_image(sample_imgs, os.path.join(config.output_dir, "samples", f"step_{global_step:07d}.png"), nrow=2)
359
+ ema.restore(model)
360
+ model.train()
361
+
362
+ if global_step % config.save_every_n_steps == 0:
363
+ torch.save({
364
+ "model": model.state_dict(), "ema": ema.state_dict(),
365
+ "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(),
366
+ "global_step": global_step, "epoch": epoch, "config": asdict(config),
367
+ }, os.path.join(config.output_dir, "checkpoints", f"step_{global_step:07d}.pt"))
368
+
369
+ print(f"Epoch {epoch} complete | time={time.time()-t_start:.0f}s")
370
+
371
+ torch.save({"model": model.state_dict(), "ema": ema.state_dict(), "config": asdict(config),
372
+ "global_step": global_step}, os.path.join(config.output_dir, "checkpoints", "final.pt"))
373
+ print(f"Training complete! Final model saved.")
374
+
375
+
376
+ if __name__ == "__main__":
377
+ config = TrainConfig(model_size="small", image_size=256, batch_size=4, num_epochs=2)
378
+ train(config)