asdf98 commited on
Commit
2403335
ยท
verified ยท
1 Parent(s): 1aa5ac1

Fix: add trust_remote_code=True for datasets with legacy loading scripts

Browse files
Files changed (1) hide show
  1. train.py +16 -16
train.py CHANGED
@@ -38,6 +38,7 @@ DATASET_PRESETS = {
38
  "image_column": "image",
39
  "label_column": "labels",
40
  "num_classes": 27,
 
41
  "description": "~200 painting samples, 27 styles, 1.7MB โ€” instant smoke test",
42
  },
43
  "paintings": {
@@ -46,6 +47,7 @@ DATASET_PRESETS = {
46
  "image_column": "image",
47
  "label_column": "labels",
48
  "num_classes": 27,
 
49
  "description": "~8K paintings, 27 styles, 204MB โ€” best for style-conditional training",
50
  },
51
  "cartoon": {
@@ -90,18 +92,16 @@ class TrainConfig:
90
  max_images: int = 0 # 0 = use all, >0 = limit (for streaming/testing)
91
 
92
  # VAE โ€” fully open, no login needed
93
- # madebyollin/sdxl-vae-fp16-fix: SDXL VAE with fp16 NaN fix
94
- # 4 latent channels, 8x spatial compression, scaling_factor=0.13025
95
  vae_id: str = "madebyollin/sdxl-vae-fp16-fix"
96
  vae_scaling_factor: float = 0.13025
97
  latent_channels: int = 4
98
 
99
  # Training
100
- batch_size: int = 32 # Can be large since training on cached tensors!
101
  gradient_accumulation_steps: int = 1
102
  learning_rate: float = 1e-4
103
  weight_decay: float = 0.01
104
- max_grad_norm: float = 2.0 # Critical for stability (ZigMa paper)
105
  num_epochs: int = 100
106
  warmup_steps: int = 500
107
  ema_decay: float = 0.9999
@@ -190,7 +190,7 @@ def precache_latents(config, cache_path=None):
190
  os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else ".", exist_ok=True)
191
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
192
 
193
- # Load VAE โ€” no subfolder, no auth needed
194
  print(f"Loading VAE: {config.vae_id} (open, no login needed)...")
195
  from diffusers import AutoencoderKL
196
  vae = AutoencoderKL.from_pretrained(
@@ -213,6 +213,9 @@ def precache_latents(config, cache_path=None):
213
  ds_kwargs["name"] = preset["config"]
214
  if is_streaming:
215
  ds_kwargs["streaming"] = True
 
 
 
216
 
217
  dataset = load_dataset(preset["name"], **ds_kwargs)
218
 
@@ -252,7 +255,7 @@ def precache_latents(config, cache_path=None):
252
  with torch.no_grad():
253
  px = torch.stack(batch_pixels).to(device, dtype=torch.float16) * 2 - 1
254
  lat = vae.encode(px).latent_dist.sample()
255
- lat = lat * config.vae_scaling_factor # SDXL: scale only, no shift
256
  all_latents.append(lat.cpu().float())
257
  all_labels.extend(batch_labels)
258
  batch_pixels, batch_labels = [], []
@@ -273,7 +276,7 @@ def precache_latents(config, cache_path=None):
273
 
274
  elapsed = time.time() - t0
275
  mb = os.path.getsize(cache_path) / 1024**2
276
- print(f"\nโœ… Cached {count} latents โ†’ {cache_path}")
277
  print(f" Shape: {all_latents.shape}, Size: {mb:.1f}MB, Time: {elapsed:.0f}s")
278
 
279
  del vae
@@ -376,9 +379,9 @@ def train(config):
376
  train_dl = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True,
377
  num_workers=config.num_workers, pin_memory=True, drop_last=True)
378
 
379
- # Step 3: Model (in_channels=4 for SDXL VAE)
380
  mcfg = get_model_config(config.model_size, config.num_classes, config.class_drop_prob)
381
- mcfg["in_channels"] = config.latent_channels # 4 for SDXL VAE
382
  model = LiquidGen(**mcfg).to(device)
383
  print(f"LiquidGen-{config.model_size}: {model.count_params()/1e6:.1f}M params")
384
 
@@ -440,10 +443,9 @@ def train(config):
440
  f"lr={lr:.2e} | vram={vram:.1f}G | {sps:.1f} st/s")
