qic999 commited on
Commit
283ab7b
·
verified ·
1 Parent(s): a38b2c1

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +33 -0
  2. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/LICENSE +9 -0
  3. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/README.md +255 -0
  4. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/__pycache__/main.cpython-310.pyc +0 -0
  5. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/__pycache__/main.cpython-311.pyc +0 -0
  6. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/__pycache__/main.cpython-39.pyc +0 -0
  7. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/checkpoints/instruct-pix2pix-00-22000.ckpt +3 -0
  8. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/config.sh +6 -0
  9. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/configs/generate.yaml +99 -0
  10. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/configs/infer.yaml +108 -0
  11. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/configs/infer3d.yaml +111 -0
  12. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/configs/train.yaml +107 -0
  13. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/configs/train_3d.yaml +109 -0
  14. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/dataset_creation/generate_img_dataset.py +315 -0
  15. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/dataset_creation/generate_txt_dataset.py +113 -0
  16. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/dataset_creation/prepare_dataset.py +29 -0
  17. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/dataset_creation/prepare_for_gpt.py +25 -0
  18. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/edit_app.py +268 -0
  19. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/edit_cli.py +128 -0
  20. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/edit_dataset.py +121 -0
  21. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/environment.yaml +38 -0
  22. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/imgs/dataset.jpg +3 -0
  23. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/imgs/edit_app.jpg +3 -0
  24. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/imgs/example.jpg +0 -0
  25. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/imgs/prompt_app.jpg +0 -0
  26. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/inference_full_ct_2d_with_body_mask.py +139 -0
  27. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/inference_full_ct_3d_with_body_mask.py +201 -0
  28. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/inference_full_ct_3d_with_body_mask_v2.py +197 -0
  29. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/checkpoints/epoch=000096.ckpt +3 -0
  30. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/checkpoints/epoch=000097.ckpt +0 -0
  31. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/checkpoints/epoch=000145.ckpt +3 -0
  32. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/checkpoints/epoch=000146.ckpt +0 -0
  33. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/checkpoints/last.ckpt +3 -0
  34. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/configs/2026-01-13T15-35-13-lightning.yaml +15 -0
  35. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/configs/2026-01-13T15-35-13-project.yaml +95 -0
  36. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/configs/2026-01-13T23-40-35-lightning.yaml +16 -0
  37. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/configs/2026-01-13T23-40-35-project.yaml +95 -0
  38. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/testtube/version_0/meta.experiment +1 -0
  39. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/testtube/version_0/meta_tags.csv +1 -0
  40. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/testtube/version_0/metrics.csv +102 -0
  41. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/testtube/version_0/tf/events.out.tfevents.1768318562.node-0.1923.0 +3 -0
  42. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/testtube/version_1/meta.experiment +1 -0
  43. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/testtube/version_1/meta_tags.csv +1 -0
  44. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/testtube/version_1/metrics.csv +54 -0
  45. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/testtube/version_1/tf/events.out.tfevents.1768347806.node-0.4103.0 +3 -0
  46. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/main.py +814 -0
  47. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/metrics/clip_similarity.py +47 -0
  48. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/metrics/compute_metrics.py +235 -0
  49. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/prompt_app.py +55 -0
  50. instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/requirements.txt +101 -0
.gitattributes CHANGED
@@ -413,3 +413,36 @@ instruct-pix2pix-BioMedCLIP-concat-newdata-data-Opacity/stable_diffusion/data/in
413
  instruct-pix2pix-BioMedCLIP-concat-newdata-data-Opacity/stable_diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png filter=lfs diff=lfs merge=lfs -text
414
  instruct-pix2pix-BioMedCLIP-concat-newdata-data-Opacity/stable_diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png filter=lfs diff=lfs merge=lfs -text
415
  instruct-pix2pix-BioMedCLIP-concat-newdata-data-Opacity/stable_diffusion/ldm/modules/image_degradation/utils/test.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
  instruct-pix2pix-BioMedCLIP-concat-newdata-data-Opacity/stable_diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png filter=lfs diff=lfs merge=lfs -text
414
  instruct-pix2pix-BioMedCLIP-concat-newdata-data-Opacity/stable_diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png filter=lfs diff=lfs merge=lfs -text
415
  instruct-pix2pix-BioMedCLIP-concat-newdata-data-Opacity/stable_diffusion/ldm/modules/image_degradation/utils/test.png filter=lfs diff=lfs merge=lfs -text
