krystv commited on
Commit
b9e8cb3
·
verified ·
1 Parent(s): 40a4412

Upload liquid_flow/vae_wrapper.py

Browse files
Files changed (1) hide show
  1. liquid_flow/vae_wrapper.py +141 -0
liquid_flow/vae_wrapper.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VAE Wrappers — compatible VAE interfaces for LiquidFlow.
3
+
4
+ Supports two VAE backends:
5
+ 1. TAESD (Tiny AutoEncoder for SD): < 1M params, extremely fast, perfect for mobile
6
+ 2. SD-VAE (Stability AI VAE): Higher quality, 84M params, standard for SD pipelines
7
+
8
+ TAESD is the DEFAULT for LiquidFlow — it's designed to be lightweight and
9
+ fast enough for Colab/Kaggle free tier.
10
+
11
+ Paper reference: "Tiny AutoEncoder for Stable Diffusion" (madebyollin/taesd)
12
+ Model: madebyollin/taesd (335K downloads on HF)
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from typing import Optional
19
+
20
+
21
+ class TAESDWrapper:
22
+ """
23
+ Wrapper for Tiny AutoEncoder for Stable Diffusion (TAESD).
24
+
25
+ TAESD properties:
26
+ - ~1M parameters (vs 84M for SD VAE)
27
+ - Latent dim: 4 channels @ 8x compression
28
+ - Extremely fast encode/decode
29
+ - Works on CPU — no GPU needed
30
+ - Perfect for Colab/Kaggle free tier
31
+
32
+ Model on HF: madebyollin/taesd
33
+ """
34
+
35
+ def __init__(self, device='cpu'):
36
+ self.device = device
37
+ self.model = None
38
+
39
+ @staticmethod
40
+ def is_available():
41
+ """Check if TAESD can be loaded."""
42
+ try:
43
+ from diffusers import AutoencoderTiny
44
+ return True
45
+ except ImportError:
46
+ return False
47
+
48
+ @staticmethod
49
+ def load(device='cpu'):
50
+ """Load TAESD model."""
51
+ from diffusers import AutoencoderTiny
52
+ model = AutoencoderTiny.from_pretrained(
53
+ "madebyollin/taesd",
54
+ torch_dtype=torch.float32,
55
+ )
56
+ model = model.to(device)
57
+ model.eval()
58
+ return model
59
+
60
+ @staticmethod
61
+ def get_latent_shape(image_size):
62
+ """Get latent spatial size given image size (8x compression)."""
63
+ return image_size // 8
64
+
65
+ @staticmethod
66
+ def encode(vae, x):
67
+ """
68
+ Encode image to latent.
69
+ Args:
70
+ vae: TAESD model
71
+ x: [B, 3, H, W] images in [-1, 1]
72
+ Returns:
73
+ z: [B, 4, H/8, W/8]
74
+ """
75
+ with torch.no_grad():
76
+ posterior = vae.encode(x).latent_dist
77
+ z = posterior.sample()
78
+ z = z * vae.config.scaling_factor
79
+ return z
80
+
81
+ @staticmethod
82
+ def decode(vae, z):
83
+ """
84
+ Decode latent to image.
85
+ Args:
86
+ vae: TAESD model
87
+ z: [B, 4, H/8, W/8]
88
+ Returns:
89
+ x: [B, 3, H, W] images in [-1, 1]
90
+ """
91
+ with torch.no_grad():
92
+ z = z / vae.config.scaling_factor
93
+ x = vae.decode(z).sample
94
+ return x
95
+
96
+
97
+ class SDVAEWrapper:
98
+ """
99
+ Wrapper for Stability AI VAE (sd-vae-ft-mse).
100
+
101
+ Properties:
102
+ - ~84M parameters
103
+ - Latent dim: 4 channels @ 8x compression
104
+ - Higher quality reconstruction than TAESD
105
+ - Requires GPU for reasonable speed
106
+
107
+ Model on HF: stabilityai/sd-vae-ft-mse
108
+ """
109
+
110
+ def __init__(self, device='cpu'):
111
+ self.device = device
112
+ self.model = None
113
+
114
+ @staticmethod
115
+ def load(device='cpu'):
116
+ """Load SD VAE model."""
117
+ from diffusers import AutoencoderKL
118
+ model = AutoencoderKL.from_pretrained(
119
+ "stabilityai/sd-vae-ft-mse",
120
+ torch_dtype=torch.float32,
121
+ )
122
+ model = model.to(device)
123
+ model.eval()
124
+ return model
125
+
126
+ @staticmethod
127
+ def encode(vae, x):
128
+ """Encode image to latent."""
129
+ with torch.no_grad():
130
+ posterior = vae.encode(x).latent_dist
131
+ z = posterior.sample()
132
+ z = z * vae.config.scaling_factor
133
+ return z
134
+
135
+ @staticmethod
136
+ def decode(vae, z):
137
+ """Decode latent to image."""
138
+ with torch.no_grad():
139
+ z = z / vae.config.scaling_factor
140
+ x = vae.decode(z).sample
141
+ return x