data-archetype commited on
Commit
ec37736
·
verified ·
1 Parent(s): 68589fe

Upload folder using huggingface_hub

Browse files
fcdm_diffae/__init__.py CHANGED
@@ -26,8 +26,8 @@ from .encoder import EncoderPosterior
26
  from .model import FCDMDiffAE
27
 
28
  __all__ = [
 
29
  "FCDMDiffAE",
30
  "FCDMDiffAEConfig",
31
  "FCDMDiffAEInferenceConfig",
32
- "EncoderPosterior",
33
  ]
 
26
  from .model import FCDMDiffAE
27
 
28
  __all__ = [
29
+ "EncoderPosterior",
30
  "FCDMDiffAE",
31
  "FCDMDiffAEConfig",
32
  "FCDMDiffAEInferenceConfig",
 
33
  ]
fcdm_diffae/__pycache__/__init__.cpython-312.pyc CHANGED
Binary files a/fcdm_diffae/__pycache__/__init__.cpython-312.pyc and b/fcdm_diffae/__pycache__/__init__.cpython-312.pyc differ
 
fcdm_diffae/__pycache__/config.cpython-312.pyc CHANGED
Binary files a/fcdm_diffae/__pycache__/config.cpython-312.pyc and b/fcdm_diffae/__pycache__/config.cpython-312.pyc differ
 
fcdm_diffae/__pycache__/decoder.cpython-312.pyc CHANGED
Binary files a/fcdm_diffae/__pycache__/decoder.cpython-312.pyc and b/fcdm_diffae/__pycache__/decoder.cpython-312.pyc differ
 
fcdm_diffae/__pycache__/encoder.cpython-312.pyc CHANGED
Binary files a/fcdm_diffae/__pycache__/encoder.cpython-312.pyc and b/fcdm_diffae/__pycache__/encoder.cpython-312.pyc differ
 
fcdm_diffae/__pycache__/model.cpython-312.pyc CHANGED
Binary files a/fcdm_diffae/__pycache__/model.cpython-312.pyc and b/fcdm_diffae/__pycache__/model.cpython-312.pyc differ
 
fcdm_diffae/__pycache__/samplers.cpython-312.pyc CHANGED
Binary files a/fcdm_diffae/__pycache__/samplers.cpython-312.pyc and b/fcdm_diffae/__pycache__/samplers.cpython-312.pyc differ
 
fcdm_diffae/config.py CHANGED
@@ -26,12 +26,30 @@ class FCDMDiffAEConfig:
26
  bottleneck_posterior_kind: str = "diagonal_gaussian"
27
  # Post-bottleneck normalization: "channel_wise" or "disabled"
28
  bottleneck_norm_mode: str = "disabled"
 
 
 
 
29
  # VP diffusion schedule endpoints
30
  logsnr_min: float = -10.0
31
  logsnr_max: float = 10.0
32
  # Pixel-space noise std for VP diffusion initialization
33
  pixel_noise_std: float = 0.558
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def save(self, path: str | Path) -> None:
36
  """Save config as JSON."""
37
  p = Path(path)
 
26
  bottleneck_posterior_kind: str = "diagonal_gaussian"
27
  # Post-bottleneck normalization: "channel_wise" or "disabled"
28
  bottleneck_norm_mode: str = "disabled"
29
+ # Bottleneck patchification: "off" or "patch_2x2"
30
+ # When "patch_2x2", encoder latents are 2x2 patchified after the bottleneck
31
+ # (channels * 4, spatial / 2), and decode unpatchifies before the decoder.
32
+ bottleneck_patchify_mode: str = "off"
33
  # VP diffusion schedule endpoints
34
  logsnr_min: float = -10.0
35
  logsnr_max: float = 10.0
36
  # Pixel-space noise std for VP diffusion initialization
37
  pixel_noise_std: float = 0.558
38
 
39
+ @property
40
+ def latent_channels(self) -> int:
41
+ """Channel width of the exported latent space."""
42
+ if self.bottleneck_patchify_mode == "patch_2x2":
43
+ return self.bottleneck_dim * 4
44
+ return self.bottleneck_dim
45
+
46
+ @property
47
+ def effective_patch_size(self) -> int:
48
+ """Effective spatial stride from image to latent grid."""
49
+ if self.bottleneck_patchify_mode == "patch_2x2":
50
+ return self.patch_size * 2
51
+ return self.patch_size
52
+
53
  def save(self, path: str | Path) -> None:
54
  """Save config as JSON."""
55
  p = Path(path)
fcdm_diffae/model.py CHANGED
@@ -71,14 +71,14 @@ class FCDMDiffAE(nn.Module):
71
  super().__init__()
72
  self.config = config
73
 
74
- # Latent running stats for whitening/dewhitening
75
  self.register_buffer(
76
  "latent_norm_running_mean",
77
- torch.zeros((config.bottleneck_dim,), dtype=torch.float32),
78
  )
79
  self.register_buffer(
80
  "latent_norm_running_var",
81
- torch.ones((config.bottleneck_dim,), dtype=torch.float32),
82
  )
83
 
