Fix GroupNorm for arbitrary channel counts
Browse files- liquid_diffusion/model.py +12 -5
liquid_diffusion/model.py
CHANGED
|
@@ -78,7 +78,11 @@ class AdaLN(nn.Module):
|
|
| 78 |
"""Adaptive Layer Normalization: out = norm(x) * (1 + scale(t)) + shift(t)"""
|
| 79 |
def __init__(self, dim: int, cond_dim: int):
|
| 80 |
super().__init__()
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
self.proj = nn.Sequential(nn.SiLU(), nn.Linear(cond_dim, dim * 2))
|
| 83 |
|
| 84 |
def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
|
|
@@ -339,8 +343,11 @@ class LiquidDiffusionUNet(nn.Module):
|
|
| 339 |
self.decoder_blocks.append(stage)
|
| 340 |
|
| 341 |
# Output head (initialized to zero for stable start)
|
|
|
|
|
|
|
|
|
|
| 342 |
self.head = nn.Sequential(
|
| 343 |
-
nn.GroupNorm(
|
| 344 |
nn.SiLU(),
|
| 345 |
nn.Conv2d(channels[0], in_channels, 3, padding=1),
|
| 346 |
)
|
|
@@ -395,19 +402,19 @@ class LiquidDiffusionUNet(nn.Module):
|
|
| 395 |
# =============================================================================
|
| 396 |
|
| 397 |
def liquid_diffusion_tiny(**kwargs):
|
| 398 |
-
"""~
|
| 399 |
return LiquidDiffusionUNet(
|
| 400 |
channels=[64, 128, 256], blocks_per_stage=[2, 2, 4],
|
| 401 |
t_dim=256, expand_ratio=2.0, kernel_size=7, **kwargs)
|
| 402 |
|
| 403 |
def liquid_diffusion_small(**kwargs):
|
| 404 |
-
"""~
|
| 405 |
return LiquidDiffusionUNet(
|
| 406 |
channels=[96, 192, 384], blocks_per_stage=[2, 3, 6],
|
| 407 |
t_dim=384, expand_ratio=2.0, kernel_size=7, **kwargs)
|
| 408 |
|
| 409 |
def liquid_diffusion_base(**kwargs):
|
| 410 |
-
"""~
|
| 411 |
return LiquidDiffusionUNet(
|
| 412 |
channels=[128, 256, 512], blocks_per_stage=[2, 4, 8],
|
| 413 |
t_dim=512, expand_ratio=2.0, kernel_size=7, **kwargs)
|
|
|
|
| 78 |
"""Adaptive Layer Normalization: out = norm(x) * (1 + scale(t)) + shift(t)"""
|
| 79 |
def __init__(self, dim: int, cond_dim: int):
|
| 80 |
super().__init__()
|
| 81 |
+
# Find largest valid group count ≤ 32
|
| 82 |
+
num_groups = min(32, dim)
|
| 83 |
+
while dim % num_groups != 0:
|
| 84 |
+
num_groups -= 1
|
| 85 |
+
self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=dim, affine=False)
|
| 86 |
self.proj = nn.Sequential(nn.SiLU(), nn.Linear(cond_dim, dim * 2))
|
| 87 |
|
| 88 |
def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 343 |
self.decoder_blocks.append(stage)
|
| 344 |
|
| 345 |
# Output head (initialized to zero for stable start)
|
| 346 |
+
head_groups = min(32, channels[0])
|
| 347 |
+
while channels[0] % head_groups != 0:
|
| 348 |
+
head_groups -= 1
|
| 349 |
self.head = nn.Sequential(
|
| 350 |
+
nn.GroupNorm(head_groups, channels[0]),
|
| 351 |
nn.SiLU(),
|
| 352 |
nn.Conv2d(channels[0], in_channels, 3, padding=1),
|
| 353 |
)
|
|
|
|
| 402 |
# =============================================================================
|
| 403 |
|
| 404 |
def liquid_diffusion_tiny(**kwargs):
|
| 405 |
+
"""~23M params, 256px, fits ~6GB VRAM."""
|
| 406 |
return LiquidDiffusionUNet(
|
| 407 |
channels=[64, 128, 256], blocks_per_stage=[2, 2, 4],
|
| 408 |
t_dim=256, expand_ratio=2.0, kernel_size=7, **kwargs)
|
| 409 |
|
| 410 |
def liquid_diffusion_small(**kwargs):
|
| 411 |
+
"""~69M params, 256px, fits ~10GB VRAM."""
|
| 412 |
return LiquidDiffusionUNet(
|
| 413 |
channels=[96, 192, 384], blocks_per_stage=[2, 3, 6],
|
| 414 |
t_dim=384, expand_ratio=2.0, kernel_size=7, **kwargs)
|
| 415 |
|
| 416 |
def liquid_diffusion_base(**kwargs):
|
| 417 |
+
"""~154M params, 512px, fits ~16GB VRAM."""
|
| 418 |
return LiquidDiffusionUNet(
|
| 419 |
channels=[128, 256, 512], blocks_per_stage=[2, 4, 8],
|
| 420 |
t_dim=512, expand_ratio=2.0, kernel_size=7, **kwargs)
|