File size: 1,665 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""2D VAE encoder — wraps PixArt SDXL AutoencoderKL encoder half.

Encodes images or mel spectrograms to [B, 4, H/8, W/8] latents.
Same encoder used for images AND audio spectrograms (via MelSpectrogram3Band).
Frozen float32 sidecar (no gradients).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F


def load_vae2d(device="cuda", quantize=None):
    from diffusers import AutoencoderKL

    vae = AutoencoderKL.from_pretrained(
        "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
        subfolder="vae",
        torch_dtype=torch.float32,
    ).to(device)
    vae.eval()
    for p in vae.parameters():
        p.requires_grad = False

    return VAE2DEncoder(vae)


class VAE2DEncoder(nn.Module):
    def __init__(self, vae):
        super().__init__()
        self.encoder = vae.encoder
        self.quant_conv = vae.quant_conv
        self.latent_channels = 4
        self.input_scale = 0.18215

    def forward(self, x):
        H, W = x.shape[-2], x.shape[-1]
        pad_h = (8 - H % 8) % 8
        pad_w = (8 - W % 8) % 8
        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, (0, pad_w, 0, pad_h))

        h = self.encoder(x)
        moments = self.quant_conv(h)
        posterior = torch.distributions.Normal(
            moments[:, :self.latent_channels],
            torch.nn.functional.softplus(moments[:, self.latent_channels:])
        )
        latent = posterior.rsample()
        latent = latent * self.input_scale

        if pad_h > 0 or pad_w > 0:
            out_h = H // 8 if H >= 8 else 1
            out_w = W // 8 if W >= 8 else 1
            latent = latent[:, :, :out_h, :out_w]

        return latent