Upload folder using huggingface_hub
Browse files- capacitor_diffae/model.py +14 -5
capacitor_diffae/model.py
CHANGED
|
@@ -206,19 +206,23 @@ class CapacitorDiffAE(nn.Module):
|
|
| 206 |
return z * std.to(device=z.device) + mean.to(device=z.device)
|
| 207 |
|
| 208 |
def encode(self, images: Tensor) -> Tensor:
|
| 209 |
-
"""Encode images to latents (posterior mode).
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
Args:
|
| 212 |
images: [B, 3, H, W] in [-1, 1], H and W divisible by patch_size.
|
| 213 |
|
| 214 |
Returns:
|
| 215 |
-
|
| 216 |
"""
|
| 217 |
try:
|
| 218 |
model_dtype = next(self.parameters()).dtype
|
| 219 |
except StopIteration:
|
| 220 |
model_dtype = torch.float32
|
| 221 |
-
|
|
|
|
| 222 |
|
| 223 |
def encode_posterior(self, images: Tensor) -> EncoderPosterior:
|
| 224 |
"""Encode images and return the full posterior (mean + logsnr).
|
|
@@ -244,10 +248,12 @@ class CapacitorDiffAE(nn.Module):
|
|
| 244 |
*,
|
| 245 |
inference_config: CapacitorDiffAEInferenceConfig | None = None,
|
| 246 |
) -> Tensor:
|
| 247 |
-
"""Decode latents to images via VP diffusion.
|
|
|
|
|
|
|
| 248 |
|
| 249 |
Args:
|
| 250 |
-
latents: [B, bottleneck_dim, h, w] encoder latents.
|
| 251 |
height: Output image height (divisible by patch_size).
|
| 252 |
width: Output image width (divisible by patch_size).
|
| 253 |
inference_config: Optional inference parameters.
|
|
@@ -265,6 +271,9 @@ class CapacitorDiffAE(nn.Module):
|
|
| 265 |
except StopIteration:
|
| 266 |
model_dtype = torch.float32
|
| 267 |
|
|
|
|
|
|
|
|
|
|
| 268 |
if height % config.patch_size != 0 or width % config.patch_size != 0:
|
| 269 |
raise ValueError(
|
| 270 |
f"height={height} and width={width} must be divisible by "
|
|
|
|
| 206 |
return z * std.to(device=z.device) + mean.to(device=z.device)
|
| 207 |
|
| 208 |
def encode(self, images: Tensor) -> Tensor:
|
| 209 |
+
"""Encode images to whitened latents (posterior mode).
|
| 210 |
+
|
| 211 |
+
Returns latents whitened using per-channel running stats, ready for
|
| 212 |
+
use by downstream latent-space diffusion models.
|
| 213 |
|
| 214 |
Args:
|
| 215 |
images: [B, 3, H, W] in [-1, 1], H and W divisible by patch_size.
|
| 216 |
|
| 217 |
Returns:
|
| 218 |
+
Whitened latents [B, bottleneck_dim, H/patch, W/patch].
|
| 219 |
"""
|
| 220 |
try:
|
| 221 |
model_dtype = next(self.parameters()).dtype
|
| 222 |
except StopIteration:
|
| 223 |
model_dtype = torch.float32
|
| 224 |
+
z = self.encoder(images.to(dtype=model_dtype))
|
| 225 |
+
return self.whiten(z).to(dtype=model_dtype)
|
| 226 |
|
| 227 |
def encode_posterior(self, images: Tensor) -> EncoderPosterior:
|
| 228 |
"""Encode images and return the full posterior (mean + logsnr).
|
|
|
|
| 248 |
*,
|
| 249 |
inference_config: CapacitorDiffAEInferenceConfig | None = None,
|
| 250 |
) -> Tensor:
|
| 251 |
+
"""Decode whitened latents to images via VP diffusion.
|
| 252 |
+
|
| 253 |
+
Latents are dewhitened internally before being passed to the decoder.
|
| 254 |
|
| 255 |
Args:
|
| 256 |
+
latents: [B, bottleneck_dim, h, w] whitened encoder latents.
|
| 257 |
height: Output image height (divisible by patch_size).
|
| 258 |
width: Output image width (divisible by patch_size).
|
| 259 |
inference_config: Optional inference parameters.
|
|
|
|
| 271 |
except StopIteration:
|
| 272 |
model_dtype = torch.float32
|
| 273 |
|
| 274 |
+
# Dewhiten back to raw encoder scale for the decoder
|
| 275 |
+
latents = self.dewhiten(latents).to(dtype=model_dtype)
|
| 276 |
+
|
| 277 |
if height % config.patch_size != 0 or width % config.patch_size != 0:
|
| 278 |
raise ValueError(
|
| 279 |
f"height={height} and width={width} must be divisible by "
|