recoilme commited on
Commit
47237d5
·
verified ·
1 Parent(s): 46a735a

Upload folder using huggingface_hub

Browse files
.ipynb_checkpoints/vae_comp-checkpoint.ipynb ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "b3b23a40-8354-4287-bac2-32f9d084fff3",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_validators.py:202: UserWarning: The `local_dir_use_symlinks` argument is deprecated and ignored in `hf_hub_download`. Downloading to a local directory does not use symlinks anymore.\n",
14
+ " warnings.warn(\n"
15
+ ]
16
+ },
17
+ {
18
+ "data": {
19
+ "application/vnd.jupyter.widget-view+json": {
20
+ "model_id": "96d38ff0fa134b02a5a21c96bdfd36b5",
21
+ "version_major": 2,
22
+ "version_minor": 0
23
+ },
24
+ "text/plain": [
25
+ "vae/config.json: 0%| | 0.00/752 [00:00<?, ?B/s]"
26
+ ]
27
+ },
28
+ "metadata": {},
29
+ "output_type": "display_data"
30
+ },
31
+ {
32
+ "data": {
33
+ "application/vnd.jupyter.widget-view+json": {
34
+ "model_id": "0a44c60705d44f58b5a07ead45936327",
35
+ "version_major": 2,
36
+ "version_minor": 0
37
+ },
38
+ "text/plain": [
39
+ "vae/diffusion_pytorch_model.safetensors: 0%| | 0.00/191M [00:00<?, ?B/s]"
40
+ ]
41
+ },
42
+ "metadata": {},
43
+ "output_type": "display_data"
44
+ },
45
+ {
46
+ "name": "stdout",
47
+ "output_type": "stream",
48
+ "text": [
49
+ "sdxs_vae log-variance: 1.840\n"
50
+ ]
51
+ },
52
+ {
53
+ "name": "stderr",
54
+ "output_type": "stream",
55
+ "text": [
56
+ "The config attributes {'block_out_channels': [128, 128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n"
57
+ ]
58
+ },
59
+ {
60
+ "name": "stdout",
61
+ "output_type": "stream",
62
+ "text": [
63
+ "vae9 log-variance: 1.840\n",
64
+ "Готово\n"
65
+ ]
66
+ }
67
+ ],
68
+ "source": [
69
+ "import torch\n",
70
+ "from PIL import Image\n",
71
+ "from diffusers import AutoencoderKL,AsymmetricAutoencoderKL\n",
72
+ "from torchvision.transforms.functional import to_pil_image\n",
73
+ "import matplotlib.pyplot as plt\n",
74
+ "import os\n",
75
+ "from torchvision.transforms import ToTensor, Normalize, CenterCrop\n",
76
+ "\n",
77
+ "# путь к вашей картинке\n",
78
+ "IMG_PATH = \"1234567890.png\"\n",
79
+ "OUT_DIR = \"vaetest\"\n",
80
+ "device = \"cuda\"\n",
81
+ "dtype = torch.float32 # ← единый float32\n",
82
+ "os.makedirs(OUT_DIR, exist_ok=True)\n",
83
+ "\n",
84
+ "# список VAE\n",
85
+ "VAES = {\n",
86
+ " #\"sdxl\": \"madebyollin/sdxl-vae-fp16-fix\",\n",
87
+ " \"sdxs_vae\": \"AiArtLab/sdxs-1b\",\n",
88
+ " #\"vae8\": \"/workspace/simplevae2x/vae8\",\n",
89
+ " \"vae9\": \"/workspace/simplevae2x/vae9\"\n",
90
+ "}\n",
91
+ "\n",
92
+ "def load_image(path):\n",
93
+ " img = Image.open(path).convert('RGB')\n",
94
+ " # обрезаем до кратности 8\n",
95
+ " w, h = img.size\n",
96
+ " img = CenterCrop((h // 8 * 8, w // 8 * 8))(img)\n",
97
+ " tensor = ToTensor()(img).unsqueeze(0) # [0,1]\n",
98
+ " tensor = Normalize(mean=[0.5]*3, std=[0.5]*3)(tensor) # [-1,1]\n",
99
+ " return img, tensor.to(device, dtype=dtype)\n",
100
+ "\n",
101
+ "# обратно в PIL\n",
102
+ "def tensor_to_img(t):\n",
103
+ " t = (t * 0.5 + 0.5).clamp(0, 1)\n",
104
+ " return to_pil_image(t[0])\n",
105
+ "\n",
106
+ "def logvariance(latents):\n",
107
+ " \"\"\"Возвращает лог-дисперсию по всем элементам.\"\"\"\n",
108
+ " return torch.log(latents.var() + 1e-8).item()\n",
109
+ "\n",
110
+ "def plot_latent_distribution(latents, title, save_path):\n",
111
+ " \"\"\"Гистограмма + QQ-plot.\"\"\"\n",
112
+ " lat = latents.detach().cpu().numpy().flatten()\n",
113
+ " plt.figure(figsize=(10, 4))\n",
114
+ "\n",
115
+ " # гистограмма\n",
116
+ " plt.subplot(1, 2, 1)\n",
117
+ " plt.hist(lat, bins=100, density=True, alpha=0.7, color='steelblue')\n",
118
+ " plt.title(f\"{title} histogram\")\n",
119
+ " plt.xlabel(\"latent value\")\n",
120
+ " plt.ylabel(\"density\")\n",
121
+ "\n",
122
+ " # QQ-plot\n",
123
+ " from scipy.stats import probplot\n",
124
+ " plt.subplot(1, 2, 2)\n",
125
+ " probplot(lat, dist=\"norm\", plot=plt)\n",
126
+ " plt.title(f\"{title} QQ-plot\")\n",
127
+ "\n",
128
+ " plt.tight_layout()\n",
129
+ " plt.savefig(save_path)\n",
130
+ " plt.close()\n",
131
+ "\n",
132
+ "for name, repo in VAES.items():\n",
133
+ " if name==\"sdxs_vae\":\n",
134
+ " vae = AsymmetricAutoencoderKL.from_pretrained(repo, subfolder=\"vae\", torch_dtype=dtype).to(device)\n",
135
+ " else:\n",
136
+ " vae = AsymmetricAutoencoderKL.from_pretrained(repo, torch_dtype=dtype).to(device)#, subfolder=\"vae\", variant=\"fp16\"\n",
137
+ "\n",
138
+ " cfg = vae.config\n",
139
+ " scale = getattr(cfg, \"scaling_factor\", 1.)\n",
140
+ " shift = getattr(cfg, \"shift_factor\", 0.0)\n",
141
+ " mean = getattr(cfg, \"latents_mean\", None)\n",
142
+ " std = getattr(cfg, \"latents_std\", None)\n",
143
+ "\n",
144
+ " C = 4 # 4 для SDXL\n",
145
+ " if mean is not None:\n",
146
+ " mean = torch.tensor(mean, device=device, dtype=dtype).view(1, C, 1, 1)\n",
147
+ " if std is not None:\n",
148
+ " std = torch.tensor(std, device=device, dtype=dtype).view(1, C, 1, 1)\n",
149
+ " if shift is not None:\n",
150
+ " shift = torch.tensor(shift, device=device, dtype=dtype)\n",
151
+ " else:\n",
152
+ " shift = 0.0 \n",
153
+ "\n",
154
+ " scale = torch.tensor(scale, device=device, dtype=dtype)\n",
155
+ "\n",
156
+ " img, x = load_image(IMG_PATH)\n",
157
+ " img.save(os.path.join(OUT_DIR, f\"original.png\"))\n",
158
+ "\n",
159
+ " with torch.no_grad():\n",
160
+ " # encode\n",
161
+ " latents = vae.encode(x).latent_dist.sample().to(dtype)\n",
162
+ " if mean is not None and std is not None:\n",
163
+ " latents = (latents - mean) / std\n",
164
+ " latents = latents * scale + shift\n",
165
+ "\n",
166
+ " lv = logvariance(latents)\n",
167
+ " print(f\"{name} log-variance: {lv:.3f}\")\n",
168
+ "\n",
169
+ " # график\n",
170
+ " plot_latent_distribution(latents, f\"{name}_latents\",\n",
171
+ " os.path.join(OUT_DIR, f\"dist_{name}.png\"))\n",
172
+ "\n",
173
+ " # decode\n",
174
+ " latents = (latents - shift) / scale\n",
175
+ " if mean is not None and std is not None:\n",
176
+ " latents = latents * std + mean\n",
177
+ " rec = vae.decode(latents).sample\n",
178
+ "\n",
179
+ " tensor_to_img(rec).save(os.path.join(OUT_DIR, f\"decoded_{name}.png\"))\n",
180
+ "\n",
181
+ "print(\"Готово\")"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": null,
187
+ "id": "200b72ab-1978-4d71-9aba-b1ef97cf0b27",
188
+ "metadata": {},
189
+ "outputs": [],
190
+ "source": []
191
+ }
192
+ ],
193
+ "metadata": {
194
+ "kernelspec": {
195
+ "display_name": "Python 3 (ipykernel)",
196
+ "language": "python",
197
+ "name": "python3"
198
+ },
199
+ "language_info": {
200
+ "codemirror_mode": {
201
+ "name": "ipython",
202
+ "version": 3
203
+ },
204
+ "file_extension": ".py",
205
+ "mimetype": "text/x-python",
206
+ "name": "python",
207
+ "nbconvert_exporter": "python",
208
+ "pygments_lexer": "ipython3",
209
+ "version": "3.11.10"
210
+ }
211
+ },
212
+ "nbformat": 4,
213
+ "nbformat_minor": 5
214
+ }
vae_comp.ipynb ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "b3b23a40-8354-4287-bac2-32f9d084fff3",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_validators.py:202: UserWarning: The `local_dir_use_symlinks` argument is deprecated and ignored in `hf_hub_download`. Downloading to a local directory does not use symlinks anymore.\n",
14
+ " warnings.warn(\n"
15
+ ]
16
+ },
17
+ {
18
+ "data": {
19
+ "application/vnd.jupyter.widget-view+json": {
20
+ "model_id": "96d38ff0fa134b02a5a21c96bdfd36b5",
21
+ "version_major": 2,
22
+ "version_minor": 0
23
+ },
24
+ "text/plain": [
25
+ "vae/config.json: 0%| | 0.00/752 [00:00<?, ?B/s]"
26
+ ]
27
+ },
28
+ "metadata": {},
29
+ "output_type": "display_data"
30
+ },
31
+ {
32
+ "data": {
33
+ "application/vnd.jupyter.widget-view+json": {
34
+ "model_id": "0a44c60705d44f58b5a07ead45936327",
35
+ "version_major": 2,
36
+ "version_minor": 0
37
+ },
38
+ "text/plain": [
39
+ "vae/diffusion_pytorch_model.safetensors: 0%| | 0.00/191M [00:00<?, ?B/s]"
40
+ ]
41
+ },
42
+ "metadata": {},
43
+ "output_type": "display_data"
44
+ },
45
+ {
46
+ "name": "stdout",
47
+ "output_type": "stream",
48
+ "text": [
49
+ "sdxs_vae log-variance: 1.840\n"
50
+ ]
51
+ },
52
+ {
53
+ "name": "stderr",
54
+ "output_type": "stream",
55
+ "text": [
56
+ "The config attributes {'block_out_channels': [128, 128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n"
57
+ ]
58
+ },
59
+ {
60
+ "name": "stdout",
61
+ "output_type": "stream",
62
+ "text": [
63
+ "vae9 log-variance: 1.840\n",
64
+ "Готово\n"
65
+ ]
66
+ }
67
+ ],
68
+ "source": [
69
+ "import torch\n",
70
+ "from PIL import Image\n",
71
+ "from diffusers import AutoencoderKL,AsymmetricAutoencoderKL\n",
72
+ "from torchvision.transforms.functional import to_pil_image\n",
73
+ "import matplotlib.pyplot as plt\n",
74
+ "import os\n",
75
+ "from torchvision.transforms import ToTensor, Normalize, CenterCrop\n",
76
+ "\n",
77
+ "# путь к вашей картинке\n",
78
+ "IMG_PATH = \"1234567890.png\"\n",
79
+ "OUT_DIR = \"vaetest\"\n",
80
+ "device = \"cuda\"\n",
81
+ "dtype = torch.float32 # ← единый float32\n",
82
+ "os.makedirs(OUT_DIR, exist_ok=True)\n",
83
+ "\n",
84
+ "# список VAE\n",
85
+ "VAES = {\n",
86
+ " #\"sdxl\": \"madebyollin/sdxl-vae-fp16-fix\",\n",
87
+ " \"sdxs_vae\": \"AiArtLab/sdxs-1b\",\n",
88
+ " #\"vae8\": \"/workspace/simplevae2x/vae8\",\n",
89
+ " \"vae9\": \"/workspace/simplevae2x/vae9\"\n",
90
+ "}\n",
91
+ "\n",
92
+ "def load_image(path):\n",
93
+ " img = Image.open(path).convert('RGB')\n",
94
+ " # обрезаем до кратности 8\n",
95
+ " w, h = img.size\n",
96
+ " img = CenterCrop((h // 8 * 8, w // 8 * 8))(img)\n",
97
+ " tensor = ToTensor()(img).unsqueeze(0) # [0,1]\n",
98
+ " tensor = Normalize(mean=[0.5]*3, std=[0.5]*3)(tensor) # [-1,1]\n",
99
+ " return img, tensor.to(device, dtype=dtype)\n",
100
+ "\n",
101
+ "# обратно в PIL\n",
102
+ "def tensor_to_img(t):\n",
103
+ " t = (t * 0.5 + 0.5).clamp(0, 1)\n",
104
+ " return to_pil_image(t[0])\n",
105
+ "\n",
106
+ "def logvariance(latents):\n",
107
+ " \"\"\"Возвращает лог-дисперсию по всем элементам.\"\"\"\n",
108
+ " return torch.log(latents.var() + 1e-8).item()\n",
109
+ "\n",
110
+ "def plot_latent_distribution(latents, title, save_path):\n",
111
+ " \"\"\"Гистограмма + QQ-plot.\"\"\"\n",
112
+ " lat = latents.detach().cpu().numpy().flatten()\n",
113
+ " plt.figure(figsize=(10, 4))\n",
114
+ "\n",
115
+ " # гистограмма\n",
116
+ " plt.subplot(1, 2, 1)\n",
117
+ " plt.hist(lat, bins=100, density=True, alpha=0.7, color='steelblue')\n",
118
+ " plt.title(f\"{title} histogram\")\n",
119
+ " plt.xlabel(\"latent value\")\n",
120
+ " plt.ylabel(\"density\")\n",
121
+ "\n",
122
+ " # QQ-plot\n",
123
+ " from scipy.stats import probplot\n",
124
+ " plt.subplot(1, 2, 2)\n",
125
+ " probplot(lat, dist=\"norm\", plot=plt)\n",
126
+ " plt.title(f\"{title} QQ-plot\")\n",
127
+ "\n",
128
+ " plt.tight_layout()\n",
129
+ " plt.savefig(save_path)\n",
130
+ " plt.close()\n",
131
+ "\n",
132
+ "for name, repo in VAES.items():\n",
133
+ " if name==\"sdxs_vae\":\n",
134
+ " vae = AsymmetricAutoencoderKL.from_pretrained(repo, subfolder=\"vae\", torch_dtype=dtype).to(device)\n",
135
+ " else:\n",
136
+ " vae = AsymmetricAutoencoderKL.from_pretrained(repo, torch_dtype=dtype).to(device)#, subfolder=\"vae\", variant=\"fp16\"\n",
137
+ "\n",
138
+ " cfg = vae.config\n",
139
+ " scale = getattr(cfg, \"scaling_factor\", 1.)\n",
140
+ " shift = getattr(cfg, \"shift_factor\", 0.0)\n",
141
+ " mean = getattr(cfg, \"latents_mean\", None)\n",
142
+ " std = getattr(cfg, \"latents_std\", None)\n",
143
+ "\n",
144
+ " C = 4 # 4 для SDXL\n",
145
+ " if mean is not None:\n",
146
+ " mean = torch.tensor(mean, device=device, dtype=dtype).view(1, C, 1, 1)\n",
147
+ " if std is not None:\n",
148
+ " std = torch.tensor(std, device=device, dtype=dtype).view(1, C, 1, 1)\n",
149
+ " if shift is not None:\n",
150
+ " shift = torch.tensor(shift, device=device, dtype=dtype)\n",
151
+ " else:\n",
152
+ " shift = 0.0 \n",
153
+ "\n",
154
+ " scale = torch.tensor(scale, device=device, dtype=dtype)\n",
155
+ "\n",
156
+ " img, x = load_image(IMG_PATH)\n",
157
+ " img.save(os.path.join(OUT_DIR, f\"original.png\"))\n",
158
+ "\n",
159
+ " with torch.no_grad():\n",
160
+ " # encode\n",
161
+ " latents = vae.encode(x).latent_dist.sample().to(dtype)\n",
162
+ " if mean is not None and std is not None:\n",
163
+ " latents = (latents - mean) / std\n",
164
+ " latents = latents * scale + shift\n",
165
+ "\n",
166
+ " lv = logvariance(latents)\n",
167
+ " print(f\"{name} log-variance: {lv:.3f}\")\n",
168
+ "\n",
169
+ " # график\n",
170
+ " plot_latent_distribution(latents, f\"{name}_latents\",\n",
171
+ " os.path.join(OUT_DIR, f\"dist_{name}.png\"))\n",
172
+ "\n",
173
+ " # decode\n",
174
+ " latents = (latents - shift) / scale\n",
175
+ " if mean is not None and std is not None:\n",
176
+ " latents = latents * std + mean\n",
177
+ " rec = vae.decode(latents).sample\n",
178
+ "\n",
179
+ " tensor_to_img(rec).save(os.path.join(OUT_DIR, f\"decoded_{name}.png\"))\n",
180
+ "\n",
181
+ "print(\"Готово\")"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": null,
187
+ "id": "200b72ab-1978-4d71-9aba-b1ef97cf0b27",
188
+ "metadata": {},
189
+ "outputs": [],
190
+ "source": []
191
+ }
192
+ ],
193
+ "metadata": {
194
+ "kernelspec": {
195
+ "display_name": "Python 3 (ipykernel)",
196
+ "language": "python",
197
+ "name": "python3"
198
+ },
199
+ "language_info": {
200
+ "codemirror_mode": {
201
+ "name": "ipython",
202
+ "version": 3
203
+ },
204
+ "file_extension": ".py",
205
+ "mimetype": "text/x-python",
206
+ "name": "python",
207
+ "nbconvert_exporter": "python",
208
+ "pygments_lexer": "ipython3",
209
+ "version": "3.11.10"
210
+ }
211
+ },
212
+ "nbformat": 4,
213
+ "nbformat_minor": 5
214
+ }