asdf98 commited on
Commit
ce2ad4d
·
verified ·
1 Parent(s): 635ef78

Add optimized training v2 with latent pre-caching for Colab

Browse files
Files changed (1) hide show
  1. train.py +379 -380
train.py CHANGED
@@ -1,95 +1,136 @@
1
  """
2
- LiquidGen Training Pipeline
 
 
 
 
 
 
3
 
4
  Flow Matching training objective (velocity prediction):
5
- - Forward: x_t = (1 - t) * x_0 + t * ε (linear interpolation)
6
- - Target: v = ε - x_0 (velocity)
7
  - Loss: MSE(model(x_t, t), v)
8
-
9
- At inference: solve ODE from t=1 (noise) to t=0 (clean) using Euler steps.
10
-
11
- Dataset loading: Uses STREAMING mode by default — no full download needed!
12
- For small datasets (<500MB), set use_streaming=False for faster epoch iteration.
13
  """
14
 
15
  import torch
16
  import torch.nn as nn
17
  import torch.nn.functional as F
18
- from torch.utils.data import DataLoader, Dataset, IterableDataset
19
  from torch.amp import autocast, GradScaler
20
  import math
21
  import os
22
  import json
23
  import time
24
- from pathlib import Path
25
- from typing import Optional, Dict, Any
26
- from dataclasses import dataclass, field, asdict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  @dataclass
30
  class TrainConfig:
31
- """Training configuration with sensible defaults for Colab free tier."""
32
  # Model
33
- model_size: str = "small"
34
- num_classes: int = 0
35
  class_drop_prob: float = 0.1
36
-
37
  # Data
38
- image_size: int = 256
39
- dataset_name: str = "huggan/wikiart"
40
- dataset_config: str = ""
41
- image_column: str = "image"
42
- label_column: str = ""
43
- use_streaming: bool = True # KEY: streaming mode, no full download
44
- max_samples: int = 0 # 0 = use all (only for non-streaming)
45
- streaming_buffer: int = 1000 # Shuffle buffer for streaming
46
-
47
- # VAE (SDXL VAE - open access, no login needed, fp16-safe)
48
- vae_id: str = "madebyollin/sdxl-vae-fp16-fix"
49
- vae_subfolder: str = ""
50
- vae_dtype: str = "float16"
51
- vae_scaling_factor: float = 0.13025
52
- vae_shift_factor: float = 0.0 # SDXL VAE has no shift
53
-
54
  # Training
55
- batch_size: int = 8
56
- gradient_accumulation_steps: int = 4
57
  learning_rate: float = 1e-4
58
  weight_decay: float = 0.01
59
- max_grad_norm: float = 2.0
60
- max_steps: int = 50000 # Train by steps, not epochs (better for streaming)
61
- warmup_steps: int = 1000
62
  ema_decay: float = 0.9999
63
  mixed_precision: bool = True
64
-
65
  # Flow matching
66
  min_timestep: float = 0.001
67
  max_timestep: float = 0.999
68
-
69
  # Saving
70
  output_dir: str = "./outputs"
71
- save_every_n_steps: int = 5000
72
- sample_every_n_steps: int = 1000
73
- log_every_n_steps: int = 50
74
-
75
  # Sampling
76
  num_sample_steps: int = 50
77
- cfg_scale: float = 1.5
78
  num_samples: int = 4
79
-
80
  # System
81
  seed: int = 42
82
- num_workers: int = 0 # 0 for streaming (required)
83
- pin_memory: bool = True
84
  compile_model: bool = False
85
-
86
  # Hub
87
  push_to_hub: bool = False
88
  hub_model_id: str = ""
89
 
90
 
