| """2D VAE encoder — wraps PixArt SDXL AutoencoderKL encoder half. |
| |
| Encodes images or mel spectrograms to [B, 4, H/8, W/8] latents. |
| Same encoder used for images AND audio spectrograms (via MelSpectrogram3Band). |
| Frozen float32 sidecar (no gradients). |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| def load_vae2d(device="cuda", quantize=None): |
| from diffusers import AutoencoderKL |
|
|
| vae = AutoencoderKL.from_pretrained( |
| "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", |
| subfolder="vae", |
| torch_dtype=torch.float32, |
| ).to(device) |
| vae.eval() |
| for p in vae.parameters(): |
| p.requires_grad = False |
|
|
| return VAE2DEncoder(vae) |
|
|
|
|
| class VAE2DEncoder(nn.Module): |
| def __init__(self, vae): |
| super().__init__() |
| self.encoder = vae.encoder |
| self.quant_conv = vae.quant_conv |
| self.latent_channels = 4 |
| self.input_scale = 0.18215 |
|
|
| def forward(self, x): |
| H, W = x.shape[-2], x.shape[-1] |
| pad_h = (8 - H % 8) % 8 |
| pad_w = (8 - W % 8) % 8 |
| if pad_h > 0 or pad_w > 0: |
| x = F.pad(x, (0, pad_w, 0, pad_h)) |
|
|
| h = self.encoder(x) |
| moments = self.quant_conv(h) |
| posterior = torch.distributions.Normal( |
| moments[:, :self.latent_channels], |
| torch.nn.functional.softplus(moments[:, self.latent_channels:]) |
| ) |
| latent = posterior.rsample() |
| latent = latent * self.input_scale |
|
|
| if pad_h > 0 or pad_w > 0: |
| out_h = H // 8 if H >= 8 else 1 |
| out_w = W // 8 if W >= 8 else 1 |
| latent = latent[:, :, :out_h, :out_w] |
|
|
| return latent |
|
|