File size: 3,823 Bytes
b9e8cb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
VAE Wrappers — compatible VAE interfaces for LiquidFlow.

Supports two VAE backends:
1. TAESD (Tiny AutoEncoder for SD): < 1M params, extremely fast, perfect for mobile
2. SD-VAE (Stability AI VAE): Higher quality, 84M params, standard for SD pipelines

TAESD is the DEFAULT for LiquidFlow — it's designed to be lightweight and
fast enough for Colab/Kaggle free tier.

Paper reference: "Tiny AutoEncoder for Stable Diffusion" (madebyollin/taesd)
Model: madebyollin/taesd (335K downloads on HF)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional


class TAESDWrapper:
    """
    Wrapper for Tiny AutoEncoder for Stable Diffusion (TAESD).
    
    TAESD properties:
    - ~1M parameters (vs 84M for SD VAE)
    - Latent dim: 4 channels @ 8x compression
    - Extremely fast encode/decode
    - Works on CPU — no GPU needed
    - Perfect for Colab/Kaggle free tier
    
    Model on HF: madebyollin/taesd
    """
    
    def __init__(self, device='cpu'):
        self.device = device
        self.model = None
        
    @staticmethod
    def is_available():
        """Check if TAESD can be loaded."""
        try:
            from diffusers import AutoencoderTiny
            return True
        except ImportError:
            return False
    
    @staticmethod
    def load(device='cpu'):
        """Load TAESD model."""
        from diffusers import AutoencoderTiny
        model = AutoencoderTiny.from_pretrained(
            "madebyollin/taesd",
            torch_dtype=torch.float32,
        )
        model = model.to(device)
        model.eval()
        return model
    
    @staticmethod
    def get_latent_shape(image_size):
        """Get latent spatial size given image size (8x compression)."""
        return image_size // 8
    
    @staticmethod
    def encode(vae, x):
        """
        Encode image to latent.
        Args:
            vae: TAESD model
            x: [B, 3, H, W] images in [-1, 1]
        Returns:
            z: [B, 4, H/8, W/8]
        """
        with torch.no_grad():
            posterior = vae.encode(x).latent_dist
            z = posterior.sample()
            z = z * vae.config.scaling_factor
        return z
    
    @staticmethod
    def decode(vae, z):
        """
        Decode latent to image.
        Args:
            vae: TAESD model  
            z: [B, 4, H/8, W/8]
        Returns:
            x: [B, 3, H, W] images in [-1, 1]
        """
        with torch.no_grad():
            z = z / vae.config.scaling_factor
            x = vae.decode(z).sample
        return x


class SDVAEWrapper:
    """
    Wrapper for Stability AI VAE (sd-vae-ft-mse).
    
    Properties:
    - ~84M parameters
    - Latent dim: 4 channels @ 8x compression
    - Higher quality reconstruction than TAESD
    - Requires GPU for reasonable speed
    
    Model on HF: stabilityai/sd-vae-ft-mse
    """
    
    def __init__(self, device='cpu'):
        self.device = device
        self.model = None
    
    @staticmethod
    def load(device='cpu'):
        """Load SD VAE model."""
        from diffusers import AutoencoderKL
        model = AutoencoderKL.from_pretrained(
            "stabilityai/sd-vae-ft-mse",
            torch_dtype=torch.float32,
        )
        model = model.to(device)
        model.eval()
        return model
    
    @staticmethod
    def encode(vae, x):
        """Encode image to latent."""
        with torch.no_grad():
            posterior = vae.encode(x).latent_dist
            z = posterior.sample()
            z = z * vae.config.scaling_factor
        return z
    
    @staticmethod
    def decode(vae, z):
        """Decode latent to image."""
        with torch.no_grad():
            z = z / vae.config.scaling_factor
            x = vae.decode(z).sample
        return x