91
- def get_model_config(size: str, num_classes: int = 0, class_drop_prob: float = 0.1) -> dict:
92
- """Get model kwargs for a given size preset."""
93
  configs = {
94
  "small": dict(embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31,
95
  expand_ratio=2.0, mlp_ratio=3.0),
@@ -106,384 +147,342 @@ def get_model_config(size: str, num_classes: int = 0, class_drop_prob: float = 0
106
 
107
 
108
  # =============================================================================
109
- # Dataset Loaders
110
  # =============================================================================
111
 
112
- class StreamingImageDataset(IterableDataset):
113
- """
114
- Streaming dataset — loads images on-the-fly from HuggingFace Hub.
115
- NO full download needed. Starts training immediately.
116
-
117
- Perfect for large datasets (WikiArt, LAION, etc.) on Colab free tier.
118
- """
119
- def __init__(self, dataset_name, image_column="image", label_column="",
120
- image_size=256, split="train", dataset_config="",
121
- buffer_size=1000, seed=42):
122
- super().__init__()
123
- self.dataset_name = dataset_name
124
- self.image_column = image_column
125
- self.label_column = label_column
126
- self.split = split
127
- self.dataset_config = dataset_config
128
- self.buffer_size = buffer_size
129
- self.seed = seed
130
-
131
- from torchvision import transforms
132
- self.transform = transforms.Compose([
133
- transforms.Resize(image_size, interpolation=transforms.InterpolationMode.LANCZOS),
134
- transforms.CenterCrop(image_size),
135
- transforms.RandomHorizontalFlip(),
136
- transforms.ToTensor(),
137
- ])
138
-
139
- def _get_stream(self):
140
- from datasets import load_dataset
141
- kwargs = {}
142
- if self.dataset_config:
143
- kwargs["name"] = self.dataset_config
144
- ds = load_dataset(self.dataset_name, split=self.split, streaming=True, **kwargs)
145
- ds = ds.shuffle(seed=self.seed, buffer_size=self.buffer_size)
146
- return iter(ds)
147
-
148
- def __iter__(self):
149
- stream = self._get_stream()
150
- for item in stream:
151
- try:
152
- img = item[self.image_column]
153
- if img.mode != "RGB":
154
- img = img.convert("RGB")
155
- img_tensor = self.transform(img)
156
- label = -1
157
- if self.label_column and self.label_column in item:
158
- label = item[self.label_column]
159
- yield img_tensor, label
160
- except Exception:
161
- continue
162
-
163
-
164
- class MapImageDataset(Dataset):
165
- """
166
- Standard map-style dataset for small datasets that fit in memory.
167
- Downloads once, then fast random access.
168
-
169
- Good for: Pokemon (95MB), Flowers (330MB), few-shot-art (510MB)
170
- """
171
- def __init__(self, dataset_name, image_column="image", label_column="",
172
- image_size=256, split="train", dataset_config="", max_samples=0):
173
- super().__init__()
174
- self.image_column = image_column
175
- self.label_column = label_column
176
-
177
- from datasets import load_dataset
178
- from torchvision import transforms
179
-
180
- kwargs = {}
181
- if dataset_config:
182
- kwargs["name"] = dataset_config
183
-
184
- print(f"Downloading {dataset_name}...")
185
- self.dataset = load_dataset(dataset_name, split=split, **kwargs)
186
- if max_samples > 0:
187
- self.dataset = self.dataset.select(range(min(max_samples, len(self.dataset))))
188
- print(f" {len(self.dataset)} images loaded")
189
-
190
- self.transform = transforms.Compose([
191
- transforms.Resize(image_size, interpolation=transforms.InterpolationMode.LANCZOS),
192
- transforms.CenterCrop(image_size),
193
- transforms.RandomHorizontalFlip(),
194
- transforms.ToTensor(),
195
- ])
196
-
197
  def __len__(self):
198
- return len(self.dataset)
199
-
200
  def __getitem__(self, idx):
201
- item = self.dataset[idx]
202
- img = item[self.image_column]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  if img.mode != "RGB":
204
  img = img.convert("RGB")
205
- img = self.transform(img)
206
- label = item[self.label_column] if self.label_column and self.label_column in item else -1
207
- return img, label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
 
210
  # =============================================================================
211
- # Training Utilities
212
  # =============================================================================
213
 
214
  class EMAModel:
215
- """Exponential Moving Average of model parameters."""
216
- def __init__(self, model: nn.Module, decay: float = 0.9999):
217
  self.decay = decay
218
- self.shadow = {name: p.clone().detach() for name, p in model.named_parameters() if p.requires_grad}
219
-
220
  @torch.no_grad()
221
- def update(self, model: nn.Module):
222
- for name, p in model.named_parameters():
223
- if p.requires_grad and name in self.shadow:
224
- self.shadow[name].mul_(self.decay).add_(p.data, alpha=1 - self.decay)
225
-
226
- def apply(self, model: nn.Module):
227
- self.backup = {name: p.data.clone() for name, p in model.named_parameters() if p.requires_grad}
228
- for name, p in model.named_parameters():
229
- if p.requires_grad and name in self.shadow:
230
- p.data.copy_(self.shadow[name])
231
-
232
- def restore(self, model: nn.Module):
233
- for name, p in model.named_parameters():
234
- if p.requires_grad and name in self.backup:
235
- p.data.copy_(self.backup[name])
236
  self.backup = {}
237
-
238
- def state_dict(self):
239
- return self.shadow
240
-
241
- def load_state_dict(self, state_dict):
242
- self.shadow = state_dict
243
 
244
 
245
  class FlowMatchingScheduler:
246
- """Flow Matching: x_t = (1-t)*x_0 + t*ε, v_target = ε - x_0"""
247
  def __init__(self, min_t=0.001, max_t=0.999):
248
  self.min_t, self.max_t = min_t, max_t
249
-
250
- def sample_timesteps(self, batch_size, device):
251
- return torch.rand(batch_size, device=device) * (self.max_t - self.min_t) + self.min_t
252
-
253
  def add_noise(self, x0, noise, t):
254
- t = t.view(-1, 1, 1, 1)
255
- return (1 - t) * x0 + t * noise
256
-
257
  def get_velocity_target(self, x0, noise):
258
  return noise - x0
259
-
260
  @torch.no_grad()
261
- def sample(self, model, shape, device, num_steps=50, class_labels=None,
262
- cfg_scale=1.0, dtype=torch.float32):
263
- model.eval()
264
- x = torch.randn(shape, device=device, dtype=dtype)
265
  dt = 1.0 / num_steps
266
- for t_val in torch.linspace(1.0, dt, num_steps, device=device):
267
- t = torch.full((shape[0],), t_val.item(), device=device, dtype=dtype)
268
- if cfg_scale > 1.0 and class_labels is not None:
269
- with torch.amp.autocast('cuda', enabled=(dtype != torch.float32)):
270
- v_cond = model(x, t, class_labels)
271
- v_uncond = model(x, t, torch.zeros_like(class_labels))
272
- v = v_uncond + cfg_scale * (v_cond - v_uncond)
273
- else:
274
- with torch.amp.autocast('cuda', enabled=(dtype != torch.float32)):
275
- v = model(x, t, class_labels)
276
- x = x - dt * v
277
  return x
278
 
279
 
280
- def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps):
281
- def lr_lambda(step):
282
- if step < warmup_steps:
283
- return step / max(1, warmup_steps)
284
- progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
285
- return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
286
- return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
287
-
288
-
289
- @torch.no_grad()
290
- def encode_images_with_vae(images, vae, scaling_factor, shift_factor):
291
- images = images * 2.0 - 1.0
292
- latents = vae.encode(images).latent_dist.sample()
293
- return (latents - shift_factor) * scaling_factor
294
-
295
-
296
- @torch.no_grad()
297
- def decode_latents_with_vae(latents, vae, scaling_factor, shift_factor):
298
- latents = latents / scaling_factor + shift_factor
299
- images = vae.decode(latents).sample
300
- return ((images + 1.0) / 2.0).clamp(0, 1)
301
 
302
 
303
  # =============================================================================
304
  # Main Training Loop
305
  # =============================================================================
306
 
307
- def train(config: TrainConfig):
308
- """Main training loop with streaming dataset support."""
309
  from model import LiquidGen
310
-
311
  torch.manual_seed(config.seed)
312
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
313
  print(f"Device: {device}")
314
-
 
 
 
315
  os.makedirs(config.output_dir, exist_ok=True)
316
- os.makedirs(os.path.join(config.output_dir, "samples"), exist_ok=True)
317
- os.makedirs(os.path.join(config.output_dir, "checkpoints"), exist_ok=True)
318
-
319
- with open(os.path.join(config.output_dir, "config.json"), "w") as f:
320
  json.dump(asdict(config), f, indent=2)
321
-
322
- # Load VAE (frozen)
323
- print("Loading VAE...")
324
- from diffusers import AutoencoderKL
325
- vae_dtype = torch.float16 if config.vae_dtype == "float16" else torch.bfloat16
326
- vae_kwargs = {"torch_dtype": vae_dtype}
327
- if config.vae_subfolder:
328
- vae_kwargs["subfolder"] = config.vae_subfolder
329
- vae = AutoencoderKL.from_pretrained(
330
- config.vae_id, **vae_kwargs
331
- ).to(device).eval()
332
- for p in vae.parameters():
333
- p.requires_grad_(False)
334
- print(f"VAE: {sum(p.numel() for p in vae.parameters())/1e6:.1f}M params (frozen)")
335
-
336
- # Load Dataset
337
- print(f"Loading dataset: {config.dataset_name} (streaming={config.use_streaming})")
338
- if config.use_streaming:
339
- train_dataset = StreamingImageDataset(
340
- dataset_name=config.dataset_name,
341
- image_column=config.image_column,
342
- label_column=config.label_column,
343
- image_size=config.image_size,
344
- dataset_config=config.dataset_config,
345
- buffer_size=config.streaming_buffer,
346
- seed=config.seed,
347
- )
348
- train_loader = DataLoader(
349
- train_dataset, batch_size=config.batch_size,
350
- num_workers=0, # Required for streaming
351
- pin_memory=config.pin_memory,
352
- )
353
- print(" Streaming mode — no full download, starts immediately!")
354
- else:
355
- train_dataset = MapImageDataset(
356
- dataset_name=config.dataset_name,
357
- image_column=config.image_column,
358
- label_column=config.label_column,
359
- image_size=config.image_size,
360
- dataset_config=config.dataset_config,
361
- max_samples=config.max_samples,
362
- )
363
- train_loader = DataLoader(
364
- train_dataset, batch_size=config.batch_size, shuffle=True,
365
- num_workers=2, pin_memory=config.pin_memory, drop_last=True,
366
- )
367
-
368
- # Create Model
369
- model_kwargs = get_model_config(config.model_size, config.num_classes, config.class_drop_prob)
370
- model = LiquidGen(**model_kwargs).to(device)
371
- print(f"LiquidGen-{config.model_size}: {model.count_params() / 1e6:.1f}M params")
372
-
373
  if config.compile_model and hasattr(torch, "compile"):
374
  model = torch.compile(model)
375
-
376
- optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate,
377
- weight_decay=config.weight_decay, betas=(0.9, 0.999))
378
- scheduler = get_cosine_schedule_with_warmup(optimizer, config.warmup_steps, config.max_steps)
379
- ema = EMAModel(model, decay=config.ema_decay)
380
- scaler = GradScaler('cuda', enabled=config.mixed_precision)
381
- fm = FlowMatchingScheduler(min_t=config.min_timestep, max_t=config.max_timestep)
382
-
383
- print(f"\nTraining for {config.max_steps} steps")
384
- print(f"Effective batch size: {config.batch_size * config.gradient_accumulation_steps}")
385
-
386
- # Step-based training loop (works for both streaming and map datasets)
387
- global_step = 0
388
- loss_accum = 0.0
389
- accum_count = 0
390
- model.train()
 
 
 
 
391
  t_start = time.time()
392
-
393
- while global_step < config.max_steps:
394
- for images, labels in train_loader:
395
- if global_step >= config.max_steps:
396
- break
397
-
398
- images = images.to(device)
399
- labels = labels.to(device) if config.num_classes > 0 else None
400
-
401
- # Encode to latents
402
- with torch.no_grad():
403
- latents = encode_images_with_vae(
404
- images.to(vae_dtype), vae, config.vae_scaling_factor, config.vae_shift_factor
405
- ).float()
406
-
407
- # Flow matching
408
- t = fm.sample_timesteps(latents.shape[0], device)
409
- noise = torch.randn_like(latents)
410
- x_t = fm.add_noise(latents, noise, t)
411
- v_target = fm.get_velocity_target(latents, noise)
412
-
413
- with autocast('cuda', enabled=config.mixed_precision):
414
- v_pred = model(x_t, t, labels)
415
- loss = F.mse_loss(v_pred, v_target) / config.gradient_accumulation_steps
416
-
417
  scaler.scale(loss).backward()
418
- loss_accum += loss.item()
419
- accum_count += 1
420
-
421
- if accum_count % config.gradient_accumulation_steps == 0:
422
- scaler.unscale_(optimizer)
423
- grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
424
- scaler.step(optimizer)
425
- scaler.update()
426
- optimizer.zero_grad()
427
- scheduler.step()
428
- ema.update(model)
429
- global_step += 1
430
-
431
- # Logging
432
- if global_step % config.log_every_n_steps == 0:
433
- avg_loss = loss_accum / config.log_every_n_steps
434
- lr = optimizer.param_groups[0]["lr"]
435
- elapsed = time.time() - t_start
436
- steps_per_sec = global_step / max(elapsed, 1)
437
- print(f"step={global_step} | loss={avg_loss:.4f} | "
438
- f"grad_norm={grad_norm:.2f} | lr={lr:.2e} | "
439
- f"steps/s={steps_per_sec:.2f} | elapsed={elapsed:.0f}s")
440
- loss_accum = 0.0
441
-
442
- if math.isnan(avg_loss) or avg_loss > 100:
443
- print("⚠️ Training diverged!")
444
- return
445
-
446
- # Sample
447
- if global_step % config.sample_every_n_steps == 0:
448
- ema.apply(model)
449
- model.eval()
450
- latent_size = config.image_size // 8
451
- sample_labels = None
452
- if config.num_classes > 0:
453
- sample_labels = torch.randint(0, config.num_classes, (config.num_samples,), device=device)
454
- latent_ch = vae.config.latent_channels # 4 for SDXL, 16 for Flux
455
- sampled = fm.sample(model, (config.num_samples, latent_ch, latent_size, latent_size),
456
- device, config.num_sample_steps, sample_labels, config.cfg_scale)
457
- sample_imgs = decode_latents_with_vae(sampled.to(vae_dtype), vae,
458
- config.vae_scaling_factor, config.vae_shift_factor).float()
459
  from torchvision.utils import save_image
460
- save_image(sample_imgs, os.path.join(config.output_dir, "samples", f"step_{global_step:07d}.png"), nrow=2)
461
- print(f" 📸 Saved samples: step_{global_step:07d}.png")
462
- ema.restore(model)
463
- model.train()
464
-
465
- # Checkpoint
466
- if global_step % config.save_every_n_steps == 0:
467
- ckpt_path = os.path.join(config.output_dir, "checkpoints", f"step_{global_step:07d}.pt")
468
- torch.save({
469
- "model": model.state_dict(), "ema": ema.state_dict(),
470
- "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(),
471
- "global_step": global_step, "config": asdict(config),
472
- }, ckpt_path)
473
- print(f" 💾 Checkpoint: {ckpt_path}")
474
-
475
- # Final save
476
- final_path = os.path.join(config.output_dir, "checkpoints", "final.pt")
477
- torch.save({"model": model.state_dict(), "ema": ema.state_dict(),
478
- "config": asdict(config), "global_step": global_step}, final_path)
479
- elapsed = time.time() - t_start
480
- print(f"\n🎉 Training complete! {global_step} steps in {elapsed/60:.1f} min")
481
- print(f" Final model: {final_path}")
482
 
483
 
484
  if __name__ == "__main__":
485
  config = TrainConfig(
486
- model_size="small", image_size=256, batch_size=4,
487
- max_steps=100, use_streaming=True,
 
488
  )
489
  train(config)
 
1
  """
2
+ LiquidGen Training Pipeline v2
3
+
4
+ Optimized for Colab free tier:
5
+ - Latent pre-caching: encode images with VAE once, save to disk, train on pure tensors
6
+ - No VAE needed during training loop → saves ~1GB VRAM + faster iterations
7
+ - Streaming support for large datasets
8
+ - Multiple small dataset presets
9
 
10
  Flow Matching training objective (velocity prediction):
11
+ - Forward: x_t = (1 - t) * x_0 + t * ε
12
+ - Target: v = ε - x_0
13
  - Loss: MSE(model(x_t, t), v)
 
 
 
 
 
14
  """
15
 
16
  import torch
17
  import torch.nn as nn
18
  import torch.nn.functional as F
19
+ from torch.utils.data import DataLoader, Dataset
20
  from torch.amp import autocast, GradScaler
21
  import math
22
  import os
23
  import json
24
  import time
25
+ from typing import Optional
26
+ from dataclasses import dataclass, asdict
27
+
28
+
29
+ # =============================================================================
30
+ # Dataset Presets (all verified, fast to download)
31
+ # =============================================================================
32
+
33
+ DATASET_PRESETS = {
34
+ "paintings_mini": {
35
+ "name": "keremberke/painting-style-classification",
36
+ "config": "mini",
37
+ "image_column": "image",
38
+ "label_column": "labels",
39
+ "num_classes": 27,
40
+ "description": "~200 painting samples, 27 styles, 1.7MB — instant smoke test",
41
+ },
42
+ "paintings": {
43
+ "name": "keremberke/painting-style-classification",
44
+ "config": "full",
45
+ "image_column": "image",
46
+ "label_column": "labels",
47
+ "num_classes": 27,
48
+ "description": "~8K paintings, 27 styles, 204MB — best for style-conditional training",
49
+ },
50
+ "cartoon": {
51
+ "name": "Norod78/cartoon-blip-captions",
52
+ "config": "",
53
+ "image_column": "image",
54
+ "label_column": "",
55
+ "num_classes": 0,
56
+ "description": "~2.5K cartoon/anime, unconditional, 181MB",
57
+ },
58
+ "flowers": {
59
+ "name": "huggan/flowers-102-categories",
60
+ "config": "",
61
+ "image_column": "image",
62
+ "label_column": "",
63
+ "num_classes": 0,
64
+ "description": "~8K flower photos, unconditional, 331MB",
65
+ },
66
+ "wikiart_stream": {
67
+ "name": "huggan/wikiart",
68
+ "config": "",
69
+ "image_column": "image",
70
+ "label_column": "style",
71
+ "num_classes": 27,
72
+ "streaming": True,
73
+ "description": "~80K paintings, 27 styles, STREAMING (0 disk) — use max_images to limit",
74
+ },
75
+ }
76
 
77
 
78
  @dataclass
79
  class TrainConfig:
80
+ """Training configuration optimized for Colab free tier (T4 16GB)."""
81
  # Model
82
+ model_size: str = "small" # small (~55M), base (~140M), large (~280M)
83
+ num_classes: int = 27
84
  class_drop_prob: float = 0.1
85
+
86
  # Data
87
+ dataset_preset: str = "paintings" # key from DATASET_PRESETS
88
+ image_size: int = 256 # 256 or 512
89
+ max_images: int = 0 # 0 = use all, >0 = limit (for streaming/testing)
90
+
91
+ # VAE (for pre-caching only — NOT loaded during training)
92
+ vae_id: str = "black-forest-labs/FLUX.1-schnell"
93
+ vae_subfolder: str = "vae"
94
+ vae_scaling_factor: float = 0.3611
95
+ vae_shift_factor: float = 0.1159
96
+
 
 
 
 
 
 
97
  # Training
98
+ batch_size: int = 32 # Can be large since training on cached tensors!
99
+ gradient_accumulation_steps: int = 1
100
  learning_rate: float = 1e-4
101
  weight_decay: float = 0.01
102
+ max_grad_norm: float = 2.0 # Critical for stability (ZigMa paper)
103
+ num_epochs: int = 100
104
+ warmup_steps: int = 500
105
  ema_decay: float = 0.9999
106
  mixed_precision: bool = True
107
+
108
  # Flow matching
109
  min_timestep: float = 0.001
110
  max_timestep: float = 0.999
111
+
112
  # Saving
113
  output_dir: str = "./outputs"
114
+ save_every_n_steps: int = 2000
115
+ sample_every_n_steps: int = 500
116
+ log_every_n_steps: int = 25
117
+
118
  # Sampling
119
  num_sample_steps: int = 50
120
+ cfg_scale: float = 2.0
121
  num_samples: int = 4
122
+
123
  # System
124
  seed: int = 42
125
+ num_workers: int = 2
 
126
  compile_model: bool = False
127
+
128
  # Hub
129
  push_to_hub: bool = False
130
  hub_model_id: str = ""
131
 
132
 
133
+ def get_model_config(size, num_classes=0, class_drop_prob=0.1):
 
134
  configs = {
135
  "small": dict(embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31,
136
  expand_ratio=2.0, mlp_ratio=3.0),
 
147
 
148
 
149
  # =============================================================================
150
+ # Latent Pre-Caching (the key optimization for Colab)
151
  # =============================================================================
152
 
153
+ class CachedLatentDataset(Dataset):
154
+ """Training dataset from pre-encoded VAE latents on disk."""
155
+
156
+ def __init__(self, cache_path):
157
+ data = torch.load(cache_path, map_location="cpu", weights_only=True)
158
+ self.latents = data["latents"]
159
+ self.labels = data.get("labels", None)
160
+ print(f"Loaded {len(self.latents)} cached latents from {cache_path}")
161
+ print(f" Shape: {self.latents.shape}, dtype: {self.latents.dtype}")
162
+ if self.labels is not None:
163
+ print(f" Labels: unique={self.labels.unique().shape[0]}")
164
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  def __len__(self):
166
+ return len(self.latents)
167
+
168
  def __getitem__(self, idx):
169
+ lat = self.latents[idx]
170
+ label = self.labels[idx] if self.labels is not None else -1
171
+ return lat, label
172
+
173
+
174
+ def precache_latents(config, cache_path=None):
175
+ """
176
+ Encode all images to VAE latents once, save to disk.
177
+
178
+ After caching:
179
+ - VAE unloaded → frees ~1GB VRAM
180
+ - Training loads pure tensors → much faster iterations
181
+ - Larger batch sizes possible (no VAE memory overhead)
182
+
183
+ Returns path to cache file.
184
+ """
185
+ if cache_path is None:
186
+ cache_path = os.path.join(config.output_dir, "cached_latents.pt")
187
+
188
+ if os.path.exists(cache_path):
189
+ print(f"✅ Cache exists: {cache_path}")
190
+ data = torch.load(cache_path, map_location="cpu", weights_only=True)
191
+ print(f" {data['latents'].shape[0]} latents, shape {data['latents'].shape[1:]}")
192
+ return cache_path
193
+
194
+ os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else ".", exist_ok=True)
195
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
196
+
197
+ # Load VAE temporarily
198
+ print("Loading VAE for encoding...")
199
+ from diffusers import AutoencoderKL
200
+ vae = AutoencoderKL.from_pretrained(
201
+ config.vae_id, subfolder=config.vae_subfolder, torch_dtype=torch.float16
202
+ ).to(device).eval()
203
+ for p in vae.parameters():
204
+ p.requires_grad_(False)
205
+
206
+ # Load dataset
207
+ preset = DATASET_PRESETS[config.dataset_preset]
208
+ print(f"Loading dataset: {preset['name']} ({preset['description']})")
209
+
210
+ from datasets import load_dataset
211
+ from torchvision import transforms
212
+
213
+ is_streaming = preset.get("streaming", False)
214
+ ds_kwargs = {"split": "train"}
215
+ if preset["config"]:
216
+ ds_kwargs["name"] = preset["config"]
217
+ if is_streaming:
218
+ ds_kwargs["streaming"] = True
219
+
220
+ dataset = load_dataset(preset["name"], **ds_kwargs)
221
+
222
+ transform = transforms.Compose([
223
+ transforms.Resize(config.image_size, interpolation=transforms.InterpolationMode.LANCZOS),
224
+ transforms.CenterCrop(config.image_size),
225
+ transforms.ToTensor(),
226
+ ])
227
+
228
+ all_latents = []
229
+ all_labels = []
230
+ batch_pixels = []
231
+ batch_labels = []
232
+ encode_bs = 16
233
+ count = 0
234
+ max_imgs = config.max_images if config.max_images > 0 else float("inf")
235
+ img_col = preset["image_column"]
236
+ lbl_col = preset["label_column"]
237
+
238
+ print(f"Encoding images to latents...")
239
+ t0 = time.time()
240
+
241
+ for item in dataset:
242
+ if count >= max_imgs:
243
+ break
244
+ img = item[img_col]
245
  if img.mode != "RGB":
246
  img = img.convert("RGB")
247
+ batch_pixels.append(transform(img))
248
+ if lbl_col and lbl_col in item:
249
+ batch_labels.append(item[lbl_col])
250
+ else:
251
+ batch_labels.append(-1)
252
+ count += 1
253
+
254
+ if len(batch_pixels) >= encode_bs:
255
+ with torch.no_grad():
256
+ px = torch.stack(batch_pixels).to(device, dtype=torch.float16) * 2 - 1
257
+ lat = vae.encode(px).latent_dist.sample()
258
+ lat = (lat - config.vae_shift_factor) * config.vae_scaling_factor
259
+ all_latents.append(lat.cpu().float())
260
+ all_labels.extend(batch_labels)
261
+ batch_pixels, batch_labels = [], []
262
+ if count % 500 == 0:
263
+ print(f" {count} images encoded ({time.time()-t0:.0f}s)")
264
+
265
+ if batch_pixels:
266
+ with torch.no_grad():
267
+ px = torch.stack(batch_pixels).to(device, dtype=torch.float16) * 2 - 1
268
+ lat = vae.encode(px).latent_dist.sample()
269
+ lat = (lat - config.vae_shift_factor) * config.vae_scaling_factor
270
+ all_latents.append(lat.cpu().float())
271
+ all_labels.extend(batch_labels)
272
+
273
+ all_latents = torch.cat(all_latents, dim=0)
274
+ all_labels = torch.tensor(all_labels, dtype=torch.long)
275
+ torch.save({"latents": all_latents, "labels": all_labels}, cache_path)
276
+
277
+ elapsed = time.time() - t0
278
+ mb = os.path.getsize(cache_path) / 1024**2
279
+ print(f"\n✅ Cached {count} latents → {cache_path}")
280
+ print(f" Shape: {all_latents.shape}, Size: {mb:.1f}MB, Time: {elapsed:.0f}s")
281
+
282
+ del vae
283
+ if torch.cuda.is_available():
284
+ torch.cuda.empty_cache()
285
+ print(" VAE unloaded, VRAM freed\n")
286
+ return cache_path
287
 
288
 
289
  # =============================================================================
290
+ # EMA, FlowMatching, Scheduler
291
  # =============================================================================
292
 
293
  class EMAModel:
294
+ def __init__(self, model, decay=0.9999):
 
295
  self.decay = decay
296
+ self.shadow = {n: p.clone().detach() for n, p in model.named_parameters() if p.requires_grad}
297
+
298
  @torch.no_grad()
299
+ def update(self, model):
300
+ for n, p in model.named_parameters():
301
+ if p.requires_grad and n in self.shadow:
302
+ self.shadow[n].mul_(self.decay).add_(p.data, alpha=1 - self.decay)
303
+
304
+ def apply(self, model):
305
+ self.backup = {n: p.data.clone() for n, p in model.named_parameters() if p.requires_grad}
306
+ for n, p in model.named_parameters():
307
+ if p.requires_grad and n in self.shadow:
308
+ p.data.copy_(self.shadow[n])
309
+
310
+ def restore(self, model):
311
+ for n, p in model.named_parameters():
312
+ if p.requires_grad and n in self.backup:
313
+ p.data.copy_(self.backup[n])
314
  self.backup = {}
 
 
 
 
 
 
315
 
316
 
317
  class FlowMatchingScheduler:
 
318
  def __init__(self, min_t=0.001, max_t=0.999):
319
  self.min_t, self.max_t = min_t, max_t
320
+
321
+ def sample_timesteps(self, bs, dev):
322
+ return torch.rand(bs, device=dev) * (self.max_t - self.min_t) + self.min_t
323
+
324
  def add_noise(self, x0, noise, t):
325
+ t = t.view(-1, 1, 1, 1); return (1 - t) * x0 + t * noise
326
+
 
327
  def get_velocity_target(self, x0, noise):
328
  return noise - x0
329
+
330
  @torch.no_grad()
331
+ def sample(self, model, shape, dev, num_steps=50, labels=None, cfg=1.0):
332
+ model.eval(); x = torch.randn(shape, device=dev)
 
 
333
  dt = 1.0 / num_steps
334
+ for tv in torch.linspace(1.0, dt, num_steps, device=dev):
335
+ t = torch.full((shape[0],), tv.item(), device=dev)
336
+ with torch.amp.autocast("cuda"):
337
+ if cfg > 1.0 and labels is not None:
338
+ vc = model(x, t, labels); vu = model(x, t, torch.zeros_like(labels))
339
+ v = vu + cfg * (vc - vu)
340
+ else:
341
+ v = model(x, t, labels)
342
+ x = x - dt * v.float()
 
 
343
  return x
344
 
345
 
346
+ def cosine_schedule(opt, warmup, total):
347
+ def lr(s):
348
+ if s < warmup: return s / max(1, warmup)
349
+ return max(0, 0.5 * (1 + math.cos(math.pi * (s - warmup) / max(1, total - warmup))))
350
+ return torch.optim.lr_scheduler.LambdaLR(opt, lr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
 
353
  # =============================================================================
354
  # Main Training Loop
355
  # =============================================================================
356
 
357
+ def train(config):
 
358
  from model import LiquidGen
359
+
360
  torch.manual_seed(config.seed)
361
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
362
  print(f"Device: {device}")
363
+ if torch.cuda.is_available():
364
+ print(f"GPU: {torch.cuda.get_device_name(0)} "
365
+ f"({torch.cuda.get_device_properties(0).total_mem/1024**3:.1f} GB)")
366
+
367
  os.makedirs(config.output_dir, exist_ok=True)
368
+ os.makedirs(f"{config.output_dir}/samples", exist_ok=True)
369
+ os.makedirs(f"{config.output_dir}/checkpoints", exist_ok=True)
370
+
371
+ with open(f"{config.output_dir}/config.json", "w") as f:
372
  json.dump(asdict(config), f, indent=2)
373
+
374
+ # Step 1: Pre-cache latents
375
+ cache_path = precache_latents(config)
376
+
377
+ # Step 2: Dataset from cache
378
+ train_ds = CachedLatentDataset(cache_path)
379
+ train_dl = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True,
380
+ num_workers=config.num_workers, pin_memory=True, drop_last=True)
381
+
382
+ # Step 3: Model
383
+ mcfg = get_model_config(config.model_size, config.num_classes, config.class_drop_prob)
384
+ model = LiquidGen(**mcfg).to(device)
385
+ print(f"LiquidGen-{config.model_size}: {model.count_params()/1e6:.1f}M params")
386
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  if config.compile_model and hasattr(torch, "compile"):
388
  model = torch.compile(model)
389
+
390
+ # Step 4: Training setup
391
+ opt = torch.optim.AdamW(model.parameters(), lr=config.learning_rate,
392
+ weight_decay=config.weight_decay, betas=(0.9, 0.999))
393
+ total_steps = len(train_dl) * config.num_epochs // config.gradient_accumulation_steps
394
+ sched = cosine_schedule(opt, config.warmup_steps, total_steps)
395
+ ema = EMAModel(model, config.ema_decay)
396
+ scaler = GradScaler("cuda", enabled=config.mixed_precision and torch.cuda.is_available())
397
+ fm = FlowMatchingScheduler(config.min_timestep, config.max_timestep)
398
+ lat_size = config.image_size // 8
399
+
400
+ print(f"\nTotal steps: {total_steps}, Batch: {config.batch_size}×{config.gradient_accumulation_steps}")
401
+ print(f"No VAE during training → max VRAM for model")
402
+ if torch.cuda.is_available():
403
+ print(f"VRAM: {torch.cuda.memory_allocated()/1024**3:.1f} / "
404
+ f"{torch.cuda.get_device_properties(0).total_mem/1024**3:.1f} GB")
405
+
406
+ # Step 5: Train!
407
+ gs = 0; la = 0.0; vae = None; vae_loaded = False
408
+ print(f"\n{'='*60}\n🚀 Training!\n{'='*60}\n")
409
  t_start = time.time()
410
+
411
+ for epoch in range(config.num_epochs):
412
+ model.train(); et = time.time()
413
+ for bi, (lats, lbls) in enumerate(train_dl):
414
+ lats = lats.to(device)
415
+ lbls = lbls.to(device) if config.num_classes > 0 else None
416
+
417
+ t = fm.sample_timesteps(lats.shape[0], device)
418
+ noise = torch.randn_like(lats)
419
+ xt = fm.add_noise(lats, noise, t)
420
+ vtgt = fm.get_velocity_target(lats, noise)
421
+
422
+ with autocast("cuda", enabled=config.mixed_precision and torch.cuda.is_available()):
423
+ vp = model(xt, t, lbls)
424
+ loss = F.mse_loss(vp, vtgt) / config.gradient_accumulation_steps
425
+
 
 
 
 
 
 
 
 
 
426
  scaler.scale(loss).backward()
427
+ la += loss.item()
428
+
429
+ if (bi + 1) % config.gradient_accumulation_steps == 0:
430
+ scaler.unscale_(opt)
431
+ gn = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
432
+ scaler.step(opt); scaler.update(); opt.zero_grad(); sched.step()
433
+ ema.update(model); gs += 1
434
+
435
+ if gs % config.log_every_n_steps == 0:
436
+ al = la / config.log_every_n_steps
437
+ lr = opt.param_groups[0]["lr"]
438
+ vram = torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0
439
+ sps = gs / max(time.time() - t_start, 1)
440
+ print(f"step={gs:>6d} | ep={epoch} | loss={al:.4f} | gn={gn:.2f} | "
441
+ f"lr={lr:.2e} | vram={vram:.1f}G | {sps:.1f} st/s")
442
+ la = 0.0
443
+ if math.isnan(al) or al > 50:
444
+ print("💥 Diverged!"); return
445
+
446
+ if gs % config.sample_every_n_steps == 0:
447
+ if not vae_loaded:
448
+ from diffusers import AutoencoderKL
449
+ vae = AutoencoderKL.from_pretrained(
450
+ config.vae_id, subfolder=config.vae_subfolder,
451
+ torch_dtype=torch.float16).to(device).eval()
452
+ for p in vae.parameters(): p.requires_grad_(False)
453
+ vae_loaded = True
454
+ ema.apply(model); model.eval()
455
+ sl = torch.randint(0, max(1, config.num_classes), (config.num_samples,),
456
+ device=device) if config.num_classes > 0 else None
457
+ samp = fm.sample(model, (config.num_samples, 16, lat_size, lat_size),
458
+ device, config.num_sample_steps, sl, config.cfg_scale)
459
+ with torch.no_grad():
460
+ dec = samp.half() / config.vae_scaling_factor + config.vae_shift_factor
461
+ imgs = ((vae.decode(dec).sample + 1) / 2).clamp(0, 1).float()
 
 
 
 
 
 
462
  from torchvision.utils import save_image
463
+ sp = f"{config.output_dir}/samples/step_{gs:07d}.png"
464
+ save_image(imgs, sp, nrow=2); print(f" 📸 {sp}")
465
+ ema.restore(model); model.train()
466
+
467
+ if gs % config.save_every_n_steps == 0:
468
+ cp = f"{config.output_dir}/checkpoints/step_{gs:07d}.pt"
469
+ torch.save({"model": model.state_dict(), "ema": ema.shadow,
470
+ "optimizer": opt.state_dict(), "scheduler": sched.state_dict(),
471
+ "step": gs, "epoch": epoch, "model_config": mcfg}, cp)
472
+ print(f" 💾 {cp}")
473
+
474
+ print(f"Epoch {epoch} | {time.time()-et:.0f}s\n")
475
+
476
+ final = f"{config.output_dir}/checkpoints/final.pt"
477
+ torch.save({"model": model.state_dict(), "ema": ema.shadow,
478
+ "model_config": mcfg, "step": gs}, final)
479
+ print(f"\n🎉 Done! {gs} steps, {(time.time()-t_start)/60:.1f}min {final}")
 
 
 
 
 
480
 
481
 
482
  if __name__ == "__main__":
483
  config = TrainConfig(
484
+ model_size="small", dataset_preset="paintings_mini",
485
+ image_size=256, batch_size=8, num_epochs=5,
486
+ log_every_n_steps=5, sample_every_n_steps=99999,
487
  )
488
  train(config)