Fix: add trust_remote_code=True for datasets with legacy loading scripts
Browse files
train.py
CHANGED
|
@@ -38,6 +38,7 @@ DATASET_PRESETS = {
|
|
| 38 |
"image_column": "image",
|
| 39 |
"label_column": "labels",
|
| 40 |
"num_classes": 27,
|
|
|
|
| 41 |
"description": "~200 painting samples, 27 styles, 1.7MB โ instant smoke test",
|
| 42 |
},
|
| 43 |
"paintings": {
|
|
@@ -46,6 +47,7 @@ DATASET_PRESETS = {
|
|
| 46 |
"image_column": "image",
|
| 47 |
"label_column": "labels",
|
| 48 |
"num_classes": 27,
|
|
|
|
| 49 |
"description": "~8K paintings, 27 styles, 204MB โ best for style-conditional training",
|
| 50 |
},
|
| 51 |
"cartoon": {
|
|
@@ -90,18 +92,16 @@ class TrainConfig:
|
|
| 90 |
max_images: int = 0 # 0 = use all, >0 = limit (for streaming/testing)
|
| 91 |
|
| 92 |
# VAE โ fully open, no login needed
|
| 93 |
-
# madebyollin/sdxl-vae-fp16-fix: SDXL VAE with fp16 NaN fix
|
| 94 |
-
# 4 latent channels, 8x spatial compression, scaling_factor=0.13025
|
| 95 |
vae_id: str = "madebyollin/sdxl-vae-fp16-fix"
|
| 96 |
vae_scaling_factor: float = 0.13025
|
| 97 |
latent_channels: int = 4
|
| 98 |
|
| 99 |
# Training
|
| 100 |
-
batch_size: int = 32
|
| 101 |
gradient_accumulation_steps: int = 1
|
| 102 |
learning_rate: float = 1e-4
|
| 103 |
weight_decay: float = 0.01
|
| 104 |
-
max_grad_norm: float = 2.0
|
| 105 |
num_epochs: int = 100
|
| 106 |
warmup_steps: int = 500
|
| 107 |
ema_decay: float = 0.9999
|
|
@@ -190,7 +190,7 @@ def precache_latents(config, cache_path=None):
|
|
| 190 |
os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else ".", exist_ok=True)
|
| 191 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 192 |
|
| 193 |
-
# Load VAE
|
| 194 |
print(f"Loading VAE: {config.vae_id} (open, no login needed)...")
|
| 195 |
from diffusers import AutoencoderKL
|
| 196 |
vae = AutoencoderKL.from_pretrained(
|
|
@@ -213,6 +213,9 @@ def precache_latents(config, cache_path=None):
|
|
| 213 |
ds_kwargs["name"] = preset["config"]
|
| 214 |
if is_streaming:
|
| 215 |
ds_kwargs["streaming"] = True
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
dataset = load_dataset(preset["name"], **ds_kwargs)
|
| 218 |
|
|
@@ -252,7 +255,7 @@ def precache_latents(config, cache_path=None):
|
|
| 252 |
with torch.no_grad():
|
| 253 |
px = torch.stack(batch_pixels).to(device, dtype=torch.float16) * 2 - 1
|
| 254 |
lat = vae.encode(px).latent_dist.sample()
|
| 255 |
-
lat = lat * config.vae_scaling_factor
|
| 256 |
all_latents.append(lat.cpu().float())
|
| 257 |
all_labels.extend(batch_labels)
|
| 258 |
batch_pixels, batch_labels = [], []
|
|
@@ -273,7 +276,7 @@ def precache_latents(config, cache_path=None):
|
|
| 273 |
|
| 274 |
elapsed = time.time() - t0
|
| 275 |
mb = os.path.getsize(cache_path) / 1024**2
|
| 276 |
-
print(f"\nโ
Cached {count} latents
|
| 277 |
print(f" Shape: {all_latents.shape}, Size: {mb:.1f}MB, Time: {elapsed:.0f}s")
|
| 278 |
|
| 279 |
del vae
|
|
@@ -376,9 +379,9 @@ def train(config):
|
|
| 376 |
train_dl = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True,
|
| 377 |
num_workers=config.num_workers, pin_memory=True, drop_last=True)
|
| 378 |
|
| 379 |
-
# Step 3: Model
|
| 380 |
mcfg = get_model_config(config.model_size, config.num_classes, config.class_drop_prob)
|
| 381 |
-
mcfg["in_channels"] = config.latent_channels
|
| 382 |
model = LiquidGen(**mcfg).to(device)
|
| 383 |
print(f"LiquidGen-{config.model_size}: {model.count_params()/1e6:.1f}M params")
|
| 384 |
|
|
@@ -440,10 +443,9 @@ def train(config):
|
|
| 440 |
f"lr={lr:.2e} | vram={vram:.1f}G | {sps:.1f} st/s")
|
| 441 |
la = 0.0
|
| 442 |
if math.isnan(al) or al > 50:
|
| 443 |
-
print("
|
| 444 |
|
| 445 |
if gs % config.sample_every_n_steps == 0:
|
| 446 |
-
# Load VAE lazily (only for decoding samples)
|
| 447 |
if not vae_loaded:
|
| 448 |
from diffusers import AutoencoderKL
|
| 449 |
vae = AutoencoderKL.from_pretrained(
|
|
@@ -454,16 +456,14 @@ def train(config):
|
|
| 454 |
ema.apply(model); model.eval()
|
| 455 |
sl = torch.randint(0, max(1, config.num_classes), (config.num_samples,),
|
| 456 |
device=device) if config.num_classes > 0 else None
|
| 457 |
-
# 4 channels for SDXL VAE
|
| 458 |
samp = fm.sample(model, (config.num_samples, config.latent_channels, lat_size, lat_size),
|
| 459 |
device, config.num_sample_steps, sl, config.cfg_scale)
|
| 460 |
with torch.no_grad():
|
| 461 |
-
# SDXL VAE: unscale only, no shift
|
| 462 |
dec = samp.half() / config.vae_scaling_factor
|
| 463 |
imgs = ((vae.decode(dec).sample + 1) / 2).clamp(0, 1).float()
|
| 464 |
from torchvision.utils import save_image
|
| 465 |
sp = f"{config.output_dir}/samples/step_{gs:07d}.png"
|
| 466 |
-
save_image(imgs, sp, nrow=2); print(f"
|
| 467 |
ema.restore(model); model.train()
|
| 468 |
|
| 469 |
if gs % config.save_every_n_steps == 0:
|
|
@@ -471,14 +471,14 @@ def train(config):
|
|
| 471 |
torch.save({"model": model.state_dict(), "ema": ema.shadow,
|
| 472 |
"optimizer": opt.state_dict(), "scheduler": sched.state_dict(),
|
| 473 |
"step": gs, "epoch": epoch, "model_config": mcfg}, cp)
|
| 474 |
-
print(f"
|
| 475 |
|
| 476 |
print(f"Epoch {epoch} | {time.time()-et:.0f}s\n")
|
| 477 |
|
| 478 |
final = f"{config.output_dir}/checkpoints/final.pt"
|
| 479 |
torch.save({"model": model.state_dict(), "ema": ema.shadow,
|
| 480 |
"model_config": mcfg, "step": gs}, final)
|
| 481 |
-
print(f"\
|
| 482 |
|
| 483 |
|
| 484 |
if __name__ == "__main__":
|
|
|
|
| 38 |
"image_column": "image",
|
| 39 |
"label_column": "labels",
|
| 40 |
"num_classes": 27,
|
| 41 |
+
"trust_remote_code": True,
|
| 42 |
"description": "~200 painting samples, 27 styles, 1.7MB โ instant smoke test",
|
| 43 |
},
|
| 44 |
"paintings": {
|
|
|
|
| 47 |
"image_column": "image",
|
| 48 |
"label_column": "labels",
|
| 49 |
"num_classes": 27,
|
| 50 |
+
"trust_remote_code": True,
|
| 51 |
"description": "~8K paintings, 27 styles, 204MB โ best for style-conditional training",
|
| 52 |
},
|
| 53 |
"cartoon": {
|
|
|
|
| 92 |
max_images: int = 0 # 0 = use all, >0 = limit (for streaming/testing)
|
| 93 |
|
| 94 |
# VAE โ fully open, no login needed
|
|
|
|
|
|
|
| 95 |
vae_id: str = "madebyollin/sdxl-vae-fp16-fix"
|
| 96 |
vae_scaling_factor: float = 0.13025
|
| 97 |
latent_channels: int = 4
|
| 98 |
|
| 99 |
# Training
|
| 100 |
+
batch_size: int = 32
|
| 101 |
gradient_accumulation_steps: int = 1
|
| 102 |
learning_rate: float = 1e-4
|
| 103 |
weight_decay: float = 0.01
|
| 104 |
+
max_grad_norm: float = 2.0
|
| 105 |
num_epochs: int = 100
|
| 106 |
warmup_steps: int = 500
|
| 107 |
ema_decay: float = 0.9999
|
|
|
|
| 190 |
os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else ".", exist_ok=True)
|
| 191 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 192 |
|
| 193 |
+
# Load VAE
|
| 194 |
print(f"Loading VAE: {config.vae_id} (open, no login needed)...")
|
| 195 |
from diffusers import AutoencoderKL
|
| 196 |
vae = AutoencoderKL.from_pretrained(
|
|
|
|
| 213 |
ds_kwargs["name"] = preset["config"]
|
| 214 |
if is_streaming:
|
| 215 |
ds_kwargs["streaming"] = True
|
| 216 |
+
# Some datasets have legacy loading scripts that need this flag
|
| 217 |
+
if preset.get("trust_remote_code", False):
|
| 218 |
+
ds_kwargs["trust_remote_code"] = True
|
| 219 |
|
| 220 |
dataset = load_dataset(preset["name"], **ds_kwargs)
|
| 221 |
|
|
|
|
| 255 |
with torch.no_grad():
|
| 256 |
px = torch.stack(batch_pixels).to(device, dtype=torch.float16) * 2 - 1
|
| 257 |
lat = vae.encode(px).latent_dist.sample()
|
| 258 |
+
lat = lat * config.vae_scaling_factor
|
| 259 |
all_latents.append(lat.cpu().float())
|
| 260 |
all_labels.extend(batch_labels)
|
| 261 |
batch_pixels, batch_labels = [], []
|
|
|
|
| 276 |
|
| 277 |
elapsed = time.time() - t0
|
| 278 |
mb = os.path.getsize(cache_path) / 1024**2
|
| 279 |
+
print(f"\nโ
Cached {count} latents -> {cache_path}")
|
| 280 |
print(f" Shape: {all_latents.shape}, Size: {mb:.1f}MB, Time: {elapsed:.0f}s")
|
| 281 |
|
| 282 |
del vae
|
|
|
|
| 379 |
train_dl = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True,
|
| 380 |
num_workers=config.num_workers, pin_memory=True, drop_last=True)
|
| 381 |
|
| 382 |
+
# Step 3: Model
|
| 383 |
mcfg = get_model_config(config.model_size, config.num_classes, config.class_drop_prob)
|
| 384 |
+
mcfg["in_channels"] = config.latent_channels
|
| 385 |
model = LiquidGen(**mcfg).to(device)
|
| 386 |
print(f"LiquidGen-{config.model_size}: {model.count_params()/1e6:.1f}M params")
|
| 387 |
|
|
|
|
| 443 |
f"lr={lr:.2e} | vram={vram:.1f}G | {sps:.1f} st/s")
|
| 444 |
la = 0.0
|
| 445 |
if math.isnan(al) or al > 50:
|
| 446 |
+
print("Diverged!"); return
|
| 447 |
|
| 448 |
if gs % config.sample_every_n_steps == 0:
|
|
|
|
| 449 |
if not vae_loaded:
|
| 450 |
from diffusers import AutoencoderKL
|
| 451 |
vae = AutoencoderKL.from_pretrained(
|
|
|
|
| 456 |
ema.apply(model); model.eval()
|
| 457 |
sl = torch.randint(0, max(1, config.num_classes), (config.num_samples,),
|
| 458 |
device=device) if config.num_classes > 0 else None
|
|
|
|
| 459 |
samp = fm.sample(model, (config.num_samples, config.latent_channels, lat_size, lat_size),
|
| 460 |
device, config.num_sample_steps, sl, config.cfg_scale)
|
| 461 |
with torch.no_grad():
|
|
|
|
| 462 |
dec = samp.half() / config.vae_scaling_factor
|
| 463 |
imgs = ((vae.decode(dec).sample + 1) / 2).clamp(0, 1).float()
|
| 464 |
from torchvision.utils import save_image
|
| 465 |
sp = f"{config.output_dir}/samples/step_{gs:07d}.png"
|
| 466 |
+
save_image(imgs, sp, nrow=2); print(f" Saved: {sp}")
|
| 467 |
ema.restore(model); model.train()
|
| 468 |
|
| 469 |
if gs % config.save_every_n_steps == 0:
|
|
|
|
| 471 |
torch.save({"model": model.state_dict(), "ema": ema.shadow,
|
| 472 |
"optimizer": opt.state_dict(), "scheduler": sched.state_dict(),
|
| 473 |
"step": gs, "epoch": epoch, "model_config": mcfg}, cp)
|
| 474 |
+
print(f" Saved: {cp}")
|
| 475 |
|
| 476 |
print(f"Epoch {epoch} | {time.time()-et:.0f}s\n")
|
| 477 |
|
| 478 |
final = f"{config.output_dir}/checkpoints/final.pt"
|
| 479 |
torch.save({"model": model.state_dict(), "ema": ema.shadow,
|
| 480 |
"model_config": mcfg, "step": gs}, final)
|
| 481 |
+
print(f"\nDone! {gs} steps, {(time.time()-t_start)/60:.1f}min -> {final}")
|
| 482 |
|
| 483 |
|
| 484 |
if __name__ == "__main__":
|