asdf98 commited on
Commit
ba78ae2
·
verified ·
1 Parent(s): a733be1

Fix: SDXL VAE (no login), streaming dataset, step-based training

Browse files
Files changed (1) hide show
  1. train.py +11 -7
train.py CHANGED
@@ -44,12 +44,12 @@ class TrainConfig:
44
  max_samples: int = 0 # 0 = use all (only for non-streaming)
45
  streaming_buffer: int = 1000 # Shuffle buffer for streaming
46
 
47
- # VAE
48
- vae_id: str = "black-forest-labs/FLUX.1-schnell"
49
- vae_subfolder: str = "vae"
50
  vae_dtype: str = "float16"
51
- vae_scaling_factor: float = 0.3611
52
- vae_shift_factor: float = 0.1159
53
 
54
  # Training
55
  batch_size: int = 8
@@ -323,8 +323,11 @@ def train(config: TrainConfig):
323
  print("Loading VAE...")
324
  from diffusers import AutoencoderKL
325
  vae_dtype = torch.float16 if config.vae_dtype == "float16" else torch.bfloat16
 
 
 
326
  vae = AutoencoderKL.from_pretrained(
327
- config.vae_id, subfolder=config.vae_subfolder, torch_dtype=vae_dtype
328
  ).to(device).eval()
329
  for p in vae.parameters():
330
  p.requires_grad_(False)
@@ -448,7 +451,8 @@ def train(config: TrainConfig):
448
  sample_labels = None
449
  if config.num_classes > 0:
450
  sample_labels = torch.randint(0, config.num_classes, (config.num_samples,), device=device)
451
- sampled = fm.sample(model, (config.num_samples, 16, latent_size, latent_size),
 
452
  device, config.num_sample_steps, sample_labels, config.cfg_scale)
453
  sample_imgs = decode_latents_with_vae(sampled.to(vae_dtype), vae,
454
  config.vae_scaling_factor, config.vae_shift_factor).float()
 
44
  max_samples: int = 0 # 0 = use all (only for non-streaming)
45
  streaming_buffer: int = 1000 # Shuffle buffer for streaming
46
 
47
+ # VAE (SDXL VAE - open access, no login needed, fp16-safe)
48
+ vae_id: str = "madebyollin/sdxl-vae-fp16-fix"
49
+ vae_subfolder: str = ""
50
  vae_dtype: str = "float16"
51
+ vae_scaling_factor: float = 0.13025
52
+ vae_shift_factor: float = 0.0 # SDXL VAE has no shift
53
 
54
  # Training
55
  batch_size: int = 8
 
323
  print("Loading VAE...")
324
  from diffusers import AutoencoderKL
325
  vae_dtype = torch.float16 if config.vae_dtype == "float16" else torch.bfloat16
326
+ vae_kwargs = {"torch_dtype": vae_dtype}
327
+ if config.vae_subfolder:
328
+ vae_kwargs["subfolder"] = config.vae_subfolder
329
  vae = AutoencoderKL.from_pretrained(
330
+ config.vae_id, **vae_kwargs
331
  ).to(device).eval()
332
  for p in vae.parameters():
333
  p.requires_grad_(False)
 
451
  sample_labels = None
452
  if config.num_classes > 0:
453
  sample_labels = torch.randint(0, config.num_classes, (config.num_samples,), device=device)
454
+ latent_ch = vae.config.latent_channels # 4 for SDXL, 16 for Flux
455
+ sampled = fm.sample(model, (config.num_samples, latent_ch, latent_size, latent_size),
456
  device, config.num_sample_steps, sample_labels, config.cfg_scale)
457
  sample_imgs = decode_latents_with_vae(sampled.to(vae_dtype), vae,
458
  config.vae_scaling_factor, config.vae_shift_factor).float()