JTriggerFish commited on
Commit
e12172b
·
1 Parent(s): 0f964b6

Fix encode_posterior patchify, input/output size validation for p32 mode

Browse files
Files changed (1) hide show
  1. fcdm_diffae/model.py +28 -9
fcdm_diffae/model.py CHANGED
@@ -232,6 +232,13 @@ class FCDMDiffAE(nn.Module):
232
  Returns:
233
  Whitened latents [B, latent_channels, H/effective_patch, W/effective_patch].
234
  """
 
 
 
 
 
 
 
235
  try:
236
  model_dtype = next(self.parameters()).dtype
237
  except StopIteration:
@@ -244,17 +251,28 @@ class FCDMDiffAE(nn.Module):
244
  def encode_posterior(self, images: Tensor) -> EncoderPosterior:
245
  """Encode images and return the full posterior (mean + logsnr).
246
 
 
 
 
247
  Args:
248
- images: [B, 3, H, W] in [-1, 1], H and W divisible by patch_size.
 
249
 
250
  Returns:
251
- EncoderPosterior with mean and logsnr tensors.
 
252
  """
253
  try:
254
  model_dtype = next(self.parameters()).dtype
255
  except StopIteration:
256
  model_dtype = torch.float32
257
- return self.encoder.encode_posterior(images.to(dtype=model_dtype))
 
 
 
 
 
 
258
 
259
  @torch.no_grad()
260
  def decode(
@@ -289,18 +307,19 @@ class FCDMDiffAE(nn.Module):
289
  except StopIteration:
290
  model_dtype = torch.float32
291
 
 
 
 
 
 
 
 
292
  # Dewhiten and unpatchify back to raw encoder scale for the decoder
293
  latents = self.dewhiten(latents)
294
  if config.bottleneck_patchify_mode == "patch_2x2":
295
  latents = self._unpatchify(latents)
296
  latents = latents.to(dtype=model_dtype)
297
 
298
- if height % config.patch_size != 0 or width % config.patch_size != 0:
299
- raise ValueError(
300
- f"height={height} and width={width} must be divisible by "
301
- f"patch_size={config.patch_size}"
302
- )
303
-
304
  shape = (batch, config.in_channels, height, width)
305
  noise = sample_noise(
306
  shape,
 
232
  Returns:
233
  Whitened latents [B, latent_channels, H/effective_patch, W/effective_patch].
234
  """
235
+ eff_patch = self.config.effective_patch_size
236
+ h, w = int(images.shape[2]), int(images.shape[3])
237
+ if h % eff_patch != 0 or w % eff_patch != 0:
238
+ raise ValueError(
239
+ f"Image height={h} and width={w} must be divisible by "
240
+ f"effective_patch_size={eff_patch}"
241
+ )
242
  try:
243
  model_dtype = next(self.parameters()).dtype
244
  except StopIteration:
 
251
  def encode_posterior(self, images: Tensor) -> EncoderPosterior:
252
  """Encode images and return the full posterior (mean + logsnr).
253
 
254
+ In patch-32 mode, the posterior is returned in the patchified space
255
+ (512ch at H/32), consistent with encode() and whiten().
256
+
257
  Args:
258
+ images: [B, 3, H, W] in [-1, 1], H and W divisible by
259
+ effective_patch_size.
260
 
261
  Returns:
262
+ EncoderPosterior with mean and logsnr tensors in the exported
263
+ latent space.
264
  """
265
  try:
266
  model_dtype = next(self.parameters()).dtype
267
  except StopIteration:
268
  model_dtype = torch.float32
269
+ posterior = self.encoder.encode_posterior(images.to(dtype=model_dtype))
270
+ if self.config.bottleneck_patchify_mode == "patch_2x2":
271
+ return EncoderPosterior(
272
+ mean=self._patchify(posterior.mean),
273
+ logsnr=self._patchify(posterior.logsnr),
274
+ )
275
+ return posterior
276
 
277
  @torch.no_grad()
278
  def decode(
 
307
  except StopIteration:
308
  model_dtype = torch.float32
309
 
310
+ eff_patch = config.effective_patch_size
311
+ if height % eff_patch != 0 or width % eff_patch != 0:
312
+ raise ValueError(
313
+ f"height={height} and width={width} must be divisible by "
314
+ f"effective_patch_size={eff_patch}"
315
+ )
316
+
317
  # Dewhiten and unpatchify back to raw encoder scale for the decoder
318
  latents = self.dewhiten(latents)
319
  if config.bottleneck_patchify_mode == "patch_2x2":
320
  latents = self._unpatchify(latents)
321
  latents = latents.to(dtype=model_dtype)
322
 
 
 
 
 
 
 
323
  shape = (batch, config.in_channels, height, width)
324
  noise = sample_noise(
325
  shape,