"""2D VAE encoder — wraps PixArt SDXL AutoencoderKL encoder half. Encodes images or mel spectrograms to [B, 4, H/8, W/8] latents. Same encoder used for images AND audio spectrograms (via MelSpectrogram3Band). Frozen float32 sidecar (no gradients). """ import torch import torch.nn as nn import torch.nn.functional as F def load_vae2d(device="cuda", quantize=None): from diffusers import AutoencoderKL vae = AutoencoderKL.from_pretrained( "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", subfolder="vae", torch_dtype=torch.float32, ).to(device) vae.eval() for p in vae.parameters(): p.requires_grad = False return VAE2DEncoder(vae) class VAE2DEncoder(nn.Module): def __init__(self, vae): super().__init__() self.encoder = vae.encoder self.quant_conv = vae.quant_conv self.latent_channels = 4 self.input_scale = 0.18215 def forward(self, x): H, W = x.shape[-2], x.shape[-1] pad_h = (8 - H % 8) % 8 pad_w = (8 - W % 8) % 8 if pad_h > 0 or pad_w > 0: x = F.pad(x, (0, pad_w, 0, pad_h)) h = self.encoder(x) moments = self.quant_conv(h) posterior = torch.distributions.Normal( moments[:, :self.latent_channels], torch.nn.functional.softplus(moments[:, self.latent_channels:]) ) latent = posterior.rsample() latent = latent * self.input_scale if pad_h > 0 or pad_w > 0: out_h = H // 8 if H >= 8 else 1 out_w = W // 8 if W >= 8 else 1 latent = latent[:, :, :out_h, :out_w] return latent