84
  self.encoder = Encoder(
@@ -205,6 +205,20 @@ class FCDMDiffAE(nn.Module):
205
  mean, std = self._latent_norm_stats()
206
  return z * std.to(device=z.device) + mean.to(device=z.device)
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  def encode(self, images: Tensor) -> Tensor:
209
  """Encode images to whitened latents (posterior mode).
210
 
@@ -212,16 +226,19 @@ class FCDMDiffAE(nn.Module):
212
  use by downstream latent-space diffusion models.
213
 
214
  Args:
215
- images: [B, 3, H, W] in [-1, 1], H and W divisible by patch_size.
 
216
 
217
  Returns:
218
- Whitened latents [B, bottleneck_dim, H/patch, W/patch].
219
  """
220
  try:
221
  model_dtype = next(self.parameters()).dtype
222
  except StopIteration:
223
  model_dtype = torch.float32
224
  z = self.encoder(images.to(dtype=model_dtype))
 
 
225
  return self.whiten(z).to(dtype=model_dtype)
226
 
227
  def encode_posterior(self, images: Tensor) -> EncoderPosterior:
@@ -250,12 +267,13 @@ class FCDMDiffAE(nn.Module):
250
  ) -> Tensor:
251
  """Decode whitened latents to images via VP diffusion.
252
 
253
- Latents are dewhitened internally before being passed to the decoder.
 
254
 
255
  Args:
256
- latents: [B, bottleneck_dim, h, w] whitened encoder latents.
257
- height: Output image height (divisible by patch_size).
258
- width: Output image width (divisible by patch_size).
259
  inference_config: Optional inference parameters.
260
 
261
  Returns:
@@ -271,8 +289,11 @@ class FCDMDiffAE(nn.Module):
271
  except StopIteration:
272
  model_dtype = torch.float32
273
 
274
- # Dewhiten back to raw encoder scale for the decoder
275
- latents = self.dewhiten(latents).to(dtype=model_dtype)
 
 
 
276
 
277
  if height % config.patch_size != 0 or width % config.patch_size != 0:
278
  raise ValueError(
 
71
  super().__init__()
72
  self.config = config
73
 
74
+ # Latent running stats for whitening/dewhitening (at exported latent channels)
75
  self.register_buffer(
76
  "latent_norm_running_mean",
77
+ torch.zeros((config.latent_channels,), dtype=torch.float32),
78
  )
79
  self.register_buffer(
80
  "latent_norm_running_var",
81
+ torch.ones((config.latent_channels,), dtype=torch.float32),
82
  )
83
 
84
  self.encoder = Encoder(
 
205
  mean, std = self._latent_norm_stats()
206
  return z * std.to(device=z.device) + mean.to(device=z.device)
207
 
208
+ def _patchify(self, z: Tensor) -> Tensor:
209
+ """2x2 patchify: [B, C, H, W] -> [B, 4C, H/2, W/2]."""
210
+ b, c, h, w = z.shape
211
+ z = z.reshape(b, c, h // 2, 2, w // 2, 2)
212
+ z = z.permute(0, 1, 3, 5, 2, 4)
213
+ return z.reshape(b, c * 4, h // 2, w // 2)
214
+
215
+ def _unpatchify(self, z: Tensor) -> Tensor:
216
+ """2x2 unpatchify: [B, 4C, H/2, W/2] -> [B, C, H, W]."""
217
+ b, c, h, w = z.shape
218
+ z = z.reshape(b, c // 4, 2, 2, h, w)
219
+ z = z.permute(0, 1, 4, 2, 5, 3)
220
+ return z.reshape(b, c // 4, h * 2, w * 2)
221
+
222
  def encode(self, images: Tensor) -> Tensor:
223
  """Encode images to whitened latents (posterior mode).
224
 
 
226
  use by downstream latent-space diffusion models.
227
 
228
  Args:
229
+ images: [B, 3, H, W] in [-1, 1], H and W divisible by
230
+ effective_patch_size.
231
 
232
  Returns:
233
+ Whitened latents [B, latent_channels, H/effective_patch, W/effective_patch].
234
  """
235
  try:
236
  model_dtype = next(self.parameters()).dtype
237
  except StopIteration:
238
  model_dtype = torch.float32
239
  z = self.encoder(images.to(dtype=model_dtype))
240
+ if self.config.bottleneck_patchify_mode == "patch_2x2":
241
+ z = self._patchify(z)
242
  return self.whiten(z).to(dtype=model_dtype)
243
 
244
  def encode_posterior(self, images: Tensor) -> EncoderPosterior:
 
267
  ) -> Tensor:
268
  """Decode whitened latents to images via VP diffusion.
269
 
270
+ Latents are dewhitened and (if applicable) unpatchified internally
271
+ before being passed to the decoder.
272
 
273
  Args:
274
+ latents: [B, latent_channels, h, w] whitened encoder latents.
275
+ height: Output image height (divisible by effective_patch_size).
276
+ width: Output image width (divisible by effective_patch_size).
277
  inference_config: Optional inference parameters.
278
 
279
  Returns:
 
289
  except StopIteration:
290
  model_dtype = torch.float32
291
 
292
+ # Dewhiten and unpatchify back to raw encoder scale for the decoder
293
+ latents = self.dewhiten(latents)
294
+ if config.bottleneck_patchify_mode == "patch_2x2":
295
+ latents = self._unpatchify(latents)
296
+ latents = latents.to(dtype=model_dtype)
297
 
298
  if height % config.patch_size != 0 or width % config.patch_size != 0:
299
  raise ValueError(