asdf98 commited on
Commit
551424e
Β·
verified Β·
1 Parent(s): 3063cf6

Fix VAE: switch to madebyollin/sdxl-vae-fp16-fix (open, no auth needed, 4ch latent)

Browse files
Files changed (1) hide show
  1. train.py +30 -28
train.py CHANGED
@@ -6,6 +6,7 @@ Optimized for Colab free tier:
6
  - No VAE needed during training loop β†’ saves ~1GB VRAM + faster iterations
7
  - Streaming support for large datasets
8
  - Multiple small dataset presets
 
9
 
10
  Flow Matching training objective (velocity prediction):
11
  - Forward: x_t = (1 - t) * x_0 + t * Ξ΅
@@ -27,7 +28,7 @@ from dataclasses import dataclass, asdict
27
 
28
 
29
  # =============================================================================
30
- # Dataset Presets (all verified, fast to download)
31
  # =============================================================================
32
 
33
  DATASET_PRESETS = {
@@ -88,11 +89,12 @@ class TrainConfig:
88
  image_size: int = 256 # 256 or 512
89
  max_images: int = 0 # 0 = use all, >0 = limit (for streaming/testing)
90
 
91
- # VAE (for pre-caching only β€” NOT loaded during training)
92
- vae_id: str = "black-forest-labs/FLUX.1-schnell"
93
- vae_subfolder: str = "vae"
94
- vae_scaling_factor: float = 0.3611
95
- vae_shift_factor: float = 0.1159
 
96
 
97
  # Training
98
  batch_size: int = 32 # Can be large since training on cached tensors!
@@ -147,7 +149,7 @@ def get_model_config(size, num_classes=0, class_drop_prob=0.1):
147
 
148
 
149
  # =============================================================================
150
- # Latent Pre-Caching (the key optimization for Colab)
151
  # =============================================================================
152
 
153
  class CachedLatentDataset(Dataset):
@@ -174,13 +176,7 @@ class CachedLatentDataset(Dataset):
174
  def precache_latents(config, cache_path=None):
175
  """
176
  Encode all images to VAE latents once, save to disk.
177
-
178
- After caching:
179
- - VAE unloaded β†’ frees ~1GB VRAM
180
- - Training loads pure tensors β†’ much faster iterations
181
- - Larger batch sizes possible (no VAE memory overhead)
182
-
183
- Returns path to cache file.
184
  """
185
  if cache_path is None:
186
  cache_path = os.path.join(config.output_dir, "cached_latents.pt")
@@ -194,14 +190,15 @@ def precache_latents(config, cache_path=None):
194
  os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else ".", exist_ok=True)
195
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
196
 
197
- # Load VAE temporarily
198
- print("Loading VAE for encoding...")
199
  from diffusers import AutoencoderKL
200
  vae = AutoencoderKL.from_pretrained(
201
- config.vae_id, subfolder=config.vae_subfolder, torch_dtype=torch.float16
202
  ).to(device).eval()
203
  for p in vae.parameters():
204
  p.requires_grad_(False)
 
205
 
206
  # Load dataset
207
  preset = DATASET_PRESETS[config.dataset_preset]
@@ -235,7 +232,7 @@ def precache_latents(config, cache_path=None):
235
  img_col = preset["image_column"]
236
  lbl_col = preset["label_column"]
237
 
238
- print(f"Encoding images to latents...")
239
  t0 = time.time()
240
 
241
  for item in dataset:
@@ -255,7 +252,7 @@ def precache_latents(config, cache_path=None):
255
  with torch.no_grad():
256
  px = torch.stack(batch_pixels).to(device, dtype=torch.float16) * 2 - 1
257
  lat = vae.encode(px).latent_dist.sample()
258
- lat = (lat - config.vae_shift_factor) * config.vae_scaling_factor
259
  all_latents.append(lat.cpu().float())
260
  all_labels.extend(batch_labels)
261
  batch_pixels, batch_labels = [], []
@@ -266,7 +263,7 @@ def precache_latents(config, cache_path=None):
266
  with torch.no_grad():
267
  px = torch.stack(batch_pixels).to(device, dtype=torch.float16) * 2 - 1
268
  lat = vae.encode(px).latent_dist.sample()
269
- lat = (lat - config.vae_shift_factor) * config.vae_scaling_factor
270
  all_latents.append(lat.cpu().float())
271
  all_labels.extend(batch_labels)
272
 
@@ -379,8 +376,9 @@ def train(config):
379
  train_dl = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True,
380
  num_workers=config.num_workers, pin_memory=True, drop_last=True)
381
 
382
- # Step 3: Model
383
  mcfg = get_model_config(config.model_size, config.num_classes, config.class_drop_prob)
 
384
  model = LiquidGen(**mcfg).to(device)
385
  print(f"LiquidGen-{config.model_size}: {model.count_params()/1e6:.1f}M params")
386
 
@@ -397,8 +395,9 @@ def train(config):
397
  fm = FlowMatchingScheduler(config.min_timestep, config.max_timestep)
398
  lat_size = config.image_size // 8
399
 
400
- print(f"\nTotal steps: {total_steps}, Batch: {config.batch_size}Γ—{config.gradient_accumulation_steps}")
401
- print(f"No VAE during training β†’ max VRAM for model")
 
402
  if torch.cuda.is_available():
403
  print(f"VRAM: {torch.cuda.memory_allocated()/1024**3:.1f} / "
404
  f"{torch.cuda.get_device_properties(0).total_mem/1024**3:.1f} GB")
@@ -444,20 +443,23 @@ def train(config):
444
  print("πŸ’₯ Diverged!"); return
445
 
446
  if gs % config.sample_every_n_steps == 0:
 
447
  if not vae_loaded:
448
  from diffusers import AutoencoderKL
449
  vae = AutoencoderKL.from_pretrained(
450
- config.vae_id, subfolder=config.vae_subfolder,
451
- torch_dtype=torch.float16).to(device).eval()
452
  for p in vae.parameters(): p.requires_grad_(False)
453
  vae_loaded = True
454
  ema.apply(model); model.eval()
455
  sl = torch.randint(0, max(1, config.num_classes), (config.num_samples,),
456
  device=device) if config.num_classes > 0 else None
457
- samp = fm.sample(model, (config.num_samples, 16, lat_size, lat_size),
 
458
  device, config.num_sample_steps, sl, config.cfg_scale)
459
  with torch.no_grad():
460
- dec = samp.half() / config.vae_scaling_factor + config.vae_shift_factor
 
461
  imgs = ((vae.decode(dec).sample + 1) / 2).clamp(0, 1).float()
462
  from torchvision.utils import save_image
463
  sp = f"{config.output_dir}/samples/step_{gs:07d}.png"
@@ -476,7 +478,7 @@ def train(config):
476
  final = f"{config.output_dir}/checkpoints/final.pt"
477
  torch.save({"model": model.state_dict(), "ema": ema.shadow,
478
  "model_config": mcfg, "step": gs}, final)
479
- print(f"\nπŸŽ‰ Done! {gs} steps, {(time.time()-t_start)/60:.1f}min β†’ {final}")
480
 
481
 
482
  if __name__ == "__main__":
 
6
  - No VAE needed during training loop β†’ saves ~1GB VRAM + faster iterations
7
  - Streaming support for large datasets
8
  - Multiple small dataset presets
9
+ - Uses madebyollin/sdxl-vae-fp16-fix (fully open, no login, fp16 stable)
10
 
11
  Flow Matching training objective (velocity prediction):
12
  - Forward: x_t = (1 - t) * x_0 + t * Ξ΅
 
28
 
29
 
30
  # =============================================================================
31
+ # Dataset Presets (all verified, fast to download, no auth needed)
32
  # =============================================================================
33
 
34
  DATASET_PRESETS = {
 
89
  image_size: int = 256 # 256 or 512
90
  max_images: int = 0 # 0 = use all, >0 = limit (for streaming/testing)
91
 
92
+ # VAE β€” fully open, no login needed
93
+ # madebyollin/sdxl-vae-fp16-fix: SDXL VAE with fp16 NaN fix
94
+ # 4 latent channels, 8x spatial compression, scaling_factor=0.13025
95
+ vae_id: str = "madebyollin/sdxl-vae-fp16-fix"
96
+ vae_scaling_factor: float = 0.13025
97
+ latent_channels: int = 4
98
 
99
  # Training
100
  batch_size: int = 32 # Can be large since training on cached tensors!
 
149
 
150
 
151
  # =============================================================================
152
+ # Latent Pre-Caching
153
  # =============================================================================
154
 
155
  class CachedLatentDataset(Dataset):
 
176
  def precache_latents(config, cache_path=None):
177
  """
178
  Encode all images to VAE latents once, save to disk.
179
+ Uses madebyollin/sdxl-vae-fp16-fix (no auth needed).
 
 
 
 
 
 
180
  """
181
  if cache_path is None:
182
  cache_path = os.path.join(config.output_dir, "cached_latents.pt")
 
190
  os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else ".", exist_ok=True)
191
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
192
 
193
+ # Load VAE β€” no subfolder, no auth needed
194
+ print(f"Loading VAE: {config.vae_id} (open, no login needed)...")
195
  from diffusers import AutoencoderKL
196
  vae = AutoencoderKL.from_pretrained(
197
+ config.vae_id, torch_dtype=torch.float16
198
  ).to(device).eval()
199
  for p in vae.parameters():
200
  p.requires_grad_(False)
201
+ print(f" VAE loaded: {sum(p.numel() for p in vae.parameters())/1e6:.0f}M params")
202
 
203
  # Load dataset
204
  preset = DATASET_PRESETS[config.dataset_preset]
 
232
  img_col = preset["image_column"]
233
  lbl_col = preset["label_column"]
234
 
235
+ print(f"Encoding images to VAE latents...")
236
  t0 = time.time()
237
 
238
  for item in dataset:
 
252
  with torch.no_grad():
253
  px = torch.stack(batch_pixels).to(device, dtype=torch.float16) * 2 - 1
254
  lat = vae.encode(px).latent_dist.sample()
255
+ lat = lat * config.vae_scaling_factor # SDXL: scale only, no shift
256
  all_latents.append(lat.cpu().float())
257
  all_labels.extend(batch_labels)
258
  batch_pixels, batch_labels = [], []
 
263
  with torch.no_grad():
264
  px = torch.stack(batch_pixels).to(device, dtype=torch.float16) * 2 - 1
265
  lat = vae.encode(px).latent_dist.sample()
266
+ lat = lat * config.vae_scaling_factor
267
  all_latents.append(lat.cpu().float())
268
  all_labels.extend(batch_labels)
269
 
 
376
  train_dl = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True,
377
  num_workers=config.num_workers, pin_memory=True, drop_last=True)
378
 
379
+ # Step 3: Model (in_channels=4 for SDXL VAE)
380
  mcfg = get_model_config(config.model_size, config.num_classes, config.class_drop_prob)
381
+ mcfg["in_channels"] = config.latent_channels # 4 for SDXL VAE
382
  model = LiquidGen(**mcfg).to(device)
383
  print(f"LiquidGen-{config.model_size}: {model.count_params()/1e6:.1f}M params")
384
 
 
395
  fm = FlowMatchingScheduler(config.min_timestep, config.max_timestep)
396
  lat_size = config.image_size // 8
397
 
398
+ print(f"\nTotal steps: {total_steps}, Batch: {config.batch_size}x{config.gradient_accumulation_steps}")
399
+ print(f"Latent: [{config.batch_size}, {config.latent_channels}, {lat_size}, {lat_size}]")
400
+ print(f"No VAE during training -> max VRAM for model")
401
  if torch.cuda.is_available():
402
  print(f"VRAM: {torch.cuda.memory_allocated()/1024**3:.1f} / "
403
  f"{torch.cuda.get_device_properties(0).total_mem/1024**3:.1f} GB")
 
443
  print("πŸ’₯ Diverged!"); return
444
 
445
  if gs % config.sample_every_n_steps == 0:
446
+ # Load VAE lazily (only for decoding samples)
447
  if not vae_loaded:
448
  from diffusers import AutoencoderKL
449
  vae = AutoencoderKL.from_pretrained(
450
+ config.vae_id, torch_dtype=torch.float16
451
+ ).to(device).eval()
452
  for p in vae.parameters(): p.requires_grad_(False)
453
  vae_loaded = True
454
  ema.apply(model); model.eval()
455
  sl = torch.randint(0, max(1, config.num_classes), (config.num_samples,),
456
  device=device) if config.num_classes > 0 else None
457
+ # 4 channels for SDXL VAE
458
+ samp = fm.sample(model, (config.num_samples, config.latent_channels, lat_size, lat_size),
459
  device, config.num_sample_steps, sl, config.cfg_scale)
460
  with torch.no_grad():
461
+ # SDXL VAE: unscale only, no shift
462
+ dec = samp.half() / config.vae_scaling_factor
463
  imgs = ((vae.decode(dec).sample + 1) / 2).clamp(0, 1).float()
464
  from torchvision.utils import save_image
465
  sp = f"{config.output_dir}/samples/step_{gs:07d}.png"
 
478
  final = f"{config.output_dir}/checkpoints/final.pt"
479
  torch.save({"model": model.state_dict(), "ema": ema.shadow,
480
  "model_config": mcfg, "step": gs}, final)
481
+ print(f"\nπŸŽ‰ Done! {gs} steps, {(time.time()-t_start)/60:.1f}min -> {final}")
482
 
483
 
484
  if __name__ == "__main__":