File size: 3,823 Bytes
b9e8cb3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | """
VAE Wrappers — compatible VAE interfaces for LiquidFlow.
Supports two VAE backends:
1. TAESD (Tiny AutoEncoder for SD): < 1M params, extremely fast, perfect for mobile
2. SD-VAE (Stability AI VAE): Higher quality, 84M params, standard for SD pipelines
TAESD is the DEFAULT for LiquidFlow — it's designed to be lightweight and
fast enough for Colab/Kaggle free tier.
Paper reference: "Tiny AutoEncoder for Stable Diffusion" (madebyollin/taesd)
Model: madebyollin/taesd (335K downloads on HF)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class TAESDWrapper:
"""
Wrapper for Tiny AutoEncoder for Stable Diffusion (TAESD).
TAESD properties:
- ~1M parameters (vs 84M for SD VAE)
- Latent dim: 4 channels @ 8x compression
- Extremely fast encode/decode
- Works on CPU — no GPU needed
- Perfect for Colab/Kaggle free tier
Model on HF: madebyollin/taesd
"""
def __init__(self, device='cpu'):
self.device = device
self.model = None
@staticmethod
def is_available():
"""Check if TAESD can be loaded."""
try:
from diffusers import AutoencoderTiny
return True
except ImportError:
return False
@staticmethod
def load(device='cpu'):
"""Load TAESD model."""
from diffusers import AutoencoderTiny
model = AutoencoderTiny.from_pretrained(
"madebyollin/taesd",
torch_dtype=torch.float32,
)
model = model.to(device)
model.eval()
return model
@staticmethod
def get_latent_shape(image_size):
"""Get latent spatial size given image size (8x compression)."""
return image_size // 8
@staticmethod
def encode(vae, x):
"""
Encode image to latent.
Args:
vae: TAESD model
x: [B, 3, H, W] images in [-1, 1]
Returns:
z: [B, 4, H/8, W/8]
"""
with torch.no_grad():
posterior = vae.encode(x).latent_dist
z = posterior.sample()
z = z * vae.config.scaling_factor
return z
@staticmethod
def decode(vae, z):
"""
Decode latent to image.
Args:
vae: TAESD model
z: [B, 4, H/8, W/8]
Returns:
x: [B, 3, H, W] images in [-1, 1]
"""
with torch.no_grad():
z = z / vae.config.scaling_factor
x = vae.decode(z).sample
return x
class SDVAEWrapper:
"""
Wrapper for Stability AI VAE (sd-vae-ft-mse).
Properties:
- ~84M parameters
- Latent dim: 4 channels @ 8x compression
- Higher quality reconstruction than TAESD
- Requires GPU for reasonable speed
Model on HF: stabilityai/sd-vae-ft-mse
"""
def __init__(self, device='cpu'):
self.device = device
self.model = None
@staticmethod
def load(device='cpu'):
"""Load SD VAE model."""
from diffusers import AutoencoderKL
model = AutoencoderKL.from_pretrained(
"stabilityai/sd-vae-ft-mse",
torch_dtype=torch.float32,
)
model = model.to(device)
model.eval()
return model
@staticmethod
def encode(vae, x):
"""Encode image to latent."""
with torch.no_grad():
posterior = vae.encode(x).latent_dist
z = posterior.sample()
z = z * vae.config.scaling_factor
return z
@staticmethod
def decode(vae, z):
"""Decode latent to image."""
with torch.no_grad():
z = z / vae.config.scaling_factor
x = vae.decode(z).sample
return x
|