JTriggerFish commited on
Commit ·
e12172b
1
Parent(s): 0f964b6
Fix encode_posterior patchify, input/output size validation for p32 mode
Browse files- fcdm_diffae/model.py +28 -9
fcdm_diffae/model.py
CHANGED
|
@@ -232,6 +232,13 @@ class FCDMDiffAE(nn.Module):
|
|
| 232 |
Returns:
|
| 233 |
Whitened latents [B, latent_channels, H/effective_patch, W/effective_patch].
|
| 234 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
try:
|
| 236 |
model_dtype = next(self.parameters()).dtype
|
| 237 |
except StopIteration:
|
|
@@ -244,17 +251,28 @@ class FCDMDiffAE(nn.Module):
|
|
| 244 |
def encode_posterior(self, images: Tensor) -> EncoderPosterior:
|
| 245 |
"""Encode images and return the full posterior (mean + logsnr).
|
| 246 |
|
|
|
|
|
|
|
|
|
|
| 247 |
Args:
|
| 248 |
-
images: [B, 3, H, W] in [-1, 1], H and W divisible by
|
|
|
|
| 249 |
|
| 250 |
Returns:
|
| 251 |
-
EncoderPosterior with mean and logsnr tensors
|
|
|
|
| 252 |
"""
|
| 253 |
try:
|
| 254 |
model_dtype = next(self.parameters()).dtype
|
| 255 |
except StopIteration:
|
| 256 |
model_dtype = torch.float32
|
| 257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
|
| 259 |
@torch.no_grad()
|
| 260 |
def decode(
|
|
@@ -289,18 +307,19 @@ class FCDMDiffAE(nn.Module):
|
|
| 289 |
except StopIteration:
|
| 290 |
model_dtype = torch.float32
|
| 291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
# Dewhiten and unpatchify back to raw encoder scale for the decoder
|
| 293 |
latents = self.dewhiten(latents)
|
| 294 |
if config.bottleneck_patchify_mode == "patch_2x2":
|
| 295 |
latents = self._unpatchify(latents)
|
| 296 |
latents = latents.to(dtype=model_dtype)
|
| 297 |
|
| 298 |
-
if height % config.patch_size != 0 or width % config.patch_size != 0:
|
| 299 |
-
raise ValueError(
|
| 300 |
-
f"height={height} and width={width} must be divisible by "
|
| 301 |
-
f"patch_size={config.patch_size}"
|
| 302 |
-
)
|
| 303 |
-
|
| 304 |
shape = (batch, config.in_channels, height, width)
|
| 305 |
noise = sample_noise(
|
| 306 |
shape,
|
|
|
|
| 232 |
Returns:
|
| 233 |
Whitened latents [B, latent_channels, H/effective_patch, W/effective_patch].
|
| 234 |
"""
|
| 235 |
+
eff_patch = self.config.effective_patch_size
|
| 236 |
+
h, w = int(images.shape[2]), int(images.shape[3])
|
| 237 |
+
if h % eff_patch != 0 or w % eff_patch != 0:
|
| 238 |
+
raise ValueError(
|
| 239 |
+
f"Image height={h} and width={w} must be divisible by "
|
| 240 |
+
f"effective_patch_size={eff_patch}"
|
| 241 |
+
)
|
| 242 |
try:
|
| 243 |
model_dtype = next(self.parameters()).dtype
|
| 244 |
except StopIteration:
|
|
|
|
| 251 |
def encode_posterior(self, images: Tensor) -> EncoderPosterior:
|
| 252 |
"""Encode images and return the full posterior (mean + logsnr).
|
| 253 |
|
| 254 |
+
In patch-32 mode, the posterior is returned in the patchified space
|
| 255 |
+
(512ch at H/32), consistent with encode() and whiten().
|
| 256 |
+
|
| 257 |
Args:
|
| 258 |
+
images: [B, 3, H, W] in [-1, 1], H and W divisible by
|
| 259 |
+
effective_patch_size.
|
| 260 |
|
| 261 |
Returns:
|
| 262 |
+
EncoderPosterior with mean and logsnr tensors in the exported
|
| 263 |
+
latent space.
|
| 264 |
"""
|
| 265 |
try:
|
| 266 |
model_dtype = next(self.parameters()).dtype
|
| 267 |
except StopIteration:
|
| 268 |
model_dtype = torch.float32
|
| 269 |
+
posterior = self.encoder.encode_posterior(images.to(dtype=model_dtype))
|
| 270 |
+
if self.config.bottleneck_patchify_mode == "patch_2x2":
|
| 271 |
+
return EncoderPosterior(
|
| 272 |
+
mean=self._patchify(posterior.mean),
|
| 273 |
+
logsnr=self._patchify(posterior.logsnr),
|
| 274 |
+
)
|
| 275 |
+
return posterior
|
| 276 |
|
| 277 |
@torch.no_grad()
|
| 278 |
def decode(
|
|
|
|
| 307 |
except StopIteration:
|
| 308 |
model_dtype = torch.float32
|
| 309 |
|
| 310 |
+
eff_patch = config.effective_patch_size
|
| 311 |
+
if height % eff_patch != 0 or width % eff_patch != 0:
|
| 312 |
+
raise ValueError(
|
| 313 |
+
f"height={height} and width={width} must be divisible by "
|
| 314 |
+
f"effective_patch_size={eff_patch}"
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
# Dewhiten and unpatchify back to raw encoder scale for the decoder
|
| 318 |
latents = self.dewhiten(latents)
|
| 319 |
if config.bottleneck_patchify_mode == "patch_2x2":
|
| 320 |
latents = self._unpatchify(latents)
|
| 321 |
latents = latents.to(dtype=model_dtype)
|
| 322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
shape = (batch, config.in_channels, height, width)
|
| 324 |
noise = sample_noise(
|
| 325 |
shape,
|