asdf98 commited on
Commit
a733be1
·
verified ·
1 Parent(s): 1373ccf

Fix: use SDXL VAE (4ch, no login needed)

Browse files
Files changed (1) hide show
  1. model.py +7 -5
model.py CHANGED
@@ -2,7 +2,7 @@
2
  LiquidGen: A Novel Liquid Neural Network Image Generation Model
3
 
4
  Architecture Overview:
5
- - Frozen VAE encoder/decoder (FLUX.1-schnell, 16ch latent, 8x compression)
6
  - Liquid backbone for denoising (fully parallelizable, no attention, no sequential ODE)
7
  - Flow matching training objective (velocity prediction)
8
 
@@ -305,7 +305,7 @@ class LiquidGen(nn.Module):
305
 
306
  def __init__(
307
  self,
308
- in_channels: int = 16,
309
  patch_size: int = 2,
310
  embed_dim: int = 512,
311
  depth: int = 16,
@@ -385,7 +385,7 @@ class LiquidGen(nn.Module):
385
  """
386
  Predict velocity field for flow matching.
387
  Args:
388
- x: [B, C, H, W] noisy latent (C=16 for Flux VAE)
389
  t: [B] timestep in [0, 1]
390
  class_labels: [B] optional class labels
391
  Returns:
@@ -459,13 +459,15 @@ if __name__ == "__main__":
459
  model = factory(num_classes=27).to(device)
460
  print(f"LiquidGen-{name}: {model.count_params() / 1e6:.1f}M params")
461
 
462
- x = torch.randn(2, 16, 32, 32, device=device)
 
463
  t = torch.rand(2, device=device)
464
  labels = torch.randint(0, 27, (2,), device=device)
465
  v = model(x, t, labels)
466
  assert v.shape == x.shape
467
 
468
- x512 = torch.randn(1, 16, 64, 64, device=device)
 
469
  v512 = model(x512, t[:1], labels[:1])
470
  assert v512.shape == x512.shape
471
  print(f" 256px ✅ 512px ✅")
 
2
  LiquidGen: A Novel Liquid Neural Network Image Generation Model
3
 
4
  Architecture Overview:
5
+ - Frozen VAE encoder/decoder (SDXL VAE, 4ch latent, 8x compression, no login needed)
6
  - Liquid backbone for denoising (fully parallelizable, no attention, no sequential ODE)
7
  - Flow matching training objective (velocity prediction)
8
 
 
305
 
306
  def __init__(
307
  self,
308
+ in_channels: int = 4, # 4 for SDXL VAE
309
  patch_size: int = 2,
310
  embed_dim: int = 512,
311
  depth: int = 16,
 
385
  """
386
  Predict velocity field for flow matching.
387
  Args:
388
+ x: [B, C, H, W] noisy latent (C=4 for SDXL VAE)
389
  t: [B] timestep in [0, 1]
390
  class_labels: [B] optional class labels
391
  Returns:
 
459
  model = factory(num_classes=27).to(device)
460
  print(f"LiquidGen-{name}: {model.count_params() / 1e6:.1f}M params")
461
 
462
+ # 256px: image/8 = 32x32 latent, 4 channels (SDXL VAE)
463
+ x = torch.randn(2, 4, 32, 32, device=device)
464
  t = torch.rand(2, device=device)
465
  labels = torch.randint(0, 27, (2,), device=device)
466
  v = model(x, t, labels)
467
  assert v.shape == x.shape
468
 
469
+ # 512px: image/8 = 64x64 latent
470
+ x512 = torch.randn(1, 4, 64, 64, device=device)
471
  v512 = model(x512, t[:1], labels[:1])
472
  assert v512.shape == x512.shape
473
  print(f" 256px ✅ 512px ✅")