Fix: SDXL VAE (no login), streaming dataset, step-based training
Browse files
train.py
CHANGED
|
@@ -44,12 +44,12 @@ class TrainConfig:
|
|
| 44 |
max_samples: int = 0 # 0 = use all (only for non-streaming)
|
| 45 |
streaming_buffer: int = 1000 # Shuffle buffer for streaming
|
| 46 |
|
| 47 |
-
# VAE
|
| 48 |
-
vae_id: str = "
|
| 49 |
-
vae_subfolder: str = "
|
| 50 |
vae_dtype: str = "float16"
|
| 51 |
-
vae_scaling_factor: float = 0.
|
| 52 |
-
vae_shift_factor: float = 0.
|
| 53 |
|
| 54 |
# Training
|
| 55 |
batch_size: int = 8
|
|
@@ -323,8 +323,11 @@ def train(config: TrainConfig):
|
|
| 323 |
print("Loading VAE...")
|
| 324 |
from diffusers import AutoencoderKL
|
| 325 |
vae_dtype = torch.float16 if config.vae_dtype == "float16" else torch.bfloat16
|
|
|
|
|
|
|
|
|
|
| 326 |
vae = AutoencoderKL.from_pretrained(
|
| 327 |
-
config.vae_id,
|
| 328 |
).to(device).eval()
|
| 329 |
for p in vae.parameters():
|
| 330 |
p.requires_grad_(False)
|
|
@@ -448,7 +451,8 @@ def train(config: TrainConfig):
|
|
| 448 |
sample_labels = None
|
| 449 |
if config.num_classes > 0:
|
| 450 |
sample_labels = torch.randint(0, config.num_classes, (config.num_samples,), device=device)
|
| 451 |
-
|
|
|
|
| 452 |
device, config.num_sample_steps, sample_labels, config.cfg_scale)
|
| 453 |
sample_imgs = decode_latents_with_vae(sampled.to(vae_dtype), vae,
|
| 454 |
config.vae_scaling_factor, config.vae_shift_factor).float()
|
|
|
|
| 44 |
max_samples: int = 0 # 0 = use all (only for non-streaming)
|
| 45 |
streaming_buffer: int = 1000 # Shuffle buffer for streaming
|
| 46 |
|
| 47 |
+
# VAE (SDXL VAE - open access, no login needed, fp16-safe)
|
| 48 |
+
vae_id: str = "madebyollin/sdxl-vae-fp16-fix"
|
| 49 |
+
vae_subfolder: str = ""
|
| 50 |
vae_dtype: str = "float16"
|
| 51 |
+
vae_scaling_factor: float = 0.13025
|
| 52 |
+
vae_shift_factor: float = 0.0 # SDXL VAE has no shift
|
| 53 |
|
| 54 |
# Training
|
| 55 |
batch_size: int = 8
|
|
|
|
| 323 |
print("Loading VAE...")
|
| 324 |
from diffusers import AutoencoderKL
|
| 325 |
vae_dtype = torch.float16 if config.vae_dtype == "float16" else torch.bfloat16
|
| 326 |
+
vae_kwargs = {"torch_dtype": vae_dtype}
|
| 327 |
+
if config.vae_subfolder:
|
| 328 |
+
vae_kwargs["subfolder"] = config.vae_subfolder
|
| 329 |
vae = AutoencoderKL.from_pretrained(
|
| 330 |
+
config.vae_id, **vae_kwargs
|
| 331 |
).to(device).eval()
|
| 332 |
for p in vae.parameters():
|
| 333 |
p.requires_grad_(False)
|
|
|
|
| 451 |
sample_labels = None
|
| 452 |
if config.num_classes > 0:
|
| 453 |
sample_labels = torch.randint(0, config.num_classes, (config.num_samples,), device=device)
|
| 454 |
+
latent_ch = vae.config.latent_channels # 4 for SDXL, 16 for Flux
|
| 455 |
+
sampled = fm.sample(model, (config.num_samples, latent_ch, latent_size, latent_size),
|
| 456 |
device, config.num_sample_steps, sample_labels, config.cfg_scale)
|
| 457 |
sample_imgs = decode_latents_with_vae(sampled.to(vae_dtype), vae,
|
| 458 |
config.vae_scaling_factor, config.vae_shift_factor).float()
|