416
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/imgs/dataset.jpg filter=lfs diff=lfs merge=lfs -text
417
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/imgs/edit_app.jpg filter=lfs diff=lfs merge=lfs -text
418
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/assets/a-painting-of-a-fire.png filter=lfs diff=lfs merge=lfs -text
419
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/assets/a-photograph-of-a-fire.png filter=lfs diff=lfs merge=lfs -text
420
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/assets/a-shirt-with-a-fire-printed-on-it.png filter=lfs diff=lfs merge=lfs -text
421
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/assets/a-shirt-with-the-inscription-'fire'.png filter=lfs diff=lfs merge=lfs -text
422
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/assets/a-watercolor-painting-of-a-fire.png filter=lfs diff=lfs merge=lfs -text
423
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/assets/birdhouse.png filter=lfs diff=lfs merge=lfs -text
424
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/assets/fire.png filter=lfs diff=lfs merge=lfs -text
425
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/assets/inpainting.png filter=lfs diff=lfs merge=lfs -text
426
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/assets/rdm-preview.jpg filter=lfs diff=lfs merge=lfs -text
427
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/assets/reconstruction1.png filter=lfs diff=lfs merge=lfs -text
428
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/assets/reconstruction2.png filter=lfs diff=lfs merge=lfs -text
429
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/assets/rick.jpeg filter=lfs diff=lfs merge=lfs -text
430
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/assets/stable-samples/img2img/mountains-1.png filter=lfs diff=lfs merge=lfs -text
431
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/assets/stable-samples/img2img/mountains-2.png filter=lfs diff=lfs merge=lfs -text
432
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/assets/stable-samples/img2img/mountains-3.png filter=lfs diff=lfs merge=lfs -text
433
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/assets/stable-samples/img2img/sketch-mountains-input.jpg filter=lfs diff=lfs merge=lfs -text
434
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/assets/stable-samples/txt2img/000002025.png filter=lfs diff=lfs merge=lfs -text
435
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/assets/stable-samples/txt2img/000002035.png filter=lfs diff=lfs merge=lfs -text
436
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/assets/the-earth-is-on-fire,-oil-on-canvas.png filter=lfs diff=lfs merge=lfs -text
437
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/assets/txt2img-convsample.png filter=lfs diff=lfs merge=lfs -text
438
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/data/DejaVuSans.ttf filter=lfs diff=lfs merge=lfs -text
439
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/data/imagenet_val_hr_indices.p filter=lfs diff=lfs merge=lfs -text
440
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/data/inpainting_examples/6458524847_2f4c361183_k.png filter=lfs diff=lfs merge=lfs -text
441
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k.png filter=lfs diff=lfs merge=lfs -text
442
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png filter=lfs diff=lfs merge=lfs -text
443
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/data/inpainting_examples/bench2.png filter=lfs diff=lfs merge=lfs -text
444
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png filter=lfs diff=lfs merge=lfs -text
445
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y.png filter=lfs diff=lfs merge=lfs -text
446
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png filter=lfs diff=lfs merge=lfs -text
447
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png filter=lfs diff=lfs merge=lfs -text
448
+ instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/stable_diffusion/ldm/modules/image_degradation/utils/test.png filter=lfs diff=lfs merge=lfs -text
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/LICENSE ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4
+
5
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6
+
7
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
8
+
9
+ Portions of code and models (such as pretrained checkpoints, which are fine-tuned starting from released Stable Diffusion checkpoints) are derived from the Stable Diffusion codebase (https://github.com/CompVis/stable-diffusion). Further restrictions may apply. Please consult the Stable Diffusion license `stable_diffusion/LICENSE`. Modified code is denoted as such in comments at the start of each file.
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/README.md ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # InstructPix2Pix: Learning to Follow Image Editing Instructions
2
+ ### [Project Page](https://www.timothybrooks.com/instruct-pix2pix/) | [Paper](https://arxiv.org/abs/2211.09800) | [Data](http://instruct-pix2pix.eecs.berkeley.edu/)
3
+ PyTorch implementation of InstructPix2Pix, an instruction-based image editing model, based on the original [CompVis/stable_diffusion](https://github.com/CompVis/stable-diffusion) repo. <br>
4
+
5
+ [InstructPix2Pix: Learning to Follow Image Editing Instructions](https://www.timothybrooks.com/instruct-pix2pix/)
6
+ [Tim Brooks](https://www.timothybrooks.com/)\*,
7
+ [Aleksander Holynski](https://holynski.org/)\*,
8
+ [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/) <br>
9
+ UC Berkeley <br>
10
+ \*denotes equal contribution
11
+
12
+ <img src='https://instruct-pix2pix.timothybrooks.com/teaser.jpg'/>
13
+
14
+ ## TL;DR: quickstart
15
+
16
+ Follow the instructions below to download and run InstructPix2Pix on your own images. These instructions have been tested on a GPU with >18GB VRAM. If you don't have a GPU, you may need to change the default configuration, or check out [other ways of using the model](https://github.com/timothybrooks/instruct-pix2pix#other-ways-of-using-instructpix2pix).
17
+
18
+ ### Set up a conda environment, and download a pretrained model:
19
+ ```
20
+ conda env create -f environment.yaml
21
+ conda activate ip2p
22
+ bash scripts/download_checkpoints.sh
23
+ ```
24
+
25
+ ### Edit a single image:
26
+ ```
27
+ python edit_cli.py --input imgs/example.jpg --output imgs/output.jpg --edit "turn him into a cyborg"
28
+
29
+ # Optionally, you can specify parameters to tune your result:
30
+ # python edit_cli.py --steps 100 --resolution 512 --seed 1371 --cfg-text 7.5 --cfg-image 1.2 --input imgs/example.jpg --output imgs/output.jpg --edit "turn him into a cyborg"
31
+ ```
32
+
33
+ ### Or launch your own interactive editing Gradio app:
34
+ ```
35
+ python edit_app.py
36
+ ```
37
+ ![Edit app](https://github.com/timothybrooks/instruct-pix2pix/blob/main/imgs/edit_app.jpg?raw=true)
38
+
39
+ _(For advice on how to get the best results by tuning parameters, see the [Tips](https://github.com/timothybrooks/instruct-pix2pix#tips) section)._
40
+
41
+ ## Setup
42
+
43
+ Install all dependencies with:
44
+ ```
45
+ conda env create -f environment.yaml
46
+ ```
47
+
48
+ Download the pretrained models by running:
49
+ ```
50
+ bash scripts/download_checkpoints.sh
51
+ ```
52
+
53
+ ## Generated Dataset
54
+
55
+ Our image editing model is trained on a generated dataset consisting of 454,445 examples. Each example contains (1) an input image, (2) an editing instruction, and (3) an output edited image. We provide two versions of the dataset, one in which each pair of edited images is generated 100 times, and the best examples are chosen based on CLIP metrics (Section 3.1.2 in the paper) (`clip-filtered-dataset`), and one in which examples are randomly chosen (`random-sample-dataset`).
56
+
57
+ For the released version of this dataset, we've additionally filtered prompts and images for NSFW content. After NSFW filtering, the GPT-3 generated dataset contains 451,990 examples. The final image-pair datasets contain:
58
+
59
+ | | # of image editing examples | Dataset size |
60
+ |--|-----------------------|----------------------- |
61
+ | `random-sample-dataset` |451990|727GB|
62
+ | `clip-filtered-dataset` |313010|436GB|
63
+
64
+ To download one of these datasets, along with the entire NSFW-filtered text data, run the following command with the appropriate dataset name:
65
+
66
+ ```
67
+ bash scripts/download_data.sh clip-filtered-dataset
68
+ ```
69
+
70
+
71
+ ## Training InstructPix2Pix
72
+
73
+ InstructPix2Pix is trained by fine-tuning from an initial StableDiffusion checkpoint. The first step is to download a Stable Diffusion checkpoint. For our trained models, we used the v1.5 checkpoint as the starting point. To download the same ones we used, you can run the following script:
74
+ ```
75
+ bash scripts/download_pretrained_sd.sh
76
+ ```
77
+ If you'd like to use a different checkpoint, point to it in the config file `configs/train.yaml`, on line 8, after `ckpt_path:`.
78
+
79
+ Next, we need to change the config to point to our downloaded (or generated) dataset. If you're using the `clip-filtered-dataset` from above, you can skip this. Otherwise, you may need to edit lines 85 and 94 of the config (`data.params.train.params.path`, `data.params.validation.params.path`).
80
+
81
+ Finally, start a training job with the following command:
82
+
83
+ ```
84
+ python main.py --name default --base configs/train.yaml --train --gpus 0,1,2,3,4,5,6,7
85
+ ```
86
+
87
+
88
+ ## Creating your own dataset
89
+
90
+ Our generated dataset of paired images and editing instructions is made in two phases: First, we use GPT-3 to generate text triplets: (a) a caption describing an image, (b) an edit instruction, (c) a caption describing the image after the edit. Then, we turn pairs of captions (before/after the edit) into pairs of images using Stable Diffusion and Prompt-to-Prompt.
91
+
92
+ ### (1) Generate a dataset of captions and instructions
93
+
94
+ We provide our generated dataset of captions and edit instructions [here](https://instruct-pix2pix.eecs.berkeley.edu/gpt-generated-prompts.jsonl). If you plan to use our captions+instructions, skip to step (2). Otherwise, if you would like to create your own text dataset, please follow steps (1.1-1.3) below. Note that generating very large datasets using GPT-3 can be expensive.
95
+
96
+ #### (1.1) Manually write a dataset of instructions and captions
97
+
98
+ The first step of the process is fine-tuning GPT-3. To do this, we made a dataset of 700 examples broadly covering of edits that we might want our model to be able to perform. Our examples are available [here](https://instruct-pix2pix.eecs.berkeley.edu/human-written-prompts.jsonl). These should be diverse and cover a wide range of possible captions and types of edits. Ideally, they should avoid duplication or significant overlap of captions and instructions. It is also important to be mindful of limitations of Stable Diffusion and Prompt-to-Prompt in writing these examples, such as inability to perform large spatial transformations (e.g., moving the camera, zooming in, swapping object locations).
99
+
100
+ Input prompts should closely match the distribution of input prompts used to generate the larger dataset. We sampled the 700 input prompts from the _LAION Improved Aesthetics 6.5+_ dataset and also use this dataset for generating examples. We found this dataset is quite noisy (many of the captions are overly long and contain irrelevant text). For this reason, we also considered MSCOCO and LAION-COCO datasets, but ultimately chose _LAION Improved Aesthetics 6.5+_ due to its diversity of content, proper nouns, and artistic mediums. If you choose to use another dataset or combination of datasets as input to GPT-3 when generating examples, we recommend you sample the input prompts from the same distribution when manually writing training examples.
101
+
102
+ #### (1.2) Finetune GPT-3
103
+
104
+ The next step is to finetune a large language model on the manually written instructions/outputs to generate edit instructions and edited caption from a new input caption. For this, we finetune GPT-3's Davinci model via the OpenAI API, although other language models could be used.
105
+
106
+ To prepare training data for GPT-3, one must first create an OpenAI developer account to access the needed APIs, and [set up the API keys on your local device](https://beta.openai.com/docs/api-reference/introduction). Also, run the `prompts/prepare_for_gpt.py` script, which forms the prompts into the correct format by concatenating instructions and captions and adding delimiters and stop sequences.
107
+
108
+ ```bash
109
+ python dataset_creation/prepare_for_gpt.py --input-path data/human-written-prompts.jsonl --output-path data/human-written-prompts-for-gpt.jsonl
110
+ ```
111
+
112
+ Next, finetune GPT-3 via the OpenAI CLI. We provide an example below, although please refer to OpenAI's official documentation for this, as best practices may change. We trained the Davinci model for a single epoch. You can experiment with smaller less expensive GPT-3 variants or with open source language models, although this may negatively affect performance.
113
+
114
+ ```bash
115
+ openai api fine_tunes.create -t data/human-written-prompts-for-gpt.jsonl -m davinci --n_epochs 1 --suffix "instruct-pix2pix"
116
+ ```
117
+
118
+ You can test out the finetuned GPT-3 model by launching the provided Gradio app:
119
+
120
+ ```bash
121
+ python prompt_app.py --openai-api-key OPENAI_KEY --openai-model OPENAI_MODEL_NAME
122
+ ```
123
+
124
+ ![Prompt app](https://github.com/timothybrooks/instruct-pix2pix/blob/main/imgs/prompt_app.jpg?raw=true)
125
+
126
+ #### (1.3) Generate a large dataset of captions and instructions
127
+
128
+ We now use the finetuned GPT-3 model to generate a large dataset. Our dataset cost thousands of dollars to create. See `prompts/gen_instructions_and_captions.py` for the script which generates these examples. We recommend first generating a small number of examples (by setting a low value of `--num-samples`) and gradually increasing the scale to ensure the results are working as desired before increasing scale.
129
+
130
+ ```bash
131
+ python dataset_creation/generate_txt_dataset.py --openai-api-key OPENAI_KEY --openai-model OPENAI_MODEL_NAME
132
+ ```
133
+
134
+ If you are generating at a very large scale (e.g., 100K+), it will be noteably faster to generate the dataset with multiple processes running in parallel. This can be accomplished by setting `--partitions=N` to a higher number and running multiple processes, setting each `--partition` to the corresponding value.
135
+
136
+ ```bash
137
+ python dataset_creation/generate_txt_dataset.py --openai-api-key OPENAI_KEY --openai-model OPENAI_MODEL_NAME --partitions=10 --partition=0
138
+ ```
139
+
140
+ ### (2) Turn paired captions into paired images
141
+
142
+ The next step is to turn pairs of text captions into pairs of images. For this, we need to copy some pre-trained Stable Diffusion checkpoints to `stable_diffusion/models/ldm/stable-diffusion-v1/`. You may have already done this if you followed the instructions above for training with our provided data, but if not, you can do this by running:
143
+
144
+ ```bash
145
+ bash scripts/download_pretrained_sd.sh
146
+ ```
147
+
148
+ For our model, we used [checkpoint v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.ckpt), and the [new autoencoder](https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt), but other models may work as well. If you choose to use other models, make sure to change point to the corresponding checkpoints by passing in the `--ckpt` and `--vae-ckpt` arguments. Once all checkpoints have been downloaded, we can generate the dataset with the following command:
149
+
150
+ ```
151
+ python dataset_creation/generate_img_dataset.py --out_dir data/instruct-pix2pix-dataset-000 --prompts_file path/to/generated_prompts.jsonl
152
+ ```
153
+
154
+ This command operates on a single GPU (typically a V100 or A100). To parallelize over many GPUs/machines, set `--n-partitions` to the total number of parallel jobs and `--partition` to the index of each job.
155
+
156
+ ```
157
+ python dataset_creation/generate_img_dataset.py --out_dir data/instruct-pix2pix-dataset-000 --prompts_file path/to/generated_prompts.jsonl --n-partitions 100 --partition 0
158
+ ```
159
+
160
+ The default parameters match that of our dataset, although in practice you can use a smaller number of steps (e.g., `--steps=25`) to generate high quality data faster. By default, we generate 100 samples per prompt and use CLIP filtering to keep a max of 4 per prompt. You can experiment with fewer samples by setting `--n-samples`. The command below turns off CLIP filtering entirely and is therefore faster:
161
+
162
+ ```
163
+ python dataset_creation/generate_img_dataset.py --out_dir data/instruct-pix2pix-dataset-000 --prompts_file path/to/generated_prompts.jsonl --n-samples 4 --clip-threshold 0 --clip-dir-threshold 0 --clip-img-threshold 0 --n-partitions 100 --partition 0
164
+ ```
165
+
166
+ After generating all of the dataset examples, run the following command below to create a list of the examples. This is needed for the dataset onject to efficiently be able to sample examples without needing to iterate over the entire dataset directory at the start of each training run.
167
+
168
+ ```
169
+ python dataset_creation/prepare_dataset.py data/instruct-pix2pix-dataset-000
170
+ ```
171
+
172
+ ## Evaluation
173
+
174
+ To generate plots like the ones in Figures 8 and 10 in the paper, run the following command:
175
+
176
+ ```
177
+ python metrics/compute_metrics.py --ckpt /path/to/your/model.ckpt
178
+ ```
179
+
180
+ ## Tips
181
+
182
+ If you're not getting the quality result you want, there may be a few reasons:
183
+ 1. **Is the image not changing enough?** Your Image CFG weight may be too high. This value dictates how similar the output should be to the input. It's possible your edit requires larger changes from the original image, and your Image CFG weight isn't allowing that. Alternatively, your Text CFG weight may be too low. This value dictates how much to listen to the text instruction. The default Image CFG of 1.5 and Text CFG of 7.5 are a good starting point, but aren't necessarily optimal for each edit. Try:
184
+ * Decreasing the Image CFG weight, or
185
+ * Increasing the Text CFG weight, or
186
+ 2. Conversely, **is the image changing too much**, such that the details in the original image aren't preserved? Try:
187
+ * Increasing the Image CFG weight, or
188
+ * Decreasing the Text CFG weight
189
+ 3. Try generating results with different random seeds by setting "Randomize Seed" and running generation multiple times. You can also try setting "Randomize CFG" to sample new Text CFG and Image CFG values each time.
190
+ 4. Rephrasing the instruction sometimes improves results (e.g., "turn him into a dog" vs. "make him a dog" vs. "as a dog").
191
+ 5. Increasing the number of steps sometimes improves results.
192
+ 6. Do faces look weird? The Stable Diffusion autoencoder has a hard time with faces that are small in the image. Try cropping the image so the face takes up a larger portion of the frame.
193
+
194
+ ## Comments
195
+
196
+ - Our codebase is based on the [Stable Diffusion codebase](https://github.com/CompVis/stable-diffusion).
197
+
198
+ ## BibTeX
199
+
200
+ ```
201
+ @article{brooks2022instructpix2pix,
202
+ title={InstructPix2Pix: Learning to Follow Image Editing Instructions},
203
+ author={Brooks, Tim and Holynski, Aleksander and Efros, Alexei A},
204
+ journal={arXiv preprint arXiv:2211.09800},
205
+ year={2022}
206
+ }
207
+ ```
208
+ ## Other ways of using InstructPix2Pix
209
+
210
+ ### InstructPix2Pix on [HuggingFace](https://huggingface.co/spaces/timbrooks/instruct-pix2pix):
211
+ > A browser-based version of the demo is available as a [HuggingFace space](https://huggingface.co/spaces/timbrooks/instruct-pix2pix). For this version, you only need a browser, a picture you want to edit, and an instruction! Note that this is a shared online demo, and processing time may be slower during peak utilization.
212
+
213
+ ### InstructPix2Pix on [Replicate](https://replicate.com/timothybrooks/instruct-pix2pix):
214
+ > Replicate provides a production-ready cloud API for running the InstructPix2Pix model. You can run the model from any environment using a simple API call with cURL, Python, JavaScript, or your language of choice. Replicate also provides a web interface for running the model and sharing predictions.
215
+
216
+ ### InstructPix2Pix in [Imaginairy](https://github.com/brycedrennan/imaginAIry#-edit-images-with-instructions-alone-by-instructpix2pix):
217
+ > Imaginairy offers another way of easily installing InstructPix2Pix with a single command. It can run on devices without GPUs (like a Macbook!).
218
+ > ```bash
219
+ > pip install imaginairy --upgrade
220
+ > aimg edit any-image.jpg --gif "turn him into a cyborg"
221
+ > ```
222
+ > It also offers an easy way to perform a bunch of edits on an image, and can save edits out to an animated GIF:
223
+ > ```
224
+ > aimg edit --gif --surprise-me pearl-earring.jpg
225
+ > ```
226
+ > <img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/7c05c3aae2740278978c5e84962b826e58201bac/assets/girl_with_a_pearl_earring_suprise.gif" width="512">
227
+
228
+ ### InstructPix2Pix in [🧨 Diffusers](https://github.com/huggingface/diffusers):
229
+
230
+ > InstructPix2Pix in Diffusers is a bit more optimized, so it may be faster and more suitable for GPUs with less memory. Below are instructions for installing the library and editing an image:
231
+ > 1. Install diffusers and relevant dependencies:
232
+ >
233
+ > ```bash
234
+ > pip install transformers accelerate torch
235
+ >
236
+ > pip install git+https://github.com/huggingface/diffusers.git
237
+ > ```
238
+ >
239
+ > 2. Load the model and edit the image:
240
+ >
241
+ > ```python
242
+ >
243
+ > import torch
244
+ > from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
245
+ >
246
+ > model_id = "timbrooks/instruct-pix2pix"
247
+ > pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None)
248
+ > pipe.to("cuda")
249
+ > pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
250
+ > # `image` is an RGB PIL.Image
251
+ > images = pipe("turn him into cyborg", image=image).images
252
+ > images[0]
253
+ > ```
254
+ >
255
+ > For more information, check the docs [here](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/pix2pix).
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/__pycache__/main.cpython-310.pyc ADDED
Binary file (20.4 kB). View file
 
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/__pycache__/main.cpython-311.pyc ADDED
Binary file (39.9 kB). View file
 
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/__pycache__/main.cpython-39.pyc ADDED
Binary file (20.4 kB). View file
 
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/checkpoints/instruct-pix2pix-00-22000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffd280ddcfc8234e4d28b93641cb83169cebcb4d70998df9ee2eabb4d705374a
3
+ size 7703927910
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/config.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ python main.py --name debug --base configs/train.yaml --train --gpus 0,
5
+
6
+ python main.py --name debug --base configs/train_3d.yaml --train --gpus 0,
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/configs/generate.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
2
+ # See more details in LICENSE.
3
+
4
+ model:
5
+ base_learning_rate: 1.0e-04
6
+ target: ldm.models.diffusion.ddpm_edit.LatentDiffusion
7
+ params:
8
+ linear_start: 0.00085
9
+ linear_end: 0.0120
10
+ num_timesteps_cond: 1
11
+ log_every_t: 200
12
+ timesteps: 1000
13
+ first_stage_key: edited
14
+ cond_stage_key: edit
15
+ # image_size: 64
16
+ # image_size: 32
17
+ image_size: 16
18
+ channels: 4
19
+ cond_stage_trainable: false # Note: different from the one we trained before
20
+ conditioning_key: hybrid
21
+ monitor: val/loss_simple_ema
22
+ scale_factor: 0.18215
23
+ use_ema: true
24
+ load_ema: true
25
+
26
+ scheduler_config: # 10000 warmup steps
27
+ target: ldm.lr_scheduler.LambdaLinearScheduler
28
+ params:
29
+ warm_up_steps: [ 0 ]
30
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
31
+ f_start: [ 1.e-6 ]
32
+ f_max: [ 1. ]
33
+ f_min: [ 1. ]
34
+
35
+ unet_config:
36
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
37
+ params:
38
+ image_size: 32 # unused
39
+ in_channels: 8
40
+ out_channels: 4
41
+ model_channels: 320
42
+ attention_resolutions: [ 4, 2, 1 ]
43
+ num_res_blocks: 2
44
+ channel_mult: [ 1, 2, 4, 4 ]
45
+ num_heads: 8
46
+ use_spatial_transformer: True
47
+ transformer_depth: 1
48
+ context_dim: 768
49
+ use_checkpoint: True
50
+ legacy: False
51
+
52
+ first_stage_config:
53
+ target: ldm.models.autoencoder.AutoencoderKL
54
+ params:
55
+ embed_dim: 4
56
+ monitor: val/rec_loss
57
+ ddconfig:
58
+ double_z: true
59
+ z_channels: 4
60
+ resolution: 256
61
+ in_channels: 3
62
+ out_ch: 3
63
+ ch: 128
64
+ ch_mult:
65
+ - 1
66
+ - 2
67
+ - 4
68
+ - 4
69
+ num_res_blocks: 2
70
+ attn_resolutions: []
71
+ dropout: 0.0
72
+ lossconfig:
73
+ target: torch.nn.Identity
74
+
75
+ cond_stage_config:
76
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
77
+
78
+ data:
79
+ target: main.DataModuleFromConfig
80
+ params:
81
+ batch_size: 128
82
+ num_workers: 1
83
+ wrap: false
84
+ validation:
85
+ target: edit_dataset.EditDataset
86
+ params:
87
+ path: data/clip-filtered-dataset
88
+ cache_dir: data/
89
+ cache_name: data_10k
90
+ split: val
91
+ min_text_sim: 0.2
92
+ min_image_sim: 0.75
93
+ min_direction_sim: 0.2
94
+ max_samples_per_prompt: 1
95
+ min_resize_res: 512
96
+ max_resize_res: 512
97
+ crop_res: 512
98
+ output_as_edit: False
99
+ real_input: True
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/configs/infer.yaml ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
2
+ # See more details in LICENSE.
3
+
4
+ model:
5
+ base_learning_rate: 1.0e-04
6
+ target: ldm.models.diffusion.ddpm_edit.LatentDiffusion
7
+ params:
8
+ ckpt_path: stable_diffusion/models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
9
+ linear_start: 0.00085
10
+ linear_end: 0.0120
11
+ num_timesteps_cond: 1
12
+ log_every_t: 200
13
+ timesteps: 1000
14
+ first_stage_key: edited
15
+ cond_stage_key: edit
16
+ image_size: 64
17
+ channels: 4
18
+ cond_stage_trainable: false # Note: different from the one we trained before
19
+ conditioning_key: hybrid
20
+ monitor: val/loss_simple_ema
21
+ scale_factor: 0.18215
22
+ use_ema: true
23
+ load_ema: false
24
+
25
+ scheduler_config: # 10000 warmup steps
26
+ target: ldm.lr_scheduler.LambdaLinearScheduler
27
+ params:
28
+ warm_up_steps: [ 0 ]
29
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
30
+ f_start: [ 1.e-6 ]
31
+ f_max: [ 1. ]
32
+ f_min: [ 1. ]
33
+
34
+ unet_config:
35
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
36
+ params:
37
+ image_size: 32 # unused
38
+ in_channels: 8
39
+ out_channels: 4
40
+ model_channels: 320
41
+ attention_resolutions: [ 4, 2, 1 ]
42
+ num_res_blocks: 2
43
+ channel_mult: [ 1, 2, 4, 4 ]
44
+ num_heads: 8
45
+ use_spatial_transformer: True
46
+ transformer_depth: 1
47
+ context_dim: 768
48
+ use_checkpoint: True
49
+ legacy: False
50
+
51
+ first_stage_config:
52
+ target: ldm.models.autoencoder.AutoencoderKL
53
+ params:
54
+ embed_dim: 4
55
+ monitor: val/rec_loss
56
+ ddconfig:
57
+ double_z: true
58
+ z_channels: 4
59
+ resolution: 256
60
+ in_channels: 3
61
+ out_ch: 3
62
+ ch: 128
63
+ ch_mult:
64
+ - 1
65
+ - 2
66
+ - 4
67
+ - 4
68
+ num_res_blocks: 2
69
+ attn_resolutions: []
70
+ dropout: 0.0
71
+ lossconfig:
72
+ target: torch.nn.Identity
73
+
74
+ cond_stage_config:
75
+ target: ldm.modules.encoders.modules.FrozenBioMedCLIPEmbedder
76
+
77
+ data:
78
+ target: main.DataModuleFromConfig
79
+ params:
80
+ batch_size: 16
81
+ num_workers: 8
82
+ train:
83
+ target: ldm.data.ct_clip_data_train.CTReportDataset
84
+ params:
85
+ data_folder: '/sd/shuhan/CT-RATE/dataset/train_fixed'
86
+ csv_file: '/sd/shuhan/CT-RATE/radiology_text_reports/train_reports.csv'
87
+
88
+ validation:
89
+ target: ldm.data.ct_clip_data_inference.CTReportDatasetinfer
90
+ params:
91
+ data_folder: '/sd/shuhan/CT-RATE/dataset/valid_fixed'
92
+ csv_file: '/sd/shuhan/CT-RATE/radiology_text_reports/valid_reports.csv'
93
+ labels: '/sd/shuhan/CT-RATE/multi_abnormality_labels/valid_predicted_labels.csv'
94
+
95
+ lightning:
96
+ callbacks:
97
+ image_logger:
98
+ target: main.ImageLogger
99
+ params:
100
+ batch_frequency: 200000000
101
+ max_images: 2
102
+ increase_log_steps: False
103
+
104
+ trainer:
105
+ max_epochs: 2000
106
+ benchmark: True
107
+ accumulate_grad_batches: 4
108
+ check_val_every_n_epoch: 10000
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/configs/infer3d.yaml ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
2
+ # See more details in LICENSE.
3
+
4
+ model:
5
+ base_learning_rate: 1.0e-04
6
+ target: ldm.models.diffusion.ddpm_edit_3d.LatentDiffusion
7
+ params:
8
+ linear_start: 0.00085
9
+ linear_end: 0.0120
10
+ num_timesteps_cond: 1
11
+ log_every_t: 200
12
+ timesteps: 1000
13
+ first_stage_key: edited
14
+ cond_stage_key: edit
15
+ image_size: 32
16
+ channels: 4
17
+ cond_stage_trainable: false # Note: different from the one we trained before
18
+ conditioning_key: hybrid
19
+ monitor: val/loss_simple_ema
20
+ scale_factor: 0.18215
21
+ use_ema: true
22
+ load_ema: false
23
+ ckpt_path: /sd/qichen/full_ct_gen/instruct-pix2pix-BioMedCLIP-concat-newdata/logs/train_instructpix2pix_2d_random/checkpoints/epoch=000091.ckpt
24
+ load_only_unet: True
25
+
26
+ scheduler_config: # 10000 warmup steps
27
+ target: ldm.lr_scheduler.LambdaLinearScheduler
28
+ params:
29
+ warm_up_steps: [ 0 ]
30
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
31
+ f_start: [ 1.e-6 ]
32
+ f_max: [ 1. ]
33
+ f_min: [ 1. ]
34
+
35
+ unet_config:
36
+ target: ldm.modules.diffusionmodules.openaimodel_pseudo3D.UNetModel
37
+ params:
38
+ image_size: 32 # unused
39
+ in_channels: 8
40
+ out_channels: 4
41
+ model_channels: 320
42
+ attention_resolutions: [ 4, 2, 1 ]
43
+ num_res_blocks: 2
44
+ channel_mult: [ 1, 2, 4, 4 ]
45
+ num_heads: 8
46
+ use_spatial_transformer: True
47
+ transformer_depth: 1
48
+ context_dim: 768
49
+ use_checkpoint: True
50
+ legacy: False
51
+
52
+ first_stage_config:
53
+ target: ldm.models.autoencoder.AutoencoderKL
54
+ params:
55
+ embed_dim: 4
56
+ monitor: val/rec_loss
57
+ ddconfig:
58
+ double_z: true
59
+ z_channels: 4
60
+ resolution: 256
61
+ in_channels: 3
62
+ out_ch: 3
63
+ ch: 128
64
+ ch_mult:
65
+ - 1
66
+ - 2
67
+ - 4
68
+ - 4
69
+ num_res_blocks: 2
70
+ attn_resolutions: []
71
+ dropout: 0.0
72
+ lossconfig:
73
+ target: torch.nn.Identity
74
+
75
+ cond_stage_config:
76
+ target: ldm.modules.encoders.modules.FrozenBioMedCLIPEmbedder
77
+
78
+
79
+ data:
80
+ target: main.DataModuleFromConfig
81
+ params:
82
+ batch_size: 16
83
+ num_workers: 8
84
+ train:
85
+ target: ldm.data.ct_clip_data_train_3d.CTReportDataset
86
+ params:
87
+ data_folder: '/sd/shuhan/CT-RATE/dataset/train_fixed'
88
+ csv_file: '/sd/shuhan/CT-RATE/radiology_text_reports/train_reports.csv'
89
+
90
+ validation:
91
+ target: ldm.data.ct_clip_data_inference_3d.CTReportDatasetinfer
92
+ params:
93
+ data_folder: '/sd/shuhan/CT-RATE/dataset/valid_fixed'
94
+ csv_file: '/sd/shuhan/CT-RATE/radiology_text_reports/valid_reports.csv'
95
+ labels: '/sd/shuhan/CT-RATE/multi_abnormality_labels/valid_predicted_labels.csv'
96
+
97
+
98
+ lightning:
99
+ callbacks:
100
+ image_logger:
101
+ target: main.ImageLogger
102
+ params:
103
+ batch_frequency: 200000000
104
+ max_images: 2
105
+ increase_log_steps: False
106
+
107
+ trainer:
108
+ max_epochs: 2000
109
+ benchmark: True
110
+ accumulate_grad_batches: 4
111
+ check_val_every_n_epoch: 10000
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/configs/train.yaml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
2
+ # See more details in LICENSE.
3
+
4
+ model:
5
+ base_learning_rate: 1.0e-04
6
+ target: ldm.models.diffusion.ddpm_edit.LatentDiffusion
7
+ params:
8
+ ckpt_path: stable_diffusion/models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
9
+ linear_start: 0.00085
10
+ linear_end: 0.0120
11
+ num_timesteps_cond: 1
12
+ log_every_t: 200
13
+ timesteps: 1000
14
+ first_stage_key: edited
15
+ cond_stage_key: edit
16
+ image_size: 64
17
+ channels: 4
18
+ cond_stage_trainable: false # Note: different from the one we trained before
19
+ conditioning_key: hybrid
20
+ monitor: val/loss_simple_ema
21
+ scale_factor: 0.18215
22
+ use_ema: true
23
+ load_ema: false
24
+
25
+ scheduler_config: # 10000 warmup steps
26
+ target: ldm.lr_scheduler.LambdaLinearScheduler
27
+ params:
28
+ warm_up_steps: [ 0 ]
29
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
30
+ f_start: [ 1.e-6 ]
31
+ f_max: [ 1. ]
32
+ f_min: [ 1. ]
33
+
34
+ unet_config:
35
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
36
+ params:
37
+ image_size: 32 # unused
38
+ in_channels: 8
39
+ out_channels: 4
40
+ model_channels: 320
41
+ attention_resolutions: [ 4, 2, 1 ]
42
+ num_res_blocks: 2
43
+ channel_mult: [ 1, 2, 4, 4 ]
44
+ num_heads: 8
45
+ use_spatial_transformer: True
46
+ transformer_depth: 1
47
+ context_dim: 768
48
+ use_checkpoint: True
49
+ legacy: False
50
+
51
+ first_stage_config:
52
+ target: ldm.models.autoencoder.AutoencoderKL
53
+ params:
54
+ embed_dim: 4
55
+ monitor: val/rec_loss
56
+ ddconfig:
57
+ double_z: true
58
+ z_channels: 4
59
+ resolution: 256
60
+ in_channels: 3
61
+ out_ch: 3
62
+ ch: 128
63
+ ch_mult:
64
+ - 1
65
+ - 2
66
+ - 4
67
+ - 4
68
+ num_res_blocks: 2
69
+ attn_resolutions: []
70
+ dropout: 0.0
71
+ lossconfig:
72
+ target: torch.nn.Identity
73
+
74
+ cond_stage_config:
75
+ target: ldm.modules.encoders.modules.FrozenBioMedCLIPEmbedder
76
+
77
+ data:
78
+ target: main.DataModuleFromConfig
79
+ params:
80
+ batch_size: 16
81
+ num_workers: 8
82
+ train:
83
+ target: ldm.data.ct_clip_data_train.CTReportDataset
84
+ params:
85
+ data_folder: '/sd/shuhan/CT-RATE/dataset/train_fixed'
86
+ csv_file: '/sd/shuhan/CT-RATE/radiology_text_reports/train_reports.csv'
87
+
88
+ validation:
89
+ target: ldm.data.ct_clip_data_inference.CTReportDatasetinfer
90
+ params:
91
+ data_folder: '/sd/shuhan/CT-RATE/dataset/valid_fixed'
92
+ csv_file: '/sd/shuhan/CT-RATE/radiology_text_reports/valid_reports.csv'
93
+
94
+ lightning:
95
+ callbacks:
96
+ image_logger:
97
+ target: main.ImageLogger
98
+ params:
99
+ batch_frequency: 200000000
100
+ max_images: 2
101
+ increase_log_steps: False
102
+
103
+ trainer:
104
+ max_epochs: 2000
105
+ benchmark: True
106
+ accumulate_grad_batches: 4
107
+ check_val_every_n_epoch: 10000
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/configs/train_3d.yaml ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
2
+ # See more details in LICENSE.
3
+
4
+ model:
5
+ base_learning_rate: 1.0e-04
6
+ target: ldm.models.diffusion.ddpm_edit_3d.LatentDiffusion
7
+ params:
8
+ linear_start: 0.00085
9
+ linear_end: 0.0120
10
+ num_timesteps_cond: 1
11
+ log_every_t: 200
12
+ timesteps: 1000
13
+ first_stage_key: edited
14
+ cond_stage_key: edit
15
+ image_size: 32
16
+ channels: 4
17
+ cond_stage_trainable: false # Note: different from the one we trained before
18
+ conditioning_key: hybrid
19
+ monitor: val/loss_simple_ema
20
+ scale_factor: 0.18215
21
+ use_ema: true
22
+ load_ema: false
23
+ ckpt_path: ./logs/train_train_instructpix2pix/checkpoints/epoch=001959.ckpt
24
+ load_only_unet: True
25
+
26
+ scheduler_config: # 10000 warmup steps
27
+ target: ldm.lr_scheduler.LambdaLinearScheduler
28
+ params:
29
+ warm_up_steps: [ 0 ]
30
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
31
+ f_start: [ 1.e-6 ]
32
+ f_max: [ 1. ]
33
+ f_min: [ 1. ]
34
+
35
+ unet_config:
36
+ target: ldm.modules.diffusionmodules.openaimodel_pseudo3D.UNetModel
37
+ params:
38
+ image_size: 32 # unused
39
+ in_channels: 8
40
+ out_channels: 4
41
+ model_channels: 320
42
+ attention_resolutions: [ 4, 2, 1 ]
43
+ num_res_blocks: 2
44
+ channel_mult: [ 1, 2, 4, 4 ]
45
+ num_heads: 8
46
+ use_spatial_transformer: True
47
+ transformer_depth: 1
48
+ context_dim: 768
49
+ use_checkpoint: True
50
+ legacy: False
51
+
52
+ first_stage_config:
53
+ target: ldm.models.autoencoder.AutoencoderKL
54
+ params:
55
+ embed_dim: 4
56
+ monitor: val/rec_loss
57
+ ddconfig:
58
+ double_z: true
59
+ z_channels: 4
60
+ resolution: 256
61
+ in_channels: 3
62
+ out_ch: 3
63
+ ch: 128
64
+ ch_mult:
65
+ - 1
66
+ - 2
67
+ - 4
68
+ - 4
69
+ num_res_blocks: 2
70
+ attn_resolutions: []
71
+ dropout: 0.0
72
+ lossconfig:
73
+ target: torch.nn.Identity
74
+
75
+ cond_stage_config:
76
+ target: ldm.modules.encoders.modules.FrozenBioMedCLIPEmbedder
77
+
78
+ data:
79
+ target: main.DataModuleFromConfig
80
+ params:
81
+ batch_size: 1
82
+ num_workers: 8
83
+ train:
84
+ target: ldm.data.ct_clip_data_train_3d.CTReportDataset
85
+ params:
86
+ data_folder: '/sd/shuhan/CT-RATE/dataset/train_fixed'
87
+ csv_file: '/sd/shuhan/CT-RATE/radiology_text_reports/train_reports.csv'
88
+
89
+ validation:
90
+ target: ldm.data.ct_clip_data_inference_3d.CTReportDatasetinfer
91
+ params:
92
+ data_folder: '/sd/shuhan/CT-RATE/dataset/valid_fixed'
93
+ csv_file: '/sd/shuhan/CT-RATE/radiology_text_reports/valid_reports.csv'
94
+ labels: '/sd/shuhan/CT-RATE/multi_abnormality_labels/valid_predicted_labels.csv'
95
+
96
+ lightning:
97
+ callbacks:
98
+ image_logger:
99
+ target: main.ImageLogger
100
+ params:
101
+ batch_frequency: 200000000000
102
+ max_images: 2
103
+ increase_log_steps: False
104
+
105
+ trainer:
106
+ max_epochs: 2000
107
+ benchmark: True
108
+ accumulate_grad_batches: 4
109
+ check_val_every_n_epoch: 4
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/dataset_creation/generate_img_dataset.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ import k_diffusion
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange, repeat
11
+ from omegaconf import OmegaConf
12
+ from PIL import Image
13
+ from pytorch_lightning import seed_everything
14
+ from tqdm import tqdm
15
+
16
+ sys.path.append("./")
17
+ sys.path.append("./stable_diffusion")
18
+
19
+ from ldm.modules.attention import CrossAttention
20
+ from ldm.util import instantiate_from_config
21
+ from metrics.clip_similarity import ClipSimilarity
22
+
23
+
24
+ ################################################################################
25
+ # Modified K-diffusion Euler ancestral sampler with prompt-to-prompt.
26
+ # https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
27
+
28
+
29
+ def append_dims(x, target_dims):
30
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
31
+ dims_to_append = target_dims - x.ndim
32
+ if dims_to_append < 0:
33
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
34
+ return x[(...,) + (None,) * dims_to_append]
35
+
36
+
37
+ def to_d(x, sigma, denoised):
38
+ """Converts a denoiser output to a Karras ODE derivative."""
39
+ return (x - denoised) / append_dims(sigma, x.ndim)
40
+
41
+
42
+ def get_ancestral_step(sigma_from, sigma_to):
43
+ """Calculates the noise level (sigma_down) to step down to and the amount
44
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
45
+ sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5)
46
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
47
+ return sigma_down, sigma_up
48
+
49
+
50
+ def sample_euler_ancestral(model, x, sigmas, prompt2prompt_threshold=0.0, **extra_args):
51
+ """Ancestral sampling with Euler method steps."""
52
+ s_in = x.new_ones([x.shape[0]])
53
+ for i in range(len(sigmas) - 1):
54
+ prompt_to_prompt = prompt2prompt_threshold > i / (len(sigmas) - 2)
55
+ for m in model.modules():
56
+ if isinstance(m, CrossAttention):
57
+ m.prompt_to_prompt = prompt_to_prompt
58
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
59
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
60
+ d = to_d(x, sigmas[i], denoised)
61
+ # Euler method
62
+ dt = sigma_down - sigmas[i]
63
+ x = x + d * dt
64
+ if sigmas[i + 1] > 0:
65
+ # Make noise the same across all samples in batch.
66
+ x = x + torch.randn_like(x[:1]) * sigma_up
67
+ return x
68
+
69
+
70
+ ################################################################################
71
+
72
+
73
+ def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
74
+ print(f"Loading model from {ckpt}")
75
+ pl_sd = torch.load(ckpt, map_location="cpu")
76
+ if "global_step" in pl_sd:
77
+ print(f"Global Step: {pl_sd['global_step']}")
78
+ sd = pl_sd["state_dict"]
79
+ if vae_ckpt is not None:
80
+ print(f"Loading VAE from {vae_ckpt}")
81
+ vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
82
+ sd = {
83
+ k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
84
+ for k, v in sd.items()
85
+ }
86
+ model = instantiate_from_config(config.model)
87
+ m, u = model.load_state_dict(sd, strict=False)
88
+ if len(m) > 0 and verbose:
89
+ print("missing keys:")
90
+ print(m)
91
+ if len(u) > 0 and verbose:
92
+ print("unexpected keys:")
93
+ print(u)
94
+ return model
95
+
96
+
97
+ class CFGDenoiser(nn.Module):
98
+ def __init__(self, model):
99
+ super().__init__()
100
+ self.inner_model = model
101
+
102
+ def forward(self, x, sigma, uncond, cond, cfg_scale):
103
+ x_in = torch.cat([x] * 2)
104
+ sigma_in = torch.cat([sigma] * 2)
105
+ cond_in = torch.cat([uncond, cond])
106
+ uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
107
+ return uncond + (cond - uncond) * cfg_scale
108
+
109
+
110
+ def to_pil(image: torch.Tensor) -> Image.Image:
111
+ image = 255.0 * rearrange(image.cpu().numpy(), "c h w -> h w c")
112
+ image = Image.fromarray(image.astype(np.uint8))
113
+ return image
114
+
115
+
116
+ def main():
117
+ parser = argparse.ArgumentParser()
118
+ parser.add_argument(
119
+ "--out_dir",
120
+ type=str,
121
+ required=True,
122
+ help="Path to output dataset directory.",
123
+ )
124
+ parser.add_argument(
125
+ "--prompts_file",
126
+ type=str,
127
+ required=True,
128
+ help="Path to prompts .jsonl file.",
129
+ )
130
+ parser.add_argument(
131
+ "--ckpt",
132
+ type=str,
133
+ default="stable_diffusion/models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt",
134
+ help="Path to stable diffusion checkpoint.",
135
+ )
136
+ parser.add_argument(
137
+ "--vae-ckpt",
138
+ type=str,
139
+ default="stable_diffusion/models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt",
140
+ help="Path to vae checkpoint.",
141
+ )
142
+ parser.add_argument(
143
+ "--steps",
144
+ type=int,
145
+ default=100,
146
+ help="Number of sampling steps.",
147
+ )
148
+ parser.add_argument(
149
+ "--n-samples",
150
+ type=int,
151
+ default=100,
152
+ help="Number of samples to generate per prompt (before CLIP filtering).",
153
+ )
154
+ parser.add_argument(
155
+ "--max-out-samples",
156
+ type=int,
157
+ default=4,
158
+ help="Max number of output samples to save per prompt (after CLIP filtering).",
159
+ )
160
+ parser.add_argument(
161
+ "--n-partitions",
162
+ type=int,
163
+ default=1,
164
+ help="Number of total partitions.",
165
+ )
166
+ parser.add_argument(
167
+ "--partition",
168
+ type=int,
169
+ default=0,
170
+ help="Partition index.",
171
+ )
172
+ parser.add_argument(
173
+ "--min-p2p",
174
+ type=float,
175
+ default=0.1,
176
+ help="Min prompt2prompt threshold (portion of denoising for which to fix self attention maps).",
177
+ )
178
+ parser.add_argument(
179
+ "--max-p2p",
180
+ type=float,
181
+ default=0.9,
182
+ help="Max prompt2prompt threshold (portion of denoising for which to fix self attention maps).",
183
+ )
184
+ parser.add_argument(
185
+ "--min-cfg",
186
+ type=float,
187
+ default=7.5,
188
+ help="Min classifier free guidance scale.",
189
+ )
190
+ parser.add_argument(
191
+ "--max-cfg",
192
+ type=float,
193
+ default=15,
194
+ help="Max classifier free guidance scale.",
195
+ )
196
+ parser.add_argument(
197
+ "--clip-threshold",
198
+ type=float,
199
+ default=0.2,
200
+ help="CLIP threshold for text-image similarity of each image.",
201
+ )
202
+ parser.add_argument(
203
+ "--clip-dir-threshold",
204
+ type=float,
205
+ default=0.2,
206
+ help="Directional CLIP threshold for similarity of change between pairs of text and pairs of images.",
207
+ )
208
+ parser.add_argument(
209
+ "--clip-img-threshold",
210
+ type=float,
211
+ default=0.7,
212
+ help="CLIP threshold for image-image similarity.",
213
+ )
214
+ opt = parser.parse_args()
215
+
216
+ global_seed = torch.randint(1 << 32, ()).item()
217
+ print(f"Global seed: {global_seed}")
218
+ seed_everything(global_seed)
219
+
220
+ model = load_model_from_config(
221
+ OmegaConf.load("stable_diffusion/configs/stable-diffusion/v1-inference.yaml"),
222
+ ckpt=opt.ckpt,
223
+ vae_ckpt=opt.vae_ckpt,
224
+ )
225
+ model.cuda().eval()
226
+ model_wrap = k_diffusion.external.CompVisDenoiser(model)
227
+
228
+ clip_similarity = ClipSimilarity().cuda()
229
+
230
+ out_dir = Path(opt.out_dir)
231
+ out_dir.mkdir(exist_ok=True, parents=True)
232
+
233
+ with open(opt.prompts_file) as fp:
234
+ prompts = [json.loads(line) for line in fp]
235
+
236
+ print(f"Partition index {opt.partition} ({opt.partition + 1} / {opt.n_partitions})")
237
+ prompts = np.array_split(list(enumerate(prompts)), opt.n_partitions)[opt.partition]
238
+
239
+ with torch.no_grad(), torch.autocast("cuda"), model.ema_scope():
240
+ uncond = model.get_learned_conditioning(2 * [""])
241
+ sigmas = model_wrap.get_sigmas(opt.steps)
242
+
243
+ for i, prompt in tqdm(prompts, desc="Prompts"):
244
+ prompt_dir = out_dir.joinpath(f"{i:07d}")
245
+ prompt_dir.mkdir(exist_ok=True)
246
+
247
+ with open(prompt_dir.joinpath("prompt.json"), "w") as fp:
248
+ json.dump(prompt, fp)
249
+
250
+ cond = model.get_learned_conditioning([prompt["caption"], prompt["output"]])
251
+ results = {}
252
+
253
+ with tqdm(total=opt.n_samples, desc="Samples") as progress_bar:
254
+
255
+ while len(results) < opt.n_samples:
256
+ seed = torch.randint(1 << 32, ()).item()
257
+ if seed in results:
258
+ continue
259
+ torch.manual_seed(seed)
260
+
261
+ x = torch.randn(1, 4, 512 // 8, 512 // 8, device="cuda") * sigmas[0]
262
+ x = repeat(x, "1 ... -> n ...", n=2)
263
+
264
+ model_wrap_cfg = CFGDenoiser(model_wrap)
265
+ p2p_threshold = opt.min_p2p + torch.rand(()).item() * (opt.max_p2p - opt.min_p2p)
266
+ cfg_scale = opt.min_cfg + torch.rand(()).item() * (opt.max_cfg - opt.min_cfg)
267
+ extra_args = {"cond": cond, "uncond": uncond, "cfg_scale": cfg_scale}
268
+ samples_ddim = sample_euler_ancestral(model_wrap_cfg, x, sigmas, p2p_threshold, **extra_args)
269
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
270
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
271
+
272
+ x0 = x_samples_ddim[0]
273
+ x1 = x_samples_ddim[1]
274
+
275
+ clip_sim_0, clip_sim_1, clip_sim_dir, clip_sim_image = clip_similarity(
276
+ x0[None], x1[None], [prompt["caption"]], [prompt["output"]]
277
+ )
278
+
279
+ results[seed] = dict(
280
+ image_0=to_pil(x0),
281
+ image_1=to_pil(x1),
282
+ p2p_threshold=p2p_threshold,
283
+ cfg_scale=cfg_scale,
284
+ clip_sim_0=clip_sim_0[0].item(),
285
+ clip_sim_1=clip_sim_1[0].item(),
286
+ clip_sim_dir=clip_sim_dir[0].item(),
287
+ clip_sim_image=clip_sim_image[0].item(),
288
+ )
289
+
290
+ progress_bar.update()
291
+
292
+ # CLIP filter to get best samples for each prompt.
293
+ metadata = [
294
+ (result["clip_sim_dir"], seed)
295
+ for seed, result in results.items()
296
+ if result["clip_sim_image"] >= opt.clip_img_threshold
297
+ and result["clip_sim_dir"] >= opt.clip_dir_threshold
298
+ and result["clip_sim_0"] >= opt.clip_threshold
299
+ and result["clip_sim_1"] >= opt.clip_threshold
300
+ ]
301
+ metadata.sort(reverse=True)
302
+ for _, seed in metadata[: opt.max_out_samples]:
303
+ result = results[seed]
304
+ image_0 = result.pop("image_0")
305
+ image_1 = result.pop("image_1")
306
+ image_0.save(prompt_dir.joinpath(f"{seed}_0.jpg"), quality=100)
307
+ image_1.save(prompt_dir.joinpath(f"{seed}_1.jpg"), quality=100)
308
+ with open(prompt_dir.joinpath(f"metadata.jsonl"), "a") as fp:
309
+ fp.write(f"{json.dumps(dict(seed=seed, **result))}\n")
310
+
311
+ print("Done.")
312
+
313
+
314
+ if __name__ == "__main__":
315
+ main()
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/dataset_creation/generate_txt_dataset.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import time
5
+ from argparse import ArgumentParser
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import datasets
10
+ import numpy as np
11
+ import openai
12
+ from tqdm.auto import tqdm
13
+
14
+
15
+ DELIMITER_0 = "\n##\n"
16
+ DELIMITER_1 = "\n%%\n"
17
+ STOP = "\nEND"
18
+
19
+
20
+ def generate(
21
+ openai_model: str,
22
+ caption: str,
23
+ num_retries: int = 3,
24
+ max_tokens: int = 256,
25
+ temperature: float = 0.7,
26
+ top_p: float = 1.0,
27
+ frequency_penalty: float = 0.1,
28
+ presence_penalty: float = 0.0,
29
+ sleep_on_error: float = 1.0,
30
+ ) -> Optional[tuple[str, str]]:
31
+ for _ in range(1 + num_retries):
32
+ try:
33
+ response = openai.Completion.create(
34
+ model=openai_model,
35
+ prompt=caption + DELIMITER_0,
36
+ temperature=temperature,
37
+ max_tokens=max_tokens,
38
+ top_p=top_p,
39
+ frequency_penalty=frequency_penalty,
40
+ presence_penalty=presence_penalty,
41
+ stop=[STOP],
42
+ )
43
+ except Exception as e:
44
+ print(e)
45
+ time.sleep(sleep_on_error)
46
+ continue
47
+ output = response["choices"][0]["text"].split(DELIMITER_1)
48
+ if len(output) == 2:
49
+ instruction, edited_caption = output
50
+ results = openai.Moderation.create([instruction, edited_caption])["results"]
51
+ if results[0]["flagged"] or results[1]["flagged"]:
52
+ continue
53
+ if caption.strip().strip(".!?").lower() != edited_caption.strip().strip(".!?").lower():
54
+ return instruction, edited_caption
55
+
56
+
57
+ def main(openai_model: str, num_samples: int, num_partitions: int, partition: int, seed: int):
58
+ dataset = datasets.load_dataset("ChristophSchuhmann/improved_aesthetics_6.5plus", split="train")
59
+ # Other datasets we considered that may be worth trying:
60
+ # dataset = datasets.load_dataset("ChristophSchuhmann/MS_COCO_2017_URL_TEXT", split="train")
61
+ # dataset = datasets.load_dataset("laion/laion-coco", split="train")
62
+
63
+ np.random.seed(seed)
64
+ permutation = np.array_split(np.random.permutation(len(dataset)), num_partitions)[partition]
65
+ dataset = dataset[permutation]
66
+ captions = dataset["TEXT"]
67
+ urls = dataset["URL"]
68
+ output_path = f"data/dataset=laion-aesthetics-6.5_model={openai_model}_samples={num_samples}_partition={partition}.jsonl" # fmt: skip
69
+ print(f"Prompt file path: {output_path}")
70
+
71
+ count = 0
72
+ caption_set = set()
73
+ url_set = set()
74
+
75
+ if Path(output_path).exists():
76
+ with open(output_path, "r") as f:
77
+ for line in tqdm(f, desc="Resuming from existing prompts"):
78
+ prompt = json.loads(line)
79
+ if prompt["caption"] not in caption_set and prompt["url"] not in url_set:
80
+ caption_set.add(prompt["caption"])
81
+ url_set.add(prompt["url"])
82
+ count += 1
83
+
84
+ with open(output_path, "a") as fp:
85
+ with tqdm(total=num_samples - count, desc="Generating instructions and edited captions") as progress_bar:
86
+ for caption, url in zip(captions, urls):
87
+ if caption in caption_set or url in url_set:
88
+ continue
89
+ if openai.Moderation.create(caption)["results"][0]["flagged"]:
90
+ continue
91
+ edit_output = generate(openai_model, caption)
92
+ if edit_output is not None:
93
+ edit, output = edit_output
94
+ fp.write(f"{json.dumps(dict(caption=caption, edit=edit, output=output, url=url))}\n")
95
+ count += 1
96
+ progress_bar.update()
97
+ caption_set.add(caption)
98
+ url_set.add(url)
99
+ if count == num_samples:
100
+ break
101
+
102
+
103
+ if __name__ == "__main__":
104
+ parser = ArgumentParser()
105
+ parser.add_argument("--openai-api-key", required=True, type=str)
106
+ parser.add_argument("--openai-model", required=True, type=str)
107
+ parser.add_argument("--num-samples", default=10000, type=int)
108
+ parser.add_argument("--num-partitions", default=1, type=int)
109
+ parser.add_argument("--partition", default=0, type=int)
110
+ parser.add_argument("--seed", default=0, type=int)
111
+ args = parser.parse_args()
112
+ openai.api_key = args.openai_api_key
113
+ main(args.openai_model, args.num_samples, args.num_partitions, args.partition, args.seed)
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/dataset_creation/prepare_dataset.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from argparse import ArgumentParser
3
+ from pathlib import Path
4
+
5
+ from tqdm.auto import tqdm
6
+
7
+
8
+ def main():
9
+ parser = ArgumentParser()
10
+ parser.add_argument("dataset_dir")
11
+ args = parser.parse_args()
12
+ dataset_dir = Path(args.dataset_dir)
13
+
14
+ seeds = []
15
+ with tqdm(desc="Listing dataset image seeds") as progress_bar:
16
+ for prompt_dir in dataset_dir.iterdir():
17
+ if prompt_dir.is_dir():
18
+ prompt_seeds = [image_path.name.split("_")[0] for image_path in sorted(prompt_dir.glob("*_0.jpg"))]
19
+ if len(prompt_seeds) > 0:
20
+ seeds.append((prompt_dir.name, prompt_seeds))
21
+ progress_bar.update()
22
+ seeds.sort()
23
+
24
+ with open(dataset_dir.joinpath("seeds.json"), "w") as f:
25
+ json.dump(seeds, f)
26
+
27
+
28
+ if __name__ == "__main__":
29
+ main()
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/dataset_creation/prepare_for_gpt.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from argparse import ArgumentParser
3
+
4
+ from generate_txt_dataset import DELIMITER_0, DELIMITER_1, STOP
5
+
6
+
7
+ def main(input_path: str, output_path: str):
8
+ with open(input_path) as f:
9
+ prompts = [json.loads(l) for l in f]
10
+
11
+ with open(output_path, "w") as f:
12
+ for prompt in prompts:
13
+ prompt_for_gpt = {
14
+ "prompt": f"{prompt['input']}{DELIMITER_0}",
15
+ "completion": f"{prompt['edit']}{DELIMITER_1}{prompt['output']}{STOP}",
16
+ }
17
+ f.write(f"{json.dumps(prompt_for_gpt)}\n")
18
+
19
+
20
+ if __name__ == "__main__":
21
+ parser = ArgumentParser()
22
+ parser.add_argument("--input-path", required=True, type=str)
23
+ parser.add_argument("--output-path", required=True, type=str)
24
+ args = parser.parse_args()
25
+ main(args.input_path, args.output_path)
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/edit_app.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import random
5
+ import sys
6
+ from argparse import ArgumentParser
7
+
8
+ import einops
9
+ import gradio as gr
10
+ import k_diffusion as K
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from einops import rearrange
15
+ from omegaconf import OmegaConf
16
+ from PIL import Image, ImageOps
17
+ from torch import autocast
18
+
19
+ sys.path.append("./stable_diffusion")
20
+
21
+ from stable_diffusion.ldm.util import instantiate_from_config
22
+
23
+
24
+ help_text = """
25
+ If you're not getting what you want, there may be a few reasons:
26
+ 1. Is the image not changing enough? Your Image CFG weight may be too high. This value dictates how similar the output should be to the input. It's possible your edit requires larger changes from the original image, and your Image CFG weight isn't allowing that. Alternatively, your Text CFG weight may be too low. This value dictates how much to listen to the text instruction. The default Image CFG of 1.5 and Text CFG of 7.5 are a good starting point, but aren't necessarily optimal for each edit. Try:
27
+ * Decreasing the Image CFG weight, or
28
+ * Incerasing the Text CFG weight, or
29
+ 2. Conversely, is the image changing too much, such that the details in the original image aren't preserved? Try:
30
+ * Increasing the Image CFG weight, or
31
+ * Decreasing the Text CFG weight
32
+ 3. Try generating results with different random seeds by setting "Randomize Seed" and running generation multiple times. You can also try setting "Randomize CFG" to sample new Text CFG and Image CFG values each time.
33
+ 4. Rephrasing the instruction sometimes improves results (e.g., "turn him into a dog" vs. "make him a dog" vs. "as a dog").
34
+ 5. Increasing the number of steps sometimes improves results.
35
+ 6. Do faces look weird? The Stable Diffusion autoencoder has a hard time with faces that are small in the image. Try:
36
+ * Cropping the image so the face takes up a larger portion of the frame.
37
+ """
38
+
39
+
40
+ example_instructions = [
41
+ "Make it a picasso painting",
42
+ "as if it were by modigliani",
43
+ "convert to a bronze statue",
44
+ "Turn it into an anime.",
45
+ "have it look like a graphic novel",
46
+ "make him gain weight",
47
+ "what would he look like bald?",
48
+ "Have him smile",
49
+ "Put him in a cocktail party.",
50
+ "move him at the beach.",
51
+ "add dramatic lighting",
52
+ "Convert to black and white",
53
+ "What if it were snowing?",
54
+ "Give him a leather jacket",
55
+ "Turn him into a cyborg!",
56
+ "make him wear a beanie",
57
+ ]
58
+
59
+
60
+ class CFGDenoiser(nn.Module):
61
+ def __init__(self, model):
62
+ super().__init__()
63
+ self.inner_model = model
64
+
65
+ def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale):
66
+ cfg_z = einops.repeat(z, "1 ... -> n ...", n=3)
67
+ cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3)
68
+ cfg_cond = {
69
+ "c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])],
70
+ "c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])],
71
+ }
72
+ out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3)
73
+ return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)
74
+
75
+
76
+ def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
77
+ print(f"Loading model from {ckpt}")
78
+ pl_sd = torch.load(ckpt, map_location="cpu")
79
+ if "global_step" in pl_sd:
80
+ print(f"Global Step: {pl_sd['global_step']}")
81
+ sd = pl_sd["state_dict"]
82
+ if vae_ckpt is not None:
83
+ print(f"Loading VAE from {vae_ckpt}")
84
+ vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
85
+ sd = {
86
+ k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
87
+ for k, v in sd.items()
88
+ }
89
+ model = instantiate_from_config(config.model)
90
+ m, u = model.load_state_dict(sd, strict=False)
91
+ if len(m) > 0 and verbose:
92
+ print("missing keys:")
93
+ print(m)
94
+ if len(u) > 0 and verbose:
95
+ print("unexpected keys:")
96
+ print(u)
97
+ return model
98
+
99
+
100
+ def main():
101
+ parser = ArgumentParser()
102
+ parser.add_argument("--resolution", default=512, type=int)
103
+ parser.add_argument("--config", default="configs/generate.yaml", type=str)
104
+ parser.add_argument("--ckpt", default="checkpoints/instruct-pix2pix-00-22000.ckpt", type=str)
105
+ parser.add_argument("--vae-ckpt", default=None, type=str)
106
+ args = parser.parse_args()
107
+
108
+ config = OmegaConf.load(args.config)
109
+ model = load_model_from_config(config, args.ckpt, args.vae_ckpt)
110
+ model.eval().cuda()
111
+ model_wrap = K.external.CompVisDenoiser(model)
112
+ model_wrap_cfg = CFGDenoiser(model_wrap)
113
+ null_token = model.get_learned_conditioning([""])
114
+ example_image = Image.open("imgs/example.jpg").convert("RGB")
115
+
116
+ def load_example(
117
+ steps: int,
118
+ randomize_seed: bool,
119
+ seed: int,
120
+ randomize_cfg: bool,
121
+ text_cfg_scale: float,
122
+ image_cfg_scale: float,
123
+ ):
124
+ example_instruction = random.choice(example_instructions)
125
+ return [example_image, example_instruction] + generate(
126
+ example_image,
127
+ example_instruction,
128
+ steps,
129
+ randomize_seed,
130
+ seed,
131
+ randomize_cfg,
132
+ text_cfg_scale,
133
+ image_cfg_scale,
134
+ )
135
+
136
+ def generate(
137
+ input_image: Image.Image,
138
+ instruction: str,
139
+ steps: int,
140
+ randomize_seed: bool,
141
+ seed: int,
142
+ randomize_cfg: bool,
143
+ text_cfg_scale: float,
144
+ image_cfg_scale: float,
145
+ ):
146
+ seed = random.randint(0, 100000) if randomize_seed else seed
147
+ text_cfg_scale = round(random.uniform(6.0, 9.0), ndigits=2) if randomize_cfg else text_cfg_scale
148
+ image_cfg_scale = round(random.uniform(1.2, 1.8), ndigits=2) if randomize_cfg else image_cfg_scale
149
+
150
+ width, height = input_image.size
151
+ factor = args.resolution / max(width, height)
152
+ factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
153
+ width = int((width * factor) // 64) * 64
154
+ height = int((height * factor) // 64) * 64
155
+ input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
156
+
157
+ if instruction == "":
158
+ return [input_image, seed]
159
+
160
+ with torch.no_grad(), autocast("cuda"), model.ema_scope():
161
+ cond = {}
162
+ cond["c_crossattn"] = [model.get_learned_conditioning([instruction])]
163
+ input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
164
+ input_image = rearrange(input_image, "h w c -> 1 c h w").to(model.device)
165
+ cond["c_concat"] = [model.encode_first_stage(input_image).mode()]
166
+
167
+ uncond = {}
168
+ uncond["c_crossattn"] = [null_token]
169
+ uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
170
+
171
+ sigmas = model_wrap.get_sigmas(steps)
172
+
173
+ extra_args = {
174
+ "cond": cond,
175
+ "uncond": uncond,
176
+ "text_cfg_scale": text_cfg_scale,
177
+ "image_cfg_scale": image_cfg_scale,
178
+ }
179
+ torch.manual_seed(seed)
180
+ z = torch.randn_like(cond["c_concat"][0]) * sigmas[0]
181
+ z = K.sampling.sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra_args)
182
+ x = model.decode_first_stage(z)
183
+ x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
184
+ x = 255.0 * rearrange(x, "1 c h w -> h w c")
185
+ edited_image = Image.fromarray(x.type(torch.uint8).cpu().numpy())
186
+
187
+ return [seed, text_cfg_scale, image_cfg_scale, edited_image]
188
+
189
+ def reset():
190
+ return [0, "Randomize Seed", 1371, "Fix CFG", 7.5, 1.5, None]
191
+
192
+ with gr.Blocks(css="footer {visibility: hidden}") as demo:
193
+ with gr.Row():
194
+ with gr.Column(scale=1, min_width=100):
195
+ generate_button = gr.Button("Generate")
196
+ with gr.Column(scale=1, min_width=100):
197
+ load_button = gr.Button("Load Example")
198
+ with gr.Column(scale=1, min_width=100):
199
+ reset_button = gr.Button("Reset")
200
+ with gr.Column(scale=3):
201
+ instruction = gr.Textbox(lines=1, label="Edit Instruction", interactive=True)
202
+
203
+ with gr.Row():
204
+ input_image = gr.Image(label="Input Image", type="pil", interactive=True)
205
+ edited_image = gr.Image(label=f"Edited Image", type="pil", interactive=False)
206
+ input_image.style(height=512, width=512)
207
+ edited_image.style(height=512, width=512)
208
+
209
+ with gr.Row():
210
+ steps = gr.Number(value=100, precision=0, label="Steps", interactive=True)
211
+ randomize_seed = gr.Radio(
212
+ ["Fix Seed", "Randomize Seed"],
213
+ value="Randomize Seed",
214
+ type="index",
215
+ show_label=False,
216
+ interactive=True,
217
+ )
218
+ seed = gr.Number(value=1371, precision=0, label="Seed", interactive=True)
219
+ randomize_cfg = gr.Radio(
220
+ ["Fix CFG", "Randomize CFG"],
221
+ value="Fix CFG",
222
+ type="index",
223
+ show_label=False,
224
+ interactive=True,
225
+ )
226
+ text_cfg_scale = gr.Number(value=7.5, label=f"Text CFG", interactive=True)
227
+ image_cfg_scale = gr.Number(value=1.5, label=f"Image CFG", interactive=True)
228
+
229
+ gr.Markdown(help_text)
230
+
231
+ load_button.click(
232
+ fn=load_example,
233
+ inputs=[
234
+ steps,
235
+ randomize_seed,
236
+ seed,
237
+ randomize_cfg,
238
+ text_cfg_scale,
239
+ image_cfg_scale,
240
+ ],
241
+ outputs=[input_image, instruction, seed, text_cfg_scale, image_cfg_scale, edited_image],
242
+ )
243
+ generate_button.click(
244
+ fn=generate,
245
+ inputs=[
246
+ input_image,
247
+ instruction,
248
+ steps,
249
+ randomize_seed,
250
+ seed,
251
+ randomize_cfg,
252
+ text_cfg_scale,
253
+ image_cfg_scale,
254
+ ],
255
+ outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image],
256
+ )
257
+ reset_button.click(
258
+ fn=reset,
259
+ inputs=[],
260
+ outputs=[steps, randomize_seed, seed, randomize_cfg, text_cfg_scale, image_cfg_scale, edited_image],
261
+ )
262
+
263
+ demo.queue(concurrency_count=1)
264
+ demo.launch(share=True)
265
+
266
+
267
+ if __name__ == "__main__":
268
+ main()
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/edit_cli.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import random
5
+ import sys
6
+ from argparse import ArgumentParser
7
+
8
+ import einops
9
+ import k_diffusion as K
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ from einops import rearrange
14
+ from omegaconf import OmegaConf
15
+ from PIL import Image, ImageOps
16
+ from torch import autocast
17
+
18
+ sys.path.append("./stable_diffusion")
19
+
20
+ from stable_diffusion.ldm.util import instantiate_from_config
21
+
22
+
23
+ class CFGDenoiser(nn.Module):
24
+ def __init__(self, model):
25
+ super().__init__()
26
+ self.inner_model = model
27
+
28
+ def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale):
29
+ cfg_z = einops.repeat(z, "1 ... -> n ...", n=3)
30
+ cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3)
31
+ cfg_cond = {
32
+ "c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])],
33
+ "c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])],
34
+ }
35
+ out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3)
36
+ return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)
37
+
38
+
39
+ def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
40
+ print(f"Loading model from {ckpt}")
41
+ pl_sd = torch.load(ckpt, map_location="cpu")
42
+ if "global_step" in pl_sd:
43
+ print(f"Global Step: {pl_sd['global_step']}")
44
+ sd = pl_sd["state_dict"]
45
+ if vae_ckpt is not None:
46
+ print(f"Loading VAE from {vae_ckpt}")
47
+ vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
48
+ sd = {
49
+ k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
50
+ for k, v in sd.items()
51
+ }
52
+ model = instantiate_from_config(config.model)
53
+ m, u = model.load_state_dict(sd, strict=False)
54
+ if len(m) > 0 and verbose:
55
+ print("missing keys:")
56
+ print(m)
57
+ if len(u) > 0 and verbose:
58
+ print("unexpected keys:")
59
+ print(u)
60
+ return model
61
+
62
+
63
+ def main():
64
+ parser = ArgumentParser()
65
+ parser.add_argument("--resolution", default=512, type=int)
66
+ parser.add_argument("--steps", default=100, type=int)
67
+ parser.add_argument("--config", default="configs/generate.yaml", type=str)
68
+ parser.add_argument("--ckpt", default="checkpoints/instruct-pix2pix-00-22000.ckpt", type=str)
69
+ parser.add_argument("--vae-ckpt", default=None, type=str)
70
+ parser.add_argument("--input", required=True, type=str)
71
+ parser.add_argument("--output", required=True, type=str)
72
+ parser.add_argument("--edit", required=True, type=str)
73
+ parser.add_argument("--cfg-text", default=7.5, type=float)
74
+ parser.add_argument("--cfg-image", default=1.5, type=float)
75
+ parser.add_argument("--seed", type=int)
76
+ args = parser.parse_args()
77
+
78
+ config = OmegaConf.load(args.config)
79
+ model = load_model_from_config(config, args.ckpt, args.vae_ckpt)
80
+ model.eval().cuda()
81
+ model_wrap = K.external.CompVisDenoiser(model)
82
+ model_wrap_cfg = CFGDenoiser(model_wrap)
83
+ null_token = model.get_learned_conditioning([""])
84
+
85
+ seed = random.randint(0, 100000) if args.seed is None else args.seed
86
+ input_image = Image.open(args.input).convert("RGB")
87
+ width, height = input_image.size
88
+ factor = args.resolution / max(width, height)
89
+ factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
90
+ width = int((width * factor) // 64) * 64
91
+ height = int((height * factor) // 64) * 64
92
+ input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
93
+
94
+ if args.edit == "":
95
+ input_image.save(args.output)
96
+ return
97
+
98
+ with torch.no_grad(), autocast("cuda"), model.ema_scope():
99
+ cond = {}
100
+ cond["c_crossattn"] = [model.get_learned_conditioning([args.edit])]
101
+ input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
102
+ input_image = rearrange(input_image, "h w c -> 1 c h w").to(model.device)
103
+ cond["c_concat"] = [model.encode_first_stage(input_image).mode()]
104
+
105
+ uncond = {}
106
+ uncond["c_crossattn"] = [null_token]
107
+ uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
108
+
109
+ sigmas = model_wrap.get_sigmas(args.steps)
110
+
111
+ extra_args = {
112
+ "cond": cond,
113
+ "uncond": uncond,
114
+ "text_cfg_scale": args.cfg_text,
115
+ "image_cfg_scale": args.cfg_image,
116
+ }
117
+ torch.manual_seed(seed)
118
+ z = torch.randn_like(cond["c_concat"][0]) * sigmas[0]
119
+ z = K.sampling.sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra_args)
120
+ x = model.decode_first_stage(z)
121
+ x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
122
+ x = 255.0 * rearrange(x, "1 c h w -> h w c")
123
+ edited_image = Image.fromarray(x.type(torch.uint8).cpu().numpy())
124
+ edited_image.save(args.output)
125
+
126
+
127
+ if __name__ == "__main__":
128
+ main()
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/edit_dataset.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import math
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torchvision
11
+ from einops import rearrange
12
+ from PIL import Image
13
+ from torch.utils.data import Dataset
14
+
15
+
16
+ class EditDataset(Dataset):
17
+ def __init__(
18
+ self,
19
+ path: str,
20
+ split: str = "train",
21
+ splits: tuple[float, float, float] = (0.9, 0.05, 0.05),
22
+ min_resize_res: int = 256,
23
+ max_resize_res: int = 256,
24
+ crop_res: int = 256,
25
+ flip_prob: float = 0.0,
26
+ ):
27
+ assert split in ("train", "val", "test")
28
+ assert sum(splits) == 1
29
+ self.path = path
30
+ self.min_resize_res = min_resize_res
31
+ self.max_resize_res = max_resize_res
32
+ self.crop_res = crop_res
33
+ self.flip_prob = flip_prob
34
+
35
+ with open(Path(self.path, "seeds.json")) as f:
36
+ self.seeds = json.load(f)
37
+
38
+ split_0, split_1 = {
39
+ "train": (0.0, splits[0]),
40
+ "val": (splits[0], splits[0] + splits[1]),
41
+ "test": (splits[0] + splits[1], 1.0),
42
+ }[split]
43
+
44
+ idx_0 = math.floor(split_0 * len(self.seeds))
45
+ idx_1 = math.floor(split_1 * len(self.seeds))
46
+ self.seeds = self.seeds[idx_0:idx_1]
47
+
48
+ def __len__(self) -> int:
49
+ return len(self.seeds)
50
+
51
+ def __getitem__(self, i: int) -> dict[str, Any]:
52
+ name, seeds = self.seeds[i]
53
+ propt_dir = Path(self.path, name)
54
+ seed = seeds[torch.randint(0, len(seeds), ()).item()]
55
+ with open(propt_dir.joinpath("prompt.json")) as fp:
56
+ prompt = json.load(fp)["edit"]
57
+
58
+ image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg"))
59
+ image_1 = Image.open(propt_dir.joinpath(f"{seed}_1.jpg"))
60
+
61
+ reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item()
62
+ image_0 = image_0.resize((reize_res, reize_res), Image.Resampling.LANCZOS)
63
+ image_1 = image_1.resize((reize_res, reize_res), Image.Resampling.LANCZOS)
64
+
65
+ image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w")
66
+ image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w")
67
+
68
+ crop = torchvision.transforms.RandomCrop(self.crop_res)
69
+ flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
70
+ image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2)
71
+
72
+ return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt))
73
+
74
+
75
+ class EditDatasetEval(Dataset):
76
+ def __init__(
77
+ self,
78
+ path: str,
79
+ split: str = "train",
80
+ splits: tuple[float, float, float] = (0.9, 0.05, 0.05),
81
+ res: int = 256,
82
+ ):
83
+ assert split in ("train", "val", "test")
84
+ assert sum(splits) == 1
85
+ self.path = path
86
+ self.res = res
87
+
88
+ with open(Path(self.path, "seeds.json")) as f:
89
+ self.seeds = json.load(f)
90
+
91
+ split_0, split_1 = {
92
+ "train": (0.0, splits[0]),
93
+ "val": (splits[0], splits[0] + splits[1]),
94
+ "test": (splits[0] + splits[1], 1.0),
95
+ }[split]
96
+
97
+ idx_0 = math.floor(split_0 * len(self.seeds))
98
+ idx_1 = math.floor(split_1 * len(self.seeds))
99
+ self.seeds = self.seeds[idx_0:idx_1]
100
+
101
+ def __len__(self) -> int:
102
+ return len(self.seeds)
103
+
104
+ def __getitem__(self, i: int) -> dict[str, Any]:
105
+ name, seeds = self.seeds[i]
106
+ propt_dir = Path(self.path, name)
107
+ seed = seeds[torch.randint(0, len(seeds), ()).item()]
108
+ with open(propt_dir.joinpath("prompt.json")) as fp:
109
+ prompt = json.load(fp)
110
+ edit = prompt["edit"]
111
+ input_prompt = prompt["input"]
112
+ output_prompt = prompt["output"]
113
+
114
+ image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg"))
115
+
116
+ reize_res = torch.randint(self.res, self.res + 1, ()).item()
117
+ image_0 = image_0.resize((reize_res, reize_res), Image.Resampling.LANCZOS)
118
+
119
+ image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w")
120
+
121
+ return dict(image_0=image_0, input_prompt=input_prompt, edit=edit, output_prompt=output_prompt)
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/environment.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
2
+ # See more details in LICENSE.
3
+
4
+ name: ip2p
5
+ channels:
6
+ - pytorch
7
+ - defaults
8
+ dependencies:
9
+ - python=3.8.5
10
+ - pip=20.3
11
+ - cudatoolkit=11.3
12
+ - pytorch=1.11.0
13
+ - torchvision=0.12.0
14
+ - numpy=1.19.2
15
+ - pip:
16
+ - albumentations==0.4.3
17
+ - datasets==2.8.0
18
+ - diffusers
19
+ - opencv-python==4.1.2.30
20
+ - pudb==2019.2
21
+ - invisible-watermark
22
+ - imageio==2.9.0
23
+ - imageio-ffmpeg==0.4.2
24
+ - pytorch-lightning==1.4.2
25
+ - omegaconf==2.1.1
26
+ - test-tube>=0.7.5
27
+ - streamlit>=0.73.1
28
+ - einops==0.3.0
29
+ - torch-fidelity==0.3.0
30
+ - transformers==4.19.2
31
+ - torchmetrics==0.6.0
32
+ - kornia==0.6
33
+ - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
34
+ - -e git+https://github.com/openai/CLIP.git@main#egg=clip
35
+ - openai
36
+ - gradio
37
+ - seaborn
38
+ - git+https://github.com/crowsonkb/k-diffusion.git
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/imgs/dataset.jpg ADDED

Git LFS Details

  • SHA256: 9393c97383a25b93e0bfef374033b6c6d953124bb1f37181de7345c725ff8dbc
  • Pointer size: 131 Bytes
  • Size of remote file: 100 kB
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/imgs/edit_app.jpg ADDED

Git LFS Details

  • SHA256: 395761cecae1d5442f8c48435f552f129b3425df1bccaaff82526d49bf5aeaea
  • Pointer size: 131 Bytes
  • Size of remote file: 402 kB
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/imgs/example.jpg ADDED
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/imgs/prompt_app.jpg ADDED
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/inference_full_ct_2d_with_body_mask.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import nibabel as nib
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+ from skimage import io
7
+ from einops import rearrange
8
+ from omegaconf import OmegaConf
9
+ from torch.utils.data import DataLoader
10
+ from tqdm import tqdm
11
+ from argparse import ArgumentParser
12
+ import sys
13
+ sys.path.append("./stable_diffusion")
14
+ import torch.nn as nn
15
+ from ldm.util import instantiate_from_config
16
+ import k_diffusion as K
17
+ from torch import autocast
18
+ import random
19
+ import einops
20
+ from PIL import Image, ImageOps
21
+ class CFGDenoiser(nn.Module):
22
+ def __init__(self, model):
23
+ super().__init__()
24
+ self.inner_model = model
25
+
26
+ def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale):
27
+ cfg_z = einops.repeat(z, "1 ... -> n ...", n=3)
28
+ cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3)
29
+ cfg_cond = {
30
+ "c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])],
31
+ "c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])],
32
+ }
33
+ out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3)
34
+ return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)
35
+
36
+
37
+ def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
38
+ print(f"Loading model from {ckpt}")
39
+ pl_sd = torch.load(ckpt, map_location="cpu")
40
+ if "global_step" in pl_sd:
41
+ print(f"Global Step: {pl_sd['global_step']}")
42
+ sd = pl_sd["state_dict"]
43
+ if vae_ckpt is not None:
44
+ print(f"Loading VAE from {vae_ckpt}")
45
+ vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
46
+ sd = {
47
+ k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
48
+ for k, v in sd.items()
49
+ }
50
+ model = instantiate_from_config(config.model)
51
+ m, u = model.load_state_dict(sd, strict=False)
52
+ if len(m) > 0 and verbose:
53
+ print("missing keys:")
54
+ print(m)
55
+ if len(u) > 0 and verbose:
56
+ print("unexpected keys:")
57
+ print(u)
58
+ return model
59
+
60
+ parser = ArgumentParser()
61
+ parser.add_argument("--resolution", default=512, type=int)
62
+ parser.add_argument("--steps", default=100, type=int)
63
+ parser.add_argument("--config", default="configs/generate.yaml", type=str)
64
+ # parser.add_argument("--ckpt", default="checkpoints/instruct-pix2pix-00-22000.ckpt", type=str)
65
+ parser.add_argument("--ckpt", default="logs/train_train_instructpix2pix/checkpoints/epoch=000756.ckpt", type=str)
66
+ parser.add_argument("--vae-ckpt", default=None, type=str)
67
+
68
+ parser.add_argument("--cfg-text", default=7.5, type=float)
69
+ parser.add_argument("--cfg-image", default=1.5, type=float)
70
+ parser.add_argument("--seed", type=int)
71
+ args = parser.parse_args()
72
+ seed = random.randint(0, 100000) if args.seed is None else args.seed
73
+
74
+ config = OmegaConf.load(args.config)
75
+ model = load_model_from_config(config, args.ckpt, args.vae_ckpt)
76
+ model.eval().cuda()
77
+ model_wrap = K.external.CompVisDenoiser(model)
78
+ model_wrap_cfg = CFGDenoiser(model_wrap)
79
+ null_token = model.get_learned_conditioning([""])
80
+
81
+
82
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
83
+ print("Using device", device)
84
+ model = model.to(device)
85
+
86
+ config = OmegaConf.load('./configs/train.yaml')
87
+ data = instantiate_from_config(config.data)
88
+ data.prepare_data()
89
+ data.setup()
90
+
91
+ save_path = 'logs/train_instructpix2pix/inference'
92
+ if not os.path.exists(save_path):
93
+ os.makedirs(save_path)
94
+
95
+ val_dataset = data.datasets['validation']
96
+ batch_size = 1
97
+ valloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
98
+ val_num = len(val_dataset)
99
+ save_gt = True
100
+ # val_num = 10
101
+ # breakpoint()
102
+ for idx, data in tqdm(enumerate(valloader)):
103
+ name=data['name'][0].split('.')[0]
104
+
105
+ # breakpoint()
106
+ z, cond = model.get_input(data, model.first_stage_key)
107
+
108
+ with torch.no_grad(), autocast("cuda"), model.ema_scope():
109
+ uncond = {}
110
+ uncond["c_crossattn"] = [null_token]
111
+ uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
112
+
113
+ sigmas = model_wrap.get_sigmas(args.steps)
114
+
115
+ extra_args = {
116
+ "cond": cond,
117
+ "uncond": uncond,
118
+ "text_cfg_scale": args.cfg_text,
119
+ "image_cfg_scale": args.cfg_image,
120
+ }
121
+ torch.manual_seed(seed)
122
+ z = torch.randn_like(cond["c_concat"][0]) * sigmas[0]
123
+ z = K.sampling.sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra_args)
124
+ x = model.decode_first_stage(z)
125
+ x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
126
+ x = 255.0 * rearrange(x, "1 c h w -> h w c")
127
+ # breakpoint()
128
+ edited_image = x.type(torch.uint8).cpu().numpy()
129
+ # edited_image = Image.fromarray(x.type(torch.uint8).cpu().numpy())
130
+
131
+ input_img = (data['edit']['c_concat'][0].detach().cpu().numpy() + 1.0) / 2.0
132
+ input_img = np.clip(input_img ,0,1) * 255
133
+
134
+ io.imsave(os.path.join(save_path, str(name) + f'edited_image_{idx}.png'), edited_image.astype(np.uint8))
135
+ io.imsave(os.path.join(save_path, str(name) + f'input_image_{idx}.png'), input_img.astype(np.uint8))
136
+
137
+ # break
138
+
139
+
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/inference_full_ct_3d_with_body_mask.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import nibabel as nib
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+ import sys
7
+ sys.path.append("./stable_diffusion")
8
+ from einops import rearrange
9
+ from omegaconf import OmegaConf
10
+ from torch.utils.data import DataLoader
11
+ from tqdm import tqdm
12
+
13
+ from ldm.util import instantiate_from_config
14
+ import argparse
15
+
16
+ def compute_orientation(init_axcodes, final_axcodes):
17
+ """
18
+ A thin wrapper around ``nib.orientations.ornt_transform``
19
+
20
+ :param init_axcodes: Initial orientation codes
21
+ :param final_axcodes: Target orientation codes
22
+ :return: orientations array, start_ornt, end_ornt
23
+ """
24
+ ornt_init = nib.orientations.axcodes2ornt(init_axcodes)
25
+ ornt_fin = nib.orientations.axcodes2ornt(final_axcodes)
26
+
27
+ ornt_transf = nib.orientations.ornt_transform(ornt_init, ornt_fin)
28
+
29
+ return ornt_transf, ornt_init, ornt_fin
30
+
31
+ def do_reorientation(data_array, init_axcodes, final_axcodes):
32
+ """
33
+ source: https://niftynet.readthedocs.io/en/dev/_modules/niftynet/io/misc_io.html#do_reorientation
34
+ Performs the reorientation (changing order of axes)
35
+
36
+ :param data_array: 3D Array to reorient
37
+ :param init_axcodes: Initial orientation
38
+ :param final_axcodes: Target orientation
39
+ :return data_reoriented: New data array in its reoriented form
40
+ """
41
+ ornt_transf, ornt_init, ornt_fin = compute_orientation(init_axcodes, final_axcodes)
42
+ if np.array_equal(ornt_init, ornt_fin):
43
+ return data_array
44
+
45
+ return nib.orientations.apply_orientation(data_array, ornt_transf)
46
+
47
+ def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
48
+ print(f"Loading model from {ckpt}")
49
+ pl_sd = torch.load(ckpt, map_location="cpu")
50
+ if "global_step" in pl_sd:
51
+ print(f"Global Step: {pl_sd['global_step']}")
52
+ sd = pl_sd["state_dict"]
53
+ if vae_ckpt is not None:
54
+ print(f"Loading VAE from {vae_ckpt}")
55
+ vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
56
+ sd = {
57
+ k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
58
+ for k, v in sd.items()
59
+ }
60
+ model = instantiate_from_config(config.model)
61
+ m, u = model.load_state_dict(sd, strict=False)
62
+ if len(m) > 0 and verbose:
63
+ print("missing keys:")
64
+ print(m)
65
+ if len(u) > 0 and verbose:
66
+ print("unexpected keys:")
67
+ print(u)
68
+ return model
69
+
70
+ parser = argparse.ArgumentParser()
71
+ parser.add_argument(
72
+ "-t",
73
+ "--time_steps",
74
+ type=int,
75
+ default=20,
76
+ )
77
+ parser.add_argument("--vae-ckpt", default=None, type=str)
78
+ parser.add_argument("--cfg-text", default=7.5, type=float)
79
+ parser.add_argument("--cfg-image", default=1.5, type=float)
80
+ args = parser.parse_args()
81
+ ddim_steps=args.time_steps
82
+ # breakpoint()
83
+
84
+ logdir = 'logs/full_ct_3d_with_body_mask'
85
+ ckpt = '/sd/jifu/projects/B200_logs/train_3d_opacity_opacity_3d/checkpoints/trainstep_checkpoints/epoch=000062-step=000001000.ckpt'
86
+
87
+
88
+ config = OmegaConf.load("configs/infer3d.yaml")
89
+ model = load_model_from_config(config, ckpt, args.vae_ckpt)
90
+ model.eval().cuda()
91
+
92
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
93
+ print("Using device", device)
94
+ model = model.to(device)
95
+
96
+ data = instantiate_from_config(config.data)
97
+ data.prepare_data()
98
+ data.setup()
99
+
100
+ save_path = 'logs/train_3d_train_instructpix2pix/inference_nii'
101
+ if not os.path.exists(save_path):
102
+ os.makedirs(save_path)
103
+
104
+ val_dataset = data.datasets['validation']
105
+ batch_size = 1
106
+ valloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
107
+ val_num = len(val_dataset)
108
+ save_gt = True
109
+
110
+ for idx, data in tqdm(enumerate(valloader)):
111
+
112
+ if idx >= val_num:
113
+ break
114
+
115
+ name=data['name'][0].split('.')[0]
116
+ target = data['edited']
117
+ reference = data['c_concat']
118
+ input_text = data['c_crossattn']
119
+ # breakpoint()
120
+
121
+ # return dict(name=file_name, edited=volume_seg.float(), c_concat=video_tensor.float(), c_crossattn=input_text)
122
+
123
+ window_length = 64
124
+ h = 0
125
+ slice_num =target.shape[2]
126
+ result = torch.zeros((batch_size, slice_num, 4, 32, 64)).cuda()
127
+
128
+
129
+ upper_iters = (slice_num-h) // (window_length-h)+1 if (slice_num-h)%(window_length-h) != 0 else (slice_num-h) // (window_length-h)
130
+ print('upper_iters', upper_iters)
131
+ # breakpoint()
132
+ for i in range(upper_iters):
133
+ print('i', i)
134
+ input_data={}
135
+ if i == upper_iters-1:
136
+ input_data['name'] = data['name']
137
+ input_data['edited'] = torch.cat([reference[:, :, -window_length:], target[:, :, -window_length:]],dim=-1).to(device)
138
+ input_data['edit']={}
139
+ input_data['edit']['c_concat'] = torch.cat([reference[:, :, -window_length:], torch.ones_like(reference[:, :, -window_length:])*-1],dim=-1).to(device)
140
+ input_data['edit']['c_crossattn'] = input_text
141
+ else:
142
+ input_data['edited'] = torch.cat([reference[:, :, i*window_length-i*h:(i+1)*window_length-i*h], target[:, :, i*window_length-i*h:(i+1)*window_length-i*h]],dim=-1).to(device)
143
+ input_data['edit']={}
144
+ input_data['edit']['c_concat'] = torch.cat([reference[:, :, i*window_length-i*h:(i+1)*window_length-i*h], torch.ones_like(reference[:, :, i*window_length-i*h:(i+1)*window_length-i*h])*-1],dim=-1).to(device)
145
+ input_data['edit']['c_crossattn'] = input_text
146
+ # breakpoint()
147
+
148
+ with torch.no_grad():
149
+ z, cond = model.get_input(input_data, model.first_stage_key)
150
+ # breakpoint()
151
+ if i == 0:
152
+ samples_i, _ = model.sample_log(cond=cond, batch_size=window_length, ddim=True, eta=1., ddim_steps=ddim_steps)
153
+ else:
154
+ samples_i, _ = model.sample_log(cond=cond, batch_size=window_length, ddim=True, eta=1., ddim_steps=ddim_steps, previous=x_minus1)
155
+ # breakpoint()
156
+ samples_i = rearrange(samples_i, '(b z) c h w -> b z c h w', z=window_length)
157
+
158
+ if i == upper_iters-1:
159
+ result[:, -window_length+h:] = samples_i[:,h:,...]
160
+ else:
161
+ if i == 0:
162
+ result[:, :window_length] = samples_i
163
+ else:
164
+ result[:, i*window_length-i*h+h:(i+1)*window_length-i*h] = samples_i[:, h:]
165
+ x_minus1 = samples_i[:, -h:,...]
166
+ # breakpoint()
167
+ result = rearrange(result, 'b z c h w -> (b z) c h w')
168
+ x_result = torch.zeros((result.shape[0],3,256,512))
169
+ # breakpoint()
170
+ dec_unit = 64
171
+ num_dec_iter = slice_num // dec_unit + 1 if slice_num % dec_unit != 0 else slice_num // dec_unit
172
+ for i in range(num_dec_iter):
173
+ if i == num_dec_iter - 1:
174
+ x_result[-dec_unit:] = model.decode_first_stage(result[-dec_unit:])
175
+ x_result[i*dec_unit:(i+1)*dec_unit] = model.decode_first_stage(result[i*dec_unit:(i+1)*dec_unit])
176
+ x_result[x_result>1.0] = 1.0
177
+ x_result[x_result<-1.0] = -1.0
178
+ x_result = (x_result+1)/2
179
+ x_result = rearrange(x_result, '(b z) c h w -> b z c h w', z=slice_num)
180
+ x_result_ = x_result[0].mean(axis=1).detach().cpu().numpy()
181
+ # x_result = x_result[0,:,0,...].detach().cpu().numpy()
182
+
183
+ x_result = x_result_.transpose(2,1,0)
184
+ # x_result = np.rot90(x_result, k=1, axes=(0,1))
185
+ # x_result = np.flip(x_result,axis=(0,1))
186
+ # import imageio as io
187
+ # io.imsave('exp.png', (x_result[:,:,400]*255).astype(np.uint8))
188
+
189
+ # breakpoint()
190
+ ref_root = '/sd/shuhan/CT-RATE/dataset/valid_fixed'
191
+ ref_nii = os.path.join(ref_root, name.split('_')[0]+'_'+name.split('_')[1], name.split('_')[0]+'_'+name.split('_')[1]+'_'+name.split('_')[2],name+'.nii.gz')
192
+ affine = nib.load(ref_nii).affine
193
+
194
+ x_result = x_result*2000.0 - 1000.0
195
+ data_path = os.path.join(save_path, str(f'{name}.nii.gz'))
196
+ data_nii = nib.Nifti1Image(x_result.astype(np.int16), affine)
197
+
198
+ nib.save(data_nii, data_path)
199
+
200
+ # breakpoint()
201
+
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/inference_full_ct_3d_with_body_mask_v2.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import nibabel as nib
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+ from skimage import io
7
+ from einops import rearrange
8
+ from omegaconf import OmegaConf
9
+ from torch.utils.data import DataLoader
10
+ from tqdm import tqdm
11
+ from argparse import ArgumentParser
12
+ import sys
13
+ sys.path.append("./stable_diffusion")
14
+ import torch.nn as nn
15
+ from ldm.util import instantiate_from_config
16
+ import k_diffusion as K
17
+ from torch import autocast
18
+ import random
19
+ import einops
20
+ from PIL import Image, ImageOps
21
+ class CFGDenoiser(nn.Module):
22
+ def __init__(self, model):
23
+ super().__init__()
24
+ self.inner_model = model
25
+
26
+ def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale):
27
+ cfg_z = einops.repeat(z, "1 ... -> n ...", n=3)
28
+ cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3)
29
+ cfg_cond = {
30
+ "c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])],
31
+ "c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])],
32
+ }
33
+ out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3)
34
+ return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)
35
+
36
+
37
+ def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
38
+ print(f"Loading model from {ckpt}")
39
+ pl_sd = torch.load(ckpt, map_location="cpu")
40
+ if "global_step" in pl_sd:
41
+ print(f"Global Step: {pl_sd['global_step']}")
42
+ sd = pl_sd["state_dict"]
43
+ if vae_ckpt is not None:
44
+ print(f"Loading VAE from {vae_ckpt}")
45
+ vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
46
+ sd = {
47
+ k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
48
+ for k, v in sd.items()
49
+ }
50
+ model = instantiate_from_config(config.model)
51
+ m, u = model.load_state_dict(sd, strict=False)
52
+ if len(m) > 0 and verbose:
53
+ print("missing keys:")
54
+ print(m)
55
+ if len(u) > 0 and verbose:
56
+ print("unexpected keys:")
57
+ print(u)
58
+ return model
59
+
60
+
61
+ parser = ArgumentParser()
62
+ parser.add_argument("--resolution", default=512, type=int)
63
+ parser.add_argument("--steps", default=20, type=int)
64
+ parser.add_argument("--config", default="configs/generate.yaml", type=str)
65
+ parser.add_argument("--ckpt", default="/sd/jifu/projects/B200_logs/train_3d_opacity_opacity_3d/checkpoints/epoch=000929.ckpt", type=str)
66
+ parser.add_argument("--vae-ckpt", default=None, type=str)
67
+
68
+ parser.add_argument("--cfg-text", default=7.5, type=float)
69
+ parser.add_argument("--cfg-image", default=1.5, type=float)
70
+ parser.add_argument("--seed", type=int)
71
+ args = parser.parse_args()
72
+ seed = random.randint(0, 100000) if args.seed is None else args.seed
73
+
74
+ config = OmegaConf.load(args.config)
75
+ model = load_model_from_config(config, args.ckpt, args.vae_ckpt)
76
+ model.eval().cuda()
77
+ model_wrap = K.external.CompVisDenoiser(model)
78
+ model_wrap_cfg = CFGDenoiser(model_wrap)
79
+ null_token = model.get_learned_conditioning([""])
80
+
81
+
82
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
83
+ print("Using device", device)
84
+ model = model.to(device)
85
+
86
+ config = OmegaConf.load('./configs/infer.yaml')
87
+ data = instantiate_from_config(config.data)
88
+ data.prepare_data()
89
+ data.setup()
90
+
91
+ save_path = 'logs/train_instructpix2pix/inference'
92
+ if not os.path.exists(save_path):
93
+ os.makedirs(save_path)
94
+
95
+ val_dataset = data.datasets['validation']
96
+ batch_size = 1
97
+ valloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
98
+ val_num = len(val_dataset)
99
+ save_gt = True
100
+
101
+ for idx, data in tqdm(enumerate(valloader)):
102
+ # if idx >= val_num:
103
+ # break
104
+ name=data['name'][0].split('.')[0]
105
+ disease_mask_channel=data['disease_mask_channel']
106
+ target = data['edited']
107
+ reference = data['c_concat']
108
+ input_text = data['c_crossattn']
109
+
110
+ window_length = 1
111
+ h = 0
112
+ slice_num =target.shape[2]
113
+ result = torch.zeros((batch_size, slice_num, 4, 32, 64)).cuda()
114
+ upper_iters = (slice_num-h) // (window_length-h)+1 if (slice_num-h)%(window_length-h) != 0 else (slice_num-h) // (window_length-h)
115
+ print('upper_iters', upper_iters)
116
+ # breakpoint()
117
+ for i in range(upper_iters):
118
+ print('i', i)
119
+ input_data={}
120
+
121
+ input_data['edited'] = torch.cat([reference[:, :, i], target[:, :, i]],dim=-1).to(device)
122
+ input_data['edit']={}
123
+ input_data['edit']['c_concat'] = torch.cat([reference[:, :, i], torch.ones_like(reference[:, :, i])*-1],dim=-1).to(device)
124
+ input_data['edit']['c_crossattn'] = input_text
125
+ # breakpoint()
126
+
127
+ z, cond = model.get_input(input_data, model.first_stage_key)
128
+
129
+ with torch.no_grad(), autocast("cuda"), model.ema_scope():
130
+ uncond = {}
131
+ uncond["c_crossattn"] = [null_token]
132
+ uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
133
+
134
+ sigmas = model_wrap.get_sigmas(args.steps)
135
+
136
+ extra_args = {
137
+ "cond": cond,
138
+ "uncond": uncond,
139
+ "text_cfg_scale": args.cfg_text,
140
+ "image_cfg_scale": args.cfg_image,
141
+ }
142
+ torch.manual_seed(seed)
143
+ z = torch.randn_like(cond["c_concat"][0]) * sigmas[0]
144
+ samples_i = K.sampling.sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra_args)
145
+
146
+ samples_i = rearrange(samples_i, '(b z) c h w -> b z c h w', z=window_length)
147
+
148
+
149
+ result[:, i] = samples_i[:,0]
150
+ # breakpoint()
151
+ result = rearrange(result, 'b z c h w -> (b z) c h w')
152
+ x_result = torch.zeros((result.shape[0],3,256,512))
153
+ dec_unit = 1
154
+ num_dec_iter = slice_num // dec_unit + 1 if slice_num % dec_unit != 0 else slice_num // dec_unit
155
+ for i in range(num_dec_iter):
156
+ if i == num_dec_iter - 1:
157
+ x_result[-dec_unit:] = model.decode_first_stage(result[-dec_unit:])
158
+ x_result[i*dec_unit:(i+1)*dec_unit] = model.decode_first_stage(result[i*dec_unit:(i+1)*dec_unit])
159
+
160
+ # breakpoint()
161
+ x_result = x_result.mean(axis=1)
162
+ # breakpoint()
163
+
164
+ # x_result = torch.clamp((x_result/x_result.abs().max() + 1.0) / 2.0, min=0.0, max=1.0)
165
+ x_result = torch.clamp((x_result + 1.0) / 2.0, min=0.0, max=1.0)
166
+ x_result = 255.0 * rearrange(x_result.detach().cpu().numpy(), "d h w -> h w d")
167
+ # x_result = 255.0 * x_result.detach().cpu().numpy()
168
+
169
+ diff_result = np.abs(x_result[:,:256,:] - x_result[:,256:,:])
170
+ syn_result = x_result[:,256:,:]
171
+
172
+ ref_root = '/sd/shuhan/CT-RATE/dataset/valid_fixed'
173
+ ref_nii = os.path.join(ref_root, name.split('_')[0]+'_'+name.split('_')[1], name.split('_')[0]+'_'+name.split('_')[1]+'_'+name.split('_')[2],name+'.nii.gz')
174
+ affine = nib.load(ref_nii).affine
175
+ data_path = os.path.join(save_path, str(f'{name}_{disease_mask_channel[0]}.nii.gz'))
176
+ data_nii = nib.Nifti1Image(x_result.astype(np.uint16), affine)
177
+ nib.save(data_nii, data_path)
178
+
179
+ diff_data_path = os.path.join(save_path, str(f'{name}_diff_{disease_mask_channel[0]}.nii.gz'))
180
+ diff_nii = nib.Nifti1Image(diff_result.astype(np.uint16), affine)
181
+ nib.save(diff_nii, diff_data_path)
182
+
183
+ syn_data_path = os.path.join(save_path, str(f'{name}_syn_{disease_mask_channel[0]}.nii.gz'))
184
+ syn_nii = nib.Nifti1Image(syn_result.astype(np.uint16), affine)
185
+ nib.save(syn_nii, syn_data_path)
186
+
187
+ disease_root = '/sd/shuhan/CT-RATE/seg_rxg'
188
+ disease_path = os.path.join(disease_root, str(f'{name}.nii.gz'))
189
+ nii_img = nib.load(str(disease_path))
190
+ disease_data = nii_img.get_fdata()[int(disease_mask_channel[0])]
191
+ data_nii = nib.Nifti1Image(disease_data.astype(np.uint16), affine)
192
+ nib.save(data_nii, os.path.join(save_path, str(f'{name}_disease_{disease_mask_channel[0]}.nii.gz')))
193
+
194
+ print('input_text', input_text)
195
+ breakpoint()
196
+
197
+
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/checkpoints/epoch=000096.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51a1a050c2cbf06d4b3de25d461772d786ed82d94874d3f0d372de7eb013d91c
3
+ size 14872250099
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/checkpoints/epoch=000097.ckpt ADDED
File without changes
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/checkpoints/epoch=000145.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66512dded55f036df038922c67eb38fa7dcde54b602063bdd9517679160aef47
3
+ size 14872249655
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/checkpoints/epoch=000146.ckpt ADDED
File without changes
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/checkpoints/last.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66512dded55f036df038922c67eb38fa7dcde54b602063bdd9517679160aef47
3
+ size 14872249655
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/configs/2026-01-13T15-35-13-lightning.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lightning:
2
+ callbacks:
3
+ image_logger:
4
+ target: main.ImageLogger
5
+ params:
6
+ batch_frequency: 200000000
7
+ max_images: 2
8
+ increase_log_steps: false
9
+ trainer:
10
+ max_epochs: 2000
11
+ benchmark: true
12
+ accumulate_grad_batches: 4
13
+ check_val_every_n_epoch: 10000
14
+ accelerator: ddp
15
+ gpus: 0,1,2,3,
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/configs/2026-01-13T15-35-13-project.yaml ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 0.0001
3
+ target: ldm.models.diffusion.ddpm_edit.LatentDiffusion
4
+ params:
5
+ ckpt_path: stable_diffusion/models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
6
+ linear_start: 0.00085
7
+ linear_end: 0.012
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: edited
12
+ cond_stage_key: edit
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: false
16
+ conditioning_key: hybrid
17
+ monitor: val/loss_simple_ema
18
+ scale_factor: 0.18215
19
+ use_ema: true
20
+ load_ema: false
21
+ scheduler_config:
22
+ target: ldm.lr_scheduler.LambdaLinearScheduler
23
+ params:
24
+ warm_up_steps:
25
+ - 0
26
+ cycle_lengths:
27
+ - 10000000000000
28
+ f_start:
29
+ - 1.0e-06
30
+ f_max:
31
+ - 1.0
32
+ f_min:
33
+ - 1.0
34
+ unet_config:
35
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
36
+ params:
37
+ image_size: 32
38
+ in_channels: 8
39
+ out_channels: 4
40
+ model_channels: 320
41
+ attention_resolutions:
42
+ - 4
43
+ - 2
44
+ - 1
45
+ num_res_blocks: 2
46
+ channel_mult:
47
+ - 1
48
+ - 2
49
+ - 4
50
+ - 4
51
+ num_heads: 8
52
+ use_spatial_transformer: true
53
+ transformer_depth: 1
54
+ context_dim: 768
55
+ use_checkpoint: true
56
+ legacy: false
57
+ first_stage_config:
58
+ target: ldm.models.autoencoder.AutoencoderKL
59
+ params:
60
+ embed_dim: 4
61
+ monitor: val/rec_loss
62
+ ddconfig:
63
+ double_z: true
64
+ z_channels: 4
65
+ resolution: 256
66
+ in_channels: 3
67
+ out_ch: 3
68
+ ch: 128
69
+ ch_mult:
70
+ - 1
71
+ - 2
72
+ - 4
73
+ - 4
74
+ num_res_blocks: 2
75
+ attn_resolutions: []
76
+ dropout: 0.0
77
+ lossconfig:
78
+ target: torch.nn.Identity
79
+ cond_stage_config:
80
+ target: ldm.modules.encoders.modules.FrozenBioMedCLIPEmbedder
81
+ data:
82
+ target: main.DataModuleFromConfig
83
+ params:
84
+ batch_size: 16
85
+ num_workers: 8
86
+ train:
87
+ target: ldm.data.ct_clip_data_train.CTReportDataset
88
+ params:
89
+ data_folder: /sd/shuhan/CT-RATE/dataset/train_fixed
90
+ csv_file: /sd/shuhan/CT-RATE/radiology_text_reports/train_reports.csv
91
+ validation:
92
+ target: ldm.data.ct_clip_data_inference.CTReportDatasetinfer
93
+ params:
94
+ data_folder: /sd/shuhan/CT-RATE/dataset/valid_fixed
95
+ csv_file: /sd/shuhan/CT-RATE/radiology_text_reports/valid_reports.csv
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/configs/2026-01-13T23-40-35-lightning.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lightning:
2
+ callbacks:
3
+ image_logger:
4
+ target: main.ImageLogger
5
+ params:
6
+ batch_frequency: 200000000
7
+ max_images: 2
8
+ increase_log_steps: false
9
+ trainer:
10
+ max_epochs: 2000
11
+ benchmark: true
12
+ accumulate_grad_batches: 4
13
+ check_val_every_n_epoch: 10000
14
+ accelerator: ddp
15
+ gpus: 0,1,2,3,
16
+ resume_from_checkpoint: logs/train_train_instructpix2pix/checkpoints/last.ckpt
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/configs/2026-01-13T23-40-35-project.yaml ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 0.0001
3
+ target: ldm.models.diffusion.ddpm_edit.LatentDiffusion
4
+ params:
5
+ ckpt_path: stable_diffusion/models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
6
+ linear_start: 0.00085
7
+ linear_end: 0.012
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: edited
12
+ cond_stage_key: edit
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: false
16
+ conditioning_key: hybrid
17
+ monitor: val/loss_simple_ema
18
+ scale_factor: 0.18215
19
+ use_ema: true
20
+ load_ema: true
21
+ scheduler_config:
22
+ target: ldm.lr_scheduler.LambdaLinearScheduler
23
+ params:
24
+ warm_up_steps:
25
+ - 0
26
+ cycle_lengths:
27
+ - 10000000000000
28
+ f_start:
29
+ - 1.0e-06
30
+ f_max:
31
+ - 1.0
32
+ f_min:
33
+ - 1.0
34
+ unet_config:
35
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
36
+ params:
37
+ image_size: 32
38
+ in_channels: 8
39
+ out_channels: 4
40
+ model_channels: 320
41
+ attention_resolutions:
42
+ - 4
43
+ - 2
44
+ - 1
45
+ num_res_blocks: 2
46
+ channel_mult:
47
+ - 1
48
+ - 2
49
+ - 4
50
+ - 4
51
+ num_heads: 8
52
+ use_spatial_transformer: true
53
+ transformer_depth: 1
54
+ context_dim: 768
55
+ use_checkpoint: true
56
+ legacy: false
57
+ first_stage_config:
58
+ target: ldm.models.autoencoder.AutoencoderKL
59
+ params:
60
+ embed_dim: 4
61
+ monitor: val/rec_loss
62
+ ddconfig:
63
+ double_z: true
64
+ z_channels: 4
65
+ resolution: 256
66
+ in_channels: 3
67
+ out_ch: 3
68
+ ch: 128
69
+ ch_mult:
70
+ - 1
71
+ - 2
72
+ - 4
73
+ - 4
74
+ num_res_blocks: 2
75
+ attn_resolutions: []
76
+ dropout: 0.0
77
+ lossconfig:
78
+ target: torch.nn.Identity
79
+ cond_stage_config:
80
+ target: ldm.modules.encoders.modules.FrozenBioMedCLIPEmbedder
81
+ data:
82
+ target: main.DataModuleFromConfig
83
+ params:
84
+ batch_size: 16
85
+ num_workers: 8
86
+ train:
87
+ target: ldm.data.ct_clip_data_train.CTReportDataset
88
+ params:
89
+ data_folder: /sd/shuhan/CT-RATE/dataset/train_fixed
90
+ csv_file: /sd/shuhan/CT-RATE/radiology_text_reports/train_reports.csv
91
+ validation:
92
+ target: ldm.data.ct_clip_data_inference.CTReportDatasetinfer
93
+ params:
94
+ data_folder: /sd/shuhan/CT-RATE/dataset/valid_fixed
95
+ csv_file: /sd/shuhan/CT-RATE/radiology_text_reports/valid_reports.csv
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/testtube/version_0/meta.experiment ADDED
@@ -0,0 +1 @@
 
 
1
+ {"name": "testtube", "version": 0, "tags_path": "logs/train_train_instructpix2pix/testtube/version_0/meta_tags.csv", "metrics_path": "logs/train_train_instructpix2pix/testtube/version_0/metrics.csv", "autosave": false, "description": null, "created_at": "2026-01-13 15:36:01.172177", "exp_hash": "testtube_v0"}
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/testtube/version_0/meta_tags.csv ADDED
@@ -0,0 +1 @@
 
 
1
+ key,value
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/testtube/version_0/metrics.csv ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train/loss_simple_epoch,train/loss_vlb_epoch,train/loss_epoch,epoch,created_at,lr-AdamW,train/loss_simple_step,train/loss_vlb_step,train/loss_step,global_step,lr_abs
2
+ 0.04492774233222008,0.00028238692902959883,0.04492774233222008,0.0,2026-01-13 15:39:49.398976,,,,,,
3
+ 0.04945197328925133,0.001165940542705357,0.04945197328925133,1.0,2026-01-13 15:42:31.949919,,,,,,
4
+ 0.05576256290078163,0.003641557414084673,0.05576256290078163,2.0,2026-01-13 15:45:20.903362,,,,,,
5
+ 0.041883960366249084,0.0003884841571561992,0.041883960366249084,3.0,2026-01-13 15:48:00.064725,,,,,,
6
+ 0.06064687669277191,0.00043762807035818696,0.06064687669277191,4.0,2026-01-13 15:50:35.354253,,,,,,
7
+ 0.056190915405750275,0.0007902872166596353,0.056190915405750275,5.0,2026-01-13 15:53:14.286251,,,,,,
8
+ 0.05258661508560181,0.0005384630640037358,0.05258661508560181,6.0,2026-01-13 15:56:06.363793,,,,,,
9
+ 0.03564197197556496,0.00016246133600361645,0.03564197197556496,7.0,2026-01-13 15:59:07.938194,,,,,,
10
+ 0.042768098413944244,0.0002556768886279315,0.042768098413944244,8.0,2026-01-13 16:01:47.971633,,,,,,
11
+ 0.0444299541413784,0.00031295977532863617,0.0444299541413784,9.0,2026-01-13 16:04:26.817808,,,,,,
12
+ 0.04728930443525314,0.001666650758124888,0.04728930443525314,10.0,2026-01-13 16:07:07.550724,,,,,,
13
+ 0.043938156217336655,0.0004285657196305692,0.043938156217336655,11.0,2026-01-13 16:09:56.638376,,,,,,
14
+ 0.03814474865794182,0.00020941562252119184,0.03814474865794182,12.0,2026-01-13 16:12:34.041103,,,,,,
15
+ 0.046581581234931946,0.001041729818098247,0.046581581234931946,13.0,2026-01-13 16:14:59.679246,,,,,,
16
+ 0.04245039075613022,0.0043443045578897,0.04245039075613022,14.0,2026-01-13 16:17:35.659020,,,,,,
17
+ 0.03935012221336365,0.00022819600417278707,0.03935012221336365,15.0,2026-01-13 16:20:29.136322,,,,,,
18
+ 0.04137749969959259,0.0003488641232252121,0.04137749969959259,16.0,2026-01-13 16:23:08.688349,,,,,,
19
+ 0.04447738826274872,0.00022584669932257384,0.04447738826274872,17.0,2026-01-13 16:25:44.017076,,,,,,
20
+ 0.034274861216545105,0.00042341675725765526,0.034274861216545105,18.0,2026-01-13 16:28:15.980276,,,,,,
21
+ 0.04171200469136238,0.001967966090887785,0.04171200469136238,19.0,2026-01-13 16:30:45.178647,,,,,,
22
+ 0.0435456857085228,0.0012122655753046274,0.0435456857085228,20.0,2026-01-13 16:33:20.613595,,,,,,
23
+ 0.03531428799033165,0.00026467524003237486,0.03531428799033165,21.0,2026-01-13 16:35:59.972802,,,,,,
24
+ 0.036896076053380966,0.0003513060219120234,0.036896076053380966,22.0,2026-01-13 16:38:37.428488,,,,,,
25
+ 0.039321959018707275,0.0002118393749697134,0.039321959018707275,23.0,2026-01-13 16:41:16.981679,,,,,,
26
+ 0.037652213126420975,0.0004801239410880953,0.037652213126420975,24.0,2026-01-13 16:43:55.411387,,,,,,
27
+ 0.03323481231927872,0.00022156101476866752,0.03323481231927872,25.0,2026-01-13 16:46:33.674520,,,,,,
28
+ 0.033796556293964386,0.0007194963400252163,0.033796556293964386,26.0,2026-01-13 16:49:17.837438,,,,,,
29
+ 0.03964272141456604,0.00030356034403666854,0.03964272141456604,27.0,2026-01-13 16:51:54.570995,,,,,,
30
+ 0.0341181680560112,0.0003009494685102254,0.0341181680560112,28.0,2026-01-13 16:54:37.705612,,,,,,
31
+ 0.03446051850914955,0.0005756927421316504,0.03446051850914955,29.0,2026-01-13 16:57:08.658119,,,,,,
32
+ 0.028473246842622757,0.000150643551023677,0.028473246842622757,30.0,2026-01-13 16:59:37.887618,,,,,,
33
+ 0.03591898828744888,0.00016116136976052076,0.03591898828744888,31.0,2026-01-13 17:02:42.776683,,,,,,
34
+ 0.03866593912243843,0.0002652394468896091,0.03866593912243843,32.0,2026-01-13 17:05:30.622988,,,,,,
35
+ 0.039249055087566376,0.0005666539655067027,0.039249055087566376,33.0,2026-01-13 17:07:57.387815,,,,,,
36
+ 0.028335751965641975,0.00017373055743519217,0.028335751965641975,34.0,2026-01-13 17:10:34.541191,,,,,,
37
+ 0.03553164750337601,0.0003914266126230359,0.03553164750337601,35.0,2026-01-13 17:13:12.852720,,,,,,
38
+ 0.028429865837097168,0.00014207293861545622,0.028429865837097168,36.0,2026-01-13 17:16:04.967122,,,,,,
39
+ 0.036691952496767044,0.0029678838327527046,0.036691952496767044,37.0,2026-01-13 17:18:28.050656,,,,,,
40
+ 0.04113037884235382,0.0008072647615335882,0.04113037884235382,38.0,2026-01-13 17:21:11.567197,,,,,,
41
+ 0.041899241507053375,0.0015469326172024012,0.041899241507053375,39.0,2026-01-13 17:23:52.115598,,,,,,
42
+ 0.034236062318086624,0.0011594714596867561,0.034236062318086624,40.0,2026-01-13 17:26:35.607605,,,,,,
43
+ 0.030954739078879356,0.0001455929159419611,0.030954739078879356,41.0,2026-01-13 17:29:16.124200,,,,,,
44
+ 0.04418838396668434,0.00910048745572567,0.04418838396668434,42.0,2026-01-13 17:31:58.721039,,,,,,
45
+ 0.03483903035521507,0.0003233915485907346,0.03483903035521507,43.0,2026-01-13 17:34:45.769391,,,,,,
46
+ 0.029837023466825485,0.0001850406260928139,0.029837023466825485,44.0,2026-01-13 17:38:30.754686,,,,,,
47
+ 0.020381895825266838,0.00011296541924821213,0.020381895825266838,45.0,2026-01-13 17:41:15.459268,,,,,,
48
+ 0.03233564645051956,0.00018584895587991923,0.03233564645051956,46.0,2026-01-13 17:43:52.829968,,,,,,
49
+ 0.02938614971935749,0.0002223715273430571,0.02938614971935749,47.0,2026-01-13 17:46:22.069204,,,,,,
50
+ 0.04097375646233559,0.0009451856021769345,0.04097375646233559,48.0,2026-01-13 17:49:06.606438,,,,,,
51
+ ,,,,2026-01-13 17:50:03.850979,0.0001,,,,,
52
+ ,,,,2026-01-13 17:50:04.674137,0.0001,,,,,
53
+ ,,,,2026-01-13 17:50:05.517919,0.0001,,,,,
54
+ ,,,49.0,2026-01-13 17:50:06.064603,,0.020297572016716003,9.82858837232925e-05,0.020297572016716003,49.0,9.999999747378752e-05
55
+ 0.02533917874097824,0.0001783097250154242,0.02533917874097824,49.0,2026-01-13 17:51:51.632851,,,,,,
56
+ 0.030183883383870125,0.00022021151380613446,0.030183883383870125,50.0,2026-01-13 17:54:33.703026,,,,,,
57
+ 0.03008226305246353,0.002578485058620572,0.03008226305246353,51.0,2026-01-13 17:57:20.055129,,,,,,
58
+ 0.03803397715091705,0.00037925405194982886,0.03803397715091705,52.0,2026-01-13 18:00:20.944089,,,,,,
59
+ 0.041673775762319565,0.00046355099766515195,0.041673775762319565,53.0,2026-01-13 18:02:59.538489,,,,,,
60
+ 0.0244731642305851,0.0018854450900107622,0.0244731642305851,54.0,2026-01-13 18:05:31.385397,,,,,,
61
+ 0.04758085682988167,0.0011080902768298984,0.04758085682988167,55.0,2026-01-13 18:08:23.627467,,,,,,
62
+ 0.03392697870731354,0.002914445474743843,0.03392697870731354,56.0,2026-01-13 18:11:00.730189,,,,,,
63
+ 0.02654872089624405,0.00013887944805901498,0.02654872089624405,57.0,2026-01-13 18:13:38.885119,,,,,,
64
+ 0.0330892875790596,0.0036669319961220026,0.0330892875790596,58.0,2026-01-13 18:16:01.485152,,,,,,
65
+ 0.031982921063899994,0.0002334276941837743,0.031982921063899994,59.0,2026-01-13 18:18:25.702880,,,,,,
66
+ 0.035211168229579926,0.0007371274405159056,0.035211168229579926,60.0,2026-01-13 18:20:42.539903,,,,,,
67
+ 0.03253152221441269,0.0008004647097550333,0.03253152221441269,61.0,2026-01-13 18:23:40.384635,,,,,,
68
+ 0.031589679419994354,0.0006487449863925576,0.031589679419994354,62.0,2026-01-13 18:26:05.687190,,,,,,
69
+ 0.027530457824468613,0.0004612143966369331,0.027530457824468613,63.0,2026-01-13 18:28:44.825634,,,,,,
70
+ 0.026691773906350136,0.0001971848396351561,0.026691773906350136,64.0,2026-01-13 18:31:17.238475,,,,,,
71
+ 0.03211250156164169,0.0027711628936231136,0.03211250156164169,65.0,2026-01-13 18:33:57.599706,,,,,,
72
+ 0.01422487199306488,6.4654610469006e-05,0.01422487199306488,66.0,2026-01-13 18:36:10.453560,,,,,,
73
+ 0.028031136840581894,0.00031970994314178824,0.028031136840581894,67.0,2026-01-13 18:39:18.404783,,,,,,
74
+ 0.022906355559825897,0.00030137741123326123,0.022906355559825897,68.0,2026-01-13 18:41:59.531463,,,,,,
75
+ 0.03375226631760597,0.0005524270818568766,0.03375226631760597,69.0,2026-01-13 18:44:38.344377,,,,,,
76
+ 0.025169331580400467,0.0002144321333616972,0.025169331580400467,70.0,2026-01-13 18:47:06.777753,,,,,,
77
+ 0.029542915523052216,0.0002853941696230322,0.029542915523052216,71.0,2026-01-13 18:50:28.002651,,,,,,
78
+ 0.022899124771356583,0.00018100191664416343,0.022899124771356583,72.0,2026-01-13 18:53:17.050091,,,,,,
79
+ 0.021205568686127663,0.000281251355772838,0.021205568686127663,73.0,2026-01-13 18:55:51.216897,,,,,,
80
+ 0.02343231439590454,0.00016876781592145562,0.02343231439590454,74.0,2026-01-13 18:58:20.414185,,,,,,
81
+ 0.030981680378317833,0.005326937418431044,0.030981680378317833,75.0,2026-01-13 19:00:54.643760,,,,,,
82
+ 0.022693296894431114,0.0001320972660323605,0.022693296894431114,76.0,2026-01-13 19:03:42.896797,,,,,,
83
+ 0.031266022473573685,0.0010528827551752329,0.031266022473573685,77.0,2026-01-13 19:06:24.067347,,,,,,
84
+ 0.019028138369321823,0.00027022312860935926,0.019028138369321823,78.0,2026-01-13 19:09:10.997335,,,,,,
85
+ 0.023607315495610237,0.00025715187075547874,0.023607315495610237,79.0,2026-01-13 19:11:42.342316,,,,,,
86
+ 0.012905941344797611,6.465279147960246e-05,0.012905941344797611,80.0,2026-01-13 19:14:11.513844,,,,,,
87
+ 0.02133249118924141,0.0004584984271787107,0.02133249118924141,81.0,2026-01-13 19:16:38.620064,,,,,,
88
+ 0.01448127906769514,8.018436346901581e-05,0.01448127906769514,82.0,2026-01-13 19:19:26.935454,,,,,,
89
+ 0.02080843597650528,0.0015032771043479443,0.02080843597650528,83.0,2026-01-13 19:22:00.455423,,,,,,
90
+ 0.026811497285962105,0.00042073841905221343,0.026811497285962105,84.0,2026-01-13 19:24:32.022931,,,,,,
91
+ 0.017424674704670906,0.0006118972669355571,0.017424674704670906,85.0,2026-01-13 19:27:10.718300,,,,,,
92
+ 0.01699788123369217,0.00036411231849342585,0.01699788123369217,86.0,2026-01-13 19:30:30.630058,,,,,,
93
+ 0.020952994003891945,0.0002837599895428866,0.020952994003891945,87.0,2026-01-13 19:33:00.965650,,,,,,
94
+ 0.015126981772482395,0.0002649685484357178,0.015126981772482395,88.0,2026-01-13 19:35:48.905038,,,,,,
95
+ 0.025458965450525284,0.004489585291594267,0.025458965450525284,89.0,2026-01-13 19:38:44.584141,,,,,,
96
+ 0.021945003420114517,0.000531058176420629,0.021945003420114517,90.0,2026-01-13 19:41:20.476941,,,,,,
97
+ 0.014269448816776276,9.480538574280217e-05,0.014269448816776276,91.0,2026-01-13 19:44:00.040129,,,,,,
98
+ 0.015860360115766525,0.0001000822230707854,0.015860360115766525,92.0,2026-01-13 19:47:09.181119,,,,,,
99
+ 0.013729044236242771,0.00010012665006797761,0.013729044236242771,93.0,2026-01-13 19:49:40.177524,,,,,,
100
+ 0.012890013866126537,0.00023106325534172356,0.012890013866126537,94.0,2026-01-13 19:52:04.915582,,,,,,
101
+ 0.01310599036514759,8.260652248281986e-05,0.01310599036514759,95.0,2026-01-13 19:54:41.807542,,,,,,
102
+ 0.014611278660595417,0.00042112352093681693,0.014611278660595417,96.0,2026-01-13 19:57:15.067459,,,,,,
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/testtube/version_0/tf/events.out.tfevents.1768318562.node-0.1923.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff8652efa0be26a68221d9bdd7e7767e55cf2015949af538781f1d6298d62f29
3
+ size 21476
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/testtube/version_1/meta.experiment ADDED
@@ -0,0 +1 @@
 
 
1
+ {"name": "testtube", "version": 1, "tags_path": "logs/train_train_instructpix2pix/testtube/version_1/meta_tags.csv", "metrics_path": "logs/train_train_instructpix2pix/testtube/version_1/metrics.csv", "autosave": false, "description": null, "created_at": "2026-01-13 23:43:19.408467", "exp_hash": "testtube_v1"}
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/testtube/version_1/meta_tags.csv ADDED
@@ -0,0 +1 @@
 
 
1
+ key,value
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/testtube/version_1/metrics.csv ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train/loss_simple_epoch,train/loss_vlb_epoch,train/loss_epoch,epoch,created_at,lr-AdamW,train/loss_simple_step,train/loss_vlb_step,train/loss_step,global_step,lr_abs
2
+ 0.01197216846048832,0.00019215696374885738,0.01197216846048832,97.0,2026-01-13 23:47:52.592140,,,,,,
3
+ 0.01689619943499565,0.00022687626187689602,0.01689619943499565,98.0,2026-01-13 23:52:06.218450,,,,,,
4
+ ,,,,2026-01-13 23:53:21.502126,0.0001,,,,,
5
+ ,,,,2026-01-13 23:53:28.518867,0.0001,,,,,
6
+ ,,,,2026-01-13 23:53:35.864360,0.0001,,,,,
7
+ ,,,99.0,2026-01-13 23:53:36.421754,,0.008886125870049,4.193108907202259e-05,0.008886125870049,99.0,9.999999747378752e-05
8
+ 0.012425115332007408,0.00011319141776766628,0.012425115332007408,99.0,2026-01-13 23:55:56.614328,,,,,,
9
+ 0.009767280891537666,7.24010678823106e-05,0.009767280891537666,100.0,2026-01-13 23:59:06.002144,,,,,,
10
+ 0.017172064632177353,0.00036570191150531173,0.017172064632177353,101.0,2026-01-14 00:02:15.416377,,,,,,
11
+ 0.014853031374514103,0.00021442514844238758,0.014853031374514103,102.0,2026-01-14 00:06:42.803790,,,,,,
12
+ 0.011737050488591194,8.082642307272181e-05,0.011737050488591194,103.0,2026-01-14 00:09:48.525215,,,,,,
13
+ 0.010128101333975792,0.00028111651772633195,0.010128101333975792,104.0,2026-01-14 00:12:50.301862,,,,,,
14
+ 0.020027538761496544,0.0012822315329685807,0.020027538761496544,105.0,2026-01-14 00:16:13.522760,,,,,,
15
+ 0.012312892824411392,0.00022742303553968668,0.012312892824411392,106.0,2026-01-14 00:19:23.857237,,,,,,
16
+ 0.013259190134704113,0.00011501927656354383,0.013259190134704113,107.0,2026-01-14 00:22:31.914371,,,,,,
17
+ 0.005941064562648535,3.36673365382012e-05,0.005941064562648535,108.0,2026-01-14 00:30:33.476385,,,,,,
18
+ 0.024892671033740044,0.005268546286970377,0.024892671033740044,109.0,2026-01-14 00:37:59.013915,,,,,,
19
+ 0.010536117479205132,0.0007038437179289758,0.010536117479205132,110.0,2026-01-14 00:43:35.028266,,,,,,
20
+ 0.015121201053261757,0.0005928723840042949,0.015121201053261757,111.0,2026-01-14 00:49:04.669255,,,,,,
21
+ 0.006611506920307875,3.997121166321449e-05,0.006611506920307875,112.0,2026-01-14 00:53:54.071099,,,,,,
22
+ 0.017725441604852676,0.001076962798833847,0.017725441604852676,113.0,2026-01-14 00:58:46.243281,,,,,,
23
+ 0.012379763647913933,0.0002492000930942595,0.012379763647913933,114.0,2026-01-14 01:02:46.925807,,,,,,
24
+ 0.0070332628674805164,7.378628652077168e-05,0.0070332628674805164,115.0,2026-01-14 01:05:53.256814,,,,,,
25
+ 0.011917666532099247,7.035672024358064e-05,0.011917666532099247,116.0,2026-01-14 01:09:06.659490,,,,,,
26
+ 0.006569588091224432,3.5176315577700734e-05,0.006569588091224432,117.0,2026-01-14 01:12:13.904002,,,,,,
27
+ 0.011747820302844048,0.0022530024871230125,0.011747820302844048,118.0,2026-01-14 01:15:33.549657,,,,,,
28
+ 0.00879058800637722,5.238176527200267e-05,0.00879058800637722,119.0,2026-01-14 01:18:42.746095,,,,,,
29
+ 0.013679510913789272,0.0006952152471058071,0.013679510913789272,120.0,2026-01-14 01:22:08.793910,,,,,,
30
+ 0.011278304271399975,0.0005908889579586685,0.011278304271399975,121.0,2026-01-14 01:25:14.934196,,,,,,
31
+ 0.01308036595582962,0.00013252333155833185,0.01308036595582962,122.0,2026-01-14 01:28:22.976512,,,,,,
32
+ 0.010981367900967598,0.00018925454060081393,0.010981367900967598,123.0,2026-01-14 01:31:46.472359,,,,,,
33
+ 0.009664732031524181,0.00012013369996566325,0.009664732031524181,124.0,2026-01-14 01:36:00.441369,,,,,,
34
+ 0.011889003217220306,8.517378591932356e-05,0.011889003217220306,125.0,2026-01-14 01:39:11.873864,,,,,,
35
+ 0.018262574449181557,0.00038071247399784625,0.018262574449181557,126.0,2026-01-14 01:42:16.759612,,,,,,
36
+ 0.015098193660378456,0.00015731646271888167,0.015098193660378456,127.0,2026-01-14 01:45:31.681988,,,,,,
37
+ 0.008732830174267292,0.0017312642885372043,0.008732830174267292,128.0,2026-01-14 01:48:28.750403,,,,,,
38
+ 0.007402519695460796,0.00013707477774005383,0.007402519695460796,129.0,2026-01-14 01:51:30.697285,,,,,,
39
+ 0.01271725632250309,0.00044070437434129417,0.01271725632250309,130.0,2026-01-14 01:54:47.314776,,,,,,
40
+ 0.014155593700706959,0.0023787980899214745,0.014155593700706959,131.0,2026-01-14 01:58:02.136205,,,,,,
41
+ 0.014698449522256851,0.0007856176234781742,0.014698449522256851,132.0,2026-01-14 02:01:08.156323,,,,,,
42
+ 0.007009289693087339,4.169391104369424e-05,0.007009289693087339,133.0,2026-01-14 02:04:25.428844,,,,,,
43
+ 0.014788895845413208,0.0009002183796837926,0.014788895845413208,134.0,2026-01-14 02:07:40.663624,,,,,,
44
+ 0.013603661209344864,0.0026975262444466352,0.013603661209344864,135.0,2026-01-14 02:10:32.772745,,,,,,
45
+ 0.009711007587611675,5.309990228852257e-05,0.009711007587611675,136.0,2026-01-14 02:13:43.066020,,,,,,
46
+ 0.006434881128370762,0.0002034392673522234,0.006434881128370762,137.0,2026-01-14 02:16:47.822359,,,,,,
47
+ 0.016238775104284286,0.00015190854901447892,0.016238775104284286,138.0,2026-01-14 02:20:02.527370,,,,,,
48
+ 0.010713496245443821,0.00011459724919404835,0.010713496245443821,139.0,2026-01-14 02:23:10.495598,,,,,,
49
+ 0.013620677404105663,0.0016112453304231167,0.013620677404105663,140.0,2026-01-14 02:26:20.471643,,,,,,
50
+ 0.01036806870251894,7.42679912946187e-05,0.01036806870251894,141.0,2026-01-14 02:29:39.162748,,,,,,
51
+ 0.00851675495505333,9.04347762116231e-05,0.00851675495505333,142.0,2026-01-14 02:33:00.651105,,,,,,
52
+ 0.009983322583138943,7.42983611417003e-05,0.009983322583138943,143.0,2026-01-14 02:36:05.562470,,,,,,
53
+ 0.008650604635477066,0.0002860668464563787,0.008650604635477066,144.0,2026-01-14 02:39:17.846689,,,,,,
54
+ 0.007028167136013508,3.1158197089098394e-05,0.007028167136013508,145.0,2026-01-14 02:42:31.720644,,,,,,
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/logs/train_train_instructpix2pix/testtube/version_1/tf/events.out.tfevents.1768347806.node-0.4103.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e92dea2056405540d75ead4ba80dfdd623c6e00d050457b6e795fe328cfdf4c
3
+ size 11188
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/main.py ADDED
@@ -0,0 +1,814 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, datetime, glob
2
+ import numpy as np
3
+ import time
4
+ import torch
5
+ import torchvision
6
+
7
+ import enum
8
+
9
+ # 保存原始实现(可选,用于调试或回滚)
10
+ _enum_format_orig = enum.Enum.__format__
11
+
12
+ def _enum_format_value(self, format_spec: str) -> str:
13
+ # 如果指定了 format 规范,就交给 value 本身的格式化;否则直接返回 value
14
+ return format(self.value, format_spec) if format_spec else str(self.value)
15
+
16
+ # 全局替换 Enum.__format__
17
+ enum.Enum.__format__ = _enum_format_value
18
+
19
+ import pytorch_lightning as pl
20
+ import json
21
+ import pickle
22
+
23
+ from packaging import version
24
+ from omegaconf import OmegaConf
25
+ from torch.utils.data import DataLoader, Dataset
26
+ from functools import partial
27
+ from PIL import Image
28
+
29
+ import torch.distributed as dist
30
+ from pytorch_lightning import seed_everything
31
+ from pytorch_lightning.trainer import Trainer
32
+ from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
33
+ from pytorch_lightning.utilities.distributed import rank_zero_only
34
+ from pytorch_lightning.utilities import rank_zero_info
35
+ from pytorch_lightning.plugins import DDPPlugin
36
+
37
+ sys.path.append("./stable_diffusion")
38
+
39
+ from ldm.data.base import Txt2ImgIterableBaseDataset
40
+ from ldm.util import instantiate_from_config
41
+
42
+
43
+ def get_parser(**parser_kwargs):
44
+ def str2bool(v):
45
+ if isinstance(v, bool):
46
+ return v
47
+ if v.lower() in ("yes", "true", "t", "y", "1"):
48
+ return True
49
+ elif v.lower() in ("no", "false", "f", "n", "0"):
50
+ return False
51
+ else:
52
+ raise argparse.ArgumentTypeError("Boolean value expected.")
53
+
54
+ parser = argparse.ArgumentParser(**parser_kwargs)
55
+ parser.add_argument(
56
+ "-n",
57
+ "--name",
58
+ type=str,
59
+ const=True,
60
+ default="",
61
+ nargs="?",
62
+ help="postfix for logdir",
63
+ )
64
+ parser.add_argument(
65
+ "-r",
66
+ "--resume",
67
+ type=str,
68
+ const=True,
69
+ default="",
70
+ nargs="?",
71
+ help="resume from logdir or checkpoint in logdir",
72
+ )
73
+ parser.add_argument(
74
+ "-b",
75
+ "--base",
76
+ nargs="*",
77
+ metavar="base_config.yaml",
78
+ help="paths to base configs. Loaded from left-to-right. "
79
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
80
+ default=list(),
81
+ )
82
+ parser.add_argument(
83
+ "-t",
84
+ "--train",
85
+ type=str2bool,
86
+ const=True,
87
+ default=False,
88
+ nargs="?",
89
+ help="train",
90
+ )
91
+ parser.add_argument(
92
+ "--no-test",
93
+ type=str2bool,
94
+ const=True,
95
+ default=False,
96
+ nargs="?",
97
+ help="disable test",
98
+ )
99
+ parser.add_argument(
100
+ "-p",
101
+ "--project",
102
+ help="name of new or path to existing project"
103
+ )
104
+ parser.add_argument(
105
+ "-d",
106
+ "--debug",
107
+ type=str2bool,
108
+ nargs="?",
109
+ const=True,
110
+ default=False,
111
+ help="enable post-mortem debugging",
112
+ )
113
+ parser.add_argument(
114
+ "-s",
115
+ "--seed",
116
+ type=int,
117
+ default=23,
118
+ help="seed for seed_everything",
119
+ )
120
+ parser.add_argument(
121
+ "-f",
122
+ "--postfix",
123
+ type=str,
124
+ default="",
125
+ help="post-postfix for default name",
126
+ )
127
+ parser.add_argument(
128
+ "-l",
129
+ "--logdir",
130
+ type=str,
131
+ default="logs",
132
+ help="directory for logging dat shit",
133
+ )
134
+ parser.add_argument(
135
+ "--scale_lr",
136
+ action="store_true",
137
+ default=False,
138
+ help="scale base-lr by ngpu * batch_size * n_accumulate",
139
+ )
140
+ return parser
141
+
142
+
143
+ def nondefault_trainer_args(opt):
144
+ parser = argparse.ArgumentParser()
145
+ parser = Trainer.add_argparse_args(parser)
146
+ args = parser.parse_args([])
147
+ return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
148
+
149
+
150
+ class WrappedDataset(Dataset):
151
+ """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
152
+
153
+ def __init__(self, dataset):
154
+ self.data = dataset
155
+
156
+ def __len__(self):
157
+ return len(self.data)
158
+
159
+ def __getitem__(self, idx):
160
+ return self.data[idx]
161
+
162
+
163
+ def worker_init_fn(_):
164
+ worker_info = torch.utils.data.get_worker_info()
165
+
166
+ dataset = worker_info.dataset
167
+ worker_id = worker_info.id
168
+
169
+ if isinstance(dataset, Txt2ImgIterableBaseDataset):
170
+ split_size = dataset.num_records // worker_info.num_workers
171
+ # reset num_records to the true number to retain reliable length information
172
+ dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
173
+ current_id = np.random.choice(len(np.random.get_state()[1]), 1)
174
+ return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
175
+ else:
176
+ return np.random.seed(np.random.get_state()[1][0] + worker_id)
177
+
178
+
179
+ class DataModuleFromConfig(pl.LightningDataModule):
180
+ def __init__(self, batch_size, train=None, validation=None, test=None, predict=None,
181
+ wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
182
+ shuffle_val_dataloader=False):
183
+ super().__init__()
184
+ self.batch_size = batch_size
185
+ self.dataset_configs = dict()
186
+ self.num_workers = num_workers if num_workers is not None else batch_size * 2
187
+ self.use_worker_init_fn = use_worker_init_fn
188
+ if train is not None:
189
+ self.dataset_configs["train"] = train
190
+ self.train_dataloader = self._train_dataloader
191
+ if validation is not None:
192
+ self.dataset_configs["validation"] = validation
193
+ self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
194
+ if test is not None:
195
+ self.dataset_configs["test"] = test
196
+ self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
197
+ if predict is not None:
198
+ self.dataset_configs["predict"] = predict
199
+ self.predict_dataloader = self._predict_dataloader
200
+ self.wrap = wrap
201
+
202
+ def prepare_data(self):
203
+ for data_cfg in self.dataset_configs.values():
204
+ instantiate_from_config(data_cfg)
205
+
206
+ def setup(self, stage=None):
207
+ self.datasets = dict(
208
+ (k, instantiate_from_config(self.dataset_configs[k]))
209
+ for k in self.dataset_configs)
210
+ if self.wrap:
211
+ for k in self.datasets:
212
+ self.datasets[k] = WrappedDataset(self.datasets[k])
213
+
214
+ def _train_dataloader(self):
215
+ is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
216
+ if is_iterable_dataset or self.use_worker_init_fn:
217
+ init_fn = worker_init_fn
218
+ else:
219
+ init_fn = None
220
+ return DataLoader(self.datasets["train"], batch_size=self.batch_size,
221
+ num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True,
222
+ worker_init_fn=init_fn, persistent_workers=False)
223
+
224
+ def _val_dataloader(self, shuffle=False):
225
+ if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
226
+ init_fn = worker_init_fn
227
+ else:
228
+ init_fn = None
229
+ return DataLoader(self.datasets["validation"],
230
+ batch_size=self.batch_size,
231
+ num_workers=self.num_workers,
232
+ worker_init_fn=init_fn,
233
+ shuffle=shuffle, persistent_workers=False)
234
+
235
+ def _test_dataloader(self, shuffle=False):
236
+ is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
237
+ if is_iterable_dataset or self.use_worker_init_fn:
238
+ init_fn = worker_init_fn
239
+ else:
240
+ init_fn = None
241
+
242
+ # do not shuffle dataloader for iterable dataset
243
+ shuffle = shuffle and (not is_iterable_dataset)
244
+
245
+ return DataLoader(self.datasets["test"], batch_size=self.batch_size,
246
+ num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle, persistent_workers=False)
247
+
248
+ def _predict_dataloader(self, shuffle=False):
249
+ if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
250
+ init_fn = worker_init_fn
251
+ else:
252
+ init_fn = None
253
+ return DataLoader(self.datasets["predict"], batch_size=self.batch_size,
254
+ num_workers=self.num_workers, worker_init_fn=init_fn, persistent_workers=False)
255
+
256
+
257
+ class SetupCallback(Callback):
258
+ def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
259
+ super().__init__()
260
+ self.resume = resume
261
+ self.now = now
262
+ self.logdir = logdir
263
+ self.ckptdir = ckptdir
264
+ self.cfgdir = cfgdir
265
+ self.config = config
266
+ self.lightning_config = lightning_config
267
+
268
+ def on_keyboard_interrupt(self, trainer, pl_module):
269
+ if trainer.global_rank == 0:
270
+ print("Summoning checkpoint.")
271
+ ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
272
+ trainer.save_checkpoint(ckpt_path)
273
+
274
+ def on_pretrain_routine_start(self, trainer, pl_module):
275
+ if trainer.global_rank == 0:
276
+ # Create logdirs and save configs
277
+ # os.makedirs(self.logdir, exist_ok=True)
278
+ # os.makedirs(self.ckptdir, exist_ok=True)
279
+ # os.makedirs(self.cfgdir, exist_ok=True)
280
+
281
+ if "callbacks" in self.lightning_config:
282
+ if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']:
283
+ os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
284
+ print("Project config")
285
+ print(OmegaConf.to_yaml(self.config))
286
+ OmegaConf.save(self.config,
287
+ os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
288
+
289
+ print("Lightning config")
290
+ print(OmegaConf.to_yaml(self.lightning_config))
291
+ OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
292
+ os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
293
+
294
+ def get_world_size():
295
+ if not dist.is_available():
296
+ return 1
297
+ if not dist.is_initialized():
298
+ return 1
299
+ return dist.get_world_size()
300
+
301
+ def all_gather(data):
302
+ """
303
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
304
+ Args:
305
+ data: any picklable object
306
+ Returns:
307
+ list[data]: list of data gathered from each rank
308
+ """
309
+ world_size = get_world_size()
310
+ if world_size == 1:
311
+ return [data]
312
+
313
+ # serialized to a Tensor
314
+ origin_size = None
315
+ if not isinstance(data, torch.Tensor):
316
+ buffer = pickle.dumps(data)
317
+ storage = torch.ByteStorage.from_buffer(buffer)
318
+ tensor = torch.ByteTensor(storage).to("cuda")
319
+ else:
320
+ origin_size = data.size()
321
+ tensor = data.reshape(-1)
322
+
323
+ tensor_type = tensor.dtype
324
+
325
+ # obtain Tensor size of each rank
326
+ local_size = torch.LongTensor([tensor.numel()]).to("cuda")
327
+ size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
328
+ dist.all_gather(size_list, local_size)
329
+ size_list = [int(size.item()) for size in size_list]
330
+ max_size = max(size_list)
331
+
332
+ # receiving Tensor from all ranks
333
+ # we pad the tensor because torch all_gather does not support
334
+ # gathering tensors of different shapes
335
+ tensor_list = []
336
+ for _ in size_list:
337
+ tensor_list.append(torch.FloatTensor(size=(max_size,)).cuda().to(tensor_type))
338
+ if local_size != max_size:
339
+ padding = torch.FloatTensor(size=(max_size - local_size,)).cuda().to(tensor_type)
340
+ tensor = torch.cat((tensor, padding), dim=0)
341
+ dist.all_gather(tensor_list, tensor)
342
+
343
+ data_list = []
344
+ for size, tensor in zip(size_list, tensor_list):
345
+ if origin_size is None:
346
+ buffer = tensor.cpu().numpy().tobytes()[:size]
347
+ data_list.append(pickle.loads(buffer))
348
+ else:
349
+ buffer = tensor[:size]
350
+ data_list.append(buffer)
351
+
352
+ if origin_size is not None:
353
+ new_shape = [-1] + list(origin_size[1:])
354
+ resized_list = []
355
+ for data in data_list:
356
+ # suppose the difference of tensor size exist in first dimension
357
+ data = data.reshape(new_shape)
358
+ resized_list.append(data)
359
+
360
+ return resized_list
361
+ else:
362
+ return data_list
363
+
364
+ class ImageLogger(Callback):
365
+ def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True,
366
+ rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
367
+ log_images_kwargs=None):
368
+ super().__init__()
369
+ self.rescale = rescale
370
+ self.batch_freq = batch_frequency
371
+ self.max_images = max_images
372
+ self.logger_log_images = {
373
+ pl.loggers.TestTubeLogger: self._testtube,
374
+ }
375
+ self.log_steps = [2 ** n for n in range(6, int(np.log2(self.batch_freq)) + 1)]
376
+ if not increase_log_steps:
377
+ self.log_steps = [self.batch_freq]
378
+ self.clamp = clamp
379
+ self.disabled = disabled
380
+ self.log_on_batch_idx = log_on_batch_idx
381
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
382
+ self.log_first_step = log_first_step
383
+
384
+ @rank_zero_only
385
+ def _testtube(self, pl_module, images, batch_idx, split):
386
+ for k in images:
387
+ grid = torchvision.utils.make_grid(images[k])
388
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
389
+
390
+ tag = f"{split}/{k}"
391
+ pl_module.logger.experiment.add_image(
392
+ tag, grid,
393
+ global_step=pl_module.global_step)
394
+
395
+ @rank_zero_only
396
+ def log_local(self, save_dir, split, images, prompts,
397
+ global_step, current_epoch, batch_idx):
398
+ root = os.path.join(save_dir, "images", split)
399
+ names = {"reals": "before", "inputs": "after", "reconstruction": "before-vq", "samples": "after-gen"}
400
+ # print(root)
401
+ for k in images:
402
+ grid = torchvision.utils.make_grid(images[k], nrow=8)
403
+ if self.rescale:
404
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
405
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
406
+ grid = grid.numpy()
407
+ grid = (grid * 255).astype(np.uint8)
408
+ filename = "gs-{:06}_e-{:06}_b-{:06}_{}.png".format(
409
+ global_step,
410
+ current_epoch,
411
+ batch_idx,
412
+ names[k])
413
+ path = os.path.join(root, filename)
414
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
415
+ # print(path)
416
+ Image.fromarray(grid).save(path)
417
+
418
+ filename = "gs-{:06}_e-{:06}_b-{:06}_prompt.json".format(
419
+ global_step,
420
+ current_epoch,
421
+ batch_idx)
422
+ path = os.path.join(root, filename)
423
+ with open(path, "w") as f:
424
+ for p in prompts:
425
+ f.write(f"{json.dumps(p)}\n")
426
+
427
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
428
+ check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
429
+ if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
430
+ hasattr(pl_module, "log_images") and
431
+ callable(pl_module.log_images) and
432
+ self.max_images > 0) or (split == "val" and batch_idx == 0):
433
+ logger = type(pl_module.logger)
434
+
435
+ is_train = pl_module.training
436
+ if is_train:
437
+ pl_module.eval()
438
+
439
+ with torch.no_grad():
440
+ images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
441
+
442
+ prompts = batch["edit"]["c_crossattn"][:self.max_images]
443
+ prompts = [p for ps in all_gather(prompts) for p in ps]
444
+
445
+ for k in images:
446
+ N = min(images[k].shape[0], self.max_images)
447
+ images[k] = images[k][:N]
448
+ images[k] = torch.cat(all_gather(images[k][:N]))
449
+ if isinstance(images[k], torch.Tensor):
450
+ images[k] = images[k].detach().cpu()
451
+ if self.clamp:
452
+ images[k] = torch.clamp(images[k], -1., 1.)
453
+
454
+ self.log_local(pl_module.logger.save_dir, split, images, prompts,
455
+ pl_module.global_step, pl_module.current_epoch, batch_idx)
456
+
457
+ logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
458
+ logger_log_images(pl_module, images, pl_module.global_step, split)
459
+
460
+ if is_train:
461
+ pl_module.train()
462
+
463
+ def check_frequency(self, check_idx):
464
+ if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
465
+ check_idx > 0 or self.log_first_step):
466
+ if len(self.log_steps) > 0:
467
+ self.log_steps.pop(0)
468
+ return True
469
+ return False
470
+
471
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
472
+ # if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
473
+ # self.log_img(pl_module, batch, batch_idx, split="train")
474
+ pass
475
+
476
+ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
477
+ # if not self.disabled and pl_module.global_step > 0:
478
+ # self.log_img(pl_module, batch, batch_idx, split="val")
479
+ # if hasattr(pl_module, 'calibrate_grad_norm'):
480
+ # if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
481
+ # self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
482
+ pass
483
+
484
+
485
+ class CUDACallback(Callback):
486
+ # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
487
+ def on_train_epoch_start(self, trainer, pl_module):
488
+ # Reset the memory use counter
489
+ torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
490
+ torch.cuda.synchronize(trainer.root_gpu)
491
+ self.start_time = time.time()
492
+
493
+ def on_train_epoch_end(self, trainer, pl_module, outputs):
494
+ torch.cuda.synchronize(trainer.root_gpu)
495
+ max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20
496
+ epoch_time = time.time() - self.start_time
497
+
498
+ try:
499
+ max_memory = trainer.training_type_plugin.reduce(max_memory)
500
+ epoch_time = trainer.training_type_plugin.reduce(epoch_time)
501
+
502
+ rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
503
+ rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
504
+ except AttributeError:
505
+ pass
506
+
507
+
508
+ if __name__ == "__main__":
509
+ # custom parser to specify config files, train, test and debug mode,
510
+ # postfix, resume.
511
+ # `--key value` arguments are interpreted as arguments to the trainer.
512
+ # `nested.key=value` arguments are interpreted as config parameters.
513
+ # configs are merged from left-to-right followed by command line parameters.
514
+
515
+ # model:
516
+ # base_learning_rate: float
517
+ # target: path to lightning module
518
+ # params:
519
+ # key: value
520
+ # data:
521
+ # target: main.DataModuleFromConfig
522
+ # params:
523
+ # batch_size: int
524
+ # wrap: bool
525
+ # train:
526
+ # target: path to train dataset
527
+ # params:
528
+ # key: value
529
+ # validation:
530
+ # target: path to validation dataset
531
+ # params:
532
+ # key: value
533
+ # test:
534
+ # target: path to test dataset
535
+ # params:
536
+ # key: value
537
+ # lightning: (optional, has sane defaults and can be specified on cmdline)
538
+ # trainer:
539
+ # additional arguments to trainer
540
+ # logger:
541
+ # logger to instantiate
542
+ # modelcheckpoint:
543
+ # modelcheckpoint to instantiate
544
+ # callbacks:
545
+ # callback1:
546
+ # target: importpath
547
+ # params:
548
+ # key: value
549
+
550
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
551
+
552
+ # add cwd for convenience and to make classes in this file available when
553
+ # running as `python main.py`
554
+ # (in particular `main.DataModuleFromConfig`)
555
+ sys.path.append(os.getcwd())
556
+
557
+ parser = get_parser()
558
+ parser = Trainer.add_argparse_args(parser)
559
+
560
+ opt, unknown = parser.parse_known_args()
561
+
562
+ assert opt.name
563
+ cfg_fname = os.path.split(opt.base[0])[-1]
564
+ cfg_name = os.path.splitext(cfg_fname)[0]
565
+ nowname = f"{cfg_name}_{opt.name}"
566
+ logdir = os.path.join(opt.logdir, nowname)
567
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
568
+ resume = False
569
+
570
+ if os.path.isfile(ckpt):
571
+ opt.resume_from_checkpoint = ckpt
572
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
573
+ opt.base = base_configs + opt.base
574
+ _tmp = logdir.split("/")
575
+ nowname = _tmp[-1]
576
+ resume = True
577
+
578
+ ckptdir = os.path.join(logdir, "checkpoints")
579
+ cfgdir = os.path.join(logdir, "configs")
580
+
581
+ os.makedirs(logdir, exist_ok=True)
582
+ os.makedirs(ckptdir, exist_ok=True)
583
+ os.makedirs(cfgdir, exist_ok=True)
584
+
585
+ try:
586
+ # init and save configs
587
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
588
+ cli = OmegaConf.from_dotlist(unknown)
589
+ config = OmegaConf.merge(*configs, cli)
590
+
591
+ if resume:
592
+ # By default, when finetuning from Stable Diffusion, we load the EMA-only checkpoint to initialize all weights.
593
+ # If resuming InstructPix2Pix from a finetuning checkpoint, instead load both EMA and non-EMA weights.
594
+ config.model.params.load_ema = True
595
+
596
+ lightning_config = config.pop("lightning", OmegaConf.create())
597
+ # merge trainer cli with config
598
+ trainer_config = lightning_config.get("trainer", OmegaConf.create())
599
+ # default to ddp
600
+ trainer_config["accelerator"] = "ddp"
601
+ for k in nondefault_trainer_args(opt):
602
+ trainer_config[k] = getattr(opt, k)
603
+ if not "gpus" in trainer_config:
604
+ del trainer_config["accelerator"]
605
+ cpu = True
606
+ else:
607
+ gpuinfo = trainer_config["gpus"]
608
+ print(f"Running on GPUs {gpuinfo}")
609
+ cpu = False
610
+ trainer_opt = argparse.Namespace(**trainer_config)
611
+ lightning_config.trainer = trainer_config
612
+
613
+ # model
614
+ model = instantiate_from_config(config.model)
615
+
616
+ # trainer and callbacks
617
+ trainer_kwargs = dict()
618
+
619
+ # default logger configs
620
+ default_logger_cfgs = {
621
+ "wandb": {
622
+ "target": "pytorch_lightning.loggers.WandbLogger",
623
+ "params": {
624
+ "name": nowname,
625
+ "save_dir": logdir,
626
+ "id": nowname,
627
+ }
628
+ },
629
+ "testtube": {
630
+ "target": "pytorch_lightning.loggers.TestTubeLogger",
631
+ "params": {
632
+ "name": "testtube",
633
+ "save_dir": logdir,
634
+ }
635
+ },
636
+ }
637
+ default_logger_cfg = default_logger_cfgs["testtube"]
638
+ if "logger" in lightning_config:
639
+ logger_cfg = lightning_config.logger
640
+ else:
641
+ logger_cfg = OmegaConf.create()
642
+ logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
643
+ trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
644
+
645
+ # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
646
+ # specify which metric is used to determine best models
647
+ default_modelckpt_cfg = {
648
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
649
+ "params": {
650
+ "dirpath": ckptdir,
651
+ "filename": "{epoch:06}",
652
+ "verbose": True,
653
+ "save_last": True,
654
+ }
655
+ }
656
+
657
+ if "modelcheckpoint" in lightning_config:
658
+ modelckpt_cfg = lightning_config.modelcheckpoint
659
+ else:
660
+ modelckpt_cfg = OmegaConf.create()
661
+ modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
662
+ print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
663
+ if version.parse(pl.__version__) < version.parse('1.4.0'):
664
+ trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
665
+
666
+ # add callback which sets up log directory
667
+ default_callbacks_cfg = {
668
+ "setup_callback": {
669
+ "target": "main.SetupCallback",
670
+ "params": {
671
+ "resume": opt.resume,
672
+ "now": now,
673
+ "logdir": logdir,
674
+ "ckptdir": ckptdir,
675
+ "cfgdir": cfgdir,
676
+ "config": config,
677
+ "lightning_config": lightning_config,
678
+ }
679
+ },
680
+ "image_logger": {
681
+ "target": "main.ImageLogger",
682
+ "params": {
683
+ "batch_frequency": 750,
684
+ "max_images": 4,
685
+ "clamp": True
686
+ }
687
+ },
688
+ "learning_rate_logger": {
689
+ "target": "main.LearningRateMonitor",
690
+ "params": {
691
+ "logging_interval": "step",
692
+ # "log_momentum": True
693
+ }
694
+ },
695
+ "cuda_callback": {
696
+ "target": "main.CUDACallback"
697
+ },
698
+ }
699
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
700
+ default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
701
+
702
+ if "callbacks" in lightning_config:
703
+ callbacks_cfg = lightning_config.callbacks
704
+ else:
705
+ callbacks_cfg = OmegaConf.create()
706
+
707
+ print(
708
+ 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
709
+ default_metrics_over_trainsteps_ckpt_dict = {
710
+ 'metrics_over_trainsteps_checkpoint': {
711
+ "target": 'pytorch_lightning.callbacks.ModelCheckpoint',
712
+ 'params': {
713
+ "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
714
+ "filename": "{epoch:06}-{step:09}",
715
+ "verbose": True,
716
+ 'save_top_k': -1,
717
+ 'every_n_train_steps': 1000,
718
+ 'save_weights_only': True
719
+ }
720
+ }
721
+ }
722
+ default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
723
+
724
+ callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
725
+ if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):
726
+ callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint
727
+ elif 'ignore_keys_callback' in callbacks_cfg:
728
+ del callbacks_cfg['ignore_keys_callback']
729
+
730
+ trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
731
+
732
+ trainer = Trainer.from_argparse_args(trainer_opt, plugins=DDPPlugin(find_unused_parameters=False), **trainer_kwargs)
733
+ trainer.logdir = logdir ###
734
+
735
+ # data
736
+ data = instantiate_from_config(config.data)
737
+ # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
738
+ # calling these ourselves should not be necessary but it is.
739
+ # lightning still takes care of proper multiprocessing though
740
+ data.prepare_data()
741
+ data.setup()
742
+ print("#### Data #####")
743
+ for k in data.datasets:
744
+ print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
745
+
746
+ # configure learning rate
747
+ bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
748
+ if not cpu:
749
+ ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
750
+ else:
751
+ ngpu = 1
752
+ if 'accumulate_grad_batches' in lightning_config.trainer:
753
+ accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
754
+ else:
755
+ accumulate_grad_batches = 1
756
+ print(f"accumulate_grad_batches = {accumulate_grad_batches}")
757
+ lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
758
+ if opt.scale_lr:
759
+ model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
760
+ print(
761
+ "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
762
+ model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
763
+ else:
764
+ model.learning_rate = base_lr
765
+ print("++++ NOT USING LR SCALING ++++")
766
+ print(f"Setting learning rate to {model.learning_rate:.2e}")
767
+
768
+
769
+ # allow checkpointing via USR1
770
+ def melk(*args, **kwargs):
771
+ # run all checkpoint hooks
772
+ if trainer.global_rank == 0:
773
+ print("Summoning checkpoint.")
774
+ ckpt_path = os.path.join(ckptdir, "last.ckpt")
775
+ trainer.save_checkpoint(ckpt_path)
776
+
777
+
778
+ def divein(*args, **kwargs):
779
+ if trainer.global_rank == 0:
780
+ import pudb;
781
+ pudb.set_trace()
782
+
783
+
784
+ import signal
785
+
786
+ signal.signal(signal.SIGUSR1, melk)
787
+ signal.signal(signal.SIGUSR2, divein)
788
+
789
+ # run
790
+ if opt.train:
791
+ try:
792
+ trainer.fit(model, data)
793
+ except Exception:
794
+ melk()
795
+ raise
796
+ if not opt.no_test and not trainer.interrupted:
797
+ trainer.test(model, data)
798
+ except Exception:
799
+ if opt.debug and trainer.global_rank == 0:
800
+ try:
801
+ import pudb as debugger
802
+ except ImportError:
803
+ import pdb as debugger
804
+ debugger.post_mortem()
805
+ raise
806
+ finally:
807
+ # move newly created debug project to debug_runs
808
+ if opt.debug and not opt.resume and trainer.global_rank == 0:
809
+ dst, name = os.path.split(logdir)
810
+ dst = os.path.join(dst, "debug_runs", name)
811
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
812
+ os.rename(logdir, dst)
813
+ if trainer.global_rank == 0:
814
+ print(trainer.profiler.summary())
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/metrics/clip_similarity.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import clip
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+
9
+
10
+ class ClipSimilarity(nn.Module):
11
+ def __init__(self, name: str = "ViT-L/14"):
12
+ super().__init__()
13
+ assert name in ("RN50", "RN101", "RN50x4", "RN50x16", "RN50x64", "ViT-B/32", "ViT-B/16", "ViT-L/14", "ViT-L/14@336px") # fmt: skip
14
+ self.size = {"RN50x4": 288, "RN50x16": 384, "RN50x64": 448, "ViT-L/14@336px": 336}.get(name, 224)
15
+
16
+ self.model, _ = clip.load(name, device="cpu", download_root="./")
17
+ self.model.eval().requires_grad_(False)
18
+
19
+ self.register_buffer("mean", torch.tensor((0.48145466, 0.4578275, 0.40821073)))
20
+ self.register_buffer("std", torch.tensor((0.26862954, 0.26130258, 0.27577711)))
21
+
22
+ def encode_text(self, text: list[str]) -> torch.Tensor:
23
+ text = clip.tokenize(text, truncate=True).to(next(self.parameters()).device)
24
+ text_features = self.model.encode_text(text)
25
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
26
+ return text_features
27
+
28
+ def encode_image(self, image: torch.Tensor) -> torch.Tensor: # Input images in range [0, 1].
29
+ image = F.interpolate(image.float(), size=self.size, mode="bicubic", align_corners=False)
30
+ image = image - rearrange(self.mean, "c -> 1 c 1 1")
31
+ image = image / rearrange(self.std, "c -> 1 c 1 1")
32
+ image_features = self.model.encode_image(image)
33
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
34
+ return image_features
35
+
36
+ def forward(
37
+ self, image_0: torch.Tensor, image_1: torch.Tensor, text_0: list[str], text_1: list[str]
38
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
39
+ image_features_0 = self.encode_image(image_0)
40
+ image_features_1 = self.encode_image(image_1)
41
+ text_features_0 = self.encode_text(text_0)
42
+ text_features_1 = self.encode_text(text_1)
43
+ sim_0 = F.cosine_similarity(image_features_0, text_features_0)
44
+ sim_1 = F.cosine_similarity(image_features_1, text_features_1)
45
+ sim_direction = F.cosine_similarity(image_features_1 - image_features_0, text_features_1 - text_features_0)
46
+ sim_image = F.cosine_similarity(image_features_0, image_features_1)
47
+ return sim_0, sim_1, sim_direction, sim_image
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/metrics/compute_metrics.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import random
5
+ import sys
6
+ from argparse import ArgumentParser
7
+
8
+ import einops
9
+ import k_diffusion as K
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ from tqdm.auto import tqdm
14
+ from einops import rearrange
15
+ from omegaconf import OmegaConf
16
+ from PIL import Image, ImageOps
17
+ from torch import autocast
18
+
19
+ import json
20
+ import matplotlib.pyplot as plt
21
+ import seaborn
22
+ from pathlib import Path
23
+
24
+ sys.path.append("./")
25
+
26
+ from clip_similarity import ClipSimilarity
27
+ from edit_dataset import EditDatasetEval
28
+
29
+ sys.path.append("./stable_diffusion")
30
+
31
+ from ldm.util import instantiate_from_config
32
+
33
+
34
+ class CFGDenoiser(nn.Module):
35
+ def __init__(self, model):
36
+ super().__init__()
37
+ self.inner_model = model
38
+
39
+ def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale):
40
+ cfg_z = einops.repeat(z, "1 ... -> n ...", n=3)
41
+ cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3)
42
+ cfg_cond = {
43
+ "c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])],
44
+ "c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])],
45
+ }
46
+ out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3)
47
+ return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)
48
+
49
+
50
+ def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
51
+ print(f"Loading model from {ckpt}")
52
+ pl_sd = torch.load(ckpt, map_location="cpu")
53
+ if "global_step" in pl_sd:
54
+ print(f"Global Step: {pl_sd['global_step']}")
55
+ sd = pl_sd["state_dict"]
56
+ if vae_ckpt is not None:
57
+ print(f"Loading VAE from {vae_ckpt}")
58
+ vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
59
+ sd = {
60
+ k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
61
+ for k, v in sd.items()
62
+ }
63
+ model = instantiate_from_config(config.model)
64
+ m, u = model.load_state_dict(sd, strict=False)
65
+ if len(m) > 0 and verbose:
66
+ print("missing keys:")
67
+ print(m)
68
+ if len(u) > 0 and verbose:
69
+ print("unexpected keys:")
70
+ print(u)
71
+ return model
72
+
73
+ class ImageEditor(nn.Module):
74
+ def __init__(self, config, ckpt, vae_ckpt=None):
75
+ super().__init__()
76
+
77
+ config = OmegaConf.load(config)
78
+ self.model = load_model_from_config(config, ckpt, vae_ckpt)
79
+ self.model.eval().cuda()
80
+ self.model_wrap = K.external.CompVisDenoiser(self.model)
81
+ self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
82
+ self.null_token = self.model.get_learned_conditioning([""])
83
+
84
+ def forward(
85
+ self,
86
+ image: torch.Tensor,
87
+ edit: str,
88
+ scale_txt: float = 7.5,
89
+ scale_img: float = 1.0,
90
+ steps: int = 100,
91
+ ) -> torch.Tensor:
92
+ assert image.dim() == 3
93
+ assert image.size(1) % 64 == 0
94
+ assert image.size(2) % 64 == 0
95
+ with torch.no_grad(), autocast("cuda"), self.model.ema_scope():
96
+ cond = {
97
+ "c_crossattn": [self.model.get_learned_conditioning([edit])],
98
+ "c_concat": [self.model.encode_first_stage(image[None]).mode()],
99
+ }
100
+ uncond = {
101
+ "c_crossattn": [self.model.get_learned_conditioning([""])],
102
+ "c_concat": [torch.zeros_like(cond["c_concat"][0])],
103
+ }
104
+ extra_args = {
105
+ "uncond": uncond,
106
+ "cond": cond,
107
+ "image_cfg_scale": scale_img,
108
+ "text_cfg_scale": scale_txt,
109
+ }
110
+ sigmas = self.model_wrap.get_sigmas(steps)
111
+ x = torch.randn_like(cond["c_concat"][0]) * sigmas[0]
112
+ x = K.sampling.sample_euler_ancestral(self.model_wrap_cfg, x, sigmas, extra_args=extra_args)
113
+ x = self.model.decode_first_stage(x)[0]
114
+ return x
115
+
116
+
117
+ def compute_metrics(config,
118
+ model_path,
119
+ vae_ckpt,
120
+ data_path,
121
+ output_path,
122
+ scales_img,
123
+ scales_txt,
124
+ num_samples = 5000,
125
+ split = "test",
126
+ steps = 50,
127
+ res = 512,
128
+ seed = 0):
129
+ editor = ImageEditor(config, model_path, vae_ckpt).cuda()
130
+ clip_similarity = ClipSimilarity().cuda()
131
+
132
+
133
+
134
+ outpath = Path(output_path, f"n={num_samples}_p={split}_s={steps}_r={res}_e={seed}.jsonl")
135
+ Path(output_path).mkdir(parents=True, exist_ok=True)
136
+
137
+ for scale_txt in scales_txt:
138
+ for scale_img in scales_img:
139
+ dataset = EditDatasetEval(
140
+ path=data_path,
141
+ split=split,
142
+ res=res
143
+ )
144
+ assert num_samples <= len(dataset)
145
+ print(f'Processing t={scale_txt}, i={scale_img}')
146
+ torch.manual_seed(seed)
147
+ perm = torch.randperm(len(dataset))
148
+ count = 0
149
+ i = 0
150
+
151
+ sim_0_avg = 0
152
+ sim_1_avg = 0
153
+ sim_direction_avg = 0
154
+ sim_image_avg = 0
155
+ count = 0
156
+
157
+ pbar = tqdm(total=num_samples)
158
+ while count < num_samples:
159
+
160
+ idx = perm[i].item()
161
+ sample = dataset[idx]
162
+ i += 1
163
+
164
+ gen = editor(sample["image_0"].cuda(), sample["edit"], scale_txt=scale_txt, scale_img=scale_img, steps=steps)
165
+
166
+ sim_0, sim_1, sim_direction, sim_image = clip_similarity(
167
+ sample["image_0"][None].cuda(), gen[None].cuda(), [sample["input_prompt"]], [sample["output_prompt"]]
168
+ )
169
+ sim_0_avg += sim_0.item()
170
+ sim_1_avg += sim_1.item()
171
+ sim_direction_avg += sim_direction.item()
172
+ sim_image_avg += sim_image.item()
173
+ count += 1
174
+ pbar.update(count)
175
+ pbar.close()
176
+
177
+ sim_0_avg /= count
178
+ sim_1_avg /= count
179
+ sim_direction_avg /= count
180
+ sim_image_avg /= count
181
+
182
+ with open(outpath, "a") as f:
183
+ f.write(f"{json.dumps(dict(sim_0=sim_0_avg, sim_1=sim_1_avg, sim_direction=sim_direction_avg, sim_image=sim_image_avg, num_samples=num_samples, split=split, scale_txt=scale_txt, scale_img=scale_img, steps=steps, res=res, seed=seed))}\n")
184
+ return outpath
185
+
186
+ def plot_metrics(metrics_file, output_path):
187
+
188
+ with open(metrics_file, 'r') as f:
189
+ data = [json.loads(line) for line in f]
190
+
191
+ plt.rcParams.update({'font.size': 11.5})
192
+ seaborn.set_style("darkgrid")
193
+ plt.figure(figsize=(20.5* 0.7, 10.8* 0.7), dpi=200)
194
+
195
+ x = [d["sim_direction"] for d in data]
196
+ y = [d["sim_image"] for d in data]
197
+
198
+ plt.plot(x, y, marker='o', linewidth=2, markersize=4)
199
+
200
+ plt.xlabel("CLIP Text-Image Direction Similarity", labelpad=10)
201
+ plt.ylabel("CLIP Image Similarity", labelpad=10)
202
+
203
+ plt.savefig(Path(output_path) / Path("plot.pdf"), bbox_inches="tight")
204
+
205
+ def main():
206
+ parser = ArgumentParser()
207
+ parser.add_argument("--resolution", default=512, type=int)
208
+ parser.add_argument("--steps", default=100, type=int)
209
+ parser.add_argument("--config", default="configs/generate.yaml", type=str)
210
+ parser.add_argument("--output_path", default="analysis/", type=str)
211
+ parser.add_argument("--ckpt", default="checkpoints/instruct-pix2pix-00-22000.ckpt", type=str)
212
+ parser.add_argument("--dataset", default="data/clip-filtered-dataset/", type=str)
213
+ parser.add_argument("--vae-ckpt", default=None, type=str)
214
+ args = parser.parse_args()
215
+
216
+ scales_img = [1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2]
217
+ scales_txt = [7.5]
218
+
219
+ metrics_file = compute_metrics(
220
+ args.config,
221
+ args.ckpt,
222
+ args.vae_ckpt,
223
+ args.dataset,
224
+ args.output_path,
225
+ scales_img,
226
+ scales_txt,
227
+ steps = args.steps,
228
+ )
229
+
230
+ plot_metrics(metrics_file, args.output_path)
231
+
232
+
233
+
234
+ if __name__ == "__main__":
235
+ main()
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/prompt_app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from argparse import ArgumentParser
4
+
5
+ import datasets
6
+ import gradio as gr
7
+ import numpy as np
8
+ import openai
9
+
10
+ from dataset_creation.generate_txt_dataset import generate
11
+
12
+
13
+ def main(openai_model: str):
14
+ dataset = datasets.load_dataset("ChristophSchuhmann/improved_aesthetics_6.5plus", split="train")
15
+ captions = dataset[np.random.permutation(len(dataset))]["TEXT"]
16
+ index = 0
17
+
18
+ def click_random():
19
+ nonlocal index
20
+ output = captions[index]
21
+ index = (index + 1) % len(captions)
22
+ return output
23
+
24
+ def click_generate(input: str):
25
+ if input == "":
26
+ raise gr.Error("Input caption is missing!")
27
+ edit_output = generate(openai_model, input)
28
+ if edit_output is None:
29
+ return "Failed :(", "Failed :("
30
+ return edit_output
31
+
32
+ with gr.Blocks(css="footer {visibility: hidden}") as demo:
33
+ txt_input = gr.Textbox(lines=3, label="Input Caption", interactive=True, placeholder="Type image caption here...") # fmt: skip
34
+ txt_edit = gr.Textbox(lines=1, label="GPT-3 Instruction", interactive=False)
35
+ txt_output = gr.Textbox(lines=3, label="GPT3 Edited Caption", interactive=False)
36
+
37
+ with gr.Row():
38
+ clear_btn = gr.Button("Clear")
39
+ random_btn = gr.Button("Random Input")
40
+ generate_btn = gr.Button("Generate Instruction + Edited Caption")
41
+
42
+ clear_btn.click(fn=lambda: ("", "", ""), inputs=[], outputs=[txt_input, txt_edit, txt_output])
43
+ random_btn.click(fn=click_random, inputs=[], outputs=[txt_input])
44
+ generate_btn.click(fn=click_generate, inputs=[txt_input], outputs=[txt_edit, txt_output])
45
+
46
+ demo.launch(share=True)
47
+
48
+
49
+ if __name__ == "__main__":
50
+ parser = ArgumentParser()
51
+ parser.add_argument("--openai-api-key", required=True, type=str)
52
+ parser.add_argument("--openai-model", required=True, type=str)
53
+ args = parser.parse_args()
54
+ openai.api_key = args.openai_api_key
55
+ main(args.openai_model)
instruct-pix2pix-BioMedCLIP-concat-newdata-data-effusion/requirements.txt ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.3.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.12.12
4
+ aiosignal==1.3.2
5
+ antlr4-python3-runtime==4.8
6
+ async-timeout==5.0.1
7
+ attrs==25.3.0
8
+ blosc2==2.5.1
9
+ Brotli==1.0.9
10
+ cachetools==5.5.2
11
+ certifi==2025.4.26
12
+ charset-normalizer==3.3.2
13
+ contourpy==1.3.0
14
+ cycler==0.12.1
15
+ easydict==1.11
16
+ einops==0.7.0
17
+ filelock==3.17.0
18
+ fonttools==4.58.2
19
+ frozenlist==1.7.0
20
+ fsspec==2025.5.1
21
+ ftfy==6.3.1
22
+ future==1.0.0
23
+ gmpy2==2.2.1
24
+ google-auth==2.40.3
25
+ google-auth-oauthlib==1.2.2
26
+ grpcio==1.73.0
27
+ hf-xet==1.1.3
28
+ huggingface-hub==0.33.0
29
+ idna==3.7
30
+ imageio==2.31.6
31
+ importlib_metadata==8.7.0
32
+ importlib_resources==6.5.2
33
+ Jinja2==3.1.6
34
+ joblib==1.5.1
35
+ kiwisolver==1.4.7
36
+ lazy_loader==0.4
37
+ lightning-utilities==0.9.0
38
+ lpips==0.1.4
39
+ Markdown==3.8
40
+ MarkupSafe==3.0.2
41
+ matplotlib==3.9.4
42
+ mkl_fft==1.3.11
43
+ mkl_random==1.2.8
44
+ mkl-service==2.4.1
45
+ mpmath==1.3.0
46
+ msgpack==1.1.0
47
+ multidict==6.4.4
48
+ ndindex==1.10.0
49
+ networkx==3.2.1
50
+ nibabel==5.2.0
51
+ numpy
52
+ oauthlib==3.2.2
53
+ omegaconf==2.1.1
54
+ open-clip-torch==2.7.0
55
+ opencv-python==4.8.1.78
56
+ packaging==25.0
57
+ pandas==2.1.2
58
+ Pillow==10.0.1
59
+ pip==22.3.1
60
+ propcache==0.3.2
61
+ protobuf==4.23.4
62
+ py-cpuinfo==9.0.0
63
+ pyasn1==0.6.1
64
+ pyasn1_modules==0.4.2
65
+ pyDeprecate==0.3.1
66
+ pyparsing==3.1.1
67
+ PySocks==1.7.1
68
+ python-dateutil==2.9.0.post0
69
+ pytorch-lightning==1.4.2
70
+ pytz==2025.2
71
+ PyYAML==6.0.2
72
+ regex==2024.11.6
73
+ requests==2.32.3
74
+ requests-oauthlib==2.0.0
75
+ rsa==4.9.1
76
+ scikit-image==0.22.0
77
+ scikit-learn==1.3.2
78
+ scipy==1.11.3
79
+ seaborn==0.13.0
80
+ setuptools==78.1.1
81
+ six==1.17.0
82
+ sympy==1.13.3
83
+ tensorboard==2.15.1
84
+ tensorboard-data-server==0.7.2
85
+ tensorboardX==2.6.2.2
86
+ test_tube==0.7.5
87
+ threadpoolctl==3.6.0
88
+ tifffile==2024.8.30
89
+ tokenizers==0.13.3
90
+ tqdm==4.66.1
91
+ transformers==4.29.2
92
+ triton==2.1.0
93
+ typing_extensions==4.12.2
94
+ tzdata==2025.2
95
+ urllib3==2.3.0
96
+ volumentations-3D==1.0.4
97
+ wcwidth==0.2.13
98
+ Werkzeug==3.1.3
99
+ wheel==0.45.1
100
+ yarl==1.20.1
101
+ zipp==3.23.0