AlbeRota commited on
Commit
0f47de5
·
verified ·
1 Parent(s): bb509fc

Upload weights, notebooks, sample images

Browse files
Files changed (2) hide show
  1. .gitattributes +1 -0
  2. notebooks/UnReflectAnything.ipynb +3 -485
.gitattributes CHANGED
@@ -74,3 +74,4 @@ sample_images/sampleimage_6.png filter=lfs diff=lfs merge=lfs -text
74
  sample_images/sampleimage_7.png filter=lfs diff=lfs merge=lfs -text
75
  sample_images/sampleimage_8.png filter=lfs diff=lfs merge=lfs -text
76
  sample_images/sampleimage_9.png filter=lfs diff=lfs merge=lfs -text
 
 
74
  sample_images/sampleimage_7.png filter=lfs diff=lfs merge=lfs -text
75
  sample_images/sampleimage_8.png filter=lfs diff=lfs merge=lfs -text
76
  sample_images/sampleimage_9.png filter=lfs diff=lfs merge=lfs -text
77
+ notebooks/UnReflectAnything.ipynb filter=lfs diff=lfs merge=lfs -text
notebooks/UnReflectAnything.ipynb CHANGED
@@ -1,485 +1,3 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {},
6
- "source": [
7
- "# UnReflectAnything API & CLI Examples\n",
8
- "---\n",
9
- "\n",
10
- "### 1. Installation and assets download\n",
11
- "Ensure you have installed UnReflectAnything with \n",
12
- "```bash\n",
13
- "pip install unreflectanything\n",
14
- "```\n",
15
- "this will also install the CLI, which is also callable with aliases `unreflect` and `ura`. Verify installation and check the version with:\n",
16
- "```bash\n",
17
- "unreflectanything --help\n",
18
- "```\n",
19
- "```bash\n",
20
- "unreflect --version\n",
21
- "```\n",
22
- "```bash\n",
23
- "ura --version\n",
24
- "```"
25
- ]
26
- },
27
- {
28
- "cell_type": "code",
29
- "execution_count": 31,
30
- "metadata": {},
31
- "outputs": [
32
- {
33
- "name": "stdout",
34
- "output_type": "stream",
35
- "text": [
36
- "Using device: cuda\n"
37
- ]
38
- }
39
- ],
40
- "source": [
41
- "import torch\n",
42
- "from pathlib import Path\n",
43
- "\n",
44
- "# Import UnreflectAnything!\n",
45
- "import unreflectanything\n",
46
- "\n",
47
- "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
48
- "print(f\"Using device: {device}\")"
49
- ]
50
- },
51
- {
52
- "cell_type": "markdown",
53
- "metadata": {},
54
- "source": [
55
- "`pip install`ing UnReflectAnything does not download the pretrained model weights. Download them with the cli command\n",
56
- "```bash\n",
57
- "unrefleactanything download --weights\n",
58
- "```\n",
59
- "or"
60
- ]
61
- },
62
- {
63
- "cell_type": "code",
64
- "execution_count": 32,
65
- "metadata": {},
66
- "outputs": [
67
- {
68
- "data": {
69
- "application/vnd.jupyter.widget-view+json": {
70
- "model_id": "e2614e279f5f40609b18ac82d696b01d",
71
- "version_major": 2,
72
- "version_minor": 0
73
- },
74
- "text/plain": [
75
- "Fetching 4 files: 0%| | 0/4 [00:00<?, ?it/s]"
76
- ]
77
- },
78
- "metadata": {},
79
- "output_type": "display_data"
80
- },
81
- {
82
- "data": {
83
- "application/vnd.jupyter.widget-view+json": {
84
- "model_id": "7293fbadfcc74c438a3cb5a0e4df1ac2",
85
- "version_major": 2,
86
- "version_minor": 0
87
- },
88
- "text/plain": [
89
- "weights/diffuse_decoder.pt: 0%| | 0.00/418M [00:00<?, ?B/s]"
90
- ]
91
- },
92
- "metadata": {},
93
- "output_type": "display_data"
94
- },
95
- {
96
- "data": {
97
- "application/vnd.jupyter.widget-view+json": {
98
- "model_id": "f1253cb4b21d48c2b0c3780f4c4060ef",
99
- "version_major": 2,
100
- "version_minor": 0
101
- },
102
- "text/plain": [
103
- "weights/token_inpainter.pt: 0%| | 0.00/307M [00:00<?, ?B/s]"
104
- ]
105
- },
106
- "metadata": {},
107
- "output_type": "display_data"
108
- },
109
- {
110
- "data": {
111
- "application/vnd.jupyter.widget-view+json": {
112
- "model_id": "74ed792f61264beab4a890ea5ede8a58",
113
- "version_major": 2,
114
- "version_minor": 0
115
- },
116
- "text/plain": [
117
- "weights/highlight_decoder.pt: 0%| | 0.00/54.8M [00:00<?, ?B/s]"
118
- ]
119
- },
120
- "metadata": {},
121
- "output_type": "display_data"
122
- },
123
- {
124
- "data": {
125
- "application/vnd.jupyter.widget-view+json": {
126
- "model_id": "1ea92cd82ad34af68f06be823dd12060",
127
- "version_major": 2,
128
- "version_minor": 0
129
- },
130
- "text/plain": [
131
- "weights/full_model_weights.pt: 0%| | 0.00/3.55G [00:00<?, ?B/s]"
132
- ]
133
- },
134
- "metadata": {},
135
- "output_type": "display_data"
136
- },
137
- {
138
- "name": "stdout",
139
- "output_type": "stream",
140
- "text": [
141
- "Weights saved to /home/arota/.cache/unreflectanything/weights\n"
142
- ]
143
- }
144
- ],
145
- "source": [
146
- "weights_dir = unreflectanything.download(\"weights\")"
147
- ]
148
- },
149
- {
150
- "cell_type": "markdown",
151
- "id": "511c670a",
152
- "metadata": {},
153
- "source": [
154
- "Download some sample images which will be used in this notebook with \n",
155
- "```bash\n",
156
- "unrefleactanything download --images\n",
157
- "```\n",
158
- "or"
159
- ]
160
- },
161
- {
162
- "cell_type": "code",
163
- "execution_count": null,
164
- "id": "e0c60b28",
165
- "metadata": {},
166
- "outputs": [],
167
- "source": [
168
- "images_dir = unreflectanything.download(\"images\")"
169
- ]
170
- },
171
- {
172
- "cell_type": "markdown",
173
- "id": "36417577",
174
- "metadata": {},
175
- "source": [
176
- "### 2. Running UnReflectAnything with pretrained weights"
177
- ]
178
- },
179
- {
180
- "cell_type": "code",
181
- "execution_count": null,
182
- "id": "c318961c",
183
- "metadata": {},
184
- "outputs": [],
185
- "source": [
186
- "# Instantating the pretrained default UnreflectAnything model. \n",
187
- "unreflect = unreflectanything.model(device=device)"
188
- ]
189
- },
190
- {
191
- "cell_type": "code",
192
- "execution_count": null,
193
- "id": "448d5456",
194
- "metadata": {},
195
- "outputs": [],
196
- "source": [
197
- "from PIL import Image\n",
198
- "import numpy as np\n",
199
- "\n",
200
- "# Building a simple dataloader on a simple dataset that loads from a dir of images\n",
201
- "sample_dataset = unreflectanything.ImageDirDataset(images_dir)\n",
202
- "sample_dataloader = torch.utils.data.DataLoader(\n",
203
- " sample_dataset, batch_size=1, shuffle=False\n",
204
- ")\n",
205
- "\n",
206
- "# Threshold and Dilation in inpaint mask can be overridden; defaults 0.2 and 40\n",
207
- "THRESHOLD = 0.2\n",
208
- "DILATION = 40\n",
209
- "\n",
210
- "# Process and display only N images out of the full sample dataset\n",
211
- "DISPLAY_N_IMAGES = 2\n",
212
- "\n",
213
- "outputs = []\n",
214
- "for batch in sample_dataloader:\n",
215
- " # Forward pass\n",
216
- " batch_output = unreflect(\n",
217
- " batch.to(device), return_dict=True, threshold=THRESHOLD, dilation=DILATION\n",
218
- " )\n",
219
- " outputs.append(batch_output)\n",
220
- " if len(outputs) >= DISPLAY_N_IMAGES:\n",
221
- " break\n",
222
- "\n"
223
- ]
224
- },
225
- {
226
- "cell_type": "code",
227
- "execution_count": null,
228
- "id": "98c60bf4",
229
- "metadata": {},
230
- "outputs": [],
231
- "source": [
232
- "# Helper: Convert tensor [H, W, C] in [0,1] float32 to uint8 to display them \n",
233
- "def tensor_to_uint8_img(t):\n",
234
- " arr = t.permute(1, 2, 0).detach().numpy()\n",
235
- " arr = np.clip(arr, 0, 1)\n",
236
- " arr = (arr * 255).round().astype(np.uint8)\n",
237
- " return arr\n",
238
- "\n",
239
- "# Plotting a collage of the input, the diffuse output, and the highlight mask\n",
240
- "for input_batch, output_batch in zip(sample_dataloader, outputs):\n",
241
- " concat_images = torch.cat(\n",
242
- " [\n",
243
- " input_batch.cpu(),\n",
244
- " output_batch[\"diffuse\"].cpu(),\n",
245
- " output_batch[\"highlight\"].repeat(1, 3, 1, 1).cpu(), # \n",
246
- " ],\n",
247
- " dim=3,\n",
248
- " )\n",
249
- " for sample in concat_images:\n",
250
- " img_uint8 = tensor_to_uint8_img(sample)\n",
251
- " display(Image.fromarray(img_uint8))\n",
252
- " # break\n"
253
- ]
254
- },
255
- {
256
- "cell_type": "markdown",
257
- "id": "cf0e6ac6",
258
- "metadata": {},
259
- "source": [
260
- "### 3. Inference API and CLI endpoint\n",
261
- "The `inference` wrapper instantiates the UnReflectAnything model and calls its forward function is a single API call. It either:\n",
262
- "- Inputs a batched image tensor and outputs a batched image tensor\n",
263
- "- Inputs the path to an image (or directory of images) and saves the output results at a given path (of file or directory)\n",
264
- "- Inputs the path to an image and outputs a batched image tensor \n",
265
- "\n",
266
- "Some example CLI calls:\n",
267
- "```bash\n",
268
- "unreflect inference path/to/image/dir/ -o output/dir/ --threshold 0.3 --dilation 40\n",
269
- "```\n",
270
- "```bash\n",
271
- "unreflect inference path/to/image.png -o path/to/output.png --threshold 0.3 --dilation 40\n",
272
- "```"
273
- ]
274
- },
275
- {
276
- "cell_type": "code",
277
- "execution_count": null,
278
- "metadata": {},
279
- "outputs": [],
280
- "source": [
281
- "# Pick a sample image from the downloaded assets. `input` can also be the path to a dir\n",
282
- "input_path = list(images_dir.glob(\"*.png\"))[0]\n",
283
- "print(\"Input file: \", input_path)\n",
284
- "# Specify the outptut name. If `input` is a path to a dir, `output` should be too.\n",
285
- "output_path = Path(\"output_example.png\").resolve()\n",
286
- "print(\"Output file: \", output_path)\n",
287
- "\n",
288
- "unreflectanything.inference(\n",
289
- " input=input_path,\n",
290
- " output=output_path,\n",
291
- " device=device,\n",
292
- " threshold=THRESHOLD, \n",
293
- " dilation=DILATION, \n",
294
- ")\n",
295
- "\n",
296
- "# Loading the saved output and original input from files, then displaying them\n",
297
- "input_img = Image.open(input_path).convert(\"RGB\")\n",
298
- "output_img = Image.open(output_path).convert(\"RGB\")\n",
299
- "\n",
300
- "def to_tensor(img):\n",
301
- " return torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.\n",
302
- "\n",
303
- "input_tensor = to_tensor(input_img)\n",
304
- "output_tensor = to_tensor(output_img)\n",
305
- "concat = torch.cat([input_tensor, output_tensor], dim=2)\n",
306
- "concat_uint8 = (concat.permute(1,2,0).numpy() * 255).clip(0,255).astype(np.uint8)\n",
307
- "display(Image.fromarray(concat_uint8))"
308
- ]
309
- },
310
- {
311
- "cell_type": "code",
312
- "execution_count": null,
313
- "id": "5118ea92",
314
- "metadata": {},
315
- "outputs": [],
316
- "source": [
317
- "print(\"Equivalent CLI command:\\n\")\n",
318
- "print(f\"unreflect inference {input_path} -o {output_path} --threshold {THRESHOLD} --dilation {DILATION}\")"
319
- ]
320
- },
321
- {
322
- "cell_type": "markdown",
323
- "id": "562c17f1",
324
- "metadata": {},
325
- "source": [
326
- "`inference` initializes the model every time by default. To run it without this step, pass them model to the API call"
327
- ]
328
- },
329
- {
330
- "cell_type": "code",
331
- "execution_count": null,
332
- "id": "89fbbf62",
333
- "metadata": {},
334
- "outputs": [],
335
- "source": [
336
- "# Pick a sample image from the downloaded assets. `input` can also be the path to a dir\n",
337
- "input_path = list(images_dir.glob(\"*.png\"))[6]\n",
338
- "# Specify the outptu name\n",
339
- "output_path = Path(\"output_example.png\")\n",
340
- " \n",
341
- "unreflectanything.inference(\n",
342
- " model=unreflect, # <<<<<<<<< Pass the model instance and it won't be loaded at every `inference` call\n",
343
- " input=input_path,\n",
344
- " output=output_path,\n",
345
- " device=device,\n",
346
- " threshold=THRESHOLD, \n",
347
- " dilation=DILATION, \n",
348
- ")\n",
349
- "\n",
350
- "# Loading the saved output and original input from files, then displaying them\n",
351
- "input_img = Image.open(input_path).convert(\"RGB\")\n",
352
- "output_img = Image.open(output_path).convert(\"RGB\")\n",
353
- "\n",
354
- "def to_tensor(img):\n",
355
- " return torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.\n",
356
- "\n",
357
- "input_tensor = to_tensor(input_img)\n",
358
- "output_tensor = to_tensor(output_img)\n",
359
- "concat = torch.cat([input_tensor, output_tensor], dim=2)\n",
360
- "concat_uint8 = (concat.permute(1,2,0).numpy() * 255).clip(0,255).astype(np.uint8)\n",
361
- "display(Image.fromarray(concat_uint8))"
362
- ]
363
- },
364
- {
365
- "cell_type": "markdown",
366
- "id": "57441af0",
367
- "metadata": {},
368
- "source": [
369
- "### 4. The Cache Directory\n",
370
- "\n",
371
- "`unreflectanything download` saves the downloaded asset in your system cache. Print this path with\n",
372
- "```bash\n",
373
- "unreflectanything cache --dir\n",
374
- "```\n",
375
- "or clear the cache with \n",
376
- "```bash\n",
377
- "unreflectanything cache --clear\n",
378
- "```\n",
379
- "The same endopoints are also on the API"
380
- ]
381
- },
382
- {
383
- "cell_type": "code",
384
- "execution_count": null,
385
- "id": "b331050d",
386
- "metadata": {},
387
- "outputs": [],
388
- "source": [
389
- "unreflectanything.cache(\"dir\") # Also unreflectanything.cache()\n",
390
- "unreflectanything.cache(\"clear\")\n",
391
- "# unreflectanything.cache.clear()"
392
- ]
393
- },
394
- {
395
- "cell_type": "markdown",
396
- "metadata": {},
397
- "source": [
398
- "## 4. Verify Assets\n",
399
- "\n",
400
- "You can verify that the weights are correctly downloaded and loadable."
401
- ]
402
- },
403
- {
404
- "cell_type": "code",
405
- "execution_count": null,
406
- "metadata": {},
407
- "outputs": [],
408
- "source": [
409
- "is_valid = unreflectanything.verify(\"weights\")"
410
- ]
411
- },
412
- {
413
- "cell_type": "markdown",
414
- "metadata": {},
415
- "source": [
416
- "### CLI Equivalent\n",
417
- "\n",
418
- "```bash\n",
419
- "unreflect verify --weights\n",
420
- "```\n",
421
- "```bash\n",
422
- "unreflect verify --weights\n",
423
- "```"
424
- ]
425
- },
426
- {
427
- "cell_type": "code",
428
- "execution_count": null,
429
- "id": "85947e38",
430
- "metadata": {},
431
- "outputs": [],
432
- "source": []
433
- },
434
- {
435
- "cell_type": "markdown",
436
- "metadata": {},
437
- "source": [
438
- "## 5. Cite\n",
439
- "\n",
440
- "If you use UnReflectAnything in your research, please cite it:"
441
- ]
442
- },
443
- {
444
- "cell_type": "code",
445
- "execution_count": null,
446
- "metadata": {},
447
- "outputs": [],
448
- "source": [
449
- "print(ura.cite(format=\"bibtex\"))"
450
- ]
451
- },
452
- {
453
- "cell_type": "markdown",
454
- "metadata": {},
455
- "source": [
456
- "### CLI Equivalent\n",
457
- "\n",
458
- "```bash\n",
459
- "unreflect cite --bibtex\n",
460
- "```"
461
- ]
462
- }
463
- ],
464
- "metadata": {
465
- "kernelspec": {
466
- "display_name": "Python 3 (ipykernel)",
467
- "language": "python",
468
- "name": "python3"
469
- },
470
- "language_info": {
471
- "codemirror_mode": {
472
- "name": "ipython",
473
- "version": 3
474
- },
475
- "file_extension": ".py",
476
- "mimetype": "text/x-python",
477
- "name": "python",
478
- "nbconvert_exporter": "python",
479
- "pygments_lexer": "ipython3",
480
- "version": "3.12.3"
481
- }
482
- },
483
- "nbformat": 4,
484
- "nbformat_minor": 5
485
- }
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9be5cfa9f72f715a1a043bf88e201ada87c4344b18356e14b89f13c79f63213
3
+ size 12164640