""" VAE Wrappers — compatible VAE interfaces for LiquidFlow. Supports two VAE backends: 1. TAESD (Tiny AutoEncoder for SD): < 1M params, extremely fast, perfect for mobile 2. SD-VAE (Stability AI VAE): Higher quality, 84M params, standard for SD pipelines TAESD is the DEFAULT for LiquidFlow — it's designed to be lightweight and fast enough for Colab/Kaggle free tier. Paper reference: "Tiny AutoEncoder for Stable Diffusion" (madebyollin/taesd) Model: madebyollin/taesd (335K downloads on HF) """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional class TAESDWrapper: """ Wrapper for Tiny AutoEncoder for Stable Diffusion (TAESD). TAESD properties: - ~1M parameters (vs 84M for SD VAE) - Latent dim: 4 channels @ 8x compression - Extremely fast encode/decode - Works on CPU — no GPU needed - Perfect for Colab/Kaggle free tier Model on HF: madebyollin/taesd """ def __init__(self, device='cpu'): self.device = device self.model = None @staticmethod def is_available(): """Check if TAESD can be loaded.""" try: from diffusers import AutoencoderTiny return True except ImportError: return False @staticmethod def load(device='cpu'): """Load TAESD model.""" from diffusers import AutoencoderTiny model = AutoencoderTiny.from_pretrained( "madebyollin/taesd", torch_dtype=torch.float32, ) model = model.to(device) model.eval() return model @staticmethod def get_latent_shape(image_size): """Get latent spatial size given image size (8x compression).""" return image_size // 8 @staticmethod def encode(vae, x): """ Encode image to latent. Args: vae: TAESD model x: [B, 3, H, W] images in [-1, 1] Returns: z: [B, 4, H/8, W/8] """ with torch.no_grad(): posterior = vae.encode(x).latent_dist z = posterior.sample() z = z * vae.config.scaling_factor return z @staticmethod def decode(vae, z): """ Decode latent to image. Args: vae: TAESD model z: [B, 4, H/8, W/8] Returns: x: [B, 3, H, W] images in [-1, 1] """ with torch.no_grad(): z = z / vae.config.scaling_factor x = vae.decode(z).sample return x class SDVAEWrapper: """ Wrapper for Stability AI VAE (sd-vae-ft-mse). Properties: - ~84M parameters - Latent dim: 4 channels @ 8x compression - Higher quality reconstruction than TAESD - Requires GPU for reasonable speed Model on HF: stabilityai/sd-vae-ft-mse """ def __init__(self, device='cpu'): self.device = device self.model = None @staticmethod def load(device='cpu'): """Load SD VAE model.""" from diffusers import AutoencoderKL model = AutoencoderKL.from_pretrained( "stabilityai/sd-vae-ft-mse", torch_dtype=torch.float32, ) model = model.to(device) model.eval() return model @staticmethod def encode(vae, x): """Encode image to latent.""" with torch.no_grad(): posterior = vae.encode(x).latent_dist z = posterior.sample() z = z * vae.config.scaling_factor return z @staticmethod def decode(vae, z): """Decode latent to image.""" with torch.no_grad(): z = z / vae.config.scaling_factor x = vae.decode(z).sample return x