| """ |
| VAE Wrappers — compatible VAE interfaces for LiquidFlow. |
| |
| Supports two VAE backends: |
| 1. TAESD (Tiny AutoEncoder for SD): < 1M params, extremely fast, perfect for mobile |
| 2. SD-VAE (Stability AI VAE): Higher quality, 84M params, standard for SD pipelines |
| |
| TAESD is the DEFAULT for LiquidFlow — it's designed to be lightweight and |
| fast enough for Colab/Kaggle free tier. |
| |
| Paper reference: "Tiny AutoEncoder for Stable Diffusion" (madebyollin/taesd) |
| Model: madebyollin/taesd (335K downloads on HF) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional |
|
|
|
|
| class TAESDWrapper: |
| """ |
| Wrapper for Tiny AutoEncoder for Stable Diffusion (TAESD). |
| |
| TAESD properties: |
| - ~1M parameters (vs 84M for SD VAE) |
| - Latent dim: 4 channels @ 8x compression |
| - Extremely fast encode/decode |
| - Works on CPU — no GPU needed |
| - Perfect for Colab/Kaggle free tier |
| |
| Model on HF: madebyollin/taesd |
| """ |
| |
| def __init__(self, device='cpu'): |
| self.device = device |
| self.model = None |
| |
| @staticmethod |
| def is_available(): |
| """Check if TAESD can be loaded.""" |
| try: |
| from diffusers import AutoencoderTiny |
| return True |
| except ImportError: |
| return False |
| |
| @staticmethod |
| def load(device='cpu'): |
| """Load TAESD model.""" |
| from diffusers import AutoencoderTiny |
| model = AutoencoderTiny.from_pretrained( |
| "madebyollin/taesd", |
| torch_dtype=torch.float32, |
| ) |
| model = model.to(device) |
| model.eval() |
| return model |
| |
| @staticmethod |
| def get_latent_shape(image_size): |
| """Get latent spatial size given image size (8x compression).""" |
| return image_size // 8 |
| |
| @staticmethod |
| def encode(vae, x): |
| """ |
| Encode image to latent. |
| Args: |
| vae: TAESD model |
| x: [B, 3, H, W] images in [-1, 1] |
| Returns: |
| z: [B, 4, H/8, W/8] |
| """ |
| with torch.no_grad(): |
| posterior = vae.encode(x).latent_dist |
| z = posterior.sample() |
| z = z * vae.config.scaling_factor |
| return z |
| |
| @staticmethod |
| def decode(vae, z): |
| """ |
| Decode latent to image. |
| Args: |
| vae: TAESD model |
| z: [B, 4, H/8, W/8] |
| Returns: |
| x: [B, 3, H, W] images in [-1, 1] |
| """ |
| with torch.no_grad(): |
| z = z / vae.config.scaling_factor |
| x = vae.decode(z).sample |
| return x |
|
|
|
|
| class SDVAEWrapper: |
| """ |
| Wrapper for Stability AI VAE (sd-vae-ft-mse). |
| |
| Properties: |
| - ~84M parameters |
| - Latent dim: 4 channels @ 8x compression |
| - Higher quality reconstruction than TAESD |
| - Requires GPU for reasonable speed |
| |
| Model on HF: stabilityai/sd-vae-ft-mse |
| """ |
| |
| def __init__(self, device='cpu'): |
| self.device = device |
| self.model = None |
| |
| @staticmethod |
| def load(device='cpu'): |
| """Load SD VAE model.""" |
| from diffusers import AutoencoderKL |
| model = AutoencoderKL.from_pretrained( |
| "stabilityai/sd-vae-ft-mse", |
| torch_dtype=torch.float32, |
| ) |
| model = model.to(device) |
| model.eval() |
| return model |
| |
| @staticmethod |
| def encode(vae, x): |
| """Encode image to latent.""" |
| with torch.no_grad(): |
| posterior = vae.encode(x).latent_dist |
| z = posterior.sample() |
| z = z * vae.config.scaling_factor |
| return z |
| |
| @staticmethod |
| def decode(vae, z): |
| """Decode latent to image.""" |
| with torch.no_grad(): |
| z = z / vae.config.scaling_factor |
| x = vae.decode(z).sample |
| return x |
|
|