Fix: use SDXL VAE (4ch, no login needed)
Browse files
model.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
LiquidGen: A Novel Liquid Neural Network Image Generation Model
|
| 3 |
|
| 4 |
Architecture Overview:
|
| 5 |
-
- Frozen VAE encoder/decoder (
|
| 6 |
- Liquid backbone for denoising (fully parallelizable, no attention, no sequential ODE)
|
| 7 |
- Flow matching training objective (velocity prediction)
|
| 8 |
|
|
@@ -305,7 +305,7 @@ class LiquidGen(nn.Module):
|
|
| 305 |
|
| 306 |
def __init__(
|
| 307 |
self,
|
| 308 |
-
in_channels: int =
|
| 309 |
patch_size: int = 2,
|
| 310 |
embed_dim: int = 512,
|
| 311 |
depth: int = 16,
|
|
@@ -385,7 +385,7 @@ class LiquidGen(nn.Module):
|
|
| 385 |
"""
|
| 386 |
Predict velocity field for flow matching.
|
| 387 |
Args:
|
| 388 |
-
x: [B, C, H, W] noisy latent (C=
|
| 389 |
t: [B] timestep in [0, 1]
|
| 390 |
class_labels: [B] optional class labels
|
| 391 |
Returns:
|
|
@@ -459,13 +459,15 @@ if __name__ == "__main__":
|
|
| 459 |
model = factory(num_classes=27).to(device)
|
| 460 |
print(f"LiquidGen-{name}: {model.count_params() / 1e6:.1f}M params")
|
| 461 |
|
| 462 |
-
|
|
|
|
| 463 |
t = torch.rand(2, device=device)
|
| 464 |
labels = torch.randint(0, 27, (2,), device=device)
|
| 465 |
v = model(x, t, labels)
|
| 466 |
assert v.shape == x.shape
|
| 467 |
|
| 468 |
-
|
|
|
|
| 469 |
v512 = model(x512, t[:1], labels[:1])
|
| 470 |
assert v512.shape == x512.shape
|
| 471 |
print(f" 256px ✅ 512px ✅")
|
|
|
|
| 2 |
LiquidGen: A Novel Liquid Neural Network Image Generation Model
|
| 3 |
|
| 4 |
Architecture Overview:
|
| 5 |
+
- Frozen VAE encoder/decoder (SDXL VAE, 4ch latent, 8x compression, no login needed)
|
| 6 |
- Liquid backbone for denoising (fully parallelizable, no attention, no sequential ODE)
|
| 7 |
- Flow matching training objective (velocity prediction)
|
| 8 |
|
|
|
|
| 305 |
|
| 306 |
def __init__(
|
| 307 |
self,
|
| 308 |
+
in_channels: int = 4, # 4 for SDXL VAE
|
| 309 |
patch_size: int = 2,
|
| 310 |
embed_dim: int = 512,
|
| 311 |
depth: int = 16,
|
|
|
|
| 385 |
"""
|
| 386 |
Predict velocity field for flow matching.
|
| 387 |
Args:
|
| 388 |
+
x: [B, C, H, W] noisy latent (C=4 for SDXL VAE)
|
| 389 |
t: [B] timestep in [0, 1]
|
| 390 |
class_labels: [B] optional class labels
|
| 391 |
Returns:
|
|
|
|
| 459 |
model = factory(num_classes=27).to(device)
|
| 460 |
print(f"LiquidGen-{name}: {model.count_params() / 1e6:.1f}M params")
|
| 461 |
|
| 462 |
+
# 256px: image/8 = 32x32 latent, 4 channels (SDXL VAE)
|
| 463 |
+
x = torch.randn(2, 4, 32, 32, device=device)
|
| 464 |
t = torch.rand(2, device=device)
|
| 465 |
labels = torch.randint(0, 27, (2,), device=device)
|
| 466 |
v = model(x, t, labels)
|
| 467 |
assert v.shape == x.shape
|
| 468 |
|
| 469 |
+
# 512px: image/8 = 64x64 latent
|
| 470 |
+
x512 = torch.randn(1, 4, 64, 64, device=device)
|
| 471 |
v512 = model(x512, t[:1], labels[:1])
|
| 472 |
assert v512.shape == x512.shape
|
| 473 |
print(f" 256px ✅ 512px ✅")
|