441
  la = 0.0
442
  if math.isnan(al) or al > 50:
443
- print("๐Ÿ’ฅ Diverged!"); return
444
 
445
  if gs % config.sample_every_n_steps == 0:
446
- # Load VAE lazily (only for decoding samples)
447
  if not vae_loaded:
448
  from diffusers import AutoencoderKL
449
  vae = AutoencoderKL.from_pretrained(
@@ -454,16 +456,14 @@ def train(config):
454
  ema.apply(model); model.eval()
455
  sl = torch.randint(0, max(1, config.num_classes), (config.num_samples,),
456
  device=device) if config.num_classes > 0 else None
457
- # 4 channels for SDXL VAE
458
  samp = fm.sample(model, (config.num_samples, config.latent_channels, lat_size, lat_size),
459
  device, config.num_sample_steps, sl, config.cfg_scale)
460
  with torch.no_grad():
461
- # SDXL VAE: unscale only, no shift
462
  dec = samp.half() / config.vae_scaling_factor
463
  imgs = ((vae.decode(dec).sample + 1) / 2).clamp(0, 1).float()
464
  from torchvision.utils import save_image
465
  sp = f"{config.output_dir}/samples/step_{gs:07d}.png"
466
- save_image(imgs, sp, nrow=2); print(f" ๐Ÿ“ธ {sp}")
467
  ema.restore(model); model.train()
468
 
469
  if gs % config.save_every_n_steps == 0:
@@ -471,14 +471,14 @@ def train(config):
471
  torch.save({"model": model.state_dict(), "ema": ema.shadow,
472
  "optimizer": opt.state_dict(), "scheduler": sched.state_dict(),
473
  "step": gs, "epoch": epoch, "model_config": mcfg}, cp)
474
- print(f" ๐Ÿ’พ {cp}")
475
 
476
  print(f"Epoch {epoch} | {time.time()-et:.0f}s\n")
477
 
478
  final = f"{config.output_dir}/checkpoints/final.pt"
479
  torch.save({"model": model.state_dict(), "ema": ema.shadow,
480
  "model_config": mcfg, "step": gs}, final)
481
- print(f"\n๐ŸŽ‰ Done! {gs} steps, {(time.time()-t_start)/60:.1f}min -> {final}")
482
 
483
 
484
  if __name__ == "__main__":
 
38
  "image_column": "image",
39
  "label_column": "labels",
40
  "num_classes": 27,
41
+ "trust_remote_code": True,
42
  "description": "~200 painting samples, 27 styles, 1.7MB โ€” instant smoke test",
43
  },
44
  "paintings": {
 
47
  "image_column": "image",
48
  "label_column": "labels",
49
  "num_classes": 27,
50
+ "trust_remote_code": True,
51
  "description": "~8K paintings, 27 styles, 204MB โ€” best for style-conditional training",
52
  },
53
  "cartoon": {
 
92
  max_images: int = 0 # 0 = use all, >0 = limit (for streaming/testing)
93
 
94
  # VAE โ€” fully open, no login needed
 
 
95
  vae_id: str = "madebyollin/sdxl-vae-fp16-fix"
96
  vae_scaling_factor: float = 0.13025
97
  latent_channels: int = 4
98
 
99
  # Training
100
+ batch_size: int = 32
101
  gradient_accumulation_steps: int = 1
102
  learning_rate: float = 1e-4
103
  weight_decay: float = 0.01
104
+ max_grad_norm: float = 2.0
105
  num_epochs: int = 100
106
  warmup_steps: int = 500
107
  ema_decay: float = 0.9999
 
190
  os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else ".", exist_ok=True)
191
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
192
 
193
+ # Load VAE
194
  print(f"Loading VAE: {config.vae_id} (open, no login needed)...")
195
  from diffusers import AutoencoderKL
196
  vae = AutoencoderKL.from_pretrained(
 
213
  ds_kwargs["name"] = preset["config"]
214
  if is_streaming:
215
  ds_kwargs["streaming"] = True
216
+ # Some datasets have legacy loading scripts that need this flag
217
+ if preset.get("trust_remote_code", False):
218
+ ds_kwargs["trust_remote_code"] = True
219
 
220
  dataset = load_dataset(preset["name"], **ds_kwargs)
221
 
 
255
  with torch.no_grad():
256
  px = torch.stack(batch_pixels).to(device, dtype=torch.float16) * 2 - 1
257
  lat = vae.encode(px).latent_dist.sample()
258
+ lat = lat * config.vae_scaling_factor
259
  all_latents.append(lat.cpu().float())
260
  all_labels.extend(batch_labels)
261
  batch_pixels, batch_labels = [], []
 
276
 
277
  elapsed = time.time() - t0
278
  mb = os.path.getsize(cache_path) / 1024**2
279
+ print(f"\nโœ… Cached {count} latents -> {cache_path}")
280
  print(f" Shape: {all_latents.shape}, Size: {mb:.1f}MB, Time: {elapsed:.0f}s")
281
 
282
  del vae
 
379
  train_dl = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True,
380
  num_workers=config.num_workers, pin_memory=True, drop_last=True)
381
 
382
+ # Step 3: Model
383
  mcfg = get_model_config(config.model_size, config.num_classes, config.class_drop_prob)
384
+ mcfg["in_channels"] = config.latent_channels
385
  model = LiquidGen(**mcfg).to(device)
386
  print(f"LiquidGen-{config.model_size}: {model.count_params()/1e6:.1f}M params")
387
 
 
443
  f"lr={lr:.2e} | vram={vram:.1f}G | {sps:.1f} st/s")
444
  la = 0.0
445
  if math.isnan(al) or al > 50:
446
+ print("Diverged!"); return
447
 
448
  if gs % config.sample_every_n_steps == 0:
 
449
  if not vae_loaded:
450
  from diffusers import AutoencoderKL
451
  vae = AutoencoderKL.from_pretrained(
 
456
  ema.apply(model); model.eval()
457
  sl = torch.randint(0, max(1, config.num_classes), (config.num_samples,),
458
  device=device) if config.num_classes > 0 else None
 
459
  samp = fm.sample(model, (config.num_samples, config.latent_channels, lat_size, lat_size),
460
  device, config.num_sample_steps, sl, config.cfg_scale)
461
  with torch.no_grad():
 
462
  dec = samp.half() / config.vae_scaling_factor
463
  imgs = ((vae.decode(dec).sample + 1) / 2).clamp(0, 1).float()
464
  from torchvision.utils import save_image
465
  sp = f"{config.output_dir}/samples/step_{gs:07d}.png"
466
+ save_image(imgs, sp, nrow=2); print(f" Saved: {sp}")
467
  ema.restore(model); model.train()
468
 
469
  if gs % config.save_every_n_steps == 0:
 
471
  torch.save({"model": model.state_dict(), "ema": ema.shadow,
472
  "optimizer": opt.state_dict(), "scheduler": sched.state_dict(),
473
  "step": gs, "epoch": epoch, "model_config": mcfg}, cp)
474
+ print(f" Saved: {cp}")
475
 
476
  print(f"Epoch {epoch} | {time.time()-et:.0f}s\n")
477
 
478
  final = f"{config.output_dir}/checkpoints/final.pt"
479
  torch.save({"model": model.state_dict(), "ema": ema.shadow,
480
  "model_config": mcfg, "step": gs}, final)
481
+ print(f"\nDone! {gs} steps, {(time.time()-t_start)/60:.1f}min -> {final}")
482
 
483
 
484
  if __name__ == "__main__":