Fix VAE: switch to madebyollin/sdxl-vae-fp16-fix (open, no auth needed, 4ch latent)
Browse files
train.py
CHANGED
|
@@ -6,6 +6,7 @@ Optimized for Colab free tier:
|
|
| 6 |
- No VAE needed during training loop β saves ~1GB VRAM + faster iterations
|
| 7 |
- Streaming support for large datasets
|
| 8 |
- Multiple small dataset presets
|
|
|
|
| 9 |
|
| 10 |
Flow Matching training objective (velocity prediction):
|
| 11 |
- Forward: x_t = (1 - t) * x_0 + t * Ξ΅
|
|
@@ -27,7 +28,7 @@ from dataclasses import dataclass, asdict
|
|
| 27 |
|
| 28 |
|
| 29 |
# =============================================================================
|
| 30 |
-
# Dataset Presets (all verified, fast to download)
|
| 31 |
# =============================================================================
|
| 32 |
|
| 33 |
DATASET_PRESETS = {
|
|
@@ -88,11 +89,12 @@ class TrainConfig:
|
|
| 88 |
image_size: int = 256 # 256 or 512
|
| 89 |
max_images: int = 0 # 0 = use all, >0 = limit (for streaming/testing)
|
| 90 |
|
| 91 |
-
# VAE
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
| 96 |
|
| 97 |
# Training
|
| 98 |
batch_size: int = 32 # Can be large since training on cached tensors!
|
|
@@ -147,7 +149,7 @@ def get_model_config(size, num_classes=0, class_drop_prob=0.1):
|
|
| 147 |
|
| 148 |
|
| 149 |
# =============================================================================
|
| 150 |
-
# Latent Pre-Caching
|
| 151 |
# =============================================================================
|
| 152 |
|
| 153 |
class CachedLatentDataset(Dataset):
|
|
@@ -174,13 +176,7 @@ class CachedLatentDataset(Dataset):
|
|
| 174 |
def precache_latents(config, cache_path=None):
|
| 175 |
"""
|
| 176 |
Encode all images to VAE latents once, save to disk.
|
| 177 |
-
|
| 178 |
-
After caching:
|
| 179 |
-
- VAE unloaded β frees ~1GB VRAM
|
| 180 |
-
- Training loads pure tensors β much faster iterations
|
| 181 |
-
- Larger batch sizes possible (no VAE memory overhead)
|
| 182 |
-
|
| 183 |
-
Returns path to cache file.
|
| 184 |
"""
|
| 185 |
if cache_path is None:
|
| 186 |
cache_path = os.path.join(config.output_dir, "cached_latents.pt")
|
|
@@ -194,14 +190,15 @@ def precache_latents(config, cache_path=None):
|
|
| 194 |
os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else ".", exist_ok=True)
|
| 195 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 196 |
|
| 197 |
-
# Load VAE
|
| 198 |
-
print("Loading VAE
|
| 199 |
from diffusers import AutoencoderKL
|
| 200 |
vae = AutoencoderKL.from_pretrained(
|
| 201 |
-
config.vae_id,
|
| 202 |
).to(device).eval()
|
| 203 |
for p in vae.parameters():
|
| 204 |
p.requires_grad_(False)
|
|
|
|
| 205 |
|
| 206 |
# Load dataset
|
| 207 |
preset = DATASET_PRESETS[config.dataset_preset]
|
|
@@ -235,7 +232,7 @@ def precache_latents(config, cache_path=None):
|
|
| 235 |
img_col = preset["image_column"]
|
| 236 |
lbl_col = preset["label_column"]
|
| 237 |
|
| 238 |
-
print(f"Encoding images to latents...")
|
| 239 |
t0 = time.time()
|
| 240 |
|
| 241 |
for item in dataset:
|
|
@@ -255,7 +252,7 @@ def precache_latents(config, cache_path=None):
|
|
| 255 |
with torch.no_grad():
|
| 256 |
px = torch.stack(batch_pixels).to(device, dtype=torch.float16) * 2 - 1
|
| 257 |
lat = vae.encode(px).latent_dist.sample()
|
| 258 |
-
lat =
|
| 259 |
all_latents.append(lat.cpu().float())
|
| 260 |
all_labels.extend(batch_labels)
|
| 261 |
batch_pixels, batch_labels = [], []
|
|
@@ -266,7 +263,7 @@ def precache_latents(config, cache_path=None):
|
|
| 266 |
with torch.no_grad():
|
| 267 |
px = torch.stack(batch_pixels).to(device, dtype=torch.float16) * 2 - 1
|
| 268 |
lat = vae.encode(px).latent_dist.sample()
|
| 269 |
-
lat =
|
| 270 |
all_latents.append(lat.cpu().float())
|
| 271 |
all_labels.extend(batch_labels)
|
| 272 |
|
|
@@ -379,8 +376,9 @@ def train(config):
|
|
| 379 |
train_dl = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True,
|
| 380 |
num_workers=config.num_workers, pin_memory=True, drop_last=True)
|
| 381 |
|
| 382 |
-
# Step 3: Model
|
| 383 |
mcfg = get_model_config(config.model_size, config.num_classes, config.class_drop_prob)
|
|
|
|
| 384 |
model = LiquidGen(**mcfg).to(device)
|
| 385 |
print(f"LiquidGen-{config.model_size}: {model.count_params()/1e6:.1f}M params")
|
| 386 |
|
|
@@ -397,8 +395,9 @@ def train(config):
|
|
| 397 |
fm = FlowMatchingScheduler(config.min_timestep, config.max_timestep)
|
| 398 |
lat_size = config.image_size // 8
|
| 399 |
|
| 400 |
-
print(f"\nTotal steps: {total_steps}, Batch: {config.batch_size}
|
| 401 |
-
print(f"
|
|
|
|
| 402 |
if torch.cuda.is_available():
|
| 403 |
print(f"VRAM: {torch.cuda.memory_allocated()/1024**3:.1f} / "
|
| 404 |
f"{torch.cuda.get_device_properties(0).total_mem/1024**3:.1f} GB")
|
|
@@ -444,20 +443,23 @@ def train(config):
|
|
| 444 |
print("π₯ Diverged!"); return
|
| 445 |
|
| 446 |
if gs % config.sample_every_n_steps == 0:
|
|
|
|
| 447 |
if not vae_loaded:
|
| 448 |
from diffusers import AutoencoderKL
|
| 449 |
vae = AutoencoderKL.from_pretrained(
|
| 450 |
-
config.vae_id,
|
| 451 |
-
|
| 452 |
for p in vae.parameters(): p.requires_grad_(False)
|
| 453 |
vae_loaded = True
|
| 454 |
ema.apply(model); model.eval()
|
| 455 |
sl = torch.randint(0, max(1, config.num_classes), (config.num_samples,),
|
| 456 |
device=device) if config.num_classes > 0 else None
|
| 457 |
-
|
|
|
|
| 458 |
device, config.num_sample_steps, sl, config.cfg_scale)
|
| 459 |
with torch.no_grad():
|
| 460 |
-
|
|
|
|
| 461 |
imgs = ((vae.decode(dec).sample + 1) / 2).clamp(0, 1).float()
|
| 462 |
from torchvision.utils import save_image
|
| 463 |
sp = f"{config.output_dir}/samples/step_{gs:07d}.png"
|
|
@@ -476,7 +478,7 @@ def train(config):
|
|
| 476 |
final = f"{config.output_dir}/checkpoints/final.pt"
|
| 477 |
torch.save({"model": model.state_dict(), "ema": ema.shadow,
|
| 478 |
"model_config": mcfg, "step": gs}, final)
|
| 479 |
-
print(f"\nπ Done! {gs} steps, {(time.time()-t_start)/60:.1f}min
|
| 480 |
|
| 481 |
|
| 482 |
if __name__ == "__main__":
|
|
|
|
| 6 |
- No VAE needed during training loop β saves ~1GB VRAM + faster iterations
|
| 7 |
- Streaming support for large datasets
|
| 8 |
- Multiple small dataset presets
|
| 9 |
+
- Uses madebyollin/sdxl-vae-fp16-fix (fully open, no login, fp16 stable)
|
| 10 |
|
| 11 |
Flow Matching training objective (velocity prediction):
|
| 12 |
- Forward: x_t = (1 - t) * x_0 + t * Ξ΅
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
# =============================================================================
|
| 31 |
+
# Dataset Presets (all verified, fast to download, no auth needed)
|
| 32 |
# =============================================================================
|
| 33 |
|
| 34 |
DATASET_PRESETS = {
|
|
|
|
| 89 |
image_size: int = 256 # 256 or 512
|
| 90 |
max_images: int = 0 # 0 = use all, >0 = limit (for streaming/testing)
|
| 91 |
|
| 92 |
+
# VAE β fully open, no login needed
|
| 93 |
+
# madebyollin/sdxl-vae-fp16-fix: SDXL VAE with fp16 NaN fix
|
| 94 |
+
# 4 latent channels, 8x spatial compression, scaling_factor=0.13025
|
| 95 |
+
vae_id: str = "madebyollin/sdxl-vae-fp16-fix"
|
| 96 |
+
vae_scaling_factor: float = 0.13025
|
| 97 |
+
latent_channels: int = 4
|
| 98 |
|
| 99 |
# Training
|
| 100 |
batch_size: int = 32 # Can be large since training on cached tensors!
|
|
|
|
| 149 |
|
| 150 |
|
| 151 |
# =============================================================================
|
| 152 |
+
# Latent Pre-Caching
|
| 153 |
# =============================================================================
|
| 154 |
|
| 155 |
class CachedLatentDataset(Dataset):
|
|
|
|
| 176 |
def precache_latents(config, cache_path=None):
|
| 177 |
"""
|
| 178 |
Encode all images to VAE latents once, save to disk.
|
| 179 |
+
Uses madebyollin/sdxl-vae-fp16-fix (no auth needed).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
"""
|
| 181 |
if cache_path is None:
|
| 182 |
cache_path = os.path.join(config.output_dir, "cached_latents.pt")
|
|
|
|
| 190 |
os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else ".", exist_ok=True)
|
| 191 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 192 |
|
| 193 |
+
# Load VAE β no subfolder, no auth needed
|
| 194 |
+
print(f"Loading VAE: {config.vae_id} (open, no login needed)...")
|
| 195 |
from diffusers import AutoencoderKL
|
| 196 |
vae = AutoencoderKL.from_pretrained(
|
| 197 |
+
config.vae_id, torch_dtype=torch.float16
|
| 198 |
).to(device).eval()
|
| 199 |
for p in vae.parameters():
|
| 200 |
p.requires_grad_(False)
|
| 201 |
+
print(f" VAE loaded: {sum(p.numel() for p in vae.parameters())/1e6:.0f}M params")
|
| 202 |
|
| 203 |
# Load dataset
|
| 204 |
preset = DATASET_PRESETS[config.dataset_preset]
|
|
|
|
| 232 |
img_col = preset["image_column"]
|
| 233 |
lbl_col = preset["label_column"]
|
| 234 |
|
| 235 |
+
print(f"Encoding images to VAE latents...")
|
| 236 |
t0 = time.time()
|
| 237 |
|
| 238 |
for item in dataset:
|
|
|
|
| 252 |
with torch.no_grad():
|
| 253 |
px = torch.stack(batch_pixels).to(device, dtype=torch.float16) * 2 - 1
|
| 254 |
lat = vae.encode(px).latent_dist.sample()
|
| 255 |
+
lat = lat * config.vae_scaling_factor # SDXL: scale only, no shift
|
| 256 |
all_latents.append(lat.cpu().float())
|
| 257 |
all_labels.extend(batch_labels)
|
| 258 |
batch_pixels, batch_labels = [], []
|
|
|
|
| 263 |
with torch.no_grad():
|
| 264 |
px = torch.stack(batch_pixels).to(device, dtype=torch.float16) * 2 - 1
|
| 265 |
lat = vae.encode(px).latent_dist.sample()
|
| 266 |
+
lat = lat * config.vae_scaling_factor
|
| 267 |
all_latents.append(lat.cpu().float())
|
| 268 |
all_labels.extend(batch_labels)
|
| 269 |
|
|
|
|
| 376 |
train_dl = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True,
|
| 377 |
num_workers=config.num_workers, pin_memory=True, drop_last=True)
|
| 378 |
|
| 379 |
+
# Step 3: Model (in_channels=4 for SDXL VAE)
|
| 380 |
mcfg = get_model_config(config.model_size, config.num_classes, config.class_drop_prob)
|
| 381 |
+
mcfg["in_channels"] = config.latent_channels # 4 for SDXL VAE
|
| 382 |
model = LiquidGen(**mcfg).to(device)
|
| 383 |
print(f"LiquidGen-{config.model_size}: {model.count_params()/1e6:.1f}M params")
|
| 384 |
|
|
|
|
| 395 |
fm = FlowMatchingScheduler(config.min_timestep, config.max_timestep)
|
| 396 |
lat_size = config.image_size // 8
|
| 397 |
|
| 398 |
+
print(f"\nTotal steps: {total_steps}, Batch: {config.batch_size}x{config.gradient_accumulation_steps}")
|
| 399 |
+
print(f"Latent: [{config.batch_size}, {config.latent_channels}, {lat_size}, {lat_size}]")
|
| 400 |
+
print(f"No VAE during training -> max VRAM for model")
|
| 401 |
if torch.cuda.is_available():
|
| 402 |
print(f"VRAM: {torch.cuda.memory_allocated()/1024**3:.1f} / "
|
| 403 |
f"{torch.cuda.get_device_properties(0).total_mem/1024**3:.1f} GB")
|
|
|
|
| 443 |
print("π₯ Diverged!"); return
|
| 444 |
|
| 445 |
if gs % config.sample_every_n_steps == 0:
|
| 446 |
+
# Load VAE lazily (only for decoding samples)
|
| 447 |
if not vae_loaded:
|
| 448 |
from diffusers import AutoencoderKL
|
| 449 |
vae = AutoencoderKL.from_pretrained(
|
| 450 |
+
config.vae_id, torch_dtype=torch.float16
|
| 451 |
+
).to(device).eval()
|
| 452 |
for p in vae.parameters(): p.requires_grad_(False)
|
| 453 |
vae_loaded = True
|
| 454 |
ema.apply(model); model.eval()
|
| 455 |
sl = torch.randint(0, max(1, config.num_classes), (config.num_samples,),
|
| 456 |
device=device) if config.num_classes > 0 else None
|
| 457 |
+
# 4 channels for SDXL VAE
|
| 458 |
+
samp = fm.sample(model, (config.num_samples, config.latent_channels, lat_size, lat_size),
|
| 459 |
device, config.num_sample_steps, sl, config.cfg_scale)
|
| 460 |
with torch.no_grad():
|
| 461 |
+
# SDXL VAE: unscale only, no shift
|
| 462 |
+
dec = samp.half() / config.vae_scaling_factor
|
| 463 |
imgs = ((vae.decode(dec).sample + 1) / 2).clamp(0, 1).float()
|
| 464 |
from torchvision.utils import save_image
|
| 465 |
sp = f"{config.output_dir}/samples/step_{gs:07d}.png"
|
|
|
|
| 478 |
final = f"{config.output_dir}/checkpoints/final.pt"
|
| 479 |
torch.save({"model": model.state_dict(), "ema": ema.shadow,
|
| 480 |
"model_config": mcfg, "step": gs}, final)
|
| 481 |
+
print(f"\nπ Done! {gs} steps, {(time.time()-t_start)/60:.1f}min -> {final}")
|
| 482 |
|
| 483 |
|
| 484 |
if __name__ == "__main__":
|