krystv commited on
Commit
19b13e1
·
verified ·
1 Parent(s): 421b295

Fix GroupNorm for arbitrary channel counts

Browse files
Files changed (1) hide show
  1. liquid_diffusion/model.py +12 -5
liquid_diffusion/model.py CHANGED
@@ -78,7 +78,11 @@ class AdaLN(nn.Module):
78
  """Adaptive Layer Normalization: out = norm(x) * (1 + scale(t)) + shift(t)"""
79
  def __init__(self, dim: int, cond_dim: int):
80
  super().__init__()
81
- self.norm = nn.GroupNorm(num_groups=min(32, dim), num_channels=dim, affine=False)
 
 
 
 
82
  self.proj = nn.Sequential(nn.SiLU(), nn.Linear(cond_dim, dim * 2))
83
 
84
  def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
@@ -339,8 +343,11 @@ class LiquidDiffusionUNet(nn.Module):
339
  self.decoder_blocks.append(stage)
340
 
341
  # Output head (initialized to zero for stable start)
 
 
 
342
  self.head = nn.Sequential(
343
- nn.GroupNorm(min(32, channels[0]), channels[0]),
344
  nn.SiLU(),
345
  nn.Conv2d(channels[0], in_channels, 3, padding=1),
346
  )
@@ -395,19 +402,19 @@ class LiquidDiffusionUNet(nn.Module):
395
  # =============================================================================
396
 
397
  def liquid_diffusion_tiny(**kwargs):
398
- """~8M params, 256px, fits ~4GB VRAM."""
399
  return LiquidDiffusionUNet(
400
  channels=[64, 128, 256], blocks_per_stage=[2, 2, 4],
401
  t_dim=256, expand_ratio=2.0, kernel_size=7, **kwargs)
402
 
403
  def liquid_diffusion_small(**kwargs):
404
- """~25M params, 256px, fits ~8GB VRAM."""
405
  return LiquidDiffusionUNet(
406
  channels=[96, 192, 384], blocks_per_stage=[2, 3, 6],
407
  t_dim=384, expand_ratio=2.0, kernel_size=7, **kwargs)
408
 
409
  def liquid_diffusion_base(**kwargs):
410
- """~65M params, 512px, fits ~14GB VRAM."""
411
  return LiquidDiffusionUNet(
412
  channels=[128, 256, 512], blocks_per_stage=[2, 4, 8],
413
  t_dim=512, expand_ratio=2.0, kernel_size=7, **kwargs)
 
78
  """Adaptive Layer Normalization: out = norm(x) * (1 + scale(t)) + shift(t)"""
79
  def __init__(self, dim: int, cond_dim: int):
80
  super().__init__()
81
+ # Find largest valid group count ≤ 32
82
+ num_groups = min(32, dim)
83
+ while dim % num_groups != 0:
84
+ num_groups -= 1
85
+ self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=dim, affine=False)
86
  self.proj = nn.Sequential(nn.SiLU(), nn.Linear(cond_dim, dim * 2))
87
 
88
  def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
 
343
  self.decoder_blocks.append(stage)
344
 
345
  # Output head (initialized to zero for stable start)
346
+ head_groups = min(32, channels[0])
347
+ while channels[0] % head_groups != 0:
348
+ head_groups -= 1
349
  self.head = nn.Sequential(
350
+ nn.GroupNorm(head_groups, channels[0]),
351
  nn.SiLU(),
352
  nn.Conv2d(channels[0], in_channels, 3, padding=1),
353
  )
 
402
  # =============================================================================
403
 
404
  def liquid_diffusion_tiny(**kwargs):
405
+ """~23M params, 256px, fits ~6GB VRAM."""
406
  return LiquidDiffusionUNet(
407
  channels=[64, 128, 256], blocks_per_stage=[2, 2, 4],
408
  t_dim=256, expand_ratio=2.0, kernel_size=7, **kwargs)
409
 
410
  def liquid_diffusion_small(**kwargs):
411
+ """~69M params, 256px, fits ~10GB VRAM."""
412
  return LiquidDiffusionUNet(
413
  channels=[96, 192, 384], blocks_per_stage=[2, 3, 6],
414
  t_dim=384, expand_ratio=2.0, kernel_size=7, **kwargs)
415
 
416
  def liquid_diffusion_base(**kwargs):
417
+ """~154M params, 512px, fits ~16GB VRAM."""
418
  return LiquidDiffusionUNet(
419
  channels=[128, 256, 512], blocks_per_stage=[2, 4, 8],
420
  t_dim=512, expand_ratio=2.0, kernel_size=7, **kwargs)