wangfuyun commited on
Commit
522bf24
·
verified ·
1 Parent(s): f9f4433

Add UniRL inference code

Browse files
.gitattributes CHANGED
@@ -37,3 +37,6 @@ promptrl_geneval/tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
  promptrl_ocr/tokenizer.json filter=lfs diff=lfs merge=lfs -text
38
  promptrl_ps/tokenizer.json filter=lfs diff=lfs merge=lfs -text
39
  promptrl_edit/tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
 
 
 
37
  promptrl_ocr/tokenizer.json filter=lfs diff=lfs merge=lfs -text
38
  promptrl_ps/tokenizer.json filter=lfs diff=lfs merge=lfs -text
39
  promptrl_edit/tokenizer.json filter=lfs diff=lfs merge=lfs -text
40
+ assets/edit_comparison.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/logo.png filter=lfs diff=lfs merge=lfs -text
42
+ assets/t2i_comparison.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ *.pyc
3
+ __pycache__/
4
+ *.pyo
5
+ *.pyd
6
+ .Python
7
+ env/
8
+ venv/
9
+ ENV/
10
+ env.bak/
11
+ venv.bak/
12
+
13
+ # Byte-compiled / optimized / DLL files
14
+ *.py[cod]
15
+ *$py.class
16
+
17
+ # C extensions
18
+ *.so
19
+
20
+ # Distribution / packaging
21
+ .Python
22
+ build/
23
+ dist/
24
+ downloads/
25
+ eggs/
26
+ .eggs/
27
+ lib/
28
+ lib64/
29
+ parts/
30
+ sdist/
31
+ var/
32
+ wheels/
33
+ *.egg-info/
34
+ .installed.cfg
35
+ *.egg
36
+
37
+ # PyInstaller
38
+ *.manifest
39
+ *.spec
40
+
41
+ # Installer logs
42
+ pip-log.txt
43
+ pip-delete-this-directory.txt
44
+
45
+ # Unit test / coverage reports
46
+ htmlcov/
47
+ .tox/
48
+ .nox/
49
+ .coverage
50
+ .coverage.*
51
+ .cache
52
+ nosetests.xml
53
+ coverage.xml
54
+ *.cover
55
+ *.py,cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+
59
+ # Jupyter Notebook
60
+ .ipynb_checkpoints
61
+
62
+ # IPython
63
+ profile_default/
64
+ ipython_config.py
65
+
66
+ # pyenv
67
+ .python-version
68
+
69
+ # pipenv
70
+ Pipfile.lock
71
+
72
+ # Poetry
73
+ poetry.lock
74
+
75
+ # Virtualenv
76
+ .venv
77
+ venv/
78
+ ENV/
79
+
80
+ # Spyder project settings
81
+ .spyderproject
82
+ .spyproject
83
+
84
+ # Rope project settings
85
+ .ropeproject
86
+
87
+ # mkdocs documentation
88
+ /site
89
+
90
+ # mypy
91
+ .mypy_cache/
92
+ .dmypy.json
93
+ dmypy.json
94
+
95
+ # Pyre type checker
96
+ .pyre/
97
+
98
+ # IDEs and editors
99
+ .idea/
100
+ .vscode/
101
+ *.sublime-workspace
102
+
103
+ # OS generated files
104
+ .DS_Store
105
+ Thumbs.db
106
+
107
+ # Logs
108
+ *.log
109
+ logs/
110
+ *.log.*
111
+
112
+ # Dependency directories
113
+ node_modules/
114
+ bower_components/
115
+
116
+ # Optional: Local configuration files
117
+ *.local
118
+ *.env
119
+ .env
120
+ .env.local
121
+ .env.development.local
122
+ .env.test.local
123
+ .env.production.local
124
+
125
+ # Optional: Database
126
+ *.sqlite3
127
+ *.db
128
+
129
+ # Optional: Django
130
+ *.sqlite3
131
+ migrations/
132
+ *.mo
133
+ *.pot
134
+ staticfiles/
135
+
136
+ # Optional: Flask
137
+ instance/
138
+ .webassets-cache
139
+
140
+ # Optional: Scrapy
141
+ .scrapy
142
+
143
+ outputs/
144
+
145
+ wandb/
146
+
147
+ assets/large_rl_datasets/
148
+
149
+ utils/parquet_cache/
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2025] [Fu-Yun Wang]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="assets/logo.png" width="30%"><br>
3
+ PromptRL
4
+ </p>
5
+
6
+ <p align="center">
7
+ <a href="https://arxiv.org/abs/2602.01382"><img src="https://img.shields.io/badge/arXiv-2602.01382-b31b1b.svg" alt="arXiv"></a>
8
+ <a href="https://g-u-n.github.io/projects/promptrl/"><img src="https://img.shields.io/badge/Project-Page-green.svg" alt="Project Page"></a>
9
+ <a href="https://huggingface.co/wangfuyun/PrompRL"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue" alt="HuggingFace"></a>
10
+ </p>
11
+
12
+ ## Overview
13
+
14
+ **PromptRL** is a framework that jointly trains language models (LMs) and flow-matching models (FMs) within a unified reinforcement learning loop for text-to-image generation. By incorporating LMs as adaptive prompt refiners, PromptRL addresses two critical limitations in current flow-based RL pipelines: *exploration collapse* due to insufficient generation diversity, and *prompt overfitting* where models memorize specific training formulations.
15
+
16
+
17
+ ## Installation
18
+
19
+ ```bash
20
+ conda env create -f environment.yml
21
+ conda activate unirl
22
+ pip install git+https://github.com/openai/CLIP.git
23
+ pip install git+https://github.com/huggingface/diffusers.git
24
+ pip install flash-attn==2.7.4.post1 --no-build-isolation
25
+
26
+ # run gen.sh for evaluation
27
+ # bash gen.sh
28
+ ```
29
+
30
+ ## Qualitative Results
31
+
32
+ ### Text-to-Image Generation
33
+ <p align="center">
34
+ <img src="assets/t2i_comparison.png" width="85%">
35
+ </p>
36
+
37
+ ### Instructional Image Editing
38
+ <p align="center">
39
+ <img src="assets/edit_comparison.png" width="75%">
40
+ </p>
41
+
42
+
43
+ ## Key Results
44
+
45
+ PromptRL achieves **2× sample efficiency** compared to flow-only RL while obtains a adaptative prompt refinement agent to improve test-time performance.
46
+
47
+ ### Summary
48
+
49
+ | Benchmark | Metric | PromptRL w/ PE | Best Baseline |
50
+ |:---|:---|:---:|:---:|
51
+ | GenEval | Avg. Score ↑ | **0.97** | 0.92 (FlowGRPO) |
52
+ | Aesthetic | PickScore ↑ | **24.05** | 23.63 (DiffusionNFT) |
53
+ | Aesthetic | HPS ↑ | **32.03** | 31.79 (DiffusionNFT) |
54
+ | OCR | OCR-1k ↑ | **0.98** | 0.89 (FlowGRPO) |
55
+ | Image Editing | EditReward Avg. ↑ | **1.43** | 1.44 (ReasonEdit-Think) |
56
+
57
+ ---
58
+
59
+ <details>
60
+ <summary><b>📊 GenEval Benchmark (Full Results)</b></summary>
61
+
62
+ <br>
63
+
64
+ | Model | 1 Obj. | 2 Obj. | Cnt. | Clr. | Pos. | Attr. | Avg.↑ |
65
+ |:---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
66
+ | Show-o | 0.95 | 0.52 | 0.49 | 0.82 | 0.11 | 0.28 | 0.53 |
67
+ | Emu3-Gen | 0.98 | 0.71 | 0.34 | 0.81 | 0.17 | 0.21 | 0.54 |
68
+ | SD3 Medium | 0.98 | 0.74 | 0.63 | 0.67 | 0.34 | 0.36 | 0.62 |
69
+ | FLUX.1-dev | 0.98 | 0.81 | 0.74 | 0.79 | 0.22 | 0.45 | 0.66 |
70
+ | SD3.5 Large | 0.98 | 0.89 | 0.73 | 0.83 | 0.34 | 0.47 | 0.71 |
71
+ | JanusFlow | 0.97 | 0.59 | 0.45 | 0.83 | 0.53 | 0.42 | 0.63 |
72
+ | Janus-Pro-7B | 0.99 | 0.89 | 0.59 | 0.90 | 0.79 | 0.66 | 0.80 |
73
+ | HiDream | 1.00 | 0.98 | 0.79 | 0.91 | 0.60 | 0.72 | 0.83 |
74
+ | Seedream 3.0 | 0.99 | 0.96 | 0.91 | 0.93 | 0.47 | 0.80 | 0.84 |
75
+ | Qwen-Image | 0.99 | 0.92 | 0.89 | 0.88 | 0.76 | 0.77 | 0.87 |
76
+ | *RL-based* | | | | | | | |
77
+ | RePrompt | 0.98 | 0.87 | 0.77 | 0.85 | 0.62 | 0.49 | 0.76 |
78
+ | FlowGRPO | 1.00 | 0.99 | 0.91 | 0.89 | 0.95 | 0.80 | 0.92 |
79
+ | DiffusionNFT | 1.00 | 0.98 | 0.74 | 0.92 | 0.85 | 0.80 | 0.88 |
80
+ | PromptRL w/o PE | 1.00 | 0.96 | 0.95 | 0.95 | 0.93 | 0.85 | 0.94 |
81
+ | **PromptRL w/ PE** | **1.00** | **0.99** | **0.99** | **0.96** | **0.99** | **0.90** | **0.97** |
82
+
83
+ </details>
84
+
85
+ <details>
86
+ <summary><b>🎨 Aesthetic & OCR Metrics (Full Results)</b></summary>
87
+
88
+ <br>
89
+
90
+ | Model | P.S. | HPS | U.R. | OCR-1k | TMDB | OpenLib |
91
+ |:---|:---:|:---:|:---:|:---:|:---:|:---:|
92
+ | SD1.5 | 20.92 | 23.71 | 2.00 | 0.05 | 0.13 | 0.08 |
93
+ | SDXL | 22.14 | 26.67 | 2.78 | 0.13 | 0.20 | 0.09 |
94
+ | SD3 Medium | 22.38 | 28.56 | 3.09 | — | 0.44 | 0.33 |
95
+ | FLUX.1-schnell | 22.64 | 29.39 | 3.25 | 0.54 | 0.66 | 0.50 |
96
+ | FLUX.2-klein | 22.79 | 29.03 | 3.29 | 0.55 | 0.22 | 0.46 |
97
+ | Z-Image | 20.14 | 28.22 | 3.51 | 0.70 | 0.71 | 0.83 |
98
+ | Qwen-Image | 23.05 | 30.40 | 3.53 | 0.65 | 0.79 | 0.94 |
99
+ | Qwen-Image-2512 | 23.16 | 30.79 | 3.40 | 0.72 | 0.81 | 0.87 |
100
+ | *RL-based* | | | | | | |
101
+ | FlowGRPO | 23.33 | 29.80 | 3.33 | 0.89 | 0.83 | 0.73 |
102
+ | DiffusionNFT | 23.63 | 31.79 | 3.39 | 0.89 | 0.91 | 0.86 |
103
+ | PromptRL w/o PE | 24.01 | 31.79 | 3.38 | 0.97 | 0.92 | 0.95 |
104
+ | **PromptRL w/ PE** | **24.05** | **32.03** | **3.44** | **0.98** | **0.91** | **0.95** |
105
+
106
+ </details>
107
+
108
+ <details>
109
+ <summary><b>✏️ Image Editing - EditReward (Full Results)</b></summary>
110
+
111
+ <br>
112
+
113
+ | Model | Swap | Style | Add. | Attr. | Env. | Removal | Avg.↑ |
114
+ |:---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
115
+ | InstructPix2Pix | -0.24 | 0.91 | -0.45 | 0.45 | 0.48 | -0.80 | 0.02 |
116
+ | MagicBrush | -0.38 | 0.36 | -0.78 | -0.80 | 0.91 | -0.85 | -0.27 |
117
+ | LEDITS++ | -0.81 | -0.32 | -0.30 | -0.60 | -0.37 | -0.97 | -0.60 |
118
+ | Qwen-Image-Edit | 1.11 | 1.14 | 0.95 | 0.90 | 1.39 | 0.61 | 1.03 |
119
+ | FLUX.2-klein | 1.42 | 1.73 | 1.29 | 1.42 | 1.80 | 0.32 | 1.34 |
120
+ | Nano Banana | 1.58 | 1.20 | 1.28 | 1.18 | 1.61 | 1.13 | 1.37 |
121
+ | Step1X-Edit | 1.39 | 1.58 | 1.19 | 1.34 | 1.57 | 0.22 | 1.24 |
122
+ | ReasonEdit | 1.51 | 1.43 | 1.19 | 1.47 | 1.58 | 1.14 | 1.40 |
123
+ | ReasonEdit-Think | 1.52 | 1.47 | 1.19 | 1.44 | 1.69 | 1.27 | 1.44 |
124
+ | FLUX.1-Kontext | 1.35 | 1.36 | 1.16 | 1.15 | 1.44 | 0.55 | 1.19 |
125
+ | FLUX.1-Kontext w/ PE | 1.35 | 0.97 | 1.04 | 0.48 | 1.22 | 0.65 | 1.01 |
126
+ | PromptRL w/o PE | 1.45 | 1.46 | 1.28 | 1.35 | 1.56 | 0.98 | 1.36 |
127
+ | **PromptRL w/ PE** | **1.47** | **1.43** | **1.29** | **1.39** | **1.72** | **1.24** | **1.43** |
128
+
129
+ </details>
130
+
131
+
132
+
133
+ ## Citation
134
+
135
+ ```bibtext
136
+ @article{wang2025promptrl,
137
+ title={PromptRL: Prompt Matters in RL for Flow-Based Image Generation},
138
+ author={Wang, Fu-Yun and Zhang, Han and Gharbi, Michael and Li, Hongsheng and Park, Taesung},
139
+ journal={arXiv preprint arXiv:2602.01382},
140
+ year={2026}
141
+ }
142
+ ```
143
+
144
+ ```bibtext
145
+ @article{wang2025unirl,
146
+ title={UniRL-Zero: Reinforcement Learning on Unified Models with Joint Language Model and Diffusion Model Experts},
147
+ author={Wang, Fu-Yun and Zhang, Han and Gharbi, Michael and Li, Hongsheng and Park, Taesung},
148
+ journal={arXiv preprint arXiv:2510.17937},
149
+ year={2025}
150
+ }
151
+ ```
152
+
153
+ ## Acknowledgments
154
+
155
+ This codebase builds upon [UniRL-Zero](https://github.com/G-U-N/UniRL/tree/master).
assets/edit_comparison.png ADDED

Git LFS Details

  • SHA256: 7c9ca476030f9ea93db556f9157a3b94d113deaf7e18857a1712d33f9727f6ee
  • Pointer size: 132 Bytes
  • Size of remote file: 5.64 MB
assets/logo.png ADDED

Git LFS Details

  • SHA256: 9feacdd6b0ae47cc59b2bded2d663e66014fd8c642fbf945a9c0f819e675cfbe
  • Pointer size: 132 Bytes
  • Size of remote file: 1.48 MB
assets/t2i_comparison.png ADDED

Git LFS Details

  • SHA256: 23470ac01392176140c71274f9353dfb4a06c311c92c2c3d5b2c5b9064117e09
  • Pointer size: 132 Bytes
  • Size of remote file: 6.63 MB
environment.yml ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: unirl
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - _libgcc_mutex=0.1=main
6
+ - _openmp_mutex=5.1=1_gnu
7
+ - bzip2=1.0.8=h5eee18b_6
8
+ - ca-certificates=2025.2.25=h06a4308_0
9
+ - expat=2.7.1=h6a678d5_0
10
+ - ld_impl_linux-64=2.40=h12ee557_0
11
+ - libffi=3.4.4=h6a678d5_1
12
+ - libgcc-ng=11.2.0=h1234567_1
13
+ - libgomp=11.2.0=h1234567_1
14
+ - libstdcxx-ng=11.2.0=h1234567_1
15
+ - libuuid=1.41.5=h5eee18b_0
16
+ - libxcb=1.17.0=h9b100fa_0
17
+ - ncurses=6.4=h6a678d5_0
18
+ - openssl=3.0.16=h5eee18b_0
19
+ - pip=25.1=pyhc872135_2
20
+ - pthread-stubs=0.3=h0ce48e5_1
21
+ - python=3.11.13=h1a3bd86_0
22
+ - readline=8.2=h5eee18b_0
23
+ - setuptools=78.1.1=py311h06a4308_0
24
+ - sqlite=3.45.3=h5eee18b_0
25
+ - tk=8.6.14=h993c535_1
26
+ - wheel=0.45.1=py311h06a4308_0
27
+ - xorg-libx11=1.8.12=h9b100fa_1
28
+ - xorg-libxau=1.0.12=h9b100fa_0
29
+ - xorg-libxdmcp=1.1.5=h9b100fa_0
30
+ - xorg-xorgproto=2024.1=h5eee18b_1
31
+ - xz=5.6.4=h5eee18b_1
32
+ - zlib=1.2.13=h5eee18b_1
33
+ - pip:
34
+ - accelerate==1.7.0
35
+ - aiohappyeyeballs==2.6.1
36
+ - aiohttp==3.12.9
37
+ - aiosignal==1.3.2
38
+ - airportsdata==20250523
39
+ - annotated-types==0.7.0
40
+ - anthropic==0.54.0
41
+ - antlr4-python3-runtime==4.13.2
42
+ - anyio==4.9.0
43
+ - astor==0.8.1
44
+ - asttokens==3.0.0
45
+ - attrs==25.3.0
46
+ - av==14.4.0
47
+ - bitsandbytes==0.46.0
48
+ - blake3==1.0.5
49
+ - cachetools==6.0.0
50
+ - certifi==2025.4.26
51
+ - charset-normalizer==3.4.2
52
+ - click==8.2.1
53
+ # - clip==1.0
54
+ - cloudpickle==3.1.1
55
+ - compressed-tensors==0.9.4
56
+ - contourpy==1.3.2
57
+ - cupy-cuda12x==13.4.1
58
+ - cycler==0.12.1
59
+ - datasets==3.6.0
60
+ - decorator==5.2.1
61
+ - deepspeed==0.15.4
62
+ - depyf==0.18.0
63
+ # - diffusers==0.34.0.dev0
64
+ - dill==0.3.8
65
+ - diskcache==5.6.3
66
+ - distro==1.9.0
67
+ - dnspython==2.7.0
68
+ - docker-pycreds==0.4.0
69
+ - einops==0.8.1
70
+ - email-validator==2.2.0
71
+ - executing==2.2.0
72
+ - fastapi==0.115.12
73
+ - fastapi-cli==0.0.7
74
+ - fastrlock==0.8.3
75
+ - filelock==3.18.0
76
+ # - flash-attn==2.7.4.post1
77
+ - fonttools==4.58.4
78
+ - frozenlist==1.6.2
79
+ - fsspec==2025.3.0
80
+ - ftfy==6.3.1
81
+ - gguf==0.17.0
82
+ - gitdb==4.0.12
83
+ - gitpython==3.1.44
84
+ - googleapis-common-protos==1.70.0
85
+ - grpcio==1.72.1
86
+ - h11==0.16.0
87
+ - hf-transfer==0.1.9
88
+ - hf-xet==1.1.3
89
+ - hjson==3.1.0
90
+ - httpcore==1.0.9
91
+ - httptools==0.6.4
92
+ - httpx==0.28.1
93
+ - huggingface-hub==0.32.4
94
+ - idna==3.10
95
+ - importlib-metadata==8.7.0
96
+ - inquirerpy==0.3.4
97
+ - interegular==0.3.3
98
+ - ipython==9.3.0
99
+ - ipython-pygments-lexers==1.1.1
100
+ - jedi==0.19.2
101
+ - jinja2==3.1.6
102
+ - jiter==0.10.0
103
+ - jsonschema==4.24.0
104
+ - jsonschema-specifications==2025.4.1
105
+ - kiwisolver==1.4.8
106
+ - lark==1.2.2
107
+ - latex2sympy2-extended==1.10.1
108
+ - liger-kernel==0.5.2
109
+ - llguidance==0.7.29
110
+ - llvmlite==0.44.0
111
+ - lm-format-enforcer==0.10.11
112
+ - markdown-it-py==3.0.0
113
+ - markupsafe==3.0.2
114
+ - math-verify==0.7.0
115
+ - matplotlib==3.10.3
116
+ - matplotlib-inline==0.1.7
117
+ - mdurl==0.1.2
118
+ - mistral-common==1.5.6
119
+ - mpmath==1.3.0
120
+ - msgpack==1.1.0
121
+ - msgspec==0.19.0
122
+ - multidict==6.4.4
123
+ - multiprocess==0.70.16
124
+ - nest-asyncio==1.6.0
125
+ - networkx==3.5
126
+ - ninja==1.11.1.4
127
+ - numba==0.61.2
128
+ - numpy==2.2.6
129
+ - nvidia-cublas-cu12==12.6.4.1
130
+ - nvidia-cuda-cupti-cu12==12.6.80
131
+ - nvidia-cuda-nvrtc-cu12==12.6.77
132
+ - nvidia-cuda-runtime-cu12==12.6.77
133
+ - nvidia-cudnn-cu12==9.5.1.17
134
+ - nvidia-cufft-cu12==11.3.0.4
135
+ - nvidia-cufile-cu12==1.11.1.6
136
+ - nvidia-curand-cu12==10.3.7.77
137
+ - nvidia-cusolver-cu12==11.7.1.2
138
+ - nvidia-cusparse-cu12==12.5.4.2
139
+ - nvidia-cusparselt-cu12==0.6.3
140
+ - nvidia-nccl-cu12==2.26.2
141
+ - nvidia-nvjitlink-cu12==12.6.85
142
+ - nvidia-nvtx-cu12==12.6.77
143
+ - openai==1.84.0
144
+ - opencv-python-headless==4.11.0.86
145
+ - opentelemetry-api==1.34.0
146
+ - opentelemetry-exporter-otlp==1.34.0
147
+ - opentelemetry-exporter-otlp-proto-common==1.34.0
148
+ - opentelemetry-exporter-otlp-proto-grpc==1.34.0
149
+ - opentelemetry-exporter-otlp-proto-http==1.34.0
150
+ - opentelemetry-proto==1.34.0
151
+ - opentelemetry-sdk==1.34.0
152
+ - opentelemetry-semantic-conventions==0.55b0
153
+ - opentelemetry-semantic-conventions-ai==0.4.9
154
+ - outlines==0.1.11
155
+ - outlines-core==0.1.26
156
+ - packaging==25.0
157
+ - pandas==2.3.0
158
+ - parso==0.8.4
159
+ - partial-json-parser==0.2.1.1.post5
160
+ - peft==0.17.1
161
+ - pexpect==4.9.0
162
+ - pfzy==0.3.4
163
+ - pillow==11.2.1
164
+ - platformdirs==4.3.8
165
+ - prometheus-client==0.22.1
166
+ - prometheus-fastapi-instrumentator==7.1.0
167
+ - prompt-toolkit==3.0.51
168
+ - propcache==0.3.1
169
+ - protobuf==5.29.5
170
+ - psutil==7.0.0
171
+ - ptyprocess==0.7.0
172
+ - pure-eval==0.2.3
173
+ - py-cpuinfo==9.0.0
174
+ - pyarrow==20.0.0
175
+ - pycountry==24.6.1
176
+ - pydantic==2.11.5
177
+ - pydantic-core==2.33.2
178
+ - pygments==2.19.1
179
+ - pyparsing==3.2.3
180
+ - python-dateutil==2.9.0.post0
181
+ - python-dotenv==1.1.0
182
+ - python-json-logger==3.3.0
183
+ - python-multipart==0.0.20
184
+ - pytz==2025.2
185
+ - pyyaml==6.0.2
186
+ - pyzmq==26.4.0
187
+ - qwen-vl-utils==0.0.11
188
+ - ray==2.46.0
189
+ - referencing==0.36.2
190
+ - regex==2024.11.6
191
+ - requests==2.32.3
192
+ - rich==14.0.0
193
+ - rich-toolkit==0.14.7
194
+ - rpds-py==0.25.1
195
+ - safetensors==0.5.3
196
+ - scipy==1.15.3
197
+ - seaborn==0.13.2
198
+ - sentencepiece==0.2.0
199
+ - sentry-sdk==2.29.1
200
+ - setproctitle==1.3.6
201
+ - shellingham==1.5.4
202
+ - six==1.17.0
203
+ - smmap==5.0.2
204
+ - sniffio==1.3.1
205
+ - stack-data==0.6.3
206
+ - starlette==0.46.2
207
+ - sympy==1.14.0
208
+ - tabulate==0.9.0
209
+ - tiktoken==0.9.0
210
+ - timm==0.6.13
211
+ - tokenizers==0.21.1
212
+ - torch==2.7.0
213
+ - torchaudio==2.7.0
214
+ - torchvision==0.22.0
215
+ - tqdm==4.67.1
216
+ - traitlets==5.14.3
217
+ - transformers==4.51.3
218
+ - triton==3.3.0
219
+ - trl==0.19.0
220
+ - typer==0.16.0
221
+ - typing-extensions==4.14.0
222
+ - typing-inspection==0.4.1
223
+ - tzdata==2025.2
224
+ - urllib3==2.4.0
225
+ - utils==1.0.2
226
+ - uvicorn==0.34.3
227
+ - uvloop==0.21.0
228
+ - vllm==0.9.0.1
229
+ - wandb==0.18.3
230
+ - watchfiles==1.0.5
231
+ - wcwidth==0.2.13
232
+ - websockets==15.0.1
233
+ - xformers==0.0.30
234
+ - xgrammar==0.1.19
235
+ - xxhash==3.5.0
236
+ - yarl==1.20.0
237
+ - zipp==3.22.0
238
+ - tensorboardX==2.6.4
eval.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Batch image evaluation tool with YAML configuration."""
3
+
4
+ import requests
5
+ import pickle
6
+ from PIL import Image
7
+ from typing import List, Dict, Any, Union, Optional, Tuple
8
+ import sys
9
+ import os
10
+ import json
11
+ import yaml
12
+ from io import BytesIO
13
+ from tqdm import tqdm
14
+ from datetime import datetime
15
+
16
+
17
+ PAIR_SCORERS = {"editreward"}
18
+ CAPTION_SUFFIXES = ["_caption.txt", "_prompt.txt"]
19
+
20
+
21
+ class RewardEvaluatorClient:
22
+ def __init__(self, scorer_urls: Dict[str, str]):
23
+ self.scorer_urls = scorer_urls
24
+
25
+ def evaluate(self,
26
+ model_name: str,
27
+ images: Union[List[Image.Image], Dict[str, List[Image.Image]]],
28
+ prompts: List[str],
29
+ metadata: Dict[str, Any] = None) -> Union[List[float], Dict[str, Any]]:
30
+ url = self.scorer_urls.get(model_name)
31
+ if not url:
32
+ raise ValueError(f"Reward model '{model_name}' URL not configured.")
33
+
34
+ payload_bytes = create_payload(images, prompts, metadata)
35
+
36
+ try:
37
+ response = requests.post(url, data=payload_bytes, timeout=600)
38
+ response.raise_for_status()
39
+ result = parse_response(response.content)
40
+
41
+ if isinstance(result, dict) and "error" in result:
42
+ raise RuntimeError(f"Scorer '{model_name}' returned error: {result['error']}")
43
+
44
+ return result
45
+
46
+ except requests.exceptions.RequestException as e:
47
+ raise RuntimeError(f"HTTP request to '{model_name}' failed: {e}")
48
+ except Exception as e:
49
+ raise RuntimeError(f"Failed to process response from '{model_name}': {e}")
50
+
51
+
52
+ def serialize_images(images: List[Image.Image]) -> List[bytes]:
53
+ images_bytes = []
54
+ for img in images:
55
+ img_byte_arr = BytesIO()
56
+ if img.mode != 'RGB':
57
+ img = img.convert('RGB')
58
+ img.save(img_byte_arr, format="JPEG")
59
+ images_bytes.append(img_byte_arr.getvalue())
60
+ return images_bytes
61
+
62
+
63
+ def create_payload(images: Union[List[Image.Image], Dict[str, List[Image.Image]]],
64
+ prompts: List[str],
65
+ metadata: Dict[str, Any] = None) -> bytes:
66
+ if isinstance(images, dict):
67
+ serialized_images = {key: serialize_images(value) for key, value in images.items()}
68
+ else:
69
+ serialized_images = serialize_images(images)
70
+
71
+ return pickle.dumps({
72
+ "images": serialized_images,
73
+ "prompts": prompts,
74
+ "metadata": metadata or {}
75
+ })
76
+
77
+
78
+ def parse_response(response_content: bytes) -> Union[List[float], Dict[str, Any]]:
79
+ return pickle.loads(response_content)
80
+
81
+
82
+ def find_caption_file(base_path: str, base_name: str) -> Optional[str]:
83
+ for suffix in CAPTION_SUFFIXES:
84
+ caption_path = os.path.join(base_path, f"{base_name}{suffix}")
85
+ if os.path.exists(caption_path):
86
+ return caption_path
87
+ return None
88
+
89
+
90
+ def collect_standard_samples(folder_path: str) -> Tuple[List[Image.Image], List[str], List[str]]:
91
+ images, prompts, filenames = [], [], []
92
+
93
+ for file in sorted(os.listdir(folder_path)):
94
+ if not file.lower().endswith(('.png', '.jpg', '.jpeg')):
95
+ continue
96
+ if any(suffix in file for suffix in ['_edited', '_reference', '_source']):
97
+ continue
98
+
99
+ base_name = os.path.splitext(file)[0]
100
+ img_path = os.path.join(folder_path, file)
101
+ caption_path = find_caption_file(folder_path, base_name)
102
+
103
+ if not caption_path:
104
+ continue
105
+
106
+ try:
107
+ img = Image.open(img_path)
108
+ with open(caption_path, 'r', encoding='utf-8') as f:
109
+ prompt = f.read().strip()
110
+ images.append(img)
111
+ prompts.append(prompt)
112
+ filenames.append(file)
113
+ except Exception as e:
114
+ print(f" Warning: Failed to process {file}: {e}")
115
+
116
+ return images, prompts, filenames
117
+
118
+
119
+ def collect_edit_samples(folder_path: str) -> Tuple[Dict[str, List[Image.Image]], List[str], List[str]]:
120
+ source_images, edited_images, prompts, filenames = [], [], [], []
121
+
122
+ edited_files = [f for f in os.listdir(folder_path) if f.endswith('_edited.png')]
123
+
124
+ for edited_file in sorted(edited_files):
125
+ base_name = edited_file.replace('_edited.png', '')
126
+ source_file = f"{base_name}_reference.png"
127
+
128
+ if not os.path.exists(os.path.join(folder_path, source_file)):
129
+ source_file = f"{base_name}_source.png"
130
+
131
+ source_path = os.path.join(folder_path, source_file)
132
+ edited_path = os.path.join(folder_path, edited_file)
133
+ caption_path = find_caption_file(folder_path, base_name)
134
+
135
+ if not os.path.exists(source_path) or not caption_path:
136
+ continue
137
+
138
+ try:
139
+ source_img = Image.open(source_path)
140
+ edited_img = Image.open(edited_path)
141
+ with open(caption_path, 'r', encoding='utf-8') as f:
142
+ prompt = f.read().strip()
143
+
144
+ source_images.append(source_img)
145
+ edited_images.append(edited_img)
146
+ prompts.append(prompt)
147
+ filenames.append(base_name)
148
+ except Exception as e:
149
+ print(f" Warning: Failed to process {base_name}: {e}")
150
+
151
+ return {'source': source_images, 'edited': edited_images}, prompts, filenames
152
+
153
+
154
+ def evaluate_folder(folder_path: str,
155
+ model_name: str,
156
+ batch_size: int,
157
+ scorer_urls: Dict[str, str],
158
+ verbose: bool = True) -> Optional[Dict[str, Any]]:
159
+ if not os.path.isdir(folder_path):
160
+ return None
161
+
162
+ evaluator = RewardEvaluatorClient(scorer_urls)
163
+ is_pair_scorer = model_name in PAIR_SCORERS
164
+
165
+ if is_pair_scorer:
166
+ images, prompts, filenames = collect_edit_samples(folder_path)
167
+ sample_count = len(prompts)
168
+ else:
169
+ images, prompts, filenames = collect_standard_samples(folder_path)
170
+ sample_count = len(images)
171
+
172
+ if sample_count == 0:
173
+ if verbose:
174
+ print(f" Skipped (no valid samples): {folder_path}")
175
+ return None
176
+
177
+ if verbose:
178
+ print(f" Evaluating {sample_count} samples: {folder_path}")
179
+
180
+ all_scores = []
181
+
182
+ if is_pair_scorer:
183
+ source_images = images['source']
184
+ edited_images = images['edited']
185
+
186
+ for start_idx in tqdm(range(0, sample_count, batch_size), disable=not verbose):
187
+ end_idx = min(start_idx + batch_size, sample_count)
188
+ batch_images = {
189
+ 'source': source_images[start_idx:end_idx],
190
+ 'edited': edited_images[start_idx:end_idx]
191
+ }
192
+ batch_prompts = prompts[start_idx:end_idx]
193
+
194
+ try:
195
+ batch_results = evaluator.evaluate(model_name, batch_images, batch_prompts)
196
+ scores = batch_results.get('scores', batch_results) if isinstance(batch_results, dict) else batch_results
197
+ all_scores.extend(scores)
198
+ except Exception as e:
199
+ print(f" Batch evaluation failed [{start_idx}:{end_idx}]: {e}")
200
+ return None
201
+ else:
202
+ for start_idx in tqdm(range(0, sample_count, batch_size), disable=not verbose):
203
+ end_idx = min(start_idx + batch_size, sample_count)
204
+ batch_images = images[start_idx:end_idx]
205
+ batch_prompts = prompts[start_idx:end_idx]
206
+
207
+ try:
208
+ batch_results = evaluator.evaluate(model_name, batch_images, batch_prompts)
209
+ scores = batch_results.get('scores', batch_results) if isinstance(batch_results, dict) else batch_results
210
+ all_scores.extend(scores)
211
+ except Exception as e:
212
+ print(f" Batch evaluation failed [{start_idx}:{end_idx}]: {e}")
213
+ continue
214
+
215
+ if not all_scores:
216
+ return None
217
+
218
+ return {
219
+ 'folder': folder_path,
220
+ 'model': model_name,
221
+ 'average': sum(all_scores) / len(all_scores),
222
+ 'scores': all_scores,
223
+ 'count': len(all_scores)
224
+ }
225
+
226
+
227
+ def find_leaf_folders(root_path: str, min_depth: int = 0, max_depth: int = -1) -> List[str]:
228
+ result = []
229
+ root_path = os.path.abspath(root_path)
230
+
231
+ def has_images(folder: str) -> bool:
232
+ for f in os.listdir(folder):
233
+ if f.lower().endswith(('.png', '.jpg', '.jpeg')):
234
+ return True
235
+ return False
236
+
237
+ def recurse(current_path: str, depth: int):
238
+ if max_depth >= 0 and depth > max_depth:
239
+ return
240
+
241
+ try:
242
+ entries = os.listdir(current_path)
243
+ except PermissionError:
244
+ return
245
+
246
+ subdirs = [e for e in entries if os.path.isdir(os.path.join(current_path, e))]
247
+
248
+ if not subdirs or (max_depth >= 0 and depth == max_depth):
249
+ if depth >= min_depth and has_images(current_path):
250
+ result.append(current_path)
251
+ else:
252
+ for subdir in subdirs:
253
+ recurse(os.path.join(current_path, subdir), depth + 1)
254
+ if depth >= min_depth and has_images(current_path):
255
+ result.append(current_path)
256
+
257
+ recurse(root_path, 0)
258
+ return sorted(result)
259
+
260
+
261
+ def run(config: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
262
+ scorer_urls = config['scorer_urls']
263
+ defaults = config.get('defaults', {})
264
+ evaluations = config['evaluations']
265
+ output_file = config.get('output')
266
+ verbose = config.get('verbose', True)
267
+
268
+ default_batch_size = defaults.get('batch_size', 64)
269
+ default_recursive = defaults.get('recursive', False)
270
+ default_min_depth = defaults.get('min_depth', 0)
271
+ default_max_depth = defaults.get('max_depth', -1)
272
+
273
+ all_results = {}
274
+
275
+ for eval_item in evaluations:
276
+ path = eval_item.get('path')
277
+ if not path:
278
+ print("Warning: Evaluation item missing 'path', skipping")
279
+ continue
280
+
281
+ models = eval_item.get('models', [])
282
+ if not models:
283
+ print(f"Warning: No models specified for {path}, skipping")
284
+ continue
285
+
286
+ batch_size = eval_item.get('batch_size', default_batch_size)
287
+ recursive = eval_item.get('recursive', default_recursive)
288
+ min_depth = eval_item.get('min_depth', default_min_depth)
289
+ max_depth = eval_item.get('max_depth', default_max_depth)
290
+
291
+ if not recursive:
292
+ max_depth = 0
293
+
294
+ folders = find_leaf_folders(path, min_depth, max_depth)
295
+
296
+ if not folders:
297
+ print(f"No image folders found in: {path}")
298
+ continue
299
+
300
+ print(f"\nProcessing {len(folders)} folder(s) from: {path}")
301
+ print(f"Models: {', '.join(models)}")
302
+ print("-" * 60)
303
+
304
+ for folder in tqdm(folders, desc="Folders", disable=not verbose):
305
+ folder_results = {}
306
+
307
+ for model in models:
308
+ if verbose:
309
+ print(f"\n[{model}] ", end="")
310
+
311
+ result = evaluate_folder(folder, model, batch_size, scorer_urls, verbose)
312
+
313
+ if result:
314
+ folder_results[model] = result
315
+ if verbose:
316
+ print(f" -> Average: {result['average']:.4f} (n={result['count']})")
317
+
318
+ if folder_results:
319
+ rel_path = os.path.relpath(folder, path)
320
+ key = f"{path}:{rel_path}" if rel_path != "." else path
321
+ all_results[key] = folder_results
322
+
323
+ # Print summary
324
+ print("\n" + "=" * 60)
325
+ print("Evaluation Summary")
326
+ print("=" * 60)
327
+ for folder, results in all_results.items():
328
+ print(f"\n{folder}")
329
+ for model, data in results.items():
330
+ print(f" [{model}] avg={data['average']:.4f}, n={data['count']}")
331
+
332
+ # Save results
333
+ if output_file:
334
+ serializable = {
335
+ folder: {
336
+ model: {'average': data['average'], 'count': data['count']}
337
+ for model, data in results.items()
338
+ }
339
+ for folder, results in all_results.items()
340
+ }
341
+
342
+ with open(output_file, 'w', encoding='utf-8') as f:
343
+ json.dump({
344
+ 'timestamp': datetime.now().isoformat(),
345
+ 'results': serializable
346
+ }, f, indent=2, ensure_ascii=False)
347
+
348
+ print(f"\nResults saved to: {output_file}")
349
+
350
+ return all_results
351
+
352
+
353
+ def main():
354
+ if len(sys.argv) != 2:
355
+ print(f"Usage: python {sys.argv[0]} <config.yaml>")
356
+ sys.exit(1)
357
+
358
+ config_path = sys.argv[1]
359
+ with open(config_path, 'r', encoding='utf-8') as f:
360
+ config = yaml.safe_load(f)
361
+
362
+ results = run(config)
363
+ sys.exit(0 if results else 1)
364
+
365
+
366
+ if __name__ == "__main__":
367
+ main()
gen.sh ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ # Download eval datasets if not present
5
+ EDIT_DATA="data/omni_edit_dev.parquet"
6
+ if [ ! -f "$EDIT_DATA" ]; then
7
+ echo "Downloading edit eval dataset..."
8
+ mkdir -p data
9
+ huggingface-cli download wangfuyun/PrompRL data/omni_edit_dev.parquet \
10
+ --repo-type model --local-dir . --local-dir-use-symlinks False
11
+ fi
12
+
13
+ # # Text-to-Image OCR
14
+ python unified_inference.py --mode t2i \
15
+ --model_path wangfuyun/PrompRL/promptrl_ocr \
16
+ --model_type flux \
17
+ --prompt_file prompts/ocr_test.txt \
18
+ --output_dir outputs/ocr \
19
+ --use_cot --cot_template ocr_clarity_v2
20
+
21
+ # # Text-to-Image PS
22
+ python unified_inference.py --mode t2i \
23
+ --model_path wangfuyun/PrompRL/promptrl_ps \
24
+ --model_type flux \
25
+ --prompt_file prompts/draw_test.txt \
26
+ --output_dir outputs/pickscore \
27
+ --use_cot --cot_template quality_purev2
28
+
29
+ # # GenEval
30
+ python unified_inference.py --mode geneval \
31
+ --model_path wangfuyun/PrompRL/promptrl_geneval \
32
+ --model_type flux \
33
+ --metadata_file prompts/evaluation_metadata.jsonl \
34
+ --output_dir outputs/geneval \
35
+ --use_cot --cot_template geneval \
36
+ --n_samples 4
37
+
38
+ # # Image Editing
39
+ python unified_inference.py --mode edit \
40
+ --model_path wangfuyun/PrompRL/promptrl_edit \
41
+ --model_type kontext \
42
+ --data_file "$EDIT_DATA" \
43
+ --output_dir outputs/edit \
44
+ --use_cot --cot_template edit_general \
45
+ --guidance_scale 2.5
46
+
47
+
48
+ # python eval.py prompts/config.yaml
prompts/config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Batch Image Evaluator Configuration
2
+
3
+ scorer_urls:
4
+ aesthetic: "http://YOUR_SERVER_IP:18080/"
5
+ image_reward: "http://YOUR_SERVER_IP:18081/"
6
+ ocr: "http://YOUR_SERVER_IP:18082/"
7
+ pickscore: "http://YOUR_SERVER_IP:18083/"
8
+ deqa: "http://YOUR_SERVER_IP:18084/"
9
+ gen_eval: "http://YOUR_SERVER_IP:18085/"
10
+ unifiedreward_sglang: "http://YOUR_SERVER_IP:18086/"
11
+ hps: "http://YOUR_SERVER_IP:18087/"
12
+ editreward: "http://YOUR_SERVER_IP:18088/"
13
+
14
+ defaults:
15
+ batch_size: 64
16
+ recursive: false
17
+ min_depth: 0
18
+ max_depth: -1
19
+
20
+ output: results.json
21
+ verbose: true
22
+
23
+ evaluations:
24
+ - path: ./outputs/ocr
25
+ models: [ocr]
26
+ batch_size: 32
27
+ recursive: true
28
+
29
+ - path: ./outputs/edit
30
+ models: [editreward]
31
+ batch_size: 32
32
+
33
+ - path: ./outputs/pickscore
34
+ models: [pickscore]
35
+ batch_size: 32
prompts/draw_test.txt ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ New York Skyline with 'Google Research Pizza Cafe' written with fireworks on the sky.
2
+ A maglev train going vertically downward in high speed, New York Times photojournalism.
3
+ A pyramid made of falafel with a partial solar eclipse in the background.
4
+ A storefront with 'Google Brain Toronto' written on it.
5
+ An elephant under the sea.
6
+ Lego Arnold Schwarzenegger.
7
+ A keyboard made of water, the water is made of light, the light is turned off.
8
+ Artophagous.
9
+ One cat and one dog sitting on the grass.
10
+ A laptop on top of a teddy bear.
11
+ A red colored car.
12
+ A stack of 3 books. A green book is on the top, sitting on a red book. The red book is in the middle, sitting on a blue book. The blue book is on the bottom.
13
+ A green colored banana.
14
+ Matutinal.
15
+ A green cup and a blue cell phone.
16
+ A stack of 3 plates. A blue plate is on the top, sitting on a blue plate. The blue plate is in the middle, sitting on a green plate. The green plate is on the bottom.
17
+ A large thick-skinned semiaquatic African mammal, with massive jaws and large tusks.
18
+ A red colored banana.
19
+ Jentacular.
20
+ A sign that says 'Hello World'.
21
+ A blue cup and a green cell phone.
22
+ A black colored banana.
23
+ Two cats and two dogs sitting on the grass.
24
+ A ldarge keybord msical instroument lwith a woden case enmclosig a qsouvnkboajrd and mfgtal strivgf, which are strucrk b hammrs when the nels are depresdsmed.f lhe strsingsj' vibration ie stopped by damperds when the keys re released and can bce regulavewdd for lengh and vnolume y two or three pedalvs.
25
+ A magnifying glass over a page of a 1950s batman comic.
26
+ A separate seat for one person, typically with a back and four legs.
27
+ Two dogs on the street.
28
+ New York Skyline with 'Diffusion' written with fireworks on the sky.
29
+ A black colored banana.
30
+ An ancient Egyptian painting depicting an argument over whose turn it is to take out the trash.
31
+ A wine glass on top of a dog.
32
+ An emoji of a baby panda wearing a red hat, green gloves, red shirt, and green pants.
33
+ A pear cut into seven pieces arranged in a ring.
34
+ A large thick-skinned semiaquatic African mammal, with massive jaws and large tusks.
35
+ A baby fennec sneezing onto a strawberry, detailed, macro, studio light, droplets, backlit ears.
36
+ A panda making latte art.
37
+ An IT-guy trying to fix hardware of a PC tower is being tangled by the PC cables like Laokoon. Marble, copy after Hellenistic original from ca. 200 BC. Found in the Baths of Trajan, 1506.
38
+ A blue bird and a brown bear.
39
+ A triangular purple flower pot. A purple flower pot in the shape of a triangle.
40
+ A green apple and a black backpack.
41
+ A grocery store refrigerator has pint cartons of milk on the top shelf, quart cartons on the middle shelf, and gallon plastic jugs on the bottom shelf.
42
+ A small domesticated carnivorous mammal with soft fur, a short snout, and retractable claws. It is widely kept as a pet or for catching mice, and many breeds have been developed.
43
+ An orange colored sandwich.
44
+ A large motor vehicle carrying passengers by road, typically one serving the public on a fixed route and for a fare.
45
+ A sphere made of kitchen tile. A sphere with the texture of kitchen tile.
46
+ A cat on the right of a tennis racket.
47
+ Bzaseball galove.
48
+ A sign that says 'NeurIPS'.
49
+ A 1960s yearbook photo with animals dressed as humans.
50
+ New York Skyline with 'Hello World' written with fireworks on the sky.
51
+ Hovering cow abducting aliens.
52
+ A small vessel propelled on water by oars, sails, or an engine.
53
+ A type of digital currency in which a record of transactions is maintained and new units of currency are generated by the computational solution of mathematical problems, and which operates independently of a central bank.
54
+ A pink colored car.
55
+ A storefront with 'NeurIPS' written on it.
56
+ A black apple and a green backpack.
57
+ A large motor vehicle carrying passengers by road, typically one serving the public on a fixed route and for a fare.
58
+ A long curved fruit which grows in clusters and has soft pulpy flesh and yellow skin when ripe.
59
+ A black colored car.
60
+ A realistic photo of a Pomeranian dressed up like a 1980s professional wrestler with neon green and neon orange face paint and bright green wrestling tights with bright orange boots.
61
+ Tcennis rpacket.
62
+ McDonalds Church.
63
+ Painting of Mona Lisa but the view is from behind of Mona Lisa.
64
+ An elephant is behind a tree. You can see the trunk on one side and the back legs on the other.
65
+ Hovering cow abducting aliens.
66
+ Photo of a mega Lego space station inside a kid's bedroom.
67
+ An elephant under the sea.
68
+ One cat and two dogs sitting on the grass.
69
+ A green colored banana.
70
+ An American multinational technology company that focuses on artificial intelligence, search engine, online advertising, cloud computing, computer software, quantum computing, e-commerce, and consumer electronics.
71
+ A domesticated carnivvorous mzammal that typicbally hfaas a lons sfnout, an acxujte sense off osmell, noneetractaaln crlaws, anid xbarkring,y howlingu, or whining rvoiche.
72
+ Jentacular.
73
+ A wine glass on top of a dog.
74
+ A carrot on the left of a broccoli.
75
+ Pafrking metr.
76
+ Three cars on the street.
77
+ In late afternoon in January in New England, a man stands in the shadow of a maple tree.
78
+ An oil painting portrait of the regal Burger King posing with a Whopper.
79
+ A sign that says 'Text to Image'.
80
+ A small vessel propelled on water by oars, sails, or an engine.
81
+ A single clock is sitting on a table.
82
+ A stack of 3 plates. A blue plate is on the top, sitting on a blue plate. The blue plate is in the middle, sitting on a green plate. The green plate is on the bottom.
83
+ An elephant under the sea.
84
+ A type of digital currency in which a record of transactions is maintained and new units of currency are generated by the computational solution of mathematical problems, and which operates independently of a central bank.
85
+ A yellow colored giraffe.
86
+ An elephant is behind a tree. You can see the trunk on one side and the back legs on the other.
87
+ A device consisting of a circular canopy of cloth on a folding metal frame supported by a central rod, used as protection against rain or sometimes sun.
88
+ A fluffy baby sloth with a knitted hat trying to figure out a laptop, close up, highly detailed, studio lighting, screen reflecting in its eyes.
89
+ A stack of 3 plates. A blue plate is on the top, sitting on a blue plate. The blue plate is in the middle, sitting on a green plate. The green plate is on the bottom.
90
+ Three cats and three dogs sitting on the grass.
91
+ A large keyboard musical instrument with a wooden case enclosing a soundboard and metal strings, which are struck by hammers when the keys are depressed. The strings' vibration is stopped by dampers when the keys are released and can be regulated for length and volume by two or three pedals.
92
+ A blue coloured pizza.
93
+ A storefront with 'Google Research Pizza Cafe' written on it.
94
+ A sjmall domesticated carnivorious mammnal with sof fuh,y a sthort sout, and retracwtablbe flaws. It iw widexly kept as a pet or for catchitng mic, ad many breeds zhlyde beefn develvoked.
95
+ A green apple and a black backpack.
96
+ A pink colored car.
97
+ A pear cut into seven pieces arranged in a ring.
98
+ A screenshot of an iOS app for ordering different types of milk.
99
+ Rbefraigerator.
100
+ A blue colored dog.
101
+ Two cats and two dogs sitting on the grass.
102
+ A real life photography of super mario, 8k Ultra HD.
103
+ New York Skyline with 'Hello World' written with fireworks on the sky.
104
+ A realistic photo of a Pomeranian dressed up like a 1980s professional wrestler with neon green and neon orange face paint and bright green wrestling tights with bright orange boots.
105
+ A panda making latte art.
106
+ A storefront with 'NeurIPS' written on it.
107
+ A large keyboard musical instrument with a wooden case enclosing a soundboard and metal strings, which are struck by hammers when the keys are depressed. The strings' vibration is stopped by dampers when the keys are released and can be regulated for length and volume by two or three pedals.
108
+ A blue colored dog.
109
+ Three cats and two dogs sitting on the grass.
110
+ New York Skyline with 'Google Brain Toronto' written with fireworks on the sky.
111
+ A blue coloured pizza.
112
+ A panda making latte art.
113
+ An American multinational technology company that focuses on artificial intelligence, search engine, online advertising, cloud computing, computer software, quantum computing, e-commerce, and consumer electronics.
114
+ Backlotter.
115
+ A black colored sandwich.
116
+ A large thick-skinned semiaquatic African mammal, with massive jaws and large tusks.
117
+ A domesticated carnivorous mammal that typically has a long snout, an acute sense of smell, nonretractable claws, and a barking, howling, or whining voice.
118
+ New York Skyline with 'Deep Learning' written with fireworks on the sky.
119
+ A black colored dog.
120
+ A stack of 3 cubes. A red cube is on the top, sitting on a red cube. The red cube is in the middle, sitting on a green cube. The green cube is on the bottom.
121
+ A type of digital currency in which a record of transactions is maintained and new units of currency are generated by the computational solution of mathematical problems, and which operates independently of a central bank.
122
+ Five cars on the street.
123
+ An old photograph of a 1920s airship shaped like a pig, floating over a wheat field.
124
+ Illustration of a mouse using a mushroom as an umbrella.
125
+ Three cats and one dog sitting on the grass.
126
+ Four cars on the street.
127
+ A black colored sandwich.
128
+ Five cars on the street.
129
+ An American multinational technology company that focuses on artificial intelligence, search engine, online advertising, cloud computing, computer software, quantum computing, e-commerce, and consumer electronics.
130
+ A sign that says 'Google Brain Toronto'.
131
+ A storefront with 'Text to Image' written on it.
132
+ A magnifying glass over a page of a 1950s batman comic.
133
+ A sphere made of kitchen tile. A sphere with the texture of kitchen tile.
134
+ An IT-guy trying to fix hardware of a PC tower is being tangled by the PC cables like Laokoon. Marble, copy after Hellenistic original from ca. 200 BC. Found in the Baths of Trajan, 1506.
135
+ A sign that says 'Diffusion'.
136
+ A blue bird and a brown bear.
137
+ A photo of a confused grizzly bear in calculus class.
138
+ A grocery store refrigerator has pint cartons of milk on the top shelf, quart cartons on the middle shelf, and gallon plastic jugs on the bottom shelf.
139
+ A hair drier underneath a sheep.
140
+ Pafrking metr.
141
+ Peristeronic.
142
+ Two cats and one dog sitting on the grass.
143
+ New York Skyline with 'Google Research Pizza Cafe' written with fireworks on the sky.
144
+ A side view of an owl sitting in a field.
145
+ A pink colored car.
146
+ Paying for a quarter-sized pizza with a pizza-sized quarter.
147
+ Dininrg tablez.
148
+ A fish eating a pelican.
149
+ One cat and three dogs sitting on the grass.
150
+ An instrument used for cutting cloth, paper, and other thin material, consisting of two blades laid one on top of the other and fastened in the middle so as to allow them to be opened and closed by a thumb and finger inserted through rings on the end of their handles.
151
+ A side view of an owl sitting in a field.
152
+ A large thick-skinned semiaquatic African mammal, with massive jaws and large tusks.
153
+ A large keyboard musical instrument with a wooden case enclosing a soundboard and metal strings, which are struck by hammers when the keys are depressed. The strings' vibration is stopped by dampers when the keys are released and can be regulated for length and volume by two or three pedals.
154
+ Pafrking metr.
155
+ A sign that says 'Deep Learning'.
156
+ A collection of nail is sitting on a table.
157
+ One car on the street.
158
+ An emoji of a baby panda wearing a red hat, blue gloves, green shirt, and blue pants.
159
+ A brown bird and a blue bear.
160
+ A donkey and an octopus are playing a game. The donkey is holding a rope on one end, the octopus is holding onto the other. The donkey holds the rope in its mouth. A cat is jumping over the rope.
161
+ A fisheye lens view of a turtle sitting in a forest.
162
+ A large motor vehicle carrying passengers by road, typically one serving the public on a fixed route and for a fare.
163
+ New York Skyline with 'Hello World' written with fireworks on the sky.
164
+ An emoji of a baby panda wearing a red hat, green gloves, red shirt, and green pants.
165
+ A black colored dog.
166
+ A ldarge keybord msical instroument lwith a woden case enmclosig a qsouvnkboajrd and mfgtal strivgf, which are strucrk b hammrs when the nels are depresdsmed.f lhe strsingsj' vibration ie stopped by damperds when the keys re released and can bce regulavewdd for lengh and vnolume y two or three pedalvs.
167
+ Artophagous.
168
+ A yellow book and a red vase.
169
+ A stack of 3 books. A green book is on the top, sitting on a red book. The red book is in the middle, sitting on a blue book. The blue book is on the bottom.
170
+ A pizza on the right of a suitcase.
171
+ A tiger in a lab coat with a 1980s Miami vibe, turning a well oiled science content machine, digital art.
172
+ A storefront with 'Hello World' written on it.
173
+ A tiger in a lab coat with a 1980s Miami vibe, turning a well oiled science content machine, digital art.
174
+ A storefront with 'Google Brain Toronto' written on it.
175
+ A 1960s poster warning against climate change.
176
+ An organ of soft nervous tissue contained in the skull of vertebrates, functioning as the coordinating center of sensation and intellectual and nervous activity.
177
+ A long curved fruit which grows in clusters and has soft pulpy flesh and yellow skin when ripe.
178
+ Supreme Court Justices play a baseball game with the FBI. The FBI is at bat, the justices are on the field.
179
+ A pyramid made of falafel with a partial solar eclipse in the background.
180
+ A single clock is sitting on a table.
181
+ New York Skyline with 'Google Research Pizza Cafe' written with fireworks on the sky.
182
+ A blue cup and a green cell phone.
183
+ An oil painting of a couple in formal evening wear going home get caught in a heavy downpour with no umbrellas.
184
+ Darth Vader playing with raccoon in Mars during sunset.
185
+ A red car and a white sheep.
186
+ An illustration of a large red elephant sitting on a small blue mouse.
187
+ An illustration of a small green elephant standing behind a large red mouse.
188
+ A domesticated carnivorous mammal that typically has a long snout, an acute sense of smell, nonretractable claws, and a barking, howling, or whining voice.
189
+ A medieval painting of the wifi not working.
190
+ An American multinational technology company that focuses on artificial intelligence, search engine, online advertising, cloud computing, computer software, quantum computing, e-commerce, and consumer electronics.
191
+ One cat and two dogs sitting on the grass.
192
+ An IT-guy trying to fix hardware of a PC tower is being tangled by the PC cables like Laokoon. Marble, copy after Hellenistic original from ca. 200 BC. Found in the Baths of Trajan, 1506.
193
+ A fluffy baby sloth with a knitted hat trying to figure out a laptop, close up, highly detailed, studio lighting, screen reflecting in its eyes.
194
+ Abraham Lincoln touches his toes while George Washington does chin-ups. Lincoln is barefoot. Washington is wearing boots.
195
+ An umbrella on top of a spoon.
196
+ Matutinal.
197
+ A pink colored giraffe.
198
+ An emoji of a baby panda wearing a red hat, green gloves, red shirt, and green pants.
199
+ Illustration of a mouse using a mushroom as an umbrella.
200
+ A brown bird and a blue bear.
201
+ A painting by Grant Wood of an astronaut couple, american gothic style.
202
+ A sign that says 'Diffusion'.
203
+ Five dogs on the street.
204
+ Four dogs on the street.
205
+ A cat on the left of a dog.
206
+ A zebra underneath a broccoli.
207
+ A banana on the left of an apple.
208
+ Two cats and three dogs sitting on the grass.
209
+ A yellow colored giraffe.
210
+ Three cats and one dog sitting on the grass.
211
+ A ldarge keybord msical instroument lwith a woden case enmclosig a qsouvnkboajrd and mfgtal strivgf, which are strucrk b hammrs when the nels are depresdsmed.f lhe strsingsj' vibration ie stopped by damperds when the keys re released and can bce regulavewdd for lengh and vnolume y two or three pedalvs.
212
+ Abraham Lincoln touches his toes while George Washington does chin-ups. Lincoln is barefoot. Washington is wearing boots.
213
+ A yellow book and a red vase.
214
+ A cat on the left of a dog.
215
+ A stop sign on the right of a refrigerator.
216
+ A shark in the desert.
217
+ Octothorpe.
218
+ A red colored car.
219
+ Four cars on the street.
220
+ A tiger in a lab coat with a 1980s Miami vibe, turning a well oiled science content machine, digital art.
221
+ Three cats and one dog sitting on the grass.
222
+ Paying for a quarter-sized pizza with a pizza-sized quarter.
223
+ A zebra to the right of a fire hydrant.
224
+ A stack of 3 cubes. A red cube is on the top, sitting on a red cube. The red cube is in the middle, sitting on a green cube. The green cube is on the bottom.
225
+ A 1960s poster warning against climate change.
226
+ A storefront with 'Google Research Pizza Cafe' written on it.
227
+ A laptop on top of a teddy bear.
228
+ A painting by Grant Wood of an astronaut couple, american gothic style.
229
+ New York Skyline with 'Deep Learning' written with fireworks on the sky.
230
+ A storefront with 'Diffusion' written on it.
231
+ A storefront with 'Text to Image' written on it.
232
+ A small blue book sitting on a large red book.
233
+ Colouring page of large cats climbing the eifel tower in a cyberpunk future.
234
+ An emoji of a baby panda wearing a red hat, blue gloves, green shirt, and blue pants.
235
+ A photo of a confused grizzly bear in calculus class.
236
+ Paying for a quarter-sized pizza with a pizza-sized quarter.
237
+ Painting of the orange cat Otto von Garfield, Count of Bismarck-Schönhausen, Duke of Lauenburg, Minister-President of Prussia. Depicted wearing a Prussian Pickelhaube and eating his favorite meal - lasagna.
238
+ A device consisting of a circular canopy of cloth on a folding metal frame supported by a central rod, used as protection against rain or sometimes sun.
239
+ Supreme Court Justices play a baseball game with the FBI. The FBI is at bat, the justices are on the field.
240
+ A triangular pink stop sign. A pink stop sign in the shape of a triangle.
241
+ Painting of the orange cat Otto von Garfield, Count of Bismarck-Schönhausen, Duke of Lauenburg, Minister-President of Prussia. Depicted wearing a Prussian Pickelhaube and eating his favorite meal - lasagna.
242
+ A train on top of a surfboard.
243
+ A stack of 3 cubes. A red cube is on the top, sitting on a red cube. The red cube is in the middle, sitting on a green cube. The green cube is on the bottom.
244
+ A sjmall domesticated carnivorious mammnal with sof fuh,y a sthort sout, and retracwtablbe flaws. It iw widexly kept as a pet or for catchitng mic, ad many breeds zhlyde beefn develvoked.
245
+ A laptop on top of a teddy bear.
246
+ A train on top of a surfboard.
247
+ A photocopy of a photograph of a painting of a sculpture of a giraffe.
248
+ A 1960s yearbook photo with animals dressed as humans.
249
+ A pink colored giraffe.
250
+ A maglev train going vertically downward in high speed, New York Times photojournalism.
251
+ A domesticated carnivvorous mzammal that typicbally hfaas a lons sfnout, an acxujte sense off osmell, noneetractaaln crlaws, anid xbarkring,y howlingu, or whining rvoiche.
252
+ A sign that says 'Google Research Pizza Cafe'.
253
+ Two cars on the street.
254
+ A tennis racket underneath a traffic light.
255
+ A cross-section view of a brain.
256
+ One cat and one dog sitting on the grass.
257
+ A horse riding an astronaut.
258
+ A car playing soccer, digital art.
259
+ A large plant-eating domesticated mammal with solid hoofs and a flowing mane and tail, used for riding, racing, and to carry and pull loads.
260
+ Three dogs on the street.
261
+ A separate seat for one person, typically with a back and four legs.
262
+ A couple of glasses are sitting on a table.
263
+ A couch on the left of a chair.
264
+ Two cars on the street.
265
+ A photocopy of a photograph of a painting of a sculpture of a giraffe.
266
+ A black apple and a green backpack.
267
+ A pyramid made of falafel with a partial solar eclipse in the background.
268
+ A brown colored giraffe.
269
+ One cat and one dog sitting on the grass.
270
+ A pizza cooking an oven.
271
+ A church with stained glass windows depicting a hamburger and french fries.
272
+ A connection point by which firefighters can tap into a water supply.
273
+ A sign that says 'Google Research Pizza Cafe'.
274
+ 35mm macro shot a kitten licking a baby duck, studio lighting.
275
+ New York Skyline with 'Text to Image' written with fireworks on the sky.
276
+ An oil painting portrait of the regal Burger King posing with a Whopper.
277
+ A storefront with 'Google Brain Toronto' written on it.
278
+ A bridge connecting Europe and North America on the Atlantic Ocean, bird's eye view.
279
+ One cat and three dogs sitting on the grass.
280
+ Octothorpe.
281
+ A connection point by which firefighters can tap into a water supply.
282
+ A donut underneath a toilet.
283
+ Colouring page of large cats climbing the eifel tower in a cyberpunk future.
284
+ A panda making latte art.
285
+ A machine next to a parking space in a street, into which the driver puts money so as to be authorized to park the vehicle for a particular length of time.
286
+ New York Skyline with 'Google Brain Toronto' written with fireworks on the sky.
287
+ A real life photography of super mario, 8k Ultra HD.
288
+ A cat on the right of a tennis racket.
289
+ A sign that says 'Diffusion'.
290
+ An illustration of a large red elephant sitting on a small blue mouse.
291
+ A collection of nail is sitting on a table.
292
+ An appliance or compartment which is artificially kept cool and used to store food and drink.
293
+ An oil painting portrait of the regal Burger King posing with a Whopper.
294
+ Abraham Lincoln touches his toes while George Washington does chin-ups. Lincoln is barefoot. Washington is wearing boots.
295
+ A black colored dog.
296
+ One cat and two dogs sitting on the grass.
297
+ A donkey and an octopus are playing a game. The donkey is holding a rope on one end, the octopus is holding onto the other. The donkey holds the rope in its mouth. A cat is jumping over the rope.
298
+ A pink colored giraffe.
299
+ A hair drier underneath a sheep.
300
+ A couch on the left of a chair.
301
+ A cube made of denim. A cube with the texture of denim.
302
+ Jentacular.
303
+ An old photograph of a 1920s airship shaped like a pig, floating over a wheat field.
304
+ Colouring page of large cats climbing the eifel tower in a cyberpunk future.
305
+ A collection of nail is sitting on a table.
306
+ One dog on the street.
307
+ A stack of 3 cubes. A red cube is on the top, sitting on a red cube. The red cube is in the middle, sitting on a green cube. The green cube is on the bottom.
308
+ Illustration of a mouse using a mushroom as an umbrella.
309
+ A zebra to the right of a fire hydrant.
310
+ Two dogs on the street.
311
+ Photo of an athlete cat explaining it's latest scandal at a press conference to journalists.
312
+ A domesticated carnivvorous mzammal that typicbally hfaas a lons sfnout, an acxujte sense off osmell, noneetractaaln crlaws, anid xbarkring,y howlingu, or whining rvoiche.
313
+ A vehicle composed of two wheels held in a frame one behind the other, propelled by pedals and steered with handlebars attached to the front wheel.
314
+ A sign that says 'NeurIPS'.
315
+ A church with stained glass windows depicting a hamburger and french fries.
316
+ A shark in the desert.
317
+ An emoji of a baby panda wearing a red hat, blue gloves, green shirt, and blue pants.
318
+ A machine next to a parking space in a street, into which the driver puts money so as to be authorized to park the vehicle for a particular length of time.
319
+ Artophagous.
320
+ A car on the left of a bus.
321
+ A storefront with 'Google Brain Toronto' written on it.
322
+ A cube made of denim. A cube with the texture of denim.
323
+ A red colored banana.
324
+ Two dogs on the street.
325
+ Five cars on the street.
326
+ A mechanical or electrical device for measuring time.
327
+ Acersecomicke.
328
+ An illustration of a large red elephant sitting on a small blue mouse.
329
+ A triangular pink stop sign. A pink stop sign in the shape of a triangle.
330
+ Peristeronic.
331
+ A keyboard made of water, the water is made of light, the light is turned off.
332
+ Greek statue of a man tripping over a cat.
333
+ Two cats and three dogs sitting on the grass.
334
+ New York Skyline with 'Google Brain Toronto' written with fireworks on the sky.
335
+ Rbefraigerator.
336
+ A storefront with 'Google Research Pizza Cafe' written on it.
337
+ Four cars on the street.
338
+ An oil painting portrait of the regal Burger King posing with a Whopper.
339
+ A fluffy baby sloth with a knitted hat trying to figure out a laptop, close up, highly detailed, studio lighting, screen reflecting in its eyes.
340
+ An oil painting of a couple in formal evening wear going home get caught in a heavy downpour with no umbrellas.
341
+ Painting of the orange cat Otto von Garfield, Count of Bismarck-Schönhausen, Duke of Lauenburg, Minister-President of Prussia. Depicted wearing a Prussian Pickelhaube and eating his favorite meal - lasagna.
342
+ A grocery store refrigerator has pint cartons of milk on the top shelf, quart cartons on the middle shelf, and gallon plastic jugs on the bottom shelf.
343
+ A real life photography of super mario, 8k Ultra HD.
344
+ A carrot on the left of a broccoli.
345
+ Darth Vader playing with raccoon in Mars during sunset.
346
+ Four dogs on the street.
347
+ Photo of a cat singing in a barbershop quartet.
348
+ A real life photography of super mario, 8k Ultra HD.
349
+ A triangular pink stop sign. A pink stop sign in the shape of a triangle.
350
+ A small blue book sitting on a large red book.
351
+ A green colored banana.
352
+ A bicycle on top of a boat.
353
+ A blue cup and a green cell phone.
354
+ A cat on the right of a tennis racket.
355
+ A stop sign on the right of a refrigerator.
356
+ A sign that says 'Diffusion'.
357
+ A blue coloured pizza.
358
+ A device consisting of a circular canopy of cloth on a folding metal frame supported by a central rod, used as protection against rain or sometimes sun.
359
+ A green cup and a blue cell phone.
360
+ Three cats and two dogs sitting on the grass.
361
+ A laptop on top of a teddy bear.
362
+ A medieval painting of the wifi not working.
363
+ A small vessel propelled on water by oars, sails, or an engine.
364
+ Photo of a mega Lego space station inside a kid's bedroom.
365
+ A car on the left of a bus.
366
+ A green colored banana.
367
+ A photo of a confused grizzly bear in calculus class.
368
+ Three dogs on the street.
369
+ A medieval painting of the wifi not working.
370
+ One cat and three dogs sitting on the grass.
371
+ A red colored car.
372
+ Photo of a mega Lego space station inside a kid's bedroom.
373
+ Abraham Lincoln touches his toes while George Washington does chin-ups. Lincoln is barefoot. Washington is wearing boots.
374
+ Photo of a cat singing in a barbershop quartet.
375
+ A tennis racket underneath a traffic light.
376
+ Two cars on the street.
377
+ A sign that says 'Hello World'.
378
+ A church with stained glass windows depicting a hamburger and french fries.
379
+ A horse riding an astronaut.
380
+ A cross-section view of a brain.
381
+ A couple of glasses are sitting on a table.
382
+ A domesticated carnivorous mammal that typically has a long snout, an acute sense of smell, nonretractable claws, and a barking, howling, or whining voice.
383
+ A green cup and a blue cell phone.
384
+ Acersecomicke.
385
+ A giraffe underneath a microwave.
386
+ An elephant is behind a tree. You can see the trunk on one side and the back legs on the other.
387
+ A train on top of a surfboard.
388
+ A banana on the left of an apple.
389
+ A blue cup and a green cell phone.
390
+ A blue colored dog.
391
+ A sphere made of kitchen tile. A sphere with the texture of kitchen tile.
392
+ A couple of glasses are sitting on a table.
393
+ Matutinal.
394
+ An instrument used for cutting cloth, paper, and other thin material, consisting of two blades laid one on top of the other and fastened in the middle so as to allow them to be opened and closed by a thumb and finger inserted through rings on the end of their handles.
395
+ New York Skyline with 'Diffusion' written with fireworks on the sky.
396
+ A white car and a red sheep.
397
+ A sign that says 'NeurIPS'.
398
+ Five cars on the street.
399
+ A red colored dog.
400
+ New York Skyline with 'Text to Image' written with fireworks on the sky.
401
+ New York Skyline with 'Diffusion' written with fireworks on the sky.
402
+ Three cats and three dogs sitting on the grass.
403
+ A storefront with 'Deep Learning' written on it.
404
+ A hair drier underneath a sheep.
405
+ An instqrumemnt used for cutting cloth, paper, axdz othr thdin mteroial, consamistng of two blades lad one on tvopb of the other and fhastned in tle mixdqdjle so as to bllow them txo be pened and closed by thumb and fitngesr inserted tgrough rings on kthe end oc thei vatndlzes.
406
+ One dog on the street.
407
+ A fish eating a pelican.
408
+ A baby fennec sneezing onto a strawberry, detailed, macro, studio light, droplets, backlit ears.
409
+ A maglev train going vertically downward in high speed, New York Times photojournalism.
410
+ Supreme Court Justices play a baseball game with the FBI. The FBI is at bat, the justices are on the field.
411
+ A photo of a confused grizzly bear in calculus class.
412
+ A triangular pink stop sign. A pink stop sign in the shape of a triangle.
413
+ Matutinal.
414
+ Two cars on the street.
415
+ An orange colored sandwich.
416
+ A storefront with 'NeurIPS' written on it.
417
+ A grocery store refrigerator has pint cartons of milk on the top shelf, quart cartons on the middle shelf, and gallon plastic jugs on the bottom shelf.
418
+ A stack of 3 plates. A blue plate is on the top, sitting on a blue plate. The blue plate is in the middle, sitting on a green plate. The green plate is on the bottom.
419
+ In late afternoon in January in New England, a man stands in the shadow of a maple tree.
420
+ Hovering cow abducting aliens.
421
+ A triangular pink stop sign. A pink stop sign in the shape of a triangle.
422
+ A photocopy of a photograph of a painting of a sculpture of a giraffe.
423
+ A separate seat for one person, typically with a back and four legs.
424
+ A horse riding an astronaut.
425
+ Three cats and three dogs sitting on the grass.
426
+ A bird scaring a scarecrow.
427
+ Tcennis rpacket.
428
+ One car on the street.
429
+ A mechanical or electrical device for measuring time.
430
+ New York Skyline with 'NeurIPS' written with fireworks on the sky.
431
+ A fish eating a pelican.
432
+ A black apple and a green backpack.
433
+ A cube made of denim. A cube with the texture of denim.
434
+ A storefront with 'Deep Learning' written on it.
435
+ New York Skyline with 'Deep Learning' written with fireworks on the sky.
436
+ A brown colored giraffe.
437
+ A bird scaring a scarecrow.
438
+ A blue colored dog.
439
+ An emoji of a baby panda wearing a red hat, green gloves, red shirt, and green pants.
440
+ A green cup and a blue cell phone.
441
+ A carrot on the left of a broccoli.
442
+ A green apple and a black backpack.
443
+ A yellow book and a red vase.
444
+ A triangular purple flower pot. A purple flower pot in the shape of a triangle.
445
+ A small vessel propelled on water by oars, sails, or an engine.
446
+ An orange colored sandwich.
447
+ A tomato has been put on top of a pumpkin on a kitchen stool. There is a fork sticking into the pumpkin. The scene is viewed from above.
448
+ Rbefraigerator.
449
+ A machine next to a parking space in a street, into which the driver puts money so as to be authorized to park the vehicle for a particular length of time.
450
+ A hair drier underneath a sheep.
451
+ A grocery store refrigerator has pint cartons of milk on the top shelf, quart cartons on the middle shelf, and gallon plastic jugs on the bottom shelf.
452
+ A sign that says 'Deep Learning'.
453
+ A cross-section view of a brain.
454
+ A black colored car.
455
+ Two cars on the street.
456
+ Photo of an athlete cat explaining it's latest scandal at a press conference to journalists.
457
+ Rainbow coloured penguin.
458
+ A black apple and a green backpack.
459
+ Darth Vader playing with raccoon in Mars during sunset.
460
+ A spider with a moustache bidding an equally gentlemanly grasshopper a good day during his walk to work.
461
+ One cat and three dogs sitting on the grass.
462
+ 35mm macro shot a kitten licking a baby duck, studio lighting.
463
+ An umbrella on top of a spoon.
464
+ Bzaseball galove.
465
+ Greek statue of a man tripping over a cat.
466
+ Supreme Court Justices play a baseball game with the FBI. The FBI is at bat, the justices are on the field.
467
+ An instqrumemnt used for cutting cloth, paper, axdz othr thdin mteroial, consamistng of two blades lad one on tvopb of the other and fhastned in tle mixdqdjle so as to bllow them txo be pened and closed by thumb and fitngesr inserted tgrough rings on kthe end oc thei vatndlzes.
468
+ A car on the left of a bus.
469
+ One dog on the street.
470
+ A church with stained glass windows depicting a hamburger and french fries.
471
+ A vehicle composed of two wheels held in a frame one behind the other, propelled by pedals and steered with handlebars attached to the front wheel.
472
+ A cross-section view of a brain.
473
+ A donut underneath a toilet.
474
+ A small blue book sitting on a large red book.
475
+ A smafml vessef epropoeilled on watvewr by ors, sauls, or han engie.
476
+ A sign that says 'Deep Learning'.
477
+ Photo of a cat singing in a barbershop quartet.
478
+ A cube made of brick. A cube with the texture of brick.
479
+ An oil painting of a couple in formal evening wear going home get caught in a heavy downpour with no umbrellas.
480
+ One car on the street.
481
+ A mechanical or electrical device for measuring time.
482
+ Hyper-realistic photo of an abandoned industrial site during a storm.
483
+ A giraffe underneath a microwave.
484
+ New York Skyline with 'Google Brain Toronto' written with fireworks on the sky.
485
+ An ancient Egyptian painting depicting an argument over whose turn it is to take out the trash.
486
+ A red book and a yellow vase.
487
+ A yellow colored giraffe.
488
+ A smafml vessef epropoeilled on watvewr by ors, sauls, or han engie.
489
+ A long curved fruit which grows in clusters and has soft pulpy flesh and yellow skin when ripe.
490
+ New York Skyline with 'Hello World' written with fireworks on the sky.
491
+ Two cats and two dogs sitting on the grass.
492
+ Photo of a cat singing in a barbershop quartet.
493
+ Colouring page of large cats climbing the eifel tower in a cyberpunk future.
494
+ Abraham Lincoln touches his toes while George Washington does chin-ups. Lincoln is barefoot. Washington is wearing boots.
495
+ A medieval painting of the wifi not working.
496
+ A car playing soccer, digital art.
497
+ A black colored car.
498
+ An orange colored sandwich.
499
+ A ldarge keybord msical instroument lwith a woden case enmclosig a qsouvnkboajrd and mfgtal strivgf, which are strucrk b hammrs when the nels are depresdsmed.f lhe strsingsj' vibration ie stopped by damperds when the keys re released and can bce regulavewdd for lengh and vnolume y two or three pedalvs.
500
+ An instrument used for cutting cloth, paper, and other thin material, consisting of two blades laid one on top of the other and fastened in the middle so as to allow them to be opened and closed by a thumb and finger inserted through rings on the end of their handles.
501
+ Four cars on the street.
502
+ A small domesticated carnivorous mammal with soft fur, a short snout, and retractable claws. It is widely kept as a pet or for catching mice, and many breeds have been developed.
503
+ A donkey and an octopus are playing a game. The donkey is holding a rope on one end, the octopus is holding onto the other. The donkey holds the rope in its mouth. A cat is jumping over the rope.
504
+ An illustration of a large red elephant sitting on a small blue mouse.
505
+ Octothorpe.
506
+ A fisheye lens view of a turtle sitting in a forest.
507
+ New York Skyline with 'Text to Image' written with fireworks on the sky.
508
+ A storefront with 'Deep Learning' written on it.
509
+ A spider with a moustache bidding an equally gentlemanly grasshopper a good day during his walk to work.
510
+ An oil painting of a couple in formal evening wear going home get caught in a heavy downpour with no umbrellas.
511
+ An emoji of a baby panda wearing a red hat, blue gloves, green shirt, and blue pants.
512
+ An IT-guy trying to fix hardware of a PC tower is being tangled by the PC cables like Laokoon. Marble, copy after Hellenistic original from ca. 200 BC. Found in the Baths of Trajan, 1506.
513
+ A sheep to the right of a wine glass.
514
+ A cube made of denim. A cube with the texture of denim.
515
+ Painting of the orange cat Otto von Garfield, Count of Bismarck-Schönhausen, Duke of Lauenburg, Minister-President of Prussia. Depicted wearing a Prussian Pickelhaube and eating his favorite meal - lasagna.
516
+ A sign that says 'Google Research Pizza Cafe'.
517
+ A ldarge keybord msical instroument lwith a woden case enmclosig a qsouvnkboajrd and mfgtal strivgf, which are strucrk b hammrs when the nels are depresdsmed.f lhe strsingsj' vibration ie stopped by damperds when the keys re released and can bce regulavewdd for lengh and vnolume y two or three pedalvs.
518
+ 35mm macro shot a kitten licking a baby duck, studio lighting.
519
+ A shark in the desert.
520
+ A green colored banana.
521
+ A green cup and a blue cell phone.
522
+ Backlotter.
523
+ Darth Vader playing with raccoon in Mars during sunset.
524
+ A green apple and a black backpack.
525
+ A tomato has been put on top of a pumpkin on a kitchen stool. There is a fork sticking into the pumpkin. The scene is viewed from above.
526
+ A red colored dog.
527
+ A red book and a yellow vase.
528
+ Rbefraigerator.
529
+ A train on top of a surfboard.
530
+ Dininrg tablez.
531
+ A separate seat for one person, typically with a back and four legs.
532
+ A domesticated carnivvorous mzammal that typicbally hfaas a lons sfnout, an acxujte sense off osmell, noneetractaaln crlaws, anid xbarkring,y howlingu, or whining rvoiche.
533
+ A black colored dog.
534
+ A pink colored giraffe.
535
+ New York Skyline with 'Google Brain Toronto' written with fireworks on the sky.
536
+ Supreme Court Justices play a baseball game with the FBI. The FBI is at bat, the justices are on the field.
537
+ A white car and a red sheep.
538
+ An organ of soft nervous tissue contained in the skull of vertebrates, functioning as the coordinating center of sensation and intellectual and nervous activity.
539
+ Tcennis rpacket.
540
+ A red book and a yellow vase.
541
+ A cross-section view of a brain.
542
+ An illustration of a small green elephant standing behind a large red mouse.
543
+ One dog on the street.
544
+ A zebra underneath a broccoli.
545
+ A zebra to the right of a fire hydrant.
546
+ A large plant-eating domesticated mammal with solid hoofs and a flowing mane and tail, used for riding, racing, and to carry and pull loads.
547
+ An elephant under the sea.
548
+ An elephant under the sea.
549
+ A pizza on the right of a suitcase.
550
+ Greek statue of a man tripping over a cat.
551
+ A couple of glasses are sitting on a table.
552
+ A storefront with 'Diffusion' written on it.
553
+ A sheep to the right of a wine glass.
554
+ A fisheye lens view of a turtle sitting in a forest.
555
+ A vehicle composed of two wheels held in a frame one behind the other, propelled by pedals and steered with handlebars attached to the front wheel.
556
+ A 1960s poster warning against climate change.
557
+ Three cars on the street.
558
+ An umbrella on top of a spoon.
559
+ A zebra underneath a broccoli.
560
+ A black colored dog.
561
+ A small domesticated carnivorous mammal with soft fur, a short snout, and retractable claws. It is widely kept as a pet or for catching mice, and many breeds have been developed.
562
+ A sign that says 'Google Brain Toronto'.
563
+ A large plant-eating domesticated mammal with solid hoofs and a flowing mane and tail, used for riding, racing, and to carry and pull loads.
564
+ A sign that says 'NeurIPS'.
565
+ Pafrking metr.
566
+ A sign that says 'Text to Image'.
567
+ A screenshot of an iOS app for ordering different types of milk.
568
+ A large motor vehicle carrying passengers by road, typically one serving the public on a fixed route and for a fare.
569
+ One cat and two dogs sitting on the grass.
570
+ A cube made of brick. A cube with the texture of brick.
571
+ A storefront with 'Text to Image' written on it.
572
+ A screenshot of an iOS app for ordering different types of milk.
573
+ Two dogs on the street.
574
+ Dininrg tablez.
575
+ A baby fennec sneezing onto a strawberry, detailed, macro, studio light, droplets, backlit ears.
576
+ A cat on the left of a dog.
577
+ A machine resembling a human being and able to replicate certain human movements and functions automatically.
578
+ A panda making latte art.
579
+ A storefront with 'Hello World' written on it.
580
+ New York Skyline with 'Diffusion' written with fireworks on the sky.
581
+ Two cats and three dogs sitting on the grass.
582
+ McDonalds Church.
583
+ A cat on the left of a dog.
584
+ Octothorpe.
585
+ Painting of Mona Lisa but the view is from behind of Mona Lisa.
586
+ A smafml vessef epropoeilled on watvewr by ors, sauls, or han engie.
587
+ A maglev train going vertically downward in high speed, New York Times photojournalism.
588
+ Three dogs on the street.
589
+ A mechanical or electrical device for measuring time.
590
+ A pear cut into seven pieces arranged in a ring.
591
+ Lego Arnold Schwarzenegger.
592
+ An appliance or compartment which is artificially kept cool and used to store food and drink.
593
+ A black colored car.
594
+ An oil painting portrait of the regal Burger King posing with a Whopper.
595
+ A black colored banana.
596
+ Three cats and three dogs sitting on the grass.
597
+ A domesticated carnivorous mammal that typically has a long snout, an acute sense of smell, nonretractable claws, and a barking, howling, or whining voice.
598
+ A wine glass on top of a dog.
599
+ A tomato has been put on top of a pumpkin on a kitchen stool. There is a fork sticking into the pumpkin. The scene is viewed from above.
600
+ Backlotter.
601
+ A bird scaring a scarecrow.
602
+ A single clock is sitting on a table.
603
+ Bzaseball galove.
604
+ A yellow colored giraffe.
605
+ A white colored sandwich.
606
+ A giraffe underneath a microwave.
607
+ A couch on the left of a chair.
608
+ A pizza on the right of a suitcase.
609
+ Lego Arnold Schwarzenegger.
610
+ A donut underneath a toilet.
611
+ A triangular orange picture frame. An orange picture frame in the shape of a triangle.
612
+ McDonalds Church.
613
+ 35mm macro shot a kitten licking a baby duck, studio lighting.
614
+ A machine resembling a human being and able to replicate certain human movements and functions automatically.
615
+ An elephant is behind a tree. You can see the trunk on one side and the back legs on the other.
616
+ A spider with a moustache bidding an equally gentlemanly grasshopper a good day during his walk to work.
617
+ An umbrella on top of a spoon.
618
+ Lego Arnold Schwarzenegger.
619
+ A yellow and black bus cruising through the rainforest.
620
+ A giraffe underneath a microwave.
621
+ A cube made of denim. A cube with the texture of denim.
622
+ A sheep to the right of a wine glass.
623
+ A bridge connecting Europe and North America on the Atlantic Ocean, bird's eye view.
624
+ A 1960s yearbook photo with animals dressed as humans.
625
+ Paying for a quarter-sized pizza with a pizza-sized quarter.
626
+ A black colored sandwich.
627
+ A large keyboard musical instrument with a wooden case enclosing a soundboard and metal strings, which are struck by hammers when the keys are depressed. The strings' vibration is stopped by dampers when the keys are released and can be regulated for length and volume by two or three pedals.
628
+ A spider with a moustache bidding an equally gentlemanly grasshopper a good day during his walk to work.
629
+ One car on the street.
630
+ A carrot on the left of a broccoli.
631
+ Two cats and three dogs sitting on the grass.
632
+ A stack of 3 books. A green book is on the top, sitting on a red book. The red book is in the middle, sitting on a blue book. The blue book is on the bottom.
633
+ Two cats and one dog sitting on the grass.
634
+ An instqrumemnt used for cutting cloth, paper, axdz othr thdin mteroial, consamistng of two blades lad one on tvopb of the other and fhastned in tle mixdqdjle so as to bllow them txo be pened and closed by thumb and fitngesr inserted tgrough rings on kthe end oc thei vatndlzes.
635
+ Dininrg tablez.
636
+ A connection point by which firefighters can tap into a water supply.
637
+ Four dogs on the street.
638
+ A sign that says 'Hello World'.
639
+ Photo of a mega Lego space station inside a kid's bedroom.
640
+ McDonalds Church.
641
+ Illustration of a mouse using a mushroom as an umbrella.
642
+ A magnifying glass over a page of a 1950s batman comic.
643
+ Hyper-realistic photo of an abandoned industrial site during a storm.
644
+ A magnifying glass over a page of a 1950s batman comic.
645
+ An umbrella on top of a spoon.
646
+ A fluffy baby sloth with a knitted hat trying to figure out a laptop, close up, highly detailed, studio lighting, screen reflecting in its eyes.
647
+ A large keyboard musical instrument with a wooden case enclosing a soundboard and metal strings, which are struck by hammers when the keys are depressed. The strings' vibration is stopped by dampers when the keys are released and can be regulated for length and volume by two or three pedals.
648
+ A red colored dog.
649
+ A red colored car.
650
+ A black colored car.
651
+ Five cars on the street.
652
+ A baby fennec sneezing onto a strawberry, detailed, macro, studio light, droplets, backlit ears.
653
+ In late afternoon in January in New England, a man stands in the shadow of a maple tree.
654
+ Photo of a cat singing in a barbershop quartet.
655
+ Hovering cow abducting aliens.
656
+ An old photograph of a 1920s airship shaped like a pig, floating over a wheat field.
657
+ An organ of soft nervous tissue contained in the skull of vertebrates, functioning as the coordinating center of sensation and intellectual and nervous activity.
658
+ A triangular purple flower pot. A purple flower pot in the shape of a triangle.
659
+ A pear cut into seven pieces arranged in a ring.
660
+ A red colored car.
661
+ Two cats and one dog sitting on the grass.
662
+ A cube made of brick. A cube with the texture of brick.
663
+ A pyramid made of falafel with a partial solar eclipse in the background.
664
+ A yellow colored giraffe.
665
+ An instrument used for cutting cloth, paper, and other thin material, consisting of two blades laid one on top of the other and fastened in the middle so as to allow them to be opened and closed by a thumb and finger inserted through rings on the end of their handles.
666
+ A sjmall domesticated carnivorious mammnal with sof fuh,y a sthort sout, and retracwtablbe flaws. It iw widexly kept as a pet or for catchitng mic, ad many breeds zhlyde beefn develvoked.
667
+ A couch on the left of a chair.
668
+ A photocopy of a photograph of a painting of a sculpture of a giraffe.
669
+ A sign that says 'Google Brain Toronto'.
670
+ A sign that says 'Text to Image'.
671
+ Rainbow coloured penguin.
672
+ Two dogs on the street.
673
+ A triangular orange picture frame. An orange picture frame in the shape of a triangle.
674
+ Colouring page of large cats climbing the eifel tower in a cyberpunk future.
675
+ A white colored sandwich.
676
+ A stop sign on the right of a refrigerator.
677
+ A long curved fruit which grows in clusters and has soft pulpy flesh and yellow skin when ripe.
678
+ Three cats and two dogs sitting on the grass.
679
+ A hair drier underneath a sheep.
680
+ A train on top of a surfboard.
681
+ A large plant-eating domesticated mammal with solid hoofs and a flowing mane and tail, used for riding, racing, and to carry and pull loads.
682
+ A sign that says 'Google Research Pizza Cafe'.
683
+ A stop sign on the right of a refrigerator.
684
+ A tiger in a lab coat with a 1980s Miami vibe, turning a well oiled science content machine, digital art.
685
+ A device consisting of a circular canopy of cloth on a folding metal frame supported by a central rod, used as protection against rain or sometimes sun.
686
+ An illustration of a small green elephant standing behind a large red mouse.
687
+ A sign that says 'Hello World'.
688
+ Lego Arnold Schwarzenegger.
689
+ Five dogs on the street.
690
+ A storefront with 'Hello World' written on it.
691
+ An instrument used for cutting cloth, paper, and other thin material, consisting of two blades laid one on top of the other and fastened in the middle so as to allow them to be opened and closed by a thumb and finger inserted through rings on the end of their handles.
692
+ A tomato has been put on top of a pumpkin on a kitchen stool. There is a fork sticking into the pumpkin. The scene is viewed from above.
693
+ In late afternoon in January in New England, a man stands in the shadow of a maple tree.
694
+ Jentacular.
695
+ Four dogs on the street.
696
+ An old photograph of a 1920s airship shaped like a pig, floating over a wheat field.
697
+ A triangular purple flower pot. A purple flower pot in the shape of a triangle.
698
+ A black colored banana.
699
+ 35mm macro shot a kitten licking a baby duck, studio lighting.
700
+ Bzaseball galove.
701
+ A fisheye lens view of a turtle sitting in a forest.
702
+ A donut underneath a toilet.
703
+ A donkey and an octopus are playing a game. The donkey is holding a rope on one end, the octopus is holding onto the other. The donkey holds the rope in its mouth. A cat is jumping over the rope.
704
+ A fluffy baby sloth with a knitted hat trying to figure out a laptop, close up, highly detailed, studio lighting, screen reflecting in its eyes.
705
+ A green apple and a black backpack.
706
+ An illustration of a large red elephant sitting on a small blue mouse.
707
+ A smafml vessef epropoeilled on watvewr by ors, sauls, or han engie.
708
+ A machine next to a parking space in a street, into which the driver puts money so as to be authorized to park the vehicle for a particular length of time.
709
+ Rainbow coloured penguin.
710
+ Three cats and one dog sitting on the grass.
711
+ An old photograph of a 1920s airship shaped like a pig, floating over a wheat field.
712
+ New York Skyline with 'Google Research Pizza Cafe' written with fireworks on the sky.
713
+ A painting by Grant Wood of an astronaut couple, american gothic style.
714
+ Four dogs on the street.
715
+ A large motor vehicle carrying passengers by road, typically one serving the public on a fixed route and for a fare.
716
+ A couple of glasses are sitting on a table.
717
+ In late afternoon in January in New England, a man stands in the shadow of a maple tree.
718
+ A brown colored giraffe.
719
+ A bridge connecting Europe and North America on the Atlantic Ocean, bird's eye view.
720
+ Five dogs on the street.
721
+ New York Skyline with 'Text to Image' written with fireworks on the sky.
722
+ An appliance or compartment which is artificially kept cool and used to store food and drink.
723
+ A real life photography of super mario, 8k Ultra HD.
724
+ A pink colored car.
725
+ A painting by Grant Wood of an astronaut couple, american gothic style.
726
+ A car on the left of a bus.
727
+ A large plant-eating domesticated mammal with solid hoofs and a flowing mane and tail, used for riding, racing, and to carry and pull loads.
728
+ Pafrking metr.
729
+ An illustration of a small green elephant standing behind a large red mouse.
730
+ A blue cup and a green cell phone.
731
+ New York Skyline with 'NeurIPS' written with fireworks on the sky.
732
+ A storefront with 'Google Brain Toronto' written on it.
733
+ A painting by Grant Wood of an astronaut couple, american gothic style.
734
+ A black colored sandwich.
735
+ A fish eating a pelican.
736
+ An emoji of a baby panda wearing a red hat, blue gloves, green shirt, and blue pants.
737
+ A vehicle composed of two wheels held in a frame one behind the other, propelled by pedals and steered with handlebars attached to the front wheel.
738
+ A tennis racket underneath a traffic light.
739
+ Three cars on the street.
740
+ One car on the street.
741
+ A tennis racket underneath a traffic light.
742
+ A maglev train going vertically downward in high speed, New York Times photojournalism.
743
+ Photo of an athlete cat explaining it's latest scandal at a press conference to journalists.
744
+ A red book and a yellow vase.
745
+ A shark in the desert.
746
+ An organ of soft nervous tissue contained in the skull of vertebrates, functioning as the coordinating center of sensation and intellectual and nervous activity.
747
+ A sign that says 'Text to Image'.
748
+ A stack of 3 books. A green book is on the top, sitting on a red book. The red book is in the middle, sitting on a blue book. The blue book is on the bottom.
749
+ A shark in the desert.
750
+ A 1960s poster warning against climate change.
751
+ Backlotter.
752
+ One cat and two dogs sitting on the grass.
753
+ Matutinal.
754
+ A cat on the right of a tennis racket.
755
+ A laptop on top of a teddy bear.
756
+ A white colored sandwich.
757
+ A yellow and black bus cruising through the rainforest.
758
+ A photocopy of a photograph of a painting of a sculpture of a giraffe.
759
+ A side view of an owl sitting in a field.
760
+ A pizza on the right of a suitcase.
761
+ A wine glass on top of a dog.
762
+ A realistic photo of a Pomeranian dressed up like a 1980s professional wrestler with neon green and neon orange face paint and bright green wrestling tights with bright orange boots.
763
+ A pear cut into seven pieces arranged in a ring.
764
+ Acersecomicke.
765
+ Painting of Mona Lisa but the view is from behind of Mona Lisa.
766
+ A small vessel propelled on water by oars, sails, or an engine.
767
+ Painting of the orange cat Otto von Garfield, Count of Bismarck-Schönhausen, Duke of Lauenburg, Minister-President of Prussia. Depicted wearing a Prussian Pickelhaube and eating his favorite meal - lasagna.
768
+ A cat on the left of a dog.
769
+ A red colored banana.
770
+ A domesticated carnivorous mammal that typically has a long snout, an acute sense of smell, nonretractable claws, and a barking, howling, or whining voice.
771
+ A sign that says 'Google Brain Toronto'.
772
+ A collection of nail is sitting on a table.
773
+ A pyramid made of falafel with a partial solar eclipse in the background.
774
+ A realistic photo of a Pomeranian dressed up like a 1980s professional wrestler with neon green and neon orange face paint and bright green wrestling tights with bright orange boots.
775
+ A cube made of brick. A cube with the texture of brick.
776
+ New York Skyline with 'Text to Image' written with fireworks on the sky.
777
+ A fish eating a pelican.
778
+ A pink colored giraffe.
779
+ One cat and three dogs sitting on the grass.
780
+ A keyboard made of water, the water is made of light, the light is turned off.
781
+ Greek statue of a man tripping over a cat.
782
+ A machine resembling a human being and able to replicate certain human movements and functions automatically.
783
+ A yellow and black bus cruising through the rainforest.
784
+ An elephant is behind a tree. You can see the trunk on one side and the back legs on the other.
785
+ A small domesticated carnivorous mammal with soft fur, a short snout, and retractable claws. It is widely kept as a pet or for catching mice, and many breeds have been developed.
786
+ Dininrg tablez.
787
+ A sign that says 'NeurIPS'.
788
+ An illustration of a small green elephant standing behind a large red mouse.
789
+ A collection of nail is sitting on a table.
790
+ An oil painting of a couple in formal evening wear going home get caught in a heavy downpour with no umbrellas.
791
+ New York Skyline with 'Hello World' written with fireworks on the sky.
792
+ A storefront with 'Text to Image' written on it.
793
+ A storefront with 'Deep Learning' written on it.
794
+ Three cats and two dogs sitting on the grass.
795
+ A red car and a white sheep.
796
+ A domesticated carnivvorous mzammal that typicbally hfaas a lons sfnout, an acxujte sense off osmell, noneetractaaln crlaws, anid xbarkring,y howlingu, or whining rvoiche.
797
+ A mechanical or electrical device for measuring time.
798
+ A bridge connecting Europe and North America on the Atlantic Ocean, bird's eye view.
799
+ An appliance or compartment which is artificially kept cool and used to store food and drink.
800
+ A pizza cooking an oven.
801
+ A car playing soccer, digital art.
802
+ A blue coloured pizza.
803
+ A machine next to a parking space in a street, into which the driver puts money so as to be authorized to park the vehicle for a particular length of time.
804
+ Octothorpe.
805
+ A yellow book and a red vase.
806
+ A bicycle on top of a boat.
807
+ A device consisting of a circular canopy of cloth on a folding metal frame supported by a central rod, used as protection against rain or sometimes sun.
808
+ An orange colored sandwich.
809
+ Acersecomicke.
810
+ A magnifying glass over a page of a 1950s batman comic.
811
+ A black apple and a green backpack.
812
+ A bird scaring a scarecrow.
813
+ A sign that says 'Deep Learning'.
814
+ A bicycle on top of a boat.
815
+ Painting of Mona Lisa but the view is from behind of Mona Lisa.
816
+ Three dogs on the street.
817
+ A stack of 3 plates. A blue plate is on the top, sitting on a blue plate. The blue plate is in the middle, sitting on a green plate. The green plate is on the bottom.
818
+ A red car and a white sheep.
819
+ Greek statue of a man tripping over a cat.
820
+ Three dogs on the street.
821
+ A sheep to the right of a wine glass.
822
+ One cat and one dog sitting on the grass.
823
+ A black colored sandwich.
824
+ Peristeronic.
825
+ Three cats and two dogs sitting on the grass.
826
+ A 1960s yearbook photo with animals dressed as humans.
827
+ A sign that says 'Diffusion'.
828
+ A sign that says 'Google Research Pizza Cafe'.
829
+ A blue bird and a brown bear.
830
+ A yellow and black bus cruising through the rainforest.
831
+ A smafml vessef epropoeilled on watvewr by ors, sauls, or han engie.
832
+ Bzaseball galove.
833
+ Artophagous.
834
+ A sign that says 'Text to Image'.
835
+ A baby fennec sneezing onto a strawberry, detailed, macro, studio light, droplets, backlit ears.
836
+ A fisheye lens view of a turtle sitting in a forest.
837
+ A storefront with 'Hello World' written on it.
838
+ A connection point by which firefighters can tap into a water supply.
839
+ A separate seat for one person, typically with a back and four legs.
840
+ A 1960s yearbook photo with animals dressed as humans.
841
+ A sphere made of kitchen tile. A sphere with the texture of kitchen tile.
842
+ A black colored banana.
843
+ A vehicle composed of two wheels held in a frame one behind the other, propelled by pedals and steered with handlebars attached to the front wheel.
844
+ Four cars on the street.
845
+ Three cats and three dogs sitting on the grass.
846
+ Five dogs on the street.
847
+ An ancient Egyptian painting depicting an argument over whose turn it is to take out the trash.
848
+ A storefront with 'Diffusion' written on it.
849
+ A pizza cooking an oven.
850
+ Darth Vader playing with raccoon in Mars during sunset.
851
+ A carrot on the left of a broccoli.
852
+ A tomato has been put on top of a pumpkin on a kitchen stool. There is a fork sticking into the pumpkin. The scene is viewed from above.
853
+ A storefront with 'Diffusion' written on it.
854
+ A red book and a yellow vase.
855
+ Peristeronic.
856
+ An organ of soft nervous tissue contained in the skull of vertebrates, functioning as the coordinating center of sensation and intellectual and nervous activity.
857
+ A donkey and an octopus are playing a game. The donkey is holding a rope on one end, the octopus is holding onto the other. The donkey holds the rope in its mouth. A cat is jumping over the rope.
858
+ A couch on the left of a chair.
859
+ A sphere made of kitchen tile. A sphere with the texture of kitchen tile.
860
+ A white car and a red sheep.
861
+ Artophagous.
862
+ A stack of 3 books. A green book is on the top, sitting on a red book. The red book is in the middle, sitting on a blue book. The blue book is on the bottom.
863
+ A pizza cooking an oven.
864
+ A triangular purple flower pot. A purple flower pot in the shape of a triangle.
865
+ A brown bird and a blue bear.
866
+ An IT-guy trying to fix hardware of a PC tower is being tangled by the PC cables like Laokoon. Marble, copy after Hellenistic original from ca. 200 BC. Found in the Baths of Trajan, 1506.
867
+ A storefront with 'Google Research Pizza Cafe' written on it.
868
+ A storefront with 'Google Research Pizza Cafe' written on it.
869
+ A large thick-skinned semiaquatic African mammal, with massive jaws and large tusks.
870
+ An appliance or compartment which is artificially kept cool and used to store food and drink.
871
+ A donut underneath a toilet.
872
+ A blue bird and a brown bear.
873
+ A 1960s poster warning against climate change.
874
+ A white colored sandwich.
875
+ A white colored sandwich.
876
+ A stop sign on the right of a refrigerator.
877
+ A storefront with 'Hello World' written on it.
878
+ Five dogs on the street.
879
+ Three cars on the street.
880
+ A keyboard made of water, the water is made of light, the light is turned off.
881
+ A red colored dog.
882
+ Two cats and three dogs sitting on the grass.
883
+ A spider with a moustache bidding an equally gentlemanly grasshopper a good day during his walk to work.
884
+ A pink colored car.
885
+ A tiger in a lab coat with a 1980s Miami vibe, turning a well oiled science content machine, digital art.
886
+ Photo of an athlete cat explaining it's latest scandal at a press conference to journalists.
887
+ A realistic photo of a Pomeranian dressed up like a 1980s professional wrestler with neon green and neon orange face paint and bright green wrestling tights with bright orange boots.
888
+ A type of digital currency in which a record of transactions is maintained and new units of currency are generated by the computational solution of mathematical problems, and which operates independently of a central bank.
889
+ A sign that says 'Hello World'.
890
+ An ancient Egyptian painting depicting an argument over whose turn it is to take out the trash.
891
+ A white car and a red sheep.
892
+ Illustration of a mouse using a mushroom as an umbrella.
893
+ A red colored banana.
894
+ Three cats and one dog sitting on the grass.
895
+ A car playing soccer, digital art.
896
+ A sjmall domesticated carnivorious mammnal with sof fuh,y a sthort sout, and retracwtablbe flaws. It iw widexly kept as a pet or for catchitng mic, ad many breeds zhlyde beefn develvoked.
897
+ Rbefraigerator.
898
+ A triangular orange picture frame. An orange picture frame in the shape of a triangle.
899
+ Rainbow coloured penguin.
900
+ A storefront with 'Text to Image' written on it.
901
+ A cat on the right of a tennis racket.
902
+ A small blue book sitting on a large red book.
903
+ Two cats and one dog sitting on the grass.
904
+ An emoji of a baby panda wearing a red hat, green gloves, red shirt, and green pants.
905
+ A brown bird and a blue bear.
906
+ A red car and a white sheep.
907
+ A pizza on the right of a suitcase.
908
+ A small blue book sitting on a large red book.
909
+ A horse riding an astronaut.
910
+ A sign that says 'Google Brain Toronto'.
911
+ Hyper-realistic photo of an abandoned industrial site during a storm.
912
+ A side view of an owl sitting in a field.
913
+ A photo of a confused grizzly bear in calculus class.
914
+ An American multinational technology company that focuses on artificial intelligence, search engine, online advertising, cloud computing, computer software, quantum computing, e-commerce, and consumer electronics.
915
+ A storefront with 'NeurIPS' written on it.
916
+ A storefront with 'NeurIPS' written on it.
917
+ Two cats and one dog sitting on the grass.
918
+ New York Skyline with 'Diffusion' written with fireworks on the sky.
919
+ A storefront with 'Diffusion' written on it.
920
+ A blue coloured pizza.
921
+ A single clock is sitting on a table.
922
+ A zebra to the right of a fire hydrant.
923
+ Backlotter.
924
+ An ancient Egyptian painting depicting an argument over whose turn it is to take out the trash.
925
+ Two cats and two dogs sitting on the grass.
926
+ Painting of Mona Lisa but the view is from behind of Mona Lisa.
927
+ A triangular orange picture frame. An orange picture frame in the shape of a triangle.
928
+ A bird scaring a scarecrow.
929
+ A keyboard made of water, the water is made of light, the light is turned off.
930
+ A tennis racket underneath a traffic light.
931
+ A banana on the left of an apple.
932
+ A screenshot of an iOS app for ordering different types of milk.
933
+ A long curved fruit which grows in clusters and has soft pulpy flesh and yellow skin when ripe.
934
+ A side view of an owl sitting in a field.
935
+ Two cats and two dogs sitting on the grass.
936
+ Hovering cow abducting aliens.
937
+ A red car and a white sheep.
938
+ A zebra underneath a broccoli.
939
+ Rainbow coloured penguin.
940
+ A storefront with 'Deep Learning' written on it.
941
+ Three cars on the street.
942
+ A red colored banana.
943
+ A blue bird and a brown bear.
944
+ New York Skyline with 'NeurIPS' written with fireworks on the sky.
945
+ A sjmall domesticated carnivorious mammnal with sof fuh,y a sthort sout, and retracwtablbe flaws. It iw widexly kept as a pet or for catchitng mic, ad many breeds zhlyde beefn develvoked.
946
+ A giraffe underneath a microwave.
947
+ A brown colored giraffe.
948
+ An instqrumemnt used for cutting cloth, paper, axdz othr thdin mteroial, consamistng of two blades lad one on tvopb of the other and fhastned in tle mixdqdjle so as to bllow them txo be pened and closed by thumb and fitngesr inserted tgrough rings on kthe end oc thei vatndlzes.
949
+ A pizza cooking an oven.
950
+ A bicycle on top of a boat.
951
+ A screenshot of an iOS app for ordering different types of milk.
952
+ A car playing soccer, digital art.
953
+ A banana on the left of an apple.
954
+ A cube made of brick. A cube with the texture of brick.
955
+ A sheep to the right of a wine glass.
956
+ A type of digital currency in which a record of transactions is maintained and new units of currency are generated by the computational solution of mathematical problems, and which operates independently of a central bank.
957
+ A medieval painting of the wifi not working.
958
+ A brown bird and a blue bear.
959
+ A yellow and black bus cruising through the rainforest.
960
+ A bridge connecting Europe and North America on the Atlantic Ocean, bird's eye view.
961
+ Hyper-realistic photo of an abandoned industrial site during a storm.
962
+ Photo of an athlete cat explaining it's latest scandal at a press conference to journalists.
963
+ A stack of 3 cubes. A red cube is on the top, sitting on a red cube. The red cube is in the middle, sitting on a green cube. The green cube is on the bottom.
964
+ A yellow book and a red vase.
965
+ A wine glass on top of a dog.
966
+ A sign that says 'Deep Learning'.
967
+ A small domesticated carnivorous mammal with soft fur, a short snout, and retractable claws. It is widely kept as a pet or for catching mice, and many breeds have been developed.
968
+ Jentacular.
969
+ A car on the left of a bus.
970
+ A machine resembling a human being and able to replicate certain human movements and functions automatically.
971
+ New York Skyline with 'Google Research Pizza Cafe' written with fireworks on the sky.
972
+ Photo of a mega Lego space station inside a kid's bedroom.
973
+ Peristeronic.
974
+ One cat and one dog sitting on the grass.
975
+ A horse riding an astronaut.
976
+ New York Skyline with 'Deep Learning' written with fireworks on the sky.
977
+ A zebra underneath a broccoli.
978
+ A machine resembling a human being and able to replicate certain human movements and functions automatically.
979
+ A red colored dog.
980
+ Acersecomicke.
981
+ One dog on the street.
982
+ A white car and a red sheep.
983
+ New York Skyline with 'NeurIPS' written with fireworks on the sky.
984
+ A single clock is sitting on a table.
985
+ A zebra to the right of a fire hydrant.
986
+ A triangular orange picture frame. An orange picture frame in the shape of a triangle.
987
+ A blue colored dog.
988
+ McDonalds Church.
989
+ Tcennis rpacket.
990
+ A brown colored giraffe.
991
+ Hyper-realistic photo of an abandoned industrial site during a storm.
992
+ Tcennis rpacket.
993
+ A church with stained glass windows depicting a hamburger and french fries.
994
+ A bicycle on top of a boat.
995
+ A banana on the left of an apple.
996
+ A connection point by which firefighters can tap into a water supply.
997
+ New York Skyline with 'Deep Learning' written with fireworks on the sky.
998
+ An instqrumemnt used for cutting cloth, paper, axdz othr thdin mteroial, consamistng of two blades lad one on tvopb of the other and fhastned in tle mixdqdjle so as to bllow them txo be pened and closed by thumb and fitngesr inserted tgrough rings on kthe end oc thei vatndlzes.
999
+ New York Skyline with 'NeurIPS' written with fireworks on the sky.
1000
+ Paying for a quarter-sized pizza with a pizza-sized quarter.
prompts/evaluation_metadata.jsonl ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"tag": "single_object", "include": [{"class": "bench", "count": 1}], "prompt": "a photo of a bench"}
2
+ {"tag": "single_object", "include": [{"class": "cow", "count": 1}], "prompt": "a photo of a cow"}
3
+ {"tag": "single_object", "include": [{"class": "bicycle", "count": 1}], "prompt": "a photo of a bicycle"}
4
+ {"tag": "single_object", "include": [{"class": "clock", "count": 1}], "prompt": "a photo of a clock"}
5
+ {"tag": "single_object", "include": [{"class": "carrot", "count": 1}], "prompt": "a photo of a carrot"}
6
+ {"tag": "single_object", "include": [{"class": "suitcase", "count": 1}], "prompt": "a photo of a suitcase"}
7
+ {"tag": "single_object", "include": [{"class": "fork", "count": 1}], "prompt": "a photo of a fork"}
8
+ {"tag": "single_object", "include": [{"class": "surfboard", "count": 1}], "prompt": "a photo of a surfboard"}
9
+ {"tag": "single_object", "include": [{"class": "refrigerator", "count": 1}], "prompt": "a photo of a refrigerator"}
10
+ {"tag": "single_object", "include": [{"class": "cup", "count": 1}], "prompt": "a photo of a cup"}
11
+ {"tag": "single_object", "include": [{"class": "microwave", "count": 1}], "prompt": "a photo of a microwave"}
12
+ {"tag": "single_object", "include": [{"class": "potted plant", "count": 1}], "prompt": "a photo of a potted plant"}
13
+ {"tag": "single_object", "include": [{"class": "snowboard", "count": 1}], "prompt": "a photo of a snowboard"}
14
+ {"tag": "single_object", "include": [{"class": "zebra", "count": 1}], "prompt": "a photo of a zebra"}
15
+ {"tag": "single_object", "include": [{"class": "parking meter", "count": 1}], "prompt": "a photo of a parking meter"}
16
+ {"tag": "single_object", "include": [{"class": "spoon", "count": 1}], "prompt": "a photo of a spoon"}
17
+ {"tag": "single_object", "include": [{"class": "skateboard", "count": 1}], "prompt": "a photo of a skateboard"}
18
+ {"tag": "single_object", "include": [{"class": "car", "count": 1}], "prompt": "a photo of a car"}
19
+ {"tag": "single_object", "include": [{"class": "motorcycle", "count": 1}], "prompt": "a photo of a motorcycle"}
20
+ {"tag": "single_object", "include": [{"class": "traffic light", "count": 1}], "prompt": "a photo of a traffic light"}
21
+ {"tag": "single_object", "include": [{"class": "book", "count": 1}], "prompt": "a photo of a book"}
22
+ {"tag": "single_object", "include": [{"class": "couch", "count": 1}], "prompt": "a photo of a couch"}
23
+ {"tag": "single_object", "include": [{"class": "backpack", "count": 1}], "prompt": "a photo of a backpack"}
24
+ {"tag": "single_object", "include": [{"class": "computer keyboard", "count": 1}], "prompt": "a photo of a computer keyboard"}
25
+ {"tag": "single_object", "include": [{"class": "toaster", "count": 1}], "prompt": "a photo of a toaster"}
26
+ {"tag": "single_object", "include": [{"class": "bird", "count": 1}], "prompt": "a photo of a bird"}
27
+ {"tag": "single_object", "include": [{"class": "bowl", "count": 1}], "prompt": "a photo of a bowl"}
28
+ {"tag": "single_object", "include": [{"class": "dog", "count": 1}], "prompt": "a photo of a dog"}
29
+ {"tag": "single_object", "include": [{"class": "tie", "count": 1}], "prompt": "a photo of a tie"}
30
+ {"tag": "single_object", "include": [{"class": "laptop", "count": 1}], "prompt": "a photo of a laptop"}
31
+ {"tag": "single_object", "include": [{"class": "computer mouse", "count": 1}], "prompt": "a photo of a computer mouse"}
32
+ {"tag": "single_object", "include": [{"class": "sandwich", "count": 1}], "prompt": "a photo of a sandwich"}
33
+ {"tag": "single_object", "include": [{"class": "baseball bat", "count": 1}], "prompt": "a photo of a baseball bat"}
34
+ {"tag": "single_object", "include": [{"class": "train", "count": 1}], "prompt": "a photo of a train"}
35
+ {"tag": "single_object", "include": [{"class": "cell phone", "count": 1}], "prompt": "a photo of a cell phone"}
36
+ {"tag": "single_object", "include": [{"class": "chair", "count": 1}], "prompt": "a photo of a chair"}
37
+ {"tag": "single_object", "include": [{"class": "tv", "count": 1}], "prompt": "a photo of a tv"}
38
+ {"tag": "single_object", "include": [{"class": "broccoli", "count": 1}], "prompt": "a photo of a broccoli"}
39
+ {"tag": "single_object", "include": [{"class": "bed", "count": 1}], "prompt": "a photo of a bed"}
40
+ {"tag": "single_object", "include": [{"class": "skis", "count": 1}], "prompt": "a photo of a skis"}
41
+ {"tag": "single_object", "include": [{"class": "handbag", "count": 1}], "prompt": "a photo of a handbag"}
42
+ {"tag": "single_object", "include": [{"class": "pizza", "count": 1}], "prompt": "a photo of a pizza"}
43
+ {"tag": "single_object", "include": [{"class": "frisbee", "count": 1}], "prompt": "a photo of a frisbee"}
44
+ {"tag": "single_object", "include": [{"class": "scissors", "count": 1}], "prompt": "a photo of a scissors"}
45
+ {"tag": "single_object", "include": [{"class": "bottle", "count": 1}], "prompt": "a photo of a bottle"}
46
+ {"tag": "single_object", "include": [{"class": "elephant", "count": 1}], "prompt": "a photo of an elephant"}
47
+ {"tag": "single_object", "include": [{"class": "toilet", "count": 1}], "prompt": "a photo of a toilet"}
48
+ {"tag": "single_object", "include": [{"class": "oven", "count": 1}], "prompt": "a photo of an oven"}
49
+ {"tag": "single_object", "include": [{"class": "orange", "count": 1}], "prompt": "a photo of an orange"}
50
+ {"tag": "single_object", "include": [{"class": "person", "count": 1}], "prompt": "a photo of a person"}
51
+ {"tag": "single_object", "include": [{"class": "teddy bear", "count": 1}], "prompt": "a photo of a teddy bear"}
52
+ {"tag": "single_object", "include": [{"class": "vase", "count": 1}], "prompt": "a photo of a vase"}
53
+ {"tag": "single_object", "include": [{"class": "banana", "count": 1}], "prompt": "a photo of a banana"}
54
+ {"tag": "single_object", "include": [{"class": "toothbrush", "count": 1}], "prompt": "a photo of a toothbrush"}
55
+ {"tag": "single_object", "include": [{"class": "tv remote", "count": 1}], "prompt": "a photo of a tv remote"}
56
+ {"tag": "single_object", "include": [{"class": "dining table", "count": 1}], "prompt": "a photo of a dining table"}
57
+ {"tag": "single_object", "include": [{"class": "stop sign", "count": 1}], "prompt": "a photo of a stop sign"}
58
+ {"tag": "single_object", "include": [{"class": "sheep", "count": 1}], "prompt": "a photo of a sheep"}
59
+ {"tag": "single_object", "include": [{"class": "fire hydrant", "count": 1}], "prompt": "a photo of a fire hydrant"}
60
+ {"tag": "single_object", "include": [{"class": "airplane", "count": 1}], "prompt": "a photo of an airplane"}
61
+ {"tag": "single_object", "include": [{"class": "giraffe", "count": 1}], "prompt": "a photo of a giraffe"}
62
+ {"tag": "single_object", "include": [{"class": "horse", "count": 1}], "prompt": "a photo of a horse"}
63
+ {"tag": "single_object", "include": [{"class": "cat", "count": 1}], "prompt": "a photo of a cat"}
64
+ {"tag": "single_object", "include": [{"class": "donut", "count": 1}], "prompt": "a photo of a donut"}
65
+ {"tag": "single_object", "include": [{"class": "boat", "count": 1}], "prompt": "a photo of a boat"}
66
+ {"tag": "single_object", "include": [{"class": "baseball glove", "count": 1}], "prompt": "a photo of a baseball glove"}
67
+ {"tag": "single_object", "include": [{"class": "hair drier", "count": 1}], "prompt": "a photo of a hair drier"}
68
+ {"tag": "single_object", "include": [{"class": "sink", "count": 1}], "prompt": "a photo of a sink"}
69
+ {"tag": "single_object", "include": [{"class": "cake", "count": 1}], "prompt": "a photo of a cake"}
70
+ {"tag": "single_object", "include": [{"class": "wine glass", "count": 1}], "prompt": "a photo of a wine glass"}
71
+ {"tag": "single_object", "include": [{"class": "apple", "count": 1}], "prompt": "a photo of an apple"}
72
+ {"tag": "single_object", "include": [{"class": "bus", "count": 1}], "prompt": "a photo of a bus"}
73
+ {"tag": "single_object", "include": [{"class": "tennis racket", "count": 1}], "prompt": "a photo of a tennis racket"}
74
+ {"tag": "single_object", "include": [{"class": "knife", "count": 1}], "prompt": "a photo of a knife"}
75
+ {"tag": "single_object", "include": [{"class": "hot dog", "count": 1}], "prompt": "a photo of a hot dog"}
76
+ {"tag": "single_object", "include": [{"class": "truck", "count": 1}], "prompt": "a photo of a truck"}
77
+ {"tag": "single_object", "include": [{"class": "umbrella", "count": 1}], "prompt": "a photo of an umbrella"}
78
+ {"tag": "single_object", "include": [{"class": "sports ball", "count": 1}], "prompt": "a photo of a sports ball"}
79
+ {"tag": "single_object", "include": [{"class": "bear", "count": 1}], "prompt": "a photo of a bear"}
80
+ {"tag": "single_object", "include": [{"class": "kite", "count": 1}], "prompt": "a photo of a kite"}
81
+ {"tag": "two_object", "include": [{"class": "bench", "count": 1}, {"class": "sports ball", "count": 1}], "prompt": "a photo of a bench and a sports ball"}
82
+ {"tag": "two_object", "include": [{"class": "toothbrush", "count": 1}, {"class": "snowboard", "count": 1}], "prompt": "a photo of a toothbrush and a snowboard"}
83
+ {"tag": "two_object", "include": [{"class": "toaster", "count": 1}, {"class": "oven", "count": 1}], "prompt": "a photo of a toaster and an oven"}
84
+ {"tag": "two_object", "include": [{"class": "broccoli", "count": 1}, {"class": "vase", "count": 1}], "prompt": "a photo of a broccoli and a vase"}
85
+ {"tag": "two_object", "include": [{"class": "tennis racket", "count": 1}, {"class": "wine glass", "count": 1}], "prompt": "a photo of a tennis racket and a wine glass"}
86
+ {"tag": "two_object", "include": [{"class": "fork", "count": 1}, {"class": "knife", "count": 1}], "prompt": "a photo of a fork and a knife"}
87
+ {"tag": "two_object", "include": [{"class": "hair drier", "count": 1}, {"class": "cake", "count": 1}], "prompt": "a photo of a hair drier and a cake"}
88
+ {"tag": "two_object", "include": [{"class": "horse", "count": 1}, {"class": "giraffe", "count": 1}], "prompt": "a photo of a horse and a giraffe"}
89
+ {"tag": "two_object", "include": [{"class": "horse", "count": 1}, {"class": "computer keyboard", "count": 1}], "prompt": "a photo of a horse and a computer keyboard"}
90
+ {"tag": "two_object", "include": [{"class": "toothbrush", "count": 1}, {"class": "carrot", "count": 1}], "prompt": "a photo of a toothbrush and a carrot"}
91
+ {"tag": "two_object", "include": [{"class": "cake", "count": 1}, {"class": "zebra", "count": 1}], "prompt": "a photo of a cake and a zebra"}
92
+ {"tag": "two_object", "include": [{"class": "hair drier", "count": 1}, {"class": "bear", "count": 1}], "prompt": "a photo of a hair drier and a bear"}
93
+ {"tag": "two_object", "include": [{"class": "knife", "count": 1}, {"class": "zebra", "count": 1}], "prompt": "a photo of a knife and a zebra"}
94
+ {"tag": "two_object", "include": [{"class": "couch", "count": 1}, {"class": "wine glass", "count": 1}], "prompt": "a photo of a couch and a wine glass"}
95
+ {"tag": "two_object", "include": [{"class": "frisbee", "count": 1}, {"class": "vase", "count": 1}], "prompt": "a photo of a frisbee and a vase"}
96
+ {"tag": "two_object", "include": [{"class": "book", "count": 1}, {"class": "laptop", "count": 1}], "prompt": "a photo of a book and a laptop"}
97
+ {"tag": "two_object", "include": [{"class": "dining table", "count": 1}, {"class": "bear", "count": 1}], "prompt": "a photo of a dining table and a bear"}
98
+ {"tag": "two_object", "include": [{"class": "frisbee", "count": 1}, {"class": "couch", "count": 1}], "prompt": "a photo of a frisbee and a couch"}
99
+ {"tag": "two_object", "include": [{"class": "couch", "count": 1}, {"class": "horse", "count": 1}], "prompt": "a photo of a couch and a horse"}
100
+ {"tag": "two_object", "include": [{"class": "toilet", "count": 1}, {"class": "computer mouse", "count": 1}], "prompt": "a photo of a toilet and a computer mouse"}
101
+ {"tag": "two_object", "include": [{"class": "bottle", "count": 1}, {"class": "refrigerator", "count": 1}], "prompt": "a photo of a bottle and a refrigerator"}
102
+ {"tag": "two_object", "include": [{"class": "potted plant", "count": 1}, {"class": "backpack", "count": 1}], "prompt": "a photo of a potted plant and a backpack"}
103
+ {"tag": "two_object", "include": [{"class": "skateboard", "count": 1}, {"class": "cake", "count": 1}], "prompt": "a photo of a skateboard and a cake"}
104
+ {"tag": "two_object", "include": [{"class": "broccoli", "count": 1}, {"class": "parking meter", "count": 1}], "prompt": "a photo of a broccoli and a parking meter"}
105
+ {"tag": "two_object", "include": [{"class": "zebra", "count": 1}, {"class": "bed", "count": 1}], "prompt": "a photo of a zebra and a bed"}
106
+ {"tag": "two_object", "include": [{"class": "oven", "count": 1}, {"class": "bed", "count": 1}], "prompt": "a photo of an oven and a bed"}
107
+ {"tag": "two_object", "include": [{"class": "baseball bat", "count": 1}, {"class": "fork", "count": 1}], "prompt": "a photo of a baseball bat and a fork"}
108
+ {"tag": "two_object", "include": [{"class": "vase", "count": 1}, {"class": "spoon", "count": 1}], "prompt": "a photo of a vase and a spoon"}
109
+ {"tag": "two_object", "include": [{"class": "skateboard", "count": 1}, {"class": "sink", "count": 1}], "prompt": "a photo of a skateboard and a sink"}
110
+ {"tag": "two_object", "include": [{"class": "pizza", "count": 1}, {"class": "bench", "count": 1}], "prompt": "a photo of a pizza and a bench"}
111
+ {"tag": "two_object", "include": [{"class": "bowl", "count": 1}, {"class": "pizza", "count": 1}], "prompt": "a photo of a bowl and a pizza"}
112
+ {"tag": "two_object", "include": [{"class": "tennis racket", "count": 1}, {"class": "bird", "count": 1}], "prompt": "a photo of a tennis racket and a bird"}
113
+ {"tag": "two_object", "include": [{"class": "wine glass", "count": 1}, {"class": "bear", "count": 1}], "prompt": "a photo of a wine glass and a bear"}
114
+ {"tag": "two_object", "include": [{"class": "fork", "count": 1}, {"class": "book", "count": 1}], "prompt": "a photo of a fork and a book"}
115
+ {"tag": "two_object", "include": [{"class": "scissors", "count": 1}, {"class": "bowl", "count": 1}], "prompt": "a photo of a scissors and a bowl"}
116
+ {"tag": "two_object", "include": [{"class": "laptop", "count": 1}, {"class": "carrot", "count": 1}], "prompt": "a photo of a laptop and a carrot"}
117
+ {"tag": "two_object", "include": [{"class": "stop sign", "count": 1}, {"class": "bottle", "count": 1}], "prompt": "a photo of a stop sign and a bottle"}
118
+ {"tag": "two_object", "include": [{"class": "microwave", "count": 1}, {"class": "truck", "count": 1}], "prompt": "a photo of a microwave and a truck"}
119
+ {"tag": "two_object", "include": [{"class": "person", "count": 1}, {"class": "bear", "count": 1}], "prompt": "a photo of a person and a bear"}
120
+ {"tag": "two_object", "include": [{"class": "frisbee", "count": 1}, {"class": "cell phone", "count": 1}], "prompt": "a photo of a frisbee and a cell phone"}
121
+ {"tag": "two_object", "include": [{"class": "parking meter", "count": 1}, {"class": "teddy bear", "count": 1}], "prompt": "a photo of a parking meter and a teddy bear"}
122
+ {"tag": "two_object", "include": [{"class": "tennis racket", "count": 1}, {"class": "bicycle", "count": 1}], "prompt": "a photo of a tennis racket and a bicycle"}
123
+ {"tag": "two_object", "include": [{"class": "stop sign", "count": 1}, {"class": "motorcycle", "count": 1}], "prompt": "a photo of a stop sign and a motorcycle"}
124
+ {"tag": "two_object", "include": [{"class": "fire hydrant", "count": 1}, {"class": "tennis racket", "count": 1}], "prompt": "a photo of a fire hydrant and a tennis racket"}
125
+ {"tag": "two_object", "include": [{"class": "scissors", "count": 1}, {"class": "sandwich", "count": 1}], "prompt": "a photo of a scissors and a sandwich"}
126
+ {"tag": "two_object", "include": [{"class": "pizza", "count": 1}, {"class": "book", "count": 1}], "prompt": "a photo of a pizza and a book"}
127
+ {"tag": "two_object", "include": [{"class": "giraffe", "count": 1}, {"class": "computer mouse", "count": 1}], "prompt": "a photo of a giraffe and a computer mouse"}
128
+ {"tag": "two_object", "include": [{"class": "stop sign", "count": 1}, {"class": "toaster", "count": 1}], "prompt": "a photo of a stop sign and a toaster"}
129
+ {"tag": "two_object", "include": [{"class": "computer mouse", "count": 1}, {"class": "zebra", "count": 1}], "prompt": "a photo of a computer mouse and a zebra"}
130
+ {"tag": "two_object", "include": [{"class": "chair", "count": 1}, {"class": "bench", "count": 1}], "prompt": "a photo of a chair and a bench"}
131
+ {"tag": "two_object", "include": [{"class": "tv", "count": 1}, {"class": "carrot", "count": 1}], "prompt": "a photo of a tv and a carrot"}
132
+ {"tag": "two_object", "include": [{"class": "surfboard", "count": 1}, {"class": "suitcase", "count": 1}], "prompt": "a photo of a surfboard and a suitcase"}
133
+ {"tag": "two_object", "include": [{"class": "computer keyboard", "count": 1}, {"class": "laptop", "count": 1}], "prompt": "a photo of a computer keyboard and a laptop"}
134
+ {"tag": "two_object", "include": [{"class": "computer keyboard", "count": 1}, {"class": "microwave", "count": 1}], "prompt": "a photo of a computer keyboard and a microwave"}
135
+ {"tag": "two_object", "include": [{"class": "scissors", "count": 1}, {"class": "bird", "count": 1}], "prompt": "a photo of a scissors and a bird"}
136
+ {"tag": "two_object", "include": [{"class": "person", "count": 1}, {"class": "snowboard", "count": 1}], "prompt": "a photo of a person and a snowboard"}
137
+ {"tag": "two_object", "include": [{"class": "cow", "count": 1}, {"class": "horse", "count": 1}], "prompt": "a photo of a cow and a horse"}
138
+ {"tag": "two_object", "include": [{"class": "handbag", "count": 1}, {"class": "refrigerator", "count": 1}], "prompt": "a photo of a handbag and a refrigerator"}
139
+ {"tag": "two_object", "include": [{"class": "chair", "count": 1}, {"class": "laptop", "count": 1}], "prompt": "a photo of a chair and a laptop"}
140
+ {"tag": "two_object", "include": [{"class": "toothbrush", "count": 1}, {"class": "bench", "count": 1}], "prompt": "a photo of a toothbrush and a bench"}
141
+ {"tag": "two_object", "include": [{"class": "book", "count": 1}, {"class": "baseball bat", "count": 1}], "prompt": "a photo of a book and a baseball bat"}
142
+ {"tag": "two_object", "include": [{"class": "horse", "count": 1}, {"class": "train", "count": 1}], "prompt": "a photo of a horse and a train"}
143
+ {"tag": "two_object", "include": [{"class": "bench", "count": 1}, {"class": "vase", "count": 1}], "prompt": "a photo of a bench and a vase"}
144
+ {"tag": "two_object", "include": [{"class": "traffic light", "count": 1}, {"class": "backpack", "count": 1}], "prompt": "a photo of a traffic light and a backpack"}
145
+ {"tag": "two_object", "include": [{"class": "sports ball", "count": 1}, {"class": "cow", "count": 1}], "prompt": "a photo of a sports ball and a cow"}
146
+ {"tag": "two_object", "include": [{"class": "computer mouse", "count": 1}, {"class": "spoon", "count": 1}], "prompt": "a photo of a computer mouse and a spoon"}
147
+ {"tag": "two_object", "include": [{"class": "tv", "count": 1}, {"class": "bicycle", "count": 1}], "prompt": "a photo of a tv and a bicycle"}
148
+ {"tag": "two_object", "include": [{"class": "bench", "count": 1}, {"class": "snowboard", "count": 1}], "prompt": "a photo of a bench and a snowboard"}
149
+ {"tag": "two_object", "include": [{"class": "toothbrush", "count": 1}, {"class": "toilet", "count": 1}], "prompt": "a photo of a toothbrush and a toilet"}
150
+ {"tag": "two_object", "include": [{"class": "person", "count": 1}, {"class": "apple", "count": 1}], "prompt": "a photo of a person and an apple"}
151
+ {"tag": "two_object", "include": [{"class": "sink", "count": 1}, {"class": "sports ball", "count": 1}], "prompt": "a photo of a sink and a sports ball"}
152
+ {"tag": "two_object", "include": [{"class": "stop sign", "count": 1}, {"class": "dog", "count": 1}], "prompt": "a photo of a stop sign and a dog"}
153
+ {"tag": "two_object", "include": [{"class": "knife", "count": 1}, {"class": "stop sign", "count": 1}], "prompt": "a photo of a knife and a stop sign"}
154
+ {"tag": "two_object", "include": [{"class": "wine glass", "count": 1}, {"class": "handbag", "count": 1}], "prompt": "a photo of a wine glass and a handbag"}
155
+ {"tag": "two_object", "include": [{"class": "bowl", "count": 1}, {"class": "skis", "count": 1}], "prompt": "a photo of a bowl and a skis"}
156
+ {"tag": "two_object", "include": [{"class": "frisbee", "count": 1}, {"class": "apple", "count": 1}], "prompt": "a photo of a frisbee and an apple"}
157
+ {"tag": "two_object", "include": [{"class": "computer keyboard", "count": 1}, {"class": "cell phone", "count": 1}], "prompt": "a photo of a computer keyboard and a cell phone"}
158
+ {"tag": "two_object", "include": [{"class": "stop sign", "count": 1}, {"class": "fork", "count": 1}], "prompt": "a photo of a stop sign and a fork"}
159
+ {"tag": "two_object", "include": [{"class": "potted plant", "count": 1}, {"class": "boat", "count": 1}], "prompt": "a photo of a potted plant and a boat"}
160
+ {"tag": "two_object", "include": [{"class": "tv", "count": 1}, {"class": "cell phone", "count": 1}], "prompt": "a photo of a tv and a cell phone"}
161
+ {"tag": "two_object", "include": [{"class": "tie", "count": 1}, {"class": "broccoli", "count": 1}], "prompt": "a photo of a tie and a broccoli"}
162
+ {"tag": "two_object", "include": [{"class": "potted plant", "count": 1}, {"class": "donut", "count": 1}], "prompt": "a photo of a potted plant and a donut"}
163
+ {"tag": "two_object", "include": [{"class": "person", "count": 1}, {"class": "sink", "count": 1}], "prompt": "a photo of a person and a sink"}
164
+ {"tag": "two_object", "include": [{"class": "couch", "count": 1}, {"class": "snowboard", "count": 1}], "prompt": "a photo of a couch and a snowboard"}
165
+ {"tag": "two_object", "include": [{"class": "fork", "count": 1}, {"class": "baseball glove", "count": 1}], "prompt": "a photo of a fork and a baseball glove"}
166
+ {"tag": "two_object", "include": [{"class": "apple", "count": 1}, {"class": "toothbrush", "count": 1}], "prompt": "a photo of an apple and a toothbrush"}
167
+ {"tag": "two_object", "include": [{"class": "bus", "count": 1}, {"class": "baseball glove", "count": 1}], "prompt": "a photo of a bus and a baseball glove"}
168
+ {"tag": "two_object", "include": [{"class": "person", "count": 1}, {"class": "stop sign", "count": 1}], "prompt": "a photo of a person and a stop sign"}
169
+ {"tag": "two_object", "include": [{"class": "carrot", "count": 1}, {"class": "couch", "count": 1}], "prompt": "a photo of a carrot and a couch"}
170
+ {"tag": "two_object", "include": [{"class": "baseball bat", "count": 1}, {"class": "bear", "count": 1}], "prompt": "a photo of a baseball bat and a bear"}
171
+ {"tag": "two_object", "include": [{"class": "fire hydrant", "count": 1}, {"class": "train", "count": 1}], "prompt": "a photo of a fire hydrant and a train"}
172
+ {"tag": "two_object", "include": [{"class": "baseball glove", "count": 1}, {"class": "carrot", "count": 1}], "prompt": "a photo of a baseball glove and a carrot"}
173
+ {"tag": "two_object", "include": [{"class": "microwave", "count": 1}, {"class": "bench", "count": 1}], "prompt": "a photo of a microwave and a bench"}
174
+ {"tag": "two_object", "include": [{"class": "cake", "count": 1}, {"class": "stop sign", "count": 1}], "prompt": "a photo of a cake and a stop sign"}
175
+ {"tag": "two_object", "include": [{"class": "car", "count": 1}, {"class": "computer mouse", "count": 1}], "prompt": "a photo of a car and a computer mouse"}
176
+ {"tag": "two_object", "include": [{"class": "suitcase", "count": 1}, {"class": "dining table", "count": 1}], "prompt": "a photo of a suitcase and a dining table"}
177
+ {"tag": "two_object", "include": [{"class": "person", "count": 1}, {"class": "traffic light", "count": 1}], "prompt": "a photo of a person and a traffic light"}
178
+ {"tag": "two_object", "include": [{"class": "cell phone", "count": 1}, {"class": "horse", "count": 1}], "prompt": "a photo of a cell phone and a horse"}
179
+ {"tag": "two_object", "include": [{"class": "baseball bat", "count": 1}, {"class": "giraffe", "count": 1}], "prompt": "a photo of a baseball bat and a giraffe"}
180
+ {"tag": "counting", "include": [{"class": "clock", "count": 2}], "exclude": [{"class": "clock", "count": 3}], "prompt": "a photo of two clocks"}
181
+ {"tag": "counting", "include": [{"class": "backpack", "count": 2}], "exclude": [{"class": "backpack", "count": 3}], "prompt": "a photo of two backpacks"}
182
+ {"tag": "counting", "include": [{"class": "handbag", "count": 4}], "exclude": [{"class": "handbag", "count": 5}], "prompt": "a photo of four handbags"}
183
+ {"tag": "counting", "include": [{"class": "frisbee", "count": 2}], "exclude": [{"class": "frisbee", "count": 3}], "prompt": "a photo of two frisbees"}
184
+ {"tag": "counting", "include": [{"class": "sports ball", "count": 3}], "exclude": [{"class": "sports ball", "count": 4}], "prompt": "a photo of three sports balls"}
185
+ {"tag": "counting", "include": [{"class": "bear", "count": 2}], "exclude": [{"class": "bear", "count": 3}], "prompt": "a photo of two bears"}
186
+ {"tag": "counting", "include": [{"class": "tie", "count": 2}], "exclude": [{"class": "tie", "count": 3}], "prompt": "a photo of two ties"}
187
+ {"tag": "counting", "include": [{"class": "sink", "count": 4}], "exclude": [{"class": "sink", "count": 5}], "prompt": "a photo of four sinks"}
188
+ {"tag": "counting", "include": [{"class": "toothbrush", "count": 2}], "exclude": [{"class": "toothbrush", "count": 3}], "prompt": "a photo of two toothbrushs"}
189
+ {"tag": "counting", "include": [{"class": "person", "count": 3}], "exclude": [{"class": "person", "count": 4}], "prompt": "a photo of three persons"}
190
+ {"tag": "counting", "include": [{"class": "tennis racket", "count": 3}], "exclude": [{"class": "tennis racket", "count": 4}], "prompt": "a photo of three tennis rackets"}
191
+ {"tag": "counting", "include": [{"class": "bowl", "count": 4}], "exclude": [{"class": "bowl", "count": 5}], "prompt": "a photo of four bowls"}
192
+ {"tag": "counting", "include": [{"class": "vase", "count": 4}], "exclude": [{"class": "vase", "count": 5}], "prompt": "a photo of four vases"}
193
+ {"tag": "counting", "include": [{"class": "cup", "count": 3}], "exclude": [{"class": "cup", "count": 4}], "prompt": "a photo of three cups"}
194
+ {"tag": "counting", "include": [{"class": "computer keyboard", "count": 4}], "exclude": [{"class": "computer keyboard", "count": 5}], "prompt": "a photo of four computer keyboards"}
195
+ {"tag": "counting", "include": [{"class": "sink", "count": 3}], "exclude": [{"class": "sink", "count": 4}], "prompt": "a photo of three sinks"}
196
+ {"tag": "counting", "include": [{"class": "oven", "count": 2}], "exclude": [{"class": "oven", "count": 3}], "prompt": "a photo of two ovens"}
197
+ {"tag": "counting", "include": [{"class": "toilet", "count": 2}], "exclude": [{"class": "toilet", "count": 3}], "prompt": "a photo of two toilets"}
198
+ {"tag": "counting", "include": [{"class": "bicycle", "count": 2}], "exclude": [{"class": "bicycle", "count": 3}], "prompt": "a photo of two bicycles"}
199
+ {"tag": "counting", "include": [{"class": "train", "count": 2}], "exclude": [{"class": "train", "count": 3}], "prompt": "a photo of two trains"}
200
+ {"tag": "counting", "include": [{"class": "orange", "count": 3}], "exclude": [{"class": "orange", "count": 4}], "prompt": "a photo of three oranges"}
201
+ {"tag": "counting", "include": [{"class": "bus", "count": 3}], "exclude": [{"class": "bus", "count": 4}], "prompt": "a photo of three buses"}
202
+ {"tag": "counting", "include": [{"class": "handbag", "count": 3}], "exclude": [{"class": "handbag", "count": 4}], "prompt": "a photo of three handbags"}
203
+ {"tag": "counting", "include": [{"class": "snowboard", "count": 3}], "exclude": [{"class": "snowboard", "count": 4}], "prompt": "a photo of three snowboards"}
204
+ {"tag": "counting", "include": [{"class": "snowboard", "count": 2}], "exclude": [{"class": "snowboard", "count": 3}], "prompt": "a photo of two snowboards"}
205
+ {"tag": "counting", "include": [{"class": "dog", "count": 4}], "exclude": [{"class": "dog", "count": 5}], "prompt": "a photo of four dogs"}
206
+ {"tag": "counting", "include": [{"class": "apple", "count": 3}], "exclude": [{"class": "apple", "count": 4}], "prompt": "a photo of three apples"}
207
+ {"tag": "counting", "include": [{"class": "sheep", "count": 2}], "exclude": [{"class": "sheep", "count": 3}], "prompt": "a photo of two sheeps"}
208
+ {"tag": "counting", "include": [{"class": "hot dog", "count": 3}], "exclude": [{"class": "hot dog", "count": 4}], "prompt": "a photo of three hot dogs"}
209
+ {"tag": "counting", "include": [{"class": "zebra", "count": 3}], "exclude": [{"class": "zebra", "count": 4}], "prompt": "a photo of three zebras"}
210
+ {"tag": "counting", "include": [{"class": "kite", "count": 3}], "exclude": [{"class": "kite", "count": 4}], "prompt": "a photo of three kites"}
211
+ {"tag": "counting", "include": [{"class": "apple", "count": 4}], "exclude": [{"class": "apple", "count": 5}], "prompt": "a photo of four apples"}
212
+ {"tag": "counting", "include": [{"class": "cell phone", "count": 3}], "exclude": [{"class": "cell phone", "count": 4}], "prompt": "a photo of three cell phones"}
213
+ {"tag": "counting", "include": [{"class": "baseball glove", "count": 4}], "exclude": [{"class": "baseball glove", "count": 5}], "prompt": "a photo of four baseball gloves"}
214
+ {"tag": "counting", "include": [{"class": "computer keyboard", "count": 3}], "exclude": [{"class": "computer keyboard", "count": 4}], "prompt": "a photo of three computer keyboards"}
215
+ {"tag": "counting", "include": [{"class": "bed", "count": 2}], "exclude": [{"class": "bed", "count": 3}], "prompt": "a photo of two beds"}
216
+ {"tag": "counting", "include": [{"class": "tv remote", "count": 2}], "exclude": [{"class": "tv remote", "count": 3}], "prompt": "a photo of two tv remotes"}
217
+ {"tag": "counting", "include": [{"class": "fire hydrant", "count": 3}], "exclude": [{"class": "fire hydrant", "count": 4}], "prompt": "a photo of three fire hydrants"}
218
+ {"tag": "counting", "include": [{"class": "book", "count": 3}], "exclude": [{"class": "book", "count": 4}], "prompt": "a photo of three books"}
219
+ {"tag": "counting", "include": [{"class": "giraffe", "count": 4}], "exclude": [{"class": "giraffe", "count": 5}], "prompt": "a photo of four giraffes"}
220
+ {"tag": "counting", "include": [{"class": "vase", "count": 2}], "exclude": [{"class": "vase", "count": 3}], "prompt": "a photo of two vases"}
221
+ {"tag": "counting", "include": [{"class": "donut", "count": 4}], "exclude": [{"class": "donut", "count": 5}], "prompt": "a photo of four donuts"}
222
+ {"tag": "counting", "include": [{"class": "chair", "count": 4}], "exclude": [{"class": "chair", "count": 5}], "prompt": "a photo of four chairs"}
223
+ {"tag": "counting", "include": [{"class": "baseball bat", "count": 3}], "exclude": [{"class": "baseball bat", "count": 4}], "prompt": "a photo of three baseball bats"}
224
+ {"tag": "counting", "include": [{"class": "stop sign", "count": 4}], "exclude": [{"class": "stop sign", "count": 5}], "prompt": "a photo of four stop signs"}
225
+ {"tag": "counting", "include": [{"class": "pizza", "count": 2}], "exclude": [{"class": "pizza", "count": 3}], "prompt": "a photo of two pizzas"}
226
+ {"tag": "counting", "include": [{"class": "refrigerator", "count": 3}], "exclude": [{"class": "refrigerator", "count": 4}], "prompt": "a photo of three refrigerators"}
227
+ {"tag": "counting", "include": [{"class": "fire hydrant", "count": 2}], "exclude": [{"class": "fire hydrant", "count": 3}], "prompt": "a photo of two fire hydrants"}
228
+ {"tag": "counting", "include": [{"class": "giraffe", "count": 3}], "exclude": [{"class": "giraffe", "count": 4}], "prompt": "a photo of three giraffes"}
229
+ {"tag": "counting", "include": [{"class": "tv", "count": 4}], "exclude": [{"class": "tv", "count": 5}], "prompt": "a photo of four tvs"}
230
+ {"tag": "counting", "include": [{"class": "wine glass", "count": 3}], "exclude": [{"class": "wine glass", "count": 4}], "prompt": "a photo of three wine glasses"}
231
+ {"tag": "counting", "include": [{"class": "broccoli", "count": 4}], "exclude": [{"class": "broccoli", "count": 5}], "prompt": "a photo of four broccolis"}
232
+ {"tag": "counting", "include": [{"class": "truck", "count": 3}], "exclude": [{"class": "truck", "count": 4}], "prompt": "a photo of three trucks"}
233
+ {"tag": "counting", "include": [{"class": "truck", "count": 2}], "exclude": [{"class": "truck", "count": 3}], "prompt": "a photo of two trucks"}
234
+ {"tag": "counting", "include": [{"class": "carrot", "count": 2}], "exclude": [{"class": "carrot", "count": 3}], "prompt": "a photo of two carrots"}
235
+ {"tag": "counting", "include": [{"class": "sandwich", "count": 2}], "exclude": [{"class": "sandwich", "count": 3}], "prompt": "a photo of two sandwichs"}
236
+ {"tag": "counting", "include": [{"class": "traffic light", "count": 4}], "exclude": [{"class": "traffic light", "count": 5}], "prompt": "a photo of four traffic lights"}
237
+ {"tag": "counting", "include": [{"class": "clock", "count": 4}], "exclude": [{"class": "clock", "count": 5}], "prompt": "a photo of four clocks"}
238
+ {"tag": "counting", "include": [{"class": "car", "count": 2}], "exclude": [{"class": "car", "count": 3}], "prompt": "a photo of two cars"}
239
+ {"tag": "counting", "include": [{"class": "banana", "count": 2}], "exclude": [{"class": "banana", "count": 3}], "prompt": "a photo of two bananas"}
240
+ {"tag": "counting", "include": [{"class": "wine glass", "count": 2}], "exclude": [{"class": "wine glass", "count": 3}], "prompt": "a photo of two wine glasses"}
241
+ {"tag": "counting", "include": [{"class": "pizza", "count": 3}], "exclude": [{"class": "pizza", "count": 4}], "prompt": "a photo of three pizzas"}
242
+ {"tag": "counting", "include": [{"class": "knife", "count": 4}], "exclude": [{"class": "knife", "count": 5}], "prompt": "a photo of four knifes"}
243
+ {"tag": "counting", "include": [{"class": "suitcase", "count": 3}], "exclude": [{"class": "suitcase", "count": 4}], "prompt": "a photo of three suitcases"}
244
+ {"tag": "counting", "include": [{"class": "zebra", "count": 4}], "exclude": [{"class": "zebra", "count": 5}], "prompt": "a photo of four zebras"}
245
+ {"tag": "counting", "include": [{"class": "teddy bear", "count": 2}], "exclude": [{"class": "teddy bear", "count": 3}], "prompt": "a photo of two teddy bears"}
246
+ {"tag": "counting", "include": [{"class": "skateboard", "count": 4}], "exclude": [{"class": "skateboard", "count": 5}], "prompt": "a photo of four skateboards"}
247
+ {"tag": "counting", "include": [{"class": "hot dog", "count": 4}], "exclude": [{"class": "hot dog", "count": 5}], "prompt": "a photo of four hot dogs"}
248
+ {"tag": "counting", "include": [{"class": "bird", "count": 3}], "exclude": [{"class": "bird", "count": 4}], "prompt": "a photo of three birds"}
249
+ {"tag": "counting", "include": [{"class": "boat", "count": 4}], "exclude": [{"class": "boat", "count": 5}], "prompt": "a photo of four boats"}
250
+ {"tag": "counting", "include": [{"class": "microwave", "count": 4}], "exclude": [{"class": "microwave", "count": 5}], "prompt": "a photo of four microwaves"}
251
+ {"tag": "counting", "include": [{"class": "hair drier", "count": 2}], "exclude": [{"class": "hair drier", "count": 3}], "prompt": "a photo of two hair driers"}
252
+ {"tag": "counting", "include": [{"class": "laptop", "count": 3}], "exclude": [{"class": "laptop", "count": 4}], "prompt": "a photo of three laptops"}
253
+ {"tag": "counting", "include": [{"class": "cow", "count": 3}], "exclude": [{"class": "cow", "count": 4}], "prompt": "a photo of three cows"}
254
+ {"tag": "counting", "include": [{"class": "parking meter", "count": 2}], "exclude": [{"class": "parking meter", "count": 3}], "prompt": "a photo of two parking meters"}
255
+ {"tag": "counting", "include": [{"class": "bench", "count": 4}], "exclude": [{"class": "bench", "count": 5}], "prompt": "a photo of four benchs"}
256
+ {"tag": "counting", "include": [{"class": "bench", "count": 3}], "exclude": [{"class": "bench", "count": 4}], "prompt": "a photo of three benchs"}
257
+ {"tag": "counting", "include": [{"class": "frisbee", "count": 4}], "exclude": [{"class": "frisbee", "count": 5}], "prompt": "a photo of four frisbees"}
258
+ {"tag": "counting", "include": [{"class": "book", "count": 4}], "exclude": [{"class": "book", "count": 5}], "prompt": "a photo of four books"}
259
+ {"tag": "counting", "include": [{"class": "bus", "count": 4}], "exclude": [{"class": "bus", "count": 5}], "prompt": "a photo of four buses"}
260
+ {"tag": "colors", "include": [{"class": "fire hydrant", "count": 1, "color": "blue"}], "prompt": "a photo of a blue fire hydrant"}
261
+ {"tag": "colors", "include": [{"class": "car", "count": 1, "color": "pink"}], "prompt": "a photo of a pink car"}
262
+ {"tag": "colors", "include": [{"class": "cup", "count": 1, "color": "purple"}], "prompt": "a photo of a purple cup"}
263
+ {"tag": "colors", "include": [{"class": "cow", "count": 1, "color": "blue"}], "prompt": "a photo of a blue cow"}
264
+ {"tag": "colors", "include": [{"class": "boat", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow boat"}
265
+ {"tag": "colors", "include": [{"class": "umbrella", "count": 1, "color": "blue"}], "prompt": "a photo of a blue umbrella"}
266
+ {"tag": "colors", "include": [{"class": "elephant", "count": 1, "color": "blue"}], "prompt": "a photo of a blue elephant"}
267
+ {"tag": "colors", "include": [{"class": "elephant", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow elephant"}
268
+ {"tag": "colors", "include": [{"class": "bicycle", "count": 1, "color": "red"}], "prompt": "a photo of a red bicycle"}
269
+ {"tag": "colors", "include": [{"class": "suitcase", "count": 1, "color": "purple"}], "prompt": "a photo of a purple suitcase"}
270
+ {"tag": "colors", "include": [{"class": "hair drier", "count": 1, "color": "purple"}], "prompt": "a photo of a purple hair drier"}
271
+ {"tag": "colors", "include": [{"class": "sandwich", "count": 1, "color": "white"}], "prompt": "a photo of a white sandwich"}
272
+ {"tag": "colors", "include": [{"class": "elephant", "count": 1, "color": "purple"}], "prompt": "a photo of a purple elephant"}
273
+ {"tag": "colors", "include": [{"class": "microwave", "count": 1, "color": "green"}], "prompt": "a photo of a green microwave"}
274
+ {"tag": "colors", "include": [{"class": "zebra", "count": 1, "color": "red"}], "prompt": "a photo of a red zebra"}
275
+ {"tag": "colors", "include": [{"class": "apple", "count": 1, "color": "red"}], "prompt": "a photo of a red apple"}
276
+ {"tag": "colors", "include": [{"class": "tv remote", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow tv remote"}
277
+ {"tag": "colors", "include": [{"class": "toilet", "count": 1, "color": "blue"}], "prompt": "a photo of a blue toilet"}
278
+ {"tag": "colors", "include": [{"class": "orange", "count": 1, "color": "orange"}], "prompt": "a photo of an orange orange"}
279
+ {"tag": "colors", "include": [{"class": "donut", "count": 1, "color": "black"}], "prompt": "a photo of a black donut"}
280
+ {"tag": "colors", "include": [{"class": "vase", "count": 1, "color": "red"}], "prompt": "a photo of a red vase"}
281
+ {"tag": "colors", "include": [{"class": "pizza", "count": 1, "color": "purple"}], "prompt": "a photo of a purple pizza"}
282
+ {"tag": "colors", "include": [{"class": "skateboard", "count": 1, "color": "pink"}], "prompt": "a photo of a pink skateboard"}
283
+ {"tag": "colors", "include": [{"class": "skateboard", "count": 1, "color": "green"}], "prompt": "a photo of a green skateboard"}
284
+ {"tag": "colors", "include": [{"class": "bear", "count": 1, "color": "purple"}], "prompt": "a photo of a purple bear"}
285
+ {"tag": "colors", "include": [{"class": "chair", "count": 1, "color": "brown"}], "prompt": "a photo of a brown chair"}
286
+ {"tag": "colors", "include": [{"class": "computer keyboard", "count": 1, "color": "brown"}], "prompt": "a photo of a brown computer keyboard"}
287
+ {"tag": "colors", "include": [{"class": "cow", "count": 1, "color": "orange"}], "prompt": "a photo of an orange cow"}
288
+ {"tag": "colors", "include": [{"class": "skis", "count": 1, "color": "brown"}], "prompt": "a photo of a brown skis"}
289
+ {"tag": "colors", "include": [{"class": "kite", "count": 1, "color": "white"}], "prompt": "a photo of a white kite"}
290
+ {"tag": "colors", "include": [{"class": "dog", "count": 1, "color": "red"}], "prompt": "a photo of a red dog"}
291
+ {"tag": "colors", "include": [{"class": "couch", "count": 1, "color": "green"}], "prompt": "a photo of a green couch"}
292
+ {"tag": "colors", "include": [{"class": "airplane", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow airplane"}
293
+ {"tag": "colors", "include": [{"class": "tv", "count": 1, "color": "orange"}], "prompt": "a photo of an orange tv"}
294
+ {"tag": "colors", "include": [{"class": "scissors", "count": 1, "color": "white"}], "prompt": "a photo of a white scissors"}
295
+ {"tag": "colors", "include": [{"class": "cell phone", "count": 1, "color": "pink"}], "prompt": "a photo of a pink cell phone"}
296
+ {"tag": "colors", "include": [{"class": "surfboard", "count": 1, "color": "green"}], "prompt": "a photo of a green surfboard"}
297
+ {"tag": "colors", "include": [{"class": "fire hydrant", "count": 1, "color": "white"}], "prompt": "a photo of a white fire hydrant"}
298
+ {"tag": "colors", "include": [{"class": "bicycle", "count": 1, "color": "black"}], "prompt": "a photo of a black bicycle"}
299
+ {"tag": "colors", "include": [{"class": "carrot", "count": 1, "color": "purple"}], "prompt": "a photo of a purple carrot"}
300
+ {"tag": "colors", "include": [{"class": "dining table", "count": 1, "color": "black"}], "prompt": "a photo of a black dining table"}
301
+ {"tag": "colors", "include": [{"class": "potted plant", "count": 1, "color": "purple"}], "prompt": "a photo of a purple potted plant"}
302
+ {"tag": "colors", "include": [{"class": "backpack", "count": 1, "color": "purple"}], "prompt": "a photo of a purple backpack"}
303
+ {"tag": "colors", "include": [{"class": "train", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow train"}
304
+ {"tag": "colors", "include": [{"class": "potted plant", "count": 1, "color": "pink"}], "prompt": "a photo of a pink potted plant"}
305
+ {"tag": "colors", "include": [{"class": "giraffe", "count": 1, "color": "red"}], "prompt": "a photo of a red giraffe"}
306
+ {"tag": "colors", "include": [{"class": "bear", "count": 1, "color": "brown"}], "prompt": "a photo of a brown bear"}
307
+ {"tag": "colors", "include": [{"class": "train", "count": 1, "color": "black"}], "prompt": "a photo of a black train"}
308
+ {"tag": "colors", "include": [{"class": "laptop", "count": 1, "color": "orange"}], "prompt": "a photo of an orange laptop"}
309
+ {"tag": "colors", "include": [{"class": "hot dog", "count": 1, "color": "green"}], "prompt": "a photo of a green hot dog"}
310
+ {"tag": "colors", "include": [{"class": "parking meter", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow parking meter"}
311
+ {"tag": "colors", "include": [{"class": "potted plant", "count": 1, "color": "red"}], "prompt": "a photo of a red potted plant"}
312
+ {"tag": "colors", "include": [{"class": "traffic light", "count": 1, "color": "green"}], "prompt": "a photo of a green traffic light"}
313
+ {"tag": "colors", "include": [{"class": "tv", "count": 1, "color": "blue"}], "prompt": "a photo of a blue tv"}
314
+ {"tag": "colors", "include": [{"class": "refrigerator", "count": 1, "color": "brown"}], "prompt": "a photo of a brown refrigerator"}
315
+ {"tag": "colors", "include": [{"class": "tv remote", "count": 1, "color": "black"}], "prompt": "a photo of a black tv remote"}
316
+ {"tag": "colors", "include": [{"class": "scissors", "count": 1, "color": "purple"}], "prompt": "a photo of a purple scissors"}
317
+ {"tag": "colors", "include": [{"class": "orange", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow orange"}
318
+ {"tag": "colors", "include": [{"class": "toaster", "count": 1, "color": "brown"}], "prompt": "a photo of a brown toaster"}
319
+ {"tag": "colors", "include": [{"class": "parking meter", "count": 1, "color": "red"}], "prompt": "a photo of a red parking meter"}
320
+ {"tag": "colors", "include": [{"class": "orange", "count": 1, "color": "brown"}], "prompt": "a photo of a brown orange"}
321
+ {"tag": "colors", "include": [{"class": "clock", "count": 1, "color": "green"}], "prompt": "a photo of a green clock"}
322
+ {"tag": "colors", "include": [{"class": "sheep", "count": 1, "color": "white"}], "prompt": "a photo of a white sheep"}
323
+ {"tag": "colors", "include": [{"class": "oven", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow oven"}
324
+ {"tag": "colors", "include": [{"class": "vase", "count": 1, "color": "green"}], "prompt": "a photo of a green vase"}
325
+ {"tag": "colors", "include": [{"class": "teddy bear", "count": 1, "color": "black"}], "prompt": "a photo of a black teddy bear"}
326
+ {"tag": "colors", "include": [{"class": "carrot", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow carrot"}
327
+ {"tag": "colors", "include": [{"class": "hot dog", "count": 1, "color": "black"}], "prompt": "a photo of a black hot dog"}
328
+ {"tag": "colors", "include": [{"class": "scissors", "count": 1, "color": "red"}], "prompt": "a photo of a red scissors"}
329
+ {"tag": "colors", "include": [{"class": "teddy bear", "count": 1, "color": "white"}], "prompt": "a photo of a white teddy bear"}
330
+ {"tag": "colors", "include": [{"class": "skis", "count": 1, "color": "black"}], "prompt": "a photo of a black skis"}
331
+ {"tag": "colors", "include": [{"class": "dining table", "count": 1, "color": "blue"}], "prompt": "a photo of a blue dining table"}
332
+ {"tag": "colors", "include": [{"class": "refrigerator", "count": 1, "color": "black"}], "prompt": "a photo of a black refrigerator"}
333
+ {"tag": "colors", "include": [{"class": "dog", "count": 1, "color": "white"}], "prompt": "a photo of a white dog"}
334
+ {"tag": "colors", "include": [{"class": "scissors", "count": 1, "color": "orange"}], "prompt": "a photo of an orange scissors"}
335
+ {"tag": "colors", "include": [{"class": "cell phone", "count": 1, "color": "red"}], "prompt": "a photo of a red cell phone"}
336
+ {"tag": "colors", "include": [{"class": "orange", "count": 1, "color": "white"}], "prompt": "a photo of a white orange"}
337
+ {"tag": "colors", "include": [{"class": "clock", "count": 1, "color": "blue"}], "prompt": "a photo of a blue clock"}
338
+ {"tag": "colors", "include": [{"class": "carrot", "count": 1, "color": "blue"}], "prompt": "a photo of a blue carrot"}
339
+ {"tag": "colors", "include": [{"class": "motorcycle", "count": 1, "color": "green"}], "prompt": "a photo of a green motorcycle"}
340
+ {"tag": "colors", "include": [{"class": "stop sign", "count": 1, "color": "pink"}], "prompt": "a photo of a pink stop sign"}
341
+ {"tag": "colors", "include": [{"class": "vase", "count": 1, "color": "black"}], "prompt": "a photo of a black vase"}
342
+ {"tag": "colors", "include": [{"class": "backpack", "count": 1, "color": "black"}], "prompt": "a photo of a black backpack"}
343
+ {"tag": "colors", "include": [{"class": "car", "count": 1, "color": "red"}], "prompt": "a photo of a red car"}
344
+ {"tag": "colors", "include": [{"class": "computer mouse", "count": 1, "color": "green"}], "prompt": "a photo of a green computer mouse"}
345
+ {"tag": "colors", "include": [{"class": "backpack", "count": 1, "color": "red"}], "prompt": "a photo of a red backpack"}
346
+ {"tag": "colors", "include": [{"class": "bus", "count": 1, "color": "green"}], "prompt": "a photo of a green bus"}
347
+ {"tag": "colors", "include": [{"class": "toaster", "count": 1, "color": "orange"}], "prompt": "a photo of an orange toaster"}
348
+ {"tag": "colors", "include": [{"class": "fork", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow fork"}
349
+ {"tag": "colors", "include": [{"class": "parking meter", "count": 1, "color": "pink"}], "prompt": "a photo of a pink parking meter"}
350
+ {"tag": "colors", "include": [{"class": "book", "count": 1, "color": "blue"}], "prompt": "a photo of a blue book"}
351
+ {"tag": "colors", "include": [{"class": "broccoli", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow broccoli"}
352
+ {"tag": "colors", "include": [{"class": "computer mouse", "count": 1, "color": "orange"}], "prompt": "a photo of an orange computer mouse"}
353
+ {"tag": "colors", "include": [{"class": "cake", "count": 1, "color": "red"}], "prompt": "a photo of a red cake"}
354
+ {"tag": "position", "include": [{"class": "teddy bear", "count": 1}, {"class": "dog", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a dog right of a teddy bear"}
355
+ {"tag": "position", "include": [{"class": "kite", "count": 1}, {"class": "wine glass", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a wine glass above a kite"}
356
+ {"tag": "position", "include": [{"class": "cup", "count": 1}, {"class": "couch", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a couch below a cup"}
357
+ {"tag": "position", "include": [{"class": "cow", "count": 1}, {"class": "laptop", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a laptop left of a cow"}
358
+ {"tag": "position", "include": [{"class": "hair drier", "count": 1}, {"class": "fork", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a fork above a hair drier"}
359
+ {"tag": "position", "include": [{"class": "baseball bat", "count": 1}, {"class": "tie", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a tie right of a baseball bat"}
360
+ {"tag": "position", "include": [{"class": "fork", "count": 1}, {"class": "stop sign", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a stop sign above a fork"}
361
+ {"tag": "position", "include": [{"class": "skateboard", "count": 1}, {"class": "bird", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a bird below a skateboard"}
362
+ {"tag": "position", "include": [{"class": "tv", "count": 1}, {"class": "apple", "count": 1, "position": ["above", 0]}], "prompt": "a photo of an apple above a tv"}
363
+ {"tag": "position", "include": [{"class": "potted plant", "count": 1}, {"class": "train", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a train above a potted plant"}
364
+ {"tag": "position", "include": [{"class": "refrigerator", "count": 1}, {"class": "truck", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a truck left of a refrigerator"}
365
+ {"tag": "position", "include": [{"class": "cow", "count": 1}, {"class": "tv remote", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a tv remote below a cow"}
366
+ {"tag": "position", "include": [{"class": "train", "count": 1}, {"class": "bottle", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a bottle right of a train"}
367
+ {"tag": "position", "include": [{"class": "cow", "count": 1}, {"class": "dog", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a dog above a cow"}
368
+ {"tag": "position", "include": [{"class": "person", "count": 1}, {"class": "skateboard", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a skateboard above a person"}
369
+ {"tag": "position", "include": [{"class": "umbrella", "count": 1}, {"class": "baseball glove", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a baseball glove below an umbrella"}
370
+ {"tag": "position", "include": [{"class": "oven", "count": 1}, {"class": "dining table", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a dining table right of an oven"}
371
+ {"tag": "position", "include": [{"class": "suitcase", "count": 1}, {"class": "hot dog", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a hot dog left of a suitcase"}
372
+ {"tag": "position", "include": [{"class": "toothbrush", "count": 1}, {"class": "bus", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a bus below a toothbrush"}
373
+ {"tag": "position", "include": [{"class": "sandwich", "count": 1}, {"class": "backpack", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a backpack right of a sandwich"}
374
+ {"tag": "position", "include": [{"class": "baseball bat", "count": 1}, {"class": "cake", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a cake below a baseball bat"}
375
+ {"tag": "position", "include": [{"class": "tie", "count": 1}, {"class": "dog", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a dog right of a tie"}
376
+ {"tag": "position", "include": [{"class": "boat", "count": 1}, {"class": "suitcase", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a suitcase right of a boat"}
377
+ {"tag": "position", "include": [{"class": "clock", "count": 1}, {"class": "bear", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a bear above a clock"}
378
+ {"tag": "position", "include": [{"class": "umbrella", "count": 1}, {"class": "tv remote", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a tv remote left of an umbrella"}
379
+ {"tag": "position", "include": [{"class": "umbrella", "count": 1}, {"class": "sports ball", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a sports ball left of an umbrella"}
380
+ {"tag": "position", "include": [{"class": "dining table", "count": 1}, {"class": "train", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a train right of a dining table"}
381
+ {"tag": "position", "include": [{"class": "elephant", "count": 1}, {"class": "hair drier", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a hair drier below an elephant"}
382
+ {"tag": "position", "include": [{"class": "spoon", "count": 1}, {"class": "tennis racket", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a tennis racket right of a spoon"}
383
+ {"tag": "position", "include": [{"class": "hot dog", "count": 1}, {"class": "wine glass", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a wine glass right of a hot dog"}
384
+ {"tag": "position", "include": [{"class": "bench", "count": 1}, {"class": "computer mouse", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a computer mouse left of a bench"}
385
+ {"tag": "position", "include": [{"class": "orange", "count": 1}, {"class": "carrot", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a carrot left of an orange"}
386
+ {"tag": "position", "include": [{"class": "toothbrush", "count": 1}, {"class": "kite", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a kite above a toothbrush"}
387
+ {"tag": "position", "include": [{"class": "traffic light", "count": 1}, {"class": "toaster", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a toaster below a traffic light"}
388
+ {"tag": "position", "include": [{"class": "baseball glove", "count": 1}, {"class": "cat", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a cat below a baseball glove"}
389
+ {"tag": "position", "include": [{"class": "zebra", "count": 1}, {"class": "skis", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a skis right of a zebra"}
390
+ {"tag": "position", "include": [{"class": "chair", "count": 1}, {"class": "stop sign", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a stop sign above a chair"}
391
+ {"tag": "position", "include": [{"class": "parking meter", "count": 1}, {"class": "stop sign", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a stop sign above a parking meter"}
392
+ {"tag": "position", "include": [{"class": "skateboard", "count": 1}, {"class": "hot dog", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a hot dog right of a skateboard"}
393
+ {"tag": "position", "include": [{"class": "computer keyboard", "count": 1}, {"class": "pizza", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a pizza below a computer keyboard"}
394
+ {"tag": "position", "include": [{"class": "toilet", "count": 1}, {"class": "hair drier", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a hair drier left of a toilet"}
395
+ {"tag": "position", "include": [{"class": "stop sign", "count": 1}, {"class": "cow", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a cow left of a stop sign"}
396
+ {"tag": "position", "include": [{"class": "skis", "count": 1}, {"class": "suitcase", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a suitcase above a skis"}
397
+ {"tag": "position", "include": [{"class": "laptop", "count": 1}, {"class": "book", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a book above a laptop"}
398
+ {"tag": "position", "include": [{"class": "pizza", "count": 1}, {"class": "toothbrush", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a toothbrush below a pizza"}
399
+ {"tag": "position", "include": [{"class": "kite", "count": 1}, {"class": "toilet", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a toilet left of a kite"}
400
+ {"tag": "position", "include": [{"class": "sink", "count": 1}, {"class": "tie", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a tie above a sink"}
401
+ {"tag": "position", "include": [{"class": "couch", "count": 1}, {"class": "bird", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a bird left of a couch"}
402
+ {"tag": "position", "include": [{"class": "sports ball", "count": 1}, {"class": "bed", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a bed right of a sports ball"}
403
+ {"tag": "position", "include": [{"class": "surfboard", "count": 1}, {"class": "elephant", "count": 1, "position": ["below", 0]}], "prompt": "a photo of an elephant below a surfboard"}
404
+ {"tag": "position", "include": [{"class": "motorcycle", "count": 1}, {"class": "frisbee", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a frisbee right of a motorcycle"}
405
+ {"tag": "position", "include": [{"class": "fire hydrant", "count": 1}, {"class": "vase", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a vase above a fire hydrant"}
406
+ {"tag": "position", "include": [{"class": "elephant", "count": 1}, {"class": "zebra", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a zebra left of an elephant"}
407
+ {"tag": "position", "include": [{"class": "bear", "count": 1}, {"class": "bench", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a bench left of a bear"}
408
+ {"tag": "position", "include": [{"class": "bench", "count": 1}, {"class": "donut", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a donut right of a bench"}
409
+ {"tag": "position", "include": [{"class": "horse", "count": 1}, {"class": "frisbee", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a frisbee below a horse"}
410
+ {"tag": "position", "include": [{"class": "snowboard", "count": 1}, {"class": "computer keyboard", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a computer keyboard above a snowboard"}
411
+ {"tag": "position", "include": [{"class": "cow", "count": 1}, {"class": "tv", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a tv below a cow"}
412
+ {"tag": "position", "include": [{"class": "horse", "count": 1}, {"class": "elephant", "count": 1, "position": ["below", 0]}], "prompt": "a photo of an elephant below a horse"}
413
+ {"tag": "position", "include": [{"class": "banana", "count": 1}, {"class": "suitcase", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a suitcase left of a banana"}
414
+ {"tag": "position", "include": [{"class": "airplane", "count": 1}, {"class": "train", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a train below an airplane"}
415
+ {"tag": "position", "include": [{"class": "backpack", "count": 1}, {"class": "cat", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a cat below a backpack"}
416
+ {"tag": "position", "include": [{"class": "cake", "count": 1}, {"class": "backpack", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a backpack below a cake"}
417
+ {"tag": "position", "include": [{"class": "knife", "count": 1}, {"class": "sandwich", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a sandwich below a knife"}
418
+ {"tag": "position", "include": [{"class": "parking meter", "count": 1}, {"class": "bicycle", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a bicycle above a parking meter"}
419
+ {"tag": "position", "include": [{"class": "suitcase", "count": 1}, {"class": "knife", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a knife right of a suitcase"}
420
+ {"tag": "position", "include": [{"class": "knife", "count": 1}, {"class": "hot dog", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a hot dog above a knife"}
421
+ {"tag": "position", "include": [{"class": "parking meter", "count": 1}, {"class": "zebra", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a zebra right of a parking meter"}
422
+ {"tag": "position", "include": [{"class": "zebra", "count": 1}, {"class": "chair", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a chair left of a zebra"}
423
+ {"tag": "position", "include": [{"class": "airplane", "count": 1}, {"class": "cow", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a cow below an airplane"}
424
+ {"tag": "position", "include": [{"class": "umbrella", "count": 1}, {"class": "cup", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a cup left of an umbrella"}
425
+ {"tag": "position", "include": [{"class": "computer keyboard", "count": 1}, {"class": "zebra", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a zebra below a computer keyboard"}
426
+ {"tag": "position", "include": [{"class": "broccoli", "count": 1}, {"class": "zebra", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a zebra below a broccoli"}
427
+ {"tag": "position", "include": [{"class": "sports ball", "count": 1}, {"class": "laptop", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a laptop below a sports ball"}
428
+ {"tag": "position", "include": [{"class": "baseball bat", "count": 1}, {"class": "truck", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a truck left of a baseball bat"}
429
+ {"tag": "position", "include": [{"class": "baseball bat", "count": 1}, {"class": "refrigerator", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a refrigerator above a baseball bat"}
430
+ {"tag": "position", "include": [{"class": "baseball bat", "count": 1}, {"class": "tv", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a tv above a baseball bat"}
431
+ {"tag": "position", "include": [{"class": "bear", "count": 1}, {"class": "baseball glove", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a baseball glove right of a bear"}
432
+ {"tag": "position", "include": [{"class": "scissors", "count": 1}, {"class": "refrigerator", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a refrigerator below a scissors"}
433
+ {"tag": "position", "include": [{"class": "suitcase", "count": 1}, {"class": "dining table", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a dining table above a suitcase"}
434
+ {"tag": "position", "include": [{"class": "broccoli", "count": 1}, {"class": "parking meter", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a parking meter above a broccoli"}
435
+ {"tag": "position", "include": [{"class": "truck", "count": 1}, {"class": "frisbee", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a frisbee above a truck"}
436
+ {"tag": "position", "include": [{"class": "banana", "count": 1}, {"class": "pizza", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a pizza right of a banana"}
437
+ {"tag": "position", "include": [{"class": "boat", "count": 1}, {"class": "bus", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a bus above a boat"}
438
+ {"tag": "position", "include": [{"class": "tennis racket", "count": 1}, {"class": "cell phone", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a cell phone left of a tennis racket"}
439
+ {"tag": "position", "include": [{"class": "broccoli", "count": 1}, {"class": "horse", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a horse right of a broccoli"}
440
+ {"tag": "position", "include": [{"class": "bottle", "count": 1}, {"class": "broccoli", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a broccoli above a bottle"}
441
+ {"tag": "position", "include": [{"class": "horse", "count": 1}, {"class": "vase", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a vase right of a horse"}
442
+ {"tag": "position", "include": [{"class": "spoon", "count": 1}, {"class": "bear", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a bear above a spoon"}
443
+ {"tag": "position", "include": [{"class": "bed", "count": 1}, {"class": "zebra", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a zebra right of a bed"}
444
+ {"tag": "position", "include": [{"class": "laptop", "count": 1}, {"class": "cow", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a cow right of a laptop"}
445
+ {"tag": "position", "include": [{"class": "frisbee", "count": 1}, {"class": "bed", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a bed right of a frisbee"}
446
+ {"tag": "position", "include": [{"class": "motorcycle", "count": 1}, {"class": "tie", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a tie right of a motorcycle"}
447
+ {"tag": "position", "include": [{"class": "tv", "count": 1}, {"class": "laptop", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a laptop right of a tv"}
448
+ {"tag": "position", "include": [{"class": "chair", "count": 1}, {"class": "cell phone", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a cell phone right of a chair"}
449
+ {"tag": "position", "include": [{"class": "potted plant", "count": 1}, {"class": "couch", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a couch below a potted plant"}
450
+ {"tag": "position", "include": [{"class": "tv", "count": 1}, {"class": "clock", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a clock below a tv"}
451
+ {"tag": "position", "include": [{"class": "vase", "count": 1}, {"class": "couch", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a couch below a vase"}
452
+ {"tag": "position", "include": [{"class": "cat", "count": 1}, {"class": "donut", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a donut below a cat"}
453
+ {"tag": "position", "include": [{"class": "toaster", "count": 1}, {"class": "couch", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a couch left of a toaster"}
454
+ {"tag": "color_attr", "include": [{"class": "wine glass", "count": 1, "color": "purple"}, {"class": "apple", "count": 1, "color": "black"}], "prompt": "a photo of a purple wine glass and a black apple"}
455
+ {"tag": "color_attr", "include": [{"class": "bus", "count": 1, "color": "green"}, {"class": "microwave", "count": 1, "color": "purple"}], "prompt": "a photo of a green bus and a purple microwave"}
456
+ {"tag": "color_attr", "include": [{"class": "skis", "count": 1, "color": "green"}, {"class": "airplane", "count": 1, "color": "brown"}], "prompt": "a photo of a green skis and a brown airplane"}
457
+ {"tag": "color_attr", "include": [{"class": "computer keyboard", "count": 1, "color": "yellow"}, {"class": "sink", "count": 1, "color": "black"}], "prompt": "a photo of a yellow computer keyboard and a black sink"}
458
+ {"tag": "color_attr", "include": [{"class": "oven", "count": 1, "color": "pink"}, {"class": "motorcycle", "count": 1, "color": "green"}], "prompt": "a photo of a pink oven and a green motorcycle"}
459
+ {"tag": "color_attr", "include": [{"class": "parking meter", "count": 1, "color": "purple"}, {"class": "laptop", "count": 1, "color": "red"}], "prompt": "a photo of a purple parking meter and a red laptop"}
460
+ {"tag": "color_attr", "include": [{"class": "skateboard", "count": 1, "color": "yellow"}, {"class": "computer mouse", "count": 1, "color": "orange"}], "prompt": "a photo of a yellow skateboard and an orange computer mouse"}
461
+ {"tag": "color_attr", "include": [{"class": "skis", "count": 1, "color": "red"}, {"class": "tie", "count": 1, "color": "brown"}], "prompt": "a photo of a red skis and a brown tie"}
462
+ {"tag": "color_attr", "include": [{"class": "skateboard", "count": 1, "color": "pink"}, {"class": "train", "count": 1, "color": "black"}], "prompt": "a photo of a pink skateboard and a black train"}
463
+ {"tag": "color_attr", "include": [{"class": "handbag", "count": 1, "color": "white"}, {"class": "bed", "count": 1, "color": "purple"}], "prompt": "a photo of a white handbag and a purple bed"}
464
+ {"tag": "color_attr", "include": [{"class": "elephant", "count": 1, "color": "purple"}, {"class": "sports ball", "count": 1, "color": "brown"}], "prompt": "a photo of a purple elephant and a brown sports ball"}
465
+ {"tag": "color_attr", "include": [{"class": "dog", "count": 1, "color": "purple"}, {"class": "dining table", "count": 1, "color": "black"}], "prompt": "a photo of a purple dog and a black dining table"}
466
+ {"tag": "color_attr", "include": [{"class": "dining table", "count": 1, "color": "white"}, {"class": "car", "count": 1, "color": "red"}], "prompt": "a photo of a white dining table and a red car"}
467
+ {"tag": "color_attr", "include": [{"class": "cell phone", "count": 1, "color": "blue"}, {"class": "apple", "count": 1, "color": "green"}], "prompt": "a photo of a blue cell phone and a green apple"}
468
+ {"tag": "color_attr", "include": [{"class": "car", "count": 1, "color": "red"}, {"class": "potted plant", "count": 1, "color": "orange"}], "prompt": "a photo of a red car and an orange potted plant"}
469
+ {"tag": "color_attr", "include": [{"class": "carrot", "count": 1, "color": "brown"}, {"class": "potted plant", "count": 1, "color": "white"}], "prompt": "a photo of a brown carrot and a white potted plant"}
470
+ {"tag": "color_attr", "include": [{"class": "kite", "count": 1, "color": "black"}, {"class": "bear", "count": 1, "color": "green"}], "prompt": "a photo of a black kite and a green bear"}
471
+ {"tag": "color_attr", "include": [{"class": "laptop", "count": 1, "color": "blue"}, {"class": "bear", "count": 1, "color": "brown"}], "prompt": "a photo of a blue laptop and a brown bear"}
472
+ {"tag": "color_attr", "include": [{"class": "teddy bear", "count": 1, "color": "green"}, {"class": "kite", "count": 1, "color": "brown"}], "prompt": "a photo of a green teddy bear and a brown kite"}
473
+ {"tag": "color_attr", "include": [{"class": "stop sign", "count": 1, "color": "yellow"}, {"class": "potted plant", "count": 1, "color": "blue"}], "prompt": "a photo of a yellow stop sign and a blue potted plant"}
474
+ {"tag": "color_attr", "include": [{"class": "snowboard", "count": 1, "color": "orange"}, {"class": "cat", "count": 1, "color": "green"}], "prompt": "a photo of an orange snowboard and a green cat"}
475
+ {"tag": "color_attr", "include": [{"class": "truck", "count": 1, "color": "orange"}, {"class": "sink", "count": 1, "color": "pink"}], "prompt": "a photo of an orange truck and a pink sink"}
476
+ {"tag": "color_attr", "include": [{"class": "hot dog", "count": 1, "color": "brown"}, {"class": "pizza", "count": 1, "color": "purple"}], "prompt": "a photo of a brown hot dog and a purple pizza"}
477
+ {"tag": "color_attr", "include": [{"class": "couch", "count": 1, "color": "green"}, {"class": "umbrella", "count": 1, "color": "orange"}], "prompt": "a photo of a green couch and an orange umbrella"}
478
+ {"tag": "color_attr", "include": [{"class": "bed", "count": 1, "color": "brown"}, {"class": "cell phone", "count": 1, "color": "pink"}], "prompt": "a photo of a brown bed and a pink cell phone"}
479
+ {"tag": "color_attr", "include": [{"class": "broccoli", "count": 1, "color": "black"}, {"class": "cake", "count": 1, "color": "yellow"}], "prompt": "a photo of a black broccoli and a yellow cake"}
480
+ {"tag": "color_attr", "include": [{"class": "train", "count": 1, "color": "red"}, {"class": "bear", "count": 1, "color": "purple"}], "prompt": "a photo of a red train and a purple bear"}
481
+ {"tag": "color_attr", "include": [{"class": "tennis racket", "count": 1, "color": "purple"}, {"class": "sink", "count": 1, "color": "black"}], "prompt": "a photo of a purple tennis racket and a black sink"}
482
+ {"tag": "color_attr", "include": [{"class": "vase", "count": 1, "color": "blue"}, {"class": "banana", "count": 1, "color": "black"}], "prompt": "a photo of a blue vase and a black banana"}
483
+ {"tag": "color_attr", "include": [{"class": "clock", "count": 1, "color": "blue"}, {"class": "cup", "count": 1, "color": "white"}], "prompt": "a photo of a blue clock and a white cup"}
484
+ {"tag": "color_attr", "include": [{"class": "umbrella", "count": 1, "color": "red"}, {"class": "couch", "count": 1, "color": "blue"}], "prompt": "a photo of a red umbrella and a blue couch"}
485
+ {"tag": "color_attr", "include": [{"class": "handbag", "count": 1, "color": "white"}, {"class": "giraffe", "count": 1, "color": "red"}], "prompt": "a photo of a white handbag and a red giraffe"}
486
+ {"tag": "color_attr", "include": [{"class": "tv remote", "count": 1, "color": "pink"}, {"class": "airplane", "count": 1, "color": "blue"}], "prompt": "a photo of a pink tv remote and a blue airplane"}
487
+ {"tag": "color_attr", "include": [{"class": "handbag", "count": 1, "color": "pink"}, {"class": "scissors", "count": 1, "color": "black"}], "prompt": "a photo of a pink handbag and a black scissors"}
488
+ {"tag": "color_attr", "include": [{"class": "car", "count": 1, "color": "brown"}, {"class": "hair drier", "count": 1, "color": "pink"}], "prompt": "a photo of a brown car and a pink hair drier"}
489
+ {"tag": "color_attr", "include": [{"class": "bus", "count": 1, "color": "black"}, {"class": "cell phone", "count": 1, "color": "brown"}], "prompt": "a photo of a black bus and a brown cell phone"}
490
+ {"tag": "color_attr", "include": [{"class": "sheep", "count": 1, "color": "purple"}, {"class": "banana", "count": 1, "color": "pink"}], "prompt": "a photo of a purple sheep and a pink banana"}
491
+ {"tag": "color_attr", "include": [{"class": "handbag", "count": 1, "color": "blue"}, {"class": "cell phone", "count": 1, "color": "white"}], "prompt": "a photo of a blue handbag and a white cell phone"}
492
+ {"tag": "color_attr", "include": [{"class": "pizza", "count": 1, "color": "white"}, {"class": "umbrella", "count": 1, "color": "green"}], "prompt": "a photo of a white pizza and a green umbrella"}
493
+ {"tag": "color_attr", "include": [{"class": "tie", "count": 1, "color": "white"}, {"class": "skateboard", "count": 1, "color": "purple"}], "prompt": "a photo of a white tie and a purple skateboard"}
494
+ {"tag": "color_attr", "include": [{"class": "sports ball", "count": 1, "color": "yellow"}, {"class": "boat", "count": 1, "color": "green"}], "prompt": "a photo of a yellow sports ball and a green boat"}
495
+ {"tag": "color_attr", "include": [{"class": "wine glass", "count": 1, "color": "white"}, {"class": "giraffe", "count": 1, "color": "brown"}], "prompt": "a photo of a white wine glass and a brown giraffe"}
496
+ {"tag": "color_attr", "include": [{"class": "bowl", "count": 1, "color": "yellow"}, {"class": "baseball glove", "count": 1, "color": "white"}], "prompt": "a photo of a yellow bowl and a white baseball glove"}
497
+ {"tag": "color_attr", "include": [{"class": "microwave", "count": 1, "color": "orange"}, {"class": "spoon", "count": 1, "color": "black"}], "prompt": "a photo of an orange microwave and a black spoon"}
498
+ {"tag": "color_attr", "include": [{"class": "skateboard", "count": 1, "color": "orange"}, {"class": "bowl", "count": 1, "color": "pink"}], "prompt": "a photo of an orange skateboard and a pink bowl"}
499
+ {"tag": "color_attr", "include": [{"class": "toilet", "count": 1, "color": "blue"}, {"class": "suitcase", "count": 1, "color": "white"}], "prompt": "a photo of a blue toilet and a white suitcase"}
500
+ {"tag": "color_attr", "include": [{"class": "boat", "count": 1, "color": "white"}, {"class": "hot dog", "count": 1, "color": "orange"}], "prompt": "a photo of a white boat and an orange hot dog"}
501
+ {"tag": "color_attr", "include": [{"class": "dining table", "count": 1, "color": "yellow"}, {"class": "dog", "count": 1, "color": "pink"}], "prompt": "a photo of a yellow dining table and a pink dog"}
502
+ {"tag": "color_attr", "include": [{"class": "cake", "count": 1, "color": "red"}, {"class": "chair", "count": 1, "color": "purple"}], "prompt": "a photo of a red cake and a purple chair"}
503
+ {"tag": "color_attr", "include": [{"class": "tie", "count": 1, "color": "blue"}, {"class": "dining table", "count": 1, "color": "pink"}], "prompt": "a photo of a blue tie and a pink dining table"}
504
+ {"tag": "color_attr", "include": [{"class": "cow", "count": 1, "color": "blue"}, {"class": "computer keyboard", "count": 1, "color": "black"}], "prompt": "a photo of a blue cow and a black computer keyboard"}
505
+ {"tag": "color_attr", "include": [{"class": "pizza", "count": 1, "color": "yellow"}, {"class": "oven", "count": 1, "color": "green"}], "prompt": "a photo of a yellow pizza and a green oven"}
506
+ {"tag": "color_attr", "include": [{"class": "laptop", "count": 1, "color": "red"}, {"class": "car", "count": 1, "color": "brown"}], "prompt": "a photo of a red laptop and a brown car"}
507
+ {"tag": "color_attr", "include": [{"class": "computer keyboard", "count": 1, "color": "purple"}, {"class": "scissors", "count": 1, "color": "blue"}], "prompt": "a photo of a purple computer keyboard and a blue scissors"}
508
+ {"tag": "color_attr", "include": [{"class": "surfboard", "count": 1, "color": "green"}, {"class": "oven", "count": 1, "color": "orange"}], "prompt": "a photo of a green surfboard and an orange oven"}
509
+ {"tag": "color_attr", "include": [{"class": "parking meter", "count": 1, "color": "yellow"}, {"class": "refrigerator", "count": 1, "color": "pink"}], "prompt": "a photo of a yellow parking meter and a pink refrigerator"}
510
+ {"tag": "color_attr", "include": [{"class": "computer mouse", "count": 1, "color": "brown"}, {"class": "bottle", "count": 1, "color": "purple"}], "prompt": "a photo of a brown computer mouse and a purple bottle"}
511
+ {"tag": "color_attr", "include": [{"class": "umbrella", "count": 1, "color": "red"}, {"class": "cow", "count": 1, "color": "green"}], "prompt": "a photo of a red umbrella and a green cow"}
512
+ {"tag": "color_attr", "include": [{"class": "giraffe", "count": 1, "color": "red"}, {"class": "cell phone", "count": 1, "color": "black"}], "prompt": "a photo of a red giraffe and a black cell phone"}
513
+ {"tag": "color_attr", "include": [{"class": "oven", "count": 1, "color": "brown"}, {"class": "train", "count": 1, "color": "purple"}], "prompt": "a photo of a brown oven and a purple train"}
514
+ {"tag": "color_attr", "include": [{"class": "baseball bat", "count": 1, "color": "blue"}, {"class": "book", "count": 1, "color": "pink"}], "prompt": "a photo of a blue baseball bat and a pink book"}
515
+ {"tag": "color_attr", "include": [{"class": "cup", "count": 1, "color": "green"}, {"class": "bowl", "count": 1, "color": "yellow"}], "prompt": "a photo of a green cup and a yellow bowl"}
516
+ {"tag": "color_attr", "include": [{"class": "suitcase", "count": 1, "color": "yellow"}, {"class": "bus", "count": 1, "color": "brown"}], "prompt": "a photo of a yellow suitcase and a brown bus"}
517
+ {"tag": "color_attr", "include": [{"class": "motorcycle", "count": 1, "color": "orange"}, {"class": "donut", "count": 1, "color": "pink"}], "prompt": "a photo of an orange motorcycle and a pink donut"}
518
+ {"tag": "color_attr", "include": [{"class": "giraffe", "count": 1, "color": "orange"}, {"class": "baseball glove", "count": 1, "color": "white"}], "prompt": "a photo of an orange giraffe and a white baseball glove"}
519
+ {"tag": "color_attr", "include": [{"class": "handbag", "count": 1, "color": "orange"}, {"class": "carrot", "count": 1, "color": "green"}], "prompt": "a photo of an orange handbag and a green carrot"}
520
+ {"tag": "color_attr", "include": [{"class": "bottle", "count": 1, "color": "black"}, {"class": "refrigerator", "count": 1, "color": "white"}], "prompt": "a photo of a black bottle and a white refrigerator"}
521
+ {"tag": "color_attr", "include": [{"class": "dog", "count": 1, "color": "white"}, {"class": "potted plant", "count": 1, "color": "blue"}], "prompt": "a photo of a white dog and a blue potted plant"}
522
+ {"tag": "color_attr", "include": [{"class": "handbag", "count": 1, "color": "orange"}, {"class": "car", "count": 1, "color": "red"}], "prompt": "a photo of an orange handbag and a red car"}
523
+ {"tag": "color_attr", "include": [{"class": "stop sign", "count": 1, "color": "red"}, {"class": "book", "count": 1, "color": "blue"}], "prompt": "a photo of a red stop sign and a blue book"}
524
+ {"tag": "color_attr", "include": [{"class": "car", "count": 1, "color": "yellow"}, {"class": "toothbrush", "count": 1, "color": "orange"}], "prompt": "a photo of a yellow car and an orange toothbrush"}
525
+ {"tag": "color_attr", "include": [{"class": "potted plant", "count": 1, "color": "black"}, {"class": "toilet", "count": 1, "color": "yellow"}], "prompt": "a photo of a black potted plant and a yellow toilet"}
526
+ {"tag": "color_attr", "include": [{"class": "dining table", "count": 1, "color": "brown"}, {"class": "suitcase", "count": 1, "color": "white"}], "prompt": "a photo of a brown dining table and a white suitcase"}
527
+ {"tag": "color_attr", "include": [{"class": "donut", "count": 1, "color": "orange"}, {"class": "stop sign", "count": 1, "color": "yellow"}], "prompt": "a photo of an orange donut and a yellow stop sign"}
528
+ {"tag": "color_attr", "include": [{"class": "suitcase", "count": 1, "color": "green"}, {"class": "boat", "count": 1, "color": "blue"}], "prompt": "a photo of a green suitcase and a blue boat"}
529
+ {"tag": "color_attr", "include": [{"class": "tennis racket", "count": 1, "color": "orange"}, {"class": "sports ball", "count": 1, "color": "yellow"}], "prompt": "a photo of an orange tennis racket and a yellow sports ball"}
530
+ {"tag": "color_attr", "include": [{"class": "computer keyboard", "count": 1, "color": "purple"}, {"class": "chair", "count": 1, "color": "red"}], "prompt": "a photo of a purple computer keyboard and a red chair"}
531
+ {"tag": "color_attr", "include": [{"class": "suitcase", "count": 1, "color": "purple"}, {"class": "pizza", "count": 1, "color": "orange"}], "prompt": "a photo of a purple suitcase and an orange pizza"}
532
+ {"tag": "color_attr", "include": [{"class": "bottle", "count": 1, "color": "white"}, {"class": "sheep", "count": 1, "color": "blue"}], "prompt": "a photo of a white bottle and a blue sheep"}
533
+ {"tag": "color_attr", "include": [{"class": "backpack", "count": 1, "color": "purple"}, {"class": "umbrella", "count": 1, "color": "white"}], "prompt": "a photo of a purple backpack and a white umbrella"}
534
+ {"tag": "color_attr", "include": [{"class": "potted plant", "count": 1, "color": "orange"}, {"class": "spoon", "count": 1, "color": "black"}], "prompt": "a photo of an orange potted plant and a black spoon"}
535
+ {"tag": "color_attr", "include": [{"class": "tennis racket", "count": 1, "color": "green"}, {"class": "dog", "count": 1, "color": "black"}], "prompt": "a photo of a green tennis racket and a black dog"}
536
+ {"tag": "color_attr", "include": [{"class": "handbag", "count": 1, "color": "yellow"}, {"class": "refrigerator", "count": 1, "color": "blue"}], "prompt": "a photo of a yellow handbag and a blue refrigerator"}
537
+ {"tag": "color_attr", "include": [{"class": "broccoli", "count": 1, "color": "pink"}, {"class": "sink", "count": 1, "color": "red"}], "prompt": "a photo of a pink broccoli and a red sink"}
538
+ {"tag": "color_attr", "include": [{"class": "bowl", "count": 1, "color": "red"}, {"class": "sink", "count": 1, "color": "pink"}], "prompt": "a photo of a red bowl and a pink sink"}
539
+ {"tag": "color_attr", "include": [{"class": "toilet", "count": 1, "color": "white"}, {"class": "apple", "count": 1, "color": "red"}], "prompt": "a photo of a white toilet and a red apple"}
540
+ {"tag": "color_attr", "include": [{"class": "dining table", "count": 1, "color": "pink"}, {"class": "sandwich", "count": 1, "color": "black"}], "prompt": "a photo of a pink dining table and a black sandwich"}
541
+ {"tag": "color_attr", "include": [{"class": "car", "count": 1, "color": "black"}, {"class": "parking meter", "count": 1, "color": "green"}], "prompt": "a photo of a black car and a green parking meter"}
542
+ {"tag": "color_attr", "include": [{"class": "bird", "count": 1, "color": "yellow"}, {"class": "motorcycle", "count": 1, "color": "black"}], "prompt": "a photo of a yellow bird and a black motorcycle"}
543
+ {"tag": "color_attr", "include": [{"class": "giraffe", "count": 1, "color": "brown"}, {"class": "stop sign", "count": 1, "color": "white"}], "prompt": "a photo of a brown giraffe and a white stop sign"}
544
+ {"tag": "color_attr", "include": [{"class": "banana", "count": 1, "color": "white"}, {"class": "elephant", "count": 1, "color": "black"}], "prompt": "a photo of a white banana and a black elephant"}
545
+ {"tag": "color_attr", "include": [{"class": "cow", "count": 1, "color": "orange"}, {"class": "sandwich", "count": 1, "color": "purple"}], "prompt": "a photo of an orange cow and a purple sandwich"}
546
+ {"tag": "color_attr", "include": [{"class": "clock", "count": 1, "color": "red"}, {"class": "cell phone", "count": 1, "color": "black"}], "prompt": "a photo of a red clock and a black cell phone"}
547
+ {"tag": "color_attr", "include": [{"class": "knife", "count": 1, "color": "brown"}, {"class": "donut", "count": 1, "color": "blue"}], "prompt": "a photo of a brown knife and a blue donut"}
548
+ {"tag": "color_attr", "include": [{"class": "cup", "count": 1, "color": "red"}, {"class": "handbag", "count": 1, "color": "pink"}], "prompt": "a photo of a red cup and a pink handbag"}
549
+ {"tag": "color_attr", "include": [{"class": "bicycle", "count": 1, "color": "yellow"}, {"class": "motorcycle", "count": 1, "color": "red"}], "prompt": "a photo of a yellow bicycle and a red motorcycle"}
550
+ {"tag": "color_attr", "include": [{"class": "orange", "count": 1, "color": "red"}, {"class": "broccoli", "count": 1, "color": "purple"}], "prompt": "a photo of a red orange and a purple broccoli"}
551
+ {"tag": "color_attr", "include": [{"class": "traffic light", "count": 1, "color": "orange"}, {"class": "toilet", "count": 1, "color": "white"}], "prompt": "a photo of an orange traffic light and a white toilet"}
552
+ {"tag": "color_attr", "include": [{"class": "cup", "count": 1, "color": "green"}, {"class": "pizza", "count": 1, "color": "red"}], "prompt": "a photo of a green cup and a red pizza"}
553
+ {"tag": "color_attr", "include": [{"class": "pizza", "count": 1, "color": "blue"}, {"class": "baseball glove", "count": 1, "color": "yellow"}], "prompt": "a photo of a blue pizza and a yellow baseball glove"}
prompts/ocr_test.txt ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.7.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.12.9
4
+ aiosignal==1.3.2
5
+ airportsdata==20250523
6
+ annotated-types==0.7.0
7
+ anthropic==0.54.0
8
+ antlr4-python3-runtime==4.13.2
9
+ anyio==4.9.0
10
+ astor==0.8.1
11
+ asttokens==3.0.0
12
+ attrs==25.3.0
13
+ av==14.4.0
14
+ bitsandbytes==0.46.0
15
+ blake3==1.0.5
16
+ cachetools==6.0.0
17
+ certifi==2025.4.26
18
+ charset-normalizer==3.4.2
19
+ click==8.2.1
20
+ cloudpickle==3.1.1
21
+ compressed-tensors==0.9.4
22
+ contourpy==1.3.2
23
+ cupy-cuda12x==13.4.1
24
+ cycler==0.12.1
25
+ datasets==3.6.0
26
+ decorator==5.2.1
27
+ deepspeed==0.15.4
28
+ depyf==0.18.0
29
+ dill==0.3.8
30
+ diskcache==5.6.3
31
+ distro==1.9.0
32
+ dnspython==2.7.0
33
+ docker-pycreds==0.4.0
34
+ einops==0.8.1
35
+ email-validator==2.2.0
36
+ executing==2.2.0
37
+ fastapi==0.115.12
38
+ fastapi-cli==0.0.7
39
+ fastrlock==0.8.3
40
+ filelock==3.18.0
41
+ fonttools==4.58.4
42
+ frozenlist==1.6.2
43
+ fsspec==2025.3.0
44
+ ftfy==6.3.1
45
+ gguf==0.17.0
46
+ gitdb==4.0.12
47
+ gitpython==3.1.44
48
+ googleapis-common-protos==1.70.0
49
+ grpcio==1.72.1
50
+ h11==0.16.0
51
+ hf-transfer==0.1.9
52
+ hf-xet==1.1.3
53
+ hjson==3.1.0
54
+ httpcore==1.0.9
55
+ httptools==0.6.4
56
+ httpx==0.28.1
57
+ huggingface-hub==0.32.4
58
+ idna==3.10
59
+ importlib-metadata==8.7.0
60
+ inquirerpy==0.3.4
61
+ interegular==0.3.3
62
+ ipython==9.3.0
63
+ ipython-pygments-lexers==1.1.1
64
+ jedi==0.19.2
65
+ jinja2==3.1.6
66
+ jiter==0.10.0
67
+ jsonschema==4.24.0
68
+ jsonschema-specifications==2025.4.1
69
+ kiwisolver==1.4.8
70
+ lark==1.2.2
71
+ latex2sympy2-extended==1.10.1
72
+ liger-kernel==0.5.2
73
+ llguidance==0.7.29
74
+ llvmlite==0.44.0
75
+ lm-format-enforcer==0.10.11
76
+ markdown-it-py==3.0.0
77
+ markupsafe==3.0.2
78
+ math-verify==0.7.0
79
+ matplotlib==3.10.3
80
+ matplotlib-inline==0.1.7
81
+ mdurl==0.1.2
82
+ mistral-common==1.5.6
83
+ mpmath==1.3.0
84
+ msgpack==1.1.0
85
+ msgspec==0.19.0
86
+ multidict==6.4.4
87
+ multiprocess==0.70.16
88
+ nest-asyncio==1.6.0
89
+ networkx==3.5
90
+ ninja==1.11.1.4
91
+ numba==0.61.2
92
+ numpy==2.2.6
93
+ nvidia-cublas-cu12==12.6.4.1
94
+ nvidia-cuda-cupti-cu12==12.6.80
95
+ nvidia-cuda-nvrtc-cu12==12.6.77
96
+ nvidia-cuda-runtime-cu12==12.6.77
97
+ nvidia-cudnn-cu12==9.5.1.17
98
+ nvidia-cufft-cu12==11.3.0.4
99
+ nvidia-cufile-cu12==1.11.1.6
100
+ nvidia-curand-cu12==10.3.7.77
101
+ nvidia-cusolver-cu12==11.7.1.2
102
+ nvidia-cusparse-cu12==12.5.4.2
103
+ nvidia-cusparselt-cu12==0.6.3
104
+ nvidia-nccl-cu12==2.26.2
105
+ nvidia-nvjitlink-cu12==12.6.85
106
+ nvidia-nvtx-cu12==12.6.77
107
+ openai==1.84.0
108
+ opencv-python-headless==4.11.0.86
109
+ opentelemetry-api==1.34.0
110
+ opentelemetry-exporter-otlp==1.34.0
111
+ opentelemetry-exporter-otlp-proto-common==1.34.0
112
+ opentelemetry-exporter-otlp-proto-grpc==1.34.0
113
+ opentelemetry-exporter-otlp-proto-http==1.34.0
114
+ opentelemetry-proto==1.34.0
115
+ opentelemetry-sdk==1.34.0
116
+ opentelemetry-semantic-conventions==0.55b0
117
+ opentelemetry-semantic-conventions-ai==0.4.9
118
+ outlines==0.1.11
119
+ outlines-core==0.1.26
120
+ packaging==25.0
121
+ pandas==2.3.0
122
+ parso==0.8.4
123
+ partial-json-parser==0.2.1.1.post5
124
+ peft==0.17.1
125
+ pexpect==4.9.0
126
+ pfzy==0.3.4
127
+ pillow==11.2.1
128
+ platformdirs==4.3.8
129
+ prometheus-client==0.22.1
130
+ prometheus-fastapi-instrumentator==7.1.0
131
+ prompt-toolkit==3.0.51
132
+ propcache==0.3.1
133
+ protobuf==5.29.5
134
+ psutil==7.0.0
135
+ ptyprocess==0.7.0
136
+ pure-eval==0.2.3
137
+ py-cpuinfo==9.0.0
138
+ pyarrow==20.0.0
139
+ pycountry==24.6.1
140
+ pydantic==2.11.5
141
+ pydantic-core==2.33.2
142
+ pygments==2.19.1
143
+ pyparsing==3.2.3
144
+ python-dateutil==2.9.0.post0
145
+ python-dotenv==1.1.0
146
+ python-json-logger==3.3.0
147
+ python-multipart==0.0.20
148
+ pytz==2025.2
149
+ pyyaml==6.0.2
150
+ pyzmq==26.4.0
151
+ qwen-vl-utils==0.0.11
152
+ ray==2.46.0
153
+ referencing==0.36.2
154
+ regex==2024.11.6
155
+ requests==2.32.3
156
+ rich==14.0.0
157
+ rich-toolkit==0.14.7
158
+ rpds-py==0.25.1
159
+ safetensors==0.5.3
160
+ scipy==1.15.3
161
+ seaborn==0.13.2
162
+ sentencepiece==0.2.0
163
+ sentry-sdk==2.29.1
164
+ setproctitle==1.3.6
165
+ shellingham==1.5.4
166
+ six==1.17.0
167
+ smmap==5.0.2
168
+ sniffio==1.3.1
169
+ stack-data==0.6.3
170
+ starlette==0.46.2
171
+ sympy==1.14.0
172
+ tabulate==0.9.0
173
+ tiktoken==0.9.0
174
+ timm==0.6.13
175
+ tokenizers==0.21.1
176
+ torch==2.7.0
177
+ torchaudio==2.7.0
178
+ torchvision==0.22.0
179
+ tqdm==4.67.1
180
+ traitlets==5.14.3
181
+ transformers==4.51.3
182
+ triton==3.3.0
183
+ trl==0.19.0
184
+ typer==0.16.0
185
+ typing-extensions==4.14.0
186
+ typing-inspection==0.4.1
187
+ tzdata==2025.2
188
+ urllib3==2.4.0
189
+ utils==1.0.2
190
+ uvicorn==0.34.3
191
+ uvloop==0.21.0
192
+ vllm==0.9.0.1
193
+ wandb==0.18.3
194
+ watchfiles==1.0.5
195
+ wcwidth==0.2.13
196
+ websockets==15.0.1
197
+ xformers==0.0.30
198
+ xgrammar==0.1.19
199
+ xxhash==3.5.0
200
+ yarl==1.20.0
201
+ zipp==3.22.0
202
+ tensorboardX==2.6.4
unified_inference.py ADDED
@@ -0,0 +1,660 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified Inference Script for Multi-Modal Image Generation and Editing
3
+
4
+ Supports three modes:
5
+ 1. t2i (Text-to-Image): Generate images from text prompts (txt file)
6
+ 2. geneval: Generate multiple samples per prompt for evaluation (jsonl file)
7
+ 3. edit: Edit images based on prompts (parquet file)
8
+
9
+ Example usage:
10
+ # Text-to-Image
11
+ python unified_inference.py --mode t2i --model_path ./model --model_type flux \
12
+ --prompt_file prompts.txt --output_dir outputs/t2i
13
+
14
+ # GenEval
15
+ python unified_inference.py --mode geneval --model_path ./model --model_type flux \
16
+ --metadata_file evaluation_metadata.jsonl --output_dir outputs/geneval --n_samples 4
17
+
18
+ # Image Editing
19
+ python unified_inference.py --mode edit --model_path ./model --model_type kontext \
20
+ --data_file data.parquet --output_dir outputs/edit
21
+ """
22
+
23
+ from concurrent.futures import ThreadPoolExecutor, as_completed
24
+ import argparse
25
+ import json
26
+ import os
27
+ import traceback
28
+ from tqdm import tqdm
29
+ import torch
30
+ import numpy as np
31
+ from PIL import Image
32
+ from transformers import AutoProcessor
33
+ import random
34
+ import multiprocessing as mp
35
+ import pandas as pd
36
+ from io import BytesIO
37
+ import base64
38
+ from torchvision import transforms as TF
39
+
40
+ # Model imports
41
+ from unimodel.qwenflux.qwenflux_inference import QwenFluxForInferenceLM
42
+ from unimodel.qwenkontext.qwenkontext_inference import QwenKontextForInferenceLM
43
+
44
+ # Global configuration
45
+ NUM_DEVICE = 8
46
+ NUM_PROCESSES = 8
47
+
48
+
49
+ # =============================================================================
50
+ # CoT Prompt Templates
51
+ # =============================================================================
52
+ COT_PROMPT_TEMPLATES = {
53
+ # General enhancement
54
+ "geneval": """Please provide an enhanced prompt for the following image generation prompt to make the image more realistic, detailed, with clear separation and precise alignment of all entities.
55
+ Original prompt: {original_prompt}. Directly provide the improved prompt in <answer> </answer> tags.""",
56
+
57
+
58
+ "ocr_clarity_v2": """Please enhance the following image generation prompt with specific focus on TEXT clarity and readability.
59
+ Original prompt: {original_prompt}. Directly provide the improved prompt in <answer> </answer> tags.""",
60
+
61
+
62
+ "quality_purev2": """Rewrite the following image generation prompt to improve its visual quality, detail level, realism, and artistic sophistication.
63
+
64
+ Original prompt: {original_prompt}
65
+
66
+ Directly provide the enhanced version directly in <answer></answer> tags.""",
67
+
68
+
69
+ "edit_general": """Please provide an enhanced prompt for the following image editing prompt.
70
+ Ensure the revised prompt is clear, specific, and includes detailed instructions to achieve the desired outcome while maintaining the original intent.
71
+ Original prompt: {original_prompt}. Directly provide the improved prompt in <answer> </answer> tags.""",
72
+
73
+ }
74
+
75
+
76
+ # =============================================================================
77
+ # Utility Functions
78
+ # =============================================================================
79
+ def set_global_seed(seed):
80
+ """Set global random seed for reproducibility."""
81
+ random.seed(seed)
82
+ np.random.seed(seed)
83
+ torch.manual_seed(seed)
84
+ torch.cuda.manual_seed(seed)
85
+ torch.cuda.manual_seed_all(seed)
86
+
87
+
88
+
89
+
90
+ # =============================================================================
91
+ # Model Loading
92
+ # =============================================================================
93
+ def load_model_pipeline(model_path, model_type, device):
94
+ """Load model pipeline based on model type."""
95
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
96
+ subfolder = model_path.split('/')[-1]
97
+ model_path = model_path.replace(f"/{subfolder}", "")
98
+ if model_type == "flux":
99
+ model = QwenFluxForInferenceLM.from_pretrained(
100
+ model_path, torch_dtype=torch.bfloat16, subfolder=subfolder
101
+ )
102
+ elif model_type == "sana":
103
+ model = QwenSanaForInferenceLM.from_pretrained(
104
+ model_path, torch_dtype=torch.bfloat16, subfolder=subfolder
105
+ )
106
+ elif model_type == "sd3":
107
+ model = QwenSD3ForInferenceLM.from_pretrained(
108
+ model_path, torch_dtype=torch.bfloat16, subfolder=subfolder
109
+ )
110
+ elif model_type == "kontext":
111
+ model = QwenKontextForInferenceLM.from_pretrained(
112
+ model_path, torch_dtype=torch.bfloat16, subfolder=subfolder
113
+ )
114
+ else:
115
+ raise ValueError(f"Unknown model type: {model_type}")
116
+
117
+ processor.tokenizer.padding_side = "left" # for batch inference
118
+ model.to(device)
119
+
120
+ return model, processor
121
+
122
+
123
+ # =============================================================================
124
+ # Data Loading Functions
125
+ # =============================================================================
126
+ def load_prompts_from_txt(txt_file):
127
+ """Load prompts from text file (one per line)."""
128
+ with open(txt_file, 'r', encoding='utf-8') as f:
129
+ prompts = [line.strip() for line in f if line.strip()]
130
+ return prompts
131
+
132
+
133
+ def load_prompts_from_jsonl(metadata_file):
134
+ """Load prompts and metadata from JSONL file."""
135
+ with open(metadata_file) as fp:
136
+ metadatas = [json.loads(line) for line in fp]
137
+ prompts = [metadata['prompt'].strip() for metadata in metadatas]
138
+ return prompts, metadatas
139
+
140
+
141
+ def load_data_from_parquet(parquet_file):
142
+ """Load images and prompts from parquet file."""
143
+ df = pd.read_parquet(parquet_file)
144
+
145
+ # Identify column names
146
+ image_col = None
147
+ prompt_col = None
148
+ id_col = None
149
+
150
+ for col in df.columns:
151
+ col_lower = col.lower()
152
+ if 'image' in col_lower and image_col is None:
153
+ image_col = col
154
+ elif any(kw in col_lower for kw in ['prompt', 'text', 'caption', 'instruction']) and prompt_col is None:
155
+ prompt_col = col
156
+ elif any(kw in col_lower for kw in ['id', 'index']) and id_col is None:
157
+ id_col = col
158
+
159
+ if image_col is None or prompt_col is None:
160
+ raise ValueError(
161
+ f"Cannot identify columns. Found: {df.columns.tolist()}\n"
162
+ f"Expected 'image' and 'prompt'/'text'/'caption'"
163
+ )
164
+
165
+ print(f"Using columns - Image: '{image_col}', Prompt: '{prompt_col}', ID: '{id_col}'")
166
+
167
+ data_list = []
168
+ for idx, row in tqdm(df.iterrows(), total=len(df), desc="Loading parquet"):
169
+ try:
170
+ image_data = row[image_col]["bytes"]
171
+
172
+ if isinstance(image_data, bytes):
173
+ image = Image.open(BytesIO(image_data)).convert('RGB')
174
+ elif isinstance(image_data, str):
175
+ if image_data.startswith('data:image') or image_data.startswith('/9j/') or image_data.startswith('iVBOR'):
176
+ if 'base64,' in image_data:
177
+ image_data = image_data.split('base64,')[1]
178
+ image_bytes = base64.b64decode(image_data)
179
+ image = Image.open(BytesIO(image_bytes)).convert('RGB')
180
+ else:
181
+ image = Image.open(image_data).convert('RGB')
182
+ else:
183
+ print(f"Warning: Skipping row {idx} - unsupported image format")
184
+ continue
185
+
186
+ prompt = str(row[prompt_col])
187
+ item_id = row[id_col] if id_col else idx
188
+
189
+ data_list.append({
190
+ 'image': image,
191
+ 'prompt': prompt,
192
+ 'id': item_id,
193
+ 'index': idx
194
+ })
195
+ except Exception as e:
196
+ print(f"Error loading row {idx}: {e}")
197
+ continue
198
+
199
+ print(f"Loaded {len(data_list)} samples from parquet")
200
+ return data_list
201
+
202
+
203
+ # =============================================================================
204
+ # Image Grid Utility
205
+ # =============================================================================
206
+ def create_image_grid(images, rows, cols):
207
+ """Create a grid image from a list of images."""
208
+ assert len(images) == rows * cols
209
+ width, height = images[0].size
210
+ grid_width = width * cols
211
+ grid_height = height * rows
212
+ grid_image = Image.new('RGB', (grid_width, grid_height))
213
+ for i, image in enumerate(images):
214
+ x = (i % cols) * width
215
+ y = (i // cols) * height
216
+ grid_image.paste(image, (x, y))
217
+ return grid_image
218
+
219
+
220
+ # =============================================================================
221
+ # Generation Functions
222
+ # =============================================================================
223
+ def generate_t2i_batch(
224
+ prompts, start_idx, pipeline, processor, output_dir, batch_size,
225
+ guidance_scale, num_inference_steps, seed, use_cot, cot_template_name,
226
+ add_instruction, device_id
227
+ ):
228
+ """Generate images from text prompts (T2I mode)."""
229
+ os.makedirs(output_dir, exist_ok=True)
230
+
231
+ for i in tqdm(range(0, len(prompts), batch_size), desc=f"GPU {device_id} T2I"):
232
+ batch_prompts = prompts[i:i + batch_size]
233
+ batch_start_idx = start_idx + i
234
+ original_prompts = batch_prompts.copy()
235
+
236
+ if add_instruction:
237
+ batch_prompts = [
238
+ f"Please generate image based on the following caption: {p}"
239
+ for p in batch_prompts
240
+ ]
241
+
242
+ diffusion_kwargs = dict(
243
+ guidance_scale=guidance_scale,
244
+ num_inference_steps=num_inference_steps,
245
+ num_images_per_prompt=1,
246
+ generator=torch.Generator("cpu").manual_seed(seed)
247
+ )
248
+
249
+ try:
250
+ with torch.no_grad():
251
+ if use_cot:
252
+ llm_kwargs = dict(
253
+ max_new_tokens=256, temperature=0.7, top_p=0.9,
254
+ do_sample=False, num_return_sequences=1
255
+ )
256
+ cot_template = COT_PROMPT_TEMPLATES.get(cot_template_name)
257
+ outputs = pipeline.generate_image_cot(
258
+ texts=batch_prompts,
259
+ diffusion_kwargs=diffusion_kwargs,
260
+ processor=processor,
261
+ llm_kwargs=llm_kwargs,
262
+ cot_prompt_template=cot_template
263
+ )
264
+ images = outputs["images"]
265
+ thinking_prompts = outputs.get("improved_prompts", [])
266
+ else:
267
+ images = pipeline.generate_image(
268
+ texts=batch_prompts,
269
+ diffusion_kwargs=diffusion_kwargs
270
+ )
271
+ thinking_prompts = []
272
+
273
+ for j, img in enumerate(images):
274
+ img_idx = batch_start_idx + j
275
+ base_name = f"{img_idx:05d}"
276
+
277
+ img.save(os.path.join(output_dir, f"{base_name}.png"))
278
+
279
+ with open(os.path.join(output_dir, f"{base_name}_caption.txt"), 'w', encoding='utf-8') as f:
280
+ f.write(original_prompts[j])
281
+
282
+ if use_cot and j < len(thinking_prompts):
283
+ with open(os.path.join(output_dir, f"{base_name}_thinking.txt"), 'w', encoding='utf-8') as f:
284
+ f.write(thinking_prompts[j])
285
+
286
+ except Exception as e:
287
+ print(f"Error at batch {batch_start_idx}: {e}")
288
+ traceback.print_exc()
289
+
290
+
291
+ def generate_geneval_batch(
292
+ prompts, metadatas, start_idx, pipeline, processor, output_dir, batch_size,
293
+ guidance_scale, num_inference_steps, seed, n_samples, use_cot,
294
+ cot_template_name, skip_grid, device_id
295
+ ):
296
+ """Generate multiple samples per prompt for evaluation (GenEval mode)."""
297
+ for prompt_idx, (prompt, metadata) in enumerate(zip(prompts, metadatas)):
298
+ global_idx = start_idx + prompt_idx
299
+ outpath = os.path.join(output_dir, f"{device_id}_{prompt_idx:0>5}")
300
+ os.makedirs(outpath, exist_ok=True)
301
+ sample_path = os.path.join(outpath, "samples")
302
+ os.makedirs(sample_path, exist_ok=True)
303
+
304
+ with open(os.path.join(outpath, "metadata.jsonl"), "w") as fp:
305
+ json.dump(metadata, fp)
306
+
307
+ sample_count = 0
308
+ all_samples = []
309
+ enhanced_prompts = []
310
+ total_batches = (n_samples + batch_size - 1) // batch_size
311
+
312
+ for batch_idx in tqdm(range(total_batches), desc=f"GPU {device_id} prompt {prompt_idx}"):
313
+ num_images = min(batch_size, n_samples - sample_count)
314
+
315
+ diffusion_kwargs = dict(
316
+ guidance_scale=guidance_scale,
317
+ num_inference_steps=num_inference_steps,
318
+ num_images_per_prompt=num_images,
319
+ generator=torch.Generator("cpu").manual_seed(seed)
320
+ )
321
+
322
+ try:
323
+ with torch.inference_mode():
324
+ if use_cot:
325
+ llm_kwargs = dict(
326
+ max_new_tokens=256, temperature=0.7, top_p=0.9,
327
+ do_sample=False, num_return_sequences=1
328
+ )
329
+ cot_template = COT_PROMPT_TEMPLATES.get(cot_template_name)
330
+ outputs = pipeline.generate_image_cot(
331
+ texts=prompt,
332
+ diffusion_kwargs=diffusion_kwargs,
333
+ processor=processor,
334
+ llm_kwargs=llm_kwargs,
335
+ cot_prompt_template=cot_template
336
+ )
337
+ images = outputs["images"]
338
+ enhanced_prompts.extend(outputs.get("improved_prompts", []))
339
+ else:
340
+ images = pipeline.generate_image(
341
+ texts=prompt,
342
+ diffusion_kwargs=diffusion_kwargs
343
+ )
344
+
345
+ for img in images:
346
+ img.save(os.path.join(sample_path, f"{sample_count:05}.png"))
347
+ sample_count += 1
348
+ if not skip_grid:
349
+ all_samples.append(img)
350
+
351
+ except Exception as e:
352
+ print(f"Error at prompt {prompt_idx}, batch {batch_idx}: {e}")
353
+ traceback.print_exc()
354
+
355
+ # Save enhanced prompts
356
+ with open(os.path.join(outpath, "thinking_prompts.txt"), "w") as fp:
357
+ for ep in enhanced_prompts:
358
+ fp.write(f"{ep}\n")
359
+
360
+ # Create grid
361
+ if not skip_grid and all_samples:
362
+ rows = int(np.sqrt(n_samples))
363
+ cols = (n_samples + rows - 1) // rows
364
+ if rows * cols >= len(all_samples):
365
+ grid_image = create_image_grid(all_samples[:rows * cols], rows, cols)
366
+ grid_image.save(os.path.join(outpath, "grid.jpg"))
367
+
368
+
369
+ def generate_edit_batch(
370
+ data_batch, start_idx, pipeline, processor, output_dir, batch_size,
371
+ guidance_scale, num_inference_steps, seed, use_cot, cot_template_name,
372
+ device_id, resolution
373
+ ):
374
+ """Edit images based on prompts (Edit mode)."""
375
+ os.makedirs(output_dir, exist_ok=True)
376
+
377
+ transform = TF.Compose([
378
+ TF.Resize(resolution),
379
+ TF.CenterCrop(resolution)
380
+ ])
381
+
382
+ for i in tqdm(range(0, len(data_batch), batch_size), desc=f"GPU {device_id} Edit"):
383
+ batch_data = data_batch[i:i + batch_size]
384
+ batch_start_idx = start_idx + i
385
+
386
+ batch_images = [transform(item['image']) for item in batch_data]
387
+ batch_prompts = [item['prompt'] for item in batch_data]
388
+ batch_ids = [item['id'] for item in batch_data]
389
+
390
+ diffusion_kwargs = dict(
391
+ guidance_scale=guidance_scale,
392
+ num_inference_steps=num_inference_steps,
393
+ num_images_per_prompt=1,
394
+ generator=torch.Generator("cpu").manual_seed(seed),
395
+ max_area=resolution ** 2
396
+ )
397
+
398
+ try:
399
+ with torch.no_grad():
400
+ if use_cot:
401
+ llm_kwargs = dict(
402
+ max_new_tokens=256, temperature=0.7, top_p=0.9,
403
+ do_sample=False, num_return_sequences=1
404
+ )
405
+ cot_template = COT_PROMPT_TEMPLATES.get(cot_template_name)
406
+ outputs = pipeline.generate_image_cot(
407
+ images=batch_images,
408
+ texts=batch_prompts,
409
+ diffusion_kwargs=diffusion_kwargs,
410
+ processor=processor,
411
+ llm_kwargs=llm_kwargs,
412
+ cot_prompt_template=cot_template
413
+ )
414
+ edited_images = outputs["images"]
415
+ improved_prompts = outputs.get("improved_prompts", [])
416
+ else:
417
+ edited_images = pipeline.generate_image(
418
+ images=batch_images,
419
+ texts=batch_prompts,
420
+ diffusion_kwargs=diffusion_kwargs
421
+ )
422
+ improved_prompts = []
423
+
424
+ for j, (edited_img, ref_img) in enumerate(zip(edited_images, batch_images)):
425
+ item_id = batch_ids[j]
426
+ base_name = f"{item_id}"
427
+
428
+ edited_img.save(os.path.join(output_dir, f"{base_name}_edited.png"))
429
+ ref_img.save(os.path.join(output_dir, f"{base_name}_reference.png"))
430
+
431
+ with open(os.path.join(output_dir, f"{base_name}_prompt.txt"), 'w', encoding='utf-8') as f:
432
+ f.write(batch_prompts[j])
433
+
434
+ if use_cot and j < len(improved_prompts):
435
+ with open(os.path.join(output_dir, f"{base_name}_improved_prompt.txt"), 'w', encoding='utf-8') as f:
436
+ f.write(improved_prompts[j])
437
+
438
+ except Exception as e:
439
+ print(f"Error at batch {batch_start_idx}: {e}")
440
+ traceback.print_exc()
441
+
442
+
443
+ # =============================================================================
444
+ # Worker Process
445
+ # =============================================================================
446
+ def worker_process(
447
+ device_id, mode, data, start_idx, pipeline, processor, output_dir,
448
+ batch_size, guidance_scale, num_inference_steps, seed, use_cot,
449
+ cot_template_name, add_instruction, n_samples, skip_grid, resolution, metadatas=None
450
+ ):
451
+ """Single GPU worker process."""
452
+ torch.cuda.set_device(f"cuda:{device_id % NUM_DEVICE}")
453
+
454
+ print(f"GPU {device_id}: Processing {len(data)} items (indices {start_idx} to {start_idx + len(data) - 1})")
455
+
456
+ if mode == "t2i":
457
+ generate_t2i_batch(
458
+ prompts=data, start_idx=start_idx, pipeline=pipeline,
459
+ processor=processor, output_dir=output_dir, batch_size=batch_size,
460
+ guidance_scale=guidance_scale, num_inference_steps=num_inference_steps,
461
+ seed=seed, use_cot=use_cot, cot_template_name=cot_template_name,
462
+ add_instruction=add_instruction, device_id=device_id
463
+ )
464
+ elif mode == "geneval":
465
+ generate_geneval_batch(
466
+ prompts=data, metadatas=metadatas, start_idx=start_idx,
467
+ pipeline=pipeline, processor=processor, output_dir=output_dir,
468
+ batch_size=batch_size, guidance_scale=guidance_scale,
469
+ num_inference_steps=num_inference_steps, seed=seed,
470
+ n_samples=n_samples, use_cot=use_cot, cot_template_name=cot_template_name,
471
+ skip_grid=skip_grid, device_id=device_id
472
+ )
473
+ elif mode == "edit":
474
+ generate_edit_batch(
475
+ data_batch=data, start_idx=start_idx, pipeline=pipeline,
476
+ processor=processor, output_dir=output_dir, batch_size=batch_size,
477
+ guidance_scale=guidance_scale, num_inference_steps=num_inference_steps,
478
+ seed=seed, use_cot=use_cot, cot_template_name=cot_template_name,
479
+ device_id=device_id, resolution=resolution
480
+ )
481
+
482
+ print(f"GPU {device_id}: Completed!")
483
+
484
+
485
+ # =============================================================================
486
+ # Argument Parser
487
+ # =============================================================================
488
+ def parse_args():
489
+ parser = argparse.ArgumentParser(
490
+ description="Unified Inference Script for Image Generation and Editing"
491
+ )
492
+
493
+ # Mode selection
494
+ parser.add_argument(
495
+ "--mode", type=str, required=True,
496
+ choices=["t2i", "geneval", "edit"],
497
+ help="Inference mode: t2i (text-to-image), geneval (evaluation), edit (image editing)"
498
+ )
499
+
500
+ # Input/Output
501
+ parser.add_argument("--prompt_file", type=str, help="Text file with prompts (for t2i mode)")
502
+ parser.add_argument("--metadata_file", type=str, help="JSONL metadata file (for geneval mode)")
503
+ parser.add_argument("--data_file", type=str, help="Parquet file with images and prompts (for edit mode)")
504
+ parser.add_argument("--output_dir", type=str, default="outputs", help="Output directory")
505
+
506
+ # Model configuration
507
+ parser.add_argument("--model_path", type=str, required=True, help="Model path")
508
+ parser.add_argument(
509
+ "--model_type", type=str, default="flux",
510
+ choices=["flux", "sana", "sd3", "kontext"],
511
+ help="Model type"
512
+ )
513
+
514
+ # Generation parameters
515
+ parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
516
+ parser.add_argument("--resolution", type=int, default=1024, help="Image resolution")
517
+ parser.add_argument("--guidance_scale", type=float, default=3.5, help="CFG guidance scale")
518
+ parser.add_argument("--num_inference_steps", type=int, default=40, help="Inference steps")
519
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
520
+
521
+ # CoT options
522
+ parser.add_argument("--use_cot", action="store_true", help="Use Chain of Thought")
523
+ parser.add_argument(
524
+ "--cot_template", type=str, default="general",
525
+ choices=list(COT_PROMPT_TEMPLATES.keys()),
526
+ help="CoT prompt template"
527
+ )
528
+ parser.add_argument("--add_instruction", action="store_true", help="Add instruction prefix (t2i mode)")
529
+
530
+ # GenEval specific
531
+ parser.add_argument("--n_samples", type=int, default=4, help="Samples per prompt (geneval mode)")
532
+ parser.add_argument("--skip_grid", action="store_true", help="Skip grid image (geneval mode)")
533
+
534
+ # Hardware
535
+ parser.add_argument("--num_gpus", type=int, default=None, help="Number of GPUs to use")
536
+ parser.add_argument("--max_samples", type=int, default=None, help="Max samples to process")
537
+
538
+ return parser.parse_args()
539
+
540
+
541
+ # =============================================================================
542
+ # Main Function
543
+ # =============================================================================
544
+ def main():
545
+ mp.set_start_method('spawn', force=True)
546
+ args = parse_args()
547
+
548
+ global NUM_PROCESSES
549
+ if args.num_gpus is not None:
550
+ NUM_PROCESSES = min(args.num_gpus, NUM_DEVICE)
551
+
552
+ # Validate mode-specific arguments
553
+ if args.mode == "t2i" and not args.prompt_file:
554
+ raise ValueError("--prompt_file is required for t2i mode")
555
+ if args.mode == "geneval" and not args.metadata_file:
556
+ raise ValueError("--metadata_file is required for geneval mode")
557
+ if args.mode == "edit" and not args.data_file:
558
+ raise ValueError("--data_file is required for edit mode")
559
+ if args.mode == "edit" and args.model_type != "kontext":
560
+ print(f"Warning: edit mode typically uses kontext model, but got {args.model_type}")
561
+
562
+ # Load data based on mode
563
+ print(f"Mode: {args.mode}")
564
+ metadatas = None
565
+
566
+ if args.mode == "t2i":
567
+ print(f"Loading prompts from {args.prompt_file}...")
568
+ data = load_prompts_from_txt(args.prompt_file)
569
+ elif args.mode == "geneval":
570
+ print(f"Loading metadata from {args.metadata_file}...")
571
+ data, metadatas = load_prompts_from_jsonl(args.metadata_file)
572
+ elif args.mode == "edit":
573
+ print(f"Loading data from {args.data_file}...")
574
+ data = load_data_from_parquet(args.data_file)
575
+
576
+ # Apply max_samples limit
577
+ if args.max_samples is not None:
578
+ if args.mode == "geneval":
579
+ data = data[:args.max_samples]
580
+ metadatas = metadatas[:args.max_samples]
581
+ else:
582
+ data = data[:args.max_samples]
583
+ print(f"Limited to {len(data)} samples")
584
+
585
+ print(f"Total samples: {len(data)}")
586
+
587
+ # Create output directory
588
+ os.makedirs(args.output_dir, exist_ok=True)
589
+
590
+ # Save configuration
591
+ config_path = os.path.join(args.output_dir, "config.json")
592
+ config_dict = vars(args).copy()
593
+ with open(config_path, 'w') as f:
594
+ json.dump(config_dict, f, indent=2)
595
+ print(f"Config saved to {config_path}")
596
+
597
+ # Load models
598
+ print("Loading models...")
599
+ pipelines = []
600
+ processors = []
601
+
602
+ for i in range(NUM_DEVICE):
603
+ print(f"Loading model {i+1}/{NUM_DEVICE} on cuda:{i % NUM_DEVICE}...")
604
+ pipeline, processor = load_model_pipeline(
605
+ args.model_path, args.model_type, f"cuda:{i % NUM_DEVICE}"
606
+ )
607
+ pipelines.append(pipeline)
608
+ processors.append(processor)
609
+
610
+ print("All models loaded!")
611
+
612
+ # Distribute data across GPUs
613
+ samples_per_gpu = len(data) // NUM_PROCESSES
614
+
615
+ with ThreadPoolExecutor(max_workers=NUM_PROCESSES) as executor:
616
+ futures = []
617
+
618
+ for device_id in range(NUM_PROCESSES):
619
+ start_idx = device_id * samples_per_gpu
620
+ end_idx = len(data) if device_id == NUM_PROCESSES - 1 else start_idx + samples_per_gpu
621
+
622
+ gpu_data = data[start_idx:end_idx]
623
+ gpu_metadatas = metadatas[start_idx:end_idx] if metadatas else None
624
+
625
+ future = executor.submit(
626
+ worker_process,
627
+ device_id=device_id,
628
+ mode=args.mode,
629
+ data=gpu_data,
630
+ start_idx=start_idx,
631
+ pipeline=pipelines[device_id % NUM_DEVICE],
632
+ processor=processors[device_id % NUM_DEVICE],
633
+ output_dir=args.output_dir,
634
+ batch_size=args.batch_size,
635
+ guidance_scale=args.guidance_scale,
636
+ num_inference_steps=args.num_inference_steps,
637
+ seed=args.seed,
638
+ use_cot=args.use_cot,
639
+ cot_template_name=args.cot_template,
640
+ add_instruction=args.add_instruction,
641
+ n_samples=args.n_samples,
642
+ skip_grid=args.skip_grid,
643
+ resolution=args.resolution,
644
+ metadatas=gpu_metadatas
645
+ )
646
+ futures.append(future)
647
+
648
+ for future in as_completed(futures):
649
+ try:
650
+ future.result()
651
+ except Exception as e:
652
+ print(f"Worker failed: {e}")
653
+ traceback.print_exc()
654
+
655
+ print(f"\n✓ Done! Results saved to {args.output_dir}")
656
+ print(f" Total processed: {len(data)}")
657
+
658
+
659
+ if __name__ == "__main__":
660
+ main()
unimodel/qwenflux/fluxpipeline.py ADDED
@@ -0,0 +1,1543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ # Copyright 2025 Fu-Yun Wang
3
+
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import inspect
18
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
19
+
20
+ import numpy as np
21
+ import torch
22
+ from transformers import (
23
+ CLIPImageProcessor,
24
+ CLIPTextModel,
25
+ CLIPTokenizer,
26
+ CLIPVisionModelWithProjection,
27
+ T5EncoderModel,
28
+ T5TokenizerFast,
29
+ )
30
+
31
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
32
+ from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
33
+ from diffusers.models import AutoencoderKL, FluxTransformer2DModel
34
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
35
+ from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteSchedulerOutput
36
+ from diffusers.utils import (
37
+ USE_PEFT_BACKEND,
38
+ is_torch_xla_available,
39
+ logging,
40
+ replace_example_docstring,
41
+ scale_lora_layers,
42
+ unscale_lora_layers,
43
+ )
44
+ from diffusers.utils.torch_utils import randn_tensor
45
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
46
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
47
+ import math
48
+
49
+
50
+ if is_torch_xla_available():
51
+ import torch_xla.core.xla_model as xm
52
+
53
+ XLA_AVAILABLE = True
54
+ else:
55
+ XLA_AVAILABLE = False
56
+
57
+
58
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
59
+
60
+ EXAMPLE_DOC_STRING = """
61
+ Examples:
62
+ ```py
63
+ >>> import torch
64
+ >>> from diffusers import FluxPipeline
65
+
66
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
67
+ >>> pipe.to("cuda")
68
+ >>> prompt = "A cat holding a sign that says hello world"
69
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
70
+ >>> # Refer to the pipeline documentation for more details.
71
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
72
+ >>> image.save("flux.png")
73
+ ```
74
+ """
75
+
76
+
77
+ def calculate_shift(
78
+ image_seq_len,
79
+ base_seq_len: int = 256,
80
+ max_seq_len: int = 4096,
81
+ base_shift: float = 0.5,
82
+ max_shift: float = 1.15,
83
+ ):
84
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
85
+ b = base_shift - m * base_seq_len
86
+ mu = image_seq_len * m + b
87
+ return mu
88
+
89
+
90
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
91
+ def retrieve_timesteps(
92
+ scheduler,
93
+ num_inference_steps: Optional[int] = None,
94
+ device: Optional[Union[str, torch.device]] = None,
95
+ timesteps: Optional[List[int]] = None,
96
+ sigmas: Optional[List[float]] = None,
97
+ **kwargs,
98
+ ):
99
+ r"""
100
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
101
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
102
+
103
+ Args:
104
+ scheduler (`SchedulerMixin`):
105
+ The scheduler to get timesteps from.
106
+ num_inference_steps (`int`):
107
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
108
+ must be `None`.
109
+ device (`str` or `torch.device`, *optional*):
110
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
111
+ timesteps (`List[int]`, *optional*):
112
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
113
+ `num_inference_steps` and `sigmas` must be `None`.
114
+ sigmas (`List[float]`, *optional*):
115
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
116
+ `num_inference_steps` and `timesteps` must be `None`.
117
+
118
+ Returns:
119
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
120
+ second element is the number of inference steps.
121
+ """
122
+ if timesteps is not None and sigmas is not None:
123
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
124
+ if timesteps is not None:
125
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
126
+ if not accepts_timesteps:
127
+ raise ValueError(
128
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
129
+ f" timestep schedules. Please check whether you are using the correct scheduler."
130
+ )
131
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
132
+ timesteps = scheduler.timesteps
133
+ num_inference_steps = len(timesteps)
134
+ elif sigmas is not None:
135
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
136
+ if not accept_sigmas:
137
+ raise ValueError(
138
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
139
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
140
+ )
141
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
142
+ timesteps = scheduler.timesteps
143
+ num_inference_steps = len(timesteps)
144
+ else:
145
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
146
+ timesteps = scheduler.timesteps
147
+ return timesteps, num_inference_steps
148
+
149
+
150
+ class FluxPipeline(
151
+ DiffusionPipeline,
152
+ FluxLoraLoaderMixin,
153
+ FromSingleFileMixin,
154
+ TextualInversionLoaderMixin,
155
+ FluxIPAdapterMixin,
156
+ ):
157
+ r"""
158
+ The Flux pipeline for text-to-image generation.
159
+
160
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
161
+
162
+ Args:
163
+ transformer ([`FluxTransformer2DModel`]):
164
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
165
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
166
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
167
+ vae ([`AutoencoderKL`]):
168
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
169
+ text_encoder ([`CLIPTextModel`]):
170
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
171
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
172
+ text_encoder_2 ([`T5EncoderModel`]):
173
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
174
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
175
+ tokenizer (`CLIPTokenizer`):
176
+ Tokenizer of class
177
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
178
+ tokenizer_2 (`T5TokenizerFast`):
179
+ Second Tokenizer of class
180
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
181
+ """
182
+
183
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
184
+ _optional_components = ["image_encoder", "feature_extractor"]
185
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
186
+
187
+ def __init__(
188
+ self,
189
+ scheduler: FlowMatchEulerDiscreteScheduler,
190
+ vae: AutoencoderKL,
191
+ text_encoder: CLIPTextModel,
192
+ tokenizer: CLIPTokenizer,
193
+ text_encoder_2: T5EncoderModel,
194
+ tokenizer_2: T5TokenizerFast,
195
+ transformer: FluxTransformer2DModel,
196
+ image_encoder: CLIPVisionModelWithProjection = None,
197
+ feature_extractor: CLIPImageProcessor = None,
198
+ ):
199
+ super().__init__()
200
+
201
+ self.register_modules(
202
+ vae=vae,
203
+ text_encoder=text_encoder,
204
+ text_encoder_2=text_encoder_2,
205
+ tokenizer=tokenizer,
206
+ tokenizer_2=tokenizer_2,
207
+ transformer=transformer,
208
+ scheduler=scheduler,
209
+ image_encoder=image_encoder,
210
+ feature_extractor=feature_extractor,
211
+ )
212
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
213
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
214
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
215
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
216
+ self.tokenizer_max_length = (
217
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
218
+ )
219
+ self.default_sample_size = 128
220
+
221
+ def _get_t5_prompt_embeds(
222
+ self,
223
+ prompt: Union[str, List[str]] = None,
224
+ num_images_per_prompt: int = 1,
225
+ max_sequence_length: int = 512,
226
+ device: Optional[torch.device] = None,
227
+ dtype: Optional[torch.dtype] = None,
228
+ ):
229
+ device = device or self._execution_device
230
+ dtype = dtype or self.text_encoder.dtype
231
+
232
+ prompt = [prompt] if isinstance(prompt, str) else prompt
233
+ batch_size = len(prompt)
234
+
235
+ if isinstance(self, TextualInversionLoaderMixin):
236
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
237
+
238
+ text_inputs = self.tokenizer_2(
239
+ prompt,
240
+ padding="max_length",
241
+ max_length=max_sequence_length,
242
+ truncation=True,
243
+ return_length=False,
244
+ return_overflowing_tokens=False,
245
+ return_tensors="pt",
246
+ )
247
+ text_input_ids = text_inputs.input_ids
248
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
249
+
250
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
251
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
252
+ logger.warning(
253
+ "The following part of your input was truncated because `max_sequence_length` is set to "
254
+ f" {max_sequence_length} tokens: {removed_text}"
255
+ )
256
+
257
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
258
+
259
+ dtype = self.text_encoder_2.dtype
260
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
261
+
262
+ _, seq_len, _ = prompt_embeds.shape
263
+
264
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
265
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
266
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
267
+
268
+ return prompt_embeds
269
+
270
+ def _get_clip_prompt_embeds(
271
+ self,
272
+ prompt: Union[str, List[str]],
273
+ num_images_per_prompt: int = 1,
274
+ device: Optional[torch.device] = None,
275
+ ):
276
+ device = device or self._execution_device
277
+
278
+ prompt = [prompt] if isinstance(prompt, str) else prompt
279
+ batch_size = len(prompt)
280
+
281
+ if isinstance(self, TextualInversionLoaderMixin):
282
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
283
+
284
+ text_inputs = self.tokenizer(
285
+ prompt,
286
+ padding="max_length",
287
+ max_length=self.tokenizer_max_length,
288
+ truncation=True,
289
+ return_overflowing_tokens=False,
290
+ return_length=False,
291
+ return_tensors="pt",
292
+ )
293
+
294
+ text_input_ids = text_inputs.input_ids
295
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
296
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
297
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
298
+ logger.warning(
299
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
300
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
301
+ )
302
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
303
+
304
+ # Use pooled output of CLIPTextModel
305
+ prompt_embeds = prompt_embeds.pooler_output
306
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
307
+
308
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
309
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
310
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
311
+
312
+ return prompt_embeds
313
+
314
+ def encode_prompt(
315
+ self,
316
+ prompt: Union[str, List[str]],
317
+ prompt_2: Union[str, List[str]],
318
+ device: Optional[torch.device] = None,
319
+ num_images_per_prompt: int = 1,
320
+ prompt_embeds: Optional[torch.FloatTensor] = None,
321
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
322
+ max_sequence_length: int = 512,
323
+ lora_scale: Optional[float] = None,
324
+ ):
325
+ r"""
326
+
327
+ Args:
328
+ prompt (`str` or `List[str]`, *optional*):
329
+ prompt to be encoded
330
+ prompt_2 (`str` or `List[str]`, *optional*):
331
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
332
+ used in all text-encoders
333
+ device: (`torch.device`):
334
+ torch device
335
+ num_images_per_prompt (`int`):
336
+ number of images that should be generated per prompt
337
+ prompt_embeds (`torch.FloatTensor`, *optional*):
338
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
339
+ provided, text embeddings will be generated from `prompt` input argument.
340
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
341
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
342
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
343
+ lora_scale (`float`, *optional*):
344
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
345
+ """
346
+ device = device or self._execution_device
347
+
348
+ # set lora scale so that monkey patched LoRA
349
+ # function of text encoder can correctly access it
350
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
351
+ self._lora_scale = lora_scale
352
+
353
+ # dynamically adjust the LoRA scale
354
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
355
+ scale_lora_layers(self.text_encoder, lora_scale)
356
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
357
+ scale_lora_layers(self.text_encoder_2, lora_scale)
358
+
359
+ prompt = [prompt] if isinstance(prompt, str) else prompt
360
+
361
+ if prompt_embeds is None:
362
+ prompt_2 = prompt_2 or prompt
363
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
364
+
365
+ # We only use the pooled prompt output from the CLIPTextModel
366
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
367
+ prompt=prompt,
368
+ device=device,
369
+ num_images_per_prompt=num_images_per_prompt,
370
+ )
371
+ prompt_embeds = self._get_t5_prompt_embeds(
372
+ prompt=prompt_2,
373
+ num_images_per_prompt=num_images_per_prompt,
374
+ max_sequence_length=max_sequence_length,
375
+ device=device,
376
+ )
377
+
378
+ if self.text_encoder is not None:
379
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
380
+ # Retrieve the original scale by scaling back the LoRA layers
381
+ unscale_lora_layers(self.text_encoder, lora_scale)
382
+
383
+ if self.text_encoder_2 is not None:
384
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
385
+ # Retrieve the original scale by scaling back the LoRA layers
386
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
387
+
388
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
389
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
390
+
391
+ return prompt_embeds, pooled_prompt_embeds, text_ids
392
+
393
+ def encode_image(self, image, device, num_images_per_prompt):
394
+ dtype = next(self.image_encoder.parameters()).dtype
395
+
396
+ if not isinstance(image, torch.Tensor):
397
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
398
+
399
+ image = image.to(device=device, dtype=dtype)
400
+ image_embeds = self.image_encoder(image).image_embeds
401
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
402
+ return image_embeds
403
+
404
+ def prepare_ip_adapter_image_embeds(
405
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
406
+ ):
407
+ image_embeds = []
408
+ if ip_adapter_image_embeds is None:
409
+ if not isinstance(ip_adapter_image, list):
410
+ ip_adapter_image = [ip_adapter_image]
411
+
412
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
413
+ raise ValueError(
414
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
415
+ )
416
+
417
+ for single_ip_adapter_image in ip_adapter_image:
418
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
419
+ image_embeds.append(single_image_embeds[None, :])
420
+ else:
421
+ if not isinstance(ip_adapter_image_embeds, list):
422
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
423
+
424
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
425
+ raise ValueError(
426
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
427
+ )
428
+
429
+ for single_image_embeds in ip_adapter_image_embeds:
430
+ image_embeds.append(single_image_embeds)
431
+
432
+ ip_adapter_image_embeds = []
433
+ for single_image_embeds in image_embeds:
434
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
435
+ single_image_embeds = single_image_embeds.to(device=device)
436
+ ip_adapter_image_embeds.append(single_image_embeds)
437
+
438
+ return ip_adapter_image_embeds
439
+
440
+ def check_inputs(
441
+ self,
442
+ prompt,
443
+ prompt_2,
444
+ height,
445
+ width,
446
+ negative_prompt=None,
447
+ negative_prompt_2=None,
448
+ prompt_embeds=None,
449
+ negative_prompt_embeds=None,
450
+ pooled_prompt_embeds=None,
451
+ negative_pooled_prompt_embeds=None,
452
+ callback_on_step_end_tensor_inputs=None,
453
+ max_sequence_length=None,
454
+ ):
455
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
456
+ logger.warning(
457
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
458
+ )
459
+
460
+ if callback_on_step_end_tensor_inputs is not None and not all(
461
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
462
+ ):
463
+ raise ValueError(
464
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
465
+ )
466
+
467
+ if prompt is not None and prompt_embeds is not None:
468
+ raise ValueError(
469
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
470
+ " only forward one of the two."
471
+ )
472
+ elif prompt_2 is not None and prompt_embeds is not None:
473
+ raise ValueError(
474
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
475
+ " only forward one of the two."
476
+ )
477
+ elif prompt is None and prompt_embeds is None:
478
+ raise ValueError(
479
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
480
+ )
481
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
482
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
483
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
484
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
485
+
486
+ if negative_prompt is not None and negative_prompt_embeds is not None:
487
+ raise ValueError(
488
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
489
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
490
+ )
491
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
492
+ raise ValueError(
493
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
494
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
495
+ )
496
+
497
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
498
+ raise ValueError(
499
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
500
+ )
501
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
502
+ raise ValueError(
503
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
504
+ )
505
+
506
+ if max_sequence_length is not None and max_sequence_length > 512:
507
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
508
+
509
+ @staticmethod
510
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
511
+ latent_image_ids = torch.zeros(height, width, 3)
512
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
513
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
514
+
515
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
516
+
517
+ latent_image_ids = latent_image_ids.reshape(
518
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
519
+ )
520
+
521
+ return latent_image_ids.to(device=device, dtype=dtype)
522
+
523
+ @staticmethod
524
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
525
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
526
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
527
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
528
+
529
+ return latents
530
+
531
+ @staticmethod
532
+ def _unpack_latents(latents, height, width, vae_scale_factor):
533
+ batch_size, num_patches, channels = latents.shape
534
+
535
+ # VAE applies 8x compression on images but we must also account for packing which requires
536
+ # latent height and width to be divisible by 2.
537
+ height = 2 * (int(height) // (vae_scale_factor * 2))
538
+ width = 2 * (int(width) // (vae_scale_factor * 2))
539
+
540
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
541
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
542
+
543
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
544
+
545
+ return latents
546
+
547
+ def enable_vae_slicing(self):
548
+ r"""
549
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
550
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
551
+ """
552
+ self.vae.enable_slicing()
553
+
554
+ def disable_vae_slicing(self):
555
+ r"""
556
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
557
+ computing decoding in one step.
558
+ """
559
+ self.vae.disable_slicing()
560
+
561
+ def enable_vae_tiling(self):
562
+ r"""
563
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
564
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
565
+ processing larger images.
566
+ """
567
+ self.vae.enable_tiling()
568
+
569
+ def disable_vae_tiling(self):
570
+ r"""
571
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
572
+ computing decoding in one step.
573
+ """
574
+ self.vae.disable_tiling()
575
+
576
+ def prepare_latents(
577
+ self,
578
+ batch_size,
579
+ num_channels_latents,
580
+ height,
581
+ width,
582
+ dtype,
583
+ device,
584
+ generator,
585
+ latents=None,
586
+ ):
587
+ # VAE applies 8x compression on images but we must also account for packing which requires
588
+ # latent height and width to be divisible by 2.
589
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
590
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
591
+
592
+ shape = (batch_size, num_channels_latents, height, width)
593
+
594
+ if latents is not None:
595
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
596
+ return latents.to(device=device, dtype=dtype), latent_image_ids
597
+
598
+ if isinstance(generator, list) and len(generator) != batch_size:
599
+ raise ValueError(
600
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
601
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
602
+ )
603
+
604
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
605
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
606
+
607
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
608
+
609
+ return latents, latent_image_ids
610
+
611
+ @property
612
+ def guidance_scale(self):
613
+ return self._guidance_scale
614
+
615
+ @property
616
+ def joint_attention_kwargs(self):
617
+ return self._joint_attention_kwargs
618
+
619
+ @property
620
+ def num_timesteps(self):
621
+ return self._num_timesteps
622
+
623
+ @property
624
+ def current_timestep(self):
625
+ return self._current_timestep
626
+
627
+ @property
628
+ def interrupt(self):
629
+ return self._interrupt
630
+
631
+ @torch.no_grad()
632
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
633
+ def __call__(
634
+ self,
635
+ prompt: Union[str, List[str]] = None,
636
+ prompt_2: Optional[Union[str, List[str]]] = None,
637
+ negative_prompt: Union[str, List[str]] = None,
638
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
639
+ true_cfg_scale: float = 1.0,
640
+ height: Optional[int] = None,
641
+ width: Optional[int] = None,
642
+ num_inference_steps: int = 28,
643
+ sigmas: Optional[List[float]] = None,
644
+ guidance_scale: float = 3.5,
645
+ num_images_per_prompt: Optional[int] = 1,
646
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
647
+ latents: Optional[torch.FloatTensor] = None,
648
+ prompt_embeds: Optional[torch.FloatTensor] = None,
649
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
650
+ ip_adapter_image: Optional[PipelineImageInput] = None,
651
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
652
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
653
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
654
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
655
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
656
+ output_type: Optional[str] = "pil",
657
+ return_dict: bool = True,
658
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
659
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
660
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
661
+ max_sequence_length: int = 512,
662
+ ):
663
+ r"""
664
+ Function invoked when calling the pipeline for generation.
665
+
666
+ Args:
667
+ prompt (`str` or `List[str]`, *optional*):
668
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
669
+ instead.
670
+ prompt_2 (`str` or `List[str]`, *optional*):
671
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
672
+ will be used instead.
673
+ negative_prompt (`str` or `List[str]`, *optional*):
674
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
675
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
676
+ not greater than `1`).
677
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
678
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
679
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
680
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
681
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
682
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
683
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
684
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
685
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
686
+ num_inference_steps (`int`, *optional*, defaults to 50):
687
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
688
+ expense of slower inference.
689
+ sigmas (`List[float]`, *optional*):
690
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
691
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
692
+ will be used.
693
+ guidance_scale (`float`, *optional*, defaults to 3.5):
694
+ Guidance scale as defined in [Classifier-Free Diffusion
695
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
696
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
697
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
698
+ the text `prompt`, usually at the expense of lower image quality.
699
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
700
+ The number of images to generate per prompt.
701
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
702
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
703
+ to make generation deterministic.
704
+ latents (`torch.FloatTensor`, *optional*):
705
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
706
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
707
+ tensor will be generated by sampling using the supplied random `generator`.
708
+ prompt_embeds (`torch.FloatTensor`, *optional*):
709
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
710
+ provided, text embeddings will be generated from `prompt` input argument.
711
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
712
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
713
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
714
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
715
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
716
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
717
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
718
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
719
+ negative_ip_adapter_image:
720
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
721
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
722
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
723
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
724
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
725
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
726
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
727
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
728
+ argument.
729
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
730
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
731
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
732
+ input argument.
733
+ output_type (`str`, *optional*, defaults to `"pil"`):
734
+ The output format of the generate image. Choose between
735
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
736
+ return_dict (`bool`, *optional*, defaults to `True`):
737
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
738
+ joint_attention_kwargs (`dict`, *optional*):
739
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
740
+ `self.processor` in
741
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
742
+ callback_on_step_end (`Callable`, *optional*):
743
+ A function that calls at the end of each denoising steps during the inference. The function is called
744
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
745
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
746
+ `callback_on_step_end_tensor_inputs`.
747
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
748
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
749
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
750
+ `._callback_tensor_inputs` attribute of your pipeline class.
751
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
752
+
753
+ Examples:
754
+
755
+ Returns:
756
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
757
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
758
+ images.
759
+ """
760
+
761
+ height = height or self.default_sample_size * self.vae_scale_factor
762
+ width = width or self.default_sample_size * self.vae_scale_factor
763
+
764
+ # 1. Check inputs. Raise error if not correct
765
+ self.check_inputs(
766
+ prompt,
767
+ prompt_2,
768
+ height,
769
+ width,
770
+ negative_prompt=negative_prompt,
771
+ negative_prompt_2=negative_prompt_2,
772
+ prompt_embeds=prompt_embeds,
773
+ negative_prompt_embeds=negative_prompt_embeds,
774
+ pooled_prompt_embeds=pooled_prompt_embeds,
775
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
776
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
777
+ max_sequence_length=max_sequence_length,
778
+ )
779
+
780
+ self._guidance_scale = guidance_scale
781
+ self._joint_attention_kwargs = joint_attention_kwargs
782
+ self._current_timestep = None
783
+ self._interrupt = False
784
+
785
+ # 2. Define call parameters
786
+ if prompt is not None and isinstance(prompt, str):
787
+ batch_size = 1
788
+ elif prompt is not None and isinstance(prompt, list):
789
+ batch_size = len(prompt)
790
+ else:
791
+ batch_size = prompt_embeds.shape[0]
792
+
793
+ device = self._execution_device
794
+
795
+ lora_scale = (
796
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
797
+ )
798
+ has_neg_prompt = negative_prompt is not None or (
799
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
800
+ )
801
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
802
+ (
803
+ prompt_embeds,
804
+ pooled_prompt_embeds,
805
+ text_ids,
806
+ ) = self.encode_prompt(
807
+ prompt=prompt,
808
+ prompt_2=prompt_2,
809
+ prompt_embeds=prompt_embeds,
810
+ pooled_prompt_embeds=pooled_prompt_embeds,
811
+ device=device,
812
+ num_images_per_prompt=num_images_per_prompt,
813
+ max_sequence_length=max_sequence_length,
814
+ lora_scale=lora_scale,
815
+ )
816
+ if do_true_cfg:
817
+ (
818
+ negative_prompt_embeds,
819
+ negative_pooled_prompt_embeds,
820
+ negative_text_ids,
821
+ ) = self.encode_prompt(
822
+ prompt=negative_prompt,
823
+ prompt_2=negative_prompt_2,
824
+ prompt_embeds=negative_prompt_embeds,
825
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
826
+ device=device,
827
+ num_images_per_prompt=num_images_per_prompt,
828
+ max_sequence_length=max_sequence_length,
829
+ lora_scale=lora_scale,
830
+ )
831
+
832
+ # 4. Prepare latent variables
833
+ num_channels_latents = self.transformer.config.in_channels // 4
834
+ latents, latent_image_ids = self.prepare_latents(
835
+ batch_size * num_images_per_prompt,
836
+ num_channels_latents,
837
+ height,
838
+ width,
839
+ prompt_embeds.dtype,
840
+ device,
841
+ generator,
842
+ latents,
843
+ )
844
+
845
+ # 5. Prepare timesteps
846
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
847
+ image_seq_len = latents.shape[1]
848
+ mu = calculate_shift(
849
+ image_seq_len,
850
+ self.scheduler.config.get("base_image_seq_len", 256),
851
+ self.scheduler.config.get("max_image_seq_len", 4096),
852
+ self.scheduler.config.get("base_shift", 0.5),
853
+ self.scheduler.config.get("max_shift", 1.15),
854
+ )
855
+ timesteps, num_inference_steps = retrieve_timesteps(
856
+ self.scheduler,
857
+ num_inference_steps,
858
+ device,
859
+ sigmas=sigmas,
860
+ mu=mu,
861
+ )
862
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
863
+ self._num_timesteps = len(timesteps)
864
+
865
+ # handle guidance
866
+ if self.transformer.config.guidance_embeds:
867
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
868
+ guidance = guidance.expand(latents.shape[0])
869
+ else:
870
+ guidance = None
871
+
872
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
873
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
874
+ ):
875
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
876
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
877
+
878
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
879
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
880
+ ):
881
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
882
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
883
+
884
+ if self.joint_attention_kwargs is None:
885
+ self._joint_attention_kwargs = {}
886
+
887
+ image_embeds = None
888
+ negative_image_embeds = None
889
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
890
+ image_embeds = self.prepare_ip_adapter_image_embeds(
891
+ ip_adapter_image,
892
+ ip_adapter_image_embeds,
893
+ device,
894
+ batch_size * num_images_per_prompt,
895
+ )
896
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
897
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
898
+ negative_ip_adapter_image,
899
+ negative_ip_adapter_image_embeds,
900
+ device,
901
+ batch_size * num_images_per_prompt,
902
+ )
903
+
904
+ # 6. Denoising loop
905
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
906
+ for i, t in enumerate(timesteps):
907
+ if self.interrupt:
908
+ continue
909
+
910
+ self._current_timestep = t
911
+ if image_embeds is not None:
912
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
913
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
914
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
915
+
916
+ noise_pred = self.transformer(
917
+ hidden_states=latents,
918
+ timestep=timestep / 1000,
919
+ guidance=guidance,
920
+ pooled_projections=pooled_prompt_embeds,
921
+ encoder_hidden_states=prompt_embeds,
922
+ txt_ids=text_ids,
923
+ img_ids=latent_image_ids,
924
+ joint_attention_kwargs=self.joint_attention_kwargs,
925
+ return_dict=False,
926
+ )[0]
927
+
928
+ if do_true_cfg:
929
+ if negative_image_embeds is not None:
930
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
931
+ neg_noise_pred = self.transformer(
932
+ hidden_states=latents,
933
+ timestep=timestep / 1000,
934
+ guidance=guidance,
935
+ pooled_projections=negative_pooled_prompt_embeds,
936
+ encoder_hidden_states=negative_prompt_embeds,
937
+ txt_ids=negative_text_ids,
938
+ img_ids=latent_image_ids,
939
+ joint_attention_kwargs=self.joint_attention_kwargs,
940
+ return_dict=False,
941
+ )[0]
942
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
943
+
944
+ # compute the previous noisy sample x_t -> x_t-1
945
+ latents_dtype = latents.dtype
946
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
947
+
948
+ if latents.dtype != latents_dtype:
949
+ if torch.backends.mps.is_available():
950
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
951
+ latents = latents.to(latents_dtype)
952
+
953
+ if callback_on_step_end is not None:
954
+ callback_kwargs = {}
955
+ for k in callback_on_step_end_tensor_inputs:
956
+ callback_kwargs[k] = locals()[k]
957
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
958
+
959
+ latents = callback_outputs.pop("latents", latents)
960
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
961
+
962
+ # call the callback, if provided
963
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
964
+ progress_bar.update()
965
+
966
+ if XLA_AVAILABLE:
967
+ xm.mark_step()
968
+
969
+ self._current_timestep = None
970
+
971
+ if output_type == "latent":
972
+ image = latents
973
+ else:
974
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
975
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
976
+ image = self.vae.decode(latents, return_dict=False)[0]
977
+ image = self.image_processor.postprocess(image, output_type=output_type)
978
+
979
+ # Offload all models
980
+ self.maybe_free_model_hooks()
981
+
982
+ if not return_dict:
983
+ return (image,)
984
+
985
+ return FluxPipelineOutput(images=image)
986
+
987
+
988
+ @torch.no_grad()
989
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
990
+ def sde_sampling(
991
+ self,
992
+ prompt: Union[str, List[str]] = None,
993
+ prompt_2: Optional[Union[str, List[str]]] = None,
994
+ negative_prompt: Union[str, List[str]] = None,
995
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
996
+ true_cfg_scale: float = 1.0,
997
+ height: Optional[int] = None,
998
+ width: Optional[int] = None,
999
+ num_inference_steps: int = 28,
1000
+ sigmas: Optional[List[float]] = None,
1001
+ guidance_scale: float = 3.5,
1002
+ num_images_per_prompt: Optional[int] = 1,
1003
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1004
+ latents: Optional[torch.FloatTensor] = None,
1005
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1006
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1007
+ ip_adapter_image: Optional[PipelineImageInput] = None,
1008
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
1009
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
1010
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
1011
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1012
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1013
+ output_type: Optional[str] = "pil",
1014
+ return_dict: bool = True,
1015
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
1016
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1017
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1018
+ max_sequence_length: int = 512,
1019
+ ):
1020
+ r"""
1021
+ Function invoked when calling the pipeline for generation.
1022
+
1023
+ Args:
1024
+ prompt (`str` or `List[str]`, *optional*):
1025
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1026
+ instead.
1027
+ prompt_2 (`str` or `List[str]`, *optional*):
1028
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1029
+ will be used instead.
1030
+ negative_prompt (`str` or `List[str]`, *optional*):
1031
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1032
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
1033
+ not greater than `1`).
1034
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
1035
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1036
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
1037
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
1038
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
1039
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1040
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
1041
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1042
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
1043
+ num_inference_steps (`int`, *optional*, defaults to 50):
1044
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1045
+ expense of slower inference.
1046
+ sigmas (`List[float]`, *optional*):
1047
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
1048
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
1049
+ will be used.
1050
+ guidance_scale (`float`, *optional*, defaults to 3.5):
1051
+ Guidance scale as defined in [Classifier-Free Diffusion
1052
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
1053
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
1054
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
1055
+ the text `prompt`, usually at the expense of lower image quality.
1056
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1057
+ The number of images to generate per prompt.
1058
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1059
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1060
+ to make generation deterministic.
1061
+ latents (`torch.FloatTensor`, *optional*):
1062
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1063
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1064
+ tensor will be generated by sampling using the supplied random `generator`.
1065
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1066
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1067
+ provided, text embeddings will be generated from `prompt` input argument.
1068
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1069
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1070
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1071
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1072
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
1073
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1074
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
1075
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1076
+ negative_ip_adapter_image:
1077
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1078
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
1079
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1080
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
1081
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1082
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1083
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1084
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1085
+ argument.
1086
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1087
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1088
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1089
+ input argument.
1090
+ output_type (`str`, *optional*, defaults to `"pil"`):
1091
+ The output format of the generate image. Choose between
1092
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1093
+ return_dict (`bool`, *optional*, defaults to `True`):
1094
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
1095
+ joint_attention_kwargs (`dict`, *optional*):
1096
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1097
+ `self.processor` in
1098
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1099
+ callback_on_step_end (`Callable`, *optional*):
1100
+ A function that calls at the end of each denoising steps during the inference. The function is called
1101
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1102
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1103
+ `callback_on_step_end_tensor_inputs`.
1104
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1105
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1106
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1107
+ `._callback_tensor_inputs` attribute of your pipeline class.
1108
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
1109
+
1110
+ Examples:
1111
+
1112
+ Returns:
1113
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
1114
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
1115
+ images.
1116
+ """
1117
+
1118
+ height = height or self.default_sample_size * self.vae_scale_factor
1119
+ width = width or self.default_sample_size * self.vae_scale_factor
1120
+
1121
+ # 1. Check inputs. Raise error if not correct
1122
+ self.check_inputs(
1123
+ prompt,
1124
+ prompt_2,
1125
+ height,
1126
+ width,
1127
+ negative_prompt=negative_prompt,
1128
+ negative_prompt_2=negative_prompt_2,
1129
+ prompt_embeds=prompt_embeds,
1130
+ negative_prompt_embeds=negative_prompt_embeds,
1131
+ pooled_prompt_embeds=pooled_prompt_embeds,
1132
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1133
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1134
+ max_sequence_length=max_sequence_length,
1135
+ )
1136
+
1137
+ self._guidance_scale = guidance_scale
1138
+ self._joint_attention_kwargs = joint_attention_kwargs
1139
+ self._current_timestep = None
1140
+ self._interrupt = False
1141
+
1142
+ # 2. Define call parameters
1143
+ if prompt is not None and isinstance(prompt, str):
1144
+ batch_size = 1
1145
+ elif prompt is not None and isinstance(prompt, list):
1146
+ batch_size = len(prompt)
1147
+ else:
1148
+ batch_size = prompt_embeds.shape[0]
1149
+
1150
+ device = self._execution_device
1151
+
1152
+ lora_scale = (
1153
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
1154
+ )
1155
+ has_neg_prompt = negative_prompt is not None or (
1156
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
1157
+ )
1158
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
1159
+ (
1160
+ prompt_embeds,
1161
+ pooled_prompt_embeds,
1162
+ text_ids,
1163
+ ) = self.encode_prompt(
1164
+ prompt=prompt,
1165
+ prompt_2=prompt_2,
1166
+ prompt_embeds=prompt_embeds,
1167
+ pooled_prompt_embeds=pooled_prompt_embeds,
1168
+ device=device,
1169
+ num_images_per_prompt=num_images_per_prompt,
1170
+ max_sequence_length=max_sequence_length,
1171
+ lora_scale=lora_scale,
1172
+ )
1173
+ if do_true_cfg:
1174
+ (
1175
+ negative_prompt_embeds,
1176
+ negative_pooled_prompt_embeds,
1177
+ negative_text_ids,
1178
+ ) = self.encode_prompt(
1179
+ prompt=negative_prompt,
1180
+ prompt_2=negative_prompt_2,
1181
+ prompt_embeds=negative_prompt_embeds,
1182
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
1183
+ device=device,
1184
+ num_images_per_prompt=num_images_per_prompt,
1185
+ max_sequence_length=max_sequence_length,
1186
+ lora_scale=lora_scale,
1187
+ )
1188
+
1189
+ # 4. Prepare latent variables
1190
+ num_channels_latents = self.transformer.config.in_channels // 4
1191
+ latents, latent_image_ids = self.prepare_latents(
1192
+ batch_size * num_images_per_prompt,
1193
+ num_channels_latents,
1194
+ height,
1195
+ width,
1196
+ prompt_embeds.dtype,
1197
+ device,
1198
+ generator,
1199
+ latents,
1200
+ )
1201
+
1202
+ # 5. Prepare timesteps
1203
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
1204
+ image_seq_len = latents.shape[1]
1205
+ mu = calculate_shift(
1206
+ image_seq_len,
1207
+ self.scheduler.config.get("base_image_seq_len", 256),
1208
+ self.scheduler.config.get("max_image_seq_len", 4096),
1209
+ self.scheduler.config.get("base_shift", 0.5),
1210
+ self.scheduler.config.get("max_shift", 1.15),
1211
+ )
1212
+ timesteps, num_inference_steps = retrieve_timesteps(
1213
+ self.scheduler,
1214
+ num_inference_steps,
1215
+ device,
1216
+ sigmas=sigmas,
1217
+ mu=mu,
1218
+ )
1219
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1220
+ self._num_timesteps = len(timesteps)
1221
+
1222
+ # handle guidance
1223
+ if self.transformer.config.guidance_embeds:
1224
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
1225
+ guidance = guidance.expand(latents.shape[0])
1226
+ else:
1227
+ guidance = None
1228
+
1229
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
1230
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
1231
+ ):
1232
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
1233
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
1234
+
1235
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
1236
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
1237
+ ):
1238
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
1239
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
1240
+
1241
+ if self.joint_attention_kwargs is None:
1242
+ self._joint_attention_kwargs = {}
1243
+
1244
+ image_embeds = None
1245
+ negative_image_embeds = None
1246
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1247
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1248
+ ip_adapter_image,
1249
+ ip_adapter_image_embeds,
1250
+ device,
1251
+ batch_size * num_images_per_prompt,
1252
+ )
1253
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
1254
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
1255
+ negative_ip_adapter_image,
1256
+ negative_ip_adapter_image_embeds,
1257
+ device,
1258
+ batch_size * num_images_per_prompt,
1259
+ )
1260
+
1261
+ # 6. Denoising loop
1262
+ prev_latents = []
1263
+ pred_latents = []
1264
+ # preds_lst = []
1265
+ states = {
1266
+ "timestep": [],
1267
+ "guidance": [],
1268
+ "pooled_projections": [],
1269
+ "encoder_hidden_states": [],
1270
+ "txt_ids": None,
1271
+ "img_ids": None,
1272
+ }
1273
+ log_probs = []
1274
+ ts = []
1275
+ states["txt_ids"] = text_ids if text_ids is not None else None
1276
+ states["img_ids"] = latent_image_ids if latent_image_ids is not None else None
1277
+
1278
+ # self.scheduler.set_begin_index(0)
1279
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1280
+ for i, t in enumerate(timesteps):
1281
+ if self.interrupt:
1282
+ continue
1283
+
1284
+ self._current_timestep = t
1285
+ if image_embeds is not None:
1286
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
1287
+
1288
+
1289
+ timestep = (t.expand(latents.shape[0])/ 1000.).to(latents.dtype)
1290
+
1291
+ states["timestep"].append(timestep.unsqueeze(1)) # Unsqueezed if needed for batch/timestep handling
1292
+ states["guidance"].append(guidance.unsqueeze(1) if torch.is_tensor(guidance) else guidance) # Handle if tensor
1293
+ states["pooled_projections"].append(pooled_prompt_embeds.unsqueeze(1) if pooled_prompt_embeds is not None else None) # Unsqueezed along seq/batch if applicable
1294
+ states["encoder_hidden_states"].append(prompt_embeds.unsqueeze(1) if prompt_embeds is not None else None) # Unsqueezed along seq dim if needed
1295
+
1296
+ ts.append(t.expand(latents.shape[0]).unsqueeze(1))
1297
+ prev_latents.append(latents.detach().clone().unsqueeze(1))
1298
+
1299
+ noise_pred = self.transformer(
1300
+ hidden_states=latents,
1301
+ timestep=timestep,
1302
+ guidance=guidance,
1303
+ pooled_projections=pooled_prompt_embeds,
1304
+ encoder_hidden_states=prompt_embeds,
1305
+ txt_ids=text_ids,
1306
+ img_ids=latent_image_ids,
1307
+ joint_attention_kwargs=self.joint_attention_kwargs,
1308
+ return_dict=False,
1309
+ )[0]
1310
+
1311
+ if do_true_cfg:
1312
+ if negative_image_embeds is not None:
1313
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
1314
+
1315
+ neg_noise_pred = self.transformer(
1316
+ hidden_states=latents,
1317
+ timestep=timestep,
1318
+ guidance=guidance,
1319
+ pooled_projections=negative_pooled_prompt_embeds,
1320
+ encoder_hidden_states=negative_prompt_embeds,
1321
+ txt_ids=negative_text_ids,
1322
+ img_ids=latent_image_ids,
1323
+ joint_attention_kwargs=self.joint_attention_kwargs,
1324
+ return_dict=False,
1325
+ )[0]
1326
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
1327
+
1328
+ latents_dtype = latents.dtype
1329
+ latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(self.scheduler, noise_pred.float(), t.expand(latents.shape[0]), latents.float())
1330
+
1331
+ log_probs.append(log_prob.detach().clone().unsqueeze(1))
1332
+ pred_latents.append(latents.detach().clone().unsqueeze(1))
1333
+ if latents.dtype != latents_dtype:
1334
+ latents = latents.to(latents_dtype)
1335
+
1336
+ if callback_on_step_end is not None:
1337
+ callback_kwargs = {}
1338
+ for k in callback_on_step_end_tensor_inputs:
1339
+ callback_kwargs[k] = locals()[k]
1340
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1341
+
1342
+ latents = callback_outputs.pop("latents", latents)
1343
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1344
+
1345
+ # call the callback, if provided
1346
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1347
+ progress_bar.update()
1348
+
1349
+ if XLA_AVAILABLE:
1350
+ xm.mark_step()
1351
+
1352
+ self._current_timestep = None
1353
+
1354
+ if output_type == "latent":
1355
+ image = latents
1356
+ else:
1357
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1358
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1359
+ image = self.vae.decode(latents, return_dict=False)[0]
1360
+ image = self.image_processor.postprocess(image, output_type=output_type)
1361
+
1362
+
1363
+ batched_states = {}
1364
+ batch_size = latents.shape[0]
1365
+ num_steps = len(timesteps)
1366
+
1367
+ for key, value_list in states.items():
1368
+ if value_list is None or len(value_list) == 0: # Skip None or empty lists
1369
+ batched_states[key] = None
1370
+ continue
1371
+ if value_list[0] is None: # Handle lists of None (e.g., optional inputs)
1372
+ batched_states[key] = None
1373
+ continue
1374
+ # Concatenate along dim=1
1375
+ if isinstance(value_list, list):
1376
+ concatenated = torch.cat(value_list, dim=1) # Shape: (batch, steps, ...)
1377
+ if len(concatenated.shape) <= 2: # 1D tensors (e.g., timestep: batch, steps)
1378
+ # print(key, concatenated.shape)
1379
+ batched_states[key] = concatenated.view(-1)
1380
+ else: # Higher-dim tensors (e.g., latents: batch, steps, channels, h, w)
1381
+ batched_states[key] = concatenated.view(-1, *concatenated.shape[2:])
1382
+ else:
1383
+ batched_states[key] = value_list
1384
+ # assert 0
1385
+ prev_latents = torch.cat(prev_latents, dim=1)
1386
+ log_probs = torch.cat(log_probs, dim=1)
1387
+ pred_latents = torch.cat(pred_latents, dim=1)
1388
+ ts = torch.cat(ts, dim=1)
1389
+
1390
+ prev_latents = prev_latents.view(prev_latents.shape[0] * prev_latents.shape[1], *prev_latents.shape[2:])
1391
+ log_probs = log_probs.view(log_probs.shape[0] * log_probs.shape[1], *log_probs.shape[2:])
1392
+ pred_latents = pred_latents.view(pred_latents.shape[0] * pred_latents.shape[1], *pred_latents.shape[2:])
1393
+ ts = ts.view(-1)
1394
+
1395
+ # Offload all models
1396
+ self.maybe_free_model_hooks()
1397
+
1398
+ return (image, prev_latents, log_probs, pred_latents, ts, batched_states)
1399
+
1400
+ def sde_step_with_logprob(
1401
+ self: FlowMatchEulerDiscreteScheduler,
1402
+ model_output: torch.FloatTensor,
1403
+ timestep: Union[float, torch.FloatTensor],
1404
+ sample: torch.FloatTensor,
1405
+ prev_sample: Optional[torch.FloatTensor] = None,
1406
+ generator: Optional[torch.Generator] = None,
1407
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
1408
+ """
1409
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
1410
+ process from the learned model outputs (most often the predicted velocity).
1411
+
1412
+ Args:
1413
+ model_output (`torch.FloatTensor`):
1414
+ The direct output from learned flow model.
1415
+ timestep (`float`):
1416
+ The current discrete timestep in the diffusion chain.
1417
+ sample (`torch.FloatTensor`):
1418
+ A current instance of a sample created by the diffusion process.
1419
+ generator (`torch.Generator`, *optional*):
1420
+ A random number generator.
1421
+ """
1422
+ step_index = [self.index_for_timestep(t) for t in timestep]
1423
+ prev_step_index = [step+1 for step in step_index]
1424
+ sigma = self.sigmas[step_index].view(-1, 1, 1).to(model_output.device)
1425
+ sigma_prev = self.sigmas[prev_step_index].view(-1, 1, 1).to(model_output.device)
1426
+ sigma_max = self.sigmas[1].item()
1427
+ dt = sigma_prev - sigma
1428
+
1429
+ std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))) * 1.0
1430
+
1431
+
1432
+ # our sde
1433
+ prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
1434
+
1435
+ if prev_sample is not None and generator is not None:
1436
+ raise ValueError(
1437
+ "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
1438
+ " `prev_sample` stays `None`."
1439
+ )
1440
+
1441
+ if prev_sample is None:
1442
+ variance_noise = randn_tensor(
1443
+ model_output.shape,
1444
+ generator=generator,
1445
+ device=model_output.device,
1446
+ dtype=model_output.dtype,
1447
+ )
1448
+ prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1*dt) * variance_noise
1449
+
1450
+
1451
+ log_prob = (
1452
+ -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1*dt))**2))
1453
+ - torch.log(std_dev_t * torch.sqrt(-1*dt))
1454
+ - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
1455
+ )
1456
+
1457
+ # mean along all but batch dimension
1458
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
1459
+
1460
+ return prev_sample, log_prob, prev_sample_mean, std_dev_t * torch.sqrt(-1*dt)
1461
+
1462
+
1463
+
1464
+ # Copyright 2025 Fu-Yun Wang
1465
+ #
1466
+ # Licensed under the Apache License, Version 2.0 (the "License");
1467
+ # you may not use this file except in compliance with the License.
1468
+ # You may obtain a copy of the License at
1469
+ #
1470
+ # http://www.apache.org/licenses/LICENSE-2.0
1471
+ #
1472
+ # Unless required by applicable law or agreed to in writing, software
1473
+ # distributed under the License is distributed on an "AS IS" BASIS,
1474
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1475
+ # See the License for the specific language governing permissions and
1476
+ # limitations under the License.
1477
+
1478
+ def sde_step_with_logprob_simple(
1479
+ self: FlowMatchEulerDiscreteScheduler,
1480
+ model_output: torch.FloatTensor,
1481
+ timestep: Union[float, torch.FloatTensor],
1482
+ sample: torch.FloatTensor,
1483
+ prev_sample: Optional[torch.FloatTensor] = None,
1484
+ generator: Optional[torch.Generator] = None,
1485
+ ):
1486
+ """
1487
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
1488
+ process from the learned model outputs (most often the predicted velocity).
1489
+
1490
+ Args:
1491
+ model_output (`torch.FloatTensor`):
1492
+ The direct output from learned flow model.
1493
+ timestep (`float`):
1494
+ The current discrete timestep in the diffusion chain.
1495
+ sample (`torch.FloatTensor`):
1496
+ A current instance of a sample created by the diffusion process.
1497
+ generator (`torch.Generator`, *optional*):
1498
+ A random number generator.
1499
+ """
1500
+
1501
+ step_index = [self.index_for_timestep(t) for t in timestep]
1502
+ prev_step_index = [step+1 for step in step_index]
1503
+ sigma = self.sigmas[step_index].view(-1, 1, 1).to(model_output.device)
1504
+ sigma_prev = self.sigmas[prev_step_index].view(-1, 1, 1).to(model_output.device)
1505
+ sigma_max = self.sigmas[1].item()
1506
+ dt = sigma_prev - sigma
1507
+
1508
+
1509
+ eta = 0.5
1510
+ Dt = - dt * eta
1511
+
1512
+ prev_sample_mean = sample * (1 - Dt / (1 - torch.where(sigma == 1, sigma_max, sigma))) + model_output * (dt - Dt)
1513
+
1514
+ std_dev_t = torch.sqrt(2 * Dt * (sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))))
1515
+
1516
+ if prev_sample is not None and generator is not None:
1517
+ raise ValueError(
1518
+ "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
1519
+ " `prev_sample` stays `None`."
1520
+ )
1521
+
1522
+ if prev_sample is None:
1523
+ # Generate noise if not provided
1524
+ variance_noise = randn_tensor(
1525
+ model_output.shape,
1526
+ generator=generator,
1527
+ device=model_output.device,
1528
+ dtype=model_output.dtype,
1529
+ )
1530
+
1531
+ prev_sample = prev_sample_mean + std_dev_t * variance_noise
1532
+
1533
+
1534
+ log_prob = (
1535
+ -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
1536
+ - torch.log(std_dev_t)
1537
+ - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
1538
+ )
1539
+
1540
+ # mean along all but batch dimension
1541
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
1542
+
1543
+ return prev_sample, log_prob, prev_sample_mean, std_dev_t
unimodel/qwenflux/qwenflux_inference.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Fu-Yun Wang
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union, Dict
17
+ import torch
18
+ import torch.nn as nn
19
+ from PIL import Image
20
+ import torch.nn.functional as F
21
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoProcessor
22
+ from transformers import Qwen2_5_VLConfig, Qwen2_5_VLModel, Qwen2_5_VLForConditionalGeneration
23
+
24
+
25
+ from diffusers.utils.torch_utils import randn_tensor
26
+ from diffusers.pipelines.pipeline_utils import numpy_to_pil
27
+ import numpy as np
28
+ from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteSchedulerOutput
29
+ from diffusers.schedulers import DPMSolverMultistepScheduler
30
+ import math
31
+ from diffusers.utils.torch_utils import randn_tensor
32
+ from diffusers import FluxTransformer2DModel, AutoencoderKL, FlowMatchEulerDiscreteScheduler
33
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, CLIPTextConfig, T5Config
34
+ from .fluxpipeline import FluxPipeline
35
+ import re
36
+ import datetime
37
+ import os
38
+
39
+
40
+ def save_grid_image(prompt, images, rows, cols):
41
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
42
+ base_dir = os.path.join("samples", timestamp, prompt[:100])
43
+ os.makedirs(base_dir, exist_ok=True)
44
+
45
+ filename = os.path.join(base_dir, "grid.jpg")
46
+ grid_image = create_image_grid(images, rows, cols)
47
+ grid_image.save(filename)
48
+
49
+ print(f"Saved: {filename}")
50
+
51
+ def create_image_grid(images, rows, cols):
52
+ """Creates a grid of images and returns a single PIL Image."""
53
+
54
+ assert len(images) == rows * cols
55
+
56
+ width, height = images[0].size
57
+ grid_width = width * cols
58
+ grid_height = height * rows
59
+
60
+ grid_image = Image.new('RGB', (grid_width, grid_height))
61
+
62
+ for i, image in enumerate(images):
63
+ x = (i % cols) * width
64
+ y = (i // cols) * height
65
+ grid_image.paste(image, (x, y))
66
+
67
+ return grid_image
68
+
69
+
70
+ def sde_step_with_logprob(
71
+ self: FlowMatchEulerDiscreteScheduler,
72
+ model_output: torch.FloatTensor,
73
+ timestep: Union[float, torch.FloatTensor],
74
+ sample: torch.FloatTensor,
75
+ prev_sample: Optional[torch.FloatTensor] = None,
76
+ generator: Optional[torch.Generator] = None,
77
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
78
+ """
79
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
80
+ process from the learned model outputs (most often the predicted velocity).
81
+
82
+ Args:
83
+ model_output (`torch.FloatTensor`):
84
+ The direct output from learned flow model.
85
+ timestep (`float`):
86
+ The current discrete timestep in the diffusion chain.
87
+ sample (`torch.FloatTensor`):
88
+ A current instance of a sample created by the diffusion process.
89
+ generator (`torch.Generator`, *optional*):
90
+ A random number generator.
91
+ """
92
+ step_index = [self.index_for_timestep(t) for t in timestep]
93
+ prev_step_index = [step+1 for step in step_index]
94
+ sigma = self.sigmas[step_index].view(-1, 1, 1).to(model_output.device)
95
+ sigma_prev = self.sigmas[prev_step_index].view(-1, 1, 1).to(model_output.device)
96
+ sigma_max = self.sigmas[1].item()
97
+ dt = sigma_prev - sigma
98
+
99
+ std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*1.0
100
+
101
+
102
+ # our sde
103
+ prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
104
+
105
+ if prev_sample is not None and generator is not None:
106
+ raise ValueError(
107
+ "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
108
+ " `prev_sample` stays `None`."
109
+ )
110
+
111
+ if prev_sample is None:
112
+ variance_noise = randn_tensor(
113
+ model_output.shape,
114
+ generator=generator,
115
+ device=model_output.device,
116
+ dtype=model_output.dtype,
117
+ )
118
+ prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1*dt) * variance_noise
119
+
120
+
121
+ log_prob = (
122
+ -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1*dt))**2))
123
+ - torch.log(std_dev_t * torch.sqrt(-1*dt))
124
+ - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
125
+ )
126
+
127
+ # mean along all but batch dimension
128
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
129
+
130
+ return prev_sample, log_prob, prev_sample_mean, std_dev_t * torch.sqrt(-1*dt)
131
+
132
+
133
+ # Copyright 2025 Fu-Yun Wang
134
+ #
135
+ # Licensed under the Apache License, Version 2.0 (the "License");
136
+ # you may not use this file except in compliance with the License.
137
+ # You may obtain a copy of the License at
138
+ #
139
+ # http://www.apache.org/licenses/LICENSE-2.0
140
+ #
141
+ # Unless required by applicable law or agreed to in writing, software
142
+ # distributed under the License is distributed on an "AS IS" BASIS,
143
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
144
+ # See the License for the specific language governing permissions and
145
+ # limitations under the License.
146
+
147
+ def sde_step_with_logprob_simple(
148
+ self: FlowMatchEulerDiscreteScheduler,
149
+ model_output: torch.FloatTensor,
150
+ timestep: Union[float, torch.FloatTensor],
151
+ sample: torch.FloatTensor,
152
+ prev_sample: Optional[torch.FloatTensor] = None,
153
+ generator: Optional[torch.Generator] = None,
154
+ ):
155
+ """
156
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
157
+ process from the learned model outputs (most often the predicted velocity).
158
+
159
+ Args:
160
+ model_output (`torch.FloatTensor`):
161
+ The direct output from learned flow model.
162
+ timestep (`float`):
163
+ The current discrete timestep in the diffusion chain.
164
+ sample (`torch.FloatTensor`):
165
+ A current instance of a sample created by the diffusion process.
166
+ generator (`torch.Generator`, *optional*):
167
+ A random number generator.
168
+ """
169
+
170
+ step_index = [self.index_for_timestep(t) for t in timestep]
171
+ prev_step_index = [step+1 for step in step_index]
172
+ sigma = self.sigmas[step_index].view(-1, 1, 1, 1).to(model_output.device)
173
+ sigma_prev = self.sigmas[prev_step_index].view(-1, 1, 1, 1).to(model_output.device)
174
+ sigma_max = self.sigmas[1].item()
175
+ dt = sigma_prev - sigma
176
+
177
+
178
+ eta = 0.5
179
+ Dt = - dt * eta
180
+
181
+ prev_sample_mean = sample * (1 - Dt / (1 - torch.where(sigma == 1, sigma_max, sigma))) + model_output * (dt - Dt)
182
+
183
+ std_dev_t = torch.sqrt(2 * Dt * (sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))))
184
+
185
+ if prev_sample is not None and generator is not None:
186
+ raise ValueError(
187
+ "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
188
+ " `prev_sample` stays `None`."
189
+ )
190
+
191
+ if prev_sample is None:
192
+ # Generate noise if not provided
193
+ variance_noise = randn_tensor(
194
+ model_output.shape,
195
+ generator=generator,
196
+ device=model_output.device,
197
+ dtype=model_output.dtype,
198
+ )
199
+
200
+ prev_sample = prev_sample_mean + std_dev_t * variance_noise
201
+
202
+
203
+ log_prob = (
204
+ -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
205
+ - torch.log(std_dev_t)
206
+ - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
207
+ )
208
+
209
+ # mean along all but batch dimension
210
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
211
+
212
+ return prev_sample, log_prob, prev_sample_mean, std_dev_t
213
+
214
+ class QwenFluxMetaModel:
215
+
216
+ def __init__(self, config):
217
+ super(QwenFluxMetaModel, self).__init__(config)
218
+
219
+ if hasattr(config, "diffusion_expert"):
220
+ ckpt_id = "black-forest-labs/FLUX.1-dev"
221
+ # Load configuration for each component
222
+ transformer_config = FluxTransformer2DModel.load_config(ckpt_id, subfolder="transformer")
223
+ vae_config = AutoencoderKL.load_config(ckpt_id, subfolder="vae")
224
+ text_encoder_config = CLIPTextConfig.from_pretrained(ckpt_id, subfolder="text_encoder")
225
+ text_encoder_2_config = T5Config.from_pretrained(ckpt_id, subfolder="text_encoder_2")
226
+
227
+ # Initialize components from their configurations
228
+ self.transformer = FluxTransformer2DModel.from_config(transformer_config)
229
+ self.vae = AutoencoderKL.from_config(vae_config)
230
+ self.text_encoder = CLIPTextModel(text_encoder_config)
231
+ self.text_encoder_2 = T5EncoderModel(text_encoder_2_config)
232
+
233
+ # Initialize tokenizers (these don't use from_config as they are not models)
234
+ self.tokenizer = CLIPTokenizer.from_pretrained(ckpt_id, subfolder="tokenizer")
235
+ self.tokenizer_2 = T5TokenizerFast.from_pretrained(ckpt_id, subfolder="tokenizer_2")
236
+
237
+ self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(ckpt_id, subfolder="scheduler")
238
+
239
+ # Create the pipeline configuration dictionary
240
+ pipeline_config = {
241
+ "transformer": self.transformer,
242
+ "scheduler": self.scheduler,
243
+ "vae": self.vae,
244
+ "text_encoder": self.text_encoder,
245
+ "text_encoder_2": self.text_encoder_2,
246
+ "tokenizer": self.tokenizer,
247
+ "tokenizer_2": self.tokenizer_2,
248
+ }
249
+
250
+ self.diffusion_expert = FluxPipeline(**pipeline_config)
251
+
252
+
253
+ def initialize_diffusion_expert(self, fsdp=None):
254
+
255
+ if getattr(self, 'diffusion_expert', None) is None:
256
+ print("random initiation the diffusion expert !!!")
257
+ self.diffusion_expert = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", revision="main", torch_dtype=torch.bfloat16).to(torch.bfloat16)
258
+ self.text_encoder = self.diffusion_expert.text_encoder
259
+ self.text_encoder_2 = self.diffusion_expert.text_encoder_2
260
+ self.tokenizer = self.diffusion_expert.tokenizer
261
+ self.tokenizer_2 = self.diffusion_expert.tokenizer_2
262
+ self.vae = self.diffusion_expert.vae
263
+ self.transformer = self.diffusion_expert.transformer
264
+ self.scheduler = self.diffusion_expert.scheduler
265
+
266
+ self.config.diffusion_expert = "flux"
267
+
268
+
269
+
270
+ class QwenFluxConfig(Qwen2_5_VLConfig):
271
+ model_type = "QwenFlux"
272
+
273
+
274
+ class QwenFluxModel(QwenFluxMetaModel, Qwen2_5_VLModel):
275
+ config_class = QwenFluxConfig
276
+
277
+ def __init__(self, config: Qwen2_5_VLConfig):
278
+ super(QwenFluxModel, self).__init__(config)
279
+
280
+
281
+ class QwenFluxForInferenceLM(Qwen2_5_VLForConditionalGeneration):
282
+ config_class = QwenFluxConfig
283
+
284
+ def __init__(self, config):
285
+ Qwen2_5_VLForConditionalGeneration.__init__(self, config)
286
+ config.model_type = "QwenFlux"
287
+
288
+ self.model = QwenFluxModel(config)
289
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
290
+ self.post_init()
291
+
292
+ def get_model(self):
293
+ return self.model
294
+
295
+ @torch.no_grad()
296
+ def generate_image(
297
+ self,
298
+ texts: List[str],
299
+ diffusion_kwargs: Optional[Dict] = dict(guidance_scale = 3.5, num_inference_steps=25),
300
+ sde_sampling: Optional[bool] = False,
301
+ ):
302
+
303
+ if isinstance(texts, str):
304
+ texts = [texts]
305
+
306
+ if not sde_sampling:
307
+ output_img = self.model.diffusion_expert(
308
+ texts,
309
+ max_sequence_length=512,
310
+ **diffusion_kwargs,
311
+ ).images
312
+ return output_img
313
+ else:
314
+ return self.model.diffusion_expert.sde_sampling(
315
+ texts,
316
+ max_sequence_length=512,
317
+ **diffusion_kwargs,
318
+ )
319
+
320
+
321
+ def extract_thinking_content(self, text: str) -> str:
322
+ pattern = r'<answer>(.*?)</answer>'
323
+ matches = re.findall(pattern, text, re.DOTALL)
324
+
325
+ if matches:
326
+ return matches[-1].strip().replace("<answer>", "").replace("</answer>", "")
327
+ else:
328
+ return text.strip().replace("<answer>", "").replace("</answer>", "")
329
+
330
+ @torch.no_grad()
331
+ def generate_image_cot(
332
+ self,
333
+ texts: List[str],
334
+ processor: Optional[object] = None,
335
+ diffusion_kwargs: Optional[Dict] = dict(guidance_scale = 3.5, num_inference_steps=25),
336
+ llm_kwargs: Optional[Dict] = dict(max_new_tokens=256, temperature=0.7, top_p=0.9, do_sample=True),
337
+ cot_prompt_template: Optional[str] = None,
338
+ ):
339
+
340
+ if isinstance(texts, str):
341
+ texts = [texts]
342
+
343
+ if cot_prompt_template is None:
344
+ # cot_prompt_template = """Please improve the following image generation prompt to make it more detailed and specific for better image quality. Think step by step about what visual elements would make this image more compelling. Original prompt: {original_prompt}. Please provide the improved prompt in <thinking> </thinking> tags."""
345
+ cot_prompt_template = """Please provide an enhanced prompt for the following image generation prompt to make the image more realistic, detailed, with clear separation and precise alignment of all entities.
346
+ Original prompt: {original_prompt}. Directly provide the improved prompt in <answer> </answer> tags."""
347
+
348
+ improved_prompts = []
349
+
350
+ for text in texts:
351
+ cot_input = cot_prompt_template.format(original_prompt=text)
352
+
353
+ messages = [{"role": "user", "content": cot_input}]
354
+ input_text_formatted = processor.apply_chat_template(
355
+ messages, tokenize=False, add_generation_prompt=True
356
+ )
357
+ model_inputs = processor(
358
+ text=[input_text_formatted],
359
+ return_tensors="pt"
360
+ ).to(self.device)
361
+
362
+ generated_ids = self.generate(
363
+ **model_inputs,
364
+ **llm_kwargs,
365
+ eos_token_id=processor.tokenizer.eos_token_id,
366
+ pad_token_id=processor.tokenizer.pad_token_id
367
+ )
368
+
369
+ generated_text = processor.batch_decode(
370
+ generated_ids[:, model_inputs['input_ids'].shape[1]:],
371
+ skip_special_tokens=True
372
+ )
373
+
374
+ improved_prompt = [self.extract_thinking_content(decode_text) for decode_text in generated_text]
375
+ improved_prompts.extend(improved_prompt)
376
+
377
+ print(f"Original prompt: {text}")
378
+ print(f"Improved prompt: {improved_prompt}")
379
+ print("-" * 50)
380
+
381
+ output_images = self.generate_image(improved_prompts, diffusion_kwargs)
382
+
383
+ return {
384
+ 'images': output_images,
385
+ 'original_prompts': texts,
386
+ 'improved_prompts': improved_prompts
387
+ }
388
+
389
+ AutoConfig.register("QwenFlux", QwenFluxConfig)
390
+ AutoModelForCausalLM.register(QwenFluxConfig, QwenFluxForInferenceLM)
391
+
392
+
393
+ if __name__ == "__main__":
394
+
395
+ model = QwenFluxForInferenceLM.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct",torch_dtype=torch.bfloat16)
396
+ model.model.initialize_diffusion_expert()
397
+ model.model.diffusion_expert.to("cuda:0")
398
+ model.to("cuda:0")
399
+ AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
400
+ text = ["a photo of a cat"]
401
+ images = model.generate_image(text)
402
+ images[0].save("test_flux.png")
403
+
404
+ model.save_pretrained("outputs/pretrain/qwenflux")
405
+
406
+
407
+ model = QwenFluxForInferenceLM.from_pretrained("outputs/pretrain/qwenflux", torch_dtype=torch.bfloat16)
408
+ model.to("cuda:0")
409
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
410
+ text = ["a photo of a cat"]
411
+ images = model.generate_image(text)
412
+ images[0].save("test_flux.jpg")
413
+
414
+ outputs = model.generate_image_cot(text, processor = processor)
415
+ outputs['images'][0].save("test_flux_cot.jpg")
416
+
417
+
418
+
unimodel/qwenkontext/fluxkontext_pipeline.py ADDED
@@ -0,0 +1,1161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ # Copyright 2025 Fu-Yun Wang
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ from typing import Any, Callable, Dict, List, Optional, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ from transformers import (
22
+ CLIPImageProcessor,
23
+ CLIPTextModel,
24
+ CLIPTokenizer,
25
+ CLIPVisionModelWithProjection,
26
+ T5EncoderModel,
27
+ T5TokenizerFast,
28
+ )
29
+
30
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
31
+ from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
32
+ from diffusers.models import AutoencoderKL, FluxTransformer2DModel
33
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
34
+ from diffusers.utils import (
35
+ USE_PEFT_BACKEND,
36
+ deprecate,
37
+ is_torch_xla_available,
38
+ logging,
39
+ replace_example_docstring,
40
+ scale_lora_layers,
41
+ unscale_lora_layers,
42
+ )
43
+ from diffusers.utils.torch_utils import randn_tensor
44
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
45
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
46
+
47
+
48
+ if is_torch_xla_available():
49
+ import torch_xla.core.xla_model as xm
50
+
51
+ XLA_AVAILABLE = True
52
+ else:
53
+ XLA_AVAILABLE = False
54
+
55
+
56
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
57
+
58
+ EXAMPLE_DOC_STRING = """
59
+ Examples:
60
+ ```py
61
+ >>> import torch
62
+ >>> from diffusers import FluxKontextPipeline
63
+ >>> from diffusers.utils import load_image
64
+
65
+ >>> pipe = FluxKontextPipeline.from_pretrained(
66
+ ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
67
+ ... )
68
+ >>> pipe.to("cuda")
69
+
70
+ >>> image = load_image(
71
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
72
+ ... ).convert("RGB")
73
+ >>> prompt = "Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"
74
+ >>> image = pipe(
75
+ ... image=image,
76
+ ... prompt=prompt,
77
+ ... guidance_scale=2.5,
78
+ ... generator=torch.Generator().manual_seed(42),
79
+ ... ).images[0]
80
+ >>> image.save("output.png")
81
+ ```
82
+ """
83
+
84
+ PREFERRED_KONTEXT_RESOLUTIONS = [
85
+ (672, 1568),
86
+ (688, 1504),
87
+ (720, 1456),
88
+ (752, 1392),
89
+ (800, 1328),
90
+ (832, 1248),
91
+ (880, 1184),
92
+ (944, 1104),
93
+ (1024, 1024),
94
+ (1104, 944),
95
+ (1184, 880),
96
+ (1248, 832),
97
+ (1328, 800),
98
+ (1392, 752),
99
+ (1456, 720),
100
+ (1504, 688),
101
+ (1568, 672),
102
+ ]
103
+
104
+
105
+ def calculate_shift(
106
+ image_seq_len,
107
+ base_seq_len: int = 256,
108
+ max_seq_len: int = 4096,
109
+ base_shift: float = 0.5,
110
+ max_shift: float = 1.15,
111
+ ):
112
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
113
+ b = base_shift - m * base_seq_len
114
+ mu = image_seq_len * m + b
115
+ return mu
116
+
117
+
118
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
119
+ def retrieve_timesteps(
120
+ scheduler,
121
+ num_inference_steps: Optional[int] = None,
122
+ device: Optional[Union[str, torch.device]] = None,
123
+ timesteps: Optional[List[int]] = None,
124
+ sigmas: Optional[List[float]] = None,
125
+ **kwargs,
126
+ ):
127
+ r"""
128
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
129
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
130
+
131
+ Args:
132
+ scheduler (`SchedulerMixin`):
133
+ The scheduler to get timesteps from.
134
+ num_inference_steps (`int`):
135
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
136
+ must be `None`.
137
+ device (`str` or `torch.device`, *optional*):
138
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
139
+ timesteps (`List[int]`, *optional*):
140
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
141
+ `num_inference_steps` and `sigmas` must be `None`.
142
+ sigmas (`List[float]`, *optional*):
143
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
144
+ `num_inference_steps` and `timesteps` must be `None`.
145
+
146
+ Returns:
147
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
148
+ second element is the number of inference steps.
149
+ """
150
+ if timesteps is not None and sigmas is not None:
151
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
152
+ if timesteps is not None:
153
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
154
+ if not accepts_timesteps:
155
+ raise ValueError(
156
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
157
+ f" timestep schedules. Please check whether you are using the correct scheduler."
158
+ )
159
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
160
+ timesteps = scheduler.timesteps
161
+ num_inference_steps = len(timesteps)
162
+ elif sigmas is not None:
163
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
164
+ if not accept_sigmas:
165
+ raise ValueError(
166
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
167
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
168
+ )
169
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
170
+ timesteps = scheduler.timesteps
171
+ num_inference_steps = len(timesteps)
172
+ else:
173
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
174
+ timesteps = scheduler.timesteps
175
+ return timesteps, num_inference_steps
176
+
177
+
178
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
179
+ def retrieve_latents(
180
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
181
+ ):
182
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
183
+ return encoder_output.latent_dist.sample(generator)
184
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
185
+ return encoder_output.latent_dist.mode()
186
+ elif hasattr(encoder_output, "latents"):
187
+ return encoder_output.latents
188
+ else:
189
+ raise AttributeError("Could not access latents of provided encoder_output")
190
+
191
+
192
+ class FluxKontextPipeline(
193
+ DiffusionPipeline,
194
+ FluxLoraLoaderMixin,
195
+ FromSingleFileMixin,
196
+ TextualInversionLoaderMixin,
197
+ FluxIPAdapterMixin,
198
+ ):
199
+ r"""
200
+ The Flux Kontext pipeline for image-to-image and text-to-image generation.
201
+
202
+ Reference: https://bfl.ai/announcements/flux-1-kontext-dev
203
+
204
+ Args:
205
+ transformer ([`FluxTransformer2DModel`]):
206
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
207
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
208
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
209
+ vae ([`AutoencoderKL`]):
210
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
211
+ text_encoder ([`CLIPTextModel`]):
212
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
213
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
214
+ text_encoder_2 ([`T5EncoderModel`]):
215
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
216
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
217
+ tokenizer (`CLIPTokenizer`):
218
+ Tokenizer of class
219
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
220
+ tokenizer_2 (`T5TokenizerFast`):
221
+ Second Tokenizer of class
222
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
223
+ """
224
+
225
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
226
+ _optional_components = ["image_encoder", "feature_extractor"]
227
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
228
+
229
+ def __init__(
230
+ self,
231
+ scheduler: FlowMatchEulerDiscreteScheduler,
232
+ vae: AutoencoderKL,
233
+ text_encoder: CLIPTextModel,
234
+ tokenizer: CLIPTokenizer,
235
+ text_encoder_2: T5EncoderModel,
236
+ tokenizer_2: T5TokenizerFast,
237
+ transformer: FluxTransformer2DModel,
238
+ image_encoder: CLIPVisionModelWithProjection = None,
239
+ feature_extractor: CLIPImageProcessor = None,
240
+ ):
241
+ super().__init__()
242
+
243
+ self.register_modules(
244
+ vae=vae,
245
+ text_encoder=text_encoder,
246
+ text_encoder_2=text_encoder_2,
247
+ tokenizer=tokenizer,
248
+ tokenizer_2=tokenizer_2,
249
+ transformer=transformer,
250
+ scheduler=scheduler,
251
+ image_encoder=image_encoder,
252
+ feature_extractor=feature_extractor,
253
+ )
254
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
255
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
256
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
257
+ self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
258
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
259
+ self.tokenizer_max_length = (
260
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
261
+ )
262
+ self.default_sample_size = 128
263
+
264
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
265
+ def _get_t5_prompt_embeds(
266
+ self,
267
+ prompt: Union[str, List[str]] = None,
268
+ num_images_per_prompt: int = 1,
269
+ max_sequence_length: int = 512,
270
+ device: Optional[torch.device] = None,
271
+ dtype: Optional[torch.dtype] = None,
272
+ ):
273
+ device = device or self._execution_device
274
+ dtype = dtype or self.text_encoder.dtype
275
+
276
+ prompt = [prompt] if isinstance(prompt, str) else prompt
277
+ batch_size = len(prompt)
278
+
279
+ if isinstance(self, TextualInversionLoaderMixin):
280
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
281
+
282
+ text_inputs = self.tokenizer_2(
283
+ prompt,
284
+ padding="max_length",
285
+ max_length=max_sequence_length,
286
+ truncation=True,
287
+ return_length=False,
288
+ return_overflowing_tokens=False,
289
+ return_tensors="pt",
290
+ )
291
+ text_input_ids = text_inputs.input_ids
292
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
293
+
294
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
295
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
296
+ logger.warning(
297
+ "The following part of your input was truncated because `max_sequence_length` is set to "
298
+ f" {max_sequence_length} tokens: {removed_text}"
299
+ )
300
+
301
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
302
+
303
+ dtype = self.text_encoder_2.dtype
304
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
305
+
306
+ _, seq_len, _ = prompt_embeds.shape
307
+
308
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
309
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
310
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
311
+
312
+ return prompt_embeds
313
+
314
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
315
+ def _get_clip_prompt_embeds(
316
+ self,
317
+ prompt: Union[str, List[str]],
318
+ num_images_per_prompt: int = 1,
319
+ device: Optional[torch.device] = None,
320
+ ):
321
+ device = device or self._execution_device
322
+
323
+ prompt = [prompt] if isinstance(prompt, str) else prompt
324
+ batch_size = len(prompt)
325
+
326
+ if isinstance(self, TextualInversionLoaderMixin):
327
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
328
+
329
+ text_inputs = self.tokenizer(
330
+ prompt,
331
+ padding="max_length",
332
+ max_length=self.tokenizer_max_length,
333
+ truncation=True,
334
+ return_overflowing_tokens=False,
335
+ return_length=False,
336
+ return_tensors="pt",
337
+ )
338
+
339
+ text_input_ids = text_inputs.input_ids
340
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
341
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
342
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
343
+ logger.warning(
344
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
345
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
346
+ )
347
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
348
+
349
+ # Use pooled output of CLIPTextModel
350
+ prompt_embeds = prompt_embeds.pooler_output
351
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
352
+
353
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
354
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
355
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
356
+
357
+ return prompt_embeds
358
+
359
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
360
+ def encode_prompt(
361
+ self,
362
+ prompt: Union[str, List[str]],
363
+ prompt_2: Optional[Union[str, List[str]]] = None,
364
+ device: Optional[torch.device] = None,
365
+ num_images_per_prompt: int = 1,
366
+ prompt_embeds: Optional[torch.FloatTensor] = None,
367
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
368
+ max_sequence_length: int = 512,
369
+ lora_scale: Optional[float] = None,
370
+ ):
371
+ r"""
372
+
373
+ Args:
374
+ prompt (`str` or `List[str]`, *optional*):
375
+ prompt to be encoded
376
+ prompt_2 (`str` or `List[str]`, *optional*):
377
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
378
+ used in all text-encoders
379
+ device: (`torch.device`):
380
+ torch device
381
+ num_images_per_prompt (`int`):
382
+ number of images that should be generated per prompt
383
+ prompt_embeds (`torch.FloatTensor`, *optional*):
384
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
385
+ provided, text embeddings will be generated from `prompt` input argument.
386
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
387
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
388
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
389
+ lora_scale (`float`, *optional*):
390
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
391
+ """
392
+ device = device or self._execution_device
393
+
394
+ # set lora scale so that monkey patched LoRA
395
+ # function of text encoder can correctly access it
396
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
397
+ self._lora_scale = lora_scale
398
+
399
+ # dynamically adjust the LoRA scale
400
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
401
+ scale_lora_layers(self.text_encoder, lora_scale)
402
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
403
+ scale_lora_layers(self.text_encoder_2, lora_scale)
404
+
405
+ prompt = [prompt] if isinstance(prompt, str) else prompt
406
+
407
+ if prompt_embeds is None:
408
+ prompt_2 = prompt_2 or prompt
409
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
410
+
411
+ # We only use the pooled prompt output from the CLIPTextModel
412
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
413
+ prompt=prompt,
414
+ device=device,
415
+ num_images_per_prompt=num_images_per_prompt,
416
+ )
417
+ prompt_embeds = self._get_t5_prompt_embeds(
418
+ prompt=prompt_2,
419
+ num_images_per_prompt=num_images_per_prompt,
420
+ max_sequence_length=max_sequence_length,
421
+ device=device,
422
+ )
423
+
424
+ if self.text_encoder is not None:
425
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
426
+ # Retrieve the original scale by scaling back the LoRA layers
427
+ unscale_lora_layers(self.text_encoder, lora_scale)
428
+
429
+ if self.text_encoder_2 is not None:
430
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
431
+ # Retrieve the original scale by scaling back the LoRA layers
432
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
433
+
434
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
435
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
436
+
437
+ return prompt_embeds, pooled_prompt_embeds, text_ids
438
+
439
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
440
+ def encode_image(self, image, device, num_images_per_prompt):
441
+ dtype = next(self.image_encoder.parameters()).dtype
442
+
443
+ if not isinstance(image, torch.Tensor):
444
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
445
+
446
+ image = image.to(device=device, dtype=dtype)
447
+ image_embeds = self.image_encoder(image).image_embeds
448
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
449
+ return image_embeds
450
+
451
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
452
+ def prepare_ip_adapter_image_embeds(
453
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
454
+ ):
455
+ image_embeds = []
456
+ if ip_adapter_image_embeds is None:
457
+ if not isinstance(ip_adapter_image, list):
458
+ ip_adapter_image = [ip_adapter_image]
459
+
460
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
461
+ raise ValueError(
462
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
463
+ )
464
+
465
+ for single_ip_adapter_image in ip_adapter_image:
466
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
467
+ image_embeds.append(single_image_embeds[None, :])
468
+ else:
469
+ if not isinstance(ip_adapter_image_embeds, list):
470
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
471
+
472
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
473
+ raise ValueError(
474
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
475
+ )
476
+
477
+ for single_image_embeds in ip_adapter_image_embeds:
478
+ image_embeds.append(single_image_embeds)
479
+
480
+ ip_adapter_image_embeds = []
481
+ for single_image_embeds in image_embeds:
482
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
483
+ single_image_embeds = single_image_embeds.to(device=device)
484
+ ip_adapter_image_embeds.append(single_image_embeds)
485
+
486
+ return ip_adapter_image_embeds
487
+
488
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs
489
+ def check_inputs(
490
+ self,
491
+ prompt,
492
+ prompt_2,
493
+ height,
494
+ width,
495
+ negative_prompt=None,
496
+ negative_prompt_2=None,
497
+ prompt_embeds=None,
498
+ negative_prompt_embeds=None,
499
+ pooled_prompt_embeds=None,
500
+ negative_pooled_prompt_embeds=None,
501
+ callback_on_step_end_tensor_inputs=None,
502
+ max_sequence_length=None,
503
+ ):
504
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
505
+ logger.warning(
506
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
507
+ )
508
+
509
+ if callback_on_step_end_tensor_inputs is not None and not all(
510
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
511
+ ):
512
+ raise ValueError(
513
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
514
+ )
515
+
516
+ if prompt is not None and prompt_embeds is not None:
517
+ raise ValueError(
518
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
519
+ " only forward one of the two."
520
+ )
521
+ elif prompt_2 is not None and prompt_embeds is not None:
522
+ raise ValueError(
523
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
524
+ " only forward one of the two."
525
+ )
526
+ elif prompt is None and prompt_embeds is None:
527
+ raise ValueError(
528
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
529
+ )
530
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
531
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
532
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
533
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
534
+
535
+ if negative_prompt is not None and negative_prompt_embeds is not None:
536
+ raise ValueError(
537
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
538
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
539
+ )
540
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
541
+ raise ValueError(
542
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
543
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
544
+ )
545
+
546
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
547
+ raise ValueError(
548
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
549
+ )
550
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
551
+ raise ValueError(
552
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
553
+ )
554
+
555
+ if max_sequence_length is not None and max_sequence_length > 512:
556
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
557
+
558
+ @staticmethod
559
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
560
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
561
+ latent_image_ids = torch.zeros(height, width, 3)
562
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
563
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
564
+
565
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
566
+
567
+ latent_image_ids = latent_image_ids.reshape(
568
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
569
+ )
570
+
571
+ return latent_image_ids.to(device=device, dtype=dtype)
572
+
573
+ @staticmethod
574
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
575
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
576
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
577
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
578
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
579
+
580
+ return latents
581
+
582
+ @staticmethod
583
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
584
+ def _unpack_latents(latents, height, width, vae_scale_factor):
585
+ batch_size, num_patches, channels = latents.shape
586
+
587
+ # VAE applies 8x compression on images but we must also account for packing which requires
588
+ # latent height and width to be divisible by 2.
589
+ height = 2 * (int(height) // (vae_scale_factor * 2))
590
+ width = 2 * (int(width) // (vae_scale_factor * 2))
591
+
592
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
593
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
594
+
595
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
596
+
597
+ return latents
598
+
599
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
600
+ if isinstance(generator, list):
601
+ image_latents = [
602
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
603
+ for i in range(image.shape[0])
604
+ ]
605
+ image_latents = torch.cat(image_latents, dim=0)
606
+ else:
607
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
608
+
609
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
610
+
611
+ return image_latents
612
+
613
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
614
+ def enable_vae_slicing(self):
615
+ r"""
616
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
617
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
618
+ """
619
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
620
+ deprecate(
621
+ "enable_vae_slicing",
622
+ "0.40.0",
623
+ depr_message,
624
+ )
625
+ self.vae.enable_slicing()
626
+
627
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
628
+ def disable_vae_slicing(self):
629
+ r"""
630
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
631
+ computing decoding in one step.
632
+ """
633
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
634
+ deprecate(
635
+ "disable_vae_slicing",
636
+ "0.40.0",
637
+ depr_message,
638
+ )
639
+ self.vae.disable_slicing()
640
+
641
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
642
+ def enable_vae_tiling(self):
643
+ r"""
644
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
645
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
646
+ processing larger images.
647
+ """
648
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
649
+ deprecate(
650
+ "enable_vae_tiling",
651
+ "0.40.0",
652
+ depr_message,
653
+ )
654
+ self.vae.enable_tiling()
655
+
656
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
657
+ def disable_vae_tiling(self):
658
+ r"""
659
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
660
+ computing decoding in one step.
661
+ """
662
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
663
+ deprecate(
664
+ "disable_vae_tiling",
665
+ "0.40.0",
666
+ depr_message,
667
+ )
668
+ self.vae.disable_tiling()
669
+
670
+ def prepare_latents(
671
+ self,
672
+ image: Optional[torch.Tensor],
673
+ batch_size: int,
674
+ num_channels_latents: int,
675
+ height: int,
676
+ width: int,
677
+ dtype: torch.dtype,
678
+ device: torch.device,
679
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
680
+ latents: Optional[torch.Tensor] = None,
681
+ ):
682
+ if isinstance(generator, list) and len(generator) != batch_size:
683
+ raise ValueError(
684
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
685
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
686
+ )
687
+
688
+ # VAE applies 8x compression on images but we must also account for packing which requires
689
+ # latent height and width to be divisible by 2.
690
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
691
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
692
+ shape = (batch_size, num_channels_latents, height, width)
693
+
694
+ image_latents = image_ids = None
695
+ if image is not None:
696
+ image = image.to(device=device, dtype=dtype)
697
+ if image.shape[1] != self.latent_channels:
698
+ image_latents = self._encode_vae_image(image=image, generator=generator)
699
+ else:
700
+ image_latents = image
701
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
702
+ # expand init_latents for batch_size
703
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
704
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
705
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
706
+ raise ValueError(
707
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
708
+ )
709
+ else:
710
+ image_latents = torch.cat([image_latents], dim=0)
711
+
712
+ image_latent_height, image_latent_width = image_latents.shape[2:]
713
+ image_latents = self._pack_latents(
714
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
715
+ )
716
+ image_ids = self._prepare_latent_image_ids(
717
+ batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype
718
+ )
719
+ # image ids are the same as latent ids with the first dimension set to 1 instead of 0
720
+ image_ids[..., 0] = 1
721
+
722
+ latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
723
+
724
+ if latents is None:
725
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
726
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
727
+ else:
728
+ latents = latents.to(device=device, dtype=dtype)
729
+
730
+ return latents, image_latents, latent_ids, image_ids
731
+
732
+ @property
733
+ def guidance_scale(self):
734
+ return self._guidance_scale
735
+
736
+ @property
737
+ def joint_attention_kwargs(self):
738
+ return self._joint_attention_kwargs
739
+
740
+ @property
741
+ def num_timesteps(self):
742
+ return self._num_timesteps
743
+
744
+ @property
745
+ def current_timestep(self):
746
+ return self._current_timestep
747
+
748
+ @property
749
+ def interrupt(self):
750
+ return self._interrupt
751
+
752
+ @torch.no_grad()
753
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
754
+ def __call__(
755
+ self,
756
+ image: Optional[PipelineImageInput] = None,
757
+ prompt: Union[str, List[str]] = None,
758
+ prompt_2: Optional[Union[str, List[str]]] = None,
759
+ negative_prompt: Union[str, List[str]] = None,
760
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
761
+ true_cfg_scale: float = 1.0,
762
+ height: Optional[int] = None,
763
+ width: Optional[int] = None,
764
+ num_inference_steps: int = 28,
765
+ sigmas: Optional[List[float]] = None,
766
+ guidance_scale: float = 3.5,
767
+ num_images_per_prompt: Optional[int] = 1,
768
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
769
+ latents: Optional[torch.FloatTensor] = None,
770
+ prompt_embeds: Optional[torch.FloatTensor] = None,
771
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
772
+ ip_adapter_image: Optional[PipelineImageInput] = None,
773
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
774
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
775
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
776
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
777
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
778
+ output_type: Optional[str] = "pil",
779
+ return_dict: bool = True,
780
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
781
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
782
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
783
+ max_sequence_length: int = 512,
784
+ max_area: int = 1024**2,
785
+ _auto_resize: bool = True,
786
+ ):
787
+ r"""
788
+ Function invoked when calling the pipeline for generation.
789
+
790
+ Args:
791
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
792
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
793
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
794
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
795
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
796
+ latents as `image`, but if passing latents directly it is not encoded again.
797
+ prompt (`str` or `List[str]`, *optional*):
798
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
799
+ instead.
800
+ prompt_2 (`str` or `List[str]`, *optional*):
801
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
802
+ will be used instead.
803
+ negative_prompt (`str` or `List[str]`, *optional*):
804
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
805
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
806
+ not greater than `1`).
807
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
808
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
809
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
810
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
811
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
812
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
813
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
814
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
815
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
816
+ num_inference_steps (`int`, *optional*, defaults to 50):
817
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
818
+ expense of slower inference.
819
+ sigmas (`List[float]`, *optional*):
820
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
821
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
822
+ will be used.
823
+ guidance_scale (`float`, *optional*, defaults to 3.5):
824
+ Embedded guidance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
825
+ a model to generate images more aligned with prompt at the expense of lower image quality.
826
+
827
+ Guidance-distilled models approximates true classifier-free guidance for `guidance_scale` > 1. Refer to
828
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
829
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
830
+ The number of images to generate per prompt.
831
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
832
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
833
+ to make generation deterministic.
834
+ latents (`torch.FloatTensor`, *optional*):
835
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
836
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
837
+ tensor will be generated by sampling using the supplied random `generator`.
838
+ prompt_embeds (`torch.FloatTensor`, *optional*):
839
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
840
+ provided, text embeddings will be generated from `prompt` input argument.
841
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
842
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
843
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
844
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
845
+ Optional image input to work with IP Adapters.
846
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
847
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
848
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
849
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
850
+ negative_ip_adapter_image:
851
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
852
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
853
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
854
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
855
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
856
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
857
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
858
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
859
+ argument.
860
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
861
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
862
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
863
+ input argument.
864
+ output_type (`str`, *optional*, defaults to `"pil"`):
865
+ The output format of the generate image. Choose between
866
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
867
+ return_dict (`bool`, *optional*, defaults to `True`):
868
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
869
+ joint_attention_kwargs (`dict`, *optional*):
870
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
871
+ `self.processor` in
872
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
873
+ callback_on_step_end (`Callable`, *optional*):
874
+ A function that calls at the end of each denoising steps during the inference. The function is called
875
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
876
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
877
+ `callback_on_step_end_tensor_inputs`.
878
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
879
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
880
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
881
+ `._callback_tensor_inputs` attribute of your pipeline class.
882
+ max_sequence_length (`int` defaults to 512):
883
+ Maximum sequence length to use with the `prompt`.
884
+ max_area (`int`, defaults to `1024 ** 2`):
885
+ The maximum area of the generated image in pixels. The height and width will be adjusted to fit this
886
+ area while maintaining the aspect ratio.
887
+
888
+ Examples:
889
+
890
+ Returns:
891
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
892
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
893
+ images.
894
+ """
895
+
896
+ height = height or self.default_sample_size * self.vae_scale_factor
897
+ width = width or self.default_sample_size * self.vae_scale_factor
898
+
899
+ original_height, original_width = height, width
900
+ aspect_ratio = width / height
901
+
902
+ width = round((max_area * aspect_ratio) ** 0.5)
903
+ height = round((max_area / aspect_ratio) ** 0.5)
904
+
905
+ multiple_of = self.vae_scale_factor * 2
906
+ width = width // multiple_of * multiple_of
907
+ height = height // multiple_of * multiple_of
908
+
909
+ if height != original_height or width != original_width:
910
+ logger.warning(
911
+ f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
912
+ )
913
+
914
+ # 1. Check inputs. Raise error if not correct
915
+ self.check_inputs(
916
+ prompt,
917
+ prompt_2,
918
+ height,
919
+ width,
920
+ negative_prompt=negative_prompt,
921
+ negative_prompt_2=negative_prompt_2,
922
+ prompt_embeds=prompt_embeds,
923
+ negative_prompt_embeds=negative_prompt_embeds,
924
+ pooled_prompt_embeds=pooled_prompt_embeds,
925
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
926
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
927
+ max_sequence_length=max_sequence_length,
928
+ )
929
+
930
+ self._guidance_scale = guidance_scale
931
+ self._joint_attention_kwargs = joint_attention_kwargs
932
+ self._current_timestep = None
933
+ self._interrupt = False
934
+
935
+ # 2. Define call parameters
936
+ if prompt is not None and isinstance(prompt, str):
937
+ batch_size = 1
938
+ elif prompt is not None and isinstance(prompt, list):
939
+ batch_size = len(prompt)
940
+ else:
941
+ batch_size = prompt_embeds.shape[0]
942
+
943
+ device = self._execution_device
944
+
945
+ lora_scale = (
946
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
947
+ )
948
+ has_neg_prompt = negative_prompt is not None or (
949
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
950
+ )
951
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
952
+ (
953
+ prompt_embeds,
954
+ pooled_prompt_embeds,
955
+ text_ids,
956
+ ) = self.encode_prompt(
957
+ prompt=prompt,
958
+ prompt_2=prompt_2,
959
+ prompt_embeds=prompt_embeds,
960
+ pooled_prompt_embeds=pooled_prompt_embeds,
961
+ device=device,
962
+ num_images_per_prompt=num_images_per_prompt,
963
+ max_sequence_length=max_sequence_length,
964
+ lora_scale=lora_scale,
965
+ )
966
+ if do_true_cfg:
967
+ (
968
+ negative_prompt_embeds,
969
+ negative_pooled_prompt_embeds,
970
+ negative_text_ids,
971
+ ) = self.encode_prompt(
972
+ prompt=negative_prompt,
973
+ prompt_2=negative_prompt_2,
974
+ prompt_embeds=negative_prompt_embeds,
975
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
976
+ device=device,
977
+ num_images_per_prompt=num_images_per_prompt,
978
+ max_sequence_length=max_sequence_length,
979
+ lora_scale=lora_scale,
980
+ )
981
+
982
+ # 3. Preprocess image
983
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
984
+ img = image[0] if isinstance(image, list) else image
985
+ image_height, image_width = self.image_processor.get_default_height_width(img)
986
+ aspect_ratio = image_width / image_height
987
+ if _auto_resize:
988
+ # Kontext is trained on specific resolutions, using one of them is recommended
989
+ _, image_width, image_height = min(
990
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
991
+ )
992
+ image_width = image_width // multiple_of * multiple_of
993
+ image_height = image_height // multiple_of * multiple_of
994
+ image = self.image_processor.resize(image, image_height, image_width)
995
+ image = self.image_processor.preprocess(image, image_height, image_width)
996
+
997
+ # 4. Prepare latent variables
998
+ num_channels_latents = self.transformer.config.in_channels // 4
999
+ latents, image_latents, latent_ids, image_ids = self.prepare_latents(
1000
+ image,
1001
+ batch_size * num_images_per_prompt,
1002
+ num_channels_latents,
1003
+ height,
1004
+ width,
1005
+ prompt_embeds.dtype,
1006
+ device,
1007
+ generator,
1008
+ latents,
1009
+ )
1010
+ if image_ids is not None:
1011
+ latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
1012
+
1013
+ # 5. Prepare timesteps
1014
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
1015
+ image_seq_len = latents.shape[1]
1016
+ mu = calculate_shift(
1017
+ image_seq_len,
1018
+ self.scheduler.config.get("base_image_seq_len", 256),
1019
+ self.scheduler.config.get("max_image_seq_len", 4096),
1020
+ self.scheduler.config.get("base_shift", 0.5),
1021
+ self.scheduler.config.get("max_shift", 1.15),
1022
+ )
1023
+ timesteps, num_inference_steps = retrieve_timesteps(
1024
+ self.scheduler,
1025
+ num_inference_steps,
1026
+ device,
1027
+ sigmas=sigmas,
1028
+ mu=mu,
1029
+ )
1030
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1031
+ self._num_timesteps = len(timesteps)
1032
+
1033
+ # handle guidance
1034
+ if self.transformer.config.guidance_embeds:
1035
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
1036
+ guidance = guidance.expand(latents.shape[0])
1037
+ else:
1038
+ guidance = None
1039
+
1040
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
1041
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
1042
+ ):
1043
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
1044
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
1045
+
1046
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
1047
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
1048
+ ):
1049
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
1050
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
1051
+
1052
+ if self.joint_attention_kwargs is None:
1053
+ self._joint_attention_kwargs = {}
1054
+
1055
+ image_embeds = None
1056
+ negative_image_embeds = None
1057
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1058
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1059
+ ip_adapter_image,
1060
+ ip_adapter_image_embeds,
1061
+ device,
1062
+ batch_size * num_images_per_prompt,
1063
+ )
1064
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
1065
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
1066
+ negative_ip_adapter_image,
1067
+ negative_ip_adapter_image_embeds,
1068
+ device,
1069
+ batch_size * num_images_per_prompt,
1070
+ )
1071
+
1072
+ # 6. Denoising loop
1073
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
1074
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
1075
+ self.scheduler.set_begin_index(0)
1076
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1077
+ for i, t in enumerate(timesteps):
1078
+ if self.interrupt:
1079
+ continue
1080
+
1081
+ self._current_timestep = t
1082
+ if image_embeds is not None:
1083
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
1084
+
1085
+ latent_model_input = latents
1086
+ if image_latents is not None:
1087
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
1088
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
1089
+
1090
+ noise_pred = self.transformer(
1091
+ hidden_states=latent_model_input,
1092
+ timestep=timestep / 1000,
1093
+ guidance=guidance,
1094
+ pooled_projections=pooled_prompt_embeds,
1095
+ encoder_hidden_states=prompt_embeds,
1096
+ txt_ids=text_ids,
1097
+ img_ids=latent_ids,
1098
+ joint_attention_kwargs=self.joint_attention_kwargs,
1099
+ return_dict=False,
1100
+ )[0]
1101
+ noise_pred = noise_pred[:, : latents.size(1)]
1102
+
1103
+ if do_true_cfg:
1104
+ if negative_image_embeds is not None:
1105
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
1106
+ neg_noise_pred = self.transformer(
1107
+ hidden_states=latent_model_input,
1108
+ timestep=timestep / 1000,
1109
+ guidance=guidance,
1110
+ pooled_projections=negative_pooled_prompt_embeds,
1111
+ encoder_hidden_states=negative_prompt_embeds,
1112
+ txt_ids=negative_text_ids,
1113
+ img_ids=latent_ids,
1114
+ joint_attention_kwargs=self.joint_attention_kwargs,
1115
+ return_dict=False,
1116
+ )[0]
1117
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
1118
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
1119
+
1120
+ # compute the previous noisy sample x_t -> x_t-1
1121
+ latents_dtype = latents.dtype
1122
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1123
+
1124
+ if latents.dtype != latents_dtype:
1125
+ if torch.backends.mps.is_available():
1126
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1127
+ latents = latents.to(latents_dtype)
1128
+
1129
+ if callback_on_step_end is not None:
1130
+ callback_kwargs = {}
1131
+ for k in callback_on_step_end_tensor_inputs:
1132
+ callback_kwargs[k] = locals()[k]
1133
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1134
+
1135
+ latents = callback_outputs.pop("latents", latents)
1136
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1137
+
1138
+ # call the callback, if provided
1139
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1140
+ progress_bar.update()
1141
+
1142
+ if XLA_AVAILABLE:
1143
+ xm.mark_step()
1144
+
1145
+ self._current_timestep = None
1146
+
1147
+ if output_type == "latent":
1148
+ image = latents
1149
+ else:
1150
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1151
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1152
+ image = self.vae.decode(latents, return_dict=False)[0]
1153
+ image = self.image_processor.postprocess(image, output_type=output_type)
1154
+
1155
+ # Offload all models
1156
+ self.maybe_free_model_hooks()
1157
+
1158
+ if not return_dict:
1159
+ return (image,)
1160
+
1161
+ return FluxPipelineOutput(images=image)
unimodel/qwenkontext/qwenkontext_inference.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Fu-Yun Wang
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, Tuple, Union, Dict
16
+ import torch
17
+ import torch.nn as nn
18
+ from PIL import Image
19
+ import torch.nn.functional as F
20
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoProcessor
21
+ from transformers import Qwen2_5_VLConfig, Qwen2_5_VLModel, Qwen2_5_VLForConditionalGeneration
22
+ from qwen_vl_utils import process_vision_info
23
+ import torchvision.transforms as transforms
24
+
25
+
26
+ from diffusers.utils.torch_utils import randn_tensor
27
+ from diffusers.pipelines.pipeline_utils import numpy_to_pil
28
+ import numpy as np
29
+ from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteSchedulerOutput
30
+ from diffusers.schedulers import DPMSolverMultistepScheduler
31
+ import math
32
+ from diffusers.utils.torch_utils import randn_tensor
33
+ from diffusers import FluxTransformer2DModel, AutoencoderKL, FlowMatchEulerDiscreteScheduler #, FluxKontextPipeline
34
+ from .fluxkontext_pipeline import FluxKontextPipeline
35
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, CLIPTextConfig, T5Config
36
+ import re
37
+ import datetime
38
+ import os
39
+
40
+
41
+ def save_grid_image(prompt, images, rows, cols):
42
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
43
+ base_dir = os.path.join("samples", timestamp, prompt[:100])
44
+ os.makedirs(base_dir, exist_ok=True)
45
+
46
+ filename = os.path.join(base_dir, "grid.jpg")
47
+ grid_image = create_image_grid(images, rows, cols)
48
+ grid_image.save(filename)
49
+
50
+ print(f"Saved: {filename}")
51
+
52
+ def create_image_grid(images, rows, cols):
53
+ """Creates a grid of images and returns a single PIL Image."""
54
+
55
+ assert len(images) == rows * cols
56
+
57
+ width, height = images[0].size
58
+ grid_width = width * cols
59
+ grid_height = height * rows
60
+
61
+ grid_image = Image.new('RGB', (grid_width, grid_height))
62
+
63
+ for i, image in enumerate(images):
64
+ x = (i % cols) * width
65
+ y = (i // cols) * height
66
+ grid_image.paste(image, (x, y))
67
+
68
+ return grid_image
69
+
70
+
71
+ def sde_step_with_logprob(
72
+ self: FlowMatchEulerDiscreteScheduler,
73
+ model_output: torch.FloatTensor,
74
+ timestep: Union[float, torch.FloatTensor],
75
+ sample: torch.FloatTensor,
76
+ prev_sample: Optional[torch.FloatTensor] = None,
77
+ generator: Optional[torch.Generator] = None,
78
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
79
+ """
80
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
81
+ process from the learned model outputs (most often the predicted velocity).
82
+
83
+ Args:
84
+ model_output (`torch.FloatTensor`):
85
+ The direct output from learned flow model.
86
+ timestep (`float`):
87
+ The current discrete timestep in the diffusion chain.
88
+ sample (`torch.FloatTensor`):
89
+ A current instance of a sample created by the diffusion process.
90
+ generator (`torch.Generator`, *optional*):
91
+ A random number generator.
92
+ """
93
+ step_index = [self.index_for_timestep(t) for t in timestep]
94
+ prev_step_index = [step+1 for step in step_index]
95
+ sigma = self.sigmas[step_index].view(-1, 1, 1).to(model_output.device)
96
+ sigma_prev = self.sigmas[prev_step_index].view(-1, 1, 1).to(model_output.device)
97
+ sigma_max = self.sigmas[1].item()
98
+ dt = sigma_prev - sigma
99
+
100
+ std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*1.0
101
+
102
+
103
+ # our sde
104
+ prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
105
+
106
+ if prev_sample is not None and generator is not None:
107
+ raise ValueError(
108
+ "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
109
+ " `prev_sample` stays `None`."
110
+ )
111
+
112
+ if prev_sample is None:
113
+ variance_noise = randn_tensor(
114
+ model_output.shape,
115
+ generator=generator,
116
+ device=model_output.device,
117
+ dtype=model_output.dtype,
118
+ )
119
+ prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1*dt) * variance_noise
120
+
121
+
122
+ log_prob = (
123
+ -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1*dt))**2))
124
+ - torch.log(std_dev_t * torch.sqrt(-1*dt))
125
+ - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
126
+ )
127
+
128
+ # mean along all but batch dimension
129
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
130
+
131
+ return prev_sample, log_prob, prev_sample_mean, std_dev_t * torch.sqrt(-1*dt)
132
+
133
+
134
+ # Copyright 2025 Fu-Yun Wang
135
+ #
136
+ # Licensed under the Apache License, Version 2.0 (the "License");
137
+ # you may not use this file except in compliance with the License.
138
+ # You may obtain a copy of the License at
139
+ #
140
+ # http://www.apache.org/licenses/LICENSE-2.0
141
+ #
142
+ # Unless required by applicable law or agreed to in writing, software
143
+ # distributed under the License is distributed on an "AS IS" BASIS,
144
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
145
+ # See the License for the specific language governing permissions and
146
+ # limitations under the License.
147
+
148
+ def sde_step_with_logprob_simple(
149
+ self: FlowMatchEulerDiscreteScheduler,
150
+ model_output: torch.FloatTensor,
151
+ timestep: Union[float, torch.FloatTensor],
152
+ sample: torch.FloatTensor,
153
+ prev_sample: Optional[torch.FloatTensor] = None,
154
+ generator: Optional[torch.Generator] = None,
155
+ ):
156
+ """
157
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
158
+ process from the learned model outputs (most often the predicted velocity).
159
+
160
+ Args:
161
+ model_output (`torch.FloatTensor`):
162
+ The direct output from learned flow model.
163
+ timestep (`float`):
164
+ The current discrete timestep in the diffusion chain.
165
+ sample (`torch.FloatTensor`):
166
+ A current instance of a sample created by the diffusion process.
167
+ generator (`torch.Generator`, *optional*):
168
+ A random number generator.
169
+ """
170
+
171
+ step_index = [self.index_for_timestep(t) for t in timestep]
172
+ prev_step_index = [step+1 for step in step_index]
173
+ sigma = self.sigmas[step_index].view(-1, 1, 1, 1).to(model_output.device)
174
+ sigma_prev = self.sigmas[prev_step_index].view(-1, 1, 1, 1).to(model_output.device)
175
+ sigma_max = self.sigmas[1].item()
176
+ dt = sigma_prev - sigma
177
+
178
+
179
+ eta = 0.5
180
+ Dt = - dt * eta
181
+
182
+ prev_sample_mean = sample * (1 - Dt / (1 - torch.where(sigma == 1, sigma_max, sigma))) + model_output * (dt - Dt)
183
+
184
+ std_dev_t = torch.sqrt(2 * Dt * (sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))))
185
+
186
+ if prev_sample is not None and generator is not None:
187
+ raise ValueError(
188
+ "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
189
+ " `prev_sample` stays `None`."
190
+ )
191
+
192
+ if prev_sample is None:
193
+ # Generate noise if not provided
194
+ variance_noise = randn_tensor(
195
+ model_output.shape,
196
+ generator=generator,
197
+ device=model_output.device,
198
+ dtype=model_output.dtype,
199
+ )
200
+
201
+ prev_sample = prev_sample_mean + std_dev_t * variance_noise
202
+
203
+
204
+ log_prob = (
205
+ -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
206
+ - torch.log(std_dev_t)
207
+ - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
208
+ )
209
+
210
+ # mean along all but batch dimension
211
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
212
+
213
+ return prev_sample, log_prob, prev_sample_mean, std_dev_t
214
+
215
+ class QwenKontextMetaModel:
216
+
217
+ def __init__(self, config):
218
+ super(QwenKontextMetaModel, self).__init__(config)
219
+
220
+ if hasattr(config, "diffusion_expert"):
221
+ ckpt_id = "black-forest-labs/FLUX.1-Kontext-dev"
222
+ # Load configuration for each component
223
+ transformer_config = FluxTransformer2DModel.load_config(ckpt_id, subfolder="transformer")
224
+ vae_config = AutoencoderKL.load_config(ckpt_id, subfolder="vae")
225
+ text_encoder_config = CLIPTextConfig.from_pretrained(ckpt_id, subfolder="text_encoder")
226
+ text_encoder_2_config = T5Config.from_pretrained(ckpt_id, subfolder="text_encoder_2")
227
+
228
+ # Initialize components from their configurations
229
+ self.transformer = FluxTransformer2DModel.from_config(transformer_config)
230
+ self.vae = AutoencoderKL.from_config(vae_config)
231
+ self.text_encoder = CLIPTextModel(text_encoder_config)
232
+ self.text_encoder_2 = T5EncoderModel(text_encoder_2_config)
233
+
234
+ # Initialize tokenizers (these don't use from_config as they are not models)
235
+ self.tokenizer = CLIPTokenizer.from_pretrained(ckpt_id, subfolder="tokenizer")
236
+ self.tokenizer_2 = T5TokenizerFast.from_pretrained(ckpt_id, subfolder="tokenizer_2")
237
+
238
+ self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(ckpt_id, subfolder="scheduler")
239
+
240
+ # Create the pipeline configuration dictionary
241
+ pipeline_config = {
242
+ "transformer": self.transformer,
243
+ "scheduler": self.scheduler,
244
+ "vae": self.vae,
245
+ "text_encoder": self.text_encoder,
246
+ "text_encoder_2": self.text_encoder_2,
247
+ "tokenizer": self.tokenizer,
248
+ "tokenizer_2": self.tokenizer_2,
249
+ }
250
+
251
+ self.diffusion_expert = FluxKontextPipeline(**pipeline_config)
252
+
253
+
254
+ def initialize_diffusion_expert(self, fsdp=None):
255
+
256
+ if getattr(self, 'diffusion_expert', None) is None:
257
+ print("random initiation the diffusion expert !!!")
258
+ self.diffusion_expert = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", revision="main", torch_dtype=torch.bfloat16).to(torch.bfloat16)
259
+ self.text_encoder = self.diffusion_expert.text_encoder
260
+ self.text_encoder_2 = self.diffusion_expert.text_encoder_2
261
+ self.tokenizer = self.diffusion_expert.tokenizer
262
+ self.tokenizer_2 = self.diffusion_expert.tokenizer_2
263
+ self.vae = self.diffusion_expert.vae
264
+ self.transformer = self.diffusion_expert.transformer
265
+ self.scheduler = self.diffusion_expert.scheduler
266
+
267
+ self.config.diffusion_expert = "flux"
268
+
269
+
270
+
271
+ class QwenKontextConfig(Qwen2_5_VLConfig):
272
+ model_type = "QwenKontext"
273
+
274
+
275
+ class QwenKontextModel(QwenKontextMetaModel, Qwen2_5_VLModel):
276
+ config_class = QwenKontextConfig
277
+
278
+ def __init__(self, config: Qwen2_5_VLConfig):
279
+ super(QwenKontextModel, self).__init__(config)
280
+
281
+
282
+ class QwenKontextForInferenceLM(Qwen2_5_VLForConditionalGeneration):
283
+ config_class = QwenKontextConfig
284
+
285
+ def __init__(self, config):
286
+ Qwen2_5_VLForConditionalGeneration.__init__(self, config)
287
+ config.model_type = "QwenKontext"
288
+
289
+ self.model = QwenKontextModel(config)
290
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
291
+ self.post_init()
292
+
293
+ def get_model(self):
294
+ return self.model
295
+
296
+ @torch.no_grad()
297
+ def generate_image(
298
+ self,
299
+ images: List[Image.Image],
300
+ texts: List[str],
301
+ diffusion_kwargs: Optional[Dict] = dict(guidance_scale = 3.5, num_inference_steps=25),
302
+ sde_sampling: Optional[bool] = False,
303
+ ):
304
+
305
+ if isinstance(texts, str):
306
+ texts = [texts]
307
+
308
+ if not sde_sampling:
309
+ output_img = self.model.diffusion_expert(
310
+ images,
311
+ texts,
312
+ max_sequence_length=512,
313
+ **diffusion_kwargs,
314
+ ).images
315
+ return output_img
316
+ else:
317
+ return self.model.diffusion_expert.sde_sampling(
318
+ images,
319
+ texts,
320
+ max_sequence_length=512,
321
+ **diffusion_kwargs,
322
+ )
323
+
324
+
325
+ def extract_thinking_content(self, text: str) -> str:
326
+ pattern = r'<answer>(.*?)</answer>'
327
+ matches = re.findall(pattern, text, re.DOTALL)
328
+
329
+ if matches:
330
+ return matches[-1].strip().replace("<answer>", "").replace("</answer>", "")
331
+ else:
332
+ return text.strip().replace("<answer>", "").replace("</answer>", "")
333
+
334
+ @torch.no_grad()
335
+ def generate_image_cot(
336
+ self,
337
+ images: List[Image.Image],
338
+ texts: List[str],
339
+ processor: Optional[object] = None,
340
+ diffusion_kwargs: Optional[Dict] = dict(guidance_scale = 2.5, num_inference_steps=25),
341
+ llm_kwargs: Optional[Dict] = dict(max_new_tokens=256, temperature=0.7, top_p=0.9, do_sample=True),
342
+ cot_prompt_template: Optional[str] = None,
343
+ ):
344
+
345
+ if isinstance(texts, str):
346
+ texts = [texts]
347
+
348
+ if cot_prompt_template is None:
349
+ cot_prompt_template = """Please provide an enhanced prompt for the following image editing prompt.
350
+ Ensure the revised prompt is clear, specific, and includes detailed instructions to achieve the desired outcome while maintaining the original intent.
351
+ Original prompt: {original_prompt}. Directly provide the improved prompt in <answer> </answer> tags."""
352
+
353
+
354
+ improved_prompts = []
355
+
356
+ for text, image in zip(texts, images):
357
+ cot_input = cot_prompt_template.format(original_prompt=text)
358
+
359
+ messages = [
360
+ {
361
+ "role": "user",
362
+ "content": [
363
+ {
364
+ "type": "image",
365
+ "image": image,
366
+ },
367
+ {"type": "text", "text": cot_input},
368
+ ],
369
+ }
370
+ ]
371
+
372
+ input_text_formatted = processor.apply_chat_template(
373
+ messages, tokenize=False, add_generation_prompt=True
374
+ )
375
+ image_inputs, video_inputs = process_vision_info(messages)
376
+ model_inputs = processor(
377
+ images=image_inputs,
378
+ text=[input_text_formatted],
379
+ return_tensors="pt"
380
+ ).to(self.device)
381
+
382
+ generated_ids = self.generate(
383
+ **model_inputs,
384
+ **llm_kwargs,
385
+ eos_token_id=processor.tokenizer.eos_token_id,
386
+ pad_token_id=processor.tokenizer.pad_token_id
387
+ )
388
+
389
+ generated_text = processor.batch_decode(
390
+ generated_ids[:, model_inputs['input_ids'].shape[1]:],
391
+ skip_special_tokens=True
392
+ )
393
+
394
+ improved_prompt = [self.extract_thinking_content(decode_text) for decode_text in generated_text]
395
+ improved_prompts.extend(improved_prompt)
396
+
397
+ print(f"Original prompt: {text}")
398
+ print(f"Improved prompt: {improved_prompt}")
399
+ print("-" * 50)
400
+
401
+ output_images = self.generate_image(images, improved_prompts, diffusion_kwargs)
402
+
403
+ return {
404
+ 'ref_images': images,
405
+ 'images': output_images,
406
+ 'original_prompts': texts,
407
+ 'improved_prompts': improved_prompts
408
+ }
409
+
410
+ AutoConfig.register("QwenKontext", QwenKontextConfig)
411
+ AutoModelForCausalLM.register(QwenKontextConfig, QwenKontextForInferenceLM)
412
+
413
+
414
+ if __name__ == "__main__":
415
+ model = QwenKontextForInferenceLM.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct",torch_dtype=torch.bfloat16)
416
+ model.model.initialize_diffusion_expert()
417
+ model.model.diffusion_expert.to("cuda:0")
418
+ model.to("cuda:0")
419
+ AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
420
+ text = ["add a hat to him"]
421
+ ref_image = [Image.open("assets/images/cat.jpg").convert("RGB")]
422
+ images = model.generate_image(ref_image, text)
423
+ images[0].save("test_flux.jpg")
424
+ model.save_pretrained("outputs/pretrain/qwenkontext")
425
+
426
+
427
+ # model = QwenKontextForInferenceLM.from_pretrained("outputs/pretrain/qwenkontext", torch_dtype=torch.bfloat16)
428
+ # model.to("cuda:0")
429
+ # transform = transforms.Compose([
430
+ # transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), # Shortest side to 512
431
+ # transforms.CenterCrop((512, 512)) # Center crop to 512x512
432
+ # ])
433
+ # processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
434
+ # text = ["add a hat to him"]
435
+ # ref_image = [transform(Image.open("assets/images/cat.jpg").convert("RGB"))]
436
+ # ref_image[0].save("ref.jpg")
437
+ # images = model.generate_image(ref_image, text)
438
+ # images[0].save("test_flux.jpg")
439
+
440
+ # outputs = model.generate_image_cot(ref_image, text, processor = processor)
441
+ # outputs['images'][0].save("test_flux_cot.jpg")
442
+
unimodel/qwensana/qwensana_inference.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Fu-Yun Wang
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union, Dict
17
+ import torch
18
+ import torch.nn as nn
19
+ from PIL import Image
20
+ import torch.nn.functional as F
21
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoProcessor
22
+ from transformers import Qwen2_5_VLConfig, Qwen2_5_VLModel, Qwen2_5_VLForConditionalGeneration, T5Config, Gemma2Model, GemmaTokenizer, GemmaTokenizerFast, Gemma2Config, AutoConfig
23
+ from diffusers import SanaPipeline, AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaTransformer2DModel, DPMSolverMultistepScheduler
24
+ import re
25
+ import datetime
26
+ import os
27
+
28
+
29
+ def save_grid_image(prompt, images, rows, cols):
30
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
31
+ base_dir = os.path.join("samples", timestamp, prompt[:100])
32
+ os.makedirs(base_dir, exist_ok=True)
33
+
34
+ filename = os.path.join(base_dir, "grid.jpg")
35
+ grid_image = create_image_grid(images, rows, cols)
36
+ grid_image.save(filename)
37
+
38
+ print(f"Saved: {filename}")
39
+
40
+ def create_image_grid(images, rows, cols):
41
+ """Creates a grid of images and returns a single PIL Image."""
42
+ assert len(images) == rows * cols
43
+
44
+ width, height = images[0].size
45
+ grid_width = width * cols
46
+ grid_height = height * rows
47
+
48
+ grid_image = Image.new('RGB', (grid_width, grid_height))
49
+
50
+ for i, image in enumerate(images):
51
+ x = (i % cols) * width
52
+ y = (i // cols) * height
53
+ grid_image.paste(image, (x, y))
54
+
55
+ return grid_image
56
+
57
+
58
+ class QwenSanaMetaModel:
59
+
60
+ def __init__(self, config):
61
+ super(QwenSanaMetaModel, self).__init__(config)
62
+ if hasattr(config, "diffusion_expert"):
63
+ ckpt_id = "Efficient-Large-Model/SANA1.5_1.6B_1024px_diffusers"
64
+
65
+ # Load configuration for each component
66
+ transformer_config = SanaTransformer2DModel.load_config(ckpt_id, subfolder="transformer")
67
+ vae_config = AutoencoderDC.load_config(ckpt_id, subfolder="vae")
68
+ text_encoder_config = Gemma2Config.from_pretrained(ckpt_id, subfolder="text_encoder")
69
+ scheduler_config = DPMSolverMultistepScheduler.load_config(ckpt_id, subfolder="scheduler")
70
+ # Initialize components from their configurations
71
+ self.transformer = SanaTransformer2DModel.from_config(transformer_config)
72
+ self.vae = AutoencoderDC.from_config(vae_config)
73
+ self.text_encoder = Gemma2Model(text_encoder_config)
74
+ self.scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config)
75
+
76
+ # Initialize tokenizer
77
+ self.tokenizer = GemmaTokenizerFast.from_pretrained(ckpt_id, subfolder="tokenizer")
78
+
79
+ # Create the pipeline configuration dictionary
80
+ pipeline_config = {
81
+ "transformer": self.transformer,
82
+ "scheduler": self.scheduler,
83
+ "vae": self.vae,
84
+ "text_encoder": self.text_encoder,
85
+ "tokenizer": self.tokenizer,
86
+ }
87
+
88
+ self.diffusion_expert = SanaPipeline(**pipeline_config)
89
+
90
+ def initialize_diffusion_expert(self, fsdp=None):
91
+
92
+ if getattr(self, 'diffusion_expert', None) is None:
93
+ print("Random initiation the Sana diffusion expert !!!")
94
+ self.diffusion_expert = SanaPipeline.from_pretrained(
95
+ "Efficient-Large-Model/SANA1.5_1.6B_1024px_diffusers",
96
+ torch_dtype=torch.bfloat16
97
+ )
98
+
99
+ # Store references to components for easier access
100
+ self.transformer = self.diffusion_expert.transformer
101
+ self.vae = self.diffusion_expert.vae
102
+ self.text_encoder = self.diffusion_expert.text_encoder
103
+ self.tokenizer = self.diffusion_expert.tokenizer
104
+ self.scheduler = self.diffusion_expert.scheduler
105
+
106
+ self.config.diffusion_expert = "Sana"
107
+
108
+
109
+ class QwenSanaConfig(Qwen2_5_VLConfig):
110
+ model_type = "QwenSana"
111
+
112
+
113
+ class QwenSanaModel(QwenSanaMetaModel, Qwen2_5_VLModel):
114
+ config_class = QwenSanaConfig
115
+
116
+ def __init__(self, config: Qwen2_5_VLConfig):
117
+ super(QwenSanaModel, self).__init__(config)
118
+
119
+
120
+ class QwenSanaForInferenceLM(Qwen2_5_VLForConditionalGeneration):
121
+ config_class = QwenSanaConfig
122
+
123
+ def __init__(self, config):
124
+ Qwen2_5_VLForConditionalGeneration.__init__(self, config)
125
+ config.model_type = "QwenSana"
126
+
127
+ self.model = QwenSanaModel(config)
128
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
129
+ self.post_init()
130
+
131
+ def get_model(self):
132
+ return self.model
133
+
134
+ @torch.no_grad()
135
+ def generate_image(
136
+ self,
137
+ texts: List[str],
138
+ diffusion_kwargs: Optional[Dict] = None,
139
+ ):
140
+
141
+ if isinstance(texts, str):
142
+ texts = [texts]
143
+
144
+ # Default parameters for Sana
145
+ default_kwargs = dict(
146
+ guidance_scale=3.5,
147
+ num_inference_steps=20,
148
+ height=1024,
149
+ width=1024
150
+ )
151
+
152
+ if diffusion_kwargs:
153
+ default_kwargs.update(diffusion_kwargs)
154
+
155
+ output_img = self.model.diffusion_expert(
156
+ texts,
157
+ **default_kwargs,
158
+ ).images
159
+
160
+ return output_img
161
+
162
+ def extract_thinking_content(self, text: str) -> str:
163
+ pattern = r'<answer>(.*?)</answer>'
164
+ matches = re.findall(pattern, text, re.DOTALL)
165
+
166
+ if matches:
167
+ return matches[-1].strip().replace("<answer>", "").replace("</answer>", "")
168
+ else:
169
+ return text.strip().replace("<answer>", "").replace("</answer>", "")
170
+
171
+ @torch.no_grad()
172
+ def generate_image_cot(
173
+ self,
174
+ texts: List[str],
175
+ processor: Optional[object] = None,
176
+ diffusion_kwargs: Optional[Dict] = None,
177
+ llm_kwargs: Optional[Dict] = None,
178
+ cot_prompt_template: Optional[str] = None,
179
+ ):
180
+
181
+ if isinstance(texts, str):
182
+ texts = [texts]
183
+
184
+ # Default parameters
185
+ default_diffusion_kwargs = dict(
186
+ guidance_scale=5.0,
187
+ num_inference_steps=20,
188
+ height=1024,
189
+ width=1024
190
+ )
191
+ if diffusion_kwargs:
192
+ default_diffusion_kwargs.update(diffusion_kwargs)
193
+
194
+ default_llm_kwargs = dict(
195
+ max_new_tokens=256,
196
+ temperature=0.7,
197
+ top_p=0.9,
198
+ do_sample=True
199
+ )
200
+ if llm_kwargs:
201
+ default_llm_kwargs.update(llm_kwargs)
202
+
203
+ if cot_prompt_template is None:
204
+ cot_prompt_template = """Please provide an enhanced prompt for the following image generation prompt to make the image more realistic, detailed, with clear separation and precise alignment of all entities.
205
+ Original prompt: {original_prompt}. Directly provide the improved prompt in <answer> </answer> tags."""
206
+
207
+ improved_prompts = []
208
+
209
+ for text in texts:
210
+ cot_input = cot_prompt_template.format(original_prompt=text)
211
+
212
+ messages = [{"role": "user", "content": cot_input}]
213
+ input_text_formatted = processor.apply_chat_template(
214
+ messages, tokenize=False, add_generation_prompt=True
215
+ )
216
+ model_inputs = processor(
217
+ text=[input_text_formatted],
218
+ return_tensors="pt"
219
+ ).to(self.device)
220
+
221
+ generated_ids = self.generate(
222
+ **model_inputs,
223
+ **default_llm_kwargs,
224
+ eos_token_id=processor.tokenizer.eos_token_id,
225
+ pad_token_id=processor.tokenizer.pad_token_id
226
+ )
227
+
228
+ generated_text = processor.batch_decode(
229
+ generated_ids[:, model_inputs['input_ids'].shape[1]:],
230
+ skip_special_tokens=True
231
+ )
232
+
233
+ improved_prompt = [self.extract_thinking_content(decode_text) for decode_text in generated_text]
234
+ improved_prompts.extend(improved_prompt)
235
+
236
+ print(f"Original prompt: {text}")
237
+ print(f"Improved prompt: {improved_prompt}")
238
+ print("-" * 50)
239
+
240
+ output_images = self.generate_image(improved_prompts, default_diffusion_kwargs)
241
+
242
+ return {
243
+ 'images': output_images,
244
+ 'original_prompts': texts,
245
+ 'improved_prompts': improved_prompts
246
+ }
247
+
248
+
249
+ AutoConfig.register("QwenSana", QwenSanaConfig)
250
+ AutoModelForCausalLM.register(QwenSanaConfig, QwenSanaForInferenceLM)
251
+
252
+
253
+ if __name__ == "__main__":
254
+ model = QwenSanaForInferenceLM.from_pretrained(
255
+ "Qwen/Qwen2.5-VL-3B-Instruct",
256
+ torch_dtype=torch.bfloat16
257
+ )
258
+ model.model.initialize_diffusion_expert()
259
+ model.model.diffusion_expert.to("cuda:0")
260
+ model.to("cuda:0")
261
+
262
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
263
+
264
+ # Test basic image generation
265
+ text = ["a photo of a cat"]
266
+ diffusion_kwargs = dict(
267
+ guidance_scale=3.5,
268
+ num_inference_steps=20,
269
+ width=1024,
270
+ height=1024,
271
+ generator=torch.manual_seed(0)
272
+ )
273
+
274
+ images = model.generate_image(text, diffusion_kwargs=diffusion_kwargs)
275
+ images[0].save("test_Sana.jpg")
276
+
277
+ # Test chain-of-thought image generation
278
+ outputs = model.generate_image_cot(text, processor=processor, diffusion_kwargs=diffusion_kwargs)
279
+ outputs['images'][0].save("test_Sana_cot.jpg")
280
+
281
+ # Save the model
282
+ model.save_pretrained("outputs/pretrain/qwenSana-1.5")
283
+
284
+ # print("Sana model integration completed successfully!")
285
+
286
+ # model = QwenSanaForInferenceLM.from_pretrained(
287
+ # "outputs/pretrain/qwenSana-1.5",
288
+ # torch_dtype=torch.bfloat16
289
+ # ).to("cuda")
290
+
291
+
292
+ # processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
293
+
294
+ # # Test basic image generation
295
+ # text = ["a photo of a cat"]
296
+ # diffusion_kwargs = dict(
297
+ # guidance_scale=5.0,
298
+ # num_inference_steps=20,
299
+ # width=1024,
300
+ # height=1024,
301
+ # generator=torch.manual_seed(0)
302
+ # )
303
+
304
+ # images = model.generate_image(text, diffusion_kwargs=diffusion_kwargs)
305
+ # images[0].save("test_Sana.jpg")
306
+
307
+ # # Test chain-of-thought image generation
308
+ # outputs = model.generate_image_cot(text, processor=processor, diffusion_kwargs=diffusion_kwargs)
309
+ # outputs['images'][0].save("test_Sana_cot.jpg")
310
+
unimodel/qwensd3/qwensd3_inference.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Fu-Yun Wang
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, Tuple, Union, Dict
16
+ import torch
17
+ import torch.nn as nn
18
+ from PIL import Image
19
+ import torch.nn.functional as F
20
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoProcessor
21
+ from transformers import Qwen2_5_VLConfig, Qwen2_5_VLModel, Qwen2_5_VLForConditionalGeneration
22
+
23
+
24
+ from diffusers.utils.torch_utils import randn_tensor
25
+ from diffusers.pipelines.pipeline_utils import numpy_to_pil
26
+ import numpy as np
27
+ from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteSchedulerOutput
28
+ from diffusers.schedulers import DPMSolverMultistepScheduler
29
+ import math
30
+ from diffusers.utils.torch_utils import randn_tensor
31
+ from diffusers import SD3Transformer2DModel, AutoencoderKL, FlowMatchEulerDiscreteScheduler
32
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, CLIPTextConfig, T5Config, CLIPTextModelWithProjection
33
+ try:
34
+ from .sd3pipeline import StableDiffusion3Pipeline as SD3Pipeline
35
+ except:
36
+ from sd3pipeline import StableDiffusion3Pipeline as SD3Pipeline
37
+ # from diffusers import StableDiffusion3Pipeline as SD3Pipeline
38
+ import re
39
+ import datetime
40
+ import os
41
+ from transformers import GenerationConfig
42
+
43
+
44
+ def save_grid_image(prompt, images, rows, cols):
45
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
46
+ base_dir = os.path.join("samples", timestamp, prompt[:100])
47
+ os.makedirs(base_dir, exist_ok=True)
48
+
49
+ filename = os.path.join(base_dir, "grid.jpg")
50
+ grid_image = create_image_grid(images, rows, cols)
51
+ grid_image.save(filename)
52
+
53
+ print(f"Saved: {filename}")
54
+
55
+ def create_image_grid(images, rows, cols):
56
+ """Creates a grid of images and returns a single PIL Image."""
57
+
58
+ assert len(images) == rows * cols
59
+
60
+ width, height = images[0].size
61
+ grid_width = width * cols
62
+ grid_height = height * rows
63
+
64
+ grid_image = Image.new('RGB', (grid_width, grid_height))
65
+
66
+ for i, image in enumerate(images):
67
+ x = (i % cols) * width
68
+ y = (i // cols) * height
69
+ grid_image.paste(image, (x, y))
70
+
71
+ return grid_image
72
+
73
+
74
+ def sde_step_with_logprob(
75
+ self: FlowMatchEulerDiscreteScheduler,
76
+ model_output: torch.FloatTensor,
77
+ timestep: Union[float, torch.FloatTensor],
78
+ sample: torch.FloatTensor,
79
+ prev_sample: Optional[torch.FloatTensor] = None,
80
+ generator: Optional[torch.Generator] = None,
81
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
82
+ """
83
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
84
+ process from the learned model outputs (most often the predicted velocity).
85
+
86
+ Args:
87
+ model_output (`torch.FloatTensor`):
88
+ The direct output from learned flow model.
89
+ timestep (`float`):
90
+ The current discrete timestep in the diffusion chain.
91
+ sample (`torch.FloatTensor`):
92
+ A current instance of a sample created by the diffusion process.
93
+ generator (`torch.Generator`, *optional*):
94
+ A random number generator.
95
+ """
96
+ step_index = [self.index_for_timestep(t) for t in timestep]
97
+ prev_step_index = [step+1 for step in step_index]
98
+ sigma = self.sigmas[step_index].view(-1, 1, 1, 1).to(model_output.device)
99
+ sigma_prev = self.sigmas[prev_step_index].view(-1, 1, 1, 1).to(model_output.device)
100
+ sigma_max = self.sigmas[1].item()
101
+ dt = sigma_prev - sigma
102
+
103
+ std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*0.7
104
+
105
+
106
+ # our sde
107
+ prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
108
+
109
+ if prev_sample is not None and generator is not None:
110
+ raise ValueError(
111
+ "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
112
+ " `prev_sample` stays `None`."
113
+ )
114
+
115
+ if prev_sample is None:
116
+ variance_noise = randn_tensor(
117
+ model_output.shape,
118
+ generator=generator,
119
+ device=model_output.device,
120
+ dtype=model_output.dtype,
121
+ )
122
+ prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1*dt) * variance_noise
123
+
124
+
125
+ log_prob = (
126
+ -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1*dt))**2))
127
+ - torch.log(std_dev_t * torch.sqrt(-1*dt))
128
+ - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
129
+ )
130
+
131
+ # mean along all but batch dimension
132
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
133
+
134
+ return prev_sample, log_prob, prev_sample_mean, std_dev_t * torch.sqrt(-1*dt)
135
+
136
+
137
+
138
+ # Copyright 2025 Fu-Yun Wang
139
+ #
140
+ # Licensed under the Apache License, Version 2.0 (the "License");
141
+ # you may not use this file except in compliance with the License.
142
+ # You may obtain a copy of the License at
143
+ #
144
+ # http://www.apache.org/licenses/LICENSE-2.0
145
+ #
146
+ # Unless required by applicable law or agreed to in writing, software
147
+ # distributed under the License is distributed on an "AS IS" BASIS,
148
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
149
+ # See the License for the specific language governing permissions and
150
+ # limitations under the License.
151
+
152
+ def sde_step_with_logprob_simple(
153
+ self: FlowMatchEulerDiscreteScheduler,
154
+ model_output: torch.FloatTensor,
155
+ timestep: Union[float, torch.FloatTensor],
156
+ sample: torch.FloatTensor,
157
+ prev_sample: Optional[torch.FloatTensor] = None,
158
+ generator: Optional[torch.Generator] = None,
159
+ ):
160
+ """
161
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
162
+ process from the learned model outputs (most often the predicted velocity).
163
+
164
+ Args:
165
+ model_output (`torch.FloatTensor`):
166
+ The direct output from learned flow model.
167
+ timestep (`float`):
168
+ The current discrete timestep in the diffusion chain.
169
+ sample (`torch.FloatTensor`):
170
+ A current instance of a sample created by the diffusion process.
171
+ generator (`torch.Generator`, *optional*):
172
+ A random number generator.
173
+ """
174
+
175
+ step_index = [self.index_for_timestep(t) for t in timestep]
176
+ prev_step_index = [step+1 for step in step_index]
177
+ sigma = self.sigmas[step_index].view(-1, 1, 1, 1).to(model_output.device)
178
+ sigma_prev = self.sigmas[prev_step_index].view(-1, 1, 1, 1).to(model_output.device)
179
+ sigma_max = self.sigmas[1].item()
180
+ dt = sigma_prev - sigma
181
+
182
+
183
+ eta = 0.5
184
+ Dt = - dt * eta
185
+
186
+ prev_sample_mean = sample * (1 - Dt / (1 - torch.where(sigma == 1, sigma_max, sigma))) + model_output * (dt - Dt)
187
+
188
+ std_dev_t = torch.sqrt(2 * Dt * (sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))))
189
+
190
+ if prev_sample is not None and generator is not None:
191
+ raise ValueError(
192
+ "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
193
+ " `prev_sample` stays `None`."
194
+ )
195
+
196
+ if prev_sample is None:
197
+ # Generate noise if not provided
198
+ variance_noise = randn_tensor(
199
+ model_output.shape,
200
+ generator=generator,
201
+ device=model_output.device,
202
+ dtype=model_output.dtype,
203
+ )
204
+
205
+ prev_sample = prev_sample_mean + std_dev_t * variance_noise
206
+
207
+
208
+ log_prob = (
209
+ -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
210
+ - torch.log(std_dev_t)
211
+ - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
212
+ )
213
+
214
+ # mean along all but batch dimension
215
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
216
+
217
+ return prev_sample, log_prob, prev_sample_mean, std_dev_t
218
+
219
+ class QwenSD3MetaModel:
220
+
221
+ def __init__(self, config):
222
+ super(QwenSD3MetaModel, self).__init__(config)
223
+ if hasattr(config, "diffusion_expert"):
224
+ ckpt_id = "stabilityai/stable-diffusion-3.5-medium"
225
+
226
+ transformer_config = SD3Transformer2DModel.load_config(ckpt_id, subfolder="transformer")
227
+ vae_config = AutoencoderKL.load_config(ckpt_id, subfolder="vae")
228
+ text_encoder_config = CLIPTextConfig.from_pretrained(ckpt_id, subfolder="text_encoder", torch_dtype=config.torch_dtype)
229
+ text_encoder_2_config = CLIPTextConfig.from_pretrained(ckpt_id, subfolder="text_encoder_2", torch_dtype=config.torch_dtype)
230
+ text_encoder_3_config = T5Config.from_pretrained(ckpt_id, subfolder="text_encoder_3", torch_dtype=config.torch_dtype)
231
+
232
+ # Initialize components from their configurations
233
+ self.transformer = SD3Transformer2DModel.from_config(transformer_config)
234
+ self.vae = AutoencoderKL.from_config(vae_config)
235
+ self.text_encoder = CLIPTextModelWithProjection(text_encoder_config)
236
+ self.text_encoder_2 = CLIPTextModelWithProjection(text_encoder_2_config)
237
+ self.text_encoder_3 = T5EncoderModel(text_encoder_3_config)
238
+
239
+ # Initialize tokenizers (these don't use from_config as they are not models)
240
+ self.tokenizer = CLIPTokenizer.from_pretrained(ckpt_id, subfolder="tokenizer")
241
+ self.tokenizer_2 = CLIPTokenizer.from_pretrained(ckpt_id, subfolder="tokenizer_2")
242
+ self.tokenizer_3 = T5TokenizerFast.from_pretrained(ckpt_id, subfolder="tokenizer_3")
243
+ self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(ckpt_id, subfolder="scheduler")
244
+
245
+ # Create the pipeline configuration dictionary
246
+ pipeline_config = {
247
+ "transformer": self.transformer,
248
+ "scheduler": self.scheduler,
249
+ "vae": self.vae,
250
+ "text_encoder": self.text_encoder,
251
+ "text_encoder_2": self.text_encoder_2,
252
+ "text_encoder_3": self.text_encoder_3,
253
+ "tokenizer": self.tokenizer,
254
+ "tokenizer_2": self.tokenizer_2,
255
+ "tokenizer_3": self.tokenizer_3,
256
+ }
257
+
258
+ self.diffusion_expert = SD3Pipeline(**pipeline_config)
259
+
260
+
261
+ def initialize_diffusion_expert(self, fsdp=None):
262
+
263
+ print("random initiation the diffusion expert !!!")
264
+ self.diffusion_expert = SD3Pipeline.from_pretrained("stabilityai/stable-diffusion-3.5-medium", revision="main", torch_dtype=torch.bfloat16)
265
+ self.text_encoder = self.diffusion_expert.text_encoder
266
+ self.text_encoder_model = self.diffusion_expert.text_encoder.text_model
267
+ self.text_encoder_2 = self.diffusion_expert.text_encoder_2
268
+ self.text_encoder_2_model = self.diffusion_expert.text_encoder_2.text_model
269
+ self.text_encoder_3 = self.diffusion_expert.text_encoder_3
270
+ self.tokenizer = self.diffusion_expert.tokenizer
271
+ self.tokenizer_2 = self.diffusion_expert.tokenizer_2
272
+ self.tokenizer_3 = self.diffusion_expert.tokenizer_3
273
+ self.vae = self.diffusion_expert.vae
274
+ self.transformer = self.diffusion_expert.transformer
275
+ self.scheduler = self.diffusion_expert.scheduler
276
+
277
+ self.config.diffusion_expert = "SD3"
278
+
279
+
280
+
281
+ class QwenSD3Config(Qwen2_5_VLConfig):
282
+ model_type = "QwenSD3"
283
+
284
+
285
+ class QwenSD3Model(QwenSD3MetaModel, Qwen2_5_VLModel):
286
+ config_class = QwenSD3Config
287
+
288
+ def __init__(self, config: Qwen2_5_VLConfig):
289
+ super(QwenSD3Model, self).__init__(config)
290
+
291
+
292
+ class QwenSD3ForInferenceLM(Qwen2_5_VLForConditionalGeneration):
293
+ config_class = QwenSD3Config
294
+
295
+ def __init__(self, config):
296
+ Qwen2_5_VLForConditionalGeneration.__init__(self, config)
297
+ config.model_type = "QwenSD3"
298
+
299
+ self.model = QwenSD3Model(config)
300
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
301
+ self.post_init()
302
+
303
+ def get_model(self):
304
+ return self.model
305
+
306
+
307
+
308
+ @torch.no_grad()
309
+ def generate_image(
310
+ self,
311
+ texts: List[str],
312
+ diffusion_kwargs: Optional[Dict] = dict(guidance_scale = 3.5, num_inference_steps=25),
313
+ sde_sampling: Optional[bool] = False,
314
+ ):
315
+
316
+ if isinstance(texts, str):
317
+ texts = [texts]
318
+
319
+ if not sde_sampling:
320
+ output_img = self.model.diffusion_expert(
321
+ texts,
322
+ max_sequence_length=512,
323
+ **diffusion_kwargs,
324
+ ).images
325
+ return output_img
326
+ else:
327
+ return self.model.diffusion_expert.sde_sampling(
328
+ texts,
329
+ max_sequence_length=512,
330
+ **diffusion_kwargs,
331
+ )
332
+
333
+
334
+ def extract_thinking_content(self, text: str) -> str:
335
+ pattern = r'<answer>(.*?)</answer>'
336
+ matches = re.findall(pattern, text, re.DOTALL)
337
+
338
+ if matches:
339
+ return matches[-1].strip().replace("<answer>", "").replace("</answer>", "")
340
+ else:
341
+ return text.strip().replace("<answer>", "").replace("</answer>", "")
342
+
343
+ @torch.no_grad()
344
+ def generate_image_cot(
345
+ self,
346
+ texts: List[str],
347
+ processor: Optional[object] = None,
348
+ diffusion_kwargs: Optional[Dict] = dict(guidance_scale = 3.5, num_inference_steps=25),
349
+ llm_kwargs: Optional[Dict] = dict(max_new_tokens=256, temperature=0.7, top_p=0.9, do_sample=True),
350
+ cot_prompt_template: Optional[str] = None,
351
+ ):
352
+
353
+ if isinstance(texts, str):
354
+ texts = [texts]
355
+
356
+ if cot_prompt_template is None:
357
+ # cot_prompt_template = """Please improve the following image generation prompt to make it more detailed and specific for better image quality. Think step by step about what visual elements would make this image more compelling. Original prompt: {original_prompt}. Please provide the improved prompt in <thinking> </thinking> tags."""
358
+ cot_prompt_template = """Please provide an enhanced prompt for the following image generation prompt to make the image more realistic, detailed, with clear separation and precise alignment of all entities.
359
+ Original prompt: {original_prompt}. Directly provide the improved prompt in <answer> </answer> tags."""
360
+
361
+ improved_prompts = []
362
+
363
+ for text in texts:
364
+ cot_input = cot_prompt_template.format(original_prompt=text)
365
+
366
+ messages = [{"role": "user", "content": cot_input}]
367
+ input_text_formatted = processor.apply_chat_template(
368
+ messages, tokenize=False, add_generation_prompt=True
369
+ )
370
+ model_inputs = processor(
371
+ text=[input_text_formatted],
372
+ return_tensors="pt"
373
+ ).to(self.device)
374
+
375
+ generated_ids = self.generate(
376
+ **model_inputs,
377
+ **llm_kwargs,
378
+ eos_token_id=processor.tokenizer.eos_token_id,
379
+ pad_token_id=processor.tokenizer.pad_token_id
380
+ )
381
+
382
+ generated_text = processor.batch_decode(
383
+ generated_ids[:, model_inputs['input_ids'].shape[1]:],
384
+ skip_special_tokens=True
385
+ )
386
+
387
+ improved_prompt = [self.extract_thinking_content(decode_text) for decode_text in generated_text]
388
+ improved_prompts.extend(improved_prompt)
389
+
390
+ print(f"Original prompt: {text}")
391
+ print(f"Improved prompt: {improved_prompt}")
392
+ print("-" * 50)
393
+
394
+ output_images = self.generate_image(improved_prompts, diffusion_kwargs)
395
+
396
+ return {
397
+ 'images': output_images,
398
+ 'original_prompts': texts,
399
+ 'improved_prompts': improved_prompts
400
+ }
401
+
402
+ AutoConfig.register("QwenSD3", QwenSD3Config)
403
+ AutoModelForCausalLM.register(QwenSD3Config, QwenSD3ForInferenceLM)
404
+
405
+
406
+ if __name__ == "__main__":
407
+ pass
408
+
409
+
410
+ model = QwenSD3ForInferenceLM.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct",torch_dtype=torch.bfloat16)
411
+ model.model.initialize_diffusion_expert()
412
+ model.model.diffusion_expert.to("cuda:0")
413
+ model.to("cuda:0")
414
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
415
+ text = ["a photo of a cat"]
416
+ images = model.generate_image(text)
417
+ images[0].save("test_SD3.jpg")
418
+ outputs = model.generate_image_cot(text, processor = processor)
419
+ outputs['images'][0].save("test_SD3_cot.jpg")
420
+
421
+ model.save_pretrained("qwensd3")
422
+
423
+ # model = QwenSD3ForInferenceLM.from_pretrained("qwenSD3.0", torch_dtype=torch.bfloat16)
424
+ # model.to("cuda:0")
425
+ # model.save_pretrained("qwenSD3-test-2", torch_dtype=torch.bfloat16)
426
+
427
+ # model = QwenSD3ForInferenceLM.from_pretrained("qwenSD3-test", torch_dtype=torch.float16)
428
+ # # model.to("cuda:0")
429
+ # for n, p in model.named_parameters():
430
+ # if not p.dtype == torch.float16:
431
+ # print(n)
432
+ # processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
433
+ # text = ["a photo of a cat"]
434
+ # diffusion_kwargs = dict(guidance_scale = 5., num_inference_steps=20, width = 512, height = 512, generator = torch.manual_seed(0))
435
+ # images = model.generate_image(text, diffusion_kwargs=diffusion_kwargs)
436
+ # images[0].save("test_SD3.jpg")
437
+
438
+ # llm_kwargs = dict(max_new_tokens=256, temperature=0.7, top_p=0.9, do_sample=True, num_return_sequences=8)
439
+ # # generation_config = GenerationConfig.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", trust_remote_code=True)
440
+ # # generation_config.num_return_sequences = 8
441
+ # # print(generation_config)
442
+ # # llm_kwargs = dict(max_new_tokens=256, temperature=0.7, top_p=0.9, do_sample=True, generation_config=generation_config)
443
+
444
+ # outputs = model.generate_image_cot(text, processor = processor, llm_kwargs = llm_kwargs)
445
+ # # save_grid_image("cat", images['images'], 2, 2)
446
+ # for idx, image in enumerate(outputs['images']):
447
+ # image.save(f"test_SD3_cot_{idx}.jpg")
unimodel/qwensd3/sd3pipeline.py ADDED
@@ -0,0 +1,1162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ # Copyright 2025 Fu-Yun Wang
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ from typing import Any, Callable, Dict, List, Optional, Union
18
+
19
+ import torch
20
+ from transformers import (
21
+ CLIPTextModelWithProjection,
22
+ CLIPTokenizer,
23
+ SiglipImageProcessor,
24
+ SiglipVisionModel,
25
+ T5EncoderModel,
26
+ T5TokenizerFast,
27
+ )
28
+
29
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
30
+ from diffusers.loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
31
+ from diffusers.models.autoencoders import AutoencoderKL
32
+ from diffusers.models.transformers import SD3Transformer2DModel
33
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
34
+ from diffusers.utils import (
35
+ USE_PEFT_BACKEND,
36
+ is_torch_xla_available,
37
+ logging,
38
+ replace_example_docstring,
39
+ scale_lora_layers,
40
+ unscale_lora_layers,
41
+ )
42
+ from diffusers.utils.torch_utils import randn_tensor
43
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
44
+ from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
45
+ import deepspeed
46
+ from PIL import Image
47
+ import numpy as np
48
+ if is_torch_xla_available():
49
+ import torch_xla.core.xla_model as xm
50
+
51
+ XLA_AVAILABLE = True
52
+ else:
53
+ XLA_AVAILABLE = False
54
+
55
+
56
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
57
+
58
+ EXAMPLE_DOC_STRING = """
59
+ Examples:
60
+ ```py
61
+ >>> import torch
62
+ >>> from diffusers import StableDiffusion3Pipeline
63
+
64
+ >>> pipe = StableDiffusion3Pipeline.from_pretrained(
65
+ ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
66
+ ... )
67
+ >>> pipe.to("cuda")
68
+ >>> prompt = "A cat holding a sign that says hello world"
69
+ >>> image = pipe(prompt).images[0]
70
+ >>> image.save("sd3.png")
71
+ ```
72
+ """
73
+
74
+
75
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
76
+ def calculate_shift(
77
+ image_seq_len,
78
+ base_seq_len: int = 256,
79
+ max_seq_len: int = 4096,
80
+ base_shift: float = 0.5,
81
+ max_shift: float = 1.15,
82
+ ):
83
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
84
+ b = base_shift - m * base_seq_len
85
+ mu = image_seq_len * m + b
86
+ return mu
87
+
88
+
89
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
90
+ def retrieve_timesteps(
91
+ scheduler,
92
+ num_inference_steps: Optional[int] = None,
93
+ device: Optional[Union[str, torch.device]] = None,
94
+ timesteps: Optional[List[int]] = None,
95
+ sigmas: Optional[List[float]] = None,
96
+ **kwargs,
97
+ ):
98
+ r"""
99
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
100
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
101
+
102
+ Args:
103
+ scheduler (`SchedulerMixin`):
104
+ The scheduler to get timesteps from.
105
+ num_inference_steps (`int`):
106
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
107
+ must be `None`.
108
+ device (`str` or `torch.device`, *optional*):
109
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
110
+ timesteps (`List[int]`, *optional*):
111
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
112
+ `num_inference_steps` and `sigmas` must be `None`.
113
+ sigmas (`List[float]`, *optional*):
114
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
115
+ `num_inference_steps` and `timesteps` must be `None`.
116
+
117
+ Returns:
118
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
119
+ second element is the number of inference steps.
120
+ """
121
+ if timesteps is not None and sigmas is not None:
122
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
123
+ if timesteps is not None:
124
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
125
+ if not accepts_timesteps:
126
+ raise ValueError(
127
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
128
+ f" timestep schedules. Please check whether you are using the correct scheduler."
129
+ )
130
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
131
+ timesteps = scheduler.timesteps
132
+ num_inference_steps = len(timesteps)
133
+ elif sigmas is not None:
134
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
135
+ if not accept_sigmas:
136
+ raise ValueError(
137
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
138
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
139
+ )
140
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
141
+ timesteps = scheduler.timesteps
142
+ num_inference_steps = len(timesteps)
143
+ else:
144
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
145
+ timesteps = scheduler.timesteps
146
+ return timesteps, num_inference_steps
147
+
148
+
149
+ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
150
+ r"""
151
+ Args:
152
+ transformer ([`SD3Transformer2DModel`]):
153
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
154
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
155
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
156
+ vae ([`AutoencoderKL`]):
157
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
158
+ text_encoder ([`CLIPTextModelWithProjection`]):
159
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
160
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
161
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
162
+ as its dimension.
163
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
164
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
165
+ specifically the
166
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
167
+ variant.
168
+ text_encoder_3 ([`T5EncoderModel`]):
169
+ Frozen text-encoder. Stable Diffusion 3 uses
170
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
171
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
172
+ tokenizer (`CLIPTokenizer`):
173
+ Tokenizer of class
174
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
175
+ tokenizer_2 (`CLIPTokenizer`):
176
+ Second Tokenizer of class
177
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
178
+ tokenizer_3 (`T5TokenizerFast`):
179
+ Tokenizer of class
180
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
181
+ image_encoder (`SiglipVisionModel`, *optional*):
182
+ Pre-trained Vision Model for IP Adapter.
183
+ feature_extractor (`SiglipImageProcessor`, *optional*):
184
+ Image processor for IP Adapter.
185
+ """
186
+
187
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
188
+ _optional_components = ["image_encoder", "feature_extractor"]
189
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
190
+
191
+ def __init__(
192
+ self,
193
+ transformer: SD3Transformer2DModel,
194
+ scheduler: FlowMatchEulerDiscreteScheduler,
195
+ vae: AutoencoderKL,
196
+ text_encoder: CLIPTextModelWithProjection,
197
+ tokenizer: CLIPTokenizer,
198
+ text_encoder_2: CLIPTextModelWithProjection,
199
+ tokenizer_2: CLIPTokenizer,
200
+ text_encoder_3: T5EncoderModel,
201
+ tokenizer_3: T5TokenizerFast,
202
+ image_encoder: SiglipVisionModel = None,
203
+ feature_extractor: SiglipImageProcessor = None,
204
+ ):
205
+ super().__init__()
206
+
207
+ self.register_modules(
208
+ vae=vae,
209
+ text_encoder=text_encoder,
210
+ text_encoder_2=text_encoder_2,
211
+ text_encoder_3=text_encoder_3,
212
+ tokenizer=tokenizer,
213
+ tokenizer_2=tokenizer_2,
214
+ tokenizer_3=tokenizer_3,
215
+ transformer=transformer,
216
+ scheduler=scheduler,
217
+ image_encoder=image_encoder,
218
+ feature_extractor=feature_extractor,
219
+ )
220
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
221
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
222
+ self.tokenizer_max_length = (
223
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
224
+ )
225
+ self.default_sample_size = (
226
+ self.transformer.config.sample_size
227
+ if hasattr(self, "transformer") and self.transformer is not None
228
+ else 128
229
+ )
230
+ self.patch_size = (
231
+ self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
232
+ )
233
+
234
+ def _get_t5_prompt_embeds(
235
+ self,
236
+ prompt: Union[str, List[str]] = None,
237
+ num_images_per_prompt: int = 1,
238
+ max_sequence_length: int = 256,
239
+ device: Optional[torch.device] = None,
240
+ dtype: Optional[torch.dtype] = None,
241
+ ):
242
+ device = device or self._execution_device
243
+ dtype = dtype or self.text_encoder.dtype
244
+
245
+ prompt = [prompt] if isinstance(prompt, str) else prompt
246
+ batch_size = len(prompt)
247
+
248
+ if self.text_encoder_3 is None:
249
+ return torch.zeros(
250
+ (
251
+ batch_size * num_images_per_prompt,
252
+ self.tokenizer_max_length,
253
+ self.transformer.config.joint_attention_dim,
254
+ ),
255
+ device=device,
256
+ dtype=dtype,
257
+ )
258
+
259
+ text_inputs = self.tokenizer_3(
260
+ prompt,
261
+ padding="max_length",
262
+ max_length=max_sequence_length,
263
+ truncation=True,
264
+ add_special_tokens=True,
265
+ return_tensors="pt",
266
+ )
267
+ text_input_ids = text_inputs.input_ids
268
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
269
+
270
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
271
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
272
+ logger.warning(
273
+ "The following part of your input was truncated because `max_sequence_length` is set to "
274
+ f" {max_sequence_length} tokens: {removed_text}"
275
+ )
276
+
277
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
278
+
279
+ dtype = self.text_encoder_3.dtype
280
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
281
+
282
+ _, seq_len, _ = prompt_embeds.shape
283
+
284
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
285
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
286
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
287
+
288
+ return prompt_embeds
289
+
290
+ def _get_clip_prompt_embeds(
291
+ self,
292
+ prompt: Union[str, List[str]],
293
+ num_images_per_prompt: int = 1,
294
+ device: Optional[torch.device] = None,
295
+ clip_skip: Optional[int] = None,
296
+ clip_model_index: int = 0,
297
+ ):
298
+ device = device or self._execution_device
299
+
300
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
301
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
302
+
303
+ tokenizer = clip_tokenizers[clip_model_index]
304
+ text_encoder = clip_text_encoders[clip_model_index]
305
+
306
+ prompt = [prompt] if isinstance(prompt, str) else prompt
307
+ batch_size = len(prompt)
308
+
309
+ text_inputs = tokenizer(
310
+ prompt,
311
+ padding="max_length",
312
+ max_length=self.tokenizer_max_length,
313
+ truncation=True,
314
+ return_tensors="pt",
315
+ )
316
+
317
+ text_input_ids = text_inputs.input_ids
318
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
319
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
320
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
321
+ logger.warning(
322
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
323
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
324
+ )
325
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
326
+ pooled_prompt_embeds = prompt_embeds[0]
327
+
328
+ if clip_skip is None:
329
+ prompt_embeds = prompt_embeds.hidden_states[-2]
330
+ else:
331
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
332
+
333
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
334
+
335
+ _, seq_len, _ = prompt_embeds.shape
336
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
337
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
338
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
339
+
340
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
341
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
342
+
343
+ return prompt_embeds, pooled_prompt_embeds
344
+
345
+ def encode_prompt(
346
+ self,
347
+ prompt: Union[str, List[str]],
348
+ prompt_2: Union[str, List[str]],
349
+ prompt_3: Union[str, List[str]],
350
+ device: Optional[torch.device] = None,
351
+ num_images_per_prompt: int = 1,
352
+ do_classifier_free_guidance: bool = True,
353
+ negative_prompt: Optional[Union[str, List[str]]] = None,
354
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
355
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
356
+ prompt_embeds: Optional[torch.FloatTensor] = None,
357
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
358
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
359
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
360
+ clip_skip: Optional[int] = None,
361
+ max_sequence_length: int = 256,
362
+ lora_scale: Optional[float] = None,
363
+ ):
364
+ r"""
365
+
366
+ Args:
367
+ prompt (`str` or `List[str]`, *optional*):
368
+ prompt to be encoded
369
+ prompt_2 (`str` or `List[str]`, *optional*):
370
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
371
+ used in all text-encoders
372
+ prompt_3 (`str` or `List[str]`, *optional*):
373
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
374
+ used in all text-encoders
375
+ device: (`torch.device`):
376
+ torch device
377
+ num_images_per_prompt (`int`):
378
+ number of images that should be generated per prompt
379
+ do_classifier_free_guidance (`bool`):
380
+ whether to use classifier free guidance or not
381
+ negative_prompt (`str` or `List[str]`, *optional*):
382
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
383
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
384
+ less than `1`).
385
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
386
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
387
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
388
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
389
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
390
+ `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
391
+ prompt_embeds (`torch.FloatTensor`, *optional*):
392
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
393
+ provided, text embeddings will be generated from `prompt` input argument.
394
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
395
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
396
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
397
+ argument.
398
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
399
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
400
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
401
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
402
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
403
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
404
+ input argument.
405
+ clip_skip (`int`, *optional*):
406
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
407
+ the output of the pre-final layer will be used for computing the prompt embeddings.
408
+ lora_scale (`float`, *optional*):
409
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
410
+ """
411
+ device = device or self._execution_device
412
+
413
+ # set lora scale so that monkey patched LoRA
414
+ # function of text encoder can correctly access it
415
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
416
+ self._lora_scale = lora_scale
417
+
418
+ # dynamically adjust the LoRA scale
419
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
420
+ scale_lora_layers(self.text_encoder, lora_scale)
421
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
422
+ scale_lora_layers(self.text_encoder_2, lora_scale)
423
+
424
+ prompt = [prompt] if isinstance(prompt, str) else prompt
425
+ if prompt is not None:
426
+ batch_size = len(prompt)
427
+ else:
428
+ batch_size = prompt_embeds.shape[0]
429
+
430
+
431
+ if prompt_embeds is None:
432
+ prompt_2 = prompt_2 or prompt
433
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
434
+
435
+ prompt_3 = prompt_3 or prompt
436
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
437
+
438
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
439
+ prompt=prompt,
440
+ device=device,
441
+ num_images_per_prompt=num_images_per_prompt,
442
+ clip_skip=clip_skip,
443
+ clip_model_index=0,
444
+ )
445
+
446
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
447
+ prompt=prompt_2,
448
+ device=device,
449
+ num_images_per_prompt=num_images_per_prompt,
450
+ clip_skip=clip_skip,
451
+ clip_model_index=1,
452
+ )
453
+
454
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
455
+
456
+ t5_prompt_embed = self._get_t5_prompt_embeds(
457
+ prompt=prompt_3,
458
+ num_images_per_prompt=num_images_per_prompt,
459
+ max_sequence_length=max_sequence_length,
460
+ device=device,
461
+
462
+ )
463
+
464
+ clip_prompt_embeds = torch.nn.functional.pad(
465
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
466
+ )
467
+
468
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
469
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
470
+
471
+
472
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
473
+ negative_prompt = negative_prompt or ""
474
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
475
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
476
+
477
+ # normalize str to list
478
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
479
+ negative_prompt_2 = (
480
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
481
+ )
482
+ negative_prompt_3 = (
483
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
484
+ )
485
+
486
+ if prompt is not None and type(prompt) is not type(negative_prompt):
487
+ raise TypeError(
488
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
489
+ f" {type(prompt)}."
490
+ )
491
+ elif batch_size != len(negative_prompt):
492
+ raise ValueError(
493
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
494
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
495
+ " the batch size of `prompt`."
496
+ )
497
+ # with deepspeed.zero.GatheredParameters(self.text_encoder.parameters()):
498
+ # negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
499
+ # negative_prompt,
500
+ # device=device,
501
+ # num_images_per_prompt=num_images_per_prompt,
502
+ # clip_skip=None,
503
+ # clip_model_index=0,
504
+ # )
505
+ # with deepspeed.zero.GatheredParameters(self.text_encoder_2.parameters()):
506
+ # negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
507
+ # negative_prompt_2,
508
+ # device=device,
509
+ # num_images_per_prompt=num_images_per_prompt,
510
+ # clip_skip=None,
511
+ # clip_model_index=1,
512
+ # )
513
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
514
+ negative_prompt,
515
+ device=device,
516
+ num_images_per_prompt=num_images_per_prompt,
517
+ clip_skip=None,
518
+ clip_model_index=0,
519
+ )
520
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
521
+ negative_prompt_2,
522
+ device=device,
523
+ num_images_per_prompt=num_images_per_prompt,
524
+ clip_skip=None,
525
+ clip_model_index=1,
526
+ )
527
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
528
+
529
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
530
+ prompt=negative_prompt_3,
531
+ num_images_per_prompt=num_images_per_prompt,
532
+ max_sequence_length=max_sequence_length,
533
+ device=device,
534
+ )
535
+
536
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
537
+ negative_clip_prompt_embeds,
538
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
539
+ )
540
+
541
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
542
+ negative_pooled_prompt_embeds = torch.cat(
543
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
544
+ )
545
+
546
+ if self.text_encoder is not None:
547
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
548
+ # Retrieve the original scale by scaling back the LoRA layers
549
+ unscale_lora_layers(self.text_encoder, lora_scale)
550
+
551
+ if self.text_encoder_2 is not None:
552
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
553
+ # Retrieve the original scale by scaling back the LoRA layers
554
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
555
+
556
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
557
+
558
+ def check_inputs(
559
+ self,
560
+ prompt,
561
+ prompt_2,
562
+ prompt_3,
563
+ height,
564
+ width,
565
+ negative_prompt=None,
566
+ negative_prompt_2=None,
567
+ negative_prompt_3=None,
568
+ prompt_embeds=None,
569
+ negative_prompt_embeds=None,
570
+ pooled_prompt_embeds=None,
571
+ negative_pooled_prompt_embeds=None,
572
+ callback_on_step_end_tensor_inputs=None,
573
+ max_sequence_length=None,
574
+ ):
575
+ if (
576
+ height % (self.vae_scale_factor * self.patch_size) != 0
577
+ or width % (self.vae_scale_factor * self.patch_size) != 0
578
+ ):
579
+ raise ValueError(
580
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
581
+ f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
582
+ )
583
+
584
+ if callback_on_step_end_tensor_inputs is not None and not all(
585
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
586
+ ):
587
+ raise ValueError(
588
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
589
+ )
590
+
591
+ if prompt is not None and prompt_embeds is not None:
592
+ raise ValueError(
593
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
594
+ " only forward one of the two."
595
+ )
596
+ elif prompt_2 is not None and prompt_embeds is not None:
597
+ raise ValueError(
598
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
599
+ " only forward one of the two."
600
+ )
601
+ elif prompt_3 is not None and prompt_embeds is not None:
602
+ raise ValueError(
603
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
604
+ " only forward one of the two."
605
+ )
606
+ elif prompt is None and prompt_embeds is None:
607
+ raise ValueError(
608
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
609
+ )
610
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
611
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
612
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
613
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
614
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
615
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
616
+
617
+ if negative_prompt is not None and negative_prompt_embeds is not None:
618
+ raise ValueError(
619
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
620
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
621
+ )
622
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
623
+ raise ValueError(
624
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
625
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
626
+ )
627
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
628
+ raise ValueError(
629
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
630
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
631
+ )
632
+
633
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
634
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
635
+ raise ValueError(
636
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
637
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
638
+ f" {negative_prompt_embeds.shape}."
639
+ )
640
+
641
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
642
+ raise ValueError(
643
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
644
+ )
645
+
646
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
647
+ raise ValueError(
648
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
649
+ )
650
+
651
+ if max_sequence_length is not None and max_sequence_length > 512:
652
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
653
+
654
+ def prepare_latents(
655
+ self,
656
+ batch_size,
657
+ num_channels_latents,
658
+ height,
659
+ width,
660
+ dtype,
661
+ device,
662
+ generator,
663
+ latents=None,
664
+ ):
665
+ if latents is not None:
666
+ return latents.to(device=device, dtype=dtype)
667
+
668
+ shape = (
669
+ batch_size,
670
+ num_channels_latents,
671
+ int(height) // self.vae_scale_factor,
672
+ int(width) // self.vae_scale_factor,
673
+ )
674
+
675
+ if isinstance(generator, list) and len(generator) != batch_size:
676
+ raise ValueError(
677
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
678
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
679
+ )
680
+
681
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
682
+
683
+ return latents
684
+
685
+ @property
686
+ def guidance_scale(self):
687
+ return self._guidance_scale
688
+
689
+ @property
690
+ def skip_guidance_layers(self):
691
+ return self._skip_guidance_layers
692
+
693
+ @property
694
+ def clip_skip(self):
695
+ return self._clip_skip
696
+
697
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
698
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
699
+ # corresponds to doing no classifier free guidance.
700
+ @property
701
+ def do_classifier_free_guidance(self):
702
+ return self._guidance_scale > 1
703
+
704
+ @property
705
+ def joint_attention_kwargs(self):
706
+ return self._joint_attention_kwargs
707
+
708
+ @property
709
+ def num_timesteps(self):
710
+ return self._num_timesteps
711
+
712
+ @property
713
+ def interrupt(self):
714
+ return self._interrupt
715
+
716
+ def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
717
+ """Encodes the given image into a feature representation using a pre-trained image encoder.
718
+
719
+ Args:
720
+ image (`PipelineImageInput`):
721
+ Input image to be encoded.
722
+ device: (`torch.device`):
723
+ Torch device.
724
+
725
+ Returns:
726
+ `torch.Tensor`: The encoded image feature representation.
727
+ """
728
+ if not isinstance(image, torch.Tensor):
729
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
730
+
731
+ image = image.to(device=device, dtype=self.dtype)
732
+
733
+ return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
734
+
735
+ def prepare_ip_adapter_image_embeds(
736
+ self,
737
+ ip_adapter_image: Optional[PipelineImageInput] = None,
738
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
739
+ device: Optional[torch.device] = None,
740
+ num_images_per_prompt: int = 1,
741
+ do_classifier_free_guidance: bool = True,
742
+ ) -> torch.Tensor:
743
+ """Prepares image embeddings for use in the IP-Adapter.
744
+
745
+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
746
+
747
+ Args:
748
+ ip_adapter_image (`PipelineImageInput`, *optional*):
749
+ The input image to extract features from for IP-Adapter.
750
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
751
+ Precomputed image embeddings.
752
+ device: (`torch.device`, *optional*):
753
+ Torch device.
754
+ num_images_per_prompt (`int`, defaults to 1):
755
+ Number of images that should be generated per prompt.
756
+ do_classifier_free_guidance (`bool`, defaults to True):
757
+ Whether to use classifier free guidance or not.
758
+ """
759
+ device = device or self._execution_device
760
+
761
+ if ip_adapter_image_embeds is not None:
762
+ if do_classifier_free_guidance:
763
+ single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
764
+ else:
765
+ single_image_embeds = ip_adapter_image_embeds
766
+ elif ip_adapter_image is not None:
767
+ single_image_embeds = self.encode_image(ip_adapter_image, device)
768
+ if do_classifier_free_guidance:
769
+ single_negative_image_embeds = torch.zeros_like(single_image_embeds)
770
+ else:
771
+ raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
772
+
773
+ image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
774
+
775
+ if do_classifier_free_guidance:
776
+ negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
777
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
778
+
779
+ return image_embeds.to(device=device)
780
+
781
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
782
+ if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
783
+ logger.warning(
784
+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
785
+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
786
+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
787
+ )
788
+
789
+ super().enable_sequential_cpu_offload(*args, **kwargs)
790
+
791
+ @torch.no_grad()
792
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
793
+ def __call__(
794
+ self,
795
+ prompt: Union[str, List[str]] = None,
796
+ prompt_2: Optional[Union[str, List[str]]] = None,
797
+ prompt_3: Optional[Union[str, List[str]]] = None,
798
+ height: Optional[int] = None,
799
+ width: Optional[int] = None,
800
+ num_inference_steps: int = 28,
801
+ sigmas: Optional[List[float]] = None,
802
+ guidance_scale: float = 7.0,
803
+ negative_prompt: Optional[Union[str, List[str]]] = None,
804
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
805
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
806
+ num_images_per_prompt: Optional[int] = 1,
807
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
808
+ latents: Optional[torch.FloatTensor] = None,
809
+ prompt_embeds: Optional[torch.FloatTensor] = None,
810
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
811
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
812
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
813
+ ip_adapter_image: Optional[PipelineImageInput] = None,
814
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
815
+ output_type: Optional[str] = "pil",
816
+ return_dict: bool = True,
817
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
818
+ clip_skip: Optional[int] = None,
819
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
820
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
821
+ max_sequence_length: int = 256,
822
+ skip_guidance_layers: List[int] = None,
823
+ skip_layer_guidance_scale: float = 2.8,
824
+ skip_layer_guidance_stop: float = 0.2,
825
+ skip_layer_guidance_start: float = 0.01,
826
+ mu: Optional[float] = None,
827
+ ):
828
+ r"""
829
+ Function invoked when calling the pipeline for generation.
830
+
831
+ Args:
832
+ prompt (`str` or `List[str]`, *optional*):
833
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
834
+ instead.
835
+ prompt_2 (`str` or `List[str]`, *optional*):
836
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
837
+ will be used instead
838
+ prompt_3 (`str` or `List[str]`, *optional*):
839
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
840
+ will be used instead
841
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
842
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
843
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
844
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
845
+ num_inference_steps (`int`, *optional*, defaults to 50):
846
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
847
+ expense of slower inference.
848
+ sigmas (`List[float]`, *optional*):
849
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
850
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
851
+ will be used.
852
+ guidance_scale (`float`, *optional*, defaults to 7.0):
853
+ Guidance scale as defined in [Classifier-Free Diffusion
854
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
855
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
856
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
857
+ the text `prompt`, usually at the expense of lower image quality.
858
+ negative_prompt (`str` or `List[str]`, *optional*):
859
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
860
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
861
+ less than `1`).
862
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
863
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
864
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
865
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
866
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
867
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
868
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
869
+ The number of images to generate per prompt.
870
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
871
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
872
+ to make generation deterministic.
873
+ latents (`torch.FloatTensor`, *optional*):
874
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
875
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
876
+ tensor will ge generated by sampling using the supplied random `generator`.
877
+ prompt_embeds (`torch.FloatTensor`, *optional*):
878
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
879
+ provided, text embeddings will be generated from `prompt` input argument.
880
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
881
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
882
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
883
+ argument.
884
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
885
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
886
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
887
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
888
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
889
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
890
+ input argument.
891
+ ip_adapter_image (`PipelineImageInput`, *optional*):
892
+ Optional image input to work with IP Adapters.
893
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
894
+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
895
+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
896
+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
897
+ output_type (`str`, *optional*, defaults to `"pil"`):
898
+ The output format of the generate image. Choose between
899
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
900
+ return_dict (`bool`, *optional*, defaults to `True`):
901
+ Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
902
+ a plain tuple.
903
+ joint_attention_kwargs (`dict`, *optional*):
904
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
905
+ `self.processor` in
906
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
907
+ callback_on_step_end (`Callable`, *optional*):
908
+ A function that calls at the end of each denoising steps during the inference. The function is called
909
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
910
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
911
+ `callback_on_step_end_tensor_inputs`.
912
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
913
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
914
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
915
+ `._callback_tensor_inputs` attribute of your pipeline class.
916
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
917
+ skip_guidance_layers (`List[int]`, *optional*):
918
+ A list of integers that specify layers to skip during guidance. If not provided, all layers will be
919
+ used for guidance. If provided, the guidance will only be applied to the layers specified in the list.
920
+ Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9].
921
+ skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
922
+ `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
923
+ with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers
924
+ with a scale of `1`.
925
+ skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
926
+ `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
927
+ `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
928
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
929
+ skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
930
+ `skip_guidance_layers` will start. The guidance will be applied to the layers specified in
931
+ `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
932
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
933
+ mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
934
+
935
+ Examples:
936
+
937
+ Returns:
938
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
939
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
940
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
941
+ """
942
+
943
+ height = height or self.default_sample_size * self.vae_scale_factor
944
+ width = width or self.default_sample_size * self.vae_scale_factor
945
+
946
+ # 1. Check inputs. Raise error if not correct
947
+ self.check_inputs(
948
+ prompt,
949
+ prompt_2,
950
+ prompt_3,
951
+ height,
952
+ width,
953
+ negative_prompt=negative_prompt,
954
+ negative_prompt_2=negative_prompt_2,
955
+ negative_prompt_3=negative_prompt_3,
956
+ prompt_embeds=prompt_embeds,
957
+ negative_prompt_embeds=negative_prompt_embeds,
958
+ pooled_prompt_embeds=pooled_prompt_embeds,
959
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
960
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
961
+ max_sequence_length=max_sequence_length,
962
+ )
963
+
964
+ self._guidance_scale = guidance_scale
965
+ self._skip_layer_guidance_scale = skip_layer_guidance_scale
966
+ self._clip_skip = clip_skip
967
+ self._joint_attention_kwargs = joint_attention_kwargs
968
+ self._interrupt = False
969
+
970
+ # 2. Define call parameters
971
+ if prompt is not None and isinstance(prompt, str):
972
+ batch_size = 1
973
+ elif prompt is not None and isinstance(prompt, list):
974
+ batch_size = len(prompt)
975
+ else:
976
+ batch_size = prompt_embeds.shape[0]
977
+
978
+ device = self._execution_device
979
+
980
+ lora_scale = (
981
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
982
+ )
983
+ (
984
+ prompt_embeds,
985
+ negative_prompt_embeds,
986
+ pooled_prompt_embeds,
987
+ negative_pooled_prompt_embeds,
988
+ ) = self.encode_prompt(
989
+ prompt=prompt,
990
+ prompt_2=prompt_2,
991
+ prompt_3=prompt_3,
992
+ negative_prompt=negative_prompt,
993
+ negative_prompt_2=negative_prompt_2,
994
+ negative_prompt_3=negative_prompt_3,
995
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
996
+ prompt_embeds=prompt_embeds,
997
+ negative_prompt_embeds=negative_prompt_embeds,
998
+ pooled_prompt_embeds=pooled_prompt_embeds,
999
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1000
+ device=device,
1001
+ clip_skip=self.clip_skip,
1002
+ num_images_per_prompt=num_images_per_prompt,
1003
+ max_sequence_length=max_sequence_length,
1004
+ lora_scale=lora_scale,
1005
+ )
1006
+
1007
+ if self.do_classifier_free_guidance:
1008
+ if skip_guidance_layers is not None:
1009
+ original_prompt_embeds = prompt_embeds
1010
+ original_pooled_prompt_embeds = pooled_prompt_embeds
1011
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1012
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1013
+
1014
+ # 4. Prepare latent variables
1015
+ num_channels_latents = self.transformer.config.in_channels
1016
+ latents = self.prepare_latents(
1017
+ batch_size * num_images_per_prompt,
1018
+ num_channels_latents,
1019
+ height,
1020
+ width,
1021
+ prompt_embeds.dtype,
1022
+ device,
1023
+ generator,
1024
+ latents,
1025
+ )
1026
+
1027
+ # 5. Prepare timesteps
1028
+ scheduler_kwargs = {}
1029
+ if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
1030
+ _, _, height, width = latents.shape
1031
+ image_seq_len = (height // self.transformer.config.patch_size) * (
1032
+ width // self.transformer.config.patch_size
1033
+ )
1034
+ mu = calculate_shift(
1035
+ image_seq_len,
1036
+ self.scheduler.config.get("base_image_seq_len", 256),
1037
+ self.scheduler.config.get("max_image_seq_len", 4096),
1038
+ self.scheduler.config.get("base_shift", 0.5),
1039
+ self.scheduler.config.get("max_shift", 1.16),
1040
+ )
1041
+ scheduler_kwargs["mu"] = mu
1042
+ elif mu is not None:
1043
+ scheduler_kwargs["mu"] = mu
1044
+ timesteps, num_inference_steps = retrieve_timesteps(
1045
+ self.scheduler,
1046
+ num_inference_steps,
1047
+ device,
1048
+ sigmas=sigmas,
1049
+ **scheduler_kwargs,
1050
+ )
1051
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1052
+ self._num_timesteps = len(timesteps)
1053
+
1054
+ # 6. Prepare image embeddings
1055
+ if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
1056
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
1057
+ ip_adapter_image,
1058
+ ip_adapter_image_embeds,
1059
+ device,
1060
+ batch_size * num_images_per_prompt,
1061
+ self.do_classifier_free_guidance,
1062
+ )
1063
+
1064
+ if self.joint_attention_kwargs is None:
1065
+ self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
1066
+ else:
1067
+ self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
1068
+
1069
+
1070
+ # 7. Denoising loop
1071
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1072
+ for i, t in enumerate(timesteps):
1073
+ if self.interrupt:
1074
+ continue
1075
+
1076
+ # expand the latents if we are doing classifier free guidance
1077
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1078
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1079
+ timestep = t.expand(latent_model_input.shape[0])
1080
+
1081
+ noise_pred = self.transformer(
1082
+ hidden_states=latent_model_input,
1083
+ timestep=timestep,
1084
+ encoder_hidden_states=prompt_embeds,
1085
+ pooled_projections=pooled_prompt_embeds,
1086
+ joint_attention_kwargs=self.joint_attention_kwargs,
1087
+ return_dict=False,
1088
+ )[0]
1089
+
1090
+ # perform guidance
1091
+ if self.do_classifier_free_guidance:
1092
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1093
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1094
+ should_skip_layers = (
1095
+ True
1096
+ if i > num_inference_steps * skip_layer_guidance_start
1097
+ and i < num_inference_steps * skip_layer_guidance_stop
1098
+ else False
1099
+ )
1100
+ if skip_guidance_layers is not None and should_skip_layers:
1101
+ timestep = t.expand(latents.shape[0])
1102
+ latent_model_input = latents
1103
+ noise_pred_skip_layers = self.transformer(
1104
+ hidden_states=latent_model_input,
1105
+ timestep=timestep,
1106
+ encoder_hidden_states=original_prompt_embeds,
1107
+ pooled_projections=original_pooled_prompt_embeds,
1108
+ joint_attention_kwargs=self.joint_attention_kwargs,
1109
+ return_dict=False,
1110
+ skip_layers=skip_guidance_layers,
1111
+ )[0]
1112
+ noise_pred = (
1113
+ noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale
1114
+ )
1115
+
1116
+ latents_dtype = latents.dtype
1117
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1118
+
1119
+ if latents.dtype != latents_dtype:
1120
+ if torch.backends.mps.is_available():
1121
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1122
+ latents = latents.to(latents_dtype)
1123
+
1124
+ if callback_on_step_end is not None:
1125
+ callback_kwargs = {}
1126
+ for k in callback_on_step_end_tensor_inputs:
1127
+ callback_kwargs[k] = locals()[k]
1128
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1129
+
1130
+ latents = callback_outputs.pop("latents", latents)
1131
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1132
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1133
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1134
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1135
+ )
1136
+
1137
+ # call the callback, if provided
1138
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1139
+ progress_bar.update()
1140
+
1141
+ if XLA_AVAILABLE:
1142
+ xm.mark_step()
1143
+
1144
+ if output_type == "latent":
1145
+ image = latents
1146
+
1147
+ else:
1148
+ mean_img = torch.mean(latents[0], dim=0).cpu().float().numpy()
1149
+ Image.fromarray(((mean_img - mean_img.min()) / (mean_img.max() - mean_img.min()) * 255).astype(np.uint8)).save('mean.png')
1150
+
1151
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1152
+
1153
+ image = self.vae.decode(latents, return_dict=False)[0]
1154
+ image = self.image_processor.postprocess(image, output_type=output_type)
1155
+
1156
+ # Offload all models
1157
+ self.maybe_free_model_hooks()
1158
+
1159
+ if not return_dict:
1160
+ return (image,)
1161
+
1162
+ return StableDiffusion3PipelineOutput(images=image)