File size: 1,665 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | """2D VAE encoder — wraps PixArt SDXL AutoencoderKL encoder half.
Encodes images or mel spectrograms to [B, 4, H/8, W/8] latents.
Same encoder used for images AND audio spectrograms (via MelSpectrogram3Band).
Frozen float32 sidecar (no gradients).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
def load_vae2d(device="cuda", quantize=None):
from diffusers import AutoencoderKL
vae = AutoencoderKL.from_pretrained(
"PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
torch_dtype=torch.float32,
).to(device)
vae.eval()
for p in vae.parameters():
p.requires_grad = False
return VAE2DEncoder(vae)
class VAE2DEncoder(nn.Module):
def __init__(self, vae):
super().__init__()
self.encoder = vae.encoder
self.quant_conv = vae.quant_conv
self.latent_channels = 4
self.input_scale = 0.18215
def forward(self, x):
H, W = x.shape[-2], x.shape[-1]
pad_h = (8 - H % 8) % 8
pad_w = (8 - W % 8) % 8
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, pad_w, 0, pad_h))
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = torch.distributions.Normal(
moments[:, :self.latent_channels],
torch.nn.functional.softplus(moments[:, self.latent_channels:])
)
latent = posterior.rsample()
latent = latent * self.input_scale
if pad_h > 0 or pad_w > 0:
out_h = H // 8 if H >= 8 else 1
out_w = W // 8 if W >= 8 else 1
latent = latent[:, :, :out_h, :out_w]
return latent
|