asdf98 commited on
Commit
57090a0
·
verified ·
1 Parent(s): 193fbf7

Add gradient checkpointing + auto batch size to prevent OOM on T4

Browse files
Files changed (1) hide show
  1. train.py +106 -187
train.py CHANGED
@@ -4,13 +4,10 @@ LiquidGen Training Pipeline v2
4
  Optimized for Colab free tier:
5
  - Latent pre-caching: encode images with VAE once, save to disk, train on pure tensors
6
  - No VAE needed during training loop -> saves ~1GB VRAM + faster iterations
 
 
7
  - All datasets are pure parquet — no legacy loading scripts
8
  - Uses madebyollin/sdxl-vae-fp16-fix (fully open, no login, fp16 stable)
9
-
10
- Flow Matching training objective (velocity prediction):
11
- - Forward: x_t = (1 - t) * x_0 + t * eps
12
- - Target: v = eps - x_0
13
- - Loss: MSE(model(x_t, t), v)
14
  """
15
 
16
  import torch
@@ -26,10 +23,6 @@ from typing import Optional
26
  from dataclasses import dataclass, asdict
27
 
28
 
29
- # =============================================================================
30
- # Dataset Presets — ALL pure parquet, no loading scripts, no auth
31
- # =============================================================================
32
-
33
  DATASET_PRESETS = {
34
  "cartoon": {
35
  "name": "Norod78/cartoon-blip-captions",
@@ -52,7 +45,7 @@ DATASET_PRESETS = {
52
  "config": "",
53
  "image_column": "image",
54
  "label_column": "style",
55
- "num_classes": 0, # string labels, mapped to ints automatically
56
  "description": "~105K paintings with style labels, 1.6GB (use max_images to limit)",
57
  },
58
  "art_painting": {
@@ -66,26 +59,47 @@ DATASET_PRESETS = {
66
  }
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  @dataclass
70
  class TrainConfig:
71
- """Training configuration optimized for Colab free tier (T4 16GB)."""
72
- # Model
73
- model_size: str = "small" # small (~55M), base (~140M), large (~280M)
74
- num_classes: int = 0 # 0 = unconditional
75
  class_drop_prob: float = 0.1
76
-
77
- # Data
78
- dataset_preset: str = "cartoon" # key from DATASET_PRESETS
79
- image_size: int = 256 # 256 or 512
80
- max_images: int = 0 # 0 = use all, >0 = limit
81
-
82
- # VAE — fully open, no login needed
83
  vae_id: str = "madebyollin/sdxl-vae-fp16-fix"
84
  vae_scaling_factor: float = 0.13025
85
  latent_channels: int = 4
86
-
87
- # Training
88
- batch_size: int = 32
89
  gradient_accumulation_steps: int = 1
90
  learning_rate: float = 1e-4
91
  weight_decay: float = 0.01
@@ -94,28 +108,19 @@ class TrainConfig:
94
  warmup_steps: int = 500
95
  ema_decay: float = 0.9999
96
  mixed_precision: bool = True
97
-
98
- # Flow matching
99
  min_timestep: float = 0.001
100
  max_timestep: float = 0.999
101
-
102
- # Saving
103
  output_dir: str = "./outputs"
104
  save_every_n_steps: int = 2000
105
  sample_every_n_steps: int = 500
106
  log_every_n_steps: int = 25
107
-
108
- # Sampling
109
  num_sample_steps: int = 50
110
  cfg_scale: float = 2.0
111
  num_samples: int = 4
112
-
113
- # System
114
  seed: int = 42
115
  num_workers: int = 2
116
  compile_model: bool = False
117
-
118
- # Hub
119
  push_to_hub: bool = False
120
  hub_model_id: str = ""
121
 
@@ -136,38 +141,25 @@ def get_model_config(size, num_classes=0, class_drop_prob=0.1):
136
  return cfg
137
 
138
 
139
- # =============================================================================
140
- # Latent Pre-Caching
141
- # =============================================================================
142
-
143
  class CachedLatentDataset(Dataset):
144
- """Training dataset from pre-encoded VAE latents on disk."""
145
-
146
  def __init__(self, cache_path):
147
  data = torch.load(cache_path, map_location="cpu", weights_only=True)
148
  self.latents = data["latents"]
149
  self.labels = data.get("labels", None)
150
  print(f"Loaded {len(self.latents)} cached latents from {cache_path}")
151
- print(f" Shape: {self.latents.shape}, dtype: {self.latents.dtype}")
152
  if self.labels is not None and (self.labels >= 0).any():
153
- print(f" Labels: unique={self.labels[self.labels >= 0].unique().shape[0]}")
154
 
155
- def __len__(self):
156
- return len(self.latents)
157
 
158
  def __getitem__(self, idx):
159
- lat = self.latents[idx]
160
- label = self.labels[idx] if self.labels is not None else -1
161
- return lat, label
162
 
163
 
164
  def precache_latents(config, cache_path=None):
165
- """
166
- Encode all images to VAE latents once, save to disk.
167
- """
168
  if cache_path is None:
169
  cache_path = os.path.join(config.output_dir, "cached_latents.pt")
170
-
171
  if os.path.exists(cache_path):
172
  print(f"Cache exists: {cache_path}")
173
  data = torch.load(cache_path, map_location="cpu", weights_only=True)
@@ -177,167 +169,115 @@ def precache_latents(config, cache_path=None):
177
  os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else ".", exist_ok=True)
178
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
179
 
180
- # Load VAE
181
  print(f"Loading VAE: {config.vae_id}...")
182
  from diffusers import AutoencoderKL
183
- vae = AutoencoderKL.from_pretrained(
184
- config.vae_id, torch_dtype=torch.float16
185
- ).to(device).eval()
186
- for p in vae.parameters():
187
- p.requires_grad_(False)
188
  print(f" VAE: {sum(p.numel() for p in vae.parameters())/1e6:.0f}M params")
189
 
190
- # Load dataset
191
  preset = DATASET_PRESETS[config.dataset_preset]
192
  print(f"Loading: {preset['name']} ({preset['description']})")
193
-
194
  from datasets import load_dataset
195
  from torchvision import transforms
196
 
197
  ds_kwargs = {"split": "train"}
198
- if preset["config"]:
199
- ds_kwargs["name"] = preset["config"]
200
-
201
  dataset = load_dataset(preset["name"], **ds_kwargs)
202
 
203
  transform = transforms.Compose([
204
  transforms.Resize(config.image_size, interpolation=transforms.InterpolationMode.LANCZOS),
205
- transforms.CenterCrop(config.image_size),
206
- transforms.ToTensor(),
207
  ])
208
 
209
- # For Artificio/WikiArt: style is a string, map to int
210
- img_col = preset["image_column"]
211
- lbl_col = preset["label_column"]
212
  style_to_id = {}
213
-
214
- all_latents = []
215
- all_labels = []
216
- batch_pixels = []
217
- batch_labels = []
218
- encode_bs = 16
219
- count = 0
220
- max_imgs = config.max_images if config.max_images > 0 else float("inf")
221
-
222
- print(f"Encoding to VAE latents...")
223
  t0 = time.time()
224
 
225
  for item in dataset:
226
- if count >= max_imgs:
227
- break
228
  img = item[img_col]
229
- if img.mode != "RGB":
230
- img = img.convert("RGB")
231
- batch_pixels.append(transform(img))
232
-
233
- # Handle labels: int or string
234
  if lbl_col and lbl_col in item:
235
- raw_label = item[lbl_col]
236
- if isinstance(raw_label, str):
237
- if raw_label not in style_to_id:
238
- style_to_id[raw_label] = len(style_to_id)
239
- batch_labels.append(style_to_id[raw_label])
240
- elif isinstance(raw_label, int):
241
- batch_labels.append(raw_label)
242
- else:
243
- batch_labels.append(-1)
244
- else:
245
- batch_labels.append(-1)
246
  count += 1
247
-
248
- if len(batch_pixels) >= encode_bs:
249
  with torch.no_grad():
250
- px = torch.stack(batch_pixels).to(device, dtype=torch.float16) * 2 - 1
251
- lat = vae.encode(px).latent_dist.sample()
252
- lat = lat * config.vae_scaling_factor
253
  all_latents.append(lat.cpu().float())
254
- all_labels.extend(batch_labels)
255
- batch_pixels, batch_labels = [], []
256
- if count % 500 == 0:
257
- print(f" {count} images ({time.time()-t0:.0f}s)")
258
 
259
- if batch_pixels:
260
  with torch.no_grad():
261
- px = torch.stack(batch_pixels).to(device, dtype=torch.float16) * 2 - 1
262
- lat = vae.encode(px).latent_dist.sample()
263
- lat = lat * config.vae_scaling_factor
264
  all_latents.append(lat.cpu().float())
265
- all_labels.extend(batch_labels)
266
 
267
  all_latents = torch.cat(all_latents, dim=0)
268
  all_labels = torch.tensor(all_labels, dtype=torch.long)
269
-
270
  save_data = {"latents": all_latents, "labels": all_labels}
271
  if style_to_id:
272
  save_data["style_to_id"] = style_to_id
273
- print(f" Mapped {len(style_to_id)} style labels to class IDs")
274
  torch.save(save_data, cache_path)
275
-
276
- elapsed = time.time() - t0
277
  mb = os.path.getsize(cache_path) / 1024**2
278
- print(f"\nCached {count} latents -> {cache_path}")
279
- print(f" Shape: {all_latents.shape}, {mb:.1f}MB, {elapsed:.0f}s")
280
-
281
  del vae
282
- if torch.cuda.is_available():
283
- torch.cuda.empty_cache()
284
  print(" VAE unloaded\n")
285
  return cache_path
286
 
287
 
288
- # =============================================================================
289
- # EMA, FlowMatching, Scheduler
290
- # =============================================================================
291
-
292
  class EMAModel:
293
  def __init__(self, model, decay=0.9999):
294
  self.decay = decay
295
  self.shadow = {n: p.clone().detach() for n, p in model.named_parameters() if p.requires_grad}
296
-
297
  @torch.no_grad()
298
  def update(self, model):
299
  for n, p in model.named_parameters():
300
  if p.requires_grad and n in self.shadow:
301
  self.shadow[n].mul_(self.decay).add_(p.data, alpha=1 - self.decay)
302
-
303
  def apply(self, model):
304
  self.backup = {n: p.data.clone() for n, p in model.named_parameters() if p.requires_grad}
305
  for n, p in model.named_parameters():
306
- if p.requires_grad and n in self.shadow:
307
- p.data.copy_(self.shadow[n])
308
-
309
  def restore(self, model):
310
  for n, p in model.named_parameters():
311
- if p.requires_grad and n in self.backup:
312
- p.data.copy_(self.backup[n])
313
  self.backup = {}
314
 
315
 
316
  class FlowMatchingScheduler:
317
  def __init__(self, min_t=0.001, max_t=0.999):
318
  self.min_t, self.max_t = min_t, max_t
319
-
320
  def sample_timesteps(self, bs, dev):
321
  return torch.rand(bs, device=dev) * (self.max_t - self.min_t) + self.min_t
322
-
323
  def add_noise(self, x0, noise, t):
324
  t = t.view(-1, 1, 1, 1); return (1 - t) * x0 + t * noise
325
-
326
  def get_velocity_target(self, x0, noise):
327
  return noise - x0
328
-
329
  @torch.no_grad()
330
  def sample(self, model, shape, dev, num_steps=50, labels=None, cfg=1.0):
331
- model.eval(); x = torch.randn(shape, device=dev)
332
- dt = 1.0 / num_steps
333
  for tv in torch.linspace(1.0, dt, num_steps, device=dev):
334
  t = torch.full((shape[0],), tv.item(), device=dev)
335
  with torch.amp.autocast("cuda"):
336
  if cfg > 1.0 and labels is not None:
337
  vc = model(x, t, labels); vu = model(x, t, torch.zeros_like(labels))
338
  v = vu + cfg * (vc - vu)
339
- else:
340
- v = model(x, t, labels)
341
  x = x - dt * v.float()
342
  return x
343
 
@@ -349,29 +289,30 @@ def cosine_schedule(opt, warmup, total):
349
  return torch.optim.lr_scheduler.LambdaLR(opt, lr)
350
 
351
 
352
- # =============================================================================
353
- # Main Training Loop
354
- # =============================================================================
355
-
356
  def train(config):
357
  from model import LiquidGen
358
-
359
  torch.manual_seed(config.seed)
360
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
361
- print(f"Device: {device}")
362
  if torch.cuda.is_available():
363
- print(f"GPU: {torch.cuda.get_device_name(0)} "
364
- f"({torch.cuda.get_device_properties(0).total_mem/1024**3:.1f} GB)")
 
 
 
 
 
 
 
 
365
 
366
  os.makedirs(config.output_dir, exist_ok=True)
367
  os.makedirs(f"{config.output_dir}/samples", exist_ok=True)
368
  os.makedirs(f"{config.output_dir}/checkpoints", exist_ok=True)
369
-
370
  with open(f"{config.output_dir}/config.json", "w") as f:
371
  json.dump(asdict(config), f, indent=2)
372
 
373
  cache_path = precache_latents(config)
374
-
375
  train_ds = CachedLatentDataset(cache_path)
376
  train_dl = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True,
377
  num_workers=config.num_workers, pin_memory=True, drop_last=True)
@@ -379,6 +320,12 @@ def train(config):
379
  mcfg = get_model_config(config.model_size, config.num_classes, config.class_drop_prob)
380
  mcfg["in_channels"] = config.latent_channels
381
  model = LiquidGen(**mcfg).to(device)
 
 
 
 
 
 
382
  print(f"LiquidGen-{config.model_size}: {model.count_params()/1e6:.1f}M params")
383
 
384
  if config.compile_model and hasattr(torch, "compile"):
@@ -393,11 +340,10 @@ def train(config):
393
  fm = FlowMatchingScheduler(config.min_timestep, config.max_timestep)
394
  lat_size = config.image_size // 8
395
 
396
- print(f"\nSteps: {total_steps}, Batch: {config.batch_size}x{config.gradient_accumulation_steps}")
397
  print(f"Latent: [{config.batch_size}, {config.latent_channels}, {lat_size}, {lat_size}]")
398
  if torch.cuda.is_available():
399
- print(f"VRAM: {torch.cuda.memory_allocated()/1024**3:.1f} / "
400
- f"{torch.cuda.get_device_properties(0).total_mem/1024**3:.1f} GB")
401
 
402
  gs = 0; la = 0.0; vae = None; vae_loaded = False
403
  print(f"\n{'='*60}\nTraining!\n{'='*60}\n")
@@ -408,76 +354,49 @@ def train(config):
408
  for bi, (lats, lbls) in enumerate(train_dl):
409
  lats = lats.to(device)
410
  lbls = lbls.to(device) if config.num_classes > 0 else None
411
-
412
  t = fm.sample_timesteps(lats.shape[0], device)
413
  noise = torch.randn_like(lats)
414
  xt = fm.add_noise(lats, noise, t)
415
  vtgt = fm.get_velocity_target(lats, noise)
416
-
417
  with autocast("cuda", enabled=config.mixed_precision and torch.cuda.is_available()):
418
  vp = model(xt, t, lbls)
419
  loss = F.mse_loss(vp, vtgt) / config.gradient_accumulation_steps
420
-
421
  scaler.scale(loss).backward()
422
  la += loss.item()
423
-
424
  if (bi + 1) % config.gradient_accumulation_steps == 0:
425
  scaler.unscale_(opt)
426
  gn = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
427
  scaler.step(opt); scaler.update(); opt.zero_grad(); sched.step()
428
  ema.update(model); gs += 1
429
-
430
  if gs % config.log_every_n_steps == 0:
431
  al = la / config.log_every_n_steps
432
- lr = opt.param_groups[0]["lr"]
433
  vram = torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0
434
  sps = gs / max(time.time() - t_start, 1)
435
  print(f"step={gs:>6d} | ep={epoch} | loss={al:.4f} | gn={gn:.2f} | "
436
- f"lr={lr:.2e} | vram={vram:.1f}G | {sps:.1f} st/s")
437
  la = 0.0
438
- if math.isnan(al) or al > 50:
439
- print("Diverged!"); return
440
-
441
  if gs % config.sample_every_n_steps == 0:
442
  if not vae_loaded:
443
  from diffusers import AutoencoderKL
444
- vae = AutoencoderKL.from_pretrained(
445
- config.vae_id, torch_dtype=torch.float16
446
- ).to(device).eval()
447
  for p in vae.parameters(): p.requires_grad_(False)
448
  vae_loaded = True
449
  ema.apply(model); model.eval()
450
- sl = torch.randint(0, max(1, config.num_classes), (config.num_samples,),
451
- device=device) if config.num_classes > 0 else None
452
  samp = fm.sample(model, (config.num_samples, config.latent_channels, lat_size, lat_size),
453
  device, config.num_sample_steps, sl, config.cfg_scale)
454
  with torch.no_grad():
455
- dec = samp.half() / config.vae_scaling_factor
456
- imgs = ((vae.decode(dec).sample + 1) / 2).clamp(0, 1).float()
457
  from torchvision.utils import save_image
458
- sp = f"{config.output_dir}/samples/step_{gs:07d}.png"
459
- save_image(imgs, sp, nrow=2); print(f" Saved: {sp}")
460
- ema.restore(model); model.train()
461
-
462
  if gs % config.save_every_n_steps == 0:
463
- cp = f"{config.output_dir}/checkpoints/step_{gs:07d}.pt"
464
  torch.save({"model": model.state_dict(), "ema": ema.shadow,
465
- "optimizer": opt.state_dict(), "scheduler": sched.state_dict(),
466
- "step": gs, "epoch": epoch, "model_config": mcfg}, cp)
467
- print(f" Saved: {cp}")
468
-
469
  print(f"Epoch {epoch} | {time.time()-et:.0f}s\n")
470
 
471
  final = f"{config.output_dir}/checkpoints/final.pt"
472
- torch.save({"model": model.state_dict(), "ema": ema.shadow,
473
- "model_config": mcfg, "step": gs}, final)
474
  print(f"\nDone! {gs} steps, {(time.time()-t_start)/60:.1f}min -> {final}")
475
-
476
-
477
- if __name__ == "__main__":
478
- config = TrainConfig(
479
- model_size="small", dataset_preset="cartoon",
480
- image_size=256, batch_size=8, num_epochs=5,
481
- log_every_n_steps=5, sample_every_n_steps=99999,
482
- )
483
- train(config)
 
4
  Optimized for Colab free tier:
5
  - Latent pre-caching: encode images with VAE once, save to disk, train on pure tensors
6
  - No VAE needed during training loop -> saves ~1GB VRAM + faster iterations
7
+ - Gradient checkpointing enabled by default (saves ~50% activation VRAM)
8
+ - Auto batch size selection based on model size + image size + GPU VRAM
9
  - All datasets are pure parquet — no legacy loading scripts
10
  - Uses madebyollin/sdxl-vae-fp16-fix (fully open, no login, fp16 stable)
 
 
 
 
 
11
  """
12
 
13
  import torch
 
23
  from dataclasses import dataclass, asdict
24
 
25
 
 
 
 
 
26
  DATASET_PRESETS = {
27
  "cartoon": {
28
  "name": "Norod78/cartoon-blip-captions",
 
45
  "config": "",
46
  "image_column": "image",
47
  "label_column": "style",
48
+ "num_classes": 0,
49
  "description": "~105K paintings with style labels, 1.6GB (use max_images to limit)",
50
  },
51
  "art_painting": {
 
59
  }
60
 
61
 
62
+ def auto_batch_size(model_size, image_size, gpu_mem_gb):
63
+ """Compute safe batch size based on model + resolution + GPU memory.
64
+
65
+ Accounts for: fp16 weights + fp16 grads + fp32 Adam states + activations.
66
+ With gradient checkpointing enabled, activation memory is ~50% less.
67
+ """
68
+ # Fixed memory per model (weights + grads + optimizer) in GB
69
+ param_mem = {"small": 0.66, "base": 1.68, "large": 3.35}
70
+ base = param_mem.get(model_size, 1.0)
71
+
72
+ # Activation memory per sample at this resolution (GB, with grad checkpointing)
73
+ # 256px: lat=32x32, patch=16x16 | 512px: lat=64x64, patch=32x32
74
+ act_per_sample = {"small": {256: 0.02, 512: 0.07},
75
+ "base": {256: 0.03, 512: 0.13},
76
+ "large": {256: 0.05, 512: 0.21}}
77
+ per_sample = act_per_sample.get(model_size, {}).get(image_size, 0.1)
78
+
79
+ # Leave 1.5GB headroom for PyTorch overhead, CUDA kernels, VAE loading
80
+ available = gpu_mem_gb - base - 1.5
81
+ bs = max(1, int(available / per_sample))
82
+ # Round down to nearest power of 2 for efficiency
83
+ bs = min(bs, 64)
84
+ if bs >= 32: bs = 32
85
+ elif bs >= 16: bs = 16
86
+ elif bs >= 8: bs = 8
87
+ elif bs >= 4: bs = 4
88
+ return bs
89
+
90
+
91
  @dataclass
92
  class TrainConfig:
93
+ model_size: str = "small"
94
+ num_classes: int = 0
 
 
95
  class_drop_prob: float = 0.1
96
+ dataset_preset: str = "cartoon"
97
+ image_size: int = 256
98
+ max_images: int = 0
 
 
 
 
99
  vae_id: str = "madebyollin/sdxl-vae-fp16-fix"
100
  vae_scaling_factor: float = 0.13025
101
  latent_channels: int = 4
102
+ batch_size: int = 0 # 0 = auto-detect based on GPU
 
 
103
  gradient_accumulation_steps: int = 1
104
  learning_rate: float = 1e-4
105
  weight_decay: float = 0.01
 
108
  warmup_steps: int = 500
109
  ema_decay: float = 0.9999
110
  mixed_precision: bool = True
111
+ gradient_checkpointing: bool = True # Enabled by default!
 
112
  min_timestep: float = 0.001
113
  max_timestep: float = 0.999
 
 
114
  output_dir: str = "./outputs"
115
  save_every_n_steps: int = 2000
116
  sample_every_n_steps: int = 500
117
  log_every_n_steps: int = 25
 
 
118
  num_sample_steps: int = 50
119
  cfg_scale: float = 2.0
120
  num_samples: int = 4
 
 
121
  seed: int = 42
122
  num_workers: int = 2
123
  compile_model: bool = False
 
 
124
  push_to_hub: bool = False
125
  hub_model_id: str = ""
126
 
 
141
  return cfg
142
 
143
 
 
 
 
 
144
  class CachedLatentDataset(Dataset):
 
 
145
  def __init__(self, cache_path):
146
  data = torch.load(cache_path, map_location="cpu", weights_only=True)
147
  self.latents = data["latents"]
148
  self.labels = data.get("labels", None)
149
  print(f"Loaded {len(self.latents)} cached latents from {cache_path}")
150
+ print(f" Shape: {self.latents.shape}")
151
  if self.labels is not None and (self.labels >= 0).any():
152
+ print(f" Labels: {self.labels[self.labels >= 0].unique().shape[0]} classes")
153
 
154
+ def __len__(self): return len(self.latents)
 
155
 
156
  def __getitem__(self, idx):
157
+ return self.latents[idx], (self.labels[idx] if self.labels is not None else -1)
 
 
158
 
159
 
160
  def precache_latents(config, cache_path=None):
 
 
 
161
  if cache_path is None:
162
  cache_path = os.path.join(config.output_dir, "cached_latents.pt")
 
163
  if os.path.exists(cache_path):
164
  print(f"Cache exists: {cache_path}")
165
  data = torch.load(cache_path, map_location="cpu", weights_only=True)
 
169
  os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else ".", exist_ok=True)
170
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
171
 
 
172
  print(f"Loading VAE: {config.vae_id}...")
173
  from diffusers import AutoencoderKL
174
+ vae = AutoencoderKL.from_pretrained(config.vae_id, torch_dtype=torch.float16).to(device).eval()
175
+ for p in vae.parameters(): p.requires_grad_(False)
 
 
 
176
  print(f" VAE: {sum(p.numel() for p in vae.parameters())/1e6:.0f}M params")
177
 
 
178
  preset = DATASET_PRESETS[config.dataset_preset]
179
  print(f"Loading: {preset['name']} ({preset['description']})")
 
180
  from datasets import load_dataset
181
  from torchvision import transforms
182
 
183
  ds_kwargs = {"split": "train"}
184
+ if preset["config"]: ds_kwargs["name"] = preset["config"]
 
 
185
  dataset = load_dataset(preset["name"], **ds_kwargs)
186
 
187
  transform = transforms.Compose([
188
  transforms.Resize(config.image_size, interpolation=transforms.InterpolationMode.LANCZOS),
189
+ transforms.CenterCrop(config.image_size), transforms.ToTensor(),
 
190
  ])
191
 
192
+ img_col, lbl_col = preset["image_column"], preset["label_column"]
 
 
193
  style_to_id = {}
194
+ all_latents, all_labels = [], []
195
+ batch_px, batch_lb = [], []
196
+ count, max_imgs = 0, config.max_images if config.max_images > 0 else float("inf")
 
 
 
 
 
 
 
197
  t0 = time.time()
198
 
199
  for item in dataset:
200
+ if count >= max_imgs: break
 
201
  img = item[img_col]
202
+ if img.mode != "RGB": img = img.convert("RGB")
203
+ batch_px.append(transform(img))
 
 
 
204
  if lbl_col and lbl_col in item:
205
+ raw = item[lbl_col]
206
+ if isinstance(raw, str):
207
+ if raw not in style_to_id: style_to_id[raw] = len(style_to_id)
208
+ batch_lb.append(style_to_id[raw])
209
+ elif isinstance(raw, int): batch_lb.append(raw)
210
+ else: batch_lb.append(-1)
211
+ else: batch_lb.append(-1)
 
 
 
 
212
  count += 1
213
+ if len(batch_px) >= 16:
 
214
  with torch.no_grad():
215
+ px = torch.stack(batch_px).to(device, dtype=torch.float16) * 2 - 1
216
+ lat = vae.encode(px).latent_dist.sample() * config.vae_scaling_factor
 
217
  all_latents.append(lat.cpu().float())
218
+ all_labels.extend(batch_lb); batch_px, batch_lb = [], []
219
+ if count % 500 == 0: print(f" {count} images ({time.time()-t0:.0f}s)")
 
 
220
 
221
+ if batch_px:
222
  with torch.no_grad():
223
+ px = torch.stack(batch_px).to(device, dtype=torch.float16) * 2 - 1
224
+ lat = vae.encode(px).latent_dist.sample() * config.vae_scaling_factor
 
225
  all_latents.append(lat.cpu().float())
226
+ all_labels.extend(batch_lb)
227
 
228
  all_latents = torch.cat(all_latents, dim=0)
229
  all_labels = torch.tensor(all_labels, dtype=torch.long)
 
230
  save_data = {"latents": all_latents, "labels": all_labels}
231
  if style_to_id:
232
  save_data["style_to_id"] = style_to_id
233
+ print(f" {len(style_to_id)} style classes mapped")
234
  torch.save(save_data, cache_path)
 
 
235
  mb = os.path.getsize(cache_path) / 1024**2
236
+ print(f"\nCached {count} latents -> {cache_path} ({all_latents.shape}, {mb:.0f}MB, {time.time()-t0:.0f}s)")
 
 
237
  del vae
238
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
 
239
  print(" VAE unloaded\n")
240
  return cache_path
241
 
242
 
 
 
 
 
243
  class EMAModel:
244
  def __init__(self, model, decay=0.9999):
245
  self.decay = decay
246
  self.shadow = {n: p.clone().detach() for n, p in model.named_parameters() if p.requires_grad}
 
247
  @torch.no_grad()
248
  def update(self, model):
249
  for n, p in model.named_parameters():
250
  if p.requires_grad and n in self.shadow:
251
  self.shadow[n].mul_(self.decay).add_(p.data, alpha=1 - self.decay)
 
252
  def apply(self, model):
253
  self.backup = {n: p.data.clone() for n, p in model.named_parameters() if p.requires_grad}
254
  for n, p in model.named_parameters():
255
+ if p.requires_grad and n in self.shadow: p.data.copy_(self.shadow[n])
 
 
256
  def restore(self, model):
257
  for n, p in model.named_parameters():
258
+ if p.requires_grad and n in self.backup: p.data.copy_(self.backup[n])
 
259
  self.backup = {}
260
 
261
 
262
  class FlowMatchingScheduler:
263
  def __init__(self, min_t=0.001, max_t=0.999):
264
  self.min_t, self.max_t = min_t, max_t
 
265
  def sample_timesteps(self, bs, dev):
266
  return torch.rand(bs, device=dev) * (self.max_t - self.min_t) + self.min_t
 
267
  def add_noise(self, x0, noise, t):
268
  t = t.view(-1, 1, 1, 1); return (1 - t) * x0 + t * noise
 
269
  def get_velocity_target(self, x0, noise):
270
  return noise - x0
 
271
  @torch.no_grad()
272
  def sample(self, model, shape, dev, num_steps=50, labels=None, cfg=1.0):
273
+ model.eval(); x = torch.randn(shape, device=dev); dt = 1.0 / num_steps
 
274
  for tv in torch.linspace(1.0, dt, num_steps, device=dev):
275
  t = torch.full((shape[0],), tv.item(), device=dev)
276
  with torch.amp.autocast("cuda"):
277
  if cfg > 1.0 and labels is not None:
278
  vc = model(x, t, labels); vu = model(x, t, torch.zeros_like(labels))
279
  v = vu + cfg * (vc - vu)
280
+ else: v = model(x, t, labels)
 
281
  x = x - dt * v.float()
282
  return x
283
 
 
289
  return torch.optim.lr_scheduler.LambdaLR(opt, lr)
290
 
291
 
 
 
 
 
292
  def train(config):
293
  from model import LiquidGen
 
294
  torch.manual_seed(config.seed)
295
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
296
+ gpu_mem = 0
297
  if torch.cuda.is_available():
298
+ gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1024**3
299
+ print(f"GPU: {torch.cuda.get_device_name(0)} ({gpu_mem:.1f} GB)")
300
+
301
+ # Auto batch size if not set
302
+ if config.batch_size <= 0:
303
+ if gpu_mem > 0:
304
+ config.batch_size = auto_batch_size(config.model_size, config.image_size, gpu_mem)
305
+ print(f"Auto batch size: {config.batch_size} (for {config.model_size} at {config.image_size}px on {gpu_mem:.0f}GB)")
306
+ else:
307
+ config.batch_size = 4
308
 
309
  os.makedirs(config.output_dir, exist_ok=True)
310
  os.makedirs(f"{config.output_dir}/samples", exist_ok=True)
311
  os.makedirs(f"{config.output_dir}/checkpoints", exist_ok=True)
 
312
  with open(f"{config.output_dir}/config.json", "w") as f:
313
  json.dump(asdict(config), f, indent=2)
314
 
315
  cache_path = precache_latents(config)
 
316
  train_ds = CachedLatentDataset(cache_path)
317
  train_dl = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True,
318
  num_workers=config.num_workers, pin_memory=True, drop_last=True)
 
320
  mcfg = get_model_config(config.model_size, config.num_classes, config.class_drop_prob)
321
  mcfg["in_channels"] = config.latent_channels
322
  model = LiquidGen(**mcfg).to(device)
323
+
324
+ # Enable gradient checkpointing (saves ~50% activation VRAM)
325
+ if config.gradient_checkpointing:
326
+ model.enable_gradient_checkpointing()
327
+ print(f"Gradient checkpointing: ON")
328
+
329
  print(f"LiquidGen-{config.model_size}: {model.count_params()/1e6:.1f}M params")
330
 
331
  if config.compile_model and hasattr(torch, "compile"):
 
340
  fm = FlowMatchingScheduler(config.min_timestep, config.max_timestep)
341
  lat_size = config.image_size // 8
342
 
343
+ print(f"Steps: {total_steps}, Batch: {config.batch_size}x{config.gradient_accumulation_steps}")
344
  print(f"Latent: [{config.batch_size}, {config.latent_channels}, {lat_size}, {lat_size}]")
345
  if torch.cuda.is_available():
346
+ print(f"VRAM: {torch.cuda.memory_allocated()/1024**3:.1f} / {gpu_mem:.1f} GB")
 
347
 
348
  gs = 0; la = 0.0; vae = None; vae_loaded = False
349
  print(f"\n{'='*60}\nTraining!\n{'='*60}\n")
 
354
  for bi, (lats, lbls) in enumerate(train_dl):
355
  lats = lats.to(device)
356
  lbls = lbls.to(device) if config.num_classes > 0 else None
 
357
  t = fm.sample_timesteps(lats.shape[0], device)
358
  noise = torch.randn_like(lats)
359
  xt = fm.add_noise(lats, noise, t)
360
  vtgt = fm.get_velocity_target(lats, noise)
 
361
  with autocast("cuda", enabled=config.mixed_precision and torch.cuda.is_available()):
362
  vp = model(xt, t, lbls)
363
  loss = F.mse_loss(vp, vtgt) / config.gradient_accumulation_steps
 
364
  scaler.scale(loss).backward()
365
  la += loss.item()
 
366
  if (bi + 1) % config.gradient_accumulation_steps == 0:
367
  scaler.unscale_(opt)
368
  gn = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
369
  scaler.step(opt); scaler.update(); opt.zero_grad(); sched.step()
370
  ema.update(model); gs += 1
 
371
  if gs % config.log_every_n_steps == 0:
372
  al = la / config.log_every_n_steps
 
373
  vram = torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0
374
  sps = gs / max(time.time() - t_start, 1)
375
  print(f"step={gs:>6d} | ep={epoch} | loss={al:.4f} | gn={gn:.2f} | "
376
+ f"lr={opt.param_groups[0]['lr']:.2e} | vram={vram:.1f}G | {sps:.1f} st/s")
377
  la = 0.0
378
+ if math.isnan(al) or al > 50: print("Diverged!"); return
 
 
379
  if gs % config.sample_every_n_steps == 0:
380
  if not vae_loaded:
381
  from diffusers import AutoencoderKL
382
+ vae = AutoencoderKL.from_pretrained(config.vae_id, torch_dtype=torch.float16).to(device).eval()
 
 
383
  for p in vae.parameters(): p.requires_grad_(False)
384
  vae_loaded = True
385
  ema.apply(model); model.eval()
386
+ sl = torch.randint(0, max(1, config.num_classes), (config.num_samples,), device=device) if config.num_classes > 0 else None
 
387
  samp = fm.sample(model, (config.num_samples, config.latent_channels, lat_size, lat_size),
388
  device, config.num_sample_steps, sl, config.cfg_scale)
389
  with torch.no_grad():
390
+ imgs = ((vae.decode(samp.half() / config.vae_scaling_factor).sample + 1) / 2).clamp(0, 1).float()
 
391
  from torchvision.utils import save_image
392
+ save_image(imgs, f"{config.output_dir}/samples/step_{gs:07d}.png", nrow=2)
393
+ print(f" Saved samples"); ema.restore(model); model.train()
 
 
394
  if gs % config.save_every_n_steps == 0:
 
395
  torch.save({"model": model.state_dict(), "ema": ema.shadow,
396
+ "optimizer": opt.state_dict(), "step": gs, "model_config": mcfg},
397
+ f"{config.output_dir}/checkpoints/step_{gs:07d}.pt")
 
 
398
  print(f"Epoch {epoch} | {time.time()-et:.0f}s\n")
399
 
400
  final = f"{config.output_dir}/checkpoints/final.pt"
401
+ torch.save({"model": model.state_dict(), "ema": ema.shadow, "model_config": mcfg, "step": gs}, final)
 
402
  print(f"\nDone! {gs} steps, {(time.time()-t_start)/60:.1f}min -> {final}")