Add UniRL inference code
Browse files- .gitattributes +3 -0
- .gitignore +149 -0
- LICENSE +201 -0
- README.md +155 -0
- assets/edit_comparison.png +3 -0
- assets/logo.png +3 -0
- assets/t2i_comparison.png +3 -0
- environment.yml +238 -0
- eval.py +367 -0
- gen.sh +48 -0
- prompts/config.yaml +35 -0
- prompts/draw_test.txt +1000 -0
- prompts/evaluation_metadata.jsonl +553 -0
- prompts/ocr_test.txt +0 -0
- requirements.txt +202 -0
- unified_inference.py +660 -0
- unimodel/qwenflux/fluxpipeline.py +1543 -0
- unimodel/qwenflux/qwenflux_inference.py +418 -0
- unimodel/qwenkontext/fluxkontext_pipeline.py +1161 -0
- unimodel/qwenkontext/qwenkontext_inference.py +442 -0
- unimodel/qwensana/qwensana_inference.py +310 -0
- unimodel/qwensd3/qwensd3_inference.py +447 -0
- unimodel/qwensd3/sd3pipeline.py +1162 -0
.gitattributes
CHANGED
|
@@ -37,3 +37,6 @@ promptrl_geneval/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
|
| 37 |
promptrl_ocr/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 38 |
promptrl_ps/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 39 |
promptrl_edit/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
promptrl_ocr/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 38 |
promptrl_ps/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 39 |
promptrl_edit/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
assets/edit_comparison.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
assets/logo.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
assets/t2i_comparison.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
*.pyc
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
.Python
|
| 7 |
+
env/
|
| 8 |
+
venv/
|
| 9 |
+
ENV/
|
| 10 |
+
env.bak/
|
| 11 |
+
venv.bak/
|
| 12 |
+
|
| 13 |
+
# Byte-compiled / optimized / DLL files
|
| 14 |
+
*.py[cod]
|
| 15 |
+
*$py.class
|
| 16 |
+
|
| 17 |
+
# C extensions
|
| 18 |
+
*.so
|
| 19 |
+
|
| 20 |
+
# Distribution / packaging
|
| 21 |
+
.Python
|
| 22 |
+
build/
|
| 23 |
+
dist/
|
| 24 |
+
downloads/
|
| 25 |
+
eggs/
|
| 26 |
+
.eggs/
|
| 27 |
+
lib/
|
| 28 |
+
lib64/
|
| 29 |
+
parts/
|
| 30 |
+
sdist/
|
| 31 |
+
var/
|
| 32 |
+
wheels/
|
| 33 |
+
*.egg-info/
|
| 34 |
+
.installed.cfg
|
| 35 |
+
*.egg
|
| 36 |
+
|
| 37 |
+
# PyInstaller
|
| 38 |
+
*.manifest
|
| 39 |
+
*.spec
|
| 40 |
+
|
| 41 |
+
# Installer logs
|
| 42 |
+
pip-log.txt
|
| 43 |
+
pip-delete-this-directory.txt
|
| 44 |
+
|
| 45 |
+
# Unit test / coverage reports
|
| 46 |
+
htmlcov/
|
| 47 |
+
.tox/
|
| 48 |
+
.nox/
|
| 49 |
+
.coverage
|
| 50 |
+
.coverage.*
|
| 51 |
+
.cache
|
| 52 |
+
nosetests.xml
|
| 53 |
+
coverage.xml
|
| 54 |
+
*.cover
|
| 55 |
+
*.py,cover
|
| 56 |
+
.hypothesis/
|
| 57 |
+
.pytest_cache/
|
| 58 |
+
|
| 59 |
+
# Jupyter Notebook
|
| 60 |
+
.ipynb_checkpoints
|
| 61 |
+
|
| 62 |
+
# IPython
|
| 63 |
+
profile_default/
|
| 64 |
+
ipython_config.py
|
| 65 |
+
|
| 66 |
+
# pyenv
|
| 67 |
+
.python-version
|
| 68 |
+
|
| 69 |
+
# pipenv
|
| 70 |
+
Pipfile.lock
|
| 71 |
+
|
| 72 |
+
# Poetry
|
| 73 |
+
poetry.lock
|
| 74 |
+
|
| 75 |
+
# Virtualenv
|
| 76 |
+
.venv
|
| 77 |
+
venv/
|
| 78 |
+
ENV/
|
| 79 |
+
|
| 80 |
+
# Spyder project settings
|
| 81 |
+
.spyderproject
|
| 82 |
+
.spyproject
|
| 83 |
+
|
| 84 |
+
# Rope project settings
|
| 85 |
+
.ropeproject
|
| 86 |
+
|
| 87 |
+
# mkdocs documentation
|
| 88 |
+
/site
|
| 89 |
+
|
| 90 |
+
# mypy
|
| 91 |
+
.mypy_cache/
|
| 92 |
+
.dmypy.json
|
| 93 |
+
dmypy.json
|
| 94 |
+
|
| 95 |
+
# Pyre type checker
|
| 96 |
+
.pyre/
|
| 97 |
+
|
| 98 |
+
# IDEs and editors
|
| 99 |
+
.idea/
|
| 100 |
+
.vscode/
|
| 101 |
+
*.sublime-workspace
|
| 102 |
+
|
| 103 |
+
# OS generated files
|
| 104 |
+
.DS_Store
|
| 105 |
+
Thumbs.db
|
| 106 |
+
|
| 107 |
+
# Logs
|
| 108 |
+
*.log
|
| 109 |
+
logs/
|
| 110 |
+
*.log.*
|
| 111 |
+
|
| 112 |
+
# Dependency directories
|
| 113 |
+
node_modules/
|
| 114 |
+
bower_components/
|
| 115 |
+
|
| 116 |
+
# Optional: Local configuration files
|
| 117 |
+
*.local
|
| 118 |
+
*.env
|
| 119 |
+
.env
|
| 120 |
+
.env.local
|
| 121 |
+
.env.development.local
|
| 122 |
+
.env.test.local
|
| 123 |
+
.env.production.local
|
| 124 |
+
|
| 125 |
+
# Optional: Database
|
| 126 |
+
*.sqlite3
|
| 127 |
+
*.db
|
| 128 |
+
|
| 129 |
+
# Optional: Django
|
| 130 |
+
*.sqlite3
|
| 131 |
+
migrations/
|
| 132 |
+
*.mo
|
| 133 |
+
*.pot
|
| 134 |
+
staticfiles/
|
| 135 |
+
|
| 136 |
+
# Optional: Flask
|
| 137 |
+
instance/
|
| 138 |
+
.webassets-cache
|
| 139 |
+
|
| 140 |
+
# Optional: Scrapy
|
| 141 |
+
.scrapy
|
| 142 |
+
|
| 143 |
+
outputs/
|
| 144 |
+
|
| 145 |
+
wandb/
|
| 146 |
+
|
| 147 |
+
assets/large_rl_datasets/
|
| 148 |
+
|
| 149 |
+
utils/parquet_cache/
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [2025] [Fu-Yun Wang]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="center">
|
| 2 |
+
<img src="assets/logo.png" width="30%"><br>
|
| 3 |
+
PromptRL
|
| 4 |
+
</p>
|
| 5 |
+
|
| 6 |
+
<p align="center">
|
| 7 |
+
<a href="https://arxiv.org/abs/2602.01382"><img src="https://img.shields.io/badge/arXiv-2602.01382-b31b1b.svg" alt="arXiv"></a>
|
| 8 |
+
<a href="https://g-u-n.github.io/projects/promptrl/"><img src="https://img.shields.io/badge/Project-Page-green.svg" alt="Project Page"></a>
|
| 9 |
+
<a href="https://huggingface.co/wangfuyun/PrompRL"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue" alt="HuggingFace"></a>
|
| 10 |
+
</p>
|
| 11 |
+
|
| 12 |
+
## Overview
|
| 13 |
+
|
| 14 |
+
**PromptRL** is a framework that jointly trains language models (LMs) and flow-matching models (FMs) within a unified reinforcement learning loop for text-to-image generation. By incorporating LMs as adaptive prompt refiners, PromptRL addresses two critical limitations in current flow-based RL pipelines: *exploration collapse* due to insufficient generation diversity, and *prompt overfitting* where models memorize specific training formulations.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
## Installation
|
| 18 |
+
|
| 19 |
+
```bash
|
| 20 |
+
conda env create -f environment.yml
|
| 21 |
+
conda activate unirl
|
| 22 |
+
pip install git+https://github.com/openai/CLIP.git
|
| 23 |
+
pip install git+https://github.com/huggingface/diffusers.git
|
| 24 |
+
pip install flash-attn==2.7.4.post1 --no-build-isolation
|
| 25 |
+
|
| 26 |
+
# run gen.sh for evaluation
|
| 27 |
+
# bash gen.sh
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
## Qualitative Results
|
| 31 |
+
|
| 32 |
+
### Text-to-Image Generation
|
| 33 |
+
<p align="center">
|
| 34 |
+
<img src="assets/t2i_comparison.png" width="85%">
|
| 35 |
+
</p>
|
| 36 |
+
|
| 37 |
+
### Instructional Image Editing
|
| 38 |
+
<p align="center">
|
| 39 |
+
<img src="assets/edit_comparison.png" width="75%">
|
| 40 |
+
</p>
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
## Key Results
|
| 44 |
+
|
| 45 |
+
PromptRL achieves **2× sample efficiency** compared to flow-only RL while obtains a adaptative prompt refinement agent to improve test-time performance.
|
| 46 |
+
|
| 47 |
+
### Summary
|
| 48 |
+
|
| 49 |
+
| Benchmark | Metric | PromptRL w/ PE | Best Baseline |
|
| 50 |
+
|:---|:---|:---:|:---:|
|
| 51 |
+
| GenEval | Avg. Score ↑ | **0.97** | 0.92 (FlowGRPO) |
|
| 52 |
+
| Aesthetic | PickScore ↑ | **24.05** | 23.63 (DiffusionNFT) |
|
| 53 |
+
| Aesthetic | HPS ↑ | **32.03** | 31.79 (DiffusionNFT) |
|
| 54 |
+
| OCR | OCR-1k ↑ | **0.98** | 0.89 (FlowGRPO) |
|
| 55 |
+
| Image Editing | EditReward Avg. ↑ | **1.43** | 1.44 (ReasonEdit-Think) |
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
<details>
|
| 60 |
+
<summary><b>📊 GenEval Benchmark (Full Results)</b></summary>
|
| 61 |
+
|
| 62 |
+
<br>
|
| 63 |
+
|
| 64 |
+
| Model | 1 Obj. | 2 Obj. | Cnt. | Clr. | Pos. | Attr. | Avg.↑ |
|
| 65 |
+
|:---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
|
| 66 |
+
| Show-o | 0.95 | 0.52 | 0.49 | 0.82 | 0.11 | 0.28 | 0.53 |
|
| 67 |
+
| Emu3-Gen | 0.98 | 0.71 | 0.34 | 0.81 | 0.17 | 0.21 | 0.54 |
|
| 68 |
+
| SD3 Medium | 0.98 | 0.74 | 0.63 | 0.67 | 0.34 | 0.36 | 0.62 |
|
| 69 |
+
| FLUX.1-dev | 0.98 | 0.81 | 0.74 | 0.79 | 0.22 | 0.45 | 0.66 |
|
| 70 |
+
| SD3.5 Large | 0.98 | 0.89 | 0.73 | 0.83 | 0.34 | 0.47 | 0.71 |
|
| 71 |
+
| JanusFlow | 0.97 | 0.59 | 0.45 | 0.83 | 0.53 | 0.42 | 0.63 |
|
| 72 |
+
| Janus-Pro-7B | 0.99 | 0.89 | 0.59 | 0.90 | 0.79 | 0.66 | 0.80 |
|
| 73 |
+
| HiDream | 1.00 | 0.98 | 0.79 | 0.91 | 0.60 | 0.72 | 0.83 |
|
| 74 |
+
| Seedream 3.0 | 0.99 | 0.96 | 0.91 | 0.93 | 0.47 | 0.80 | 0.84 |
|
| 75 |
+
| Qwen-Image | 0.99 | 0.92 | 0.89 | 0.88 | 0.76 | 0.77 | 0.87 |
|
| 76 |
+
| *RL-based* | | | | | | | |
|
| 77 |
+
| RePrompt | 0.98 | 0.87 | 0.77 | 0.85 | 0.62 | 0.49 | 0.76 |
|
| 78 |
+
| FlowGRPO | 1.00 | 0.99 | 0.91 | 0.89 | 0.95 | 0.80 | 0.92 |
|
| 79 |
+
| DiffusionNFT | 1.00 | 0.98 | 0.74 | 0.92 | 0.85 | 0.80 | 0.88 |
|
| 80 |
+
| PromptRL w/o PE | 1.00 | 0.96 | 0.95 | 0.95 | 0.93 | 0.85 | 0.94 |
|
| 81 |
+
| **PromptRL w/ PE** | **1.00** | **0.99** | **0.99** | **0.96** | **0.99** | **0.90** | **0.97** |
|
| 82 |
+
|
| 83 |
+
</details>
|
| 84 |
+
|
| 85 |
+
<details>
|
| 86 |
+
<summary><b>🎨 Aesthetic & OCR Metrics (Full Results)</b></summary>
|
| 87 |
+
|
| 88 |
+
<br>
|
| 89 |
+
|
| 90 |
+
| Model | P.S. | HPS | U.R. | OCR-1k | TMDB | OpenLib |
|
| 91 |
+
|:---|:---:|:---:|:---:|:---:|:---:|:---:|
|
| 92 |
+
| SD1.5 | 20.92 | 23.71 | 2.00 | 0.05 | 0.13 | 0.08 |
|
| 93 |
+
| SDXL | 22.14 | 26.67 | 2.78 | 0.13 | 0.20 | 0.09 |
|
| 94 |
+
| SD3 Medium | 22.38 | 28.56 | 3.09 | — | 0.44 | 0.33 |
|
| 95 |
+
| FLUX.1-schnell | 22.64 | 29.39 | 3.25 | 0.54 | 0.66 | 0.50 |
|
| 96 |
+
| FLUX.2-klein | 22.79 | 29.03 | 3.29 | 0.55 | 0.22 | 0.46 |
|
| 97 |
+
| Z-Image | 20.14 | 28.22 | 3.51 | 0.70 | 0.71 | 0.83 |
|
| 98 |
+
| Qwen-Image | 23.05 | 30.40 | 3.53 | 0.65 | 0.79 | 0.94 |
|
| 99 |
+
| Qwen-Image-2512 | 23.16 | 30.79 | 3.40 | 0.72 | 0.81 | 0.87 |
|
| 100 |
+
| *RL-based* | | | | | | |
|
| 101 |
+
| FlowGRPO | 23.33 | 29.80 | 3.33 | 0.89 | 0.83 | 0.73 |
|
| 102 |
+
| DiffusionNFT | 23.63 | 31.79 | 3.39 | 0.89 | 0.91 | 0.86 |
|
| 103 |
+
| PromptRL w/o PE | 24.01 | 31.79 | 3.38 | 0.97 | 0.92 | 0.95 |
|
| 104 |
+
| **PromptRL w/ PE** | **24.05** | **32.03** | **3.44** | **0.98** | **0.91** | **0.95** |
|
| 105 |
+
|
| 106 |
+
</details>
|
| 107 |
+
|
| 108 |
+
<details>
|
| 109 |
+
<summary><b>✏️ Image Editing - EditReward (Full Results)</b></summary>
|
| 110 |
+
|
| 111 |
+
<br>
|
| 112 |
+
|
| 113 |
+
| Model | Swap | Style | Add. | Attr. | Env. | Removal | Avg.↑ |
|
| 114 |
+
|:---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
|
| 115 |
+
| InstructPix2Pix | -0.24 | 0.91 | -0.45 | 0.45 | 0.48 | -0.80 | 0.02 |
|
| 116 |
+
| MagicBrush | -0.38 | 0.36 | -0.78 | -0.80 | 0.91 | -0.85 | -0.27 |
|
| 117 |
+
| LEDITS++ | -0.81 | -0.32 | -0.30 | -0.60 | -0.37 | -0.97 | -0.60 |
|
| 118 |
+
| Qwen-Image-Edit | 1.11 | 1.14 | 0.95 | 0.90 | 1.39 | 0.61 | 1.03 |
|
| 119 |
+
| FLUX.2-klein | 1.42 | 1.73 | 1.29 | 1.42 | 1.80 | 0.32 | 1.34 |
|
| 120 |
+
| Nano Banana | 1.58 | 1.20 | 1.28 | 1.18 | 1.61 | 1.13 | 1.37 |
|
| 121 |
+
| Step1X-Edit | 1.39 | 1.58 | 1.19 | 1.34 | 1.57 | 0.22 | 1.24 |
|
| 122 |
+
| ReasonEdit | 1.51 | 1.43 | 1.19 | 1.47 | 1.58 | 1.14 | 1.40 |
|
| 123 |
+
| ReasonEdit-Think | 1.52 | 1.47 | 1.19 | 1.44 | 1.69 | 1.27 | 1.44 |
|
| 124 |
+
| FLUX.1-Kontext | 1.35 | 1.36 | 1.16 | 1.15 | 1.44 | 0.55 | 1.19 |
|
| 125 |
+
| FLUX.1-Kontext w/ PE | 1.35 | 0.97 | 1.04 | 0.48 | 1.22 | 0.65 | 1.01 |
|
| 126 |
+
| PromptRL w/o PE | 1.45 | 1.46 | 1.28 | 1.35 | 1.56 | 0.98 | 1.36 |
|
| 127 |
+
| **PromptRL w/ PE** | **1.47** | **1.43** | **1.29** | **1.39** | **1.72** | **1.24** | **1.43** |
|
| 128 |
+
|
| 129 |
+
</details>
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
## Citation
|
| 134 |
+
|
| 135 |
+
```bibtext
|
| 136 |
+
@article{wang2025promptrl,
|
| 137 |
+
title={PromptRL: Prompt Matters in RL for Flow-Based Image Generation},
|
| 138 |
+
author={Wang, Fu-Yun and Zhang, Han and Gharbi, Michael and Li, Hongsheng and Park, Taesung},
|
| 139 |
+
journal={arXiv preprint arXiv:2602.01382},
|
| 140 |
+
year={2026}
|
| 141 |
+
}
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
```bibtext
|
| 145 |
+
@article{wang2025unirl,
|
| 146 |
+
title={UniRL-Zero: Reinforcement Learning on Unified Models with Joint Language Model and Diffusion Model Experts},
|
| 147 |
+
author={Wang, Fu-Yun and Zhang, Han and Gharbi, Michael and Li, Hongsheng and Park, Taesung},
|
| 148 |
+
journal={arXiv preprint arXiv:2510.17937},
|
| 149 |
+
year={2025}
|
| 150 |
+
}
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
## Acknowledgments
|
| 154 |
+
|
| 155 |
+
This codebase builds upon [UniRL-Zero](https://github.com/G-U-N/UniRL/tree/master).
|
assets/edit_comparison.png
ADDED
|
Git LFS Details
|
assets/logo.png
ADDED
|
Git LFS Details
|
assets/t2i_comparison.png
ADDED
|
Git LFS Details
|
environment.yml
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: unirl
|
| 2 |
+
channels:
|
| 3 |
+
- defaults
|
| 4 |
+
dependencies:
|
| 5 |
+
- _libgcc_mutex=0.1=main
|
| 6 |
+
- _openmp_mutex=5.1=1_gnu
|
| 7 |
+
- bzip2=1.0.8=h5eee18b_6
|
| 8 |
+
- ca-certificates=2025.2.25=h06a4308_0
|
| 9 |
+
- expat=2.7.1=h6a678d5_0
|
| 10 |
+
- ld_impl_linux-64=2.40=h12ee557_0
|
| 11 |
+
- libffi=3.4.4=h6a678d5_1
|
| 12 |
+
- libgcc-ng=11.2.0=h1234567_1
|
| 13 |
+
- libgomp=11.2.0=h1234567_1
|
| 14 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
| 15 |
+
- libuuid=1.41.5=h5eee18b_0
|
| 16 |
+
- libxcb=1.17.0=h9b100fa_0
|
| 17 |
+
- ncurses=6.4=h6a678d5_0
|
| 18 |
+
- openssl=3.0.16=h5eee18b_0
|
| 19 |
+
- pip=25.1=pyhc872135_2
|
| 20 |
+
- pthread-stubs=0.3=h0ce48e5_1
|
| 21 |
+
- python=3.11.13=h1a3bd86_0
|
| 22 |
+
- readline=8.2=h5eee18b_0
|
| 23 |
+
- setuptools=78.1.1=py311h06a4308_0
|
| 24 |
+
- sqlite=3.45.3=h5eee18b_0
|
| 25 |
+
- tk=8.6.14=h993c535_1
|
| 26 |
+
- wheel=0.45.1=py311h06a4308_0
|
| 27 |
+
- xorg-libx11=1.8.12=h9b100fa_1
|
| 28 |
+
- xorg-libxau=1.0.12=h9b100fa_0
|
| 29 |
+
- xorg-libxdmcp=1.1.5=h9b100fa_0
|
| 30 |
+
- xorg-xorgproto=2024.1=h5eee18b_1
|
| 31 |
+
- xz=5.6.4=h5eee18b_1
|
| 32 |
+
- zlib=1.2.13=h5eee18b_1
|
| 33 |
+
- pip:
|
| 34 |
+
- accelerate==1.7.0
|
| 35 |
+
- aiohappyeyeballs==2.6.1
|
| 36 |
+
- aiohttp==3.12.9
|
| 37 |
+
- aiosignal==1.3.2
|
| 38 |
+
- airportsdata==20250523
|
| 39 |
+
- annotated-types==0.7.0
|
| 40 |
+
- anthropic==0.54.0
|
| 41 |
+
- antlr4-python3-runtime==4.13.2
|
| 42 |
+
- anyio==4.9.0
|
| 43 |
+
- astor==0.8.1
|
| 44 |
+
- asttokens==3.0.0
|
| 45 |
+
- attrs==25.3.0
|
| 46 |
+
- av==14.4.0
|
| 47 |
+
- bitsandbytes==0.46.0
|
| 48 |
+
- blake3==1.0.5
|
| 49 |
+
- cachetools==6.0.0
|
| 50 |
+
- certifi==2025.4.26
|
| 51 |
+
- charset-normalizer==3.4.2
|
| 52 |
+
- click==8.2.1
|
| 53 |
+
# - clip==1.0
|
| 54 |
+
- cloudpickle==3.1.1
|
| 55 |
+
- compressed-tensors==0.9.4
|
| 56 |
+
- contourpy==1.3.2
|
| 57 |
+
- cupy-cuda12x==13.4.1
|
| 58 |
+
- cycler==0.12.1
|
| 59 |
+
- datasets==3.6.0
|
| 60 |
+
- decorator==5.2.1
|
| 61 |
+
- deepspeed==0.15.4
|
| 62 |
+
- depyf==0.18.0
|
| 63 |
+
# - diffusers==0.34.0.dev0
|
| 64 |
+
- dill==0.3.8
|
| 65 |
+
- diskcache==5.6.3
|
| 66 |
+
- distro==1.9.0
|
| 67 |
+
- dnspython==2.7.0
|
| 68 |
+
- docker-pycreds==0.4.0
|
| 69 |
+
- einops==0.8.1
|
| 70 |
+
- email-validator==2.2.0
|
| 71 |
+
- executing==2.2.0
|
| 72 |
+
- fastapi==0.115.12
|
| 73 |
+
- fastapi-cli==0.0.7
|
| 74 |
+
- fastrlock==0.8.3
|
| 75 |
+
- filelock==3.18.0
|
| 76 |
+
# - flash-attn==2.7.4.post1
|
| 77 |
+
- fonttools==4.58.4
|
| 78 |
+
- frozenlist==1.6.2
|
| 79 |
+
- fsspec==2025.3.0
|
| 80 |
+
- ftfy==6.3.1
|
| 81 |
+
- gguf==0.17.0
|
| 82 |
+
- gitdb==4.0.12
|
| 83 |
+
- gitpython==3.1.44
|
| 84 |
+
- googleapis-common-protos==1.70.0
|
| 85 |
+
- grpcio==1.72.1
|
| 86 |
+
- h11==0.16.0
|
| 87 |
+
- hf-transfer==0.1.9
|
| 88 |
+
- hf-xet==1.1.3
|
| 89 |
+
- hjson==3.1.0
|
| 90 |
+
- httpcore==1.0.9
|
| 91 |
+
- httptools==0.6.4
|
| 92 |
+
- httpx==0.28.1
|
| 93 |
+
- huggingface-hub==0.32.4
|
| 94 |
+
- idna==3.10
|
| 95 |
+
- importlib-metadata==8.7.0
|
| 96 |
+
- inquirerpy==0.3.4
|
| 97 |
+
- interegular==0.3.3
|
| 98 |
+
- ipython==9.3.0
|
| 99 |
+
- ipython-pygments-lexers==1.1.1
|
| 100 |
+
- jedi==0.19.2
|
| 101 |
+
- jinja2==3.1.6
|
| 102 |
+
- jiter==0.10.0
|
| 103 |
+
- jsonschema==4.24.0
|
| 104 |
+
- jsonschema-specifications==2025.4.1
|
| 105 |
+
- kiwisolver==1.4.8
|
| 106 |
+
- lark==1.2.2
|
| 107 |
+
- latex2sympy2-extended==1.10.1
|
| 108 |
+
- liger-kernel==0.5.2
|
| 109 |
+
- llguidance==0.7.29
|
| 110 |
+
- llvmlite==0.44.0
|
| 111 |
+
- lm-format-enforcer==0.10.11
|
| 112 |
+
- markdown-it-py==3.0.0
|
| 113 |
+
- markupsafe==3.0.2
|
| 114 |
+
- math-verify==0.7.0
|
| 115 |
+
- matplotlib==3.10.3
|
| 116 |
+
- matplotlib-inline==0.1.7
|
| 117 |
+
- mdurl==0.1.2
|
| 118 |
+
- mistral-common==1.5.6
|
| 119 |
+
- mpmath==1.3.0
|
| 120 |
+
- msgpack==1.1.0
|
| 121 |
+
- msgspec==0.19.0
|
| 122 |
+
- multidict==6.4.4
|
| 123 |
+
- multiprocess==0.70.16
|
| 124 |
+
- nest-asyncio==1.6.0
|
| 125 |
+
- networkx==3.5
|
| 126 |
+
- ninja==1.11.1.4
|
| 127 |
+
- numba==0.61.2
|
| 128 |
+
- numpy==2.2.6
|
| 129 |
+
- nvidia-cublas-cu12==12.6.4.1
|
| 130 |
+
- nvidia-cuda-cupti-cu12==12.6.80
|
| 131 |
+
- nvidia-cuda-nvrtc-cu12==12.6.77
|
| 132 |
+
- nvidia-cuda-runtime-cu12==12.6.77
|
| 133 |
+
- nvidia-cudnn-cu12==9.5.1.17
|
| 134 |
+
- nvidia-cufft-cu12==11.3.0.4
|
| 135 |
+
- nvidia-cufile-cu12==1.11.1.6
|
| 136 |
+
- nvidia-curand-cu12==10.3.7.77
|
| 137 |
+
- nvidia-cusolver-cu12==11.7.1.2
|
| 138 |
+
- nvidia-cusparse-cu12==12.5.4.2
|
| 139 |
+
- nvidia-cusparselt-cu12==0.6.3
|
| 140 |
+
- nvidia-nccl-cu12==2.26.2
|
| 141 |
+
- nvidia-nvjitlink-cu12==12.6.85
|
| 142 |
+
- nvidia-nvtx-cu12==12.6.77
|
| 143 |
+
- openai==1.84.0
|
| 144 |
+
- opencv-python-headless==4.11.0.86
|
| 145 |
+
- opentelemetry-api==1.34.0
|
| 146 |
+
- opentelemetry-exporter-otlp==1.34.0
|
| 147 |
+
- opentelemetry-exporter-otlp-proto-common==1.34.0
|
| 148 |
+
- opentelemetry-exporter-otlp-proto-grpc==1.34.0
|
| 149 |
+
- opentelemetry-exporter-otlp-proto-http==1.34.0
|
| 150 |
+
- opentelemetry-proto==1.34.0
|
| 151 |
+
- opentelemetry-sdk==1.34.0
|
| 152 |
+
- opentelemetry-semantic-conventions==0.55b0
|
| 153 |
+
- opentelemetry-semantic-conventions-ai==0.4.9
|
| 154 |
+
- outlines==0.1.11
|
| 155 |
+
- outlines-core==0.1.26
|
| 156 |
+
- packaging==25.0
|
| 157 |
+
- pandas==2.3.0
|
| 158 |
+
- parso==0.8.4
|
| 159 |
+
- partial-json-parser==0.2.1.1.post5
|
| 160 |
+
- peft==0.17.1
|
| 161 |
+
- pexpect==4.9.0
|
| 162 |
+
- pfzy==0.3.4
|
| 163 |
+
- pillow==11.2.1
|
| 164 |
+
- platformdirs==4.3.8
|
| 165 |
+
- prometheus-client==0.22.1
|
| 166 |
+
- prometheus-fastapi-instrumentator==7.1.0
|
| 167 |
+
- prompt-toolkit==3.0.51
|
| 168 |
+
- propcache==0.3.1
|
| 169 |
+
- protobuf==5.29.5
|
| 170 |
+
- psutil==7.0.0
|
| 171 |
+
- ptyprocess==0.7.0
|
| 172 |
+
- pure-eval==0.2.3
|
| 173 |
+
- py-cpuinfo==9.0.0
|
| 174 |
+
- pyarrow==20.0.0
|
| 175 |
+
- pycountry==24.6.1
|
| 176 |
+
- pydantic==2.11.5
|
| 177 |
+
- pydantic-core==2.33.2
|
| 178 |
+
- pygments==2.19.1
|
| 179 |
+
- pyparsing==3.2.3
|
| 180 |
+
- python-dateutil==2.9.0.post0
|
| 181 |
+
- python-dotenv==1.1.0
|
| 182 |
+
- python-json-logger==3.3.0
|
| 183 |
+
- python-multipart==0.0.20
|
| 184 |
+
- pytz==2025.2
|
| 185 |
+
- pyyaml==6.0.2
|
| 186 |
+
- pyzmq==26.4.0
|
| 187 |
+
- qwen-vl-utils==0.0.11
|
| 188 |
+
- ray==2.46.0
|
| 189 |
+
- referencing==0.36.2
|
| 190 |
+
- regex==2024.11.6
|
| 191 |
+
- requests==2.32.3
|
| 192 |
+
- rich==14.0.0
|
| 193 |
+
- rich-toolkit==0.14.7
|
| 194 |
+
- rpds-py==0.25.1
|
| 195 |
+
- safetensors==0.5.3
|
| 196 |
+
- scipy==1.15.3
|
| 197 |
+
- seaborn==0.13.2
|
| 198 |
+
- sentencepiece==0.2.0
|
| 199 |
+
- sentry-sdk==2.29.1
|
| 200 |
+
- setproctitle==1.3.6
|
| 201 |
+
- shellingham==1.5.4
|
| 202 |
+
- six==1.17.0
|
| 203 |
+
- smmap==5.0.2
|
| 204 |
+
- sniffio==1.3.1
|
| 205 |
+
- stack-data==0.6.3
|
| 206 |
+
- starlette==0.46.2
|
| 207 |
+
- sympy==1.14.0
|
| 208 |
+
- tabulate==0.9.0
|
| 209 |
+
- tiktoken==0.9.0
|
| 210 |
+
- timm==0.6.13
|
| 211 |
+
- tokenizers==0.21.1
|
| 212 |
+
- torch==2.7.0
|
| 213 |
+
- torchaudio==2.7.0
|
| 214 |
+
- torchvision==0.22.0
|
| 215 |
+
- tqdm==4.67.1
|
| 216 |
+
- traitlets==5.14.3
|
| 217 |
+
- transformers==4.51.3
|
| 218 |
+
- triton==3.3.0
|
| 219 |
+
- trl==0.19.0
|
| 220 |
+
- typer==0.16.0
|
| 221 |
+
- typing-extensions==4.14.0
|
| 222 |
+
- typing-inspection==0.4.1
|
| 223 |
+
- tzdata==2025.2
|
| 224 |
+
- urllib3==2.4.0
|
| 225 |
+
- utils==1.0.2
|
| 226 |
+
- uvicorn==0.34.3
|
| 227 |
+
- uvloop==0.21.0
|
| 228 |
+
- vllm==0.9.0.1
|
| 229 |
+
- wandb==0.18.3
|
| 230 |
+
- watchfiles==1.0.5
|
| 231 |
+
- wcwidth==0.2.13
|
| 232 |
+
- websockets==15.0.1
|
| 233 |
+
- xformers==0.0.30
|
| 234 |
+
- xgrammar==0.1.19
|
| 235 |
+
- xxhash==3.5.0
|
| 236 |
+
- yarl==1.20.0
|
| 237 |
+
- zipp==3.22.0
|
| 238 |
+
- tensorboardX==2.6.4
|
eval.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Batch image evaluation tool with YAML configuration."""
|
| 3 |
+
|
| 4 |
+
import requests
|
| 5 |
+
import pickle
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from typing import List, Dict, Any, Union, Optional, Tuple
|
| 8 |
+
import sys
|
| 9 |
+
import os
|
| 10 |
+
import json
|
| 11 |
+
import yaml
|
| 12 |
+
from io import BytesIO
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
PAIR_SCORERS = {"editreward"}
|
| 18 |
+
CAPTION_SUFFIXES = ["_caption.txt", "_prompt.txt"]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class RewardEvaluatorClient:
|
| 22 |
+
def __init__(self, scorer_urls: Dict[str, str]):
|
| 23 |
+
self.scorer_urls = scorer_urls
|
| 24 |
+
|
| 25 |
+
def evaluate(self,
|
| 26 |
+
model_name: str,
|
| 27 |
+
images: Union[List[Image.Image], Dict[str, List[Image.Image]]],
|
| 28 |
+
prompts: List[str],
|
| 29 |
+
metadata: Dict[str, Any] = None) -> Union[List[float], Dict[str, Any]]:
|
| 30 |
+
url = self.scorer_urls.get(model_name)
|
| 31 |
+
if not url:
|
| 32 |
+
raise ValueError(f"Reward model '{model_name}' URL not configured.")
|
| 33 |
+
|
| 34 |
+
payload_bytes = create_payload(images, prompts, metadata)
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
response = requests.post(url, data=payload_bytes, timeout=600)
|
| 38 |
+
response.raise_for_status()
|
| 39 |
+
result = parse_response(response.content)
|
| 40 |
+
|
| 41 |
+
if isinstance(result, dict) and "error" in result:
|
| 42 |
+
raise RuntimeError(f"Scorer '{model_name}' returned error: {result['error']}")
|
| 43 |
+
|
| 44 |
+
return result
|
| 45 |
+
|
| 46 |
+
except requests.exceptions.RequestException as e:
|
| 47 |
+
raise RuntimeError(f"HTTP request to '{model_name}' failed: {e}")
|
| 48 |
+
except Exception as e:
|
| 49 |
+
raise RuntimeError(f"Failed to process response from '{model_name}': {e}")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def serialize_images(images: List[Image.Image]) -> List[bytes]:
|
| 53 |
+
images_bytes = []
|
| 54 |
+
for img in images:
|
| 55 |
+
img_byte_arr = BytesIO()
|
| 56 |
+
if img.mode != 'RGB':
|
| 57 |
+
img = img.convert('RGB')
|
| 58 |
+
img.save(img_byte_arr, format="JPEG")
|
| 59 |
+
images_bytes.append(img_byte_arr.getvalue())
|
| 60 |
+
return images_bytes
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def create_payload(images: Union[List[Image.Image], Dict[str, List[Image.Image]]],
|
| 64 |
+
prompts: List[str],
|
| 65 |
+
metadata: Dict[str, Any] = None) -> bytes:
|
| 66 |
+
if isinstance(images, dict):
|
| 67 |
+
serialized_images = {key: serialize_images(value) for key, value in images.items()}
|
| 68 |
+
else:
|
| 69 |
+
serialized_images = serialize_images(images)
|
| 70 |
+
|
| 71 |
+
return pickle.dumps({
|
| 72 |
+
"images": serialized_images,
|
| 73 |
+
"prompts": prompts,
|
| 74 |
+
"metadata": metadata or {}
|
| 75 |
+
})
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def parse_response(response_content: bytes) -> Union[List[float], Dict[str, Any]]:
|
| 79 |
+
return pickle.loads(response_content)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def find_caption_file(base_path: str, base_name: str) -> Optional[str]:
|
| 83 |
+
for suffix in CAPTION_SUFFIXES:
|
| 84 |
+
caption_path = os.path.join(base_path, f"{base_name}{suffix}")
|
| 85 |
+
if os.path.exists(caption_path):
|
| 86 |
+
return caption_path
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def collect_standard_samples(folder_path: str) -> Tuple[List[Image.Image], List[str], List[str]]:
|
| 91 |
+
images, prompts, filenames = [], [], []
|
| 92 |
+
|
| 93 |
+
for file in sorted(os.listdir(folder_path)):
|
| 94 |
+
if not file.lower().endswith(('.png', '.jpg', '.jpeg')):
|
| 95 |
+
continue
|
| 96 |
+
if any(suffix in file for suffix in ['_edited', '_reference', '_source']):
|
| 97 |
+
continue
|
| 98 |
+
|
| 99 |
+
base_name = os.path.splitext(file)[0]
|
| 100 |
+
img_path = os.path.join(folder_path, file)
|
| 101 |
+
caption_path = find_caption_file(folder_path, base_name)
|
| 102 |
+
|
| 103 |
+
if not caption_path:
|
| 104 |
+
continue
|
| 105 |
+
|
| 106 |
+
try:
|
| 107 |
+
img = Image.open(img_path)
|
| 108 |
+
with open(caption_path, 'r', encoding='utf-8') as f:
|
| 109 |
+
prompt = f.read().strip()
|
| 110 |
+
images.append(img)
|
| 111 |
+
prompts.append(prompt)
|
| 112 |
+
filenames.append(file)
|
| 113 |
+
except Exception as e:
|
| 114 |
+
print(f" Warning: Failed to process {file}: {e}")
|
| 115 |
+
|
| 116 |
+
return images, prompts, filenames
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def collect_edit_samples(folder_path: str) -> Tuple[Dict[str, List[Image.Image]], List[str], List[str]]:
|
| 120 |
+
source_images, edited_images, prompts, filenames = [], [], [], []
|
| 121 |
+
|
| 122 |
+
edited_files = [f for f in os.listdir(folder_path) if f.endswith('_edited.png')]
|
| 123 |
+
|
| 124 |
+
for edited_file in sorted(edited_files):
|
| 125 |
+
base_name = edited_file.replace('_edited.png', '')
|
| 126 |
+
source_file = f"{base_name}_reference.png"
|
| 127 |
+
|
| 128 |
+
if not os.path.exists(os.path.join(folder_path, source_file)):
|
| 129 |
+
source_file = f"{base_name}_source.png"
|
| 130 |
+
|
| 131 |
+
source_path = os.path.join(folder_path, source_file)
|
| 132 |
+
edited_path = os.path.join(folder_path, edited_file)
|
| 133 |
+
caption_path = find_caption_file(folder_path, base_name)
|
| 134 |
+
|
| 135 |
+
if not os.path.exists(source_path) or not caption_path:
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
source_img = Image.open(source_path)
|
| 140 |
+
edited_img = Image.open(edited_path)
|
| 141 |
+
with open(caption_path, 'r', encoding='utf-8') as f:
|
| 142 |
+
prompt = f.read().strip()
|
| 143 |
+
|
| 144 |
+
source_images.append(source_img)
|
| 145 |
+
edited_images.append(edited_img)
|
| 146 |
+
prompts.append(prompt)
|
| 147 |
+
filenames.append(base_name)
|
| 148 |
+
except Exception as e:
|
| 149 |
+
print(f" Warning: Failed to process {base_name}: {e}")
|
| 150 |
+
|
| 151 |
+
return {'source': source_images, 'edited': edited_images}, prompts, filenames
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def evaluate_folder(folder_path: str,
|
| 155 |
+
model_name: str,
|
| 156 |
+
batch_size: int,
|
| 157 |
+
scorer_urls: Dict[str, str],
|
| 158 |
+
verbose: bool = True) -> Optional[Dict[str, Any]]:
|
| 159 |
+
if not os.path.isdir(folder_path):
|
| 160 |
+
return None
|
| 161 |
+
|
| 162 |
+
evaluator = RewardEvaluatorClient(scorer_urls)
|
| 163 |
+
is_pair_scorer = model_name in PAIR_SCORERS
|
| 164 |
+
|
| 165 |
+
if is_pair_scorer:
|
| 166 |
+
images, prompts, filenames = collect_edit_samples(folder_path)
|
| 167 |
+
sample_count = len(prompts)
|
| 168 |
+
else:
|
| 169 |
+
images, prompts, filenames = collect_standard_samples(folder_path)
|
| 170 |
+
sample_count = len(images)
|
| 171 |
+
|
| 172 |
+
if sample_count == 0:
|
| 173 |
+
if verbose:
|
| 174 |
+
print(f" Skipped (no valid samples): {folder_path}")
|
| 175 |
+
return None
|
| 176 |
+
|
| 177 |
+
if verbose:
|
| 178 |
+
print(f" Evaluating {sample_count} samples: {folder_path}")
|
| 179 |
+
|
| 180 |
+
all_scores = []
|
| 181 |
+
|
| 182 |
+
if is_pair_scorer:
|
| 183 |
+
source_images = images['source']
|
| 184 |
+
edited_images = images['edited']
|
| 185 |
+
|
| 186 |
+
for start_idx in tqdm(range(0, sample_count, batch_size), disable=not verbose):
|
| 187 |
+
end_idx = min(start_idx + batch_size, sample_count)
|
| 188 |
+
batch_images = {
|
| 189 |
+
'source': source_images[start_idx:end_idx],
|
| 190 |
+
'edited': edited_images[start_idx:end_idx]
|
| 191 |
+
}
|
| 192 |
+
batch_prompts = prompts[start_idx:end_idx]
|
| 193 |
+
|
| 194 |
+
try:
|
| 195 |
+
batch_results = evaluator.evaluate(model_name, batch_images, batch_prompts)
|
| 196 |
+
scores = batch_results.get('scores', batch_results) if isinstance(batch_results, dict) else batch_results
|
| 197 |
+
all_scores.extend(scores)
|
| 198 |
+
except Exception as e:
|
| 199 |
+
print(f" Batch evaluation failed [{start_idx}:{end_idx}]: {e}")
|
| 200 |
+
return None
|
| 201 |
+
else:
|
| 202 |
+
for start_idx in tqdm(range(0, sample_count, batch_size), disable=not verbose):
|
| 203 |
+
end_idx = min(start_idx + batch_size, sample_count)
|
| 204 |
+
batch_images = images[start_idx:end_idx]
|
| 205 |
+
batch_prompts = prompts[start_idx:end_idx]
|
| 206 |
+
|
| 207 |
+
try:
|
| 208 |
+
batch_results = evaluator.evaluate(model_name, batch_images, batch_prompts)
|
| 209 |
+
scores = batch_results.get('scores', batch_results) if isinstance(batch_results, dict) else batch_results
|
| 210 |
+
all_scores.extend(scores)
|
| 211 |
+
except Exception as e:
|
| 212 |
+
print(f" Batch evaluation failed [{start_idx}:{end_idx}]: {e}")
|
| 213 |
+
continue
|
| 214 |
+
|
| 215 |
+
if not all_scores:
|
| 216 |
+
return None
|
| 217 |
+
|
| 218 |
+
return {
|
| 219 |
+
'folder': folder_path,
|
| 220 |
+
'model': model_name,
|
| 221 |
+
'average': sum(all_scores) / len(all_scores),
|
| 222 |
+
'scores': all_scores,
|
| 223 |
+
'count': len(all_scores)
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def find_leaf_folders(root_path: str, min_depth: int = 0, max_depth: int = -1) -> List[str]:
|
| 228 |
+
result = []
|
| 229 |
+
root_path = os.path.abspath(root_path)
|
| 230 |
+
|
| 231 |
+
def has_images(folder: str) -> bool:
|
| 232 |
+
for f in os.listdir(folder):
|
| 233 |
+
if f.lower().endswith(('.png', '.jpg', '.jpeg')):
|
| 234 |
+
return True
|
| 235 |
+
return False
|
| 236 |
+
|
| 237 |
+
def recurse(current_path: str, depth: int):
|
| 238 |
+
if max_depth >= 0 and depth > max_depth:
|
| 239 |
+
return
|
| 240 |
+
|
| 241 |
+
try:
|
| 242 |
+
entries = os.listdir(current_path)
|
| 243 |
+
except PermissionError:
|
| 244 |
+
return
|
| 245 |
+
|
| 246 |
+
subdirs = [e for e in entries if os.path.isdir(os.path.join(current_path, e))]
|
| 247 |
+
|
| 248 |
+
if not subdirs or (max_depth >= 0 and depth == max_depth):
|
| 249 |
+
if depth >= min_depth and has_images(current_path):
|
| 250 |
+
result.append(current_path)
|
| 251 |
+
else:
|
| 252 |
+
for subdir in subdirs:
|
| 253 |
+
recurse(os.path.join(current_path, subdir), depth + 1)
|
| 254 |
+
if depth >= min_depth and has_images(current_path):
|
| 255 |
+
result.append(current_path)
|
| 256 |
+
|
| 257 |
+
recurse(root_path, 0)
|
| 258 |
+
return sorted(result)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def run(config: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
|
| 262 |
+
scorer_urls = config['scorer_urls']
|
| 263 |
+
defaults = config.get('defaults', {})
|
| 264 |
+
evaluations = config['evaluations']
|
| 265 |
+
output_file = config.get('output')
|
| 266 |
+
verbose = config.get('verbose', True)
|
| 267 |
+
|
| 268 |
+
default_batch_size = defaults.get('batch_size', 64)
|
| 269 |
+
default_recursive = defaults.get('recursive', False)
|
| 270 |
+
default_min_depth = defaults.get('min_depth', 0)
|
| 271 |
+
default_max_depth = defaults.get('max_depth', -1)
|
| 272 |
+
|
| 273 |
+
all_results = {}
|
| 274 |
+
|
| 275 |
+
for eval_item in evaluations:
|
| 276 |
+
path = eval_item.get('path')
|
| 277 |
+
if not path:
|
| 278 |
+
print("Warning: Evaluation item missing 'path', skipping")
|
| 279 |
+
continue
|
| 280 |
+
|
| 281 |
+
models = eval_item.get('models', [])
|
| 282 |
+
if not models:
|
| 283 |
+
print(f"Warning: No models specified for {path}, skipping")
|
| 284 |
+
continue
|
| 285 |
+
|
| 286 |
+
batch_size = eval_item.get('batch_size', default_batch_size)
|
| 287 |
+
recursive = eval_item.get('recursive', default_recursive)
|
| 288 |
+
min_depth = eval_item.get('min_depth', default_min_depth)
|
| 289 |
+
max_depth = eval_item.get('max_depth', default_max_depth)
|
| 290 |
+
|
| 291 |
+
if not recursive:
|
| 292 |
+
max_depth = 0
|
| 293 |
+
|
| 294 |
+
folders = find_leaf_folders(path, min_depth, max_depth)
|
| 295 |
+
|
| 296 |
+
if not folders:
|
| 297 |
+
print(f"No image folders found in: {path}")
|
| 298 |
+
continue
|
| 299 |
+
|
| 300 |
+
print(f"\nProcessing {len(folders)} folder(s) from: {path}")
|
| 301 |
+
print(f"Models: {', '.join(models)}")
|
| 302 |
+
print("-" * 60)
|
| 303 |
+
|
| 304 |
+
for folder in tqdm(folders, desc="Folders", disable=not verbose):
|
| 305 |
+
folder_results = {}
|
| 306 |
+
|
| 307 |
+
for model in models:
|
| 308 |
+
if verbose:
|
| 309 |
+
print(f"\n[{model}] ", end="")
|
| 310 |
+
|
| 311 |
+
result = evaluate_folder(folder, model, batch_size, scorer_urls, verbose)
|
| 312 |
+
|
| 313 |
+
if result:
|
| 314 |
+
folder_results[model] = result
|
| 315 |
+
if verbose:
|
| 316 |
+
print(f" -> Average: {result['average']:.4f} (n={result['count']})")
|
| 317 |
+
|
| 318 |
+
if folder_results:
|
| 319 |
+
rel_path = os.path.relpath(folder, path)
|
| 320 |
+
key = f"{path}:{rel_path}" if rel_path != "." else path
|
| 321 |
+
all_results[key] = folder_results
|
| 322 |
+
|
| 323 |
+
# Print summary
|
| 324 |
+
print("\n" + "=" * 60)
|
| 325 |
+
print("Evaluation Summary")
|
| 326 |
+
print("=" * 60)
|
| 327 |
+
for folder, results in all_results.items():
|
| 328 |
+
print(f"\n{folder}")
|
| 329 |
+
for model, data in results.items():
|
| 330 |
+
print(f" [{model}] avg={data['average']:.4f}, n={data['count']}")
|
| 331 |
+
|
| 332 |
+
# Save results
|
| 333 |
+
if output_file:
|
| 334 |
+
serializable = {
|
| 335 |
+
folder: {
|
| 336 |
+
model: {'average': data['average'], 'count': data['count']}
|
| 337 |
+
for model, data in results.items()
|
| 338 |
+
}
|
| 339 |
+
for folder, results in all_results.items()
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 343 |
+
json.dump({
|
| 344 |
+
'timestamp': datetime.now().isoformat(),
|
| 345 |
+
'results': serializable
|
| 346 |
+
}, f, indent=2, ensure_ascii=False)
|
| 347 |
+
|
| 348 |
+
print(f"\nResults saved to: {output_file}")
|
| 349 |
+
|
| 350 |
+
return all_results
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def main():
|
| 354 |
+
if len(sys.argv) != 2:
|
| 355 |
+
print(f"Usage: python {sys.argv[0]} <config.yaml>")
|
| 356 |
+
sys.exit(1)
|
| 357 |
+
|
| 358 |
+
config_path = sys.argv[1]
|
| 359 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 360 |
+
config = yaml.safe_load(f)
|
| 361 |
+
|
| 362 |
+
results = run(config)
|
| 363 |
+
sys.exit(0 if results else 1)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
if __name__ == "__main__":
|
| 367 |
+
main()
|
gen.sh
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
|
| 4 |
+
# Download eval datasets if not present
|
| 5 |
+
EDIT_DATA="data/omni_edit_dev.parquet"
|
| 6 |
+
if [ ! -f "$EDIT_DATA" ]; then
|
| 7 |
+
echo "Downloading edit eval dataset..."
|
| 8 |
+
mkdir -p data
|
| 9 |
+
huggingface-cli download wangfuyun/PrompRL data/omni_edit_dev.parquet \
|
| 10 |
+
--repo-type model --local-dir . --local-dir-use-symlinks False
|
| 11 |
+
fi
|
| 12 |
+
|
| 13 |
+
# # Text-to-Image OCR
|
| 14 |
+
python unified_inference.py --mode t2i \
|
| 15 |
+
--model_path wangfuyun/PrompRL/promptrl_ocr \
|
| 16 |
+
--model_type flux \
|
| 17 |
+
--prompt_file prompts/ocr_test.txt \
|
| 18 |
+
--output_dir outputs/ocr \
|
| 19 |
+
--use_cot --cot_template ocr_clarity_v2
|
| 20 |
+
|
| 21 |
+
# # Text-to-Image PS
|
| 22 |
+
python unified_inference.py --mode t2i \
|
| 23 |
+
--model_path wangfuyun/PrompRL/promptrl_ps \
|
| 24 |
+
--model_type flux \
|
| 25 |
+
--prompt_file prompts/draw_test.txt \
|
| 26 |
+
--output_dir outputs/pickscore \
|
| 27 |
+
--use_cot --cot_template quality_purev2
|
| 28 |
+
|
| 29 |
+
# # GenEval
|
| 30 |
+
python unified_inference.py --mode geneval \
|
| 31 |
+
--model_path wangfuyun/PrompRL/promptrl_geneval \
|
| 32 |
+
--model_type flux \
|
| 33 |
+
--metadata_file prompts/evaluation_metadata.jsonl \
|
| 34 |
+
--output_dir outputs/geneval \
|
| 35 |
+
--use_cot --cot_template geneval \
|
| 36 |
+
--n_samples 4
|
| 37 |
+
|
| 38 |
+
# # Image Editing
|
| 39 |
+
python unified_inference.py --mode edit \
|
| 40 |
+
--model_path wangfuyun/PrompRL/promptrl_edit \
|
| 41 |
+
--model_type kontext \
|
| 42 |
+
--data_file "$EDIT_DATA" \
|
| 43 |
+
--output_dir outputs/edit \
|
| 44 |
+
--use_cot --cot_template edit_general \
|
| 45 |
+
--guidance_scale 2.5
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# python eval.py prompts/config.yaml
|
prompts/config.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Batch Image Evaluator Configuration
|
| 2 |
+
|
| 3 |
+
scorer_urls:
|
| 4 |
+
aesthetic: "http://YOUR_SERVER_IP:18080/"
|
| 5 |
+
image_reward: "http://YOUR_SERVER_IP:18081/"
|
| 6 |
+
ocr: "http://YOUR_SERVER_IP:18082/"
|
| 7 |
+
pickscore: "http://YOUR_SERVER_IP:18083/"
|
| 8 |
+
deqa: "http://YOUR_SERVER_IP:18084/"
|
| 9 |
+
gen_eval: "http://YOUR_SERVER_IP:18085/"
|
| 10 |
+
unifiedreward_sglang: "http://YOUR_SERVER_IP:18086/"
|
| 11 |
+
hps: "http://YOUR_SERVER_IP:18087/"
|
| 12 |
+
editreward: "http://YOUR_SERVER_IP:18088/"
|
| 13 |
+
|
| 14 |
+
defaults:
|
| 15 |
+
batch_size: 64
|
| 16 |
+
recursive: false
|
| 17 |
+
min_depth: 0
|
| 18 |
+
max_depth: -1
|
| 19 |
+
|
| 20 |
+
output: results.json
|
| 21 |
+
verbose: true
|
| 22 |
+
|
| 23 |
+
evaluations:
|
| 24 |
+
- path: ./outputs/ocr
|
| 25 |
+
models: [ocr]
|
| 26 |
+
batch_size: 32
|
| 27 |
+
recursive: true
|
| 28 |
+
|
| 29 |
+
- path: ./outputs/edit
|
| 30 |
+
models: [editreward]
|
| 31 |
+
batch_size: 32
|
| 32 |
+
|
| 33 |
+
- path: ./outputs/pickscore
|
| 34 |
+
models: [pickscore]
|
| 35 |
+
batch_size: 32
|
prompts/draw_test.txt
ADDED
|
@@ -0,0 +1,1000 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
New York Skyline with 'Google Research Pizza Cafe' written with fireworks on the sky.
|
| 2 |
+
A maglev train going vertically downward in high speed, New York Times photojournalism.
|
| 3 |
+
A pyramid made of falafel with a partial solar eclipse in the background.
|
| 4 |
+
A storefront with 'Google Brain Toronto' written on it.
|
| 5 |
+
An elephant under the sea.
|
| 6 |
+
Lego Arnold Schwarzenegger.
|
| 7 |
+
A keyboard made of water, the water is made of light, the light is turned off.
|
| 8 |
+
Artophagous.
|
| 9 |
+
One cat and one dog sitting on the grass.
|
| 10 |
+
A laptop on top of a teddy bear.
|
| 11 |
+
A red colored car.
|
| 12 |
+
A stack of 3 books. A green book is on the top, sitting on a red book. The red book is in the middle, sitting on a blue book. The blue book is on the bottom.
|
| 13 |
+
A green colored banana.
|
| 14 |
+
Matutinal.
|
| 15 |
+
A green cup and a blue cell phone.
|
| 16 |
+
A stack of 3 plates. A blue plate is on the top, sitting on a blue plate. The blue plate is in the middle, sitting on a green plate. The green plate is on the bottom.
|
| 17 |
+
A large thick-skinned semiaquatic African mammal, with massive jaws and large tusks.
|
| 18 |
+
A red colored banana.
|
| 19 |
+
Jentacular.
|
| 20 |
+
A sign that says 'Hello World'.
|
| 21 |
+
A blue cup and a green cell phone.
|
| 22 |
+
A black colored banana.
|
| 23 |
+
Two cats and two dogs sitting on the grass.
|
| 24 |
+
A ldarge keybord msical instroument lwith a woden case enmclosig a qsouvnkboajrd and mfgtal strivgf, which are strucrk b hammrs when the nels are depresdsmed.f lhe strsingsj' vibration ie stopped by damperds when the keys re released and can bce regulavewdd for lengh and vnolume y two or three pedalvs.
|
| 25 |
+
A magnifying glass over a page of a 1950s batman comic.
|
| 26 |
+
A separate seat for one person, typically with a back and four legs.
|
| 27 |
+
Two dogs on the street.
|
| 28 |
+
New York Skyline with 'Diffusion' written with fireworks on the sky.
|
| 29 |
+
A black colored banana.
|
| 30 |
+
An ancient Egyptian painting depicting an argument over whose turn it is to take out the trash.
|
| 31 |
+
A wine glass on top of a dog.
|
| 32 |
+
An emoji of a baby panda wearing a red hat, green gloves, red shirt, and green pants.
|
| 33 |
+
A pear cut into seven pieces arranged in a ring.
|
| 34 |
+
A large thick-skinned semiaquatic African mammal, with massive jaws and large tusks.
|
| 35 |
+
A baby fennec sneezing onto a strawberry, detailed, macro, studio light, droplets, backlit ears.
|
| 36 |
+
A panda making latte art.
|
| 37 |
+
An IT-guy trying to fix hardware of a PC tower is being tangled by the PC cables like Laokoon. Marble, copy after Hellenistic original from ca. 200 BC. Found in the Baths of Trajan, 1506.
|
| 38 |
+
A blue bird and a brown bear.
|
| 39 |
+
A triangular purple flower pot. A purple flower pot in the shape of a triangle.
|
| 40 |
+
A green apple and a black backpack.
|
| 41 |
+
A grocery store refrigerator has pint cartons of milk on the top shelf, quart cartons on the middle shelf, and gallon plastic jugs on the bottom shelf.
|
| 42 |
+
A small domesticated carnivorous mammal with soft fur, a short snout, and retractable claws. It is widely kept as a pet or for catching mice, and many breeds have been developed.
|
| 43 |
+
An orange colored sandwich.
|
| 44 |
+
A large motor vehicle carrying passengers by road, typically one serving the public on a fixed route and for a fare.
|
| 45 |
+
A sphere made of kitchen tile. A sphere with the texture of kitchen tile.
|
| 46 |
+
A cat on the right of a tennis racket.
|
| 47 |
+
Bzaseball galove.
|
| 48 |
+
A sign that says 'NeurIPS'.
|
| 49 |
+
A 1960s yearbook photo with animals dressed as humans.
|
| 50 |
+
New York Skyline with 'Hello World' written with fireworks on the sky.
|
| 51 |
+
Hovering cow abducting aliens.
|
| 52 |
+
A small vessel propelled on water by oars, sails, or an engine.
|
| 53 |
+
A type of digital currency in which a record of transactions is maintained and new units of currency are generated by the computational solution of mathematical problems, and which operates independently of a central bank.
|
| 54 |
+
A pink colored car.
|
| 55 |
+
A storefront with 'NeurIPS' written on it.
|
| 56 |
+
A black apple and a green backpack.
|
| 57 |
+
A large motor vehicle carrying passengers by road, typically one serving the public on a fixed route and for a fare.
|
| 58 |
+
A long curved fruit which grows in clusters and has soft pulpy flesh and yellow skin when ripe.
|
| 59 |
+
A black colored car.
|
| 60 |
+
A realistic photo of a Pomeranian dressed up like a 1980s professional wrestler with neon green and neon orange face paint and bright green wrestling tights with bright orange boots.
|
| 61 |
+
Tcennis rpacket.
|
| 62 |
+
McDonalds Church.
|
| 63 |
+
Painting of Mona Lisa but the view is from behind of Mona Lisa.
|
| 64 |
+
An elephant is behind a tree. You can see the trunk on one side and the back legs on the other.
|
| 65 |
+
Hovering cow abducting aliens.
|
| 66 |
+
Photo of a mega Lego space station inside a kid's bedroom.
|
| 67 |
+
An elephant under the sea.
|
| 68 |
+
One cat and two dogs sitting on the grass.
|
| 69 |
+
A green colored banana.
|
| 70 |
+
An American multinational technology company that focuses on artificial intelligence, search engine, online advertising, cloud computing, computer software, quantum computing, e-commerce, and consumer electronics.
|
| 71 |
+
A domesticated carnivvorous mzammal that typicbally hfaas a lons sfnout, an acxujte sense off osmell, noneetractaaln crlaws, anid xbarkring,y howlingu, or whining rvoiche.
|
| 72 |
+
Jentacular.
|
| 73 |
+
A wine glass on top of a dog.
|
| 74 |
+
A carrot on the left of a broccoli.
|
| 75 |
+
Pafrking metr.
|
| 76 |
+
Three cars on the street.
|
| 77 |
+
In late afternoon in January in New England, a man stands in the shadow of a maple tree.
|
| 78 |
+
An oil painting portrait of the regal Burger King posing with a Whopper.
|
| 79 |
+
A sign that says 'Text to Image'.
|
| 80 |
+
A small vessel propelled on water by oars, sails, or an engine.
|
| 81 |
+
A single clock is sitting on a table.
|
| 82 |
+
A stack of 3 plates. A blue plate is on the top, sitting on a blue plate. The blue plate is in the middle, sitting on a green plate. The green plate is on the bottom.
|
| 83 |
+
An elephant under the sea.
|
| 84 |
+
A type of digital currency in which a record of transactions is maintained and new units of currency are generated by the computational solution of mathematical problems, and which operates independently of a central bank.
|
| 85 |
+
A yellow colored giraffe.
|
| 86 |
+
An elephant is behind a tree. You can see the trunk on one side and the back legs on the other.
|
| 87 |
+
A device consisting of a circular canopy of cloth on a folding metal frame supported by a central rod, used as protection against rain or sometimes sun.
|
| 88 |
+
A fluffy baby sloth with a knitted hat trying to figure out a laptop, close up, highly detailed, studio lighting, screen reflecting in its eyes.
|
| 89 |
+
A stack of 3 plates. A blue plate is on the top, sitting on a blue plate. The blue plate is in the middle, sitting on a green plate. The green plate is on the bottom.
|
| 90 |
+
Three cats and three dogs sitting on the grass.
|
| 91 |
+
A large keyboard musical instrument with a wooden case enclosing a soundboard and metal strings, which are struck by hammers when the keys are depressed. The strings' vibration is stopped by dampers when the keys are released and can be regulated for length and volume by two or three pedals.
|
| 92 |
+
A blue coloured pizza.
|
| 93 |
+
A storefront with 'Google Research Pizza Cafe' written on it.
|
| 94 |
+
A sjmall domesticated carnivorious mammnal with sof fuh,y a sthort sout, and retracwtablbe flaws. It iw widexly kept as a pet or for catchitng mic, ad many breeds zhlyde beefn develvoked.
|
| 95 |
+
A green apple and a black backpack.
|
| 96 |
+
A pink colored car.
|
| 97 |
+
A pear cut into seven pieces arranged in a ring.
|
| 98 |
+
A screenshot of an iOS app for ordering different types of milk.
|
| 99 |
+
Rbefraigerator.
|
| 100 |
+
A blue colored dog.
|
| 101 |
+
Two cats and two dogs sitting on the grass.
|
| 102 |
+
A real life photography of super mario, 8k Ultra HD.
|
| 103 |
+
New York Skyline with 'Hello World' written with fireworks on the sky.
|
| 104 |
+
A realistic photo of a Pomeranian dressed up like a 1980s professional wrestler with neon green and neon orange face paint and bright green wrestling tights with bright orange boots.
|
| 105 |
+
A panda making latte art.
|
| 106 |
+
A storefront with 'NeurIPS' written on it.
|
| 107 |
+
A large keyboard musical instrument with a wooden case enclosing a soundboard and metal strings, which are struck by hammers when the keys are depressed. The strings' vibration is stopped by dampers when the keys are released and can be regulated for length and volume by two or three pedals.
|
| 108 |
+
A blue colored dog.
|
| 109 |
+
Three cats and two dogs sitting on the grass.
|
| 110 |
+
New York Skyline with 'Google Brain Toronto' written with fireworks on the sky.
|
| 111 |
+
A blue coloured pizza.
|
| 112 |
+
A panda making latte art.
|
| 113 |
+
An American multinational technology company that focuses on artificial intelligence, search engine, online advertising, cloud computing, computer software, quantum computing, e-commerce, and consumer electronics.
|
| 114 |
+
Backlotter.
|
| 115 |
+
A black colored sandwich.
|
| 116 |
+
A large thick-skinned semiaquatic African mammal, with massive jaws and large tusks.
|
| 117 |
+
A domesticated carnivorous mammal that typically has a long snout, an acute sense of smell, nonretractable claws, and a barking, howling, or whining voice.
|
| 118 |
+
New York Skyline with 'Deep Learning' written with fireworks on the sky.
|
| 119 |
+
A black colored dog.
|
| 120 |
+
A stack of 3 cubes. A red cube is on the top, sitting on a red cube. The red cube is in the middle, sitting on a green cube. The green cube is on the bottom.
|
| 121 |
+
A type of digital currency in which a record of transactions is maintained and new units of currency are generated by the computational solution of mathematical problems, and which operates independently of a central bank.
|
| 122 |
+
Five cars on the street.
|
| 123 |
+
An old photograph of a 1920s airship shaped like a pig, floating over a wheat field.
|
| 124 |
+
Illustration of a mouse using a mushroom as an umbrella.
|
| 125 |
+
Three cats and one dog sitting on the grass.
|
| 126 |
+
Four cars on the street.
|
| 127 |
+
A black colored sandwich.
|
| 128 |
+
Five cars on the street.
|
| 129 |
+
An American multinational technology company that focuses on artificial intelligence, search engine, online advertising, cloud computing, computer software, quantum computing, e-commerce, and consumer electronics.
|
| 130 |
+
A sign that says 'Google Brain Toronto'.
|
| 131 |
+
A storefront with 'Text to Image' written on it.
|
| 132 |
+
A magnifying glass over a page of a 1950s batman comic.
|
| 133 |
+
A sphere made of kitchen tile. A sphere with the texture of kitchen tile.
|
| 134 |
+
An IT-guy trying to fix hardware of a PC tower is being tangled by the PC cables like Laokoon. Marble, copy after Hellenistic original from ca. 200 BC. Found in the Baths of Trajan, 1506.
|
| 135 |
+
A sign that says 'Diffusion'.
|
| 136 |
+
A blue bird and a brown bear.
|
| 137 |
+
A photo of a confused grizzly bear in calculus class.
|
| 138 |
+
A grocery store refrigerator has pint cartons of milk on the top shelf, quart cartons on the middle shelf, and gallon plastic jugs on the bottom shelf.
|
| 139 |
+
A hair drier underneath a sheep.
|
| 140 |
+
Pafrking metr.
|
| 141 |
+
Peristeronic.
|
| 142 |
+
Two cats and one dog sitting on the grass.
|
| 143 |
+
New York Skyline with 'Google Research Pizza Cafe' written with fireworks on the sky.
|
| 144 |
+
A side view of an owl sitting in a field.
|
| 145 |
+
A pink colored car.
|
| 146 |
+
Paying for a quarter-sized pizza with a pizza-sized quarter.
|
| 147 |
+
Dininrg tablez.
|
| 148 |
+
A fish eating a pelican.
|
| 149 |
+
One cat and three dogs sitting on the grass.
|
| 150 |
+
An instrument used for cutting cloth, paper, and other thin material, consisting of two blades laid one on top of the other and fastened in the middle so as to allow them to be opened and closed by a thumb and finger inserted through rings on the end of their handles.
|
| 151 |
+
A side view of an owl sitting in a field.
|
| 152 |
+
A large thick-skinned semiaquatic African mammal, with massive jaws and large tusks.
|
| 153 |
+
A large keyboard musical instrument with a wooden case enclosing a soundboard and metal strings, which are struck by hammers when the keys are depressed. The strings' vibration is stopped by dampers when the keys are released and can be regulated for length and volume by two or three pedals.
|
| 154 |
+
Pafrking metr.
|
| 155 |
+
A sign that says 'Deep Learning'.
|
| 156 |
+
A collection of nail is sitting on a table.
|
| 157 |
+
One car on the street.
|
| 158 |
+
An emoji of a baby panda wearing a red hat, blue gloves, green shirt, and blue pants.
|
| 159 |
+
A brown bird and a blue bear.
|
| 160 |
+
A donkey and an octopus are playing a game. The donkey is holding a rope on one end, the octopus is holding onto the other. The donkey holds the rope in its mouth. A cat is jumping over the rope.
|
| 161 |
+
A fisheye lens view of a turtle sitting in a forest.
|
| 162 |
+
A large motor vehicle carrying passengers by road, typically one serving the public on a fixed route and for a fare.
|
| 163 |
+
New York Skyline with 'Hello World' written with fireworks on the sky.
|
| 164 |
+
An emoji of a baby panda wearing a red hat, green gloves, red shirt, and green pants.
|
| 165 |
+
A black colored dog.
|
| 166 |
+
A ldarge keybord msical instroument lwith a woden case enmclosig a qsouvnkboajrd and mfgtal strivgf, which are strucrk b hammrs when the nels are depresdsmed.f lhe strsingsj' vibration ie stopped by damperds when the keys re released and can bce regulavewdd for lengh and vnolume y two or three pedalvs.
|
| 167 |
+
Artophagous.
|
| 168 |
+
A yellow book and a red vase.
|
| 169 |
+
A stack of 3 books. A green book is on the top, sitting on a red book. The red book is in the middle, sitting on a blue book. The blue book is on the bottom.
|
| 170 |
+
A pizza on the right of a suitcase.
|
| 171 |
+
A tiger in a lab coat with a 1980s Miami vibe, turning a well oiled science content machine, digital art.
|
| 172 |
+
A storefront with 'Hello World' written on it.
|
| 173 |
+
A tiger in a lab coat with a 1980s Miami vibe, turning a well oiled science content machine, digital art.
|
| 174 |
+
A storefront with 'Google Brain Toronto' written on it.
|
| 175 |
+
A 1960s poster warning against climate change.
|
| 176 |
+
An organ of soft nervous tissue contained in the skull of vertebrates, functioning as the coordinating center of sensation and intellectual and nervous activity.
|
| 177 |
+
A long curved fruit which grows in clusters and has soft pulpy flesh and yellow skin when ripe.
|
| 178 |
+
Supreme Court Justices play a baseball game with the FBI. The FBI is at bat, the justices are on the field.
|
| 179 |
+
A pyramid made of falafel with a partial solar eclipse in the background.
|
| 180 |
+
A single clock is sitting on a table.
|
| 181 |
+
New York Skyline with 'Google Research Pizza Cafe' written with fireworks on the sky.
|
| 182 |
+
A blue cup and a green cell phone.
|
| 183 |
+
An oil painting of a couple in formal evening wear going home get caught in a heavy downpour with no umbrellas.
|
| 184 |
+
Darth Vader playing with raccoon in Mars during sunset.
|
| 185 |
+
A red car and a white sheep.
|
| 186 |
+
An illustration of a large red elephant sitting on a small blue mouse.
|
| 187 |
+
An illustration of a small green elephant standing behind a large red mouse.
|
| 188 |
+
A domesticated carnivorous mammal that typically has a long snout, an acute sense of smell, nonretractable claws, and a barking, howling, or whining voice.
|
| 189 |
+
A medieval painting of the wifi not working.
|
| 190 |
+
An American multinational technology company that focuses on artificial intelligence, search engine, online advertising, cloud computing, computer software, quantum computing, e-commerce, and consumer electronics.
|
| 191 |
+
One cat and two dogs sitting on the grass.
|
| 192 |
+
An IT-guy trying to fix hardware of a PC tower is being tangled by the PC cables like Laokoon. Marble, copy after Hellenistic original from ca. 200 BC. Found in the Baths of Trajan, 1506.
|
| 193 |
+
A fluffy baby sloth with a knitted hat trying to figure out a laptop, close up, highly detailed, studio lighting, screen reflecting in its eyes.
|
| 194 |
+
Abraham Lincoln touches his toes while George Washington does chin-ups. Lincoln is barefoot. Washington is wearing boots.
|
| 195 |
+
An umbrella on top of a spoon.
|
| 196 |
+
Matutinal.
|
| 197 |
+
A pink colored giraffe.
|
| 198 |
+
An emoji of a baby panda wearing a red hat, green gloves, red shirt, and green pants.
|
| 199 |
+
Illustration of a mouse using a mushroom as an umbrella.
|
| 200 |
+
A brown bird and a blue bear.
|
| 201 |
+
A painting by Grant Wood of an astronaut couple, american gothic style.
|
| 202 |
+
A sign that says 'Diffusion'.
|
| 203 |
+
Five dogs on the street.
|
| 204 |
+
Four dogs on the street.
|
| 205 |
+
A cat on the left of a dog.
|
| 206 |
+
A zebra underneath a broccoli.
|
| 207 |
+
A banana on the left of an apple.
|
| 208 |
+
Two cats and three dogs sitting on the grass.
|
| 209 |
+
A yellow colored giraffe.
|
| 210 |
+
Three cats and one dog sitting on the grass.
|
| 211 |
+
A ldarge keybord msical instroument lwith a woden case enmclosig a qsouvnkboajrd and mfgtal strivgf, which are strucrk b hammrs when the nels are depresdsmed.f lhe strsingsj' vibration ie stopped by damperds when the keys re released and can bce regulavewdd for lengh and vnolume y two or three pedalvs.
|
| 212 |
+
Abraham Lincoln touches his toes while George Washington does chin-ups. Lincoln is barefoot. Washington is wearing boots.
|
| 213 |
+
A yellow book and a red vase.
|
| 214 |
+
A cat on the left of a dog.
|
| 215 |
+
A stop sign on the right of a refrigerator.
|
| 216 |
+
A shark in the desert.
|
| 217 |
+
Octothorpe.
|
| 218 |
+
A red colored car.
|
| 219 |
+
Four cars on the street.
|
| 220 |
+
A tiger in a lab coat with a 1980s Miami vibe, turning a well oiled science content machine, digital art.
|
| 221 |
+
Three cats and one dog sitting on the grass.
|
| 222 |
+
Paying for a quarter-sized pizza with a pizza-sized quarter.
|
| 223 |
+
A zebra to the right of a fire hydrant.
|
| 224 |
+
A stack of 3 cubes. A red cube is on the top, sitting on a red cube. The red cube is in the middle, sitting on a green cube. The green cube is on the bottom.
|
| 225 |
+
A 1960s poster warning against climate change.
|
| 226 |
+
A storefront with 'Google Research Pizza Cafe' written on it.
|
| 227 |
+
A laptop on top of a teddy bear.
|
| 228 |
+
A painting by Grant Wood of an astronaut couple, american gothic style.
|
| 229 |
+
New York Skyline with 'Deep Learning' written with fireworks on the sky.
|
| 230 |
+
A storefront with 'Diffusion' written on it.
|
| 231 |
+
A storefront with 'Text to Image' written on it.
|
| 232 |
+
A small blue book sitting on a large red book.
|
| 233 |
+
Colouring page of large cats climbing the eifel tower in a cyberpunk future.
|
| 234 |
+
An emoji of a baby panda wearing a red hat, blue gloves, green shirt, and blue pants.
|
| 235 |
+
A photo of a confused grizzly bear in calculus class.
|
| 236 |
+
Paying for a quarter-sized pizza with a pizza-sized quarter.
|
| 237 |
+
Painting of the orange cat Otto von Garfield, Count of Bismarck-Schönhausen, Duke of Lauenburg, Minister-President of Prussia. Depicted wearing a Prussian Pickelhaube and eating his favorite meal - lasagna.
|
| 238 |
+
A device consisting of a circular canopy of cloth on a folding metal frame supported by a central rod, used as protection against rain or sometimes sun.
|
| 239 |
+
Supreme Court Justices play a baseball game with the FBI. The FBI is at bat, the justices are on the field.
|
| 240 |
+
A triangular pink stop sign. A pink stop sign in the shape of a triangle.
|
| 241 |
+
Painting of the orange cat Otto von Garfield, Count of Bismarck-Schönhausen, Duke of Lauenburg, Minister-President of Prussia. Depicted wearing a Prussian Pickelhaube and eating his favorite meal - lasagna.
|
| 242 |
+
A train on top of a surfboard.
|
| 243 |
+
A stack of 3 cubes. A red cube is on the top, sitting on a red cube. The red cube is in the middle, sitting on a green cube. The green cube is on the bottom.
|
| 244 |
+
A sjmall domesticated carnivorious mammnal with sof fuh,y a sthort sout, and retracwtablbe flaws. It iw widexly kept as a pet or for catchitng mic, ad many breeds zhlyde beefn develvoked.
|
| 245 |
+
A laptop on top of a teddy bear.
|
| 246 |
+
A train on top of a surfboard.
|
| 247 |
+
A photocopy of a photograph of a painting of a sculpture of a giraffe.
|
| 248 |
+
A 1960s yearbook photo with animals dressed as humans.
|
| 249 |
+
A pink colored giraffe.
|
| 250 |
+
A maglev train going vertically downward in high speed, New York Times photojournalism.
|
| 251 |
+
A domesticated carnivvorous mzammal that typicbally hfaas a lons sfnout, an acxujte sense off osmell, noneetractaaln crlaws, anid xbarkring,y howlingu, or whining rvoiche.
|
| 252 |
+
A sign that says 'Google Research Pizza Cafe'.
|
| 253 |
+
Two cars on the street.
|
| 254 |
+
A tennis racket underneath a traffic light.
|
| 255 |
+
A cross-section view of a brain.
|
| 256 |
+
One cat and one dog sitting on the grass.
|
| 257 |
+
A horse riding an astronaut.
|
| 258 |
+
A car playing soccer, digital art.
|
| 259 |
+
A large plant-eating domesticated mammal with solid hoofs and a flowing mane and tail, used for riding, racing, and to carry and pull loads.
|
| 260 |
+
Three dogs on the street.
|
| 261 |
+
A separate seat for one person, typically with a back and four legs.
|
| 262 |
+
A couple of glasses are sitting on a table.
|
| 263 |
+
A couch on the left of a chair.
|
| 264 |
+
Two cars on the street.
|
| 265 |
+
A photocopy of a photograph of a painting of a sculpture of a giraffe.
|
| 266 |
+
A black apple and a green backpack.
|
| 267 |
+
A pyramid made of falafel with a partial solar eclipse in the background.
|
| 268 |
+
A brown colored giraffe.
|
| 269 |
+
One cat and one dog sitting on the grass.
|
| 270 |
+
A pizza cooking an oven.
|
| 271 |
+
A church with stained glass windows depicting a hamburger and french fries.
|
| 272 |
+
A connection point by which firefighters can tap into a water supply.
|
| 273 |
+
A sign that says 'Google Research Pizza Cafe'.
|
| 274 |
+
35mm macro shot a kitten licking a baby duck, studio lighting.
|
| 275 |
+
New York Skyline with 'Text to Image' written with fireworks on the sky.
|
| 276 |
+
An oil painting portrait of the regal Burger King posing with a Whopper.
|
| 277 |
+
A storefront with 'Google Brain Toronto' written on it.
|
| 278 |
+
A bridge connecting Europe and North America on the Atlantic Ocean, bird's eye view.
|
| 279 |
+
One cat and three dogs sitting on the grass.
|
| 280 |
+
Octothorpe.
|
| 281 |
+
A connection point by which firefighters can tap into a water supply.
|
| 282 |
+
A donut underneath a toilet.
|
| 283 |
+
Colouring page of large cats climbing the eifel tower in a cyberpunk future.
|
| 284 |
+
A panda making latte art.
|
| 285 |
+
A machine next to a parking space in a street, into which the driver puts money so as to be authorized to park the vehicle for a particular length of time.
|
| 286 |
+
New York Skyline with 'Google Brain Toronto' written with fireworks on the sky.
|
| 287 |
+
A real life photography of super mario, 8k Ultra HD.
|
| 288 |
+
A cat on the right of a tennis racket.
|
| 289 |
+
A sign that says 'Diffusion'.
|
| 290 |
+
An illustration of a large red elephant sitting on a small blue mouse.
|
| 291 |
+
A collection of nail is sitting on a table.
|
| 292 |
+
An appliance or compartment which is artificially kept cool and used to store food and drink.
|
| 293 |
+
An oil painting portrait of the regal Burger King posing with a Whopper.
|
| 294 |
+
Abraham Lincoln touches his toes while George Washington does chin-ups. Lincoln is barefoot. Washington is wearing boots.
|
| 295 |
+
A black colored dog.
|
| 296 |
+
One cat and two dogs sitting on the grass.
|
| 297 |
+
A donkey and an octopus are playing a game. The donkey is holding a rope on one end, the octopus is holding onto the other. The donkey holds the rope in its mouth. A cat is jumping over the rope.
|
| 298 |
+
A pink colored giraffe.
|
| 299 |
+
A hair drier underneath a sheep.
|
| 300 |
+
A couch on the left of a chair.
|
| 301 |
+
A cube made of denim. A cube with the texture of denim.
|
| 302 |
+
Jentacular.
|
| 303 |
+
An old photograph of a 1920s airship shaped like a pig, floating over a wheat field.
|
| 304 |
+
Colouring page of large cats climbing the eifel tower in a cyberpunk future.
|
| 305 |
+
A collection of nail is sitting on a table.
|
| 306 |
+
One dog on the street.
|
| 307 |
+
A stack of 3 cubes. A red cube is on the top, sitting on a red cube. The red cube is in the middle, sitting on a green cube. The green cube is on the bottom.
|
| 308 |
+
Illustration of a mouse using a mushroom as an umbrella.
|
| 309 |
+
A zebra to the right of a fire hydrant.
|
| 310 |
+
Two dogs on the street.
|
| 311 |
+
Photo of an athlete cat explaining it's latest scandal at a press conference to journalists.
|
| 312 |
+
A domesticated carnivvorous mzammal that typicbally hfaas a lons sfnout, an acxujte sense off osmell, noneetractaaln crlaws, anid xbarkring,y howlingu, or whining rvoiche.
|
| 313 |
+
A vehicle composed of two wheels held in a frame one behind the other, propelled by pedals and steered with handlebars attached to the front wheel.
|
| 314 |
+
A sign that says 'NeurIPS'.
|
| 315 |
+
A church with stained glass windows depicting a hamburger and french fries.
|
| 316 |
+
A shark in the desert.
|
| 317 |
+
An emoji of a baby panda wearing a red hat, blue gloves, green shirt, and blue pants.
|
| 318 |
+
A machine next to a parking space in a street, into which the driver puts money so as to be authorized to park the vehicle for a particular length of time.
|
| 319 |
+
Artophagous.
|
| 320 |
+
A car on the left of a bus.
|
| 321 |
+
A storefront with 'Google Brain Toronto' written on it.
|
| 322 |
+
A cube made of denim. A cube with the texture of denim.
|
| 323 |
+
A red colored banana.
|
| 324 |
+
Two dogs on the street.
|
| 325 |
+
Five cars on the street.
|
| 326 |
+
A mechanical or electrical device for measuring time.
|
| 327 |
+
Acersecomicke.
|
| 328 |
+
An illustration of a large red elephant sitting on a small blue mouse.
|
| 329 |
+
A triangular pink stop sign. A pink stop sign in the shape of a triangle.
|
| 330 |
+
Peristeronic.
|
| 331 |
+
A keyboard made of water, the water is made of light, the light is turned off.
|
| 332 |
+
Greek statue of a man tripping over a cat.
|
| 333 |
+
Two cats and three dogs sitting on the grass.
|
| 334 |
+
New York Skyline with 'Google Brain Toronto' written with fireworks on the sky.
|
| 335 |
+
Rbefraigerator.
|
| 336 |
+
A storefront with 'Google Research Pizza Cafe' written on it.
|
| 337 |
+
Four cars on the street.
|
| 338 |
+
An oil painting portrait of the regal Burger King posing with a Whopper.
|
| 339 |
+
A fluffy baby sloth with a knitted hat trying to figure out a laptop, close up, highly detailed, studio lighting, screen reflecting in its eyes.
|
| 340 |
+
An oil painting of a couple in formal evening wear going home get caught in a heavy downpour with no umbrellas.
|
| 341 |
+
Painting of the orange cat Otto von Garfield, Count of Bismarck-Schönhausen, Duke of Lauenburg, Minister-President of Prussia. Depicted wearing a Prussian Pickelhaube and eating his favorite meal - lasagna.
|
| 342 |
+
A grocery store refrigerator has pint cartons of milk on the top shelf, quart cartons on the middle shelf, and gallon plastic jugs on the bottom shelf.
|
| 343 |
+
A real life photography of super mario, 8k Ultra HD.
|
| 344 |
+
A carrot on the left of a broccoli.
|
| 345 |
+
Darth Vader playing with raccoon in Mars during sunset.
|
| 346 |
+
Four dogs on the street.
|
| 347 |
+
Photo of a cat singing in a barbershop quartet.
|
| 348 |
+
A real life photography of super mario, 8k Ultra HD.
|
| 349 |
+
A triangular pink stop sign. A pink stop sign in the shape of a triangle.
|
| 350 |
+
A small blue book sitting on a large red book.
|
| 351 |
+
A green colored banana.
|
| 352 |
+
A bicycle on top of a boat.
|
| 353 |
+
A blue cup and a green cell phone.
|
| 354 |
+
A cat on the right of a tennis racket.
|
| 355 |
+
A stop sign on the right of a refrigerator.
|
| 356 |
+
A sign that says 'Diffusion'.
|
| 357 |
+
A blue coloured pizza.
|
| 358 |
+
A device consisting of a circular canopy of cloth on a folding metal frame supported by a central rod, used as protection against rain or sometimes sun.
|
| 359 |
+
A green cup and a blue cell phone.
|
| 360 |
+
Three cats and two dogs sitting on the grass.
|
| 361 |
+
A laptop on top of a teddy bear.
|
| 362 |
+
A medieval painting of the wifi not working.
|
| 363 |
+
A small vessel propelled on water by oars, sails, or an engine.
|
| 364 |
+
Photo of a mega Lego space station inside a kid's bedroom.
|
| 365 |
+
A car on the left of a bus.
|
| 366 |
+
A green colored banana.
|
| 367 |
+
A photo of a confused grizzly bear in calculus class.
|
| 368 |
+
Three dogs on the street.
|
| 369 |
+
A medieval painting of the wifi not working.
|
| 370 |
+
One cat and three dogs sitting on the grass.
|
| 371 |
+
A red colored car.
|
| 372 |
+
Photo of a mega Lego space station inside a kid's bedroom.
|
| 373 |
+
Abraham Lincoln touches his toes while George Washington does chin-ups. Lincoln is barefoot. Washington is wearing boots.
|
| 374 |
+
Photo of a cat singing in a barbershop quartet.
|
| 375 |
+
A tennis racket underneath a traffic light.
|
| 376 |
+
Two cars on the street.
|
| 377 |
+
A sign that says 'Hello World'.
|
| 378 |
+
A church with stained glass windows depicting a hamburger and french fries.
|
| 379 |
+
A horse riding an astronaut.
|
| 380 |
+
A cross-section view of a brain.
|
| 381 |
+
A couple of glasses are sitting on a table.
|
| 382 |
+
A domesticated carnivorous mammal that typically has a long snout, an acute sense of smell, nonretractable claws, and a barking, howling, or whining voice.
|
| 383 |
+
A green cup and a blue cell phone.
|
| 384 |
+
Acersecomicke.
|
| 385 |
+
A giraffe underneath a microwave.
|
| 386 |
+
An elephant is behind a tree. You can see the trunk on one side and the back legs on the other.
|
| 387 |
+
A train on top of a surfboard.
|
| 388 |
+
A banana on the left of an apple.
|
| 389 |
+
A blue cup and a green cell phone.
|
| 390 |
+
A blue colored dog.
|
| 391 |
+
A sphere made of kitchen tile. A sphere with the texture of kitchen tile.
|
| 392 |
+
A couple of glasses are sitting on a table.
|
| 393 |
+
Matutinal.
|
| 394 |
+
An instrument used for cutting cloth, paper, and other thin material, consisting of two blades laid one on top of the other and fastened in the middle so as to allow them to be opened and closed by a thumb and finger inserted through rings on the end of their handles.
|
| 395 |
+
New York Skyline with 'Diffusion' written with fireworks on the sky.
|
| 396 |
+
A white car and a red sheep.
|
| 397 |
+
A sign that says 'NeurIPS'.
|
| 398 |
+
Five cars on the street.
|
| 399 |
+
A red colored dog.
|
| 400 |
+
New York Skyline with 'Text to Image' written with fireworks on the sky.
|
| 401 |
+
New York Skyline with 'Diffusion' written with fireworks on the sky.
|
| 402 |
+
Three cats and three dogs sitting on the grass.
|
| 403 |
+
A storefront with 'Deep Learning' written on it.
|
| 404 |
+
A hair drier underneath a sheep.
|
| 405 |
+
An instqrumemnt used for cutting cloth, paper, axdz othr thdin mteroial, consamistng of two blades lad one on tvopb of the other and fhastned in tle mixdqdjle so as to bllow them txo be pened and closed by thumb and fitngesr inserted tgrough rings on kthe end oc thei vatndlzes.
|
| 406 |
+
One dog on the street.
|
| 407 |
+
A fish eating a pelican.
|
| 408 |
+
A baby fennec sneezing onto a strawberry, detailed, macro, studio light, droplets, backlit ears.
|
| 409 |
+
A maglev train going vertically downward in high speed, New York Times photojournalism.
|
| 410 |
+
Supreme Court Justices play a baseball game with the FBI. The FBI is at bat, the justices are on the field.
|
| 411 |
+
A photo of a confused grizzly bear in calculus class.
|
| 412 |
+
A triangular pink stop sign. A pink stop sign in the shape of a triangle.
|
| 413 |
+
Matutinal.
|
| 414 |
+
Two cars on the street.
|
| 415 |
+
An orange colored sandwich.
|
| 416 |
+
A storefront with 'NeurIPS' written on it.
|
| 417 |
+
A grocery store refrigerator has pint cartons of milk on the top shelf, quart cartons on the middle shelf, and gallon plastic jugs on the bottom shelf.
|
| 418 |
+
A stack of 3 plates. A blue plate is on the top, sitting on a blue plate. The blue plate is in the middle, sitting on a green plate. The green plate is on the bottom.
|
| 419 |
+
In late afternoon in January in New England, a man stands in the shadow of a maple tree.
|
| 420 |
+
Hovering cow abducting aliens.
|
| 421 |
+
A triangular pink stop sign. A pink stop sign in the shape of a triangle.
|
| 422 |
+
A photocopy of a photograph of a painting of a sculpture of a giraffe.
|
| 423 |
+
A separate seat for one person, typically with a back and four legs.
|
| 424 |
+
A horse riding an astronaut.
|
| 425 |
+
Three cats and three dogs sitting on the grass.
|
| 426 |
+
A bird scaring a scarecrow.
|
| 427 |
+
Tcennis rpacket.
|
| 428 |
+
One car on the street.
|
| 429 |
+
A mechanical or electrical device for measuring time.
|
| 430 |
+
New York Skyline with 'NeurIPS' written with fireworks on the sky.
|
| 431 |
+
A fish eating a pelican.
|
| 432 |
+
A black apple and a green backpack.
|
| 433 |
+
A cube made of denim. A cube with the texture of denim.
|
| 434 |
+
A storefront with 'Deep Learning' written on it.
|
| 435 |
+
New York Skyline with 'Deep Learning' written with fireworks on the sky.
|
| 436 |
+
A brown colored giraffe.
|
| 437 |
+
A bird scaring a scarecrow.
|
| 438 |
+
A blue colored dog.
|
| 439 |
+
An emoji of a baby panda wearing a red hat, green gloves, red shirt, and green pants.
|
| 440 |
+
A green cup and a blue cell phone.
|
| 441 |
+
A carrot on the left of a broccoli.
|
| 442 |
+
A green apple and a black backpack.
|
| 443 |
+
A yellow book and a red vase.
|
| 444 |
+
A triangular purple flower pot. A purple flower pot in the shape of a triangle.
|
| 445 |
+
A small vessel propelled on water by oars, sails, or an engine.
|
| 446 |
+
An orange colored sandwich.
|
| 447 |
+
A tomato has been put on top of a pumpkin on a kitchen stool. There is a fork sticking into the pumpkin. The scene is viewed from above.
|
| 448 |
+
Rbefraigerator.
|
| 449 |
+
A machine next to a parking space in a street, into which the driver puts money so as to be authorized to park the vehicle for a particular length of time.
|
| 450 |
+
A hair drier underneath a sheep.
|
| 451 |
+
A grocery store refrigerator has pint cartons of milk on the top shelf, quart cartons on the middle shelf, and gallon plastic jugs on the bottom shelf.
|
| 452 |
+
A sign that says 'Deep Learning'.
|
| 453 |
+
A cross-section view of a brain.
|
| 454 |
+
A black colored car.
|
| 455 |
+
Two cars on the street.
|
| 456 |
+
Photo of an athlete cat explaining it's latest scandal at a press conference to journalists.
|
| 457 |
+
Rainbow coloured penguin.
|
| 458 |
+
A black apple and a green backpack.
|
| 459 |
+
Darth Vader playing with raccoon in Mars during sunset.
|
| 460 |
+
A spider with a moustache bidding an equally gentlemanly grasshopper a good day during his walk to work.
|
| 461 |
+
One cat and three dogs sitting on the grass.
|
| 462 |
+
35mm macro shot a kitten licking a baby duck, studio lighting.
|
| 463 |
+
An umbrella on top of a spoon.
|
| 464 |
+
Bzaseball galove.
|
| 465 |
+
Greek statue of a man tripping over a cat.
|
| 466 |
+
Supreme Court Justices play a baseball game with the FBI. The FBI is at bat, the justices are on the field.
|
| 467 |
+
An instqrumemnt used for cutting cloth, paper, axdz othr thdin mteroial, consamistng of two blades lad one on tvopb of the other and fhastned in tle mixdqdjle so as to bllow them txo be pened and closed by thumb and fitngesr inserted tgrough rings on kthe end oc thei vatndlzes.
|
| 468 |
+
A car on the left of a bus.
|
| 469 |
+
One dog on the street.
|
| 470 |
+
A church with stained glass windows depicting a hamburger and french fries.
|
| 471 |
+
A vehicle composed of two wheels held in a frame one behind the other, propelled by pedals and steered with handlebars attached to the front wheel.
|
| 472 |
+
A cross-section view of a brain.
|
| 473 |
+
A donut underneath a toilet.
|
| 474 |
+
A small blue book sitting on a large red book.
|
| 475 |
+
A smafml vessef epropoeilled on watvewr by ors, sauls, or han engie.
|
| 476 |
+
A sign that says 'Deep Learning'.
|
| 477 |
+
Photo of a cat singing in a barbershop quartet.
|
| 478 |
+
A cube made of brick. A cube with the texture of brick.
|
| 479 |
+
An oil painting of a couple in formal evening wear going home get caught in a heavy downpour with no umbrellas.
|
| 480 |
+
One car on the street.
|
| 481 |
+
A mechanical or electrical device for measuring time.
|
| 482 |
+
Hyper-realistic photo of an abandoned industrial site during a storm.
|
| 483 |
+
A giraffe underneath a microwave.
|
| 484 |
+
New York Skyline with 'Google Brain Toronto' written with fireworks on the sky.
|
| 485 |
+
An ancient Egyptian painting depicting an argument over whose turn it is to take out the trash.
|
| 486 |
+
A red book and a yellow vase.
|
| 487 |
+
A yellow colored giraffe.
|
| 488 |
+
A smafml vessef epropoeilled on watvewr by ors, sauls, or han engie.
|
| 489 |
+
A long curved fruit which grows in clusters and has soft pulpy flesh and yellow skin when ripe.
|
| 490 |
+
New York Skyline with 'Hello World' written with fireworks on the sky.
|
| 491 |
+
Two cats and two dogs sitting on the grass.
|
| 492 |
+
Photo of a cat singing in a barbershop quartet.
|
| 493 |
+
Colouring page of large cats climbing the eifel tower in a cyberpunk future.
|
| 494 |
+
Abraham Lincoln touches his toes while George Washington does chin-ups. Lincoln is barefoot. Washington is wearing boots.
|
| 495 |
+
A medieval painting of the wifi not working.
|
| 496 |
+
A car playing soccer, digital art.
|
| 497 |
+
A black colored car.
|
| 498 |
+
An orange colored sandwich.
|
| 499 |
+
A ldarge keybord msical instroument lwith a woden case enmclosig a qsouvnkboajrd and mfgtal strivgf, which are strucrk b hammrs when the nels are depresdsmed.f lhe strsingsj' vibration ie stopped by damperds when the keys re released and can bce regulavewdd for lengh and vnolume y two or three pedalvs.
|
| 500 |
+
An instrument used for cutting cloth, paper, and other thin material, consisting of two blades laid one on top of the other and fastened in the middle so as to allow them to be opened and closed by a thumb and finger inserted through rings on the end of their handles.
|
| 501 |
+
Four cars on the street.
|
| 502 |
+
A small domesticated carnivorous mammal with soft fur, a short snout, and retractable claws. It is widely kept as a pet or for catching mice, and many breeds have been developed.
|
| 503 |
+
A donkey and an octopus are playing a game. The donkey is holding a rope on one end, the octopus is holding onto the other. The donkey holds the rope in its mouth. A cat is jumping over the rope.
|
| 504 |
+
An illustration of a large red elephant sitting on a small blue mouse.
|
| 505 |
+
Octothorpe.
|
| 506 |
+
A fisheye lens view of a turtle sitting in a forest.
|
| 507 |
+
New York Skyline with 'Text to Image' written with fireworks on the sky.
|
| 508 |
+
A storefront with 'Deep Learning' written on it.
|
| 509 |
+
A spider with a moustache bidding an equally gentlemanly grasshopper a good day during his walk to work.
|
| 510 |
+
An oil painting of a couple in formal evening wear going home get caught in a heavy downpour with no umbrellas.
|
| 511 |
+
An emoji of a baby panda wearing a red hat, blue gloves, green shirt, and blue pants.
|
| 512 |
+
An IT-guy trying to fix hardware of a PC tower is being tangled by the PC cables like Laokoon. Marble, copy after Hellenistic original from ca. 200 BC. Found in the Baths of Trajan, 1506.
|
| 513 |
+
A sheep to the right of a wine glass.
|
| 514 |
+
A cube made of denim. A cube with the texture of denim.
|
| 515 |
+
Painting of the orange cat Otto von Garfield, Count of Bismarck-Schönhausen, Duke of Lauenburg, Minister-President of Prussia. Depicted wearing a Prussian Pickelhaube and eating his favorite meal - lasagna.
|
| 516 |
+
A sign that says 'Google Research Pizza Cafe'.
|
| 517 |
+
A ldarge keybord msical instroument lwith a woden case enmclosig a qsouvnkboajrd and mfgtal strivgf, which are strucrk b hammrs when the nels are depresdsmed.f lhe strsingsj' vibration ie stopped by damperds when the keys re released and can bce regulavewdd for lengh and vnolume y two or three pedalvs.
|
| 518 |
+
35mm macro shot a kitten licking a baby duck, studio lighting.
|
| 519 |
+
A shark in the desert.
|
| 520 |
+
A green colored banana.
|
| 521 |
+
A green cup and a blue cell phone.
|
| 522 |
+
Backlotter.
|
| 523 |
+
Darth Vader playing with raccoon in Mars during sunset.
|
| 524 |
+
A green apple and a black backpack.
|
| 525 |
+
A tomato has been put on top of a pumpkin on a kitchen stool. There is a fork sticking into the pumpkin. The scene is viewed from above.
|
| 526 |
+
A red colored dog.
|
| 527 |
+
A red book and a yellow vase.
|
| 528 |
+
Rbefraigerator.
|
| 529 |
+
A train on top of a surfboard.
|
| 530 |
+
Dininrg tablez.
|
| 531 |
+
A separate seat for one person, typically with a back and four legs.
|
| 532 |
+
A domesticated carnivvorous mzammal that typicbally hfaas a lons sfnout, an acxujte sense off osmell, noneetractaaln crlaws, anid xbarkring,y howlingu, or whining rvoiche.
|
| 533 |
+
A black colored dog.
|
| 534 |
+
A pink colored giraffe.
|
| 535 |
+
New York Skyline with 'Google Brain Toronto' written with fireworks on the sky.
|
| 536 |
+
Supreme Court Justices play a baseball game with the FBI. The FBI is at bat, the justices are on the field.
|
| 537 |
+
A white car and a red sheep.
|
| 538 |
+
An organ of soft nervous tissue contained in the skull of vertebrates, functioning as the coordinating center of sensation and intellectual and nervous activity.
|
| 539 |
+
Tcennis rpacket.
|
| 540 |
+
A red book and a yellow vase.
|
| 541 |
+
A cross-section view of a brain.
|
| 542 |
+
An illustration of a small green elephant standing behind a large red mouse.
|
| 543 |
+
One dog on the street.
|
| 544 |
+
A zebra underneath a broccoli.
|
| 545 |
+
A zebra to the right of a fire hydrant.
|
| 546 |
+
A large plant-eating domesticated mammal with solid hoofs and a flowing mane and tail, used for riding, racing, and to carry and pull loads.
|
| 547 |
+
An elephant under the sea.
|
| 548 |
+
An elephant under the sea.
|
| 549 |
+
A pizza on the right of a suitcase.
|
| 550 |
+
Greek statue of a man tripping over a cat.
|
| 551 |
+
A couple of glasses are sitting on a table.
|
| 552 |
+
A storefront with 'Diffusion' written on it.
|
| 553 |
+
A sheep to the right of a wine glass.
|
| 554 |
+
A fisheye lens view of a turtle sitting in a forest.
|
| 555 |
+
A vehicle composed of two wheels held in a frame one behind the other, propelled by pedals and steered with handlebars attached to the front wheel.
|
| 556 |
+
A 1960s poster warning against climate change.
|
| 557 |
+
Three cars on the street.
|
| 558 |
+
An umbrella on top of a spoon.
|
| 559 |
+
A zebra underneath a broccoli.
|
| 560 |
+
A black colored dog.
|
| 561 |
+
A small domesticated carnivorous mammal with soft fur, a short snout, and retractable claws. It is widely kept as a pet or for catching mice, and many breeds have been developed.
|
| 562 |
+
A sign that says 'Google Brain Toronto'.
|
| 563 |
+
A large plant-eating domesticated mammal with solid hoofs and a flowing mane and tail, used for riding, racing, and to carry and pull loads.
|
| 564 |
+
A sign that says 'NeurIPS'.
|
| 565 |
+
Pafrking metr.
|
| 566 |
+
A sign that says 'Text to Image'.
|
| 567 |
+
A screenshot of an iOS app for ordering different types of milk.
|
| 568 |
+
A large motor vehicle carrying passengers by road, typically one serving the public on a fixed route and for a fare.
|
| 569 |
+
One cat and two dogs sitting on the grass.
|
| 570 |
+
A cube made of brick. A cube with the texture of brick.
|
| 571 |
+
A storefront with 'Text to Image' written on it.
|
| 572 |
+
A screenshot of an iOS app for ordering different types of milk.
|
| 573 |
+
Two dogs on the street.
|
| 574 |
+
Dininrg tablez.
|
| 575 |
+
A baby fennec sneezing onto a strawberry, detailed, macro, studio light, droplets, backlit ears.
|
| 576 |
+
A cat on the left of a dog.
|
| 577 |
+
A machine resembling a human being and able to replicate certain human movements and functions automatically.
|
| 578 |
+
A panda making latte art.
|
| 579 |
+
A storefront with 'Hello World' written on it.
|
| 580 |
+
New York Skyline with 'Diffusion' written with fireworks on the sky.
|
| 581 |
+
Two cats and three dogs sitting on the grass.
|
| 582 |
+
McDonalds Church.
|
| 583 |
+
A cat on the left of a dog.
|
| 584 |
+
Octothorpe.
|
| 585 |
+
Painting of Mona Lisa but the view is from behind of Mona Lisa.
|
| 586 |
+
A smafml vessef epropoeilled on watvewr by ors, sauls, or han engie.
|
| 587 |
+
A maglev train going vertically downward in high speed, New York Times photojournalism.
|
| 588 |
+
Three dogs on the street.
|
| 589 |
+
A mechanical or electrical device for measuring time.
|
| 590 |
+
A pear cut into seven pieces arranged in a ring.
|
| 591 |
+
Lego Arnold Schwarzenegger.
|
| 592 |
+
An appliance or compartment which is artificially kept cool and used to store food and drink.
|
| 593 |
+
A black colored car.
|
| 594 |
+
An oil painting portrait of the regal Burger King posing with a Whopper.
|
| 595 |
+
A black colored banana.
|
| 596 |
+
Three cats and three dogs sitting on the grass.
|
| 597 |
+
A domesticated carnivorous mammal that typically has a long snout, an acute sense of smell, nonretractable claws, and a barking, howling, or whining voice.
|
| 598 |
+
A wine glass on top of a dog.
|
| 599 |
+
A tomato has been put on top of a pumpkin on a kitchen stool. There is a fork sticking into the pumpkin. The scene is viewed from above.
|
| 600 |
+
Backlotter.
|
| 601 |
+
A bird scaring a scarecrow.
|
| 602 |
+
A single clock is sitting on a table.
|
| 603 |
+
Bzaseball galove.
|
| 604 |
+
A yellow colored giraffe.
|
| 605 |
+
A white colored sandwich.
|
| 606 |
+
A giraffe underneath a microwave.
|
| 607 |
+
A couch on the left of a chair.
|
| 608 |
+
A pizza on the right of a suitcase.
|
| 609 |
+
Lego Arnold Schwarzenegger.
|
| 610 |
+
A donut underneath a toilet.
|
| 611 |
+
A triangular orange picture frame. An orange picture frame in the shape of a triangle.
|
| 612 |
+
McDonalds Church.
|
| 613 |
+
35mm macro shot a kitten licking a baby duck, studio lighting.
|
| 614 |
+
A machine resembling a human being and able to replicate certain human movements and functions automatically.
|
| 615 |
+
An elephant is behind a tree. You can see the trunk on one side and the back legs on the other.
|
| 616 |
+
A spider with a moustache bidding an equally gentlemanly grasshopper a good day during his walk to work.
|
| 617 |
+
An umbrella on top of a spoon.
|
| 618 |
+
Lego Arnold Schwarzenegger.
|
| 619 |
+
A yellow and black bus cruising through the rainforest.
|
| 620 |
+
A giraffe underneath a microwave.
|
| 621 |
+
A cube made of denim. A cube with the texture of denim.
|
| 622 |
+
A sheep to the right of a wine glass.
|
| 623 |
+
A bridge connecting Europe and North America on the Atlantic Ocean, bird's eye view.
|
| 624 |
+
A 1960s yearbook photo with animals dressed as humans.
|
| 625 |
+
Paying for a quarter-sized pizza with a pizza-sized quarter.
|
| 626 |
+
A black colored sandwich.
|
| 627 |
+
A large keyboard musical instrument with a wooden case enclosing a soundboard and metal strings, which are struck by hammers when the keys are depressed. The strings' vibration is stopped by dampers when the keys are released and can be regulated for length and volume by two or three pedals.
|
| 628 |
+
A spider with a moustache bidding an equally gentlemanly grasshopper a good day during his walk to work.
|
| 629 |
+
One car on the street.
|
| 630 |
+
A carrot on the left of a broccoli.
|
| 631 |
+
Two cats and three dogs sitting on the grass.
|
| 632 |
+
A stack of 3 books. A green book is on the top, sitting on a red book. The red book is in the middle, sitting on a blue book. The blue book is on the bottom.
|
| 633 |
+
Two cats and one dog sitting on the grass.
|
| 634 |
+
An instqrumemnt used for cutting cloth, paper, axdz othr thdin mteroial, consamistng of two blades lad one on tvopb of the other and fhastned in tle mixdqdjle so as to bllow them txo be pened and closed by thumb and fitngesr inserted tgrough rings on kthe end oc thei vatndlzes.
|
| 635 |
+
Dininrg tablez.
|
| 636 |
+
A connection point by which firefighters can tap into a water supply.
|
| 637 |
+
Four dogs on the street.
|
| 638 |
+
A sign that says 'Hello World'.
|
| 639 |
+
Photo of a mega Lego space station inside a kid's bedroom.
|
| 640 |
+
McDonalds Church.
|
| 641 |
+
Illustration of a mouse using a mushroom as an umbrella.
|
| 642 |
+
A magnifying glass over a page of a 1950s batman comic.
|
| 643 |
+
Hyper-realistic photo of an abandoned industrial site during a storm.
|
| 644 |
+
A magnifying glass over a page of a 1950s batman comic.
|
| 645 |
+
An umbrella on top of a spoon.
|
| 646 |
+
A fluffy baby sloth with a knitted hat trying to figure out a laptop, close up, highly detailed, studio lighting, screen reflecting in its eyes.
|
| 647 |
+
A large keyboard musical instrument with a wooden case enclosing a soundboard and metal strings, which are struck by hammers when the keys are depressed. The strings' vibration is stopped by dampers when the keys are released and can be regulated for length and volume by two or three pedals.
|
| 648 |
+
A red colored dog.
|
| 649 |
+
A red colored car.
|
| 650 |
+
A black colored car.
|
| 651 |
+
Five cars on the street.
|
| 652 |
+
A baby fennec sneezing onto a strawberry, detailed, macro, studio light, droplets, backlit ears.
|
| 653 |
+
In late afternoon in January in New England, a man stands in the shadow of a maple tree.
|
| 654 |
+
Photo of a cat singing in a barbershop quartet.
|
| 655 |
+
Hovering cow abducting aliens.
|
| 656 |
+
An old photograph of a 1920s airship shaped like a pig, floating over a wheat field.
|
| 657 |
+
An organ of soft nervous tissue contained in the skull of vertebrates, functioning as the coordinating center of sensation and intellectual and nervous activity.
|
| 658 |
+
A triangular purple flower pot. A purple flower pot in the shape of a triangle.
|
| 659 |
+
A pear cut into seven pieces arranged in a ring.
|
| 660 |
+
A red colored car.
|
| 661 |
+
Two cats and one dog sitting on the grass.
|
| 662 |
+
A cube made of brick. A cube with the texture of brick.
|
| 663 |
+
A pyramid made of falafel with a partial solar eclipse in the background.
|
| 664 |
+
A yellow colored giraffe.
|
| 665 |
+
An instrument used for cutting cloth, paper, and other thin material, consisting of two blades laid one on top of the other and fastened in the middle so as to allow them to be opened and closed by a thumb and finger inserted through rings on the end of their handles.
|
| 666 |
+
A sjmall domesticated carnivorious mammnal with sof fuh,y a sthort sout, and retracwtablbe flaws. It iw widexly kept as a pet or for catchitng mic, ad many breeds zhlyde beefn develvoked.
|
| 667 |
+
A couch on the left of a chair.
|
| 668 |
+
A photocopy of a photograph of a painting of a sculpture of a giraffe.
|
| 669 |
+
A sign that says 'Google Brain Toronto'.
|
| 670 |
+
A sign that says 'Text to Image'.
|
| 671 |
+
Rainbow coloured penguin.
|
| 672 |
+
Two dogs on the street.
|
| 673 |
+
A triangular orange picture frame. An orange picture frame in the shape of a triangle.
|
| 674 |
+
Colouring page of large cats climbing the eifel tower in a cyberpunk future.
|
| 675 |
+
A white colored sandwich.
|
| 676 |
+
A stop sign on the right of a refrigerator.
|
| 677 |
+
A long curved fruit which grows in clusters and has soft pulpy flesh and yellow skin when ripe.
|
| 678 |
+
Three cats and two dogs sitting on the grass.
|
| 679 |
+
A hair drier underneath a sheep.
|
| 680 |
+
A train on top of a surfboard.
|
| 681 |
+
A large plant-eating domesticated mammal with solid hoofs and a flowing mane and tail, used for riding, racing, and to carry and pull loads.
|
| 682 |
+
A sign that says 'Google Research Pizza Cafe'.
|
| 683 |
+
A stop sign on the right of a refrigerator.
|
| 684 |
+
A tiger in a lab coat with a 1980s Miami vibe, turning a well oiled science content machine, digital art.
|
| 685 |
+
A device consisting of a circular canopy of cloth on a folding metal frame supported by a central rod, used as protection against rain or sometimes sun.
|
| 686 |
+
An illustration of a small green elephant standing behind a large red mouse.
|
| 687 |
+
A sign that says 'Hello World'.
|
| 688 |
+
Lego Arnold Schwarzenegger.
|
| 689 |
+
Five dogs on the street.
|
| 690 |
+
A storefront with 'Hello World' written on it.
|
| 691 |
+
An instrument used for cutting cloth, paper, and other thin material, consisting of two blades laid one on top of the other and fastened in the middle so as to allow them to be opened and closed by a thumb and finger inserted through rings on the end of their handles.
|
| 692 |
+
A tomato has been put on top of a pumpkin on a kitchen stool. There is a fork sticking into the pumpkin. The scene is viewed from above.
|
| 693 |
+
In late afternoon in January in New England, a man stands in the shadow of a maple tree.
|
| 694 |
+
Jentacular.
|
| 695 |
+
Four dogs on the street.
|
| 696 |
+
An old photograph of a 1920s airship shaped like a pig, floating over a wheat field.
|
| 697 |
+
A triangular purple flower pot. A purple flower pot in the shape of a triangle.
|
| 698 |
+
A black colored banana.
|
| 699 |
+
35mm macro shot a kitten licking a baby duck, studio lighting.
|
| 700 |
+
Bzaseball galove.
|
| 701 |
+
A fisheye lens view of a turtle sitting in a forest.
|
| 702 |
+
A donut underneath a toilet.
|
| 703 |
+
A donkey and an octopus are playing a game. The donkey is holding a rope on one end, the octopus is holding onto the other. The donkey holds the rope in its mouth. A cat is jumping over the rope.
|
| 704 |
+
A fluffy baby sloth with a knitted hat trying to figure out a laptop, close up, highly detailed, studio lighting, screen reflecting in its eyes.
|
| 705 |
+
A green apple and a black backpack.
|
| 706 |
+
An illustration of a large red elephant sitting on a small blue mouse.
|
| 707 |
+
A smafml vessef epropoeilled on watvewr by ors, sauls, or han engie.
|
| 708 |
+
A machine next to a parking space in a street, into which the driver puts money so as to be authorized to park the vehicle for a particular length of time.
|
| 709 |
+
Rainbow coloured penguin.
|
| 710 |
+
Three cats and one dog sitting on the grass.
|
| 711 |
+
An old photograph of a 1920s airship shaped like a pig, floating over a wheat field.
|
| 712 |
+
New York Skyline with 'Google Research Pizza Cafe' written with fireworks on the sky.
|
| 713 |
+
A painting by Grant Wood of an astronaut couple, american gothic style.
|
| 714 |
+
Four dogs on the street.
|
| 715 |
+
A large motor vehicle carrying passengers by road, typically one serving the public on a fixed route and for a fare.
|
| 716 |
+
A couple of glasses are sitting on a table.
|
| 717 |
+
In late afternoon in January in New England, a man stands in the shadow of a maple tree.
|
| 718 |
+
A brown colored giraffe.
|
| 719 |
+
A bridge connecting Europe and North America on the Atlantic Ocean, bird's eye view.
|
| 720 |
+
Five dogs on the street.
|
| 721 |
+
New York Skyline with 'Text to Image' written with fireworks on the sky.
|
| 722 |
+
An appliance or compartment which is artificially kept cool and used to store food and drink.
|
| 723 |
+
A real life photography of super mario, 8k Ultra HD.
|
| 724 |
+
A pink colored car.
|
| 725 |
+
A painting by Grant Wood of an astronaut couple, american gothic style.
|
| 726 |
+
A car on the left of a bus.
|
| 727 |
+
A large plant-eating domesticated mammal with solid hoofs and a flowing mane and tail, used for riding, racing, and to carry and pull loads.
|
| 728 |
+
Pafrking metr.
|
| 729 |
+
An illustration of a small green elephant standing behind a large red mouse.
|
| 730 |
+
A blue cup and a green cell phone.
|
| 731 |
+
New York Skyline with 'NeurIPS' written with fireworks on the sky.
|
| 732 |
+
A storefront with 'Google Brain Toronto' written on it.
|
| 733 |
+
A painting by Grant Wood of an astronaut couple, american gothic style.
|
| 734 |
+
A black colored sandwich.
|
| 735 |
+
A fish eating a pelican.
|
| 736 |
+
An emoji of a baby panda wearing a red hat, blue gloves, green shirt, and blue pants.
|
| 737 |
+
A vehicle composed of two wheels held in a frame one behind the other, propelled by pedals and steered with handlebars attached to the front wheel.
|
| 738 |
+
A tennis racket underneath a traffic light.
|
| 739 |
+
Three cars on the street.
|
| 740 |
+
One car on the street.
|
| 741 |
+
A tennis racket underneath a traffic light.
|
| 742 |
+
A maglev train going vertically downward in high speed, New York Times photojournalism.
|
| 743 |
+
Photo of an athlete cat explaining it's latest scandal at a press conference to journalists.
|
| 744 |
+
A red book and a yellow vase.
|
| 745 |
+
A shark in the desert.
|
| 746 |
+
An organ of soft nervous tissue contained in the skull of vertebrates, functioning as the coordinating center of sensation and intellectual and nervous activity.
|
| 747 |
+
A sign that says 'Text to Image'.
|
| 748 |
+
A stack of 3 books. A green book is on the top, sitting on a red book. The red book is in the middle, sitting on a blue book. The blue book is on the bottom.
|
| 749 |
+
A shark in the desert.
|
| 750 |
+
A 1960s poster warning against climate change.
|
| 751 |
+
Backlotter.
|
| 752 |
+
One cat and two dogs sitting on the grass.
|
| 753 |
+
Matutinal.
|
| 754 |
+
A cat on the right of a tennis racket.
|
| 755 |
+
A laptop on top of a teddy bear.
|
| 756 |
+
A white colored sandwich.
|
| 757 |
+
A yellow and black bus cruising through the rainforest.
|
| 758 |
+
A photocopy of a photograph of a painting of a sculpture of a giraffe.
|
| 759 |
+
A side view of an owl sitting in a field.
|
| 760 |
+
A pizza on the right of a suitcase.
|
| 761 |
+
A wine glass on top of a dog.
|
| 762 |
+
A realistic photo of a Pomeranian dressed up like a 1980s professional wrestler with neon green and neon orange face paint and bright green wrestling tights with bright orange boots.
|
| 763 |
+
A pear cut into seven pieces arranged in a ring.
|
| 764 |
+
Acersecomicke.
|
| 765 |
+
Painting of Mona Lisa but the view is from behind of Mona Lisa.
|
| 766 |
+
A small vessel propelled on water by oars, sails, or an engine.
|
| 767 |
+
Painting of the orange cat Otto von Garfield, Count of Bismarck-Schönhausen, Duke of Lauenburg, Minister-President of Prussia. Depicted wearing a Prussian Pickelhaube and eating his favorite meal - lasagna.
|
| 768 |
+
A cat on the left of a dog.
|
| 769 |
+
A red colored banana.
|
| 770 |
+
A domesticated carnivorous mammal that typically has a long snout, an acute sense of smell, nonretractable claws, and a barking, howling, or whining voice.
|
| 771 |
+
A sign that says 'Google Brain Toronto'.
|
| 772 |
+
A collection of nail is sitting on a table.
|
| 773 |
+
A pyramid made of falafel with a partial solar eclipse in the background.
|
| 774 |
+
A realistic photo of a Pomeranian dressed up like a 1980s professional wrestler with neon green and neon orange face paint and bright green wrestling tights with bright orange boots.
|
| 775 |
+
A cube made of brick. A cube with the texture of brick.
|
| 776 |
+
New York Skyline with 'Text to Image' written with fireworks on the sky.
|
| 777 |
+
A fish eating a pelican.
|
| 778 |
+
A pink colored giraffe.
|
| 779 |
+
One cat and three dogs sitting on the grass.
|
| 780 |
+
A keyboard made of water, the water is made of light, the light is turned off.
|
| 781 |
+
Greek statue of a man tripping over a cat.
|
| 782 |
+
A machine resembling a human being and able to replicate certain human movements and functions automatically.
|
| 783 |
+
A yellow and black bus cruising through the rainforest.
|
| 784 |
+
An elephant is behind a tree. You can see the trunk on one side and the back legs on the other.
|
| 785 |
+
A small domesticated carnivorous mammal with soft fur, a short snout, and retractable claws. It is widely kept as a pet or for catching mice, and many breeds have been developed.
|
| 786 |
+
Dininrg tablez.
|
| 787 |
+
A sign that says 'NeurIPS'.
|
| 788 |
+
An illustration of a small green elephant standing behind a large red mouse.
|
| 789 |
+
A collection of nail is sitting on a table.
|
| 790 |
+
An oil painting of a couple in formal evening wear going home get caught in a heavy downpour with no umbrellas.
|
| 791 |
+
New York Skyline with 'Hello World' written with fireworks on the sky.
|
| 792 |
+
A storefront with 'Text to Image' written on it.
|
| 793 |
+
A storefront with 'Deep Learning' written on it.
|
| 794 |
+
Three cats and two dogs sitting on the grass.
|
| 795 |
+
A red car and a white sheep.
|
| 796 |
+
A domesticated carnivvorous mzammal that typicbally hfaas a lons sfnout, an acxujte sense off osmell, noneetractaaln crlaws, anid xbarkring,y howlingu, or whining rvoiche.
|
| 797 |
+
A mechanical or electrical device for measuring time.
|
| 798 |
+
A bridge connecting Europe and North America on the Atlantic Ocean, bird's eye view.
|
| 799 |
+
An appliance or compartment which is artificially kept cool and used to store food and drink.
|
| 800 |
+
A pizza cooking an oven.
|
| 801 |
+
A car playing soccer, digital art.
|
| 802 |
+
A blue coloured pizza.
|
| 803 |
+
A machine next to a parking space in a street, into which the driver puts money so as to be authorized to park the vehicle for a particular length of time.
|
| 804 |
+
Octothorpe.
|
| 805 |
+
A yellow book and a red vase.
|
| 806 |
+
A bicycle on top of a boat.
|
| 807 |
+
A device consisting of a circular canopy of cloth on a folding metal frame supported by a central rod, used as protection against rain or sometimes sun.
|
| 808 |
+
An orange colored sandwich.
|
| 809 |
+
Acersecomicke.
|
| 810 |
+
A magnifying glass over a page of a 1950s batman comic.
|
| 811 |
+
A black apple and a green backpack.
|
| 812 |
+
A bird scaring a scarecrow.
|
| 813 |
+
A sign that says 'Deep Learning'.
|
| 814 |
+
A bicycle on top of a boat.
|
| 815 |
+
Painting of Mona Lisa but the view is from behind of Mona Lisa.
|
| 816 |
+
Three dogs on the street.
|
| 817 |
+
A stack of 3 plates. A blue plate is on the top, sitting on a blue plate. The blue plate is in the middle, sitting on a green plate. The green plate is on the bottom.
|
| 818 |
+
A red car and a white sheep.
|
| 819 |
+
Greek statue of a man tripping over a cat.
|
| 820 |
+
Three dogs on the street.
|
| 821 |
+
A sheep to the right of a wine glass.
|
| 822 |
+
One cat and one dog sitting on the grass.
|
| 823 |
+
A black colored sandwich.
|
| 824 |
+
Peristeronic.
|
| 825 |
+
Three cats and two dogs sitting on the grass.
|
| 826 |
+
A 1960s yearbook photo with animals dressed as humans.
|
| 827 |
+
A sign that says 'Diffusion'.
|
| 828 |
+
A sign that says 'Google Research Pizza Cafe'.
|
| 829 |
+
A blue bird and a brown bear.
|
| 830 |
+
A yellow and black bus cruising through the rainforest.
|
| 831 |
+
A smafml vessef epropoeilled on watvewr by ors, sauls, or han engie.
|
| 832 |
+
Bzaseball galove.
|
| 833 |
+
Artophagous.
|
| 834 |
+
A sign that says 'Text to Image'.
|
| 835 |
+
A baby fennec sneezing onto a strawberry, detailed, macro, studio light, droplets, backlit ears.
|
| 836 |
+
A fisheye lens view of a turtle sitting in a forest.
|
| 837 |
+
A storefront with 'Hello World' written on it.
|
| 838 |
+
A connection point by which firefighters can tap into a water supply.
|
| 839 |
+
A separate seat for one person, typically with a back and four legs.
|
| 840 |
+
A 1960s yearbook photo with animals dressed as humans.
|
| 841 |
+
A sphere made of kitchen tile. A sphere with the texture of kitchen tile.
|
| 842 |
+
A black colored banana.
|
| 843 |
+
A vehicle composed of two wheels held in a frame one behind the other, propelled by pedals and steered with handlebars attached to the front wheel.
|
| 844 |
+
Four cars on the street.
|
| 845 |
+
Three cats and three dogs sitting on the grass.
|
| 846 |
+
Five dogs on the street.
|
| 847 |
+
An ancient Egyptian painting depicting an argument over whose turn it is to take out the trash.
|
| 848 |
+
A storefront with 'Diffusion' written on it.
|
| 849 |
+
A pizza cooking an oven.
|
| 850 |
+
Darth Vader playing with raccoon in Mars during sunset.
|
| 851 |
+
A carrot on the left of a broccoli.
|
| 852 |
+
A tomato has been put on top of a pumpkin on a kitchen stool. There is a fork sticking into the pumpkin. The scene is viewed from above.
|
| 853 |
+
A storefront with 'Diffusion' written on it.
|
| 854 |
+
A red book and a yellow vase.
|
| 855 |
+
Peristeronic.
|
| 856 |
+
An organ of soft nervous tissue contained in the skull of vertebrates, functioning as the coordinating center of sensation and intellectual and nervous activity.
|
| 857 |
+
A donkey and an octopus are playing a game. The donkey is holding a rope on one end, the octopus is holding onto the other. The donkey holds the rope in its mouth. A cat is jumping over the rope.
|
| 858 |
+
A couch on the left of a chair.
|
| 859 |
+
A sphere made of kitchen tile. A sphere with the texture of kitchen tile.
|
| 860 |
+
A white car and a red sheep.
|
| 861 |
+
Artophagous.
|
| 862 |
+
A stack of 3 books. A green book is on the top, sitting on a red book. The red book is in the middle, sitting on a blue book. The blue book is on the bottom.
|
| 863 |
+
A pizza cooking an oven.
|
| 864 |
+
A triangular purple flower pot. A purple flower pot in the shape of a triangle.
|
| 865 |
+
A brown bird and a blue bear.
|
| 866 |
+
An IT-guy trying to fix hardware of a PC tower is being tangled by the PC cables like Laokoon. Marble, copy after Hellenistic original from ca. 200 BC. Found in the Baths of Trajan, 1506.
|
| 867 |
+
A storefront with 'Google Research Pizza Cafe' written on it.
|
| 868 |
+
A storefront with 'Google Research Pizza Cafe' written on it.
|
| 869 |
+
A large thick-skinned semiaquatic African mammal, with massive jaws and large tusks.
|
| 870 |
+
An appliance or compartment which is artificially kept cool and used to store food and drink.
|
| 871 |
+
A donut underneath a toilet.
|
| 872 |
+
A blue bird and a brown bear.
|
| 873 |
+
A 1960s poster warning against climate change.
|
| 874 |
+
A white colored sandwich.
|
| 875 |
+
A white colored sandwich.
|
| 876 |
+
A stop sign on the right of a refrigerator.
|
| 877 |
+
A storefront with 'Hello World' written on it.
|
| 878 |
+
Five dogs on the street.
|
| 879 |
+
Three cars on the street.
|
| 880 |
+
A keyboard made of water, the water is made of light, the light is turned off.
|
| 881 |
+
A red colored dog.
|
| 882 |
+
Two cats and three dogs sitting on the grass.
|
| 883 |
+
A spider with a moustache bidding an equally gentlemanly grasshopper a good day during his walk to work.
|
| 884 |
+
A pink colored car.
|
| 885 |
+
A tiger in a lab coat with a 1980s Miami vibe, turning a well oiled science content machine, digital art.
|
| 886 |
+
Photo of an athlete cat explaining it's latest scandal at a press conference to journalists.
|
| 887 |
+
A realistic photo of a Pomeranian dressed up like a 1980s professional wrestler with neon green and neon orange face paint and bright green wrestling tights with bright orange boots.
|
| 888 |
+
A type of digital currency in which a record of transactions is maintained and new units of currency are generated by the computational solution of mathematical problems, and which operates independently of a central bank.
|
| 889 |
+
A sign that says 'Hello World'.
|
| 890 |
+
An ancient Egyptian painting depicting an argument over whose turn it is to take out the trash.
|
| 891 |
+
A white car and a red sheep.
|
| 892 |
+
Illustration of a mouse using a mushroom as an umbrella.
|
| 893 |
+
A red colored banana.
|
| 894 |
+
Three cats and one dog sitting on the grass.
|
| 895 |
+
A car playing soccer, digital art.
|
| 896 |
+
A sjmall domesticated carnivorious mammnal with sof fuh,y a sthort sout, and retracwtablbe flaws. It iw widexly kept as a pet or for catchitng mic, ad many breeds zhlyde beefn develvoked.
|
| 897 |
+
Rbefraigerator.
|
| 898 |
+
A triangular orange picture frame. An orange picture frame in the shape of a triangle.
|
| 899 |
+
Rainbow coloured penguin.
|
| 900 |
+
A storefront with 'Text to Image' written on it.
|
| 901 |
+
A cat on the right of a tennis racket.
|
| 902 |
+
A small blue book sitting on a large red book.
|
| 903 |
+
Two cats and one dog sitting on the grass.
|
| 904 |
+
An emoji of a baby panda wearing a red hat, green gloves, red shirt, and green pants.
|
| 905 |
+
A brown bird and a blue bear.
|
| 906 |
+
A red car and a white sheep.
|
| 907 |
+
A pizza on the right of a suitcase.
|
| 908 |
+
A small blue book sitting on a large red book.
|
| 909 |
+
A horse riding an astronaut.
|
| 910 |
+
A sign that says 'Google Brain Toronto'.
|
| 911 |
+
Hyper-realistic photo of an abandoned industrial site during a storm.
|
| 912 |
+
A side view of an owl sitting in a field.
|
| 913 |
+
A photo of a confused grizzly bear in calculus class.
|
| 914 |
+
An American multinational technology company that focuses on artificial intelligence, search engine, online advertising, cloud computing, computer software, quantum computing, e-commerce, and consumer electronics.
|
| 915 |
+
A storefront with 'NeurIPS' written on it.
|
| 916 |
+
A storefront with 'NeurIPS' written on it.
|
| 917 |
+
Two cats and one dog sitting on the grass.
|
| 918 |
+
New York Skyline with 'Diffusion' written with fireworks on the sky.
|
| 919 |
+
A storefront with 'Diffusion' written on it.
|
| 920 |
+
A blue coloured pizza.
|
| 921 |
+
A single clock is sitting on a table.
|
| 922 |
+
A zebra to the right of a fire hydrant.
|
| 923 |
+
Backlotter.
|
| 924 |
+
An ancient Egyptian painting depicting an argument over whose turn it is to take out the trash.
|
| 925 |
+
Two cats and two dogs sitting on the grass.
|
| 926 |
+
Painting of Mona Lisa but the view is from behind of Mona Lisa.
|
| 927 |
+
A triangular orange picture frame. An orange picture frame in the shape of a triangle.
|
| 928 |
+
A bird scaring a scarecrow.
|
| 929 |
+
A keyboard made of water, the water is made of light, the light is turned off.
|
| 930 |
+
A tennis racket underneath a traffic light.
|
| 931 |
+
A banana on the left of an apple.
|
| 932 |
+
A screenshot of an iOS app for ordering different types of milk.
|
| 933 |
+
A long curved fruit which grows in clusters and has soft pulpy flesh and yellow skin when ripe.
|
| 934 |
+
A side view of an owl sitting in a field.
|
| 935 |
+
Two cats and two dogs sitting on the grass.
|
| 936 |
+
Hovering cow abducting aliens.
|
| 937 |
+
A red car and a white sheep.
|
| 938 |
+
A zebra underneath a broccoli.
|
| 939 |
+
Rainbow coloured penguin.
|
| 940 |
+
A storefront with 'Deep Learning' written on it.
|
| 941 |
+
Three cars on the street.
|
| 942 |
+
A red colored banana.
|
| 943 |
+
A blue bird and a brown bear.
|
| 944 |
+
New York Skyline with 'NeurIPS' written with fireworks on the sky.
|
| 945 |
+
A sjmall domesticated carnivorious mammnal with sof fuh,y a sthort sout, and retracwtablbe flaws. It iw widexly kept as a pet or for catchitng mic, ad many breeds zhlyde beefn develvoked.
|
| 946 |
+
A giraffe underneath a microwave.
|
| 947 |
+
A brown colored giraffe.
|
| 948 |
+
An instqrumemnt used for cutting cloth, paper, axdz othr thdin mteroial, consamistng of two blades lad one on tvopb of the other and fhastned in tle mixdqdjle so as to bllow them txo be pened and closed by thumb and fitngesr inserted tgrough rings on kthe end oc thei vatndlzes.
|
| 949 |
+
A pizza cooking an oven.
|
| 950 |
+
A bicycle on top of a boat.
|
| 951 |
+
A screenshot of an iOS app for ordering different types of milk.
|
| 952 |
+
A car playing soccer, digital art.
|
| 953 |
+
A banana on the left of an apple.
|
| 954 |
+
A cube made of brick. A cube with the texture of brick.
|
| 955 |
+
A sheep to the right of a wine glass.
|
| 956 |
+
A type of digital currency in which a record of transactions is maintained and new units of currency are generated by the computational solution of mathematical problems, and which operates independently of a central bank.
|
| 957 |
+
A medieval painting of the wifi not working.
|
| 958 |
+
A brown bird and a blue bear.
|
| 959 |
+
A yellow and black bus cruising through the rainforest.
|
| 960 |
+
A bridge connecting Europe and North America on the Atlantic Ocean, bird's eye view.
|
| 961 |
+
Hyper-realistic photo of an abandoned industrial site during a storm.
|
| 962 |
+
Photo of an athlete cat explaining it's latest scandal at a press conference to journalists.
|
| 963 |
+
A stack of 3 cubes. A red cube is on the top, sitting on a red cube. The red cube is in the middle, sitting on a green cube. The green cube is on the bottom.
|
| 964 |
+
A yellow book and a red vase.
|
| 965 |
+
A wine glass on top of a dog.
|
| 966 |
+
A sign that says 'Deep Learning'.
|
| 967 |
+
A small domesticated carnivorous mammal with soft fur, a short snout, and retractable claws. It is widely kept as a pet or for catching mice, and many breeds have been developed.
|
| 968 |
+
Jentacular.
|
| 969 |
+
A car on the left of a bus.
|
| 970 |
+
A machine resembling a human being and able to replicate certain human movements and functions automatically.
|
| 971 |
+
New York Skyline with 'Google Research Pizza Cafe' written with fireworks on the sky.
|
| 972 |
+
Photo of a mega Lego space station inside a kid's bedroom.
|
| 973 |
+
Peristeronic.
|
| 974 |
+
One cat and one dog sitting on the grass.
|
| 975 |
+
A horse riding an astronaut.
|
| 976 |
+
New York Skyline with 'Deep Learning' written with fireworks on the sky.
|
| 977 |
+
A zebra underneath a broccoli.
|
| 978 |
+
A machine resembling a human being and able to replicate certain human movements and functions automatically.
|
| 979 |
+
A red colored dog.
|
| 980 |
+
Acersecomicke.
|
| 981 |
+
One dog on the street.
|
| 982 |
+
A white car and a red sheep.
|
| 983 |
+
New York Skyline with 'NeurIPS' written with fireworks on the sky.
|
| 984 |
+
A single clock is sitting on a table.
|
| 985 |
+
A zebra to the right of a fire hydrant.
|
| 986 |
+
A triangular orange picture frame. An orange picture frame in the shape of a triangle.
|
| 987 |
+
A blue colored dog.
|
| 988 |
+
McDonalds Church.
|
| 989 |
+
Tcennis rpacket.
|
| 990 |
+
A brown colored giraffe.
|
| 991 |
+
Hyper-realistic photo of an abandoned industrial site during a storm.
|
| 992 |
+
Tcennis rpacket.
|
| 993 |
+
A church with stained glass windows depicting a hamburger and french fries.
|
| 994 |
+
A bicycle on top of a boat.
|
| 995 |
+
A banana on the left of an apple.
|
| 996 |
+
A connection point by which firefighters can tap into a water supply.
|
| 997 |
+
New York Skyline with 'Deep Learning' written with fireworks on the sky.
|
| 998 |
+
An instqrumemnt used for cutting cloth, paper, axdz othr thdin mteroial, consamistng of two blades lad one on tvopb of the other and fhastned in tle mixdqdjle so as to bllow them txo be pened and closed by thumb and fitngesr inserted tgrough rings on kthe end oc thei vatndlzes.
|
| 999 |
+
New York Skyline with 'NeurIPS' written with fireworks on the sky.
|
| 1000 |
+
Paying for a quarter-sized pizza with a pizza-sized quarter.
|
prompts/evaluation_metadata.jsonl
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"tag": "single_object", "include": [{"class": "bench", "count": 1}], "prompt": "a photo of a bench"}
|
| 2 |
+
{"tag": "single_object", "include": [{"class": "cow", "count": 1}], "prompt": "a photo of a cow"}
|
| 3 |
+
{"tag": "single_object", "include": [{"class": "bicycle", "count": 1}], "prompt": "a photo of a bicycle"}
|
| 4 |
+
{"tag": "single_object", "include": [{"class": "clock", "count": 1}], "prompt": "a photo of a clock"}
|
| 5 |
+
{"tag": "single_object", "include": [{"class": "carrot", "count": 1}], "prompt": "a photo of a carrot"}
|
| 6 |
+
{"tag": "single_object", "include": [{"class": "suitcase", "count": 1}], "prompt": "a photo of a suitcase"}
|
| 7 |
+
{"tag": "single_object", "include": [{"class": "fork", "count": 1}], "prompt": "a photo of a fork"}
|
| 8 |
+
{"tag": "single_object", "include": [{"class": "surfboard", "count": 1}], "prompt": "a photo of a surfboard"}
|
| 9 |
+
{"tag": "single_object", "include": [{"class": "refrigerator", "count": 1}], "prompt": "a photo of a refrigerator"}
|
| 10 |
+
{"tag": "single_object", "include": [{"class": "cup", "count": 1}], "prompt": "a photo of a cup"}
|
| 11 |
+
{"tag": "single_object", "include": [{"class": "microwave", "count": 1}], "prompt": "a photo of a microwave"}
|
| 12 |
+
{"tag": "single_object", "include": [{"class": "potted plant", "count": 1}], "prompt": "a photo of a potted plant"}
|
| 13 |
+
{"tag": "single_object", "include": [{"class": "snowboard", "count": 1}], "prompt": "a photo of a snowboard"}
|
| 14 |
+
{"tag": "single_object", "include": [{"class": "zebra", "count": 1}], "prompt": "a photo of a zebra"}
|
| 15 |
+
{"tag": "single_object", "include": [{"class": "parking meter", "count": 1}], "prompt": "a photo of a parking meter"}
|
| 16 |
+
{"tag": "single_object", "include": [{"class": "spoon", "count": 1}], "prompt": "a photo of a spoon"}
|
| 17 |
+
{"tag": "single_object", "include": [{"class": "skateboard", "count": 1}], "prompt": "a photo of a skateboard"}
|
| 18 |
+
{"tag": "single_object", "include": [{"class": "car", "count": 1}], "prompt": "a photo of a car"}
|
| 19 |
+
{"tag": "single_object", "include": [{"class": "motorcycle", "count": 1}], "prompt": "a photo of a motorcycle"}
|
| 20 |
+
{"tag": "single_object", "include": [{"class": "traffic light", "count": 1}], "prompt": "a photo of a traffic light"}
|
| 21 |
+
{"tag": "single_object", "include": [{"class": "book", "count": 1}], "prompt": "a photo of a book"}
|
| 22 |
+
{"tag": "single_object", "include": [{"class": "couch", "count": 1}], "prompt": "a photo of a couch"}
|
| 23 |
+
{"tag": "single_object", "include": [{"class": "backpack", "count": 1}], "prompt": "a photo of a backpack"}
|
| 24 |
+
{"tag": "single_object", "include": [{"class": "computer keyboard", "count": 1}], "prompt": "a photo of a computer keyboard"}
|
| 25 |
+
{"tag": "single_object", "include": [{"class": "toaster", "count": 1}], "prompt": "a photo of a toaster"}
|
| 26 |
+
{"tag": "single_object", "include": [{"class": "bird", "count": 1}], "prompt": "a photo of a bird"}
|
| 27 |
+
{"tag": "single_object", "include": [{"class": "bowl", "count": 1}], "prompt": "a photo of a bowl"}
|
| 28 |
+
{"tag": "single_object", "include": [{"class": "dog", "count": 1}], "prompt": "a photo of a dog"}
|
| 29 |
+
{"tag": "single_object", "include": [{"class": "tie", "count": 1}], "prompt": "a photo of a tie"}
|
| 30 |
+
{"tag": "single_object", "include": [{"class": "laptop", "count": 1}], "prompt": "a photo of a laptop"}
|
| 31 |
+
{"tag": "single_object", "include": [{"class": "computer mouse", "count": 1}], "prompt": "a photo of a computer mouse"}
|
| 32 |
+
{"tag": "single_object", "include": [{"class": "sandwich", "count": 1}], "prompt": "a photo of a sandwich"}
|
| 33 |
+
{"tag": "single_object", "include": [{"class": "baseball bat", "count": 1}], "prompt": "a photo of a baseball bat"}
|
| 34 |
+
{"tag": "single_object", "include": [{"class": "train", "count": 1}], "prompt": "a photo of a train"}
|
| 35 |
+
{"tag": "single_object", "include": [{"class": "cell phone", "count": 1}], "prompt": "a photo of a cell phone"}
|
| 36 |
+
{"tag": "single_object", "include": [{"class": "chair", "count": 1}], "prompt": "a photo of a chair"}
|
| 37 |
+
{"tag": "single_object", "include": [{"class": "tv", "count": 1}], "prompt": "a photo of a tv"}
|
| 38 |
+
{"tag": "single_object", "include": [{"class": "broccoli", "count": 1}], "prompt": "a photo of a broccoli"}
|
| 39 |
+
{"tag": "single_object", "include": [{"class": "bed", "count": 1}], "prompt": "a photo of a bed"}
|
| 40 |
+
{"tag": "single_object", "include": [{"class": "skis", "count": 1}], "prompt": "a photo of a skis"}
|
| 41 |
+
{"tag": "single_object", "include": [{"class": "handbag", "count": 1}], "prompt": "a photo of a handbag"}
|
| 42 |
+
{"tag": "single_object", "include": [{"class": "pizza", "count": 1}], "prompt": "a photo of a pizza"}
|
| 43 |
+
{"tag": "single_object", "include": [{"class": "frisbee", "count": 1}], "prompt": "a photo of a frisbee"}
|
| 44 |
+
{"tag": "single_object", "include": [{"class": "scissors", "count": 1}], "prompt": "a photo of a scissors"}
|
| 45 |
+
{"tag": "single_object", "include": [{"class": "bottle", "count": 1}], "prompt": "a photo of a bottle"}
|
| 46 |
+
{"tag": "single_object", "include": [{"class": "elephant", "count": 1}], "prompt": "a photo of an elephant"}
|
| 47 |
+
{"tag": "single_object", "include": [{"class": "toilet", "count": 1}], "prompt": "a photo of a toilet"}
|
| 48 |
+
{"tag": "single_object", "include": [{"class": "oven", "count": 1}], "prompt": "a photo of an oven"}
|
| 49 |
+
{"tag": "single_object", "include": [{"class": "orange", "count": 1}], "prompt": "a photo of an orange"}
|
| 50 |
+
{"tag": "single_object", "include": [{"class": "person", "count": 1}], "prompt": "a photo of a person"}
|
| 51 |
+
{"tag": "single_object", "include": [{"class": "teddy bear", "count": 1}], "prompt": "a photo of a teddy bear"}
|
| 52 |
+
{"tag": "single_object", "include": [{"class": "vase", "count": 1}], "prompt": "a photo of a vase"}
|
| 53 |
+
{"tag": "single_object", "include": [{"class": "banana", "count": 1}], "prompt": "a photo of a banana"}
|
| 54 |
+
{"tag": "single_object", "include": [{"class": "toothbrush", "count": 1}], "prompt": "a photo of a toothbrush"}
|
| 55 |
+
{"tag": "single_object", "include": [{"class": "tv remote", "count": 1}], "prompt": "a photo of a tv remote"}
|
| 56 |
+
{"tag": "single_object", "include": [{"class": "dining table", "count": 1}], "prompt": "a photo of a dining table"}
|
| 57 |
+
{"tag": "single_object", "include": [{"class": "stop sign", "count": 1}], "prompt": "a photo of a stop sign"}
|
| 58 |
+
{"tag": "single_object", "include": [{"class": "sheep", "count": 1}], "prompt": "a photo of a sheep"}
|
| 59 |
+
{"tag": "single_object", "include": [{"class": "fire hydrant", "count": 1}], "prompt": "a photo of a fire hydrant"}
|
| 60 |
+
{"tag": "single_object", "include": [{"class": "airplane", "count": 1}], "prompt": "a photo of an airplane"}
|
| 61 |
+
{"tag": "single_object", "include": [{"class": "giraffe", "count": 1}], "prompt": "a photo of a giraffe"}
|
| 62 |
+
{"tag": "single_object", "include": [{"class": "horse", "count": 1}], "prompt": "a photo of a horse"}
|
| 63 |
+
{"tag": "single_object", "include": [{"class": "cat", "count": 1}], "prompt": "a photo of a cat"}
|
| 64 |
+
{"tag": "single_object", "include": [{"class": "donut", "count": 1}], "prompt": "a photo of a donut"}
|
| 65 |
+
{"tag": "single_object", "include": [{"class": "boat", "count": 1}], "prompt": "a photo of a boat"}
|
| 66 |
+
{"tag": "single_object", "include": [{"class": "baseball glove", "count": 1}], "prompt": "a photo of a baseball glove"}
|
| 67 |
+
{"tag": "single_object", "include": [{"class": "hair drier", "count": 1}], "prompt": "a photo of a hair drier"}
|
| 68 |
+
{"tag": "single_object", "include": [{"class": "sink", "count": 1}], "prompt": "a photo of a sink"}
|
| 69 |
+
{"tag": "single_object", "include": [{"class": "cake", "count": 1}], "prompt": "a photo of a cake"}
|
| 70 |
+
{"tag": "single_object", "include": [{"class": "wine glass", "count": 1}], "prompt": "a photo of a wine glass"}
|
| 71 |
+
{"tag": "single_object", "include": [{"class": "apple", "count": 1}], "prompt": "a photo of an apple"}
|
| 72 |
+
{"tag": "single_object", "include": [{"class": "bus", "count": 1}], "prompt": "a photo of a bus"}
|
| 73 |
+
{"tag": "single_object", "include": [{"class": "tennis racket", "count": 1}], "prompt": "a photo of a tennis racket"}
|
| 74 |
+
{"tag": "single_object", "include": [{"class": "knife", "count": 1}], "prompt": "a photo of a knife"}
|
| 75 |
+
{"tag": "single_object", "include": [{"class": "hot dog", "count": 1}], "prompt": "a photo of a hot dog"}
|
| 76 |
+
{"tag": "single_object", "include": [{"class": "truck", "count": 1}], "prompt": "a photo of a truck"}
|
| 77 |
+
{"tag": "single_object", "include": [{"class": "umbrella", "count": 1}], "prompt": "a photo of an umbrella"}
|
| 78 |
+
{"tag": "single_object", "include": [{"class": "sports ball", "count": 1}], "prompt": "a photo of a sports ball"}
|
| 79 |
+
{"tag": "single_object", "include": [{"class": "bear", "count": 1}], "prompt": "a photo of a bear"}
|
| 80 |
+
{"tag": "single_object", "include": [{"class": "kite", "count": 1}], "prompt": "a photo of a kite"}
|
| 81 |
+
{"tag": "two_object", "include": [{"class": "bench", "count": 1}, {"class": "sports ball", "count": 1}], "prompt": "a photo of a bench and a sports ball"}
|
| 82 |
+
{"tag": "two_object", "include": [{"class": "toothbrush", "count": 1}, {"class": "snowboard", "count": 1}], "prompt": "a photo of a toothbrush and a snowboard"}
|
| 83 |
+
{"tag": "two_object", "include": [{"class": "toaster", "count": 1}, {"class": "oven", "count": 1}], "prompt": "a photo of a toaster and an oven"}
|
| 84 |
+
{"tag": "two_object", "include": [{"class": "broccoli", "count": 1}, {"class": "vase", "count": 1}], "prompt": "a photo of a broccoli and a vase"}
|
| 85 |
+
{"tag": "two_object", "include": [{"class": "tennis racket", "count": 1}, {"class": "wine glass", "count": 1}], "prompt": "a photo of a tennis racket and a wine glass"}
|
| 86 |
+
{"tag": "two_object", "include": [{"class": "fork", "count": 1}, {"class": "knife", "count": 1}], "prompt": "a photo of a fork and a knife"}
|
| 87 |
+
{"tag": "two_object", "include": [{"class": "hair drier", "count": 1}, {"class": "cake", "count": 1}], "prompt": "a photo of a hair drier and a cake"}
|
| 88 |
+
{"tag": "two_object", "include": [{"class": "horse", "count": 1}, {"class": "giraffe", "count": 1}], "prompt": "a photo of a horse and a giraffe"}
|
| 89 |
+
{"tag": "two_object", "include": [{"class": "horse", "count": 1}, {"class": "computer keyboard", "count": 1}], "prompt": "a photo of a horse and a computer keyboard"}
|
| 90 |
+
{"tag": "two_object", "include": [{"class": "toothbrush", "count": 1}, {"class": "carrot", "count": 1}], "prompt": "a photo of a toothbrush and a carrot"}
|
| 91 |
+
{"tag": "two_object", "include": [{"class": "cake", "count": 1}, {"class": "zebra", "count": 1}], "prompt": "a photo of a cake and a zebra"}
|
| 92 |
+
{"tag": "two_object", "include": [{"class": "hair drier", "count": 1}, {"class": "bear", "count": 1}], "prompt": "a photo of a hair drier and a bear"}
|
| 93 |
+
{"tag": "two_object", "include": [{"class": "knife", "count": 1}, {"class": "zebra", "count": 1}], "prompt": "a photo of a knife and a zebra"}
|
| 94 |
+
{"tag": "two_object", "include": [{"class": "couch", "count": 1}, {"class": "wine glass", "count": 1}], "prompt": "a photo of a couch and a wine glass"}
|
| 95 |
+
{"tag": "two_object", "include": [{"class": "frisbee", "count": 1}, {"class": "vase", "count": 1}], "prompt": "a photo of a frisbee and a vase"}
|
| 96 |
+
{"tag": "two_object", "include": [{"class": "book", "count": 1}, {"class": "laptop", "count": 1}], "prompt": "a photo of a book and a laptop"}
|
| 97 |
+
{"tag": "two_object", "include": [{"class": "dining table", "count": 1}, {"class": "bear", "count": 1}], "prompt": "a photo of a dining table and a bear"}
|
| 98 |
+
{"tag": "two_object", "include": [{"class": "frisbee", "count": 1}, {"class": "couch", "count": 1}], "prompt": "a photo of a frisbee and a couch"}
|
| 99 |
+
{"tag": "two_object", "include": [{"class": "couch", "count": 1}, {"class": "horse", "count": 1}], "prompt": "a photo of a couch and a horse"}
|
| 100 |
+
{"tag": "two_object", "include": [{"class": "toilet", "count": 1}, {"class": "computer mouse", "count": 1}], "prompt": "a photo of a toilet and a computer mouse"}
|
| 101 |
+
{"tag": "two_object", "include": [{"class": "bottle", "count": 1}, {"class": "refrigerator", "count": 1}], "prompt": "a photo of a bottle and a refrigerator"}
|
| 102 |
+
{"tag": "two_object", "include": [{"class": "potted plant", "count": 1}, {"class": "backpack", "count": 1}], "prompt": "a photo of a potted plant and a backpack"}
|
| 103 |
+
{"tag": "two_object", "include": [{"class": "skateboard", "count": 1}, {"class": "cake", "count": 1}], "prompt": "a photo of a skateboard and a cake"}
|
| 104 |
+
{"tag": "two_object", "include": [{"class": "broccoli", "count": 1}, {"class": "parking meter", "count": 1}], "prompt": "a photo of a broccoli and a parking meter"}
|
| 105 |
+
{"tag": "two_object", "include": [{"class": "zebra", "count": 1}, {"class": "bed", "count": 1}], "prompt": "a photo of a zebra and a bed"}
|
| 106 |
+
{"tag": "two_object", "include": [{"class": "oven", "count": 1}, {"class": "bed", "count": 1}], "prompt": "a photo of an oven and a bed"}
|
| 107 |
+
{"tag": "two_object", "include": [{"class": "baseball bat", "count": 1}, {"class": "fork", "count": 1}], "prompt": "a photo of a baseball bat and a fork"}
|
| 108 |
+
{"tag": "two_object", "include": [{"class": "vase", "count": 1}, {"class": "spoon", "count": 1}], "prompt": "a photo of a vase and a spoon"}
|
| 109 |
+
{"tag": "two_object", "include": [{"class": "skateboard", "count": 1}, {"class": "sink", "count": 1}], "prompt": "a photo of a skateboard and a sink"}
|
| 110 |
+
{"tag": "two_object", "include": [{"class": "pizza", "count": 1}, {"class": "bench", "count": 1}], "prompt": "a photo of a pizza and a bench"}
|
| 111 |
+
{"tag": "two_object", "include": [{"class": "bowl", "count": 1}, {"class": "pizza", "count": 1}], "prompt": "a photo of a bowl and a pizza"}
|
| 112 |
+
{"tag": "two_object", "include": [{"class": "tennis racket", "count": 1}, {"class": "bird", "count": 1}], "prompt": "a photo of a tennis racket and a bird"}
|
| 113 |
+
{"tag": "two_object", "include": [{"class": "wine glass", "count": 1}, {"class": "bear", "count": 1}], "prompt": "a photo of a wine glass and a bear"}
|
| 114 |
+
{"tag": "two_object", "include": [{"class": "fork", "count": 1}, {"class": "book", "count": 1}], "prompt": "a photo of a fork and a book"}
|
| 115 |
+
{"tag": "two_object", "include": [{"class": "scissors", "count": 1}, {"class": "bowl", "count": 1}], "prompt": "a photo of a scissors and a bowl"}
|
| 116 |
+
{"tag": "two_object", "include": [{"class": "laptop", "count": 1}, {"class": "carrot", "count": 1}], "prompt": "a photo of a laptop and a carrot"}
|
| 117 |
+
{"tag": "two_object", "include": [{"class": "stop sign", "count": 1}, {"class": "bottle", "count": 1}], "prompt": "a photo of a stop sign and a bottle"}
|
| 118 |
+
{"tag": "two_object", "include": [{"class": "microwave", "count": 1}, {"class": "truck", "count": 1}], "prompt": "a photo of a microwave and a truck"}
|
| 119 |
+
{"tag": "two_object", "include": [{"class": "person", "count": 1}, {"class": "bear", "count": 1}], "prompt": "a photo of a person and a bear"}
|
| 120 |
+
{"tag": "two_object", "include": [{"class": "frisbee", "count": 1}, {"class": "cell phone", "count": 1}], "prompt": "a photo of a frisbee and a cell phone"}
|
| 121 |
+
{"tag": "two_object", "include": [{"class": "parking meter", "count": 1}, {"class": "teddy bear", "count": 1}], "prompt": "a photo of a parking meter and a teddy bear"}
|
| 122 |
+
{"tag": "two_object", "include": [{"class": "tennis racket", "count": 1}, {"class": "bicycle", "count": 1}], "prompt": "a photo of a tennis racket and a bicycle"}
|
| 123 |
+
{"tag": "two_object", "include": [{"class": "stop sign", "count": 1}, {"class": "motorcycle", "count": 1}], "prompt": "a photo of a stop sign and a motorcycle"}
|
| 124 |
+
{"tag": "two_object", "include": [{"class": "fire hydrant", "count": 1}, {"class": "tennis racket", "count": 1}], "prompt": "a photo of a fire hydrant and a tennis racket"}
|
| 125 |
+
{"tag": "two_object", "include": [{"class": "scissors", "count": 1}, {"class": "sandwich", "count": 1}], "prompt": "a photo of a scissors and a sandwich"}
|
| 126 |
+
{"tag": "two_object", "include": [{"class": "pizza", "count": 1}, {"class": "book", "count": 1}], "prompt": "a photo of a pizza and a book"}
|
| 127 |
+
{"tag": "two_object", "include": [{"class": "giraffe", "count": 1}, {"class": "computer mouse", "count": 1}], "prompt": "a photo of a giraffe and a computer mouse"}
|
| 128 |
+
{"tag": "two_object", "include": [{"class": "stop sign", "count": 1}, {"class": "toaster", "count": 1}], "prompt": "a photo of a stop sign and a toaster"}
|
| 129 |
+
{"tag": "two_object", "include": [{"class": "computer mouse", "count": 1}, {"class": "zebra", "count": 1}], "prompt": "a photo of a computer mouse and a zebra"}
|
| 130 |
+
{"tag": "two_object", "include": [{"class": "chair", "count": 1}, {"class": "bench", "count": 1}], "prompt": "a photo of a chair and a bench"}
|
| 131 |
+
{"tag": "two_object", "include": [{"class": "tv", "count": 1}, {"class": "carrot", "count": 1}], "prompt": "a photo of a tv and a carrot"}
|
| 132 |
+
{"tag": "two_object", "include": [{"class": "surfboard", "count": 1}, {"class": "suitcase", "count": 1}], "prompt": "a photo of a surfboard and a suitcase"}
|
| 133 |
+
{"tag": "two_object", "include": [{"class": "computer keyboard", "count": 1}, {"class": "laptop", "count": 1}], "prompt": "a photo of a computer keyboard and a laptop"}
|
| 134 |
+
{"tag": "two_object", "include": [{"class": "computer keyboard", "count": 1}, {"class": "microwave", "count": 1}], "prompt": "a photo of a computer keyboard and a microwave"}
|
| 135 |
+
{"tag": "two_object", "include": [{"class": "scissors", "count": 1}, {"class": "bird", "count": 1}], "prompt": "a photo of a scissors and a bird"}
|
| 136 |
+
{"tag": "two_object", "include": [{"class": "person", "count": 1}, {"class": "snowboard", "count": 1}], "prompt": "a photo of a person and a snowboard"}
|
| 137 |
+
{"tag": "two_object", "include": [{"class": "cow", "count": 1}, {"class": "horse", "count": 1}], "prompt": "a photo of a cow and a horse"}
|
| 138 |
+
{"tag": "two_object", "include": [{"class": "handbag", "count": 1}, {"class": "refrigerator", "count": 1}], "prompt": "a photo of a handbag and a refrigerator"}
|
| 139 |
+
{"tag": "two_object", "include": [{"class": "chair", "count": 1}, {"class": "laptop", "count": 1}], "prompt": "a photo of a chair and a laptop"}
|
| 140 |
+
{"tag": "two_object", "include": [{"class": "toothbrush", "count": 1}, {"class": "bench", "count": 1}], "prompt": "a photo of a toothbrush and a bench"}
|
| 141 |
+
{"tag": "two_object", "include": [{"class": "book", "count": 1}, {"class": "baseball bat", "count": 1}], "prompt": "a photo of a book and a baseball bat"}
|
| 142 |
+
{"tag": "two_object", "include": [{"class": "horse", "count": 1}, {"class": "train", "count": 1}], "prompt": "a photo of a horse and a train"}
|
| 143 |
+
{"tag": "two_object", "include": [{"class": "bench", "count": 1}, {"class": "vase", "count": 1}], "prompt": "a photo of a bench and a vase"}
|
| 144 |
+
{"tag": "two_object", "include": [{"class": "traffic light", "count": 1}, {"class": "backpack", "count": 1}], "prompt": "a photo of a traffic light and a backpack"}
|
| 145 |
+
{"tag": "two_object", "include": [{"class": "sports ball", "count": 1}, {"class": "cow", "count": 1}], "prompt": "a photo of a sports ball and a cow"}
|
| 146 |
+
{"tag": "two_object", "include": [{"class": "computer mouse", "count": 1}, {"class": "spoon", "count": 1}], "prompt": "a photo of a computer mouse and a spoon"}
|
| 147 |
+
{"tag": "two_object", "include": [{"class": "tv", "count": 1}, {"class": "bicycle", "count": 1}], "prompt": "a photo of a tv and a bicycle"}
|
| 148 |
+
{"tag": "two_object", "include": [{"class": "bench", "count": 1}, {"class": "snowboard", "count": 1}], "prompt": "a photo of a bench and a snowboard"}
|
| 149 |
+
{"tag": "two_object", "include": [{"class": "toothbrush", "count": 1}, {"class": "toilet", "count": 1}], "prompt": "a photo of a toothbrush and a toilet"}
|
| 150 |
+
{"tag": "two_object", "include": [{"class": "person", "count": 1}, {"class": "apple", "count": 1}], "prompt": "a photo of a person and an apple"}
|
| 151 |
+
{"tag": "two_object", "include": [{"class": "sink", "count": 1}, {"class": "sports ball", "count": 1}], "prompt": "a photo of a sink and a sports ball"}
|
| 152 |
+
{"tag": "two_object", "include": [{"class": "stop sign", "count": 1}, {"class": "dog", "count": 1}], "prompt": "a photo of a stop sign and a dog"}
|
| 153 |
+
{"tag": "two_object", "include": [{"class": "knife", "count": 1}, {"class": "stop sign", "count": 1}], "prompt": "a photo of a knife and a stop sign"}
|
| 154 |
+
{"tag": "two_object", "include": [{"class": "wine glass", "count": 1}, {"class": "handbag", "count": 1}], "prompt": "a photo of a wine glass and a handbag"}
|
| 155 |
+
{"tag": "two_object", "include": [{"class": "bowl", "count": 1}, {"class": "skis", "count": 1}], "prompt": "a photo of a bowl and a skis"}
|
| 156 |
+
{"tag": "two_object", "include": [{"class": "frisbee", "count": 1}, {"class": "apple", "count": 1}], "prompt": "a photo of a frisbee and an apple"}
|
| 157 |
+
{"tag": "two_object", "include": [{"class": "computer keyboard", "count": 1}, {"class": "cell phone", "count": 1}], "prompt": "a photo of a computer keyboard and a cell phone"}
|
| 158 |
+
{"tag": "two_object", "include": [{"class": "stop sign", "count": 1}, {"class": "fork", "count": 1}], "prompt": "a photo of a stop sign and a fork"}
|
| 159 |
+
{"tag": "two_object", "include": [{"class": "potted plant", "count": 1}, {"class": "boat", "count": 1}], "prompt": "a photo of a potted plant and a boat"}
|
| 160 |
+
{"tag": "two_object", "include": [{"class": "tv", "count": 1}, {"class": "cell phone", "count": 1}], "prompt": "a photo of a tv and a cell phone"}
|
| 161 |
+
{"tag": "two_object", "include": [{"class": "tie", "count": 1}, {"class": "broccoli", "count": 1}], "prompt": "a photo of a tie and a broccoli"}
|
| 162 |
+
{"tag": "two_object", "include": [{"class": "potted plant", "count": 1}, {"class": "donut", "count": 1}], "prompt": "a photo of a potted plant and a donut"}
|
| 163 |
+
{"tag": "two_object", "include": [{"class": "person", "count": 1}, {"class": "sink", "count": 1}], "prompt": "a photo of a person and a sink"}
|
| 164 |
+
{"tag": "two_object", "include": [{"class": "couch", "count": 1}, {"class": "snowboard", "count": 1}], "prompt": "a photo of a couch and a snowboard"}
|
| 165 |
+
{"tag": "two_object", "include": [{"class": "fork", "count": 1}, {"class": "baseball glove", "count": 1}], "prompt": "a photo of a fork and a baseball glove"}
|
| 166 |
+
{"tag": "two_object", "include": [{"class": "apple", "count": 1}, {"class": "toothbrush", "count": 1}], "prompt": "a photo of an apple and a toothbrush"}
|
| 167 |
+
{"tag": "two_object", "include": [{"class": "bus", "count": 1}, {"class": "baseball glove", "count": 1}], "prompt": "a photo of a bus and a baseball glove"}
|
| 168 |
+
{"tag": "two_object", "include": [{"class": "person", "count": 1}, {"class": "stop sign", "count": 1}], "prompt": "a photo of a person and a stop sign"}
|
| 169 |
+
{"tag": "two_object", "include": [{"class": "carrot", "count": 1}, {"class": "couch", "count": 1}], "prompt": "a photo of a carrot and a couch"}
|
| 170 |
+
{"tag": "two_object", "include": [{"class": "baseball bat", "count": 1}, {"class": "bear", "count": 1}], "prompt": "a photo of a baseball bat and a bear"}
|
| 171 |
+
{"tag": "two_object", "include": [{"class": "fire hydrant", "count": 1}, {"class": "train", "count": 1}], "prompt": "a photo of a fire hydrant and a train"}
|
| 172 |
+
{"tag": "two_object", "include": [{"class": "baseball glove", "count": 1}, {"class": "carrot", "count": 1}], "prompt": "a photo of a baseball glove and a carrot"}
|
| 173 |
+
{"tag": "two_object", "include": [{"class": "microwave", "count": 1}, {"class": "bench", "count": 1}], "prompt": "a photo of a microwave and a bench"}
|
| 174 |
+
{"tag": "two_object", "include": [{"class": "cake", "count": 1}, {"class": "stop sign", "count": 1}], "prompt": "a photo of a cake and a stop sign"}
|
| 175 |
+
{"tag": "two_object", "include": [{"class": "car", "count": 1}, {"class": "computer mouse", "count": 1}], "prompt": "a photo of a car and a computer mouse"}
|
| 176 |
+
{"tag": "two_object", "include": [{"class": "suitcase", "count": 1}, {"class": "dining table", "count": 1}], "prompt": "a photo of a suitcase and a dining table"}
|
| 177 |
+
{"tag": "two_object", "include": [{"class": "person", "count": 1}, {"class": "traffic light", "count": 1}], "prompt": "a photo of a person and a traffic light"}
|
| 178 |
+
{"tag": "two_object", "include": [{"class": "cell phone", "count": 1}, {"class": "horse", "count": 1}], "prompt": "a photo of a cell phone and a horse"}
|
| 179 |
+
{"tag": "two_object", "include": [{"class": "baseball bat", "count": 1}, {"class": "giraffe", "count": 1}], "prompt": "a photo of a baseball bat and a giraffe"}
|
| 180 |
+
{"tag": "counting", "include": [{"class": "clock", "count": 2}], "exclude": [{"class": "clock", "count": 3}], "prompt": "a photo of two clocks"}
|
| 181 |
+
{"tag": "counting", "include": [{"class": "backpack", "count": 2}], "exclude": [{"class": "backpack", "count": 3}], "prompt": "a photo of two backpacks"}
|
| 182 |
+
{"tag": "counting", "include": [{"class": "handbag", "count": 4}], "exclude": [{"class": "handbag", "count": 5}], "prompt": "a photo of four handbags"}
|
| 183 |
+
{"tag": "counting", "include": [{"class": "frisbee", "count": 2}], "exclude": [{"class": "frisbee", "count": 3}], "prompt": "a photo of two frisbees"}
|
| 184 |
+
{"tag": "counting", "include": [{"class": "sports ball", "count": 3}], "exclude": [{"class": "sports ball", "count": 4}], "prompt": "a photo of three sports balls"}
|
| 185 |
+
{"tag": "counting", "include": [{"class": "bear", "count": 2}], "exclude": [{"class": "bear", "count": 3}], "prompt": "a photo of two bears"}
|
| 186 |
+
{"tag": "counting", "include": [{"class": "tie", "count": 2}], "exclude": [{"class": "tie", "count": 3}], "prompt": "a photo of two ties"}
|
| 187 |
+
{"tag": "counting", "include": [{"class": "sink", "count": 4}], "exclude": [{"class": "sink", "count": 5}], "prompt": "a photo of four sinks"}
|
| 188 |
+
{"tag": "counting", "include": [{"class": "toothbrush", "count": 2}], "exclude": [{"class": "toothbrush", "count": 3}], "prompt": "a photo of two toothbrushs"}
|
| 189 |
+
{"tag": "counting", "include": [{"class": "person", "count": 3}], "exclude": [{"class": "person", "count": 4}], "prompt": "a photo of three persons"}
|
| 190 |
+
{"tag": "counting", "include": [{"class": "tennis racket", "count": 3}], "exclude": [{"class": "tennis racket", "count": 4}], "prompt": "a photo of three tennis rackets"}
|
| 191 |
+
{"tag": "counting", "include": [{"class": "bowl", "count": 4}], "exclude": [{"class": "bowl", "count": 5}], "prompt": "a photo of four bowls"}
|
| 192 |
+
{"tag": "counting", "include": [{"class": "vase", "count": 4}], "exclude": [{"class": "vase", "count": 5}], "prompt": "a photo of four vases"}
|
| 193 |
+
{"tag": "counting", "include": [{"class": "cup", "count": 3}], "exclude": [{"class": "cup", "count": 4}], "prompt": "a photo of three cups"}
|
| 194 |
+
{"tag": "counting", "include": [{"class": "computer keyboard", "count": 4}], "exclude": [{"class": "computer keyboard", "count": 5}], "prompt": "a photo of four computer keyboards"}
|
| 195 |
+
{"tag": "counting", "include": [{"class": "sink", "count": 3}], "exclude": [{"class": "sink", "count": 4}], "prompt": "a photo of three sinks"}
|
| 196 |
+
{"tag": "counting", "include": [{"class": "oven", "count": 2}], "exclude": [{"class": "oven", "count": 3}], "prompt": "a photo of two ovens"}
|
| 197 |
+
{"tag": "counting", "include": [{"class": "toilet", "count": 2}], "exclude": [{"class": "toilet", "count": 3}], "prompt": "a photo of two toilets"}
|
| 198 |
+
{"tag": "counting", "include": [{"class": "bicycle", "count": 2}], "exclude": [{"class": "bicycle", "count": 3}], "prompt": "a photo of two bicycles"}
|
| 199 |
+
{"tag": "counting", "include": [{"class": "train", "count": 2}], "exclude": [{"class": "train", "count": 3}], "prompt": "a photo of two trains"}
|
| 200 |
+
{"tag": "counting", "include": [{"class": "orange", "count": 3}], "exclude": [{"class": "orange", "count": 4}], "prompt": "a photo of three oranges"}
|
| 201 |
+
{"tag": "counting", "include": [{"class": "bus", "count": 3}], "exclude": [{"class": "bus", "count": 4}], "prompt": "a photo of three buses"}
|
| 202 |
+
{"tag": "counting", "include": [{"class": "handbag", "count": 3}], "exclude": [{"class": "handbag", "count": 4}], "prompt": "a photo of three handbags"}
|
| 203 |
+
{"tag": "counting", "include": [{"class": "snowboard", "count": 3}], "exclude": [{"class": "snowboard", "count": 4}], "prompt": "a photo of three snowboards"}
|
| 204 |
+
{"tag": "counting", "include": [{"class": "snowboard", "count": 2}], "exclude": [{"class": "snowboard", "count": 3}], "prompt": "a photo of two snowboards"}
|
| 205 |
+
{"tag": "counting", "include": [{"class": "dog", "count": 4}], "exclude": [{"class": "dog", "count": 5}], "prompt": "a photo of four dogs"}
|
| 206 |
+
{"tag": "counting", "include": [{"class": "apple", "count": 3}], "exclude": [{"class": "apple", "count": 4}], "prompt": "a photo of three apples"}
|
| 207 |
+
{"tag": "counting", "include": [{"class": "sheep", "count": 2}], "exclude": [{"class": "sheep", "count": 3}], "prompt": "a photo of two sheeps"}
|
| 208 |
+
{"tag": "counting", "include": [{"class": "hot dog", "count": 3}], "exclude": [{"class": "hot dog", "count": 4}], "prompt": "a photo of three hot dogs"}
|
| 209 |
+
{"tag": "counting", "include": [{"class": "zebra", "count": 3}], "exclude": [{"class": "zebra", "count": 4}], "prompt": "a photo of three zebras"}
|
| 210 |
+
{"tag": "counting", "include": [{"class": "kite", "count": 3}], "exclude": [{"class": "kite", "count": 4}], "prompt": "a photo of three kites"}
|
| 211 |
+
{"tag": "counting", "include": [{"class": "apple", "count": 4}], "exclude": [{"class": "apple", "count": 5}], "prompt": "a photo of four apples"}
|
| 212 |
+
{"tag": "counting", "include": [{"class": "cell phone", "count": 3}], "exclude": [{"class": "cell phone", "count": 4}], "prompt": "a photo of three cell phones"}
|
| 213 |
+
{"tag": "counting", "include": [{"class": "baseball glove", "count": 4}], "exclude": [{"class": "baseball glove", "count": 5}], "prompt": "a photo of four baseball gloves"}
|
| 214 |
+
{"tag": "counting", "include": [{"class": "computer keyboard", "count": 3}], "exclude": [{"class": "computer keyboard", "count": 4}], "prompt": "a photo of three computer keyboards"}
|
| 215 |
+
{"tag": "counting", "include": [{"class": "bed", "count": 2}], "exclude": [{"class": "bed", "count": 3}], "prompt": "a photo of two beds"}
|
| 216 |
+
{"tag": "counting", "include": [{"class": "tv remote", "count": 2}], "exclude": [{"class": "tv remote", "count": 3}], "prompt": "a photo of two tv remotes"}
|
| 217 |
+
{"tag": "counting", "include": [{"class": "fire hydrant", "count": 3}], "exclude": [{"class": "fire hydrant", "count": 4}], "prompt": "a photo of three fire hydrants"}
|
| 218 |
+
{"tag": "counting", "include": [{"class": "book", "count": 3}], "exclude": [{"class": "book", "count": 4}], "prompt": "a photo of three books"}
|
| 219 |
+
{"tag": "counting", "include": [{"class": "giraffe", "count": 4}], "exclude": [{"class": "giraffe", "count": 5}], "prompt": "a photo of four giraffes"}
|
| 220 |
+
{"tag": "counting", "include": [{"class": "vase", "count": 2}], "exclude": [{"class": "vase", "count": 3}], "prompt": "a photo of two vases"}
|
| 221 |
+
{"tag": "counting", "include": [{"class": "donut", "count": 4}], "exclude": [{"class": "donut", "count": 5}], "prompt": "a photo of four donuts"}
|
| 222 |
+
{"tag": "counting", "include": [{"class": "chair", "count": 4}], "exclude": [{"class": "chair", "count": 5}], "prompt": "a photo of four chairs"}
|
| 223 |
+
{"tag": "counting", "include": [{"class": "baseball bat", "count": 3}], "exclude": [{"class": "baseball bat", "count": 4}], "prompt": "a photo of three baseball bats"}
|
| 224 |
+
{"tag": "counting", "include": [{"class": "stop sign", "count": 4}], "exclude": [{"class": "stop sign", "count": 5}], "prompt": "a photo of four stop signs"}
|
| 225 |
+
{"tag": "counting", "include": [{"class": "pizza", "count": 2}], "exclude": [{"class": "pizza", "count": 3}], "prompt": "a photo of two pizzas"}
|
| 226 |
+
{"tag": "counting", "include": [{"class": "refrigerator", "count": 3}], "exclude": [{"class": "refrigerator", "count": 4}], "prompt": "a photo of three refrigerators"}
|
| 227 |
+
{"tag": "counting", "include": [{"class": "fire hydrant", "count": 2}], "exclude": [{"class": "fire hydrant", "count": 3}], "prompt": "a photo of two fire hydrants"}
|
| 228 |
+
{"tag": "counting", "include": [{"class": "giraffe", "count": 3}], "exclude": [{"class": "giraffe", "count": 4}], "prompt": "a photo of three giraffes"}
|
| 229 |
+
{"tag": "counting", "include": [{"class": "tv", "count": 4}], "exclude": [{"class": "tv", "count": 5}], "prompt": "a photo of four tvs"}
|
| 230 |
+
{"tag": "counting", "include": [{"class": "wine glass", "count": 3}], "exclude": [{"class": "wine glass", "count": 4}], "prompt": "a photo of three wine glasses"}
|
| 231 |
+
{"tag": "counting", "include": [{"class": "broccoli", "count": 4}], "exclude": [{"class": "broccoli", "count": 5}], "prompt": "a photo of four broccolis"}
|
| 232 |
+
{"tag": "counting", "include": [{"class": "truck", "count": 3}], "exclude": [{"class": "truck", "count": 4}], "prompt": "a photo of three trucks"}
|
| 233 |
+
{"tag": "counting", "include": [{"class": "truck", "count": 2}], "exclude": [{"class": "truck", "count": 3}], "prompt": "a photo of two trucks"}
|
| 234 |
+
{"tag": "counting", "include": [{"class": "carrot", "count": 2}], "exclude": [{"class": "carrot", "count": 3}], "prompt": "a photo of two carrots"}
|
| 235 |
+
{"tag": "counting", "include": [{"class": "sandwich", "count": 2}], "exclude": [{"class": "sandwich", "count": 3}], "prompt": "a photo of two sandwichs"}
|
| 236 |
+
{"tag": "counting", "include": [{"class": "traffic light", "count": 4}], "exclude": [{"class": "traffic light", "count": 5}], "prompt": "a photo of four traffic lights"}
|
| 237 |
+
{"tag": "counting", "include": [{"class": "clock", "count": 4}], "exclude": [{"class": "clock", "count": 5}], "prompt": "a photo of four clocks"}
|
| 238 |
+
{"tag": "counting", "include": [{"class": "car", "count": 2}], "exclude": [{"class": "car", "count": 3}], "prompt": "a photo of two cars"}
|
| 239 |
+
{"tag": "counting", "include": [{"class": "banana", "count": 2}], "exclude": [{"class": "banana", "count": 3}], "prompt": "a photo of two bananas"}
|
| 240 |
+
{"tag": "counting", "include": [{"class": "wine glass", "count": 2}], "exclude": [{"class": "wine glass", "count": 3}], "prompt": "a photo of two wine glasses"}
|
| 241 |
+
{"tag": "counting", "include": [{"class": "pizza", "count": 3}], "exclude": [{"class": "pizza", "count": 4}], "prompt": "a photo of three pizzas"}
|
| 242 |
+
{"tag": "counting", "include": [{"class": "knife", "count": 4}], "exclude": [{"class": "knife", "count": 5}], "prompt": "a photo of four knifes"}
|
| 243 |
+
{"tag": "counting", "include": [{"class": "suitcase", "count": 3}], "exclude": [{"class": "suitcase", "count": 4}], "prompt": "a photo of three suitcases"}
|
| 244 |
+
{"tag": "counting", "include": [{"class": "zebra", "count": 4}], "exclude": [{"class": "zebra", "count": 5}], "prompt": "a photo of four zebras"}
|
| 245 |
+
{"tag": "counting", "include": [{"class": "teddy bear", "count": 2}], "exclude": [{"class": "teddy bear", "count": 3}], "prompt": "a photo of two teddy bears"}
|
| 246 |
+
{"tag": "counting", "include": [{"class": "skateboard", "count": 4}], "exclude": [{"class": "skateboard", "count": 5}], "prompt": "a photo of four skateboards"}
|
| 247 |
+
{"tag": "counting", "include": [{"class": "hot dog", "count": 4}], "exclude": [{"class": "hot dog", "count": 5}], "prompt": "a photo of four hot dogs"}
|
| 248 |
+
{"tag": "counting", "include": [{"class": "bird", "count": 3}], "exclude": [{"class": "bird", "count": 4}], "prompt": "a photo of three birds"}
|
| 249 |
+
{"tag": "counting", "include": [{"class": "boat", "count": 4}], "exclude": [{"class": "boat", "count": 5}], "prompt": "a photo of four boats"}
|
| 250 |
+
{"tag": "counting", "include": [{"class": "microwave", "count": 4}], "exclude": [{"class": "microwave", "count": 5}], "prompt": "a photo of four microwaves"}
|
| 251 |
+
{"tag": "counting", "include": [{"class": "hair drier", "count": 2}], "exclude": [{"class": "hair drier", "count": 3}], "prompt": "a photo of two hair driers"}
|
| 252 |
+
{"tag": "counting", "include": [{"class": "laptop", "count": 3}], "exclude": [{"class": "laptop", "count": 4}], "prompt": "a photo of three laptops"}
|
| 253 |
+
{"tag": "counting", "include": [{"class": "cow", "count": 3}], "exclude": [{"class": "cow", "count": 4}], "prompt": "a photo of three cows"}
|
| 254 |
+
{"tag": "counting", "include": [{"class": "parking meter", "count": 2}], "exclude": [{"class": "parking meter", "count": 3}], "prompt": "a photo of two parking meters"}
|
| 255 |
+
{"tag": "counting", "include": [{"class": "bench", "count": 4}], "exclude": [{"class": "bench", "count": 5}], "prompt": "a photo of four benchs"}
|
| 256 |
+
{"tag": "counting", "include": [{"class": "bench", "count": 3}], "exclude": [{"class": "bench", "count": 4}], "prompt": "a photo of three benchs"}
|
| 257 |
+
{"tag": "counting", "include": [{"class": "frisbee", "count": 4}], "exclude": [{"class": "frisbee", "count": 5}], "prompt": "a photo of four frisbees"}
|
| 258 |
+
{"tag": "counting", "include": [{"class": "book", "count": 4}], "exclude": [{"class": "book", "count": 5}], "prompt": "a photo of four books"}
|
| 259 |
+
{"tag": "counting", "include": [{"class": "bus", "count": 4}], "exclude": [{"class": "bus", "count": 5}], "prompt": "a photo of four buses"}
|
| 260 |
+
{"tag": "colors", "include": [{"class": "fire hydrant", "count": 1, "color": "blue"}], "prompt": "a photo of a blue fire hydrant"}
|
| 261 |
+
{"tag": "colors", "include": [{"class": "car", "count": 1, "color": "pink"}], "prompt": "a photo of a pink car"}
|
| 262 |
+
{"tag": "colors", "include": [{"class": "cup", "count": 1, "color": "purple"}], "prompt": "a photo of a purple cup"}
|
| 263 |
+
{"tag": "colors", "include": [{"class": "cow", "count": 1, "color": "blue"}], "prompt": "a photo of a blue cow"}
|
| 264 |
+
{"tag": "colors", "include": [{"class": "boat", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow boat"}
|
| 265 |
+
{"tag": "colors", "include": [{"class": "umbrella", "count": 1, "color": "blue"}], "prompt": "a photo of a blue umbrella"}
|
| 266 |
+
{"tag": "colors", "include": [{"class": "elephant", "count": 1, "color": "blue"}], "prompt": "a photo of a blue elephant"}
|
| 267 |
+
{"tag": "colors", "include": [{"class": "elephant", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow elephant"}
|
| 268 |
+
{"tag": "colors", "include": [{"class": "bicycle", "count": 1, "color": "red"}], "prompt": "a photo of a red bicycle"}
|
| 269 |
+
{"tag": "colors", "include": [{"class": "suitcase", "count": 1, "color": "purple"}], "prompt": "a photo of a purple suitcase"}
|
| 270 |
+
{"tag": "colors", "include": [{"class": "hair drier", "count": 1, "color": "purple"}], "prompt": "a photo of a purple hair drier"}
|
| 271 |
+
{"tag": "colors", "include": [{"class": "sandwich", "count": 1, "color": "white"}], "prompt": "a photo of a white sandwich"}
|
| 272 |
+
{"tag": "colors", "include": [{"class": "elephant", "count": 1, "color": "purple"}], "prompt": "a photo of a purple elephant"}
|
| 273 |
+
{"tag": "colors", "include": [{"class": "microwave", "count": 1, "color": "green"}], "prompt": "a photo of a green microwave"}
|
| 274 |
+
{"tag": "colors", "include": [{"class": "zebra", "count": 1, "color": "red"}], "prompt": "a photo of a red zebra"}
|
| 275 |
+
{"tag": "colors", "include": [{"class": "apple", "count": 1, "color": "red"}], "prompt": "a photo of a red apple"}
|
| 276 |
+
{"tag": "colors", "include": [{"class": "tv remote", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow tv remote"}
|
| 277 |
+
{"tag": "colors", "include": [{"class": "toilet", "count": 1, "color": "blue"}], "prompt": "a photo of a blue toilet"}
|
| 278 |
+
{"tag": "colors", "include": [{"class": "orange", "count": 1, "color": "orange"}], "prompt": "a photo of an orange orange"}
|
| 279 |
+
{"tag": "colors", "include": [{"class": "donut", "count": 1, "color": "black"}], "prompt": "a photo of a black donut"}
|
| 280 |
+
{"tag": "colors", "include": [{"class": "vase", "count": 1, "color": "red"}], "prompt": "a photo of a red vase"}
|
| 281 |
+
{"tag": "colors", "include": [{"class": "pizza", "count": 1, "color": "purple"}], "prompt": "a photo of a purple pizza"}
|
| 282 |
+
{"tag": "colors", "include": [{"class": "skateboard", "count": 1, "color": "pink"}], "prompt": "a photo of a pink skateboard"}
|
| 283 |
+
{"tag": "colors", "include": [{"class": "skateboard", "count": 1, "color": "green"}], "prompt": "a photo of a green skateboard"}
|
| 284 |
+
{"tag": "colors", "include": [{"class": "bear", "count": 1, "color": "purple"}], "prompt": "a photo of a purple bear"}
|
| 285 |
+
{"tag": "colors", "include": [{"class": "chair", "count": 1, "color": "brown"}], "prompt": "a photo of a brown chair"}
|
| 286 |
+
{"tag": "colors", "include": [{"class": "computer keyboard", "count": 1, "color": "brown"}], "prompt": "a photo of a brown computer keyboard"}
|
| 287 |
+
{"tag": "colors", "include": [{"class": "cow", "count": 1, "color": "orange"}], "prompt": "a photo of an orange cow"}
|
| 288 |
+
{"tag": "colors", "include": [{"class": "skis", "count": 1, "color": "brown"}], "prompt": "a photo of a brown skis"}
|
| 289 |
+
{"tag": "colors", "include": [{"class": "kite", "count": 1, "color": "white"}], "prompt": "a photo of a white kite"}
|
| 290 |
+
{"tag": "colors", "include": [{"class": "dog", "count": 1, "color": "red"}], "prompt": "a photo of a red dog"}
|
| 291 |
+
{"tag": "colors", "include": [{"class": "couch", "count": 1, "color": "green"}], "prompt": "a photo of a green couch"}
|
| 292 |
+
{"tag": "colors", "include": [{"class": "airplane", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow airplane"}
|
| 293 |
+
{"tag": "colors", "include": [{"class": "tv", "count": 1, "color": "orange"}], "prompt": "a photo of an orange tv"}
|
| 294 |
+
{"tag": "colors", "include": [{"class": "scissors", "count": 1, "color": "white"}], "prompt": "a photo of a white scissors"}
|
| 295 |
+
{"tag": "colors", "include": [{"class": "cell phone", "count": 1, "color": "pink"}], "prompt": "a photo of a pink cell phone"}
|
| 296 |
+
{"tag": "colors", "include": [{"class": "surfboard", "count": 1, "color": "green"}], "prompt": "a photo of a green surfboard"}
|
| 297 |
+
{"tag": "colors", "include": [{"class": "fire hydrant", "count": 1, "color": "white"}], "prompt": "a photo of a white fire hydrant"}
|
| 298 |
+
{"tag": "colors", "include": [{"class": "bicycle", "count": 1, "color": "black"}], "prompt": "a photo of a black bicycle"}
|
| 299 |
+
{"tag": "colors", "include": [{"class": "carrot", "count": 1, "color": "purple"}], "prompt": "a photo of a purple carrot"}
|
| 300 |
+
{"tag": "colors", "include": [{"class": "dining table", "count": 1, "color": "black"}], "prompt": "a photo of a black dining table"}
|
| 301 |
+
{"tag": "colors", "include": [{"class": "potted plant", "count": 1, "color": "purple"}], "prompt": "a photo of a purple potted plant"}
|
| 302 |
+
{"tag": "colors", "include": [{"class": "backpack", "count": 1, "color": "purple"}], "prompt": "a photo of a purple backpack"}
|
| 303 |
+
{"tag": "colors", "include": [{"class": "train", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow train"}
|
| 304 |
+
{"tag": "colors", "include": [{"class": "potted plant", "count": 1, "color": "pink"}], "prompt": "a photo of a pink potted plant"}
|
| 305 |
+
{"tag": "colors", "include": [{"class": "giraffe", "count": 1, "color": "red"}], "prompt": "a photo of a red giraffe"}
|
| 306 |
+
{"tag": "colors", "include": [{"class": "bear", "count": 1, "color": "brown"}], "prompt": "a photo of a brown bear"}
|
| 307 |
+
{"tag": "colors", "include": [{"class": "train", "count": 1, "color": "black"}], "prompt": "a photo of a black train"}
|
| 308 |
+
{"tag": "colors", "include": [{"class": "laptop", "count": 1, "color": "orange"}], "prompt": "a photo of an orange laptop"}
|
| 309 |
+
{"tag": "colors", "include": [{"class": "hot dog", "count": 1, "color": "green"}], "prompt": "a photo of a green hot dog"}
|
| 310 |
+
{"tag": "colors", "include": [{"class": "parking meter", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow parking meter"}
|
| 311 |
+
{"tag": "colors", "include": [{"class": "potted plant", "count": 1, "color": "red"}], "prompt": "a photo of a red potted plant"}
|
| 312 |
+
{"tag": "colors", "include": [{"class": "traffic light", "count": 1, "color": "green"}], "prompt": "a photo of a green traffic light"}
|
| 313 |
+
{"tag": "colors", "include": [{"class": "tv", "count": 1, "color": "blue"}], "prompt": "a photo of a blue tv"}
|
| 314 |
+
{"tag": "colors", "include": [{"class": "refrigerator", "count": 1, "color": "brown"}], "prompt": "a photo of a brown refrigerator"}
|
| 315 |
+
{"tag": "colors", "include": [{"class": "tv remote", "count": 1, "color": "black"}], "prompt": "a photo of a black tv remote"}
|
| 316 |
+
{"tag": "colors", "include": [{"class": "scissors", "count": 1, "color": "purple"}], "prompt": "a photo of a purple scissors"}
|
| 317 |
+
{"tag": "colors", "include": [{"class": "orange", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow orange"}
|
| 318 |
+
{"tag": "colors", "include": [{"class": "toaster", "count": 1, "color": "brown"}], "prompt": "a photo of a brown toaster"}
|
| 319 |
+
{"tag": "colors", "include": [{"class": "parking meter", "count": 1, "color": "red"}], "prompt": "a photo of a red parking meter"}
|
| 320 |
+
{"tag": "colors", "include": [{"class": "orange", "count": 1, "color": "brown"}], "prompt": "a photo of a brown orange"}
|
| 321 |
+
{"tag": "colors", "include": [{"class": "clock", "count": 1, "color": "green"}], "prompt": "a photo of a green clock"}
|
| 322 |
+
{"tag": "colors", "include": [{"class": "sheep", "count": 1, "color": "white"}], "prompt": "a photo of a white sheep"}
|
| 323 |
+
{"tag": "colors", "include": [{"class": "oven", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow oven"}
|
| 324 |
+
{"tag": "colors", "include": [{"class": "vase", "count": 1, "color": "green"}], "prompt": "a photo of a green vase"}
|
| 325 |
+
{"tag": "colors", "include": [{"class": "teddy bear", "count": 1, "color": "black"}], "prompt": "a photo of a black teddy bear"}
|
| 326 |
+
{"tag": "colors", "include": [{"class": "carrot", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow carrot"}
|
| 327 |
+
{"tag": "colors", "include": [{"class": "hot dog", "count": 1, "color": "black"}], "prompt": "a photo of a black hot dog"}
|
| 328 |
+
{"tag": "colors", "include": [{"class": "scissors", "count": 1, "color": "red"}], "prompt": "a photo of a red scissors"}
|
| 329 |
+
{"tag": "colors", "include": [{"class": "teddy bear", "count": 1, "color": "white"}], "prompt": "a photo of a white teddy bear"}
|
| 330 |
+
{"tag": "colors", "include": [{"class": "skis", "count": 1, "color": "black"}], "prompt": "a photo of a black skis"}
|
| 331 |
+
{"tag": "colors", "include": [{"class": "dining table", "count": 1, "color": "blue"}], "prompt": "a photo of a blue dining table"}
|
| 332 |
+
{"tag": "colors", "include": [{"class": "refrigerator", "count": 1, "color": "black"}], "prompt": "a photo of a black refrigerator"}
|
| 333 |
+
{"tag": "colors", "include": [{"class": "dog", "count": 1, "color": "white"}], "prompt": "a photo of a white dog"}
|
| 334 |
+
{"tag": "colors", "include": [{"class": "scissors", "count": 1, "color": "orange"}], "prompt": "a photo of an orange scissors"}
|
| 335 |
+
{"tag": "colors", "include": [{"class": "cell phone", "count": 1, "color": "red"}], "prompt": "a photo of a red cell phone"}
|
| 336 |
+
{"tag": "colors", "include": [{"class": "orange", "count": 1, "color": "white"}], "prompt": "a photo of a white orange"}
|
| 337 |
+
{"tag": "colors", "include": [{"class": "clock", "count": 1, "color": "blue"}], "prompt": "a photo of a blue clock"}
|
| 338 |
+
{"tag": "colors", "include": [{"class": "carrot", "count": 1, "color": "blue"}], "prompt": "a photo of a blue carrot"}
|
| 339 |
+
{"tag": "colors", "include": [{"class": "motorcycle", "count": 1, "color": "green"}], "prompt": "a photo of a green motorcycle"}
|
| 340 |
+
{"tag": "colors", "include": [{"class": "stop sign", "count": 1, "color": "pink"}], "prompt": "a photo of a pink stop sign"}
|
| 341 |
+
{"tag": "colors", "include": [{"class": "vase", "count": 1, "color": "black"}], "prompt": "a photo of a black vase"}
|
| 342 |
+
{"tag": "colors", "include": [{"class": "backpack", "count": 1, "color": "black"}], "prompt": "a photo of a black backpack"}
|
| 343 |
+
{"tag": "colors", "include": [{"class": "car", "count": 1, "color": "red"}], "prompt": "a photo of a red car"}
|
| 344 |
+
{"tag": "colors", "include": [{"class": "computer mouse", "count": 1, "color": "green"}], "prompt": "a photo of a green computer mouse"}
|
| 345 |
+
{"tag": "colors", "include": [{"class": "backpack", "count": 1, "color": "red"}], "prompt": "a photo of a red backpack"}
|
| 346 |
+
{"tag": "colors", "include": [{"class": "bus", "count": 1, "color": "green"}], "prompt": "a photo of a green bus"}
|
| 347 |
+
{"tag": "colors", "include": [{"class": "toaster", "count": 1, "color": "orange"}], "prompt": "a photo of an orange toaster"}
|
| 348 |
+
{"tag": "colors", "include": [{"class": "fork", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow fork"}
|
| 349 |
+
{"tag": "colors", "include": [{"class": "parking meter", "count": 1, "color": "pink"}], "prompt": "a photo of a pink parking meter"}
|
| 350 |
+
{"tag": "colors", "include": [{"class": "book", "count": 1, "color": "blue"}], "prompt": "a photo of a blue book"}
|
| 351 |
+
{"tag": "colors", "include": [{"class": "broccoli", "count": 1, "color": "yellow"}], "prompt": "a photo of a yellow broccoli"}
|
| 352 |
+
{"tag": "colors", "include": [{"class": "computer mouse", "count": 1, "color": "orange"}], "prompt": "a photo of an orange computer mouse"}
|
| 353 |
+
{"tag": "colors", "include": [{"class": "cake", "count": 1, "color": "red"}], "prompt": "a photo of a red cake"}
|
| 354 |
+
{"tag": "position", "include": [{"class": "teddy bear", "count": 1}, {"class": "dog", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a dog right of a teddy bear"}
|
| 355 |
+
{"tag": "position", "include": [{"class": "kite", "count": 1}, {"class": "wine glass", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a wine glass above a kite"}
|
| 356 |
+
{"tag": "position", "include": [{"class": "cup", "count": 1}, {"class": "couch", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a couch below a cup"}
|
| 357 |
+
{"tag": "position", "include": [{"class": "cow", "count": 1}, {"class": "laptop", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a laptop left of a cow"}
|
| 358 |
+
{"tag": "position", "include": [{"class": "hair drier", "count": 1}, {"class": "fork", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a fork above a hair drier"}
|
| 359 |
+
{"tag": "position", "include": [{"class": "baseball bat", "count": 1}, {"class": "tie", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a tie right of a baseball bat"}
|
| 360 |
+
{"tag": "position", "include": [{"class": "fork", "count": 1}, {"class": "stop sign", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a stop sign above a fork"}
|
| 361 |
+
{"tag": "position", "include": [{"class": "skateboard", "count": 1}, {"class": "bird", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a bird below a skateboard"}
|
| 362 |
+
{"tag": "position", "include": [{"class": "tv", "count": 1}, {"class": "apple", "count": 1, "position": ["above", 0]}], "prompt": "a photo of an apple above a tv"}
|
| 363 |
+
{"tag": "position", "include": [{"class": "potted plant", "count": 1}, {"class": "train", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a train above a potted plant"}
|
| 364 |
+
{"tag": "position", "include": [{"class": "refrigerator", "count": 1}, {"class": "truck", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a truck left of a refrigerator"}
|
| 365 |
+
{"tag": "position", "include": [{"class": "cow", "count": 1}, {"class": "tv remote", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a tv remote below a cow"}
|
| 366 |
+
{"tag": "position", "include": [{"class": "train", "count": 1}, {"class": "bottle", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a bottle right of a train"}
|
| 367 |
+
{"tag": "position", "include": [{"class": "cow", "count": 1}, {"class": "dog", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a dog above a cow"}
|
| 368 |
+
{"tag": "position", "include": [{"class": "person", "count": 1}, {"class": "skateboard", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a skateboard above a person"}
|
| 369 |
+
{"tag": "position", "include": [{"class": "umbrella", "count": 1}, {"class": "baseball glove", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a baseball glove below an umbrella"}
|
| 370 |
+
{"tag": "position", "include": [{"class": "oven", "count": 1}, {"class": "dining table", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a dining table right of an oven"}
|
| 371 |
+
{"tag": "position", "include": [{"class": "suitcase", "count": 1}, {"class": "hot dog", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a hot dog left of a suitcase"}
|
| 372 |
+
{"tag": "position", "include": [{"class": "toothbrush", "count": 1}, {"class": "bus", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a bus below a toothbrush"}
|
| 373 |
+
{"tag": "position", "include": [{"class": "sandwich", "count": 1}, {"class": "backpack", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a backpack right of a sandwich"}
|
| 374 |
+
{"tag": "position", "include": [{"class": "baseball bat", "count": 1}, {"class": "cake", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a cake below a baseball bat"}
|
| 375 |
+
{"tag": "position", "include": [{"class": "tie", "count": 1}, {"class": "dog", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a dog right of a tie"}
|
| 376 |
+
{"tag": "position", "include": [{"class": "boat", "count": 1}, {"class": "suitcase", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a suitcase right of a boat"}
|
| 377 |
+
{"tag": "position", "include": [{"class": "clock", "count": 1}, {"class": "bear", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a bear above a clock"}
|
| 378 |
+
{"tag": "position", "include": [{"class": "umbrella", "count": 1}, {"class": "tv remote", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a tv remote left of an umbrella"}
|
| 379 |
+
{"tag": "position", "include": [{"class": "umbrella", "count": 1}, {"class": "sports ball", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a sports ball left of an umbrella"}
|
| 380 |
+
{"tag": "position", "include": [{"class": "dining table", "count": 1}, {"class": "train", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a train right of a dining table"}
|
| 381 |
+
{"tag": "position", "include": [{"class": "elephant", "count": 1}, {"class": "hair drier", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a hair drier below an elephant"}
|
| 382 |
+
{"tag": "position", "include": [{"class": "spoon", "count": 1}, {"class": "tennis racket", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a tennis racket right of a spoon"}
|
| 383 |
+
{"tag": "position", "include": [{"class": "hot dog", "count": 1}, {"class": "wine glass", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a wine glass right of a hot dog"}
|
| 384 |
+
{"tag": "position", "include": [{"class": "bench", "count": 1}, {"class": "computer mouse", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a computer mouse left of a bench"}
|
| 385 |
+
{"tag": "position", "include": [{"class": "orange", "count": 1}, {"class": "carrot", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a carrot left of an orange"}
|
| 386 |
+
{"tag": "position", "include": [{"class": "toothbrush", "count": 1}, {"class": "kite", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a kite above a toothbrush"}
|
| 387 |
+
{"tag": "position", "include": [{"class": "traffic light", "count": 1}, {"class": "toaster", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a toaster below a traffic light"}
|
| 388 |
+
{"tag": "position", "include": [{"class": "baseball glove", "count": 1}, {"class": "cat", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a cat below a baseball glove"}
|
| 389 |
+
{"tag": "position", "include": [{"class": "zebra", "count": 1}, {"class": "skis", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a skis right of a zebra"}
|
| 390 |
+
{"tag": "position", "include": [{"class": "chair", "count": 1}, {"class": "stop sign", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a stop sign above a chair"}
|
| 391 |
+
{"tag": "position", "include": [{"class": "parking meter", "count": 1}, {"class": "stop sign", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a stop sign above a parking meter"}
|
| 392 |
+
{"tag": "position", "include": [{"class": "skateboard", "count": 1}, {"class": "hot dog", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a hot dog right of a skateboard"}
|
| 393 |
+
{"tag": "position", "include": [{"class": "computer keyboard", "count": 1}, {"class": "pizza", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a pizza below a computer keyboard"}
|
| 394 |
+
{"tag": "position", "include": [{"class": "toilet", "count": 1}, {"class": "hair drier", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a hair drier left of a toilet"}
|
| 395 |
+
{"tag": "position", "include": [{"class": "stop sign", "count": 1}, {"class": "cow", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a cow left of a stop sign"}
|
| 396 |
+
{"tag": "position", "include": [{"class": "skis", "count": 1}, {"class": "suitcase", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a suitcase above a skis"}
|
| 397 |
+
{"tag": "position", "include": [{"class": "laptop", "count": 1}, {"class": "book", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a book above a laptop"}
|
| 398 |
+
{"tag": "position", "include": [{"class": "pizza", "count": 1}, {"class": "toothbrush", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a toothbrush below a pizza"}
|
| 399 |
+
{"tag": "position", "include": [{"class": "kite", "count": 1}, {"class": "toilet", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a toilet left of a kite"}
|
| 400 |
+
{"tag": "position", "include": [{"class": "sink", "count": 1}, {"class": "tie", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a tie above a sink"}
|
| 401 |
+
{"tag": "position", "include": [{"class": "couch", "count": 1}, {"class": "bird", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a bird left of a couch"}
|
| 402 |
+
{"tag": "position", "include": [{"class": "sports ball", "count": 1}, {"class": "bed", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a bed right of a sports ball"}
|
| 403 |
+
{"tag": "position", "include": [{"class": "surfboard", "count": 1}, {"class": "elephant", "count": 1, "position": ["below", 0]}], "prompt": "a photo of an elephant below a surfboard"}
|
| 404 |
+
{"tag": "position", "include": [{"class": "motorcycle", "count": 1}, {"class": "frisbee", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a frisbee right of a motorcycle"}
|
| 405 |
+
{"tag": "position", "include": [{"class": "fire hydrant", "count": 1}, {"class": "vase", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a vase above a fire hydrant"}
|
| 406 |
+
{"tag": "position", "include": [{"class": "elephant", "count": 1}, {"class": "zebra", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a zebra left of an elephant"}
|
| 407 |
+
{"tag": "position", "include": [{"class": "bear", "count": 1}, {"class": "bench", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a bench left of a bear"}
|
| 408 |
+
{"tag": "position", "include": [{"class": "bench", "count": 1}, {"class": "donut", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a donut right of a bench"}
|
| 409 |
+
{"tag": "position", "include": [{"class": "horse", "count": 1}, {"class": "frisbee", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a frisbee below a horse"}
|
| 410 |
+
{"tag": "position", "include": [{"class": "snowboard", "count": 1}, {"class": "computer keyboard", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a computer keyboard above a snowboard"}
|
| 411 |
+
{"tag": "position", "include": [{"class": "cow", "count": 1}, {"class": "tv", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a tv below a cow"}
|
| 412 |
+
{"tag": "position", "include": [{"class": "horse", "count": 1}, {"class": "elephant", "count": 1, "position": ["below", 0]}], "prompt": "a photo of an elephant below a horse"}
|
| 413 |
+
{"tag": "position", "include": [{"class": "banana", "count": 1}, {"class": "suitcase", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a suitcase left of a banana"}
|
| 414 |
+
{"tag": "position", "include": [{"class": "airplane", "count": 1}, {"class": "train", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a train below an airplane"}
|
| 415 |
+
{"tag": "position", "include": [{"class": "backpack", "count": 1}, {"class": "cat", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a cat below a backpack"}
|
| 416 |
+
{"tag": "position", "include": [{"class": "cake", "count": 1}, {"class": "backpack", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a backpack below a cake"}
|
| 417 |
+
{"tag": "position", "include": [{"class": "knife", "count": 1}, {"class": "sandwich", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a sandwich below a knife"}
|
| 418 |
+
{"tag": "position", "include": [{"class": "parking meter", "count": 1}, {"class": "bicycle", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a bicycle above a parking meter"}
|
| 419 |
+
{"tag": "position", "include": [{"class": "suitcase", "count": 1}, {"class": "knife", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a knife right of a suitcase"}
|
| 420 |
+
{"tag": "position", "include": [{"class": "knife", "count": 1}, {"class": "hot dog", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a hot dog above a knife"}
|
| 421 |
+
{"tag": "position", "include": [{"class": "parking meter", "count": 1}, {"class": "zebra", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a zebra right of a parking meter"}
|
| 422 |
+
{"tag": "position", "include": [{"class": "zebra", "count": 1}, {"class": "chair", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a chair left of a zebra"}
|
| 423 |
+
{"tag": "position", "include": [{"class": "airplane", "count": 1}, {"class": "cow", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a cow below an airplane"}
|
| 424 |
+
{"tag": "position", "include": [{"class": "umbrella", "count": 1}, {"class": "cup", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a cup left of an umbrella"}
|
| 425 |
+
{"tag": "position", "include": [{"class": "computer keyboard", "count": 1}, {"class": "zebra", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a zebra below a computer keyboard"}
|
| 426 |
+
{"tag": "position", "include": [{"class": "broccoli", "count": 1}, {"class": "zebra", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a zebra below a broccoli"}
|
| 427 |
+
{"tag": "position", "include": [{"class": "sports ball", "count": 1}, {"class": "laptop", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a laptop below a sports ball"}
|
| 428 |
+
{"tag": "position", "include": [{"class": "baseball bat", "count": 1}, {"class": "truck", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a truck left of a baseball bat"}
|
| 429 |
+
{"tag": "position", "include": [{"class": "baseball bat", "count": 1}, {"class": "refrigerator", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a refrigerator above a baseball bat"}
|
| 430 |
+
{"tag": "position", "include": [{"class": "baseball bat", "count": 1}, {"class": "tv", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a tv above a baseball bat"}
|
| 431 |
+
{"tag": "position", "include": [{"class": "bear", "count": 1}, {"class": "baseball glove", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a baseball glove right of a bear"}
|
| 432 |
+
{"tag": "position", "include": [{"class": "scissors", "count": 1}, {"class": "refrigerator", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a refrigerator below a scissors"}
|
| 433 |
+
{"tag": "position", "include": [{"class": "suitcase", "count": 1}, {"class": "dining table", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a dining table above a suitcase"}
|
| 434 |
+
{"tag": "position", "include": [{"class": "broccoli", "count": 1}, {"class": "parking meter", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a parking meter above a broccoli"}
|
| 435 |
+
{"tag": "position", "include": [{"class": "truck", "count": 1}, {"class": "frisbee", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a frisbee above a truck"}
|
| 436 |
+
{"tag": "position", "include": [{"class": "banana", "count": 1}, {"class": "pizza", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a pizza right of a banana"}
|
| 437 |
+
{"tag": "position", "include": [{"class": "boat", "count": 1}, {"class": "bus", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a bus above a boat"}
|
| 438 |
+
{"tag": "position", "include": [{"class": "tennis racket", "count": 1}, {"class": "cell phone", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a cell phone left of a tennis racket"}
|
| 439 |
+
{"tag": "position", "include": [{"class": "broccoli", "count": 1}, {"class": "horse", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a horse right of a broccoli"}
|
| 440 |
+
{"tag": "position", "include": [{"class": "bottle", "count": 1}, {"class": "broccoli", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a broccoli above a bottle"}
|
| 441 |
+
{"tag": "position", "include": [{"class": "horse", "count": 1}, {"class": "vase", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a vase right of a horse"}
|
| 442 |
+
{"tag": "position", "include": [{"class": "spoon", "count": 1}, {"class": "bear", "count": 1, "position": ["above", 0]}], "prompt": "a photo of a bear above a spoon"}
|
| 443 |
+
{"tag": "position", "include": [{"class": "bed", "count": 1}, {"class": "zebra", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a zebra right of a bed"}
|
| 444 |
+
{"tag": "position", "include": [{"class": "laptop", "count": 1}, {"class": "cow", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a cow right of a laptop"}
|
| 445 |
+
{"tag": "position", "include": [{"class": "frisbee", "count": 1}, {"class": "bed", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a bed right of a frisbee"}
|
| 446 |
+
{"tag": "position", "include": [{"class": "motorcycle", "count": 1}, {"class": "tie", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a tie right of a motorcycle"}
|
| 447 |
+
{"tag": "position", "include": [{"class": "tv", "count": 1}, {"class": "laptop", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a laptop right of a tv"}
|
| 448 |
+
{"tag": "position", "include": [{"class": "chair", "count": 1}, {"class": "cell phone", "count": 1, "position": ["right of", 0]}], "prompt": "a photo of a cell phone right of a chair"}
|
| 449 |
+
{"tag": "position", "include": [{"class": "potted plant", "count": 1}, {"class": "couch", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a couch below a potted plant"}
|
| 450 |
+
{"tag": "position", "include": [{"class": "tv", "count": 1}, {"class": "clock", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a clock below a tv"}
|
| 451 |
+
{"tag": "position", "include": [{"class": "vase", "count": 1}, {"class": "couch", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a couch below a vase"}
|
| 452 |
+
{"tag": "position", "include": [{"class": "cat", "count": 1}, {"class": "donut", "count": 1, "position": ["below", 0]}], "prompt": "a photo of a donut below a cat"}
|
| 453 |
+
{"tag": "position", "include": [{"class": "toaster", "count": 1}, {"class": "couch", "count": 1, "position": ["left of", 0]}], "prompt": "a photo of a couch left of a toaster"}
|
| 454 |
+
{"tag": "color_attr", "include": [{"class": "wine glass", "count": 1, "color": "purple"}, {"class": "apple", "count": 1, "color": "black"}], "prompt": "a photo of a purple wine glass and a black apple"}
|
| 455 |
+
{"tag": "color_attr", "include": [{"class": "bus", "count": 1, "color": "green"}, {"class": "microwave", "count": 1, "color": "purple"}], "prompt": "a photo of a green bus and a purple microwave"}
|
| 456 |
+
{"tag": "color_attr", "include": [{"class": "skis", "count": 1, "color": "green"}, {"class": "airplane", "count": 1, "color": "brown"}], "prompt": "a photo of a green skis and a brown airplane"}
|
| 457 |
+
{"tag": "color_attr", "include": [{"class": "computer keyboard", "count": 1, "color": "yellow"}, {"class": "sink", "count": 1, "color": "black"}], "prompt": "a photo of a yellow computer keyboard and a black sink"}
|
| 458 |
+
{"tag": "color_attr", "include": [{"class": "oven", "count": 1, "color": "pink"}, {"class": "motorcycle", "count": 1, "color": "green"}], "prompt": "a photo of a pink oven and a green motorcycle"}
|
| 459 |
+
{"tag": "color_attr", "include": [{"class": "parking meter", "count": 1, "color": "purple"}, {"class": "laptop", "count": 1, "color": "red"}], "prompt": "a photo of a purple parking meter and a red laptop"}
|
| 460 |
+
{"tag": "color_attr", "include": [{"class": "skateboard", "count": 1, "color": "yellow"}, {"class": "computer mouse", "count": 1, "color": "orange"}], "prompt": "a photo of a yellow skateboard and an orange computer mouse"}
|
| 461 |
+
{"tag": "color_attr", "include": [{"class": "skis", "count": 1, "color": "red"}, {"class": "tie", "count": 1, "color": "brown"}], "prompt": "a photo of a red skis and a brown tie"}
|
| 462 |
+
{"tag": "color_attr", "include": [{"class": "skateboard", "count": 1, "color": "pink"}, {"class": "train", "count": 1, "color": "black"}], "prompt": "a photo of a pink skateboard and a black train"}
|
| 463 |
+
{"tag": "color_attr", "include": [{"class": "handbag", "count": 1, "color": "white"}, {"class": "bed", "count": 1, "color": "purple"}], "prompt": "a photo of a white handbag and a purple bed"}
|
| 464 |
+
{"tag": "color_attr", "include": [{"class": "elephant", "count": 1, "color": "purple"}, {"class": "sports ball", "count": 1, "color": "brown"}], "prompt": "a photo of a purple elephant and a brown sports ball"}
|
| 465 |
+
{"tag": "color_attr", "include": [{"class": "dog", "count": 1, "color": "purple"}, {"class": "dining table", "count": 1, "color": "black"}], "prompt": "a photo of a purple dog and a black dining table"}
|
| 466 |
+
{"tag": "color_attr", "include": [{"class": "dining table", "count": 1, "color": "white"}, {"class": "car", "count": 1, "color": "red"}], "prompt": "a photo of a white dining table and a red car"}
|
| 467 |
+
{"tag": "color_attr", "include": [{"class": "cell phone", "count": 1, "color": "blue"}, {"class": "apple", "count": 1, "color": "green"}], "prompt": "a photo of a blue cell phone and a green apple"}
|
| 468 |
+
{"tag": "color_attr", "include": [{"class": "car", "count": 1, "color": "red"}, {"class": "potted plant", "count": 1, "color": "orange"}], "prompt": "a photo of a red car and an orange potted plant"}
|
| 469 |
+
{"tag": "color_attr", "include": [{"class": "carrot", "count": 1, "color": "brown"}, {"class": "potted plant", "count": 1, "color": "white"}], "prompt": "a photo of a brown carrot and a white potted plant"}
|
| 470 |
+
{"tag": "color_attr", "include": [{"class": "kite", "count": 1, "color": "black"}, {"class": "bear", "count": 1, "color": "green"}], "prompt": "a photo of a black kite and a green bear"}
|
| 471 |
+
{"tag": "color_attr", "include": [{"class": "laptop", "count": 1, "color": "blue"}, {"class": "bear", "count": 1, "color": "brown"}], "prompt": "a photo of a blue laptop and a brown bear"}
|
| 472 |
+
{"tag": "color_attr", "include": [{"class": "teddy bear", "count": 1, "color": "green"}, {"class": "kite", "count": 1, "color": "brown"}], "prompt": "a photo of a green teddy bear and a brown kite"}
|
| 473 |
+
{"tag": "color_attr", "include": [{"class": "stop sign", "count": 1, "color": "yellow"}, {"class": "potted plant", "count": 1, "color": "blue"}], "prompt": "a photo of a yellow stop sign and a blue potted plant"}
|
| 474 |
+
{"tag": "color_attr", "include": [{"class": "snowboard", "count": 1, "color": "orange"}, {"class": "cat", "count": 1, "color": "green"}], "prompt": "a photo of an orange snowboard and a green cat"}
|
| 475 |
+
{"tag": "color_attr", "include": [{"class": "truck", "count": 1, "color": "orange"}, {"class": "sink", "count": 1, "color": "pink"}], "prompt": "a photo of an orange truck and a pink sink"}
|
| 476 |
+
{"tag": "color_attr", "include": [{"class": "hot dog", "count": 1, "color": "brown"}, {"class": "pizza", "count": 1, "color": "purple"}], "prompt": "a photo of a brown hot dog and a purple pizza"}
|
| 477 |
+
{"tag": "color_attr", "include": [{"class": "couch", "count": 1, "color": "green"}, {"class": "umbrella", "count": 1, "color": "orange"}], "prompt": "a photo of a green couch and an orange umbrella"}
|
| 478 |
+
{"tag": "color_attr", "include": [{"class": "bed", "count": 1, "color": "brown"}, {"class": "cell phone", "count": 1, "color": "pink"}], "prompt": "a photo of a brown bed and a pink cell phone"}
|
| 479 |
+
{"tag": "color_attr", "include": [{"class": "broccoli", "count": 1, "color": "black"}, {"class": "cake", "count": 1, "color": "yellow"}], "prompt": "a photo of a black broccoli and a yellow cake"}
|
| 480 |
+
{"tag": "color_attr", "include": [{"class": "train", "count": 1, "color": "red"}, {"class": "bear", "count": 1, "color": "purple"}], "prompt": "a photo of a red train and a purple bear"}
|
| 481 |
+
{"tag": "color_attr", "include": [{"class": "tennis racket", "count": 1, "color": "purple"}, {"class": "sink", "count": 1, "color": "black"}], "prompt": "a photo of a purple tennis racket and a black sink"}
|
| 482 |
+
{"tag": "color_attr", "include": [{"class": "vase", "count": 1, "color": "blue"}, {"class": "banana", "count": 1, "color": "black"}], "prompt": "a photo of a blue vase and a black banana"}
|
| 483 |
+
{"tag": "color_attr", "include": [{"class": "clock", "count": 1, "color": "blue"}, {"class": "cup", "count": 1, "color": "white"}], "prompt": "a photo of a blue clock and a white cup"}
|
| 484 |
+
{"tag": "color_attr", "include": [{"class": "umbrella", "count": 1, "color": "red"}, {"class": "couch", "count": 1, "color": "blue"}], "prompt": "a photo of a red umbrella and a blue couch"}
|
| 485 |
+
{"tag": "color_attr", "include": [{"class": "handbag", "count": 1, "color": "white"}, {"class": "giraffe", "count": 1, "color": "red"}], "prompt": "a photo of a white handbag and a red giraffe"}
|
| 486 |
+
{"tag": "color_attr", "include": [{"class": "tv remote", "count": 1, "color": "pink"}, {"class": "airplane", "count": 1, "color": "blue"}], "prompt": "a photo of a pink tv remote and a blue airplane"}
|
| 487 |
+
{"tag": "color_attr", "include": [{"class": "handbag", "count": 1, "color": "pink"}, {"class": "scissors", "count": 1, "color": "black"}], "prompt": "a photo of a pink handbag and a black scissors"}
|
| 488 |
+
{"tag": "color_attr", "include": [{"class": "car", "count": 1, "color": "brown"}, {"class": "hair drier", "count": 1, "color": "pink"}], "prompt": "a photo of a brown car and a pink hair drier"}
|
| 489 |
+
{"tag": "color_attr", "include": [{"class": "bus", "count": 1, "color": "black"}, {"class": "cell phone", "count": 1, "color": "brown"}], "prompt": "a photo of a black bus and a brown cell phone"}
|
| 490 |
+
{"tag": "color_attr", "include": [{"class": "sheep", "count": 1, "color": "purple"}, {"class": "banana", "count": 1, "color": "pink"}], "prompt": "a photo of a purple sheep and a pink banana"}
|
| 491 |
+
{"tag": "color_attr", "include": [{"class": "handbag", "count": 1, "color": "blue"}, {"class": "cell phone", "count": 1, "color": "white"}], "prompt": "a photo of a blue handbag and a white cell phone"}
|
| 492 |
+
{"tag": "color_attr", "include": [{"class": "pizza", "count": 1, "color": "white"}, {"class": "umbrella", "count": 1, "color": "green"}], "prompt": "a photo of a white pizza and a green umbrella"}
|
| 493 |
+
{"tag": "color_attr", "include": [{"class": "tie", "count": 1, "color": "white"}, {"class": "skateboard", "count": 1, "color": "purple"}], "prompt": "a photo of a white tie and a purple skateboard"}
|
| 494 |
+
{"tag": "color_attr", "include": [{"class": "sports ball", "count": 1, "color": "yellow"}, {"class": "boat", "count": 1, "color": "green"}], "prompt": "a photo of a yellow sports ball and a green boat"}
|
| 495 |
+
{"tag": "color_attr", "include": [{"class": "wine glass", "count": 1, "color": "white"}, {"class": "giraffe", "count": 1, "color": "brown"}], "prompt": "a photo of a white wine glass and a brown giraffe"}
|
| 496 |
+
{"tag": "color_attr", "include": [{"class": "bowl", "count": 1, "color": "yellow"}, {"class": "baseball glove", "count": 1, "color": "white"}], "prompt": "a photo of a yellow bowl and a white baseball glove"}
|
| 497 |
+
{"tag": "color_attr", "include": [{"class": "microwave", "count": 1, "color": "orange"}, {"class": "spoon", "count": 1, "color": "black"}], "prompt": "a photo of an orange microwave and a black spoon"}
|
| 498 |
+
{"tag": "color_attr", "include": [{"class": "skateboard", "count": 1, "color": "orange"}, {"class": "bowl", "count": 1, "color": "pink"}], "prompt": "a photo of an orange skateboard and a pink bowl"}
|
| 499 |
+
{"tag": "color_attr", "include": [{"class": "toilet", "count": 1, "color": "blue"}, {"class": "suitcase", "count": 1, "color": "white"}], "prompt": "a photo of a blue toilet and a white suitcase"}
|
| 500 |
+
{"tag": "color_attr", "include": [{"class": "boat", "count": 1, "color": "white"}, {"class": "hot dog", "count": 1, "color": "orange"}], "prompt": "a photo of a white boat and an orange hot dog"}
|
| 501 |
+
{"tag": "color_attr", "include": [{"class": "dining table", "count": 1, "color": "yellow"}, {"class": "dog", "count": 1, "color": "pink"}], "prompt": "a photo of a yellow dining table and a pink dog"}
|
| 502 |
+
{"tag": "color_attr", "include": [{"class": "cake", "count": 1, "color": "red"}, {"class": "chair", "count": 1, "color": "purple"}], "prompt": "a photo of a red cake and a purple chair"}
|
| 503 |
+
{"tag": "color_attr", "include": [{"class": "tie", "count": 1, "color": "blue"}, {"class": "dining table", "count": 1, "color": "pink"}], "prompt": "a photo of a blue tie and a pink dining table"}
|
| 504 |
+
{"tag": "color_attr", "include": [{"class": "cow", "count": 1, "color": "blue"}, {"class": "computer keyboard", "count": 1, "color": "black"}], "prompt": "a photo of a blue cow and a black computer keyboard"}
|
| 505 |
+
{"tag": "color_attr", "include": [{"class": "pizza", "count": 1, "color": "yellow"}, {"class": "oven", "count": 1, "color": "green"}], "prompt": "a photo of a yellow pizza and a green oven"}
|
| 506 |
+
{"tag": "color_attr", "include": [{"class": "laptop", "count": 1, "color": "red"}, {"class": "car", "count": 1, "color": "brown"}], "prompt": "a photo of a red laptop and a brown car"}
|
| 507 |
+
{"tag": "color_attr", "include": [{"class": "computer keyboard", "count": 1, "color": "purple"}, {"class": "scissors", "count": 1, "color": "blue"}], "prompt": "a photo of a purple computer keyboard and a blue scissors"}
|
| 508 |
+
{"tag": "color_attr", "include": [{"class": "surfboard", "count": 1, "color": "green"}, {"class": "oven", "count": 1, "color": "orange"}], "prompt": "a photo of a green surfboard and an orange oven"}
|
| 509 |
+
{"tag": "color_attr", "include": [{"class": "parking meter", "count": 1, "color": "yellow"}, {"class": "refrigerator", "count": 1, "color": "pink"}], "prompt": "a photo of a yellow parking meter and a pink refrigerator"}
|
| 510 |
+
{"tag": "color_attr", "include": [{"class": "computer mouse", "count": 1, "color": "brown"}, {"class": "bottle", "count": 1, "color": "purple"}], "prompt": "a photo of a brown computer mouse and a purple bottle"}
|
| 511 |
+
{"tag": "color_attr", "include": [{"class": "umbrella", "count": 1, "color": "red"}, {"class": "cow", "count": 1, "color": "green"}], "prompt": "a photo of a red umbrella and a green cow"}
|
| 512 |
+
{"tag": "color_attr", "include": [{"class": "giraffe", "count": 1, "color": "red"}, {"class": "cell phone", "count": 1, "color": "black"}], "prompt": "a photo of a red giraffe and a black cell phone"}
|
| 513 |
+
{"tag": "color_attr", "include": [{"class": "oven", "count": 1, "color": "brown"}, {"class": "train", "count": 1, "color": "purple"}], "prompt": "a photo of a brown oven and a purple train"}
|
| 514 |
+
{"tag": "color_attr", "include": [{"class": "baseball bat", "count": 1, "color": "blue"}, {"class": "book", "count": 1, "color": "pink"}], "prompt": "a photo of a blue baseball bat and a pink book"}
|
| 515 |
+
{"tag": "color_attr", "include": [{"class": "cup", "count": 1, "color": "green"}, {"class": "bowl", "count": 1, "color": "yellow"}], "prompt": "a photo of a green cup and a yellow bowl"}
|
| 516 |
+
{"tag": "color_attr", "include": [{"class": "suitcase", "count": 1, "color": "yellow"}, {"class": "bus", "count": 1, "color": "brown"}], "prompt": "a photo of a yellow suitcase and a brown bus"}
|
| 517 |
+
{"tag": "color_attr", "include": [{"class": "motorcycle", "count": 1, "color": "orange"}, {"class": "donut", "count": 1, "color": "pink"}], "prompt": "a photo of an orange motorcycle and a pink donut"}
|
| 518 |
+
{"tag": "color_attr", "include": [{"class": "giraffe", "count": 1, "color": "orange"}, {"class": "baseball glove", "count": 1, "color": "white"}], "prompt": "a photo of an orange giraffe and a white baseball glove"}
|
| 519 |
+
{"tag": "color_attr", "include": [{"class": "handbag", "count": 1, "color": "orange"}, {"class": "carrot", "count": 1, "color": "green"}], "prompt": "a photo of an orange handbag and a green carrot"}
|
| 520 |
+
{"tag": "color_attr", "include": [{"class": "bottle", "count": 1, "color": "black"}, {"class": "refrigerator", "count": 1, "color": "white"}], "prompt": "a photo of a black bottle and a white refrigerator"}
|
| 521 |
+
{"tag": "color_attr", "include": [{"class": "dog", "count": 1, "color": "white"}, {"class": "potted plant", "count": 1, "color": "blue"}], "prompt": "a photo of a white dog and a blue potted plant"}
|
| 522 |
+
{"tag": "color_attr", "include": [{"class": "handbag", "count": 1, "color": "orange"}, {"class": "car", "count": 1, "color": "red"}], "prompt": "a photo of an orange handbag and a red car"}
|
| 523 |
+
{"tag": "color_attr", "include": [{"class": "stop sign", "count": 1, "color": "red"}, {"class": "book", "count": 1, "color": "blue"}], "prompt": "a photo of a red stop sign and a blue book"}
|
| 524 |
+
{"tag": "color_attr", "include": [{"class": "car", "count": 1, "color": "yellow"}, {"class": "toothbrush", "count": 1, "color": "orange"}], "prompt": "a photo of a yellow car and an orange toothbrush"}
|
| 525 |
+
{"tag": "color_attr", "include": [{"class": "potted plant", "count": 1, "color": "black"}, {"class": "toilet", "count": 1, "color": "yellow"}], "prompt": "a photo of a black potted plant and a yellow toilet"}
|
| 526 |
+
{"tag": "color_attr", "include": [{"class": "dining table", "count": 1, "color": "brown"}, {"class": "suitcase", "count": 1, "color": "white"}], "prompt": "a photo of a brown dining table and a white suitcase"}
|
| 527 |
+
{"tag": "color_attr", "include": [{"class": "donut", "count": 1, "color": "orange"}, {"class": "stop sign", "count": 1, "color": "yellow"}], "prompt": "a photo of an orange donut and a yellow stop sign"}
|
| 528 |
+
{"tag": "color_attr", "include": [{"class": "suitcase", "count": 1, "color": "green"}, {"class": "boat", "count": 1, "color": "blue"}], "prompt": "a photo of a green suitcase and a blue boat"}
|
| 529 |
+
{"tag": "color_attr", "include": [{"class": "tennis racket", "count": 1, "color": "orange"}, {"class": "sports ball", "count": 1, "color": "yellow"}], "prompt": "a photo of an orange tennis racket and a yellow sports ball"}
|
| 530 |
+
{"tag": "color_attr", "include": [{"class": "computer keyboard", "count": 1, "color": "purple"}, {"class": "chair", "count": 1, "color": "red"}], "prompt": "a photo of a purple computer keyboard and a red chair"}
|
| 531 |
+
{"tag": "color_attr", "include": [{"class": "suitcase", "count": 1, "color": "purple"}, {"class": "pizza", "count": 1, "color": "orange"}], "prompt": "a photo of a purple suitcase and an orange pizza"}
|
| 532 |
+
{"tag": "color_attr", "include": [{"class": "bottle", "count": 1, "color": "white"}, {"class": "sheep", "count": 1, "color": "blue"}], "prompt": "a photo of a white bottle and a blue sheep"}
|
| 533 |
+
{"tag": "color_attr", "include": [{"class": "backpack", "count": 1, "color": "purple"}, {"class": "umbrella", "count": 1, "color": "white"}], "prompt": "a photo of a purple backpack and a white umbrella"}
|
| 534 |
+
{"tag": "color_attr", "include": [{"class": "potted plant", "count": 1, "color": "orange"}, {"class": "spoon", "count": 1, "color": "black"}], "prompt": "a photo of an orange potted plant and a black spoon"}
|
| 535 |
+
{"tag": "color_attr", "include": [{"class": "tennis racket", "count": 1, "color": "green"}, {"class": "dog", "count": 1, "color": "black"}], "prompt": "a photo of a green tennis racket and a black dog"}
|
| 536 |
+
{"tag": "color_attr", "include": [{"class": "handbag", "count": 1, "color": "yellow"}, {"class": "refrigerator", "count": 1, "color": "blue"}], "prompt": "a photo of a yellow handbag and a blue refrigerator"}
|
| 537 |
+
{"tag": "color_attr", "include": [{"class": "broccoli", "count": 1, "color": "pink"}, {"class": "sink", "count": 1, "color": "red"}], "prompt": "a photo of a pink broccoli and a red sink"}
|
| 538 |
+
{"tag": "color_attr", "include": [{"class": "bowl", "count": 1, "color": "red"}, {"class": "sink", "count": 1, "color": "pink"}], "prompt": "a photo of a red bowl and a pink sink"}
|
| 539 |
+
{"tag": "color_attr", "include": [{"class": "toilet", "count": 1, "color": "white"}, {"class": "apple", "count": 1, "color": "red"}], "prompt": "a photo of a white toilet and a red apple"}
|
| 540 |
+
{"tag": "color_attr", "include": [{"class": "dining table", "count": 1, "color": "pink"}, {"class": "sandwich", "count": 1, "color": "black"}], "prompt": "a photo of a pink dining table and a black sandwich"}
|
| 541 |
+
{"tag": "color_attr", "include": [{"class": "car", "count": 1, "color": "black"}, {"class": "parking meter", "count": 1, "color": "green"}], "prompt": "a photo of a black car and a green parking meter"}
|
| 542 |
+
{"tag": "color_attr", "include": [{"class": "bird", "count": 1, "color": "yellow"}, {"class": "motorcycle", "count": 1, "color": "black"}], "prompt": "a photo of a yellow bird and a black motorcycle"}
|
| 543 |
+
{"tag": "color_attr", "include": [{"class": "giraffe", "count": 1, "color": "brown"}, {"class": "stop sign", "count": 1, "color": "white"}], "prompt": "a photo of a brown giraffe and a white stop sign"}
|
| 544 |
+
{"tag": "color_attr", "include": [{"class": "banana", "count": 1, "color": "white"}, {"class": "elephant", "count": 1, "color": "black"}], "prompt": "a photo of a white banana and a black elephant"}
|
| 545 |
+
{"tag": "color_attr", "include": [{"class": "cow", "count": 1, "color": "orange"}, {"class": "sandwich", "count": 1, "color": "purple"}], "prompt": "a photo of an orange cow and a purple sandwich"}
|
| 546 |
+
{"tag": "color_attr", "include": [{"class": "clock", "count": 1, "color": "red"}, {"class": "cell phone", "count": 1, "color": "black"}], "prompt": "a photo of a red clock and a black cell phone"}
|
| 547 |
+
{"tag": "color_attr", "include": [{"class": "knife", "count": 1, "color": "brown"}, {"class": "donut", "count": 1, "color": "blue"}], "prompt": "a photo of a brown knife and a blue donut"}
|
| 548 |
+
{"tag": "color_attr", "include": [{"class": "cup", "count": 1, "color": "red"}, {"class": "handbag", "count": 1, "color": "pink"}], "prompt": "a photo of a red cup and a pink handbag"}
|
| 549 |
+
{"tag": "color_attr", "include": [{"class": "bicycle", "count": 1, "color": "yellow"}, {"class": "motorcycle", "count": 1, "color": "red"}], "prompt": "a photo of a yellow bicycle and a red motorcycle"}
|
| 550 |
+
{"tag": "color_attr", "include": [{"class": "orange", "count": 1, "color": "red"}, {"class": "broccoli", "count": 1, "color": "purple"}], "prompt": "a photo of a red orange and a purple broccoli"}
|
| 551 |
+
{"tag": "color_attr", "include": [{"class": "traffic light", "count": 1, "color": "orange"}, {"class": "toilet", "count": 1, "color": "white"}], "prompt": "a photo of an orange traffic light and a white toilet"}
|
| 552 |
+
{"tag": "color_attr", "include": [{"class": "cup", "count": 1, "color": "green"}, {"class": "pizza", "count": 1, "color": "red"}], "prompt": "a photo of a green cup and a red pizza"}
|
| 553 |
+
{"tag": "color_attr", "include": [{"class": "pizza", "count": 1, "color": "blue"}, {"class": "baseball glove", "count": 1, "color": "yellow"}], "prompt": "a photo of a blue pizza and a yellow baseball glove"}
|
prompts/ocr_test.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==1.7.0
|
| 2 |
+
aiohappyeyeballs==2.6.1
|
| 3 |
+
aiohttp==3.12.9
|
| 4 |
+
aiosignal==1.3.2
|
| 5 |
+
airportsdata==20250523
|
| 6 |
+
annotated-types==0.7.0
|
| 7 |
+
anthropic==0.54.0
|
| 8 |
+
antlr4-python3-runtime==4.13.2
|
| 9 |
+
anyio==4.9.0
|
| 10 |
+
astor==0.8.1
|
| 11 |
+
asttokens==3.0.0
|
| 12 |
+
attrs==25.3.0
|
| 13 |
+
av==14.4.0
|
| 14 |
+
bitsandbytes==0.46.0
|
| 15 |
+
blake3==1.0.5
|
| 16 |
+
cachetools==6.0.0
|
| 17 |
+
certifi==2025.4.26
|
| 18 |
+
charset-normalizer==3.4.2
|
| 19 |
+
click==8.2.1
|
| 20 |
+
cloudpickle==3.1.1
|
| 21 |
+
compressed-tensors==0.9.4
|
| 22 |
+
contourpy==1.3.2
|
| 23 |
+
cupy-cuda12x==13.4.1
|
| 24 |
+
cycler==0.12.1
|
| 25 |
+
datasets==3.6.0
|
| 26 |
+
decorator==5.2.1
|
| 27 |
+
deepspeed==0.15.4
|
| 28 |
+
depyf==0.18.0
|
| 29 |
+
dill==0.3.8
|
| 30 |
+
diskcache==5.6.3
|
| 31 |
+
distro==1.9.0
|
| 32 |
+
dnspython==2.7.0
|
| 33 |
+
docker-pycreds==0.4.0
|
| 34 |
+
einops==0.8.1
|
| 35 |
+
email-validator==2.2.0
|
| 36 |
+
executing==2.2.0
|
| 37 |
+
fastapi==0.115.12
|
| 38 |
+
fastapi-cli==0.0.7
|
| 39 |
+
fastrlock==0.8.3
|
| 40 |
+
filelock==3.18.0
|
| 41 |
+
fonttools==4.58.4
|
| 42 |
+
frozenlist==1.6.2
|
| 43 |
+
fsspec==2025.3.0
|
| 44 |
+
ftfy==6.3.1
|
| 45 |
+
gguf==0.17.0
|
| 46 |
+
gitdb==4.0.12
|
| 47 |
+
gitpython==3.1.44
|
| 48 |
+
googleapis-common-protos==1.70.0
|
| 49 |
+
grpcio==1.72.1
|
| 50 |
+
h11==0.16.0
|
| 51 |
+
hf-transfer==0.1.9
|
| 52 |
+
hf-xet==1.1.3
|
| 53 |
+
hjson==3.1.0
|
| 54 |
+
httpcore==1.0.9
|
| 55 |
+
httptools==0.6.4
|
| 56 |
+
httpx==0.28.1
|
| 57 |
+
huggingface-hub==0.32.4
|
| 58 |
+
idna==3.10
|
| 59 |
+
importlib-metadata==8.7.0
|
| 60 |
+
inquirerpy==0.3.4
|
| 61 |
+
interegular==0.3.3
|
| 62 |
+
ipython==9.3.0
|
| 63 |
+
ipython-pygments-lexers==1.1.1
|
| 64 |
+
jedi==0.19.2
|
| 65 |
+
jinja2==3.1.6
|
| 66 |
+
jiter==0.10.0
|
| 67 |
+
jsonschema==4.24.0
|
| 68 |
+
jsonschema-specifications==2025.4.1
|
| 69 |
+
kiwisolver==1.4.8
|
| 70 |
+
lark==1.2.2
|
| 71 |
+
latex2sympy2-extended==1.10.1
|
| 72 |
+
liger-kernel==0.5.2
|
| 73 |
+
llguidance==0.7.29
|
| 74 |
+
llvmlite==0.44.0
|
| 75 |
+
lm-format-enforcer==0.10.11
|
| 76 |
+
markdown-it-py==3.0.0
|
| 77 |
+
markupsafe==3.0.2
|
| 78 |
+
math-verify==0.7.0
|
| 79 |
+
matplotlib==3.10.3
|
| 80 |
+
matplotlib-inline==0.1.7
|
| 81 |
+
mdurl==0.1.2
|
| 82 |
+
mistral-common==1.5.6
|
| 83 |
+
mpmath==1.3.0
|
| 84 |
+
msgpack==1.1.0
|
| 85 |
+
msgspec==0.19.0
|
| 86 |
+
multidict==6.4.4
|
| 87 |
+
multiprocess==0.70.16
|
| 88 |
+
nest-asyncio==1.6.0
|
| 89 |
+
networkx==3.5
|
| 90 |
+
ninja==1.11.1.4
|
| 91 |
+
numba==0.61.2
|
| 92 |
+
numpy==2.2.6
|
| 93 |
+
nvidia-cublas-cu12==12.6.4.1
|
| 94 |
+
nvidia-cuda-cupti-cu12==12.6.80
|
| 95 |
+
nvidia-cuda-nvrtc-cu12==12.6.77
|
| 96 |
+
nvidia-cuda-runtime-cu12==12.6.77
|
| 97 |
+
nvidia-cudnn-cu12==9.5.1.17
|
| 98 |
+
nvidia-cufft-cu12==11.3.0.4
|
| 99 |
+
nvidia-cufile-cu12==1.11.1.6
|
| 100 |
+
nvidia-curand-cu12==10.3.7.77
|
| 101 |
+
nvidia-cusolver-cu12==11.7.1.2
|
| 102 |
+
nvidia-cusparse-cu12==12.5.4.2
|
| 103 |
+
nvidia-cusparselt-cu12==0.6.3
|
| 104 |
+
nvidia-nccl-cu12==2.26.2
|
| 105 |
+
nvidia-nvjitlink-cu12==12.6.85
|
| 106 |
+
nvidia-nvtx-cu12==12.6.77
|
| 107 |
+
openai==1.84.0
|
| 108 |
+
opencv-python-headless==4.11.0.86
|
| 109 |
+
opentelemetry-api==1.34.0
|
| 110 |
+
opentelemetry-exporter-otlp==1.34.0
|
| 111 |
+
opentelemetry-exporter-otlp-proto-common==1.34.0
|
| 112 |
+
opentelemetry-exporter-otlp-proto-grpc==1.34.0
|
| 113 |
+
opentelemetry-exporter-otlp-proto-http==1.34.0
|
| 114 |
+
opentelemetry-proto==1.34.0
|
| 115 |
+
opentelemetry-sdk==1.34.0
|
| 116 |
+
opentelemetry-semantic-conventions==0.55b0
|
| 117 |
+
opentelemetry-semantic-conventions-ai==0.4.9
|
| 118 |
+
outlines==0.1.11
|
| 119 |
+
outlines-core==0.1.26
|
| 120 |
+
packaging==25.0
|
| 121 |
+
pandas==2.3.0
|
| 122 |
+
parso==0.8.4
|
| 123 |
+
partial-json-parser==0.2.1.1.post5
|
| 124 |
+
peft==0.17.1
|
| 125 |
+
pexpect==4.9.0
|
| 126 |
+
pfzy==0.3.4
|
| 127 |
+
pillow==11.2.1
|
| 128 |
+
platformdirs==4.3.8
|
| 129 |
+
prometheus-client==0.22.1
|
| 130 |
+
prometheus-fastapi-instrumentator==7.1.0
|
| 131 |
+
prompt-toolkit==3.0.51
|
| 132 |
+
propcache==0.3.1
|
| 133 |
+
protobuf==5.29.5
|
| 134 |
+
psutil==7.0.0
|
| 135 |
+
ptyprocess==0.7.0
|
| 136 |
+
pure-eval==0.2.3
|
| 137 |
+
py-cpuinfo==9.0.0
|
| 138 |
+
pyarrow==20.0.0
|
| 139 |
+
pycountry==24.6.1
|
| 140 |
+
pydantic==2.11.5
|
| 141 |
+
pydantic-core==2.33.2
|
| 142 |
+
pygments==2.19.1
|
| 143 |
+
pyparsing==3.2.3
|
| 144 |
+
python-dateutil==2.9.0.post0
|
| 145 |
+
python-dotenv==1.1.0
|
| 146 |
+
python-json-logger==3.3.0
|
| 147 |
+
python-multipart==0.0.20
|
| 148 |
+
pytz==2025.2
|
| 149 |
+
pyyaml==6.0.2
|
| 150 |
+
pyzmq==26.4.0
|
| 151 |
+
qwen-vl-utils==0.0.11
|
| 152 |
+
ray==2.46.0
|
| 153 |
+
referencing==0.36.2
|
| 154 |
+
regex==2024.11.6
|
| 155 |
+
requests==2.32.3
|
| 156 |
+
rich==14.0.0
|
| 157 |
+
rich-toolkit==0.14.7
|
| 158 |
+
rpds-py==0.25.1
|
| 159 |
+
safetensors==0.5.3
|
| 160 |
+
scipy==1.15.3
|
| 161 |
+
seaborn==0.13.2
|
| 162 |
+
sentencepiece==0.2.0
|
| 163 |
+
sentry-sdk==2.29.1
|
| 164 |
+
setproctitle==1.3.6
|
| 165 |
+
shellingham==1.5.4
|
| 166 |
+
six==1.17.0
|
| 167 |
+
smmap==5.0.2
|
| 168 |
+
sniffio==1.3.1
|
| 169 |
+
stack-data==0.6.3
|
| 170 |
+
starlette==0.46.2
|
| 171 |
+
sympy==1.14.0
|
| 172 |
+
tabulate==0.9.0
|
| 173 |
+
tiktoken==0.9.0
|
| 174 |
+
timm==0.6.13
|
| 175 |
+
tokenizers==0.21.1
|
| 176 |
+
torch==2.7.0
|
| 177 |
+
torchaudio==2.7.0
|
| 178 |
+
torchvision==0.22.0
|
| 179 |
+
tqdm==4.67.1
|
| 180 |
+
traitlets==5.14.3
|
| 181 |
+
transformers==4.51.3
|
| 182 |
+
triton==3.3.0
|
| 183 |
+
trl==0.19.0
|
| 184 |
+
typer==0.16.0
|
| 185 |
+
typing-extensions==4.14.0
|
| 186 |
+
typing-inspection==0.4.1
|
| 187 |
+
tzdata==2025.2
|
| 188 |
+
urllib3==2.4.0
|
| 189 |
+
utils==1.0.2
|
| 190 |
+
uvicorn==0.34.3
|
| 191 |
+
uvloop==0.21.0
|
| 192 |
+
vllm==0.9.0.1
|
| 193 |
+
wandb==0.18.3
|
| 194 |
+
watchfiles==1.0.5
|
| 195 |
+
wcwidth==0.2.13
|
| 196 |
+
websockets==15.0.1
|
| 197 |
+
xformers==0.0.30
|
| 198 |
+
xgrammar==0.1.19
|
| 199 |
+
xxhash==3.5.0
|
| 200 |
+
yarl==1.20.0
|
| 201 |
+
zipp==3.22.0
|
| 202 |
+
tensorboardX==2.6.4
|
unified_inference.py
ADDED
|
@@ -0,0 +1,660 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified Inference Script for Multi-Modal Image Generation and Editing
|
| 3 |
+
|
| 4 |
+
Supports three modes:
|
| 5 |
+
1. t2i (Text-to-Image): Generate images from text prompts (txt file)
|
| 6 |
+
2. geneval: Generate multiple samples per prompt for evaluation (jsonl file)
|
| 7 |
+
3. edit: Edit images based on prompts (parquet file)
|
| 8 |
+
|
| 9 |
+
Example usage:
|
| 10 |
+
# Text-to-Image
|
| 11 |
+
python unified_inference.py --mode t2i --model_path ./model --model_type flux \
|
| 12 |
+
--prompt_file prompts.txt --output_dir outputs/t2i
|
| 13 |
+
|
| 14 |
+
# GenEval
|
| 15 |
+
python unified_inference.py --mode geneval --model_path ./model --model_type flux \
|
| 16 |
+
--metadata_file evaluation_metadata.jsonl --output_dir outputs/geneval --n_samples 4
|
| 17 |
+
|
| 18 |
+
# Image Editing
|
| 19 |
+
python unified_inference.py --mode edit --model_path ./model --model_type kontext \
|
| 20 |
+
--data_file data.parquet --output_dir outputs/edit
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 24 |
+
import argparse
|
| 25 |
+
import json
|
| 26 |
+
import os
|
| 27 |
+
import traceback
|
| 28 |
+
from tqdm import tqdm
|
| 29 |
+
import torch
|
| 30 |
+
import numpy as np
|
| 31 |
+
from PIL import Image
|
| 32 |
+
from transformers import AutoProcessor
|
| 33 |
+
import random
|
| 34 |
+
import multiprocessing as mp
|
| 35 |
+
import pandas as pd
|
| 36 |
+
from io import BytesIO
|
| 37 |
+
import base64
|
| 38 |
+
from torchvision import transforms as TF
|
| 39 |
+
|
| 40 |
+
# Model imports
|
| 41 |
+
from unimodel.qwenflux.qwenflux_inference import QwenFluxForInferenceLM
|
| 42 |
+
from unimodel.qwenkontext.qwenkontext_inference import QwenKontextForInferenceLM
|
| 43 |
+
|
| 44 |
+
# Global configuration
|
| 45 |
+
NUM_DEVICE = 8
|
| 46 |
+
NUM_PROCESSES = 8
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# =============================================================================
|
| 50 |
+
# CoT Prompt Templates
|
| 51 |
+
# =============================================================================
|
| 52 |
+
COT_PROMPT_TEMPLATES = {
|
| 53 |
+
# General enhancement
|
| 54 |
+
"geneval": """Please provide an enhanced prompt for the following image generation prompt to make the image more realistic, detailed, with clear separation and precise alignment of all entities.
|
| 55 |
+
Original prompt: {original_prompt}. Directly provide the improved prompt in <answer> </answer> tags.""",
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
"ocr_clarity_v2": """Please enhance the following image generation prompt with specific focus on TEXT clarity and readability.
|
| 59 |
+
Original prompt: {original_prompt}. Directly provide the improved prompt in <answer> </answer> tags.""",
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
"quality_purev2": """Rewrite the following image generation prompt to improve its visual quality, detail level, realism, and artistic sophistication.
|
| 63 |
+
|
| 64 |
+
Original prompt: {original_prompt}
|
| 65 |
+
|
| 66 |
+
Directly provide the enhanced version directly in <answer></answer> tags.""",
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
"edit_general": """Please provide an enhanced prompt for the following image editing prompt.
|
| 70 |
+
Ensure the revised prompt is clear, specific, and includes detailed instructions to achieve the desired outcome while maintaining the original intent.
|
| 71 |
+
Original prompt: {original_prompt}. Directly provide the improved prompt in <answer> </answer> tags.""",
|
| 72 |
+
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# =============================================================================
|
| 77 |
+
# Utility Functions
|
| 78 |
+
# =============================================================================
|
| 79 |
+
def set_global_seed(seed):
|
| 80 |
+
"""Set global random seed for reproducibility."""
|
| 81 |
+
random.seed(seed)
|
| 82 |
+
np.random.seed(seed)
|
| 83 |
+
torch.manual_seed(seed)
|
| 84 |
+
torch.cuda.manual_seed(seed)
|
| 85 |
+
torch.cuda.manual_seed_all(seed)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# =============================================================================
|
| 91 |
+
# Model Loading
|
| 92 |
+
# =============================================================================
|
| 93 |
+
def load_model_pipeline(model_path, model_type, device):
|
| 94 |
+
"""Load model pipeline based on model type."""
|
| 95 |
+
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
| 96 |
+
subfolder = model_path.split('/')[-1]
|
| 97 |
+
model_path = model_path.replace(f"/{subfolder}", "")
|
| 98 |
+
if model_type == "flux":
|
| 99 |
+
model = QwenFluxForInferenceLM.from_pretrained(
|
| 100 |
+
model_path, torch_dtype=torch.bfloat16, subfolder=subfolder
|
| 101 |
+
)
|
| 102 |
+
elif model_type == "sana":
|
| 103 |
+
model = QwenSanaForInferenceLM.from_pretrained(
|
| 104 |
+
model_path, torch_dtype=torch.bfloat16, subfolder=subfolder
|
| 105 |
+
)
|
| 106 |
+
elif model_type == "sd3":
|
| 107 |
+
model = QwenSD3ForInferenceLM.from_pretrained(
|
| 108 |
+
model_path, torch_dtype=torch.bfloat16, subfolder=subfolder
|
| 109 |
+
)
|
| 110 |
+
elif model_type == "kontext":
|
| 111 |
+
model = QwenKontextForInferenceLM.from_pretrained(
|
| 112 |
+
model_path, torch_dtype=torch.bfloat16, subfolder=subfolder
|
| 113 |
+
)
|
| 114 |
+
else:
|
| 115 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
| 116 |
+
|
| 117 |
+
processor.tokenizer.padding_side = "left" # for batch inference
|
| 118 |
+
model.to(device)
|
| 119 |
+
|
| 120 |
+
return model, processor
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# =============================================================================
|
| 124 |
+
# Data Loading Functions
|
| 125 |
+
# =============================================================================
|
| 126 |
+
def load_prompts_from_txt(txt_file):
|
| 127 |
+
"""Load prompts from text file (one per line)."""
|
| 128 |
+
with open(txt_file, 'r', encoding='utf-8') as f:
|
| 129 |
+
prompts = [line.strip() for line in f if line.strip()]
|
| 130 |
+
return prompts
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def load_prompts_from_jsonl(metadata_file):
|
| 134 |
+
"""Load prompts and metadata from JSONL file."""
|
| 135 |
+
with open(metadata_file) as fp:
|
| 136 |
+
metadatas = [json.loads(line) for line in fp]
|
| 137 |
+
prompts = [metadata['prompt'].strip() for metadata in metadatas]
|
| 138 |
+
return prompts, metadatas
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def load_data_from_parquet(parquet_file):
|
| 142 |
+
"""Load images and prompts from parquet file."""
|
| 143 |
+
df = pd.read_parquet(parquet_file)
|
| 144 |
+
|
| 145 |
+
# Identify column names
|
| 146 |
+
image_col = None
|
| 147 |
+
prompt_col = None
|
| 148 |
+
id_col = None
|
| 149 |
+
|
| 150 |
+
for col in df.columns:
|
| 151 |
+
col_lower = col.lower()
|
| 152 |
+
if 'image' in col_lower and image_col is None:
|
| 153 |
+
image_col = col
|
| 154 |
+
elif any(kw in col_lower for kw in ['prompt', 'text', 'caption', 'instruction']) and prompt_col is None:
|
| 155 |
+
prompt_col = col
|
| 156 |
+
elif any(kw in col_lower for kw in ['id', 'index']) and id_col is None:
|
| 157 |
+
id_col = col
|
| 158 |
+
|
| 159 |
+
if image_col is None or prompt_col is None:
|
| 160 |
+
raise ValueError(
|
| 161 |
+
f"Cannot identify columns. Found: {df.columns.tolist()}\n"
|
| 162 |
+
f"Expected 'image' and 'prompt'/'text'/'caption'"
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
print(f"Using columns - Image: '{image_col}', Prompt: '{prompt_col}', ID: '{id_col}'")
|
| 166 |
+
|
| 167 |
+
data_list = []
|
| 168 |
+
for idx, row in tqdm(df.iterrows(), total=len(df), desc="Loading parquet"):
|
| 169 |
+
try:
|
| 170 |
+
image_data = row[image_col]["bytes"]
|
| 171 |
+
|
| 172 |
+
if isinstance(image_data, bytes):
|
| 173 |
+
image = Image.open(BytesIO(image_data)).convert('RGB')
|
| 174 |
+
elif isinstance(image_data, str):
|
| 175 |
+
if image_data.startswith('data:image') or image_data.startswith('/9j/') or image_data.startswith('iVBOR'):
|
| 176 |
+
if 'base64,' in image_data:
|
| 177 |
+
image_data = image_data.split('base64,')[1]
|
| 178 |
+
image_bytes = base64.b64decode(image_data)
|
| 179 |
+
image = Image.open(BytesIO(image_bytes)).convert('RGB')
|
| 180 |
+
else:
|
| 181 |
+
image = Image.open(image_data).convert('RGB')
|
| 182 |
+
else:
|
| 183 |
+
print(f"Warning: Skipping row {idx} - unsupported image format")
|
| 184 |
+
continue
|
| 185 |
+
|
| 186 |
+
prompt = str(row[prompt_col])
|
| 187 |
+
item_id = row[id_col] if id_col else idx
|
| 188 |
+
|
| 189 |
+
data_list.append({
|
| 190 |
+
'image': image,
|
| 191 |
+
'prompt': prompt,
|
| 192 |
+
'id': item_id,
|
| 193 |
+
'index': idx
|
| 194 |
+
})
|
| 195 |
+
except Exception as e:
|
| 196 |
+
print(f"Error loading row {idx}: {e}")
|
| 197 |
+
continue
|
| 198 |
+
|
| 199 |
+
print(f"Loaded {len(data_list)} samples from parquet")
|
| 200 |
+
return data_list
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# =============================================================================
|
| 204 |
+
# Image Grid Utility
|
| 205 |
+
# =============================================================================
|
| 206 |
+
def create_image_grid(images, rows, cols):
|
| 207 |
+
"""Create a grid image from a list of images."""
|
| 208 |
+
assert len(images) == rows * cols
|
| 209 |
+
width, height = images[0].size
|
| 210 |
+
grid_width = width * cols
|
| 211 |
+
grid_height = height * rows
|
| 212 |
+
grid_image = Image.new('RGB', (grid_width, grid_height))
|
| 213 |
+
for i, image in enumerate(images):
|
| 214 |
+
x = (i % cols) * width
|
| 215 |
+
y = (i // cols) * height
|
| 216 |
+
grid_image.paste(image, (x, y))
|
| 217 |
+
return grid_image
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# =============================================================================
|
| 221 |
+
# Generation Functions
|
| 222 |
+
# =============================================================================
|
| 223 |
+
def generate_t2i_batch(
|
| 224 |
+
prompts, start_idx, pipeline, processor, output_dir, batch_size,
|
| 225 |
+
guidance_scale, num_inference_steps, seed, use_cot, cot_template_name,
|
| 226 |
+
add_instruction, device_id
|
| 227 |
+
):
|
| 228 |
+
"""Generate images from text prompts (T2I mode)."""
|
| 229 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 230 |
+
|
| 231 |
+
for i in tqdm(range(0, len(prompts), batch_size), desc=f"GPU {device_id} T2I"):
|
| 232 |
+
batch_prompts = prompts[i:i + batch_size]
|
| 233 |
+
batch_start_idx = start_idx + i
|
| 234 |
+
original_prompts = batch_prompts.copy()
|
| 235 |
+
|
| 236 |
+
if add_instruction:
|
| 237 |
+
batch_prompts = [
|
| 238 |
+
f"Please generate image based on the following caption: {p}"
|
| 239 |
+
for p in batch_prompts
|
| 240 |
+
]
|
| 241 |
+
|
| 242 |
+
diffusion_kwargs = dict(
|
| 243 |
+
guidance_scale=guidance_scale,
|
| 244 |
+
num_inference_steps=num_inference_steps,
|
| 245 |
+
num_images_per_prompt=1,
|
| 246 |
+
generator=torch.Generator("cpu").manual_seed(seed)
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
try:
|
| 250 |
+
with torch.no_grad():
|
| 251 |
+
if use_cot:
|
| 252 |
+
llm_kwargs = dict(
|
| 253 |
+
max_new_tokens=256, temperature=0.7, top_p=0.9,
|
| 254 |
+
do_sample=False, num_return_sequences=1
|
| 255 |
+
)
|
| 256 |
+
cot_template = COT_PROMPT_TEMPLATES.get(cot_template_name)
|
| 257 |
+
outputs = pipeline.generate_image_cot(
|
| 258 |
+
texts=batch_prompts,
|
| 259 |
+
diffusion_kwargs=diffusion_kwargs,
|
| 260 |
+
processor=processor,
|
| 261 |
+
llm_kwargs=llm_kwargs,
|
| 262 |
+
cot_prompt_template=cot_template
|
| 263 |
+
)
|
| 264 |
+
images = outputs["images"]
|
| 265 |
+
thinking_prompts = outputs.get("improved_prompts", [])
|
| 266 |
+
else:
|
| 267 |
+
images = pipeline.generate_image(
|
| 268 |
+
texts=batch_prompts,
|
| 269 |
+
diffusion_kwargs=diffusion_kwargs
|
| 270 |
+
)
|
| 271 |
+
thinking_prompts = []
|
| 272 |
+
|
| 273 |
+
for j, img in enumerate(images):
|
| 274 |
+
img_idx = batch_start_idx + j
|
| 275 |
+
base_name = f"{img_idx:05d}"
|
| 276 |
+
|
| 277 |
+
img.save(os.path.join(output_dir, f"{base_name}.png"))
|
| 278 |
+
|
| 279 |
+
with open(os.path.join(output_dir, f"{base_name}_caption.txt"), 'w', encoding='utf-8') as f:
|
| 280 |
+
f.write(original_prompts[j])
|
| 281 |
+
|
| 282 |
+
if use_cot and j < len(thinking_prompts):
|
| 283 |
+
with open(os.path.join(output_dir, f"{base_name}_thinking.txt"), 'w', encoding='utf-8') as f:
|
| 284 |
+
f.write(thinking_prompts[j])
|
| 285 |
+
|
| 286 |
+
except Exception as e:
|
| 287 |
+
print(f"Error at batch {batch_start_idx}: {e}")
|
| 288 |
+
traceback.print_exc()
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def generate_geneval_batch(
|
| 292 |
+
prompts, metadatas, start_idx, pipeline, processor, output_dir, batch_size,
|
| 293 |
+
guidance_scale, num_inference_steps, seed, n_samples, use_cot,
|
| 294 |
+
cot_template_name, skip_grid, device_id
|
| 295 |
+
):
|
| 296 |
+
"""Generate multiple samples per prompt for evaluation (GenEval mode)."""
|
| 297 |
+
for prompt_idx, (prompt, metadata) in enumerate(zip(prompts, metadatas)):
|
| 298 |
+
global_idx = start_idx + prompt_idx
|
| 299 |
+
outpath = os.path.join(output_dir, f"{device_id}_{prompt_idx:0>5}")
|
| 300 |
+
os.makedirs(outpath, exist_ok=True)
|
| 301 |
+
sample_path = os.path.join(outpath, "samples")
|
| 302 |
+
os.makedirs(sample_path, exist_ok=True)
|
| 303 |
+
|
| 304 |
+
with open(os.path.join(outpath, "metadata.jsonl"), "w") as fp:
|
| 305 |
+
json.dump(metadata, fp)
|
| 306 |
+
|
| 307 |
+
sample_count = 0
|
| 308 |
+
all_samples = []
|
| 309 |
+
enhanced_prompts = []
|
| 310 |
+
total_batches = (n_samples + batch_size - 1) // batch_size
|
| 311 |
+
|
| 312 |
+
for batch_idx in tqdm(range(total_batches), desc=f"GPU {device_id} prompt {prompt_idx}"):
|
| 313 |
+
num_images = min(batch_size, n_samples - sample_count)
|
| 314 |
+
|
| 315 |
+
diffusion_kwargs = dict(
|
| 316 |
+
guidance_scale=guidance_scale,
|
| 317 |
+
num_inference_steps=num_inference_steps,
|
| 318 |
+
num_images_per_prompt=num_images,
|
| 319 |
+
generator=torch.Generator("cpu").manual_seed(seed)
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
try:
|
| 323 |
+
with torch.inference_mode():
|
| 324 |
+
if use_cot:
|
| 325 |
+
llm_kwargs = dict(
|
| 326 |
+
max_new_tokens=256, temperature=0.7, top_p=0.9,
|
| 327 |
+
do_sample=False, num_return_sequences=1
|
| 328 |
+
)
|
| 329 |
+
cot_template = COT_PROMPT_TEMPLATES.get(cot_template_name)
|
| 330 |
+
outputs = pipeline.generate_image_cot(
|
| 331 |
+
texts=prompt,
|
| 332 |
+
diffusion_kwargs=diffusion_kwargs,
|
| 333 |
+
processor=processor,
|
| 334 |
+
llm_kwargs=llm_kwargs,
|
| 335 |
+
cot_prompt_template=cot_template
|
| 336 |
+
)
|
| 337 |
+
images = outputs["images"]
|
| 338 |
+
enhanced_prompts.extend(outputs.get("improved_prompts", []))
|
| 339 |
+
else:
|
| 340 |
+
images = pipeline.generate_image(
|
| 341 |
+
texts=prompt,
|
| 342 |
+
diffusion_kwargs=diffusion_kwargs
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
for img in images:
|
| 346 |
+
img.save(os.path.join(sample_path, f"{sample_count:05}.png"))
|
| 347 |
+
sample_count += 1
|
| 348 |
+
if not skip_grid:
|
| 349 |
+
all_samples.append(img)
|
| 350 |
+
|
| 351 |
+
except Exception as e:
|
| 352 |
+
print(f"Error at prompt {prompt_idx}, batch {batch_idx}: {e}")
|
| 353 |
+
traceback.print_exc()
|
| 354 |
+
|
| 355 |
+
# Save enhanced prompts
|
| 356 |
+
with open(os.path.join(outpath, "thinking_prompts.txt"), "w") as fp:
|
| 357 |
+
for ep in enhanced_prompts:
|
| 358 |
+
fp.write(f"{ep}\n")
|
| 359 |
+
|
| 360 |
+
# Create grid
|
| 361 |
+
if not skip_grid and all_samples:
|
| 362 |
+
rows = int(np.sqrt(n_samples))
|
| 363 |
+
cols = (n_samples + rows - 1) // rows
|
| 364 |
+
if rows * cols >= len(all_samples):
|
| 365 |
+
grid_image = create_image_grid(all_samples[:rows * cols], rows, cols)
|
| 366 |
+
grid_image.save(os.path.join(outpath, "grid.jpg"))
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def generate_edit_batch(
|
| 370 |
+
data_batch, start_idx, pipeline, processor, output_dir, batch_size,
|
| 371 |
+
guidance_scale, num_inference_steps, seed, use_cot, cot_template_name,
|
| 372 |
+
device_id, resolution
|
| 373 |
+
):
|
| 374 |
+
"""Edit images based on prompts (Edit mode)."""
|
| 375 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 376 |
+
|
| 377 |
+
transform = TF.Compose([
|
| 378 |
+
TF.Resize(resolution),
|
| 379 |
+
TF.CenterCrop(resolution)
|
| 380 |
+
])
|
| 381 |
+
|
| 382 |
+
for i in tqdm(range(0, len(data_batch), batch_size), desc=f"GPU {device_id} Edit"):
|
| 383 |
+
batch_data = data_batch[i:i + batch_size]
|
| 384 |
+
batch_start_idx = start_idx + i
|
| 385 |
+
|
| 386 |
+
batch_images = [transform(item['image']) for item in batch_data]
|
| 387 |
+
batch_prompts = [item['prompt'] for item in batch_data]
|
| 388 |
+
batch_ids = [item['id'] for item in batch_data]
|
| 389 |
+
|
| 390 |
+
diffusion_kwargs = dict(
|
| 391 |
+
guidance_scale=guidance_scale,
|
| 392 |
+
num_inference_steps=num_inference_steps,
|
| 393 |
+
num_images_per_prompt=1,
|
| 394 |
+
generator=torch.Generator("cpu").manual_seed(seed),
|
| 395 |
+
max_area=resolution ** 2
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
try:
|
| 399 |
+
with torch.no_grad():
|
| 400 |
+
if use_cot:
|
| 401 |
+
llm_kwargs = dict(
|
| 402 |
+
max_new_tokens=256, temperature=0.7, top_p=0.9,
|
| 403 |
+
do_sample=False, num_return_sequences=1
|
| 404 |
+
)
|
| 405 |
+
cot_template = COT_PROMPT_TEMPLATES.get(cot_template_name)
|
| 406 |
+
outputs = pipeline.generate_image_cot(
|
| 407 |
+
images=batch_images,
|
| 408 |
+
texts=batch_prompts,
|
| 409 |
+
diffusion_kwargs=diffusion_kwargs,
|
| 410 |
+
processor=processor,
|
| 411 |
+
llm_kwargs=llm_kwargs,
|
| 412 |
+
cot_prompt_template=cot_template
|
| 413 |
+
)
|
| 414 |
+
edited_images = outputs["images"]
|
| 415 |
+
improved_prompts = outputs.get("improved_prompts", [])
|
| 416 |
+
else:
|
| 417 |
+
edited_images = pipeline.generate_image(
|
| 418 |
+
images=batch_images,
|
| 419 |
+
texts=batch_prompts,
|
| 420 |
+
diffusion_kwargs=diffusion_kwargs
|
| 421 |
+
)
|
| 422 |
+
improved_prompts = []
|
| 423 |
+
|
| 424 |
+
for j, (edited_img, ref_img) in enumerate(zip(edited_images, batch_images)):
|
| 425 |
+
item_id = batch_ids[j]
|
| 426 |
+
base_name = f"{item_id}"
|
| 427 |
+
|
| 428 |
+
edited_img.save(os.path.join(output_dir, f"{base_name}_edited.png"))
|
| 429 |
+
ref_img.save(os.path.join(output_dir, f"{base_name}_reference.png"))
|
| 430 |
+
|
| 431 |
+
with open(os.path.join(output_dir, f"{base_name}_prompt.txt"), 'w', encoding='utf-8') as f:
|
| 432 |
+
f.write(batch_prompts[j])
|
| 433 |
+
|
| 434 |
+
if use_cot and j < len(improved_prompts):
|
| 435 |
+
with open(os.path.join(output_dir, f"{base_name}_improved_prompt.txt"), 'w', encoding='utf-8') as f:
|
| 436 |
+
f.write(improved_prompts[j])
|
| 437 |
+
|
| 438 |
+
except Exception as e:
|
| 439 |
+
print(f"Error at batch {batch_start_idx}: {e}")
|
| 440 |
+
traceback.print_exc()
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
# =============================================================================
|
| 444 |
+
# Worker Process
|
| 445 |
+
# =============================================================================
|
| 446 |
+
def worker_process(
|
| 447 |
+
device_id, mode, data, start_idx, pipeline, processor, output_dir,
|
| 448 |
+
batch_size, guidance_scale, num_inference_steps, seed, use_cot,
|
| 449 |
+
cot_template_name, add_instruction, n_samples, skip_grid, resolution, metadatas=None
|
| 450 |
+
):
|
| 451 |
+
"""Single GPU worker process."""
|
| 452 |
+
torch.cuda.set_device(f"cuda:{device_id % NUM_DEVICE}")
|
| 453 |
+
|
| 454 |
+
print(f"GPU {device_id}: Processing {len(data)} items (indices {start_idx} to {start_idx + len(data) - 1})")
|
| 455 |
+
|
| 456 |
+
if mode == "t2i":
|
| 457 |
+
generate_t2i_batch(
|
| 458 |
+
prompts=data, start_idx=start_idx, pipeline=pipeline,
|
| 459 |
+
processor=processor, output_dir=output_dir, batch_size=batch_size,
|
| 460 |
+
guidance_scale=guidance_scale, num_inference_steps=num_inference_steps,
|
| 461 |
+
seed=seed, use_cot=use_cot, cot_template_name=cot_template_name,
|
| 462 |
+
add_instruction=add_instruction, device_id=device_id
|
| 463 |
+
)
|
| 464 |
+
elif mode == "geneval":
|
| 465 |
+
generate_geneval_batch(
|
| 466 |
+
prompts=data, metadatas=metadatas, start_idx=start_idx,
|
| 467 |
+
pipeline=pipeline, processor=processor, output_dir=output_dir,
|
| 468 |
+
batch_size=batch_size, guidance_scale=guidance_scale,
|
| 469 |
+
num_inference_steps=num_inference_steps, seed=seed,
|
| 470 |
+
n_samples=n_samples, use_cot=use_cot, cot_template_name=cot_template_name,
|
| 471 |
+
skip_grid=skip_grid, device_id=device_id
|
| 472 |
+
)
|
| 473 |
+
elif mode == "edit":
|
| 474 |
+
generate_edit_batch(
|
| 475 |
+
data_batch=data, start_idx=start_idx, pipeline=pipeline,
|
| 476 |
+
processor=processor, output_dir=output_dir, batch_size=batch_size,
|
| 477 |
+
guidance_scale=guidance_scale, num_inference_steps=num_inference_steps,
|
| 478 |
+
seed=seed, use_cot=use_cot, cot_template_name=cot_template_name,
|
| 479 |
+
device_id=device_id, resolution=resolution
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
print(f"GPU {device_id}: Completed!")
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
# =============================================================================
|
| 486 |
+
# Argument Parser
|
| 487 |
+
# =============================================================================
|
| 488 |
+
def parse_args():
|
| 489 |
+
parser = argparse.ArgumentParser(
|
| 490 |
+
description="Unified Inference Script for Image Generation and Editing"
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
# Mode selection
|
| 494 |
+
parser.add_argument(
|
| 495 |
+
"--mode", type=str, required=True,
|
| 496 |
+
choices=["t2i", "geneval", "edit"],
|
| 497 |
+
help="Inference mode: t2i (text-to-image), geneval (evaluation), edit (image editing)"
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
# Input/Output
|
| 501 |
+
parser.add_argument("--prompt_file", type=str, help="Text file with prompts (for t2i mode)")
|
| 502 |
+
parser.add_argument("--metadata_file", type=str, help="JSONL metadata file (for geneval mode)")
|
| 503 |
+
parser.add_argument("--data_file", type=str, help="Parquet file with images and prompts (for edit mode)")
|
| 504 |
+
parser.add_argument("--output_dir", type=str, default="outputs", help="Output directory")
|
| 505 |
+
|
| 506 |
+
# Model configuration
|
| 507 |
+
parser.add_argument("--model_path", type=str, required=True, help="Model path")
|
| 508 |
+
parser.add_argument(
|
| 509 |
+
"--model_type", type=str, default="flux",
|
| 510 |
+
choices=["flux", "sana", "sd3", "kontext"],
|
| 511 |
+
help="Model type"
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# Generation parameters
|
| 515 |
+
parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
|
| 516 |
+
parser.add_argument("--resolution", type=int, default=1024, help="Image resolution")
|
| 517 |
+
parser.add_argument("--guidance_scale", type=float, default=3.5, help="CFG guidance scale")
|
| 518 |
+
parser.add_argument("--num_inference_steps", type=int, default=40, help="Inference steps")
|
| 519 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
| 520 |
+
|
| 521 |
+
# CoT options
|
| 522 |
+
parser.add_argument("--use_cot", action="store_true", help="Use Chain of Thought")
|
| 523 |
+
parser.add_argument(
|
| 524 |
+
"--cot_template", type=str, default="general",
|
| 525 |
+
choices=list(COT_PROMPT_TEMPLATES.keys()),
|
| 526 |
+
help="CoT prompt template"
|
| 527 |
+
)
|
| 528 |
+
parser.add_argument("--add_instruction", action="store_true", help="Add instruction prefix (t2i mode)")
|
| 529 |
+
|
| 530 |
+
# GenEval specific
|
| 531 |
+
parser.add_argument("--n_samples", type=int, default=4, help="Samples per prompt (geneval mode)")
|
| 532 |
+
parser.add_argument("--skip_grid", action="store_true", help="Skip grid image (geneval mode)")
|
| 533 |
+
|
| 534 |
+
# Hardware
|
| 535 |
+
parser.add_argument("--num_gpus", type=int, default=None, help="Number of GPUs to use")
|
| 536 |
+
parser.add_argument("--max_samples", type=int, default=None, help="Max samples to process")
|
| 537 |
+
|
| 538 |
+
return parser.parse_args()
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
# =============================================================================
|
| 542 |
+
# Main Function
|
| 543 |
+
# =============================================================================
|
| 544 |
+
def main():
|
| 545 |
+
mp.set_start_method('spawn', force=True)
|
| 546 |
+
args = parse_args()
|
| 547 |
+
|
| 548 |
+
global NUM_PROCESSES
|
| 549 |
+
if args.num_gpus is not None:
|
| 550 |
+
NUM_PROCESSES = min(args.num_gpus, NUM_DEVICE)
|
| 551 |
+
|
| 552 |
+
# Validate mode-specific arguments
|
| 553 |
+
if args.mode == "t2i" and not args.prompt_file:
|
| 554 |
+
raise ValueError("--prompt_file is required for t2i mode")
|
| 555 |
+
if args.mode == "geneval" and not args.metadata_file:
|
| 556 |
+
raise ValueError("--metadata_file is required for geneval mode")
|
| 557 |
+
if args.mode == "edit" and not args.data_file:
|
| 558 |
+
raise ValueError("--data_file is required for edit mode")
|
| 559 |
+
if args.mode == "edit" and args.model_type != "kontext":
|
| 560 |
+
print(f"Warning: edit mode typically uses kontext model, but got {args.model_type}")
|
| 561 |
+
|
| 562 |
+
# Load data based on mode
|
| 563 |
+
print(f"Mode: {args.mode}")
|
| 564 |
+
metadatas = None
|
| 565 |
+
|
| 566 |
+
if args.mode == "t2i":
|
| 567 |
+
print(f"Loading prompts from {args.prompt_file}...")
|
| 568 |
+
data = load_prompts_from_txt(args.prompt_file)
|
| 569 |
+
elif args.mode == "geneval":
|
| 570 |
+
print(f"Loading metadata from {args.metadata_file}...")
|
| 571 |
+
data, metadatas = load_prompts_from_jsonl(args.metadata_file)
|
| 572 |
+
elif args.mode == "edit":
|
| 573 |
+
print(f"Loading data from {args.data_file}...")
|
| 574 |
+
data = load_data_from_parquet(args.data_file)
|
| 575 |
+
|
| 576 |
+
# Apply max_samples limit
|
| 577 |
+
if args.max_samples is not None:
|
| 578 |
+
if args.mode == "geneval":
|
| 579 |
+
data = data[:args.max_samples]
|
| 580 |
+
metadatas = metadatas[:args.max_samples]
|
| 581 |
+
else:
|
| 582 |
+
data = data[:args.max_samples]
|
| 583 |
+
print(f"Limited to {len(data)} samples")
|
| 584 |
+
|
| 585 |
+
print(f"Total samples: {len(data)}")
|
| 586 |
+
|
| 587 |
+
# Create output directory
|
| 588 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 589 |
+
|
| 590 |
+
# Save configuration
|
| 591 |
+
config_path = os.path.join(args.output_dir, "config.json")
|
| 592 |
+
config_dict = vars(args).copy()
|
| 593 |
+
with open(config_path, 'w') as f:
|
| 594 |
+
json.dump(config_dict, f, indent=2)
|
| 595 |
+
print(f"Config saved to {config_path}")
|
| 596 |
+
|
| 597 |
+
# Load models
|
| 598 |
+
print("Loading models...")
|
| 599 |
+
pipelines = []
|
| 600 |
+
processors = []
|
| 601 |
+
|
| 602 |
+
for i in range(NUM_DEVICE):
|
| 603 |
+
print(f"Loading model {i+1}/{NUM_DEVICE} on cuda:{i % NUM_DEVICE}...")
|
| 604 |
+
pipeline, processor = load_model_pipeline(
|
| 605 |
+
args.model_path, args.model_type, f"cuda:{i % NUM_DEVICE}"
|
| 606 |
+
)
|
| 607 |
+
pipelines.append(pipeline)
|
| 608 |
+
processors.append(processor)
|
| 609 |
+
|
| 610 |
+
print("All models loaded!")
|
| 611 |
+
|
| 612 |
+
# Distribute data across GPUs
|
| 613 |
+
samples_per_gpu = len(data) // NUM_PROCESSES
|
| 614 |
+
|
| 615 |
+
with ThreadPoolExecutor(max_workers=NUM_PROCESSES) as executor:
|
| 616 |
+
futures = []
|
| 617 |
+
|
| 618 |
+
for device_id in range(NUM_PROCESSES):
|
| 619 |
+
start_idx = device_id * samples_per_gpu
|
| 620 |
+
end_idx = len(data) if device_id == NUM_PROCESSES - 1 else start_idx + samples_per_gpu
|
| 621 |
+
|
| 622 |
+
gpu_data = data[start_idx:end_idx]
|
| 623 |
+
gpu_metadatas = metadatas[start_idx:end_idx] if metadatas else None
|
| 624 |
+
|
| 625 |
+
future = executor.submit(
|
| 626 |
+
worker_process,
|
| 627 |
+
device_id=device_id,
|
| 628 |
+
mode=args.mode,
|
| 629 |
+
data=gpu_data,
|
| 630 |
+
start_idx=start_idx,
|
| 631 |
+
pipeline=pipelines[device_id % NUM_DEVICE],
|
| 632 |
+
processor=processors[device_id % NUM_DEVICE],
|
| 633 |
+
output_dir=args.output_dir,
|
| 634 |
+
batch_size=args.batch_size,
|
| 635 |
+
guidance_scale=args.guidance_scale,
|
| 636 |
+
num_inference_steps=args.num_inference_steps,
|
| 637 |
+
seed=args.seed,
|
| 638 |
+
use_cot=args.use_cot,
|
| 639 |
+
cot_template_name=args.cot_template,
|
| 640 |
+
add_instruction=args.add_instruction,
|
| 641 |
+
n_samples=args.n_samples,
|
| 642 |
+
skip_grid=args.skip_grid,
|
| 643 |
+
resolution=args.resolution,
|
| 644 |
+
metadatas=gpu_metadatas
|
| 645 |
+
)
|
| 646 |
+
futures.append(future)
|
| 647 |
+
|
| 648 |
+
for future in as_completed(futures):
|
| 649 |
+
try:
|
| 650 |
+
future.result()
|
| 651 |
+
except Exception as e:
|
| 652 |
+
print(f"Worker failed: {e}")
|
| 653 |
+
traceback.print_exc()
|
| 654 |
+
|
| 655 |
+
print(f"\n✓ Done! Results saved to {args.output_dir}")
|
| 656 |
+
print(f" Total processed: {len(data)}")
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
if __name__ == "__main__":
|
| 660 |
+
main()
|
unimodel/qwenflux/fluxpipeline.py
ADDED
|
@@ -0,0 +1,1543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
| 2 |
+
# Copyright 2025 Fu-Yun Wang
|
| 3 |
+
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import inspect
|
| 18 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
from transformers import (
|
| 23 |
+
CLIPImageProcessor,
|
| 24 |
+
CLIPTextModel,
|
| 25 |
+
CLIPTokenizer,
|
| 26 |
+
CLIPVisionModelWithProjection,
|
| 27 |
+
T5EncoderModel,
|
| 28 |
+
T5TokenizerFast,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 32 |
+
from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
| 33 |
+
from diffusers.models import AutoencoderKL, FluxTransformer2DModel
|
| 34 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 35 |
+
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteSchedulerOutput
|
| 36 |
+
from diffusers.utils import (
|
| 37 |
+
USE_PEFT_BACKEND,
|
| 38 |
+
is_torch_xla_available,
|
| 39 |
+
logging,
|
| 40 |
+
replace_example_docstring,
|
| 41 |
+
scale_lora_layers,
|
| 42 |
+
unscale_lora_layers,
|
| 43 |
+
)
|
| 44 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 45 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 46 |
+
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
| 47 |
+
import math
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if is_torch_xla_available():
|
| 51 |
+
import torch_xla.core.xla_model as xm
|
| 52 |
+
|
| 53 |
+
XLA_AVAILABLE = True
|
| 54 |
+
else:
|
| 55 |
+
XLA_AVAILABLE = False
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 59 |
+
|
| 60 |
+
EXAMPLE_DOC_STRING = """
|
| 61 |
+
Examples:
|
| 62 |
+
```py
|
| 63 |
+
>>> import torch
|
| 64 |
+
>>> from diffusers import FluxPipeline
|
| 65 |
+
|
| 66 |
+
>>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
|
| 67 |
+
>>> pipe.to("cuda")
|
| 68 |
+
>>> prompt = "A cat holding a sign that says hello world"
|
| 69 |
+
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
| 70 |
+
>>> # Refer to the pipeline documentation for more details.
|
| 71 |
+
>>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
|
| 72 |
+
>>> image.save("flux.png")
|
| 73 |
+
```
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def calculate_shift(
|
| 78 |
+
image_seq_len,
|
| 79 |
+
base_seq_len: int = 256,
|
| 80 |
+
max_seq_len: int = 4096,
|
| 81 |
+
base_shift: float = 0.5,
|
| 82 |
+
max_shift: float = 1.15,
|
| 83 |
+
):
|
| 84 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 85 |
+
b = base_shift - m * base_seq_len
|
| 86 |
+
mu = image_seq_len * m + b
|
| 87 |
+
return mu
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 91 |
+
def retrieve_timesteps(
|
| 92 |
+
scheduler,
|
| 93 |
+
num_inference_steps: Optional[int] = None,
|
| 94 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 95 |
+
timesteps: Optional[List[int]] = None,
|
| 96 |
+
sigmas: Optional[List[float]] = None,
|
| 97 |
+
**kwargs,
|
| 98 |
+
):
|
| 99 |
+
r"""
|
| 100 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 101 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
scheduler (`SchedulerMixin`):
|
| 105 |
+
The scheduler to get timesteps from.
|
| 106 |
+
num_inference_steps (`int`):
|
| 107 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 108 |
+
must be `None`.
|
| 109 |
+
device (`str` or `torch.device`, *optional*):
|
| 110 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 111 |
+
timesteps (`List[int]`, *optional*):
|
| 112 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 113 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 114 |
+
sigmas (`List[float]`, *optional*):
|
| 115 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 116 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 120 |
+
second element is the number of inference steps.
|
| 121 |
+
"""
|
| 122 |
+
if timesteps is not None and sigmas is not None:
|
| 123 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 124 |
+
if timesteps is not None:
|
| 125 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 126 |
+
if not accepts_timesteps:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 129 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 130 |
+
)
|
| 131 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 132 |
+
timesteps = scheduler.timesteps
|
| 133 |
+
num_inference_steps = len(timesteps)
|
| 134 |
+
elif sigmas is not None:
|
| 135 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 136 |
+
if not accept_sigmas:
|
| 137 |
+
raise ValueError(
|
| 138 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 139 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 140 |
+
)
|
| 141 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 142 |
+
timesteps = scheduler.timesteps
|
| 143 |
+
num_inference_steps = len(timesteps)
|
| 144 |
+
else:
|
| 145 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 146 |
+
timesteps = scheduler.timesteps
|
| 147 |
+
return timesteps, num_inference_steps
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class FluxPipeline(
|
| 151 |
+
DiffusionPipeline,
|
| 152 |
+
FluxLoraLoaderMixin,
|
| 153 |
+
FromSingleFileMixin,
|
| 154 |
+
TextualInversionLoaderMixin,
|
| 155 |
+
FluxIPAdapterMixin,
|
| 156 |
+
):
|
| 157 |
+
r"""
|
| 158 |
+
The Flux pipeline for text-to-image generation.
|
| 159 |
+
|
| 160 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
transformer ([`FluxTransformer2DModel`]):
|
| 164 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 165 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 166 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 167 |
+
vae ([`AutoencoderKL`]):
|
| 168 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 169 |
+
text_encoder ([`CLIPTextModel`]):
|
| 170 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
| 171 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
| 172 |
+
text_encoder_2 ([`T5EncoderModel`]):
|
| 173 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
| 174 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
| 175 |
+
tokenizer (`CLIPTokenizer`):
|
| 176 |
+
Tokenizer of class
|
| 177 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 178 |
+
tokenizer_2 (`T5TokenizerFast`):
|
| 179 |
+
Second Tokenizer of class
|
| 180 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
|
| 184 |
+
_optional_components = ["image_encoder", "feature_extractor"]
|
| 185 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 186 |
+
|
| 187 |
+
def __init__(
|
| 188 |
+
self,
|
| 189 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 190 |
+
vae: AutoencoderKL,
|
| 191 |
+
text_encoder: CLIPTextModel,
|
| 192 |
+
tokenizer: CLIPTokenizer,
|
| 193 |
+
text_encoder_2: T5EncoderModel,
|
| 194 |
+
tokenizer_2: T5TokenizerFast,
|
| 195 |
+
transformer: FluxTransformer2DModel,
|
| 196 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 197 |
+
feature_extractor: CLIPImageProcessor = None,
|
| 198 |
+
):
|
| 199 |
+
super().__init__()
|
| 200 |
+
|
| 201 |
+
self.register_modules(
|
| 202 |
+
vae=vae,
|
| 203 |
+
text_encoder=text_encoder,
|
| 204 |
+
text_encoder_2=text_encoder_2,
|
| 205 |
+
tokenizer=tokenizer,
|
| 206 |
+
tokenizer_2=tokenizer_2,
|
| 207 |
+
transformer=transformer,
|
| 208 |
+
scheduler=scheduler,
|
| 209 |
+
image_encoder=image_encoder,
|
| 210 |
+
feature_extractor=feature_extractor,
|
| 211 |
+
)
|
| 212 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 213 |
+
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
| 214 |
+
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
| 215 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 216 |
+
self.tokenizer_max_length = (
|
| 217 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
| 218 |
+
)
|
| 219 |
+
self.default_sample_size = 128
|
| 220 |
+
|
| 221 |
+
def _get_t5_prompt_embeds(
|
| 222 |
+
self,
|
| 223 |
+
prompt: Union[str, List[str]] = None,
|
| 224 |
+
num_images_per_prompt: int = 1,
|
| 225 |
+
max_sequence_length: int = 512,
|
| 226 |
+
device: Optional[torch.device] = None,
|
| 227 |
+
dtype: Optional[torch.dtype] = None,
|
| 228 |
+
):
|
| 229 |
+
device = device or self._execution_device
|
| 230 |
+
dtype = dtype or self.text_encoder.dtype
|
| 231 |
+
|
| 232 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 233 |
+
batch_size = len(prompt)
|
| 234 |
+
|
| 235 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 236 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
| 237 |
+
|
| 238 |
+
text_inputs = self.tokenizer_2(
|
| 239 |
+
prompt,
|
| 240 |
+
padding="max_length",
|
| 241 |
+
max_length=max_sequence_length,
|
| 242 |
+
truncation=True,
|
| 243 |
+
return_length=False,
|
| 244 |
+
return_overflowing_tokens=False,
|
| 245 |
+
return_tensors="pt",
|
| 246 |
+
)
|
| 247 |
+
text_input_ids = text_inputs.input_ids
|
| 248 |
+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
| 249 |
+
|
| 250 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 251 |
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
| 252 |
+
logger.warning(
|
| 253 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 254 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
| 258 |
+
|
| 259 |
+
dtype = self.text_encoder_2.dtype
|
| 260 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 261 |
+
|
| 262 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 263 |
+
|
| 264 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 265 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 266 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 267 |
+
|
| 268 |
+
return prompt_embeds
|
| 269 |
+
|
| 270 |
+
def _get_clip_prompt_embeds(
|
| 271 |
+
self,
|
| 272 |
+
prompt: Union[str, List[str]],
|
| 273 |
+
num_images_per_prompt: int = 1,
|
| 274 |
+
device: Optional[torch.device] = None,
|
| 275 |
+
):
|
| 276 |
+
device = device or self._execution_device
|
| 277 |
+
|
| 278 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 279 |
+
batch_size = len(prompt)
|
| 280 |
+
|
| 281 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 282 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
| 283 |
+
|
| 284 |
+
text_inputs = self.tokenizer(
|
| 285 |
+
prompt,
|
| 286 |
+
padding="max_length",
|
| 287 |
+
max_length=self.tokenizer_max_length,
|
| 288 |
+
truncation=True,
|
| 289 |
+
return_overflowing_tokens=False,
|
| 290 |
+
return_length=False,
|
| 291 |
+
return_tensors="pt",
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
text_input_ids = text_inputs.input_ids
|
| 295 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 296 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 297 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
| 298 |
+
logger.warning(
|
| 299 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 300 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
| 301 |
+
)
|
| 302 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
| 303 |
+
|
| 304 |
+
# Use pooled output of CLIPTextModel
|
| 305 |
+
prompt_embeds = prompt_embeds.pooler_output
|
| 306 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
| 307 |
+
|
| 308 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 309 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
| 310 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 311 |
+
|
| 312 |
+
return prompt_embeds
|
| 313 |
+
|
| 314 |
+
def encode_prompt(
|
| 315 |
+
self,
|
| 316 |
+
prompt: Union[str, List[str]],
|
| 317 |
+
prompt_2: Union[str, List[str]],
|
| 318 |
+
device: Optional[torch.device] = None,
|
| 319 |
+
num_images_per_prompt: int = 1,
|
| 320 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 321 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 322 |
+
max_sequence_length: int = 512,
|
| 323 |
+
lora_scale: Optional[float] = None,
|
| 324 |
+
):
|
| 325 |
+
r"""
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 329 |
+
prompt to be encoded
|
| 330 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 331 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 332 |
+
used in all text-encoders
|
| 333 |
+
device: (`torch.device`):
|
| 334 |
+
torch device
|
| 335 |
+
num_images_per_prompt (`int`):
|
| 336 |
+
number of images that should be generated per prompt
|
| 337 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 338 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 339 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 340 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 341 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 342 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 343 |
+
lora_scale (`float`, *optional*):
|
| 344 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 345 |
+
"""
|
| 346 |
+
device = device or self._execution_device
|
| 347 |
+
|
| 348 |
+
# set lora scale so that monkey patched LoRA
|
| 349 |
+
# function of text encoder can correctly access it
|
| 350 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
| 351 |
+
self._lora_scale = lora_scale
|
| 352 |
+
|
| 353 |
+
# dynamically adjust the LoRA scale
|
| 354 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
| 355 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 356 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
| 357 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
| 358 |
+
|
| 359 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 360 |
+
|
| 361 |
+
if prompt_embeds is None:
|
| 362 |
+
prompt_2 = prompt_2 or prompt
|
| 363 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
| 364 |
+
|
| 365 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
| 366 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
| 367 |
+
prompt=prompt,
|
| 368 |
+
device=device,
|
| 369 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 370 |
+
)
|
| 371 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 372 |
+
prompt=prompt_2,
|
| 373 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 374 |
+
max_sequence_length=max_sequence_length,
|
| 375 |
+
device=device,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
if self.text_encoder is not None:
|
| 379 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 380 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 381 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 382 |
+
|
| 383 |
+
if self.text_encoder_2 is not None:
|
| 384 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 385 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 386 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
| 387 |
+
|
| 388 |
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
| 389 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
| 390 |
+
|
| 391 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
| 392 |
+
|
| 393 |
+
def encode_image(self, image, device, num_images_per_prompt):
|
| 394 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 395 |
+
|
| 396 |
+
if not isinstance(image, torch.Tensor):
|
| 397 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
| 398 |
+
|
| 399 |
+
image = image.to(device=device, dtype=dtype)
|
| 400 |
+
image_embeds = self.image_encoder(image).image_embeds
|
| 401 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 402 |
+
return image_embeds
|
| 403 |
+
|
| 404 |
+
def prepare_ip_adapter_image_embeds(
|
| 405 |
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
|
| 406 |
+
):
|
| 407 |
+
image_embeds = []
|
| 408 |
+
if ip_adapter_image_embeds is None:
|
| 409 |
+
if not isinstance(ip_adapter_image, list):
|
| 410 |
+
ip_adapter_image = [ip_adapter_image]
|
| 411 |
+
|
| 412 |
+
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
| 413 |
+
raise ValueError(
|
| 414 |
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
for single_ip_adapter_image in ip_adapter_image:
|
| 418 |
+
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
|
| 419 |
+
image_embeds.append(single_image_embeds[None, :])
|
| 420 |
+
else:
|
| 421 |
+
if not isinstance(ip_adapter_image_embeds, list):
|
| 422 |
+
ip_adapter_image_embeds = [ip_adapter_image_embeds]
|
| 423 |
+
|
| 424 |
+
if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
| 425 |
+
raise ValueError(
|
| 426 |
+
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
| 430 |
+
image_embeds.append(single_image_embeds)
|
| 431 |
+
|
| 432 |
+
ip_adapter_image_embeds = []
|
| 433 |
+
for single_image_embeds in image_embeds:
|
| 434 |
+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
| 435 |
+
single_image_embeds = single_image_embeds.to(device=device)
|
| 436 |
+
ip_adapter_image_embeds.append(single_image_embeds)
|
| 437 |
+
|
| 438 |
+
return ip_adapter_image_embeds
|
| 439 |
+
|
| 440 |
+
def check_inputs(
|
| 441 |
+
self,
|
| 442 |
+
prompt,
|
| 443 |
+
prompt_2,
|
| 444 |
+
height,
|
| 445 |
+
width,
|
| 446 |
+
negative_prompt=None,
|
| 447 |
+
negative_prompt_2=None,
|
| 448 |
+
prompt_embeds=None,
|
| 449 |
+
negative_prompt_embeds=None,
|
| 450 |
+
pooled_prompt_embeds=None,
|
| 451 |
+
negative_pooled_prompt_embeds=None,
|
| 452 |
+
callback_on_step_end_tensor_inputs=None,
|
| 453 |
+
max_sequence_length=None,
|
| 454 |
+
):
|
| 455 |
+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
| 456 |
+
logger.warning(
|
| 457 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 461 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 462 |
+
):
|
| 463 |
+
raise ValueError(
|
| 464 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
if prompt is not None and prompt_embeds is not None:
|
| 468 |
+
raise ValueError(
|
| 469 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 470 |
+
" only forward one of the two."
|
| 471 |
+
)
|
| 472 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
| 473 |
+
raise ValueError(
|
| 474 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 475 |
+
" only forward one of the two."
|
| 476 |
+
)
|
| 477 |
+
elif prompt is None and prompt_embeds is None:
|
| 478 |
+
raise ValueError(
|
| 479 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 480 |
+
)
|
| 481 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 482 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 483 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
| 484 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
| 485 |
+
|
| 486 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 487 |
+
raise ValueError(
|
| 488 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 489 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 490 |
+
)
|
| 491 |
+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
| 492 |
+
raise ValueError(
|
| 493 |
+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
| 494 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
| 498 |
+
raise ValueError(
|
| 499 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
| 500 |
+
)
|
| 501 |
+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
| 502 |
+
raise ValueError(
|
| 503 |
+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
| 507 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
| 508 |
+
|
| 509 |
+
@staticmethod
|
| 510 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
| 511 |
+
latent_image_ids = torch.zeros(height, width, 3)
|
| 512 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
| 513 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
| 514 |
+
|
| 515 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 516 |
+
|
| 517 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 518 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
| 522 |
+
|
| 523 |
+
@staticmethod
|
| 524 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 525 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
| 526 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 527 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
| 528 |
+
|
| 529 |
+
return latents
|
| 530 |
+
|
| 531 |
+
@staticmethod
|
| 532 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
| 533 |
+
batch_size, num_patches, channels = latents.shape
|
| 534 |
+
|
| 535 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 536 |
+
# latent height and width to be divisible by 2.
|
| 537 |
+
height = 2 * (int(height) // (vae_scale_factor * 2))
|
| 538 |
+
width = 2 * (int(width) // (vae_scale_factor * 2))
|
| 539 |
+
|
| 540 |
+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
| 541 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 542 |
+
|
| 543 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
| 544 |
+
|
| 545 |
+
return latents
|
| 546 |
+
|
| 547 |
+
def enable_vae_slicing(self):
|
| 548 |
+
r"""
|
| 549 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 550 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 551 |
+
"""
|
| 552 |
+
self.vae.enable_slicing()
|
| 553 |
+
|
| 554 |
+
def disable_vae_slicing(self):
|
| 555 |
+
r"""
|
| 556 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
| 557 |
+
computing decoding in one step.
|
| 558 |
+
"""
|
| 559 |
+
self.vae.disable_slicing()
|
| 560 |
+
|
| 561 |
+
def enable_vae_tiling(self):
|
| 562 |
+
r"""
|
| 563 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 564 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 565 |
+
processing larger images.
|
| 566 |
+
"""
|
| 567 |
+
self.vae.enable_tiling()
|
| 568 |
+
|
| 569 |
+
def disable_vae_tiling(self):
|
| 570 |
+
r"""
|
| 571 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
| 572 |
+
computing decoding in one step.
|
| 573 |
+
"""
|
| 574 |
+
self.vae.disable_tiling()
|
| 575 |
+
|
| 576 |
+
def prepare_latents(
|
| 577 |
+
self,
|
| 578 |
+
batch_size,
|
| 579 |
+
num_channels_latents,
|
| 580 |
+
height,
|
| 581 |
+
width,
|
| 582 |
+
dtype,
|
| 583 |
+
device,
|
| 584 |
+
generator,
|
| 585 |
+
latents=None,
|
| 586 |
+
):
|
| 587 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 588 |
+
# latent height and width to be divisible by 2.
|
| 589 |
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 590 |
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 591 |
+
|
| 592 |
+
shape = (batch_size, num_channels_latents, height, width)
|
| 593 |
+
|
| 594 |
+
if latents is not None:
|
| 595 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
| 596 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
| 597 |
+
|
| 598 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 599 |
+
raise ValueError(
|
| 600 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 601 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 605 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
| 606 |
+
|
| 607 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
| 608 |
+
|
| 609 |
+
return latents, latent_image_ids
|
| 610 |
+
|
| 611 |
+
@property
|
| 612 |
+
def guidance_scale(self):
|
| 613 |
+
return self._guidance_scale
|
| 614 |
+
|
| 615 |
+
@property
|
| 616 |
+
def joint_attention_kwargs(self):
|
| 617 |
+
return self._joint_attention_kwargs
|
| 618 |
+
|
| 619 |
+
@property
|
| 620 |
+
def num_timesteps(self):
|
| 621 |
+
return self._num_timesteps
|
| 622 |
+
|
| 623 |
+
@property
|
| 624 |
+
def current_timestep(self):
|
| 625 |
+
return self._current_timestep
|
| 626 |
+
|
| 627 |
+
@property
|
| 628 |
+
def interrupt(self):
|
| 629 |
+
return self._interrupt
|
| 630 |
+
|
| 631 |
+
@torch.no_grad()
|
| 632 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 633 |
+
def __call__(
|
| 634 |
+
self,
|
| 635 |
+
prompt: Union[str, List[str]] = None,
|
| 636 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 637 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 638 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 639 |
+
true_cfg_scale: float = 1.0,
|
| 640 |
+
height: Optional[int] = None,
|
| 641 |
+
width: Optional[int] = None,
|
| 642 |
+
num_inference_steps: int = 28,
|
| 643 |
+
sigmas: Optional[List[float]] = None,
|
| 644 |
+
guidance_scale: float = 3.5,
|
| 645 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 646 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 647 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 648 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 649 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 650 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 651 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 652 |
+
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 653 |
+
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 654 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 655 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 656 |
+
output_type: Optional[str] = "pil",
|
| 657 |
+
return_dict: bool = True,
|
| 658 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 659 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 660 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 661 |
+
max_sequence_length: int = 512,
|
| 662 |
+
):
|
| 663 |
+
r"""
|
| 664 |
+
Function invoked when calling the pipeline for generation.
|
| 665 |
+
|
| 666 |
+
Args:
|
| 667 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 668 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 669 |
+
instead.
|
| 670 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 671 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 672 |
+
will be used instead.
|
| 673 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 674 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 675 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
|
| 676 |
+
not greater than `1`).
|
| 677 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 678 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
| 679 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
| 680 |
+
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
| 681 |
+
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
|
| 682 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 683 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 684 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 685 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 686 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 687 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 688 |
+
expense of slower inference.
|
| 689 |
+
sigmas (`List[float]`, *optional*):
|
| 690 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 691 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 692 |
+
will be used.
|
| 693 |
+
guidance_scale (`float`, *optional*, defaults to 3.5):
|
| 694 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 695 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 696 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 697 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 698 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 699 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 700 |
+
The number of images to generate per prompt.
|
| 701 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 702 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 703 |
+
to make generation deterministic.
|
| 704 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 705 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 706 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 707 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 708 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 709 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 710 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 711 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 712 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 713 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 714 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 715 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 716 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 717 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
| 718 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 719 |
+
negative_ip_adapter_image:
|
| 720 |
+
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 721 |
+
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 722 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 723 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
| 724 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 725 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 726 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 727 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 728 |
+
argument.
|
| 729 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 730 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 731 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 732 |
+
input argument.
|
| 733 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 734 |
+
The output format of the generate image. Choose between
|
| 735 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 736 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 737 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
| 738 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 739 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 740 |
+
`self.processor` in
|
| 741 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 742 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 743 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 744 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 745 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 746 |
+
`callback_on_step_end_tensor_inputs`.
|
| 747 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 748 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 749 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 750 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 751 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
| 752 |
+
|
| 753 |
+
Examples:
|
| 754 |
+
|
| 755 |
+
Returns:
|
| 756 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
| 757 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
| 758 |
+
images.
|
| 759 |
+
"""
|
| 760 |
+
|
| 761 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 762 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 763 |
+
|
| 764 |
+
# 1. Check inputs. Raise error if not correct
|
| 765 |
+
self.check_inputs(
|
| 766 |
+
prompt,
|
| 767 |
+
prompt_2,
|
| 768 |
+
height,
|
| 769 |
+
width,
|
| 770 |
+
negative_prompt=negative_prompt,
|
| 771 |
+
negative_prompt_2=negative_prompt_2,
|
| 772 |
+
prompt_embeds=prompt_embeds,
|
| 773 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 774 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 775 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 776 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 777 |
+
max_sequence_length=max_sequence_length,
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
self._guidance_scale = guidance_scale
|
| 781 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 782 |
+
self._current_timestep = None
|
| 783 |
+
self._interrupt = False
|
| 784 |
+
|
| 785 |
+
# 2. Define call parameters
|
| 786 |
+
if prompt is not None and isinstance(prompt, str):
|
| 787 |
+
batch_size = 1
|
| 788 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 789 |
+
batch_size = len(prompt)
|
| 790 |
+
else:
|
| 791 |
+
batch_size = prompt_embeds.shape[0]
|
| 792 |
+
|
| 793 |
+
device = self._execution_device
|
| 794 |
+
|
| 795 |
+
lora_scale = (
|
| 796 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 797 |
+
)
|
| 798 |
+
has_neg_prompt = negative_prompt is not None or (
|
| 799 |
+
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
|
| 800 |
+
)
|
| 801 |
+
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
| 802 |
+
(
|
| 803 |
+
prompt_embeds,
|
| 804 |
+
pooled_prompt_embeds,
|
| 805 |
+
text_ids,
|
| 806 |
+
) = self.encode_prompt(
|
| 807 |
+
prompt=prompt,
|
| 808 |
+
prompt_2=prompt_2,
|
| 809 |
+
prompt_embeds=prompt_embeds,
|
| 810 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 811 |
+
device=device,
|
| 812 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 813 |
+
max_sequence_length=max_sequence_length,
|
| 814 |
+
lora_scale=lora_scale,
|
| 815 |
+
)
|
| 816 |
+
if do_true_cfg:
|
| 817 |
+
(
|
| 818 |
+
negative_prompt_embeds,
|
| 819 |
+
negative_pooled_prompt_embeds,
|
| 820 |
+
negative_text_ids,
|
| 821 |
+
) = self.encode_prompt(
|
| 822 |
+
prompt=negative_prompt,
|
| 823 |
+
prompt_2=negative_prompt_2,
|
| 824 |
+
prompt_embeds=negative_prompt_embeds,
|
| 825 |
+
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 826 |
+
device=device,
|
| 827 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 828 |
+
max_sequence_length=max_sequence_length,
|
| 829 |
+
lora_scale=lora_scale,
|
| 830 |
+
)
|
| 831 |
+
|
| 832 |
+
# 4. Prepare latent variables
|
| 833 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 834 |
+
latents, latent_image_ids = self.prepare_latents(
|
| 835 |
+
batch_size * num_images_per_prompt,
|
| 836 |
+
num_channels_latents,
|
| 837 |
+
height,
|
| 838 |
+
width,
|
| 839 |
+
prompt_embeds.dtype,
|
| 840 |
+
device,
|
| 841 |
+
generator,
|
| 842 |
+
latents,
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
# 5. Prepare timesteps
|
| 846 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 847 |
+
image_seq_len = latents.shape[1]
|
| 848 |
+
mu = calculate_shift(
|
| 849 |
+
image_seq_len,
|
| 850 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 851 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 852 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 853 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 854 |
+
)
|
| 855 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 856 |
+
self.scheduler,
|
| 857 |
+
num_inference_steps,
|
| 858 |
+
device,
|
| 859 |
+
sigmas=sigmas,
|
| 860 |
+
mu=mu,
|
| 861 |
+
)
|
| 862 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 863 |
+
self._num_timesteps = len(timesteps)
|
| 864 |
+
|
| 865 |
+
# handle guidance
|
| 866 |
+
if self.transformer.config.guidance_embeds:
|
| 867 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 868 |
+
guidance = guidance.expand(latents.shape[0])
|
| 869 |
+
else:
|
| 870 |
+
guidance = None
|
| 871 |
+
|
| 872 |
+
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
|
| 873 |
+
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
|
| 874 |
+
):
|
| 875 |
+
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
| 876 |
+
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
| 877 |
+
|
| 878 |
+
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
|
| 879 |
+
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
|
| 880 |
+
):
|
| 881 |
+
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
| 882 |
+
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
| 883 |
+
|
| 884 |
+
if self.joint_attention_kwargs is None:
|
| 885 |
+
self._joint_attention_kwargs = {}
|
| 886 |
+
|
| 887 |
+
image_embeds = None
|
| 888 |
+
negative_image_embeds = None
|
| 889 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 890 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 891 |
+
ip_adapter_image,
|
| 892 |
+
ip_adapter_image_embeds,
|
| 893 |
+
device,
|
| 894 |
+
batch_size * num_images_per_prompt,
|
| 895 |
+
)
|
| 896 |
+
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
|
| 897 |
+
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 898 |
+
negative_ip_adapter_image,
|
| 899 |
+
negative_ip_adapter_image_embeds,
|
| 900 |
+
device,
|
| 901 |
+
batch_size * num_images_per_prompt,
|
| 902 |
+
)
|
| 903 |
+
|
| 904 |
+
# 6. Denoising loop
|
| 905 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 906 |
+
for i, t in enumerate(timesteps):
|
| 907 |
+
if self.interrupt:
|
| 908 |
+
continue
|
| 909 |
+
|
| 910 |
+
self._current_timestep = t
|
| 911 |
+
if image_embeds is not None:
|
| 912 |
+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
| 913 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 914 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 915 |
+
|
| 916 |
+
noise_pred = self.transformer(
|
| 917 |
+
hidden_states=latents,
|
| 918 |
+
timestep=timestep / 1000,
|
| 919 |
+
guidance=guidance,
|
| 920 |
+
pooled_projections=pooled_prompt_embeds,
|
| 921 |
+
encoder_hidden_states=prompt_embeds,
|
| 922 |
+
txt_ids=text_ids,
|
| 923 |
+
img_ids=latent_image_ids,
|
| 924 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 925 |
+
return_dict=False,
|
| 926 |
+
)[0]
|
| 927 |
+
|
| 928 |
+
if do_true_cfg:
|
| 929 |
+
if negative_image_embeds is not None:
|
| 930 |
+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
| 931 |
+
neg_noise_pred = self.transformer(
|
| 932 |
+
hidden_states=latents,
|
| 933 |
+
timestep=timestep / 1000,
|
| 934 |
+
guidance=guidance,
|
| 935 |
+
pooled_projections=negative_pooled_prompt_embeds,
|
| 936 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 937 |
+
txt_ids=negative_text_ids,
|
| 938 |
+
img_ids=latent_image_ids,
|
| 939 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 940 |
+
return_dict=False,
|
| 941 |
+
)[0]
|
| 942 |
+
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
| 943 |
+
|
| 944 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 945 |
+
latents_dtype = latents.dtype
|
| 946 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 947 |
+
|
| 948 |
+
if latents.dtype != latents_dtype:
|
| 949 |
+
if torch.backends.mps.is_available():
|
| 950 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 951 |
+
latents = latents.to(latents_dtype)
|
| 952 |
+
|
| 953 |
+
if callback_on_step_end is not None:
|
| 954 |
+
callback_kwargs = {}
|
| 955 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 956 |
+
callback_kwargs[k] = locals()[k]
|
| 957 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 958 |
+
|
| 959 |
+
latents = callback_outputs.pop("latents", latents)
|
| 960 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 961 |
+
|
| 962 |
+
# call the callback, if provided
|
| 963 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 964 |
+
progress_bar.update()
|
| 965 |
+
|
| 966 |
+
if XLA_AVAILABLE:
|
| 967 |
+
xm.mark_step()
|
| 968 |
+
|
| 969 |
+
self._current_timestep = None
|
| 970 |
+
|
| 971 |
+
if output_type == "latent":
|
| 972 |
+
image = latents
|
| 973 |
+
else:
|
| 974 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 975 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 976 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 977 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 978 |
+
|
| 979 |
+
# Offload all models
|
| 980 |
+
self.maybe_free_model_hooks()
|
| 981 |
+
|
| 982 |
+
if not return_dict:
|
| 983 |
+
return (image,)
|
| 984 |
+
|
| 985 |
+
return FluxPipelineOutput(images=image)
|
| 986 |
+
|
| 987 |
+
|
| 988 |
+
@torch.no_grad()
|
| 989 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 990 |
+
def sde_sampling(
|
| 991 |
+
self,
|
| 992 |
+
prompt: Union[str, List[str]] = None,
|
| 993 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 994 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 995 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 996 |
+
true_cfg_scale: float = 1.0,
|
| 997 |
+
height: Optional[int] = None,
|
| 998 |
+
width: Optional[int] = None,
|
| 999 |
+
num_inference_steps: int = 28,
|
| 1000 |
+
sigmas: Optional[List[float]] = None,
|
| 1001 |
+
guidance_scale: float = 3.5,
|
| 1002 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 1003 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 1004 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 1005 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 1006 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 1007 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 1008 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 1009 |
+
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 1010 |
+
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 1011 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 1012 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 1013 |
+
output_type: Optional[str] = "pil",
|
| 1014 |
+
return_dict: bool = True,
|
| 1015 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1016 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 1017 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 1018 |
+
max_sequence_length: int = 512,
|
| 1019 |
+
):
|
| 1020 |
+
r"""
|
| 1021 |
+
Function invoked when calling the pipeline for generation.
|
| 1022 |
+
|
| 1023 |
+
Args:
|
| 1024 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 1025 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 1026 |
+
instead.
|
| 1027 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 1028 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 1029 |
+
will be used instead.
|
| 1030 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 1031 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 1032 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
|
| 1033 |
+
not greater than `1`).
|
| 1034 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 1035 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
| 1036 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
| 1037 |
+
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
| 1038 |
+
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
|
| 1039 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 1040 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 1041 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 1042 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 1043 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 1044 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 1045 |
+
expense of slower inference.
|
| 1046 |
+
sigmas (`List[float]`, *optional*):
|
| 1047 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 1048 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 1049 |
+
will be used.
|
| 1050 |
+
guidance_scale (`float`, *optional*, defaults to 3.5):
|
| 1051 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 1052 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 1053 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 1054 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 1055 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 1056 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 1057 |
+
The number of images to generate per prompt.
|
| 1058 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 1059 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 1060 |
+
to make generation deterministic.
|
| 1061 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 1062 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 1063 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 1064 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 1065 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 1066 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 1067 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 1068 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 1069 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 1070 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 1071 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 1072 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 1073 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 1074 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
| 1075 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 1076 |
+
negative_ip_adapter_image:
|
| 1077 |
+
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 1078 |
+
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 1079 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 1080 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
| 1081 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 1082 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 1083 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 1084 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 1085 |
+
argument.
|
| 1086 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 1087 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 1088 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 1089 |
+
input argument.
|
| 1090 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 1091 |
+
The output format of the generate image. Choose between
|
| 1092 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 1093 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1094 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
| 1095 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 1096 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 1097 |
+
`self.processor` in
|
| 1098 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 1099 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 1100 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 1101 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 1102 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 1103 |
+
`callback_on_step_end_tensor_inputs`.
|
| 1104 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 1105 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 1106 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 1107 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 1108 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
| 1109 |
+
|
| 1110 |
+
Examples:
|
| 1111 |
+
|
| 1112 |
+
Returns:
|
| 1113 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
| 1114 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
| 1115 |
+
images.
|
| 1116 |
+
"""
|
| 1117 |
+
|
| 1118 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 1119 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 1120 |
+
|
| 1121 |
+
# 1. Check inputs. Raise error if not correct
|
| 1122 |
+
self.check_inputs(
|
| 1123 |
+
prompt,
|
| 1124 |
+
prompt_2,
|
| 1125 |
+
height,
|
| 1126 |
+
width,
|
| 1127 |
+
negative_prompt=negative_prompt,
|
| 1128 |
+
negative_prompt_2=negative_prompt_2,
|
| 1129 |
+
prompt_embeds=prompt_embeds,
|
| 1130 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 1131 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 1132 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 1133 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 1134 |
+
max_sequence_length=max_sequence_length,
|
| 1135 |
+
)
|
| 1136 |
+
|
| 1137 |
+
self._guidance_scale = guidance_scale
|
| 1138 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 1139 |
+
self._current_timestep = None
|
| 1140 |
+
self._interrupt = False
|
| 1141 |
+
|
| 1142 |
+
# 2. Define call parameters
|
| 1143 |
+
if prompt is not None and isinstance(prompt, str):
|
| 1144 |
+
batch_size = 1
|
| 1145 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 1146 |
+
batch_size = len(prompt)
|
| 1147 |
+
else:
|
| 1148 |
+
batch_size = prompt_embeds.shape[0]
|
| 1149 |
+
|
| 1150 |
+
device = self._execution_device
|
| 1151 |
+
|
| 1152 |
+
lora_scale = (
|
| 1153 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 1154 |
+
)
|
| 1155 |
+
has_neg_prompt = negative_prompt is not None or (
|
| 1156 |
+
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
|
| 1157 |
+
)
|
| 1158 |
+
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
| 1159 |
+
(
|
| 1160 |
+
prompt_embeds,
|
| 1161 |
+
pooled_prompt_embeds,
|
| 1162 |
+
text_ids,
|
| 1163 |
+
) = self.encode_prompt(
|
| 1164 |
+
prompt=prompt,
|
| 1165 |
+
prompt_2=prompt_2,
|
| 1166 |
+
prompt_embeds=prompt_embeds,
|
| 1167 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 1168 |
+
device=device,
|
| 1169 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 1170 |
+
max_sequence_length=max_sequence_length,
|
| 1171 |
+
lora_scale=lora_scale,
|
| 1172 |
+
)
|
| 1173 |
+
if do_true_cfg:
|
| 1174 |
+
(
|
| 1175 |
+
negative_prompt_embeds,
|
| 1176 |
+
negative_pooled_prompt_embeds,
|
| 1177 |
+
negative_text_ids,
|
| 1178 |
+
) = self.encode_prompt(
|
| 1179 |
+
prompt=negative_prompt,
|
| 1180 |
+
prompt_2=negative_prompt_2,
|
| 1181 |
+
prompt_embeds=negative_prompt_embeds,
|
| 1182 |
+
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 1183 |
+
device=device,
|
| 1184 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 1185 |
+
max_sequence_length=max_sequence_length,
|
| 1186 |
+
lora_scale=lora_scale,
|
| 1187 |
+
)
|
| 1188 |
+
|
| 1189 |
+
# 4. Prepare latent variables
|
| 1190 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 1191 |
+
latents, latent_image_ids = self.prepare_latents(
|
| 1192 |
+
batch_size * num_images_per_prompt,
|
| 1193 |
+
num_channels_latents,
|
| 1194 |
+
height,
|
| 1195 |
+
width,
|
| 1196 |
+
prompt_embeds.dtype,
|
| 1197 |
+
device,
|
| 1198 |
+
generator,
|
| 1199 |
+
latents,
|
| 1200 |
+
)
|
| 1201 |
+
|
| 1202 |
+
# 5. Prepare timesteps
|
| 1203 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 1204 |
+
image_seq_len = latents.shape[1]
|
| 1205 |
+
mu = calculate_shift(
|
| 1206 |
+
image_seq_len,
|
| 1207 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 1208 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 1209 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 1210 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 1211 |
+
)
|
| 1212 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 1213 |
+
self.scheduler,
|
| 1214 |
+
num_inference_steps,
|
| 1215 |
+
device,
|
| 1216 |
+
sigmas=sigmas,
|
| 1217 |
+
mu=mu,
|
| 1218 |
+
)
|
| 1219 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 1220 |
+
self._num_timesteps = len(timesteps)
|
| 1221 |
+
|
| 1222 |
+
# handle guidance
|
| 1223 |
+
if self.transformer.config.guidance_embeds:
|
| 1224 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 1225 |
+
guidance = guidance.expand(latents.shape[0])
|
| 1226 |
+
else:
|
| 1227 |
+
guidance = None
|
| 1228 |
+
|
| 1229 |
+
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
|
| 1230 |
+
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
|
| 1231 |
+
):
|
| 1232 |
+
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
| 1233 |
+
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
| 1234 |
+
|
| 1235 |
+
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
|
| 1236 |
+
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
|
| 1237 |
+
):
|
| 1238 |
+
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
| 1239 |
+
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
| 1240 |
+
|
| 1241 |
+
if self.joint_attention_kwargs is None:
|
| 1242 |
+
self._joint_attention_kwargs = {}
|
| 1243 |
+
|
| 1244 |
+
image_embeds = None
|
| 1245 |
+
negative_image_embeds = None
|
| 1246 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 1247 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 1248 |
+
ip_adapter_image,
|
| 1249 |
+
ip_adapter_image_embeds,
|
| 1250 |
+
device,
|
| 1251 |
+
batch_size * num_images_per_prompt,
|
| 1252 |
+
)
|
| 1253 |
+
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
|
| 1254 |
+
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 1255 |
+
negative_ip_adapter_image,
|
| 1256 |
+
negative_ip_adapter_image_embeds,
|
| 1257 |
+
device,
|
| 1258 |
+
batch_size * num_images_per_prompt,
|
| 1259 |
+
)
|
| 1260 |
+
|
| 1261 |
+
# 6. Denoising loop
|
| 1262 |
+
prev_latents = []
|
| 1263 |
+
pred_latents = []
|
| 1264 |
+
# preds_lst = []
|
| 1265 |
+
states = {
|
| 1266 |
+
"timestep": [],
|
| 1267 |
+
"guidance": [],
|
| 1268 |
+
"pooled_projections": [],
|
| 1269 |
+
"encoder_hidden_states": [],
|
| 1270 |
+
"txt_ids": None,
|
| 1271 |
+
"img_ids": None,
|
| 1272 |
+
}
|
| 1273 |
+
log_probs = []
|
| 1274 |
+
ts = []
|
| 1275 |
+
states["txt_ids"] = text_ids if text_ids is not None else None
|
| 1276 |
+
states["img_ids"] = latent_image_ids if latent_image_ids is not None else None
|
| 1277 |
+
|
| 1278 |
+
# self.scheduler.set_begin_index(0)
|
| 1279 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1280 |
+
for i, t in enumerate(timesteps):
|
| 1281 |
+
if self.interrupt:
|
| 1282 |
+
continue
|
| 1283 |
+
|
| 1284 |
+
self._current_timestep = t
|
| 1285 |
+
if image_embeds is not None:
|
| 1286 |
+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
| 1287 |
+
|
| 1288 |
+
|
| 1289 |
+
timestep = (t.expand(latents.shape[0])/ 1000.).to(latents.dtype)
|
| 1290 |
+
|
| 1291 |
+
states["timestep"].append(timestep.unsqueeze(1)) # Unsqueezed if needed for batch/timestep handling
|
| 1292 |
+
states["guidance"].append(guidance.unsqueeze(1) if torch.is_tensor(guidance) else guidance) # Handle if tensor
|
| 1293 |
+
states["pooled_projections"].append(pooled_prompt_embeds.unsqueeze(1) if pooled_prompt_embeds is not None else None) # Unsqueezed along seq/batch if applicable
|
| 1294 |
+
states["encoder_hidden_states"].append(prompt_embeds.unsqueeze(1) if prompt_embeds is not None else None) # Unsqueezed along seq dim if needed
|
| 1295 |
+
|
| 1296 |
+
ts.append(t.expand(latents.shape[0]).unsqueeze(1))
|
| 1297 |
+
prev_latents.append(latents.detach().clone().unsqueeze(1))
|
| 1298 |
+
|
| 1299 |
+
noise_pred = self.transformer(
|
| 1300 |
+
hidden_states=latents,
|
| 1301 |
+
timestep=timestep,
|
| 1302 |
+
guidance=guidance,
|
| 1303 |
+
pooled_projections=pooled_prompt_embeds,
|
| 1304 |
+
encoder_hidden_states=prompt_embeds,
|
| 1305 |
+
txt_ids=text_ids,
|
| 1306 |
+
img_ids=latent_image_ids,
|
| 1307 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1308 |
+
return_dict=False,
|
| 1309 |
+
)[0]
|
| 1310 |
+
|
| 1311 |
+
if do_true_cfg:
|
| 1312 |
+
if negative_image_embeds is not None:
|
| 1313 |
+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
| 1314 |
+
|
| 1315 |
+
neg_noise_pred = self.transformer(
|
| 1316 |
+
hidden_states=latents,
|
| 1317 |
+
timestep=timestep,
|
| 1318 |
+
guidance=guidance,
|
| 1319 |
+
pooled_projections=negative_pooled_prompt_embeds,
|
| 1320 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 1321 |
+
txt_ids=negative_text_ids,
|
| 1322 |
+
img_ids=latent_image_ids,
|
| 1323 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1324 |
+
return_dict=False,
|
| 1325 |
+
)[0]
|
| 1326 |
+
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
| 1327 |
+
|
| 1328 |
+
latents_dtype = latents.dtype
|
| 1329 |
+
latents, log_prob, prev_latents_mean, std_dev_t = sde_step_with_logprob(self.scheduler, noise_pred.float(), t.expand(latents.shape[0]), latents.float())
|
| 1330 |
+
|
| 1331 |
+
log_probs.append(log_prob.detach().clone().unsqueeze(1))
|
| 1332 |
+
pred_latents.append(latents.detach().clone().unsqueeze(1))
|
| 1333 |
+
if latents.dtype != latents_dtype:
|
| 1334 |
+
latents = latents.to(latents_dtype)
|
| 1335 |
+
|
| 1336 |
+
if callback_on_step_end is not None:
|
| 1337 |
+
callback_kwargs = {}
|
| 1338 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 1339 |
+
callback_kwargs[k] = locals()[k]
|
| 1340 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1341 |
+
|
| 1342 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1343 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1344 |
+
|
| 1345 |
+
# call the callback, if provided
|
| 1346 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1347 |
+
progress_bar.update()
|
| 1348 |
+
|
| 1349 |
+
if XLA_AVAILABLE:
|
| 1350 |
+
xm.mark_step()
|
| 1351 |
+
|
| 1352 |
+
self._current_timestep = None
|
| 1353 |
+
|
| 1354 |
+
if output_type == "latent":
|
| 1355 |
+
image = latents
|
| 1356 |
+
else:
|
| 1357 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 1358 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 1359 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 1360 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 1361 |
+
|
| 1362 |
+
|
| 1363 |
+
batched_states = {}
|
| 1364 |
+
batch_size = latents.shape[0]
|
| 1365 |
+
num_steps = len(timesteps)
|
| 1366 |
+
|
| 1367 |
+
for key, value_list in states.items():
|
| 1368 |
+
if value_list is None or len(value_list) == 0: # Skip None or empty lists
|
| 1369 |
+
batched_states[key] = None
|
| 1370 |
+
continue
|
| 1371 |
+
if value_list[0] is None: # Handle lists of None (e.g., optional inputs)
|
| 1372 |
+
batched_states[key] = None
|
| 1373 |
+
continue
|
| 1374 |
+
# Concatenate along dim=1
|
| 1375 |
+
if isinstance(value_list, list):
|
| 1376 |
+
concatenated = torch.cat(value_list, dim=1) # Shape: (batch, steps, ...)
|
| 1377 |
+
if len(concatenated.shape) <= 2: # 1D tensors (e.g., timestep: batch, steps)
|
| 1378 |
+
# print(key, concatenated.shape)
|
| 1379 |
+
batched_states[key] = concatenated.view(-1)
|
| 1380 |
+
else: # Higher-dim tensors (e.g., latents: batch, steps, channels, h, w)
|
| 1381 |
+
batched_states[key] = concatenated.view(-1, *concatenated.shape[2:])
|
| 1382 |
+
else:
|
| 1383 |
+
batched_states[key] = value_list
|
| 1384 |
+
# assert 0
|
| 1385 |
+
prev_latents = torch.cat(prev_latents, dim=1)
|
| 1386 |
+
log_probs = torch.cat(log_probs, dim=1)
|
| 1387 |
+
pred_latents = torch.cat(pred_latents, dim=1)
|
| 1388 |
+
ts = torch.cat(ts, dim=1)
|
| 1389 |
+
|
| 1390 |
+
prev_latents = prev_latents.view(prev_latents.shape[0] * prev_latents.shape[1], *prev_latents.shape[2:])
|
| 1391 |
+
log_probs = log_probs.view(log_probs.shape[0] * log_probs.shape[1], *log_probs.shape[2:])
|
| 1392 |
+
pred_latents = pred_latents.view(pred_latents.shape[0] * pred_latents.shape[1], *pred_latents.shape[2:])
|
| 1393 |
+
ts = ts.view(-1)
|
| 1394 |
+
|
| 1395 |
+
# Offload all models
|
| 1396 |
+
self.maybe_free_model_hooks()
|
| 1397 |
+
|
| 1398 |
+
return (image, prev_latents, log_probs, pred_latents, ts, batched_states)
|
| 1399 |
+
|
| 1400 |
+
def sde_step_with_logprob(
|
| 1401 |
+
self: FlowMatchEulerDiscreteScheduler,
|
| 1402 |
+
model_output: torch.FloatTensor,
|
| 1403 |
+
timestep: Union[float, torch.FloatTensor],
|
| 1404 |
+
sample: torch.FloatTensor,
|
| 1405 |
+
prev_sample: Optional[torch.FloatTensor] = None,
|
| 1406 |
+
generator: Optional[torch.Generator] = None,
|
| 1407 |
+
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
| 1408 |
+
"""
|
| 1409 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
|
| 1410 |
+
process from the learned model outputs (most often the predicted velocity).
|
| 1411 |
+
|
| 1412 |
+
Args:
|
| 1413 |
+
model_output (`torch.FloatTensor`):
|
| 1414 |
+
The direct output from learned flow model.
|
| 1415 |
+
timestep (`float`):
|
| 1416 |
+
The current discrete timestep in the diffusion chain.
|
| 1417 |
+
sample (`torch.FloatTensor`):
|
| 1418 |
+
A current instance of a sample created by the diffusion process.
|
| 1419 |
+
generator (`torch.Generator`, *optional*):
|
| 1420 |
+
A random number generator.
|
| 1421 |
+
"""
|
| 1422 |
+
step_index = [self.index_for_timestep(t) for t in timestep]
|
| 1423 |
+
prev_step_index = [step+1 for step in step_index]
|
| 1424 |
+
sigma = self.sigmas[step_index].view(-1, 1, 1).to(model_output.device)
|
| 1425 |
+
sigma_prev = self.sigmas[prev_step_index].view(-1, 1, 1).to(model_output.device)
|
| 1426 |
+
sigma_max = self.sigmas[1].item()
|
| 1427 |
+
dt = sigma_prev - sigma
|
| 1428 |
+
|
| 1429 |
+
std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))) * 1.0
|
| 1430 |
+
|
| 1431 |
+
|
| 1432 |
+
# our sde
|
| 1433 |
+
prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
|
| 1434 |
+
|
| 1435 |
+
if prev_sample is not None and generator is not None:
|
| 1436 |
+
raise ValueError(
|
| 1437 |
+
"Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
|
| 1438 |
+
" `prev_sample` stays `None`."
|
| 1439 |
+
)
|
| 1440 |
+
|
| 1441 |
+
if prev_sample is None:
|
| 1442 |
+
variance_noise = randn_tensor(
|
| 1443 |
+
model_output.shape,
|
| 1444 |
+
generator=generator,
|
| 1445 |
+
device=model_output.device,
|
| 1446 |
+
dtype=model_output.dtype,
|
| 1447 |
+
)
|
| 1448 |
+
prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1*dt) * variance_noise
|
| 1449 |
+
|
| 1450 |
+
|
| 1451 |
+
log_prob = (
|
| 1452 |
+
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1*dt))**2))
|
| 1453 |
+
- torch.log(std_dev_t * torch.sqrt(-1*dt))
|
| 1454 |
+
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
|
| 1455 |
+
)
|
| 1456 |
+
|
| 1457 |
+
# mean along all but batch dimension
|
| 1458 |
+
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
|
| 1459 |
+
|
| 1460 |
+
return prev_sample, log_prob, prev_sample_mean, std_dev_t * torch.sqrt(-1*dt)
|
| 1461 |
+
|
| 1462 |
+
|
| 1463 |
+
|
| 1464 |
+
# Copyright 2025 Fu-Yun Wang
|
| 1465 |
+
#
|
| 1466 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 1467 |
+
# you may not use this file except in compliance with the License.
|
| 1468 |
+
# You may obtain a copy of the License at
|
| 1469 |
+
#
|
| 1470 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 1471 |
+
#
|
| 1472 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 1473 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 1474 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 1475 |
+
# See the License for the specific language governing permissions and
|
| 1476 |
+
# limitations under the License.
|
| 1477 |
+
|
| 1478 |
+
def sde_step_with_logprob_simple(
|
| 1479 |
+
self: FlowMatchEulerDiscreteScheduler,
|
| 1480 |
+
model_output: torch.FloatTensor,
|
| 1481 |
+
timestep: Union[float, torch.FloatTensor],
|
| 1482 |
+
sample: torch.FloatTensor,
|
| 1483 |
+
prev_sample: Optional[torch.FloatTensor] = None,
|
| 1484 |
+
generator: Optional[torch.Generator] = None,
|
| 1485 |
+
):
|
| 1486 |
+
"""
|
| 1487 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
|
| 1488 |
+
process from the learned model outputs (most often the predicted velocity).
|
| 1489 |
+
|
| 1490 |
+
Args:
|
| 1491 |
+
model_output (`torch.FloatTensor`):
|
| 1492 |
+
The direct output from learned flow model.
|
| 1493 |
+
timestep (`float`):
|
| 1494 |
+
The current discrete timestep in the diffusion chain.
|
| 1495 |
+
sample (`torch.FloatTensor`):
|
| 1496 |
+
A current instance of a sample created by the diffusion process.
|
| 1497 |
+
generator (`torch.Generator`, *optional*):
|
| 1498 |
+
A random number generator.
|
| 1499 |
+
"""
|
| 1500 |
+
|
| 1501 |
+
step_index = [self.index_for_timestep(t) for t in timestep]
|
| 1502 |
+
prev_step_index = [step+1 for step in step_index]
|
| 1503 |
+
sigma = self.sigmas[step_index].view(-1, 1, 1).to(model_output.device)
|
| 1504 |
+
sigma_prev = self.sigmas[prev_step_index].view(-1, 1, 1).to(model_output.device)
|
| 1505 |
+
sigma_max = self.sigmas[1].item()
|
| 1506 |
+
dt = sigma_prev - sigma
|
| 1507 |
+
|
| 1508 |
+
|
| 1509 |
+
eta = 0.5
|
| 1510 |
+
Dt = - dt * eta
|
| 1511 |
+
|
| 1512 |
+
prev_sample_mean = sample * (1 - Dt / (1 - torch.where(sigma == 1, sigma_max, sigma))) + model_output * (dt - Dt)
|
| 1513 |
+
|
| 1514 |
+
std_dev_t = torch.sqrt(2 * Dt * (sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))))
|
| 1515 |
+
|
| 1516 |
+
if prev_sample is not None and generator is not None:
|
| 1517 |
+
raise ValueError(
|
| 1518 |
+
"Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
|
| 1519 |
+
" `prev_sample` stays `None`."
|
| 1520 |
+
)
|
| 1521 |
+
|
| 1522 |
+
if prev_sample is None:
|
| 1523 |
+
# Generate noise if not provided
|
| 1524 |
+
variance_noise = randn_tensor(
|
| 1525 |
+
model_output.shape,
|
| 1526 |
+
generator=generator,
|
| 1527 |
+
device=model_output.device,
|
| 1528 |
+
dtype=model_output.dtype,
|
| 1529 |
+
)
|
| 1530 |
+
|
| 1531 |
+
prev_sample = prev_sample_mean + std_dev_t * variance_noise
|
| 1532 |
+
|
| 1533 |
+
|
| 1534 |
+
log_prob = (
|
| 1535 |
+
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
|
| 1536 |
+
- torch.log(std_dev_t)
|
| 1537 |
+
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
|
| 1538 |
+
)
|
| 1539 |
+
|
| 1540 |
+
# mean along all but batch dimension
|
| 1541 |
+
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
|
| 1542 |
+
|
| 1543 |
+
return prev_sample, log_prob, prev_sample_mean, std_dev_t
|
unimodel/qwenflux/qwenflux_inference.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Fu-Yun Wang
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from typing import List, Optional, Tuple, Union, Dict
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from PIL import Image
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoProcessor
|
| 22 |
+
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLModel, Qwen2_5_VLForConditionalGeneration
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 26 |
+
from diffusers.pipelines.pipeline_utils import numpy_to_pil
|
| 27 |
+
import numpy as np
|
| 28 |
+
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteSchedulerOutput
|
| 29 |
+
from diffusers.schedulers import DPMSolverMultistepScheduler
|
| 30 |
+
import math
|
| 31 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 32 |
+
from diffusers import FluxTransformer2DModel, AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
| 33 |
+
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, CLIPTextConfig, T5Config
|
| 34 |
+
from .fluxpipeline import FluxPipeline
|
| 35 |
+
import re
|
| 36 |
+
import datetime
|
| 37 |
+
import os
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def save_grid_image(prompt, images, rows, cols):
|
| 41 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 42 |
+
base_dir = os.path.join("samples", timestamp, prompt[:100])
|
| 43 |
+
os.makedirs(base_dir, exist_ok=True)
|
| 44 |
+
|
| 45 |
+
filename = os.path.join(base_dir, "grid.jpg")
|
| 46 |
+
grid_image = create_image_grid(images, rows, cols)
|
| 47 |
+
grid_image.save(filename)
|
| 48 |
+
|
| 49 |
+
print(f"Saved: {filename}")
|
| 50 |
+
|
| 51 |
+
def create_image_grid(images, rows, cols):
|
| 52 |
+
"""Creates a grid of images and returns a single PIL Image."""
|
| 53 |
+
|
| 54 |
+
assert len(images) == rows * cols
|
| 55 |
+
|
| 56 |
+
width, height = images[0].size
|
| 57 |
+
grid_width = width * cols
|
| 58 |
+
grid_height = height * rows
|
| 59 |
+
|
| 60 |
+
grid_image = Image.new('RGB', (grid_width, grid_height))
|
| 61 |
+
|
| 62 |
+
for i, image in enumerate(images):
|
| 63 |
+
x = (i % cols) * width
|
| 64 |
+
y = (i // cols) * height
|
| 65 |
+
grid_image.paste(image, (x, y))
|
| 66 |
+
|
| 67 |
+
return grid_image
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def sde_step_with_logprob(
|
| 71 |
+
self: FlowMatchEulerDiscreteScheduler,
|
| 72 |
+
model_output: torch.FloatTensor,
|
| 73 |
+
timestep: Union[float, torch.FloatTensor],
|
| 74 |
+
sample: torch.FloatTensor,
|
| 75 |
+
prev_sample: Optional[torch.FloatTensor] = None,
|
| 76 |
+
generator: Optional[torch.Generator] = None,
|
| 77 |
+
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
| 78 |
+
"""
|
| 79 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
|
| 80 |
+
process from the learned model outputs (most often the predicted velocity).
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
model_output (`torch.FloatTensor`):
|
| 84 |
+
The direct output from learned flow model.
|
| 85 |
+
timestep (`float`):
|
| 86 |
+
The current discrete timestep in the diffusion chain.
|
| 87 |
+
sample (`torch.FloatTensor`):
|
| 88 |
+
A current instance of a sample created by the diffusion process.
|
| 89 |
+
generator (`torch.Generator`, *optional*):
|
| 90 |
+
A random number generator.
|
| 91 |
+
"""
|
| 92 |
+
step_index = [self.index_for_timestep(t) for t in timestep]
|
| 93 |
+
prev_step_index = [step+1 for step in step_index]
|
| 94 |
+
sigma = self.sigmas[step_index].view(-1, 1, 1).to(model_output.device)
|
| 95 |
+
sigma_prev = self.sigmas[prev_step_index].view(-1, 1, 1).to(model_output.device)
|
| 96 |
+
sigma_max = self.sigmas[1].item()
|
| 97 |
+
dt = sigma_prev - sigma
|
| 98 |
+
|
| 99 |
+
std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*1.0
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# our sde
|
| 103 |
+
prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
|
| 104 |
+
|
| 105 |
+
if prev_sample is not None and generator is not None:
|
| 106 |
+
raise ValueError(
|
| 107 |
+
"Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
|
| 108 |
+
" `prev_sample` stays `None`."
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
if prev_sample is None:
|
| 112 |
+
variance_noise = randn_tensor(
|
| 113 |
+
model_output.shape,
|
| 114 |
+
generator=generator,
|
| 115 |
+
device=model_output.device,
|
| 116 |
+
dtype=model_output.dtype,
|
| 117 |
+
)
|
| 118 |
+
prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1*dt) * variance_noise
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
log_prob = (
|
| 122 |
+
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1*dt))**2))
|
| 123 |
+
- torch.log(std_dev_t * torch.sqrt(-1*dt))
|
| 124 |
+
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# mean along all but batch dimension
|
| 128 |
+
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
|
| 129 |
+
|
| 130 |
+
return prev_sample, log_prob, prev_sample_mean, std_dev_t * torch.sqrt(-1*dt)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# Copyright 2025 Fu-Yun Wang
|
| 134 |
+
#
|
| 135 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 136 |
+
# you may not use this file except in compliance with the License.
|
| 137 |
+
# You may obtain a copy of the License at
|
| 138 |
+
#
|
| 139 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 140 |
+
#
|
| 141 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 142 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 143 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 144 |
+
# See the License for the specific language governing permissions and
|
| 145 |
+
# limitations under the License.
|
| 146 |
+
|
| 147 |
+
def sde_step_with_logprob_simple(
|
| 148 |
+
self: FlowMatchEulerDiscreteScheduler,
|
| 149 |
+
model_output: torch.FloatTensor,
|
| 150 |
+
timestep: Union[float, torch.FloatTensor],
|
| 151 |
+
sample: torch.FloatTensor,
|
| 152 |
+
prev_sample: Optional[torch.FloatTensor] = None,
|
| 153 |
+
generator: Optional[torch.Generator] = None,
|
| 154 |
+
):
|
| 155 |
+
"""
|
| 156 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
|
| 157 |
+
process from the learned model outputs (most often the predicted velocity).
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
model_output (`torch.FloatTensor`):
|
| 161 |
+
The direct output from learned flow model.
|
| 162 |
+
timestep (`float`):
|
| 163 |
+
The current discrete timestep in the diffusion chain.
|
| 164 |
+
sample (`torch.FloatTensor`):
|
| 165 |
+
A current instance of a sample created by the diffusion process.
|
| 166 |
+
generator (`torch.Generator`, *optional*):
|
| 167 |
+
A random number generator.
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
step_index = [self.index_for_timestep(t) for t in timestep]
|
| 171 |
+
prev_step_index = [step+1 for step in step_index]
|
| 172 |
+
sigma = self.sigmas[step_index].view(-1, 1, 1, 1).to(model_output.device)
|
| 173 |
+
sigma_prev = self.sigmas[prev_step_index].view(-1, 1, 1, 1).to(model_output.device)
|
| 174 |
+
sigma_max = self.sigmas[1].item()
|
| 175 |
+
dt = sigma_prev - sigma
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
eta = 0.5
|
| 179 |
+
Dt = - dt * eta
|
| 180 |
+
|
| 181 |
+
prev_sample_mean = sample * (1 - Dt / (1 - torch.where(sigma == 1, sigma_max, sigma))) + model_output * (dt - Dt)
|
| 182 |
+
|
| 183 |
+
std_dev_t = torch.sqrt(2 * Dt * (sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))))
|
| 184 |
+
|
| 185 |
+
if prev_sample is not None and generator is not None:
|
| 186 |
+
raise ValueError(
|
| 187 |
+
"Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
|
| 188 |
+
" `prev_sample` stays `None`."
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
if prev_sample is None:
|
| 192 |
+
# Generate noise if not provided
|
| 193 |
+
variance_noise = randn_tensor(
|
| 194 |
+
model_output.shape,
|
| 195 |
+
generator=generator,
|
| 196 |
+
device=model_output.device,
|
| 197 |
+
dtype=model_output.dtype,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
prev_sample = prev_sample_mean + std_dev_t * variance_noise
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
log_prob = (
|
| 204 |
+
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
|
| 205 |
+
- torch.log(std_dev_t)
|
| 206 |
+
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# mean along all but batch dimension
|
| 210 |
+
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
|
| 211 |
+
|
| 212 |
+
return prev_sample, log_prob, prev_sample_mean, std_dev_t
|
| 213 |
+
|
| 214 |
+
class QwenFluxMetaModel:
|
| 215 |
+
|
| 216 |
+
def __init__(self, config):
|
| 217 |
+
super(QwenFluxMetaModel, self).__init__(config)
|
| 218 |
+
|
| 219 |
+
if hasattr(config, "diffusion_expert"):
|
| 220 |
+
ckpt_id = "black-forest-labs/FLUX.1-dev"
|
| 221 |
+
# Load configuration for each component
|
| 222 |
+
transformer_config = FluxTransformer2DModel.load_config(ckpt_id, subfolder="transformer")
|
| 223 |
+
vae_config = AutoencoderKL.load_config(ckpt_id, subfolder="vae")
|
| 224 |
+
text_encoder_config = CLIPTextConfig.from_pretrained(ckpt_id, subfolder="text_encoder")
|
| 225 |
+
text_encoder_2_config = T5Config.from_pretrained(ckpt_id, subfolder="text_encoder_2")
|
| 226 |
+
|
| 227 |
+
# Initialize components from their configurations
|
| 228 |
+
self.transformer = FluxTransformer2DModel.from_config(transformer_config)
|
| 229 |
+
self.vae = AutoencoderKL.from_config(vae_config)
|
| 230 |
+
self.text_encoder = CLIPTextModel(text_encoder_config)
|
| 231 |
+
self.text_encoder_2 = T5EncoderModel(text_encoder_2_config)
|
| 232 |
+
|
| 233 |
+
# Initialize tokenizers (these don't use from_config as they are not models)
|
| 234 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(ckpt_id, subfolder="tokenizer")
|
| 235 |
+
self.tokenizer_2 = T5TokenizerFast.from_pretrained(ckpt_id, subfolder="tokenizer_2")
|
| 236 |
+
|
| 237 |
+
self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(ckpt_id, subfolder="scheduler")
|
| 238 |
+
|
| 239 |
+
# Create the pipeline configuration dictionary
|
| 240 |
+
pipeline_config = {
|
| 241 |
+
"transformer": self.transformer,
|
| 242 |
+
"scheduler": self.scheduler,
|
| 243 |
+
"vae": self.vae,
|
| 244 |
+
"text_encoder": self.text_encoder,
|
| 245 |
+
"text_encoder_2": self.text_encoder_2,
|
| 246 |
+
"tokenizer": self.tokenizer,
|
| 247 |
+
"tokenizer_2": self.tokenizer_2,
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
self.diffusion_expert = FluxPipeline(**pipeline_config)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def initialize_diffusion_expert(self, fsdp=None):
|
| 254 |
+
|
| 255 |
+
if getattr(self, 'diffusion_expert', None) is None:
|
| 256 |
+
print("random initiation the diffusion expert !!!")
|
| 257 |
+
self.diffusion_expert = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", revision="main", torch_dtype=torch.bfloat16).to(torch.bfloat16)
|
| 258 |
+
self.text_encoder = self.diffusion_expert.text_encoder
|
| 259 |
+
self.text_encoder_2 = self.diffusion_expert.text_encoder_2
|
| 260 |
+
self.tokenizer = self.diffusion_expert.tokenizer
|
| 261 |
+
self.tokenizer_2 = self.diffusion_expert.tokenizer_2
|
| 262 |
+
self.vae = self.diffusion_expert.vae
|
| 263 |
+
self.transformer = self.diffusion_expert.transformer
|
| 264 |
+
self.scheduler = self.diffusion_expert.scheduler
|
| 265 |
+
|
| 266 |
+
self.config.diffusion_expert = "flux"
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class QwenFluxConfig(Qwen2_5_VLConfig):
|
| 271 |
+
model_type = "QwenFlux"
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class QwenFluxModel(QwenFluxMetaModel, Qwen2_5_VLModel):
|
| 275 |
+
config_class = QwenFluxConfig
|
| 276 |
+
|
| 277 |
+
def __init__(self, config: Qwen2_5_VLConfig):
|
| 278 |
+
super(QwenFluxModel, self).__init__(config)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class QwenFluxForInferenceLM(Qwen2_5_VLForConditionalGeneration):
|
| 282 |
+
config_class = QwenFluxConfig
|
| 283 |
+
|
| 284 |
+
def __init__(self, config):
|
| 285 |
+
Qwen2_5_VLForConditionalGeneration.__init__(self, config)
|
| 286 |
+
config.model_type = "QwenFlux"
|
| 287 |
+
|
| 288 |
+
self.model = QwenFluxModel(config)
|
| 289 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 290 |
+
self.post_init()
|
| 291 |
+
|
| 292 |
+
def get_model(self):
|
| 293 |
+
return self.model
|
| 294 |
+
|
| 295 |
+
@torch.no_grad()
|
| 296 |
+
def generate_image(
|
| 297 |
+
self,
|
| 298 |
+
texts: List[str],
|
| 299 |
+
diffusion_kwargs: Optional[Dict] = dict(guidance_scale = 3.5, num_inference_steps=25),
|
| 300 |
+
sde_sampling: Optional[bool] = False,
|
| 301 |
+
):
|
| 302 |
+
|
| 303 |
+
if isinstance(texts, str):
|
| 304 |
+
texts = [texts]
|
| 305 |
+
|
| 306 |
+
if not sde_sampling:
|
| 307 |
+
output_img = self.model.diffusion_expert(
|
| 308 |
+
texts,
|
| 309 |
+
max_sequence_length=512,
|
| 310 |
+
**diffusion_kwargs,
|
| 311 |
+
).images
|
| 312 |
+
return output_img
|
| 313 |
+
else:
|
| 314 |
+
return self.model.diffusion_expert.sde_sampling(
|
| 315 |
+
texts,
|
| 316 |
+
max_sequence_length=512,
|
| 317 |
+
**diffusion_kwargs,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def extract_thinking_content(self, text: str) -> str:
|
| 322 |
+
pattern = r'<answer>(.*?)</answer>'
|
| 323 |
+
matches = re.findall(pattern, text, re.DOTALL)
|
| 324 |
+
|
| 325 |
+
if matches:
|
| 326 |
+
return matches[-1].strip().replace("<answer>", "").replace("</answer>", "")
|
| 327 |
+
else:
|
| 328 |
+
return text.strip().replace("<answer>", "").replace("</answer>", "")
|
| 329 |
+
|
| 330 |
+
@torch.no_grad()
|
| 331 |
+
def generate_image_cot(
|
| 332 |
+
self,
|
| 333 |
+
texts: List[str],
|
| 334 |
+
processor: Optional[object] = None,
|
| 335 |
+
diffusion_kwargs: Optional[Dict] = dict(guidance_scale = 3.5, num_inference_steps=25),
|
| 336 |
+
llm_kwargs: Optional[Dict] = dict(max_new_tokens=256, temperature=0.7, top_p=0.9, do_sample=True),
|
| 337 |
+
cot_prompt_template: Optional[str] = None,
|
| 338 |
+
):
|
| 339 |
+
|
| 340 |
+
if isinstance(texts, str):
|
| 341 |
+
texts = [texts]
|
| 342 |
+
|
| 343 |
+
if cot_prompt_template is None:
|
| 344 |
+
# cot_prompt_template = """Please improve the following image generation prompt to make it more detailed and specific for better image quality. Think step by step about what visual elements would make this image more compelling. Original prompt: {original_prompt}. Please provide the improved prompt in <thinking> </thinking> tags."""
|
| 345 |
+
cot_prompt_template = """Please provide an enhanced prompt for the following image generation prompt to make the image more realistic, detailed, with clear separation and precise alignment of all entities.
|
| 346 |
+
Original prompt: {original_prompt}. Directly provide the improved prompt in <answer> </answer> tags."""
|
| 347 |
+
|
| 348 |
+
improved_prompts = []
|
| 349 |
+
|
| 350 |
+
for text in texts:
|
| 351 |
+
cot_input = cot_prompt_template.format(original_prompt=text)
|
| 352 |
+
|
| 353 |
+
messages = [{"role": "user", "content": cot_input}]
|
| 354 |
+
input_text_formatted = processor.apply_chat_template(
|
| 355 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 356 |
+
)
|
| 357 |
+
model_inputs = processor(
|
| 358 |
+
text=[input_text_formatted],
|
| 359 |
+
return_tensors="pt"
|
| 360 |
+
).to(self.device)
|
| 361 |
+
|
| 362 |
+
generated_ids = self.generate(
|
| 363 |
+
**model_inputs,
|
| 364 |
+
**llm_kwargs,
|
| 365 |
+
eos_token_id=processor.tokenizer.eos_token_id,
|
| 366 |
+
pad_token_id=processor.tokenizer.pad_token_id
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
generated_text = processor.batch_decode(
|
| 370 |
+
generated_ids[:, model_inputs['input_ids'].shape[1]:],
|
| 371 |
+
skip_special_tokens=True
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
improved_prompt = [self.extract_thinking_content(decode_text) for decode_text in generated_text]
|
| 375 |
+
improved_prompts.extend(improved_prompt)
|
| 376 |
+
|
| 377 |
+
print(f"Original prompt: {text}")
|
| 378 |
+
print(f"Improved prompt: {improved_prompt}")
|
| 379 |
+
print("-" * 50)
|
| 380 |
+
|
| 381 |
+
output_images = self.generate_image(improved_prompts, diffusion_kwargs)
|
| 382 |
+
|
| 383 |
+
return {
|
| 384 |
+
'images': output_images,
|
| 385 |
+
'original_prompts': texts,
|
| 386 |
+
'improved_prompts': improved_prompts
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
AutoConfig.register("QwenFlux", QwenFluxConfig)
|
| 390 |
+
AutoModelForCausalLM.register(QwenFluxConfig, QwenFluxForInferenceLM)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
if __name__ == "__main__":
|
| 394 |
+
|
| 395 |
+
model = QwenFluxForInferenceLM.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct",torch_dtype=torch.bfloat16)
|
| 396 |
+
model.model.initialize_diffusion_expert()
|
| 397 |
+
model.model.diffusion_expert.to("cuda:0")
|
| 398 |
+
model.to("cuda:0")
|
| 399 |
+
AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
| 400 |
+
text = ["a photo of a cat"]
|
| 401 |
+
images = model.generate_image(text)
|
| 402 |
+
images[0].save("test_flux.png")
|
| 403 |
+
|
| 404 |
+
model.save_pretrained("outputs/pretrain/qwenflux")
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
model = QwenFluxForInferenceLM.from_pretrained("outputs/pretrain/qwenflux", torch_dtype=torch.bfloat16)
|
| 408 |
+
model.to("cuda:0")
|
| 409 |
+
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
| 410 |
+
text = ["a photo of a cat"]
|
| 411 |
+
images = model.generate_image(text)
|
| 412 |
+
images[0].save("test_flux.jpg")
|
| 413 |
+
|
| 414 |
+
outputs = model.generate_image_cot(text, processor = processor)
|
| 415 |
+
outputs['images'][0].save("test_flux_cot.jpg")
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
|
unimodel/qwenkontext/fluxkontext_pipeline.py
ADDED
|
@@ -0,0 +1,1161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
| 2 |
+
# Copyright 2025 Fu-Yun Wang
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
from transformers import (
|
| 22 |
+
CLIPImageProcessor,
|
| 23 |
+
CLIPTextModel,
|
| 24 |
+
CLIPTokenizer,
|
| 25 |
+
CLIPVisionModelWithProjection,
|
| 26 |
+
T5EncoderModel,
|
| 27 |
+
T5TokenizerFast,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 31 |
+
from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
| 32 |
+
from diffusers.models import AutoencoderKL, FluxTransformer2DModel
|
| 33 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 34 |
+
from diffusers.utils import (
|
| 35 |
+
USE_PEFT_BACKEND,
|
| 36 |
+
deprecate,
|
| 37 |
+
is_torch_xla_available,
|
| 38 |
+
logging,
|
| 39 |
+
replace_example_docstring,
|
| 40 |
+
scale_lora_layers,
|
| 41 |
+
unscale_lora_layers,
|
| 42 |
+
)
|
| 43 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 44 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 45 |
+
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if is_torch_xla_available():
|
| 49 |
+
import torch_xla.core.xla_model as xm
|
| 50 |
+
|
| 51 |
+
XLA_AVAILABLE = True
|
| 52 |
+
else:
|
| 53 |
+
XLA_AVAILABLE = False
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 57 |
+
|
| 58 |
+
EXAMPLE_DOC_STRING = """
|
| 59 |
+
Examples:
|
| 60 |
+
```py
|
| 61 |
+
>>> import torch
|
| 62 |
+
>>> from diffusers import FluxKontextPipeline
|
| 63 |
+
>>> from diffusers.utils import load_image
|
| 64 |
+
|
| 65 |
+
>>> pipe = FluxKontextPipeline.from_pretrained(
|
| 66 |
+
... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
|
| 67 |
+
... )
|
| 68 |
+
>>> pipe.to("cuda")
|
| 69 |
+
|
| 70 |
+
>>> image = load_image(
|
| 71 |
+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
|
| 72 |
+
... ).convert("RGB")
|
| 73 |
+
>>> prompt = "Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"
|
| 74 |
+
>>> image = pipe(
|
| 75 |
+
... image=image,
|
| 76 |
+
... prompt=prompt,
|
| 77 |
+
... guidance_scale=2.5,
|
| 78 |
+
... generator=torch.Generator().manual_seed(42),
|
| 79 |
+
... ).images[0]
|
| 80 |
+
>>> image.save("output.png")
|
| 81 |
+
```
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
PREFERRED_KONTEXT_RESOLUTIONS = [
|
| 85 |
+
(672, 1568),
|
| 86 |
+
(688, 1504),
|
| 87 |
+
(720, 1456),
|
| 88 |
+
(752, 1392),
|
| 89 |
+
(800, 1328),
|
| 90 |
+
(832, 1248),
|
| 91 |
+
(880, 1184),
|
| 92 |
+
(944, 1104),
|
| 93 |
+
(1024, 1024),
|
| 94 |
+
(1104, 944),
|
| 95 |
+
(1184, 880),
|
| 96 |
+
(1248, 832),
|
| 97 |
+
(1328, 800),
|
| 98 |
+
(1392, 752),
|
| 99 |
+
(1456, 720),
|
| 100 |
+
(1504, 688),
|
| 101 |
+
(1568, 672),
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def calculate_shift(
|
| 106 |
+
image_seq_len,
|
| 107 |
+
base_seq_len: int = 256,
|
| 108 |
+
max_seq_len: int = 4096,
|
| 109 |
+
base_shift: float = 0.5,
|
| 110 |
+
max_shift: float = 1.15,
|
| 111 |
+
):
|
| 112 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 113 |
+
b = base_shift - m * base_seq_len
|
| 114 |
+
mu = image_seq_len * m + b
|
| 115 |
+
return mu
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 119 |
+
def retrieve_timesteps(
|
| 120 |
+
scheduler,
|
| 121 |
+
num_inference_steps: Optional[int] = None,
|
| 122 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 123 |
+
timesteps: Optional[List[int]] = None,
|
| 124 |
+
sigmas: Optional[List[float]] = None,
|
| 125 |
+
**kwargs,
|
| 126 |
+
):
|
| 127 |
+
r"""
|
| 128 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 129 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
scheduler (`SchedulerMixin`):
|
| 133 |
+
The scheduler to get timesteps from.
|
| 134 |
+
num_inference_steps (`int`):
|
| 135 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 136 |
+
must be `None`.
|
| 137 |
+
device (`str` or `torch.device`, *optional*):
|
| 138 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 139 |
+
timesteps (`List[int]`, *optional*):
|
| 140 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 141 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 142 |
+
sigmas (`List[float]`, *optional*):
|
| 143 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 144 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 148 |
+
second element is the number of inference steps.
|
| 149 |
+
"""
|
| 150 |
+
if timesteps is not None and sigmas is not None:
|
| 151 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 152 |
+
if timesteps is not None:
|
| 153 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 154 |
+
if not accepts_timesteps:
|
| 155 |
+
raise ValueError(
|
| 156 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 157 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 158 |
+
)
|
| 159 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 160 |
+
timesteps = scheduler.timesteps
|
| 161 |
+
num_inference_steps = len(timesteps)
|
| 162 |
+
elif sigmas is not None:
|
| 163 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 164 |
+
if not accept_sigmas:
|
| 165 |
+
raise ValueError(
|
| 166 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 167 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 168 |
+
)
|
| 169 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 170 |
+
timesteps = scheduler.timesteps
|
| 171 |
+
num_inference_steps = len(timesteps)
|
| 172 |
+
else:
|
| 173 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 174 |
+
timesteps = scheduler.timesteps
|
| 175 |
+
return timesteps, num_inference_steps
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 179 |
+
def retrieve_latents(
|
| 180 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 181 |
+
):
|
| 182 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 183 |
+
return encoder_output.latent_dist.sample(generator)
|
| 184 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 185 |
+
return encoder_output.latent_dist.mode()
|
| 186 |
+
elif hasattr(encoder_output, "latents"):
|
| 187 |
+
return encoder_output.latents
|
| 188 |
+
else:
|
| 189 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class FluxKontextPipeline(
|
| 193 |
+
DiffusionPipeline,
|
| 194 |
+
FluxLoraLoaderMixin,
|
| 195 |
+
FromSingleFileMixin,
|
| 196 |
+
TextualInversionLoaderMixin,
|
| 197 |
+
FluxIPAdapterMixin,
|
| 198 |
+
):
|
| 199 |
+
r"""
|
| 200 |
+
The Flux Kontext pipeline for image-to-image and text-to-image generation.
|
| 201 |
+
|
| 202 |
+
Reference: https://bfl.ai/announcements/flux-1-kontext-dev
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
transformer ([`FluxTransformer2DModel`]):
|
| 206 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 207 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 208 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 209 |
+
vae ([`AutoencoderKL`]):
|
| 210 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 211 |
+
text_encoder ([`CLIPTextModel`]):
|
| 212 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
| 213 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
| 214 |
+
text_encoder_2 ([`T5EncoderModel`]):
|
| 215 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
| 216 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
| 217 |
+
tokenizer (`CLIPTokenizer`):
|
| 218 |
+
Tokenizer of class
|
| 219 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 220 |
+
tokenizer_2 (`T5TokenizerFast`):
|
| 221 |
+
Second Tokenizer of class
|
| 222 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
|
| 226 |
+
_optional_components = ["image_encoder", "feature_extractor"]
|
| 227 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 228 |
+
|
| 229 |
+
def __init__(
|
| 230 |
+
self,
|
| 231 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 232 |
+
vae: AutoencoderKL,
|
| 233 |
+
text_encoder: CLIPTextModel,
|
| 234 |
+
tokenizer: CLIPTokenizer,
|
| 235 |
+
text_encoder_2: T5EncoderModel,
|
| 236 |
+
tokenizer_2: T5TokenizerFast,
|
| 237 |
+
transformer: FluxTransformer2DModel,
|
| 238 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 239 |
+
feature_extractor: CLIPImageProcessor = None,
|
| 240 |
+
):
|
| 241 |
+
super().__init__()
|
| 242 |
+
|
| 243 |
+
self.register_modules(
|
| 244 |
+
vae=vae,
|
| 245 |
+
text_encoder=text_encoder,
|
| 246 |
+
text_encoder_2=text_encoder_2,
|
| 247 |
+
tokenizer=tokenizer,
|
| 248 |
+
tokenizer_2=tokenizer_2,
|
| 249 |
+
transformer=transformer,
|
| 250 |
+
scheduler=scheduler,
|
| 251 |
+
image_encoder=image_encoder,
|
| 252 |
+
feature_extractor=feature_extractor,
|
| 253 |
+
)
|
| 254 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 255 |
+
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
| 256 |
+
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
| 257 |
+
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
|
| 258 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 259 |
+
self.tokenizer_max_length = (
|
| 260 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
| 261 |
+
)
|
| 262 |
+
self.default_sample_size = 128
|
| 263 |
+
|
| 264 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
|
| 265 |
+
def _get_t5_prompt_embeds(
|
| 266 |
+
self,
|
| 267 |
+
prompt: Union[str, List[str]] = None,
|
| 268 |
+
num_images_per_prompt: int = 1,
|
| 269 |
+
max_sequence_length: int = 512,
|
| 270 |
+
device: Optional[torch.device] = None,
|
| 271 |
+
dtype: Optional[torch.dtype] = None,
|
| 272 |
+
):
|
| 273 |
+
device = device or self._execution_device
|
| 274 |
+
dtype = dtype or self.text_encoder.dtype
|
| 275 |
+
|
| 276 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 277 |
+
batch_size = len(prompt)
|
| 278 |
+
|
| 279 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 280 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
| 281 |
+
|
| 282 |
+
text_inputs = self.tokenizer_2(
|
| 283 |
+
prompt,
|
| 284 |
+
padding="max_length",
|
| 285 |
+
max_length=max_sequence_length,
|
| 286 |
+
truncation=True,
|
| 287 |
+
return_length=False,
|
| 288 |
+
return_overflowing_tokens=False,
|
| 289 |
+
return_tensors="pt",
|
| 290 |
+
)
|
| 291 |
+
text_input_ids = text_inputs.input_ids
|
| 292 |
+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
| 293 |
+
|
| 294 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 295 |
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
| 296 |
+
logger.warning(
|
| 297 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 298 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
| 302 |
+
|
| 303 |
+
dtype = self.text_encoder_2.dtype
|
| 304 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 305 |
+
|
| 306 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 307 |
+
|
| 308 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 309 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 310 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 311 |
+
|
| 312 |
+
return prompt_embeds
|
| 313 |
+
|
| 314 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
|
| 315 |
+
def _get_clip_prompt_embeds(
|
| 316 |
+
self,
|
| 317 |
+
prompt: Union[str, List[str]],
|
| 318 |
+
num_images_per_prompt: int = 1,
|
| 319 |
+
device: Optional[torch.device] = None,
|
| 320 |
+
):
|
| 321 |
+
device = device or self._execution_device
|
| 322 |
+
|
| 323 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 324 |
+
batch_size = len(prompt)
|
| 325 |
+
|
| 326 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 327 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
| 328 |
+
|
| 329 |
+
text_inputs = self.tokenizer(
|
| 330 |
+
prompt,
|
| 331 |
+
padding="max_length",
|
| 332 |
+
max_length=self.tokenizer_max_length,
|
| 333 |
+
truncation=True,
|
| 334 |
+
return_overflowing_tokens=False,
|
| 335 |
+
return_length=False,
|
| 336 |
+
return_tensors="pt",
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
text_input_ids = text_inputs.input_ids
|
| 340 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 341 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 342 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
| 343 |
+
logger.warning(
|
| 344 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 345 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
| 346 |
+
)
|
| 347 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
| 348 |
+
|
| 349 |
+
# Use pooled output of CLIPTextModel
|
| 350 |
+
prompt_embeds = prompt_embeds.pooler_output
|
| 351 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
| 352 |
+
|
| 353 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 354 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
| 355 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 356 |
+
|
| 357 |
+
return prompt_embeds
|
| 358 |
+
|
| 359 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
|
| 360 |
+
def encode_prompt(
|
| 361 |
+
self,
|
| 362 |
+
prompt: Union[str, List[str]],
|
| 363 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 364 |
+
device: Optional[torch.device] = None,
|
| 365 |
+
num_images_per_prompt: int = 1,
|
| 366 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 367 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 368 |
+
max_sequence_length: int = 512,
|
| 369 |
+
lora_scale: Optional[float] = None,
|
| 370 |
+
):
|
| 371 |
+
r"""
|
| 372 |
+
|
| 373 |
+
Args:
|
| 374 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 375 |
+
prompt to be encoded
|
| 376 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 377 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 378 |
+
used in all text-encoders
|
| 379 |
+
device: (`torch.device`):
|
| 380 |
+
torch device
|
| 381 |
+
num_images_per_prompt (`int`):
|
| 382 |
+
number of images that should be generated per prompt
|
| 383 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 384 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 385 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 386 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 387 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 388 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 389 |
+
lora_scale (`float`, *optional*):
|
| 390 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 391 |
+
"""
|
| 392 |
+
device = device or self._execution_device
|
| 393 |
+
|
| 394 |
+
# set lora scale so that monkey patched LoRA
|
| 395 |
+
# function of text encoder can correctly access it
|
| 396 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
| 397 |
+
self._lora_scale = lora_scale
|
| 398 |
+
|
| 399 |
+
# dynamically adjust the LoRA scale
|
| 400 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
| 401 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 402 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
| 403 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
| 404 |
+
|
| 405 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 406 |
+
|
| 407 |
+
if prompt_embeds is None:
|
| 408 |
+
prompt_2 = prompt_2 or prompt
|
| 409 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
| 410 |
+
|
| 411 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
| 412 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
| 413 |
+
prompt=prompt,
|
| 414 |
+
device=device,
|
| 415 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 416 |
+
)
|
| 417 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 418 |
+
prompt=prompt_2,
|
| 419 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 420 |
+
max_sequence_length=max_sequence_length,
|
| 421 |
+
device=device,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
if self.text_encoder is not None:
|
| 425 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 426 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 427 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 428 |
+
|
| 429 |
+
if self.text_encoder_2 is not None:
|
| 430 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 431 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 432 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
| 433 |
+
|
| 434 |
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
| 435 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
| 436 |
+
|
| 437 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
| 438 |
+
|
| 439 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
|
| 440 |
+
def encode_image(self, image, device, num_images_per_prompt):
|
| 441 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 442 |
+
|
| 443 |
+
if not isinstance(image, torch.Tensor):
|
| 444 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
| 445 |
+
|
| 446 |
+
image = image.to(device=device, dtype=dtype)
|
| 447 |
+
image_embeds = self.image_encoder(image).image_embeds
|
| 448 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 449 |
+
return image_embeds
|
| 450 |
+
|
| 451 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
|
| 452 |
+
def prepare_ip_adapter_image_embeds(
|
| 453 |
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
|
| 454 |
+
):
|
| 455 |
+
image_embeds = []
|
| 456 |
+
if ip_adapter_image_embeds is None:
|
| 457 |
+
if not isinstance(ip_adapter_image, list):
|
| 458 |
+
ip_adapter_image = [ip_adapter_image]
|
| 459 |
+
|
| 460 |
+
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
| 461 |
+
raise ValueError(
|
| 462 |
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
for single_ip_adapter_image in ip_adapter_image:
|
| 466 |
+
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
|
| 467 |
+
image_embeds.append(single_image_embeds[None, :])
|
| 468 |
+
else:
|
| 469 |
+
if not isinstance(ip_adapter_image_embeds, list):
|
| 470 |
+
ip_adapter_image_embeds = [ip_adapter_image_embeds]
|
| 471 |
+
|
| 472 |
+
if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
| 473 |
+
raise ValueError(
|
| 474 |
+
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
| 478 |
+
image_embeds.append(single_image_embeds)
|
| 479 |
+
|
| 480 |
+
ip_adapter_image_embeds = []
|
| 481 |
+
for single_image_embeds in image_embeds:
|
| 482 |
+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
| 483 |
+
single_image_embeds = single_image_embeds.to(device=device)
|
| 484 |
+
ip_adapter_image_embeds.append(single_image_embeds)
|
| 485 |
+
|
| 486 |
+
return ip_adapter_image_embeds
|
| 487 |
+
|
| 488 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs
|
| 489 |
+
def check_inputs(
|
| 490 |
+
self,
|
| 491 |
+
prompt,
|
| 492 |
+
prompt_2,
|
| 493 |
+
height,
|
| 494 |
+
width,
|
| 495 |
+
negative_prompt=None,
|
| 496 |
+
negative_prompt_2=None,
|
| 497 |
+
prompt_embeds=None,
|
| 498 |
+
negative_prompt_embeds=None,
|
| 499 |
+
pooled_prompt_embeds=None,
|
| 500 |
+
negative_pooled_prompt_embeds=None,
|
| 501 |
+
callback_on_step_end_tensor_inputs=None,
|
| 502 |
+
max_sequence_length=None,
|
| 503 |
+
):
|
| 504 |
+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
| 505 |
+
logger.warning(
|
| 506 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 510 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 511 |
+
):
|
| 512 |
+
raise ValueError(
|
| 513 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
if prompt is not None and prompt_embeds is not None:
|
| 517 |
+
raise ValueError(
|
| 518 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 519 |
+
" only forward one of the two."
|
| 520 |
+
)
|
| 521 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
| 522 |
+
raise ValueError(
|
| 523 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 524 |
+
" only forward one of the two."
|
| 525 |
+
)
|
| 526 |
+
elif prompt is None and prompt_embeds is None:
|
| 527 |
+
raise ValueError(
|
| 528 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 529 |
+
)
|
| 530 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 531 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 532 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
| 533 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
| 534 |
+
|
| 535 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 536 |
+
raise ValueError(
|
| 537 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 538 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 539 |
+
)
|
| 540 |
+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
| 541 |
+
raise ValueError(
|
| 542 |
+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
| 543 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
| 547 |
+
raise ValueError(
|
| 548 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
| 549 |
+
)
|
| 550 |
+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
| 551 |
+
raise ValueError(
|
| 552 |
+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
| 556 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
| 557 |
+
|
| 558 |
+
@staticmethod
|
| 559 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
| 560 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
| 561 |
+
latent_image_ids = torch.zeros(height, width, 3)
|
| 562 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
| 563 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
| 564 |
+
|
| 565 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 566 |
+
|
| 567 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 568 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
| 572 |
+
|
| 573 |
+
@staticmethod
|
| 574 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
| 575 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 576 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
| 577 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 578 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
| 579 |
+
|
| 580 |
+
return latents
|
| 581 |
+
|
| 582 |
+
@staticmethod
|
| 583 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
|
| 584 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
| 585 |
+
batch_size, num_patches, channels = latents.shape
|
| 586 |
+
|
| 587 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 588 |
+
# latent height and width to be divisible by 2.
|
| 589 |
+
height = 2 * (int(height) // (vae_scale_factor * 2))
|
| 590 |
+
width = 2 * (int(width) // (vae_scale_factor * 2))
|
| 591 |
+
|
| 592 |
+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
| 593 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 594 |
+
|
| 595 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
| 596 |
+
|
| 597 |
+
return latents
|
| 598 |
+
|
| 599 |
+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
| 600 |
+
if isinstance(generator, list):
|
| 601 |
+
image_latents = [
|
| 602 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
|
| 603 |
+
for i in range(image.shape[0])
|
| 604 |
+
]
|
| 605 |
+
image_latents = torch.cat(image_latents, dim=0)
|
| 606 |
+
else:
|
| 607 |
+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
|
| 608 |
+
|
| 609 |
+
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 610 |
+
|
| 611 |
+
return image_latents
|
| 612 |
+
|
| 613 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
|
| 614 |
+
def enable_vae_slicing(self):
|
| 615 |
+
r"""
|
| 616 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 617 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 618 |
+
"""
|
| 619 |
+
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
|
| 620 |
+
deprecate(
|
| 621 |
+
"enable_vae_slicing",
|
| 622 |
+
"0.40.0",
|
| 623 |
+
depr_message,
|
| 624 |
+
)
|
| 625 |
+
self.vae.enable_slicing()
|
| 626 |
+
|
| 627 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
|
| 628 |
+
def disable_vae_slicing(self):
|
| 629 |
+
r"""
|
| 630 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
| 631 |
+
computing decoding in one step.
|
| 632 |
+
"""
|
| 633 |
+
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
|
| 634 |
+
deprecate(
|
| 635 |
+
"disable_vae_slicing",
|
| 636 |
+
"0.40.0",
|
| 637 |
+
depr_message,
|
| 638 |
+
)
|
| 639 |
+
self.vae.disable_slicing()
|
| 640 |
+
|
| 641 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
|
| 642 |
+
def enable_vae_tiling(self):
|
| 643 |
+
r"""
|
| 644 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 645 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 646 |
+
processing larger images.
|
| 647 |
+
"""
|
| 648 |
+
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
|
| 649 |
+
deprecate(
|
| 650 |
+
"enable_vae_tiling",
|
| 651 |
+
"0.40.0",
|
| 652 |
+
depr_message,
|
| 653 |
+
)
|
| 654 |
+
self.vae.enable_tiling()
|
| 655 |
+
|
| 656 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
|
| 657 |
+
def disable_vae_tiling(self):
|
| 658 |
+
r"""
|
| 659 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
| 660 |
+
computing decoding in one step.
|
| 661 |
+
"""
|
| 662 |
+
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
|
| 663 |
+
deprecate(
|
| 664 |
+
"disable_vae_tiling",
|
| 665 |
+
"0.40.0",
|
| 666 |
+
depr_message,
|
| 667 |
+
)
|
| 668 |
+
self.vae.disable_tiling()
|
| 669 |
+
|
| 670 |
+
def prepare_latents(
|
| 671 |
+
self,
|
| 672 |
+
image: Optional[torch.Tensor],
|
| 673 |
+
batch_size: int,
|
| 674 |
+
num_channels_latents: int,
|
| 675 |
+
height: int,
|
| 676 |
+
width: int,
|
| 677 |
+
dtype: torch.dtype,
|
| 678 |
+
device: torch.device,
|
| 679 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 680 |
+
latents: Optional[torch.Tensor] = None,
|
| 681 |
+
):
|
| 682 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 683 |
+
raise ValueError(
|
| 684 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 685 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 689 |
+
# latent height and width to be divisible by 2.
|
| 690 |
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 691 |
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 692 |
+
shape = (batch_size, num_channels_latents, height, width)
|
| 693 |
+
|
| 694 |
+
image_latents = image_ids = None
|
| 695 |
+
if image is not None:
|
| 696 |
+
image = image.to(device=device, dtype=dtype)
|
| 697 |
+
if image.shape[1] != self.latent_channels:
|
| 698 |
+
image_latents = self._encode_vae_image(image=image, generator=generator)
|
| 699 |
+
else:
|
| 700 |
+
image_latents = image
|
| 701 |
+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
| 702 |
+
# expand init_latents for batch_size
|
| 703 |
+
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
| 704 |
+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
| 705 |
+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
| 706 |
+
raise ValueError(
|
| 707 |
+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
| 708 |
+
)
|
| 709 |
+
else:
|
| 710 |
+
image_latents = torch.cat([image_latents], dim=0)
|
| 711 |
+
|
| 712 |
+
image_latent_height, image_latent_width = image_latents.shape[2:]
|
| 713 |
+
image_latents = self._pack_latents(
|
| 714 |
+
image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
|
| 715 |
+
)
|
| 716 |
+
image_ids = self._prepare_latent_image_ids(
|
| 717 |
+
batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype
|
| 718 |
+
)
|
| 719 |
+
# image ids are the same as latent ids with the first dimension set to 1 instead of 0
|
| 720 |
+
image_ids[..., 0] = 1
|
| 721 |
+
|
| 722 |
+
latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
| 723 |
+
|
| 724 |
+
if latents is None:
|
| 725 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 726 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
| 727 |
+
else:
|
| 728 |
+
latents = latents.to(device=device, dtype=dtype)
|
| 729 |
+
|
| 730 |
+
return latents, image_latents, latent_ids, image_ids
|
| 731 |
+
|
| 732 |
+
@property
|
| 733 |
+
def guidance_scale(self):
|
| 734 |
+
return self._guidance_scale
|
| 735 |
+
|
| 736 |
+
@property
|
| 737 |
+
def joint_attention_kwargs(self):
|
| 738 |
+
return self._joint_attention_kwargs
|
| 739 |
+
|
| 740 |
+
@property
|
| 741 |
+
def num_timesteps(self):
|
| 742 |
+
return self._num_timesteps
|
| 743 |
+
|
| 744 |
+
@property
|
| 745 |
+
def current_timestep(self):
|
| 746 |
+
return self._current_timestep
|
| 747 |
+
|
| 748 |
+
@property
|
| 749 |
+
def interrupt(self):
|
| 750 |
+
return self._interrupt
|
| 751 |
+
|
| 752 |
+
@torch.no_grad()
|
| 753 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 754 |
+
def __call__(
|
| 755 |
+
self,
|
| 756 |
+
image: Optional[PipelineImageInput] = None,
|
| 757 |
+
prompt: Union[str, List[str]] = None,
|
| 758 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 759 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 760 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 761 |
+
true_cfg_scale: float = 1.0,
|
| 762 |
+
height: Optional[int] = None,
|
| 763 |
+
width: Optional[int] = None,
|
| 764 |
+
num_inference_steps: int = 28,
|
| 765 |
+
sigmas: Optional[List[float]] = None,
|
| 766 |
+
guidance_scale: float = 3.5,
|
| 767 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 768 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 769 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 770 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 771 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 772 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 773 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 774 |
+
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 775 |
+
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 776 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 777 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 778 |
+
output_type: Optional[str] = "pil",
|
| 779 |
+
return_dict: bool = True,
|
| 780 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 781 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 782 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 783 |
+
max_sequence_length: int = 512,
|
| 784 |
+
max_area: int = 1024**2,
|
| 785 |
+
_auto_resize: bool = True,
|
| 786 |
+
):
|
| 787 |
+
r"""
|
| 788 |
+
Function invoked when calling the pipeline for generation.
|
| 789 |
+
|
| 790 |
+
Args:
|
| 791 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
| 792 |
+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
| 793 |
+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
| 794 |
+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
| 795 |
+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
| 796 |
+
latents as `image`, but if passing latents directly it is not encoded again.
|
| 797 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 798 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 799 |
+
instead.
|
| 800 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 801 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 802 |
+
will be used instead.
|
| 803 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 804 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 805 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
|
| 806 |
+
not greater than `1`).
|
| 807 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 808 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
| 809 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
| 810 |
+
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
| 811 |
+
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
|
| 812 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 813 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 814 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 815 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 816 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 817 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 818 |
+
expense of slower inference.
|
| 819 |
+
sigmas (`List[float]`, *optional*):
|
| 820 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 821 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 822 |
+
will be used.
|
| 823 |
+
guidance_scale (`float`, *optional*, defaults to 3.5):
|
| 824 |
+
Embedded guidance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
| 825 |
+
a model to generate images more aligned with prompt at the expense of lower image quality.
|
| 826 |
+
|
| 827 |
+
Guidance-distilled models approximates true classifier-free guidance for `guidance_scale` > 1. Refer to
|
| 828 |
+
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
| 829 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 830 |
+
The number of images to generate per prompt.
|
| 831 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 832 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 833 |
+
to make generation deterministic.
|
| 834 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 835 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 836 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 837 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 838 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 839 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 840 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 841 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 842 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 843 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 844 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*):
|
| 845 |
+
Optional image input to work with IP Adapters.
|
| 846 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 847 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 848 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
| 849 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 850 |
+
negative_ip_adapter_image:
|
| 851 |
+
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 852 |
+
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 853 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 854 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
| 855 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 856 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 857 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 858 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 859 |
+
argument.
|
| 860 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 861 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 862 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 863 |
+
input argument.
|
| 864 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 865 |
+
The output format of the generate image. Choose between
|
| 866 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 867 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 868 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
| 869 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 870 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 871 |
+
`self.processor` in
|
| 872 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 873 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 874 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 875 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 876 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 877 |
+
`callback_on_step_end_tensor_inputs`.
|
| 878 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 879 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 880 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 881 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 882 |
+
max_sequence_length (`int` defaults to 512):
|
| 883 |
+
Maximum sequence length to use with the `prompt`.
|
| 884 |
+
max_area (`int`, defaults to `1024 ** 2`):
|
| 885 |
+
The maximum area of the generated image in pixels. The height and width will be adjusted to fit this
|
| 886 |
+
area while maintaining the aspect ratio.
|
| 887 |
+
|
| 888 |
+
Examples:
|
| 889 |
+
|
| 890 |
+
Returns:
|
| 891 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
| 892 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
| 893 |
+
images.
|
| 894 |
+
"""
|
| 895 |
+
|
| 896 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 897 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 898 |
+
|
| 899 |
+
original_height, original_width = height, width
|
| 900 |
+
aspect_ratio = width / height
|
| 901 |
+
|
| 902 |
+
width = round((max_area * aspect_ratio) ** 0.5)
|
| 903 |
+
height = round((max_area / aspect_ratio) ** 0.5)
|
| 904 |
+
|
| 905 |
+
multiple_of = self.vae_scale_factor * 2
|
| 906 |
+
width = width // multiple_of * multiple_of
|
| 907 |
+
height = height // multiple_of * multiple_of
|
| 908 |
+
|
| 909 |
+
if height != original_height or width != original_width:
|
| 910 |
+
logger.warning(
|
| 911 |
+
f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
|
| 912 |
+
)
|
| 913 |
+
|
| 914 |
+
# 1. Check inputs. Raise error if not correct
|
| 915 |
+
self.check_inputs(
|
| 916 |
+
prompt,
|
| 917 |
+
prompt_2,
|
| 918 |
+
height,
|
| 919 |
+
width,
|
| 920 |
+
negative_prompt=negative_prompt,
|
| 921 |
+
negative_prompt_2=negative_prompt_2,
|
| 922 |
+
prompt_embeds=prompt_embeds,
|
| 923 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 924 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 925 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 926 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 927 |
+
max_sequence_length=max_sequence_length,
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
self._guidance_scale = guidance_scale
|
| 931 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 932 |
+
self._current_timestep = None
|
| 933 |
+
self._interrupt = False
|
| 934 |
+
|
| 935 |
+
# 2. Define call parameters
|
| 936 |
+
if prompt is not None and isinstance(prompt, str):
|
| 937 |
+
batch_size = 1
|
| 938 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 939 |
+
batch_size = len(prompt)
|
| 940 |
+
else:
|
| 941 |
+
batch_size = prompt_embeds.shape[0]
|
| 942 |
+
|
| 943 |
+
device = self._execution_device
|
| 944 |
+
|
| 945 |
+
lora_scale = (
|
| 946 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 947 |
+
)
|
| 948 |
+
has_neg_prompt = negative_prompt is not None or (
|
| 949 |
+
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
|
| 950 |
+
)
|
| 951 |
+
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
| 952 |
+
(
|
| 953 |
+
prompt_embeds,
|
| 954 |
+
pooled_prompt_embeds,
|
| 955 |
+
text_ids,
|
| 956 |
+
) = self.encode_prompt(
|
| 957 |
+
prompt=prompt,
|
| 958 |
+
prompt_2=prompt_2,
|
| 959 |
+
prompt_embeds=prompt_embeds,
|
| 960 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 961 |
+
device=device,
|
| 962 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 963 |
+
max_sequence_length=max_sequence_length,
|
| 964 |
+
lora_scale=lora_scale,
|
| 965 |
+
)
|
| 966 |
+
if do_true_cfg:
|
| 967 |
+
(
|
| 968 |
+
negative_prompt_embeds,
|
| 969 |
+
negative_pooled_prompt_embeds,
|
| 970 |
+
negative_text_ids,
|
| 971 |
+
) = self.encode_prompt(
|
| 972 |
+
prompt=negative_prompt,
|
| 973 |
+
prompt_2=negative_prompt_2,
|
| 974 |
+
prompt_embeds=negative_prompt_embeds,
|
| 975 |
+
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 976 |
+
device=device,
|
| 977 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 978 |
+
max_sequence_length=max_sequence_length,
|
| 979 |
+
lora_scale=lora_scale,
|
| 980 |
+
)
|
| 981 |
+
|
| 982 |
+
# 3. Preprocess image
|
| 983 |
+
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
| 984 |
+
img = image[0] if isinstance(image, list) else image
|
| 985 |
+
image_height, image_width = self.image_processor.get_default_height_width(img)
|
| 986 |
+
aspect_ratio = image_width / image_height
|
| 987 |
+
if _auto_resize:
|
| 988 |
+
# Kontext is trained on specific resolutions, using one of them is recommended
|
| 989 |
+
_, image_width, image_height = min(
|
| 990 |
+
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
|
| 991 |
+
)
|
| 992 |
+
image_width = image_width // multiple_of * multiple_of
|
| 993 |
+
image_height = image_height // multiple_of * multiple_of
|
| 994 |
+
image = self.image_processor.resize(image, image_height, image_width)
|
| 995 |
+
image = self.image_processor.preprocess(image, image_height, image_width)
|
| 996 |
+
|
| 997 |
+
# 4. Prepare latent variables
|
| 998 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 999 |
+
latents, image_latents, latent_ids, image_ids = self.prepare_latents(
|
| 1000 |
+
image,
|
| 1001 |
+
batch_size * num_images_per_prompt,
|
| 1002 |
+
num_channels_latents,
|
| 1003 |
+
height,
|
| 1004 |
+
width,
|
| 1005 |
+
prompt_embeds.dtype,
|
| 1006 |
+
device,
|
| 1007 |
+
generator,
|
| 1008 |
+
latents,
|
| 1009 |
+
)
|
| 1010 |
+
if image_ids is not None:
|
| 1011 |
+
latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
|
| 1012 |
+
|
| 1013 |
+
# 5. Prepare timesteps
|
| 1014 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 1015 |
+
image_seq_len = latents.shape[1]
|
| 1016 |
+
mu = calculate_shift(
|
| 1017 |
+
image_seq_len,
|
| 1018 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 1019 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 1020 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 1021 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 1022 |
+
)
|
| 1023 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 1024 |
+
self.scheduler,
|
| 1025 |
+
num_inference_steps,
|
| 1026 |
+
device,
|
| 1027 |
+
sigmas=sigmas,
|
| 1028 |
+
mu=mu,
|
| 1029 |
+
)
|
| 1030 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 1031 |
+
self._num_timesteps = len(timesteps)
|
| 1032 |
+
|
| 1033 |
+
# handle guidance
|
| 1034 |
+
if self.transformer.config.guidance_embeds:
|
| 1035 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 1036 |
+
guidance = guidance.expand(latents.shape[0])
|
| 1037 |
+
else:
|
| 1038 |
+
guidance = None
|
| 1039 |
+
|
| 1040 |
+
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
|
| 1041 |
+
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
|
| 1042 |
+
):
|
| 1043 |
+
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
| 1044 |
+
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
| 1045 |
+
|
| 1046 |
+
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
|
| 1047 |
+
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
|
| 1048 |
+
):
|
| 1049 |
+
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
| 1050 |
+
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
| 1051 |
+
|
| 1052 |
+
if self.joint_attention_kwargs is None:
|
| 1053 |
+
self._joint_attention_kwargs = {}
|
| 1054 |
+
|
| 1055 |
+
image_embeds = None
|
| 1056 |
+
negative_image_embeds = None
|
| 1057 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 1058 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 1059 |
+
ip_adapter_image,
|
| 1060 |
+
ip_adapter_image_embeds,
|
| 1061 |
+
device,
|
| 1062 |
+
batch_size * num_images_per_prompt,
|
| 1063 |
+
)
|
| 1064 |
+
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
|
| 1065 |
+
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 1066 |
+
negative_ip_adapter_image,
|
| 1067 |
+
negative_ip_adapter_image_embeds,
|
| 1068 |
+
device,
|
| 1069 |
+
batch_size * num_images_per_prompt,
|
| 1070 |
+
)
|
| 1071 |
+
|
| 1072 |
+
# 6. Denoising loop
|
| 1073 |
+
# We set the index here to remove DtoH sync, helpful especially during compilation.
|
| 1074 |
+
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
|
| 1075 |
+
self.scheduler.set_begin_index(0)
|
| 1076 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1077 |
+
for i, t in enumerate(timesteps):
|
| 1078 |
+
if self.interrupt:
|
| 1079 |
+
continue
|
| 1080 |
+
|
| 1081 |
+
self._current_timestep = t
|
| 1082 |
+
if image_embeds is not None:
|
| 1083 |
+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
| 1084 |
+
|
| 1085 |
+
latent_model_input = latents
|
| 1086 |
+
if image_latents is not None:
|
| 1087 |
+
latent_model_input = torch.cat([latents, image_latents], dim=1)
|
| 1088 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 1089 |
+
|
| 1090 |
+
noise_pred = self.transformer(
|
| 1091 |
+
hidden_states=latent_model_input,
|
| 1092 |
+
timestep=timestep / 1000,
|
| 1093 |
+
guidance=guidance,
|
| 1094 |
+
pooled_projections=pooled_prompt_embeds,
|
| 1095 |
+
encoder_hidden_states=prompt_embeds,
|
| 1096 |
+
txt_ids=text_ids,
|
| 1097 |
+
img_ids=latent_ids,
|
| 1098 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1099 |
+
return_dict=False,
|
| 1100 |
+
)[0]
|
| 1101 |
+
noise_pred = noise_pred[:, : latents.size(1)]
|
| 1102 |
+
|
| 1103 |
+
if do_true_cfg:
|
| 1104 |
+
if negative_image_embeds is not None:
|
| 1105 |
+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
| 1106 |
+
neg_noise_pred = self.transformer(
|
| 1107 |
+
hidden_states=latent_model_input,
|
| 1108 |
+
timestep=timestep / 1000,
|
| 1109 |
+
guidance=guidance,
|
| 1110 |
+
pooled_projections=negative_pooled_prompt_embeds,
|
| 1111 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 1112 |
+
txt_ids=negative_text_ids,
|
| 1113 |
+
img_ids=latent_ids,
|
| 1114 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1115 |
+
return_dict=False,
|
| 1116 |
+
)[0]
|
| 1117 |
+
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
|
| 1118 |
+
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
| 1119 |
+
|
| 1120 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1121 |
+
latents_dtype = latents.dtype
|
| 1122 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 1123 |
+
|
| 1124 |
+
if latents.dtype != latents_dtype:
|
| 1125 |
+
if torch.backends.mps.is_available():
|
| 1126 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 1127 |
+
latents = latents.to(latents_dtype)
|
| 1128 |
+
|
| 1129 |
+
if callback_on_step_end is not None:
|
| 1130 |
+
callback_kwargs = {}
|
| 1131 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 1132 |
+
callback_kwargs[k] = locals()[k]
|
| 1133 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1134 |
+
|
| 1135 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1136 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1137 |
+
|
| 1138 |
+
# call the callback, if provided
|
| 1139 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1140 |
+
progress_bar.update()
|
| 1141 |
+
|
| 1142 |
+
if XLA_AVAILABLE:
|
| 1143 |
+
xm.mark_step()
|
| 1144 |
+
|
| 1145 |
+
self._current_timestep = None
|
| 1146 |
+
|
| 1147 |
+
if output_type == "latent":
|
| 1148 |
+
image = latents
|
| 1149 |
+
else:
|
| 1150 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 1151 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 1152 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 1153 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 1154 |
+
|
| 1155 |
+
# Offload all models
|
| 1156 |
+
self.maybe_free_model_hooks()
|
| 1157 |
+
|
| 1158 |
+
if not return_dict:
|
| 1159 |
+
return (image,)
|
| 1160 |
+
|
| 1161 |
+
return FluxPipelineOutput(images=image)
|
unimodel/qwenkontext/qwenkontext_inference.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Fu-Yun Wang
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import List, Optional, Tuple, Union, Dict
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
from PIL import Image
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoProcessor
|
| 21 |
+
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLModel, Qwen2_5_VLForConditionalGeneration
|
| 22 |
+
from qwen_vl_utils import process_vision_info
|
| 23 |
+
import torchvision.transforms as transforms
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 27 |
+
from diffusers.pipelines.pipeline_utils import numpy_to_pil
|
| 28 |
+
import numpy as np
|
| 29 |
+
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteSchedulerOutput
|
| 30 |
+
from diffusers.schedulers import DPMSolverMultistepScheduler
|
| 31 |
+
import math
|
| 32 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 33 |
+
from diffusers import FluxTransformer2DModel, AutoencoderKL, FlowMatchEulerDiscreteScheduler #, FluxKontextPipeline
|
| 34 |
+
from .fluxkontext_pipeline import FluxKontextPipeline
|
| 35 |
+
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, CLIPTextConfig, T5Config
|
| 36 |
+
import re
|
| 37 |
+
import datetime
|
| 38 |
+
import os
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def save_grid_image(prompt, images, rows, cols):
|
| 42 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 43 |
+
base_dir = os.path.join("samples", timestamp, prompt[:100])
|
| 44 |
+
os.makedirs(base_dir, exist_ok=True)
|
| 45 |
+
|
| 46 |
+
filename = os.path.join(base_dir, "grid.jpg")
|
| 47 |
+
grid_image = create_image_grid(images, rows, cols)
|
| 48 |
+
grid_image.save(filename)
|
| 49 |
+
|
| 50 |
+
print(f"Saved: {filename}")
|
| 51 |
+
|
| 52 |
+
def create_image_grid(images, rows, cols):
|
| 53 |
+
"""Creates a grid of images and returns a single PIL Image."""
|
| 54 |
+
|
| 55 |
+
assert len(images) == rows * cols
|
| 56 |
+
|
| 57 |
+
width, height = images[0].size
|
| 58 |
+
grid_width = width * cols
|
| 59 |
+
grid_height = height * rows
|
| 60 |
+
|
| 61 |
+
grid_image = Image.new('RGB', (grid_width, grid_height))
|
| 62 |
+
|
| 63 |
+
for i, image in enumerate(images):
|
| 64 |
+
x = (i % cols) * width
|
| 65 |
+
y = (i // cols) * height
|
| 66 |
+
grid_image.paste(image, (x, y))
|
| 67 |
+
|
| 68 |
+
return grid_image
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def sde_step_with_logprob(
|
| 72 |
+
self: FlowMatchEulerDiscreteScheduler,
|
| 73 |
+
model_output: torch.FloatTensor,
|
| 74 |
+
timestep: Union[float, torch.FloatTensor],
|
| 75 |
+
sample: torch.FloatTensor,
|
| 76 |
+
prev_sample: Optional[torch.FloatTensor] = None,
|
| 77 |
+
generator: Optional[torch.Generator] = None,
|
| 78 |
+
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
| 79 |
+
"""
|
| 80 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
|
| 81 |
+
process from the learned model outputs (most often the predicted velocity).
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
model_output (`torch.FloatTensor`):
|
| 85 |
+
The direct output from learned flow model.
|
| 86 |
+
timestep (`float`):
|
| 87 |
+
The current discrete timestep in the diffusion chain.
|
| 88 |
+
sample (`torch.FloatTensor`):
|
| 89 |
+
A current instance of a sample created by the diffusion process.
|
| 90 |
+
generator (`torch.Generator`, *optional*):
|
| 91 |
+
A random number generator.
|
| 92 |
+
"""
|
| 93 |
+
step_index = [self.index_for_timestep(t) for t in timestep]
|
| 94 |
+
prev_step_index = [step+1 for step in step_index]
|
| 95 |
+
sigma = self.sigmas[step_index].view(-1, 1, 1).to(model_output.device)
|
| 96 |
+
sigma_prev = self.sigmas[prev_step_index].view(-1, 1, 1).to(model_output.device)
|
| 97 |
+
sigma_max = self.sigmas[1].item()
|
| 98 |
+
dt = sigma_prev - sigma
|
| 99 |
+
|
| 100 |
+
std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*1.0
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# our sde
|
| 104 |
+
prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
|
| 105 |
+
|
| 106 |
+
if prev_sample is not None and generator is not None:
|
| 107 |
+
raise ValueError(
|
| 108 |
+
"Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
|
| 109 |
+
" `prev_sample` stays `None`."
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if prev_sample is None:
|
| 113 |
+
variance_noise = randn_tensor(
|
| 114 |
+
model_output.shape,
|
| 115 |
+
generator=generator,
|
| 116 |
+
device=model_output.device,
|
| 117 |
+
dtype=model_output.dtype,
|
| 118 |
+
)
|
| 119 |
+
prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1*dt) * variance_noise
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
log_prob = (
|
| 123 |
+
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1*dt))**2))
|
| 124 |
+
- torch.log(std_dev_t * torch.sqrt(-1*dt))
|
| 125 |
+
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# mean along all but batch dimension
|
| 129 |
+
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
|
| 130 |
+
|
| 131 |
+
return prev_sample, log_prob, prev_sample_mean, std_dev_t * torch.sqrt(-1*dt)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# Copyright 2025 Fu-Yun Wang
|
| 135 |
+
#
|
| 136 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 137 |
+
# you may not use this file except in compliance with the License.
|
| 138 |
+
# You may obtain a copy of the License at
|
| 139 |
+
#
|
| 140 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 141 |
+
#
|
| 142 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 143 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 144 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 145 |
+
# See the License for the specific language governing permissions and
|
| 146 |
+
# limitations under the License.
|
| 147 |
+
|
| 148 |
+
def sde_step_with_logprob_simple(
|
| 149 |
+
self: FlowMatchEulerDiscreteScheduler,
|
| 150 |
+
model_output: torch.FloatTensor,
|
| 151 |
+
timestep: Union[float, torch.FloatTensor],
|
| 152 |
+
sample: torch.FloatTensor,
|
| 153 |
+
prev_sample: Optional[torch.FloatTensor] = None,
|
| 154 |
+
generator: Optional[torch.Generator] = None,
|
| 155 |
+
):
|
| 156 |
+
"""
|
| 157 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
|
| 158 |
+
process from the learned model outputs (most often the predicted velocity).
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
model_output (`torch.FloatTensor`):
|
| 162 |
+
The direct output from learned flow model.
|
| 163 |
+
timestep (`float`):
|
| 164 |
+
The current discrete timestep in the diffusion chain.
|
| 165 |
+
sample (`torch.FloatTensor`):
|
| 166 |
+
A current instance of a sample created by the diffusion process.
|
| 167 |
+
generator (`torch.Generator`, *optional*):
|
| 168 |
+
A random number generator.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
step_index = [self.index_for_timestep(t) for t in timestep]
|
| 172 |
+
prev_step_index = [step+1 for step in step_index]
|
| 173 |
+
sigma = self.sigmas[step_index].view(-1, 1, 1, 1).to(model_output.device)
|
| 174 |
+
sigma_prev = self.sigmas[prev_step_index].view(-1, 1, 1, 1).to(model_output.device)
|
| 175 |
+
sigma_max = self.sigmas[1].item()
|
| 176 |
+
dt = sigma_prev - sigma
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
eta = 0.5
|
| 180 |
+
Dt = - dt * eta
|
| 181 |
+
|
| 182 |
+
prev_sample_mean = sample * (1 - Dt / (1 - torch.where(sigma == 1, sigma_max, sigma))) + model_output * (dt - Dt)
|
| 183 |
+
|
| 184 |
+
std_dev_t = torch.sqrt(2 * Dt * (sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))))
|
| 185 |
+
|
| 186 |
+
if prev_sample is not None and generator is not None:
|
| 187 |
+
raise ValueError(
|
| 188 |
+
"Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
|
| 189 |
+
" `prev_sample` stays `None`."
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if prev_sample is None:
|
| 193 |
+
# Generate noise if not provided
|
| 194 |
+
variance_noise = randn_tensor(
|
| 195 |
+
model_output.shape,
|
| 196 |
+
generator=generator,
|
| 197 |
+
device=model_output.device,
|
| 198 |
+
dtype=model_output.dtype,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
prev_sample = prev_sample_mean + std_dev_t * variance_noise
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
log_prob = (
|
| 205 |
+
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
|
| 206 |
+
- torch.log(std_dev_t)
|
| 207 |
+
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# mean along all but batch dimension
|
| 211 |
+
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
|
| 212 |
+
|
| 213 |
+
return prev_sample, log_prob, prev_sample_mean, std_dev_t
|
| 214 |
+
|
| 215 |
+
class QwenKontextMetaModel:
|
| 216 |
+
|
| 217 |
+
def __init__(self, config):
|
| 218 |
+
super(QwenKontextMetaModel, self).__init__(config)
|
| 219 |
+
|
| 220 |
+
if hasattr(config, "diffusion_expert"):
|
| 221 |
+
ckpt_id = "black-forest-labs/FLUX.1-Kontext-dev"
|
| 222 |
+
# Load configuration for each component
|
| 223 |
+
transformer_config = FluxTransformer2DModel.load_config(ckpt_id, subfolder="transformer")
|
| 224 |
+
vae_config = AutoencoderKL.load_config(ckpt_id, subfolder="vae")
|
| 225 |
+
text_encoder_config = CLIPTextConfig.from_pretrained(ckpt_id, subfolder="text_encoder")
|
| 226 |
+
text_encoder_2_config = T5Config.from_pretrained(ckpt_id, subfolder="text_encoder_2")
|
| 227 |
+
|
| 228 |
+
# Initialize components from their configurations
|
| 229 |
+
self.transformer = FluxTransformer2DModel.from_config(transformer_config)
|
| 230 |
+
self.vae = AutoencoderKL.from_config(vae_config)
|
| 231 |
+
self.text_encoder = CLIPTextModel(text_encoder_config)
|
| 232 |
+
self.text_encoder_2 = T5EncoderModel(text_encoder_2_config)
|
| 233 |
+
|
| 234 |
+
# Initialize tokenizers (these don't use from_config as they are not models)
|
| 235 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(ckpt_id, subfolder="tokenizer")
|
| 236 |
+
self.tokenizer_2 = T5TokenizerFast.from_pretrained(ckpt_id, subfolder="tokenizer_2")
|
| 237 |
+
|
| 238 |
+
self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(ckpt_id, subfolder="scheduler")
|
| 239 |
+
|
| 240 |
+
# Create the pipeline configuration dictionary
|
| 241 |
+
pipeline_config = {
|
| 242 |
+
"transformer": self.transformer,
|
| 243 |
+
"scheduler": self.scheduler,
|
| 244 |
+
"vae": self.vae,
|
| 245 |
+
"text_encoder": self.text_encoder,
|
| 246 |
+
"text_encoder_2": self.text_encoder_2,
|
| 247 |
+
"tokenizer": self.tokenizer,
|
| 248 |
+
"tokenizer_2": self.tokenizer_2,
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
self.diffusion_expert = FluxKontextPipeline(**pipeline_config)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def initialize_diffusion_expert(self, fsdp=None):
|
| 255 |
+
|
| 256 |
+
if getattr(self, 'diffusion_expert', None) is None:
|
| 257 |
+
print("random initiation the diffusion expert !!!")
|
| 258 |
+
self.diffusion_expert = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", revision="main", torch_dtype=torch.bfloat16).to(torch.bfloat16)
|
| 259 |
+
self.text_encoder = self.diffusion_expert.text_encoder
|
| 260 |
+
self.text_encoder_2 = self.diffusion_expert.text_encoder_2
|
| 261 |
+
self.tokenizer = self.diffusion_expert.tokenizer
|
| 262 |
+
self.tokenizer_2 = self.diffusion_expert.tokenizer_2
|
| 263 |
+
self.vae = self.diffusion_expert.vae
|
| 264 |
+
self.transformer = self.diffusion_expert.transformer
|
| 265 |
+
self.scheduler = self.diffusion_expert.scheduler
|
| 266 |
+
|
| 267 |
+
self.config.diffusion_expert = "flux"
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class QwenKontextConfig(Qwen2_5_VLConfig):
|
| 272 |
+
model_type = "QwenKontext"
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class QwenKontextModel(QwenKontextMetaModel, Qwen2_5_VLModel):
|
| 276 |
+
config_class = QwenKontextConfig
|
| 277 |
+
|
| 278 |
+
def __init__(self, config: Qwen2_5_VLConfig):
|
| 279 |
+
super(QwenKontextModel, self).__init__(config)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class QwenKontextForInferenceLM(Qwen2_5_VLForConditionalGeneration):
|
| 283 |
+
config_class = QwenKontextConfig
|
| 284 |
+
|
| 285 |
+
def __init__(self, config):
|
| 286 |
+
Qwen2_5_VLForConditionalGeneration.__init__(self, config)
|
| 287 |
+
config.model_type = "QwenKontext"
|
| 288 |
+
|
| 289 |
+
self.model = QwenKontextModel(config)
|
| 290 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 291 |
+
self.post_init()
|
| 292 |
+
|
| 293 |
+
def get_model(self):
|
| 294 |
+
return self.model
|
| 295 |
+
|
| 296 |
+
@torch.no_grad()
|
| 297 |
+
def generate_image(
|
| 298 |
+
self,
|
| 299 |
+
images: List[Image.Image],
|
| 300 |
+
texts: List[str],
|
| 301 |
+
diffusion_kwargs: Optional[Dict] = dict(guidance_scale = 3.5, num_inference_steps=25),
|
| 302 |
+
sde_sampling: Optional[bool] = False,
|
| 303 |
+
):
|
| 304 |
+
|
| 305 |
+
if isinstance(texts, str):
|
| 306 |
+
texts = [texts]
|
| 307 |
+
|
| 308 |
+
if not sde_sampling:
|
| 309 |
+
output_img = self.model.diffusion_expert(
|
| 310 |
+
images,
|
| 311 |
+
texts,
|
| 312 |
+
max_sequence_length=512,
|
| 313 |
+
**diffusion_kwargs,
|
| 314 |
+
).images
|
| 315 |
+
return output_img
|
| 316 |
+
else:
|
| 317 |
+
return self.model.diffusion_expert.sde_sampling(
|
| 318 |
+
images,
|
| 319 |
+
texts,
|
| 320 |
+
max_sequence_length=512,
|
| 321 |
+
**diffusion_kwargs,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def extract_thinking_content(self, text: str) -> str:
|
| 326 |
+
pattern = r'<answer>(.*?)</answer>'
|
| 327 |
+
matches = re.findall(pattern, text, re.DOTALL)
|
| 328 |
+
|
| 329 |
+
if matches:
|
| 330 |
+
return matches[-1].strip().replace("<answer>", "").replace("</answer>", "")
|
| 331 |
+
else:
|
| 332 |
+
return text.strip().replace("<answer>", "").replace("</answer>", "")
|
| 333 |
+
|
| 334 |
+
@torch.no_grad()
|
| 335 |
+
def generate_image_cot(
|
| 336 |
+
self,
|
| 337 |
+
images: List[Image.Image],
|
| 338 |
+
texts: List[str],
|
| 339 |
+
processor: Optional[object] = None,
|
| 340 |
+
diffusion_kwargs: Optional[Dict] = dict(guidance_scale = 2.5, num_inference_steps=25),
|
| 341 |
+
llm_kwargs: Optional[Dict] = dict(max_new_tokens=256, temperature=0.7, top_p=0.9, do_sample=True),
|
| 342 |
+
cot_prompt_template: Optional[str] = None,
|
| 343 |
+
):
|
| 344 |
+
|
| 345 |
+
if isinstance(texts, str):
|
| 346 |
+
texts = [texts]
|
| 347 |
+
|
| 348 |
+
if cot_prompt_template is None:
|
| 349 |
+
cot_prompt_template = """Please provide an enhanced prompt for the following image editing prompt.
|
| 350 |
+
Ensure the revised prompt is clear, specific, and includes detailed instructions to achieve the desired outcome while maintaining the original intent.
|
| 351 |
+
Original prompt: {original_prompt}. Directly provide the improved prompt in <answer> </answer> tags."""
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
improved_prompts = []
|
| 355 |
+
|
| 356 |
+
for text, image in zip(texts, images):
|
| 357 |
+
cot_input = cot_prompt_template.format(original_prompt=text)
|
| 358 |
+
|
| 359 |
+
messages = [
|
| 360 |
+
{
|
| 361 |
+
"role": "user",
|
| 362 |
+
"content": [
|
| 363 |
+
{
|
| 364 |
+
"type": "image",
|
| 365 |
+
"image": image,
|
| 366 |
+
},
|
| 367 |
+
{"type": "text", "text": cot_input},
|
| 368 |
+
],
|
| 369 |
+
}
|
| 370 |
+
]
|
| 371 |
+
|
| 372 |
+
input_text_formatted = processor.apply_chat_template(
|
| 373 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 374 |
+
)
|
| 375 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
| 376 |
+
model_inputs = processor(
|
| 377 |
+
images=image_inputs,
|
| 378 |
+
text=[input_text_formatted],
|
| 379 |
+
return_tensors="pt"
|
| 380 |
+
).to(self.device)
|
| 381 |
+
|
| 382 |
+
generated_ids = self.generate(
|
| 383 |
+
**model_inputs,
|
| 384 |
+
**llm_kwargs,
|
| 385 |
+
eos_token_id=processor.tokenizer.eos_token_id,
|
| 386 |
+
pad_token_id=processor.tokenizer.pad_token_id
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
generated_text = processor.batch_decode(
|
| 390 |
+
generated_ids[:, model_inputs['input_ids'].shape[1]:],
|
| 391 |
+
skip_special_tokens=True
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
improved_prompt = [self.extract_thinking_content(decode_text) for decode_text in generated_text]
|
| 395 |
+
improved_prompts.extend(improved_prompt)
|
| 396 |
+
|
| 397 |
+
print(f"Original prompt: {text}")
|
| 398 |
+
print(f"Improved prompt: {improved_prompt}")
|
| 399 |
+
print("-" * 50)
|
| 400 |
+
|
| 401 |
+
output_images = self.generate_image(images, improved_prompts, diffusion_kwargs)
|
| 402 |
+
|
| 403 |
+
return {
|
| 404 |
+
'ref_images': images,
|
| 405 |
+
'images': output_images,
|
| 406 |
+
'original_prompts': texts,
|
| 407 |
+
'improved_prompts': improved_prompts
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
AutoConfig.register("QwenKontext", QwenKontextConfig)
|
| 411 |
+
AutoModelForCausalLM.register(QwenKontextConfig, QwenKontextForInferenceLM)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
if __name__ == "__main__":
|
| 415 |
+
model = QwenKontextForInferenceLM.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct",torch_dtype=torch.bfloat16)
|
| 416 |
+
model.model.initialize_diffusion_expert()
|
| 417 |
+
model.model.diffusion_expert.to("cuda:0")
|
| 418 |
+
model.to("cuda:0")
|
| 419 |
+
AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
| 420 |
+
text = ["add a hat to him"]
|
| 421 |
+
ref_image = [Image.open("assets/images/cat.jpg").convert("RGB")]
|
| 422 |
+
images = model.generate_image(ref_image, text)
|
| 423 |
+
images[0].save("test_flux.jpg")
|
| 424 |
+
model.save_pretrained("outputs/pretrain/qwenkontext")
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
# model = QwenKontextForInferenceLM.from_pretrained("outputs/pretrain/qwenkontext", torch_dtype=torch.bfloat16)
|
| 428 |
+
# model.to("cuda:0")
|
| 429 |
+
# transform = transforms.Compose([
|
| 430 |
+
# transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), # Shortest side to 512
|
| 431 |
+
# transforms.CenterCrop((512, 512)) # Center crop to 512x512
|
| 432 |
+
# ])
|
| 433 |
+
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
| 434 |
+
# text = ["add a hat to him"]
|
| 435 |
+
# ref_image = [transform(Image.open("assets/images/cat.jpg").convert("RGB"))]
|
| 436 |
+
# ref_image[0].save("ref.jpg")
|
| 437 |
+
# images = model.generate_image(ref_image, text)
|
| 438 |
+
# images[0].save("test_flux.jpg")
|
| 439 |
+
|
| 440 |
+
# outputs = model.generate_image_cot(ref_image, text, processor = processor)
|
| 441 |
+
# outputs['images'][0].save("test_flux_cot.jpg")
|
| 442 |
+
|
unimodel/qwensana/qwensana_inference.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Fu-Yun Wang
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from typing import List, Optional, Tuple, Union, Dict
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from PIL import Image
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoProcessor
|
| 22 |
+
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLModel, Qwen2_5_VLForConditionalGeneration, T5Config, Gemma2Model, GemmaTokenizer, GemmaTokenizerFast, Gemma2Config, AutoConfig
|
| 23 |
+
from diffusers import SanaPipeline, AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaTransformer2DModel, DPMSolverMultistepScheduler
|
| 24 |
+
import re
|
| 25 |
+
import datetime
|
| 26 |
+
import os
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def save_grid_image(prompt, images, rows, cols):
|
| 30 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 31 |
+
base_dir = os.path.join("samples", timestamp, prompt[:100])
|
| 32 |
+
os.makedirs(base_dir, exist_ok=True)
|
| 33 |
+
|
| 34 |
+
filename = os.path.join(base_dir, "grid.jpg")
|
| 35 |
+
grid_image = create_image_grid(images, rows, cols)
|
| 36 |
+
grid_image.save(filename)
|
| 37 |
+
|
| 38 |
+
print(f"Saved: {filename}")
|
| 39 |
+
|
| 40 |
+
def create_image_grid(images, rows, cols):
|
| 41 |
+
"""Creates a grid of images and returns a single PIL Image."""
|
| 42 |
+
assert len(images) == rows * cols
|
| 43 |
+
|
| 44 |
+
width, height = images[0].size
|
| 45 |
+
grid_width = width * cols
|
| 46 |
+
grid_height = height * rows
|
| 47 |
+
|
| 48 |
+
grid_image = Image.new('RGB', (grid_width, grid_height))
|
| 49 |
+
|
| 50 |
+
for i, image in enumerate(images):
|
| 51 |
+
x = (i % cols) * width
|
| 52 |
+
y = (i // cols) * height
|
| 53 |
+
grid_image.paste(image, (x, y))
|
| 54 |
+
|
| 55 |
+
return grid_image
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class QwenSanaMetaModel:
|
| 59 |
+
|
| 60 |
+
def __init__(self, config):
|
| 61 |
+
super(QwenSanaMetaModel, self).__init__(config)
|
| 62 |
+
if hasattr(config, "diffusion_expert"):
|
| 63 |
+
ckpt_id = "Efficient-Large-Model/SANA1.5_1.6B_1024px_diffusers"
|
| 64 |
+
|
| 65 |
+
# Load configuration for each component
|
| 66 |
+
transformer_config = SanaTransformer2DModel.load_config(ckpt_id, subfolder="transformer")
|
| 67 |
+
vae_config = AutoencoderDC.load_config(ckpt_id, subfolder="vae")
|
| 68 |
+
text_encoder_config = Gemma2Config.from_pretrained(ckpt_id, subfolder="text_encoder")
|
| 69 |
+
scheduler_config = DPMSolverMultistepScheduler.load_config(ckpt_id, subfolder="scheduler")
|
| 70 |
+
# Initialize components from their configurations
|
| 71 |
+
self.transformer = SanaTransformer2DModel.from_config(transformer_config)
|
| 72 |
+
self.vae = AutoencoderDC.from_config(vae_config)
|
| 73 |
+
self.text_encoder = Gemma2Model(text_encoder_config)
|
| 74 |
+
self.scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config)
|
| 75 |
+
|
| 76 |
+
# Initialize tokenizer
|
| 77 |
+
self.tokenizer = GemmaTokenizerFast.from_pretrained(ckpt_id, subfolder="tokenizer")
|
| 78 |
+
|
| 79 |
+
# Create the pipeline configuration dictionary
|
| 80 |
+
pipeline_config = {
|
| 81 |
+
"transformer": self.transformer,
|
| 82 |
+
"scheduler": self.scheduler,
|
| 83 |
+
"vae": self.vae,
|
| 84 |
+
"text_encoder": self.text_encoder,
|
| 85 |
+
"tokenizer": self.tokenizer,
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
self.diffusion_expert = SanaPipeline(**pipeline_config)
|
| 89 |
+
|
| 90 |
+
def initialize_diffusion_expert(self, fsdp=None):
|
| 91 |
+
|
| 92 |
+
if getattr(self, 'diffusion_expert', None) is None:
|
| 93 |
+
print("Random initiation the Sana diffusion expert !!!")
|
| 94 |
+
self.diffusion_expert = SanaPipeline.from_pretrained(
|
| 95 |
+
"Efficient-Large-Model/SANA1.5_1.6B_1024px_diffusers",
|
| 96 |
+
torch_dtype=torch.bfloat16
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Store references to components for easier access
|
| 100 |
+
self.transformer = self.diffusion_expert.transformer
|
| 101 |
+
self.vae = self.diffusion_expert.vae
|
| 102 |
+
self.text_encoder = self.diffusion_expert.text_encoder
|
| 103 |
+
self.tokenizer = self.diffusion_expert.tokenizer
|
| 104 |
+
self.scheduler = self.diffusion_expert.scheduler
|
| 105 |
+
|
| 106 |
+
self.config.diffusion_expert = "Sana"
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class QwenSanaConfig(Qwen2_5_VLConfig):
|
| 110 |
+
model_type = "QwenSana"
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class QwenSanaModel(QwenSanaMetaModel, Qwen2_5_VLModel):
|
| 114 |
+
config_class = QwenSanaConfig
|
| 115 |
+
|
| 116 |
+
def __init__(self, config: Qwen2_5_VLConfig):
|
| 117 |
+
super(QwenSanaModel, self).__init__(config)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class QwenSanaForInferenceLM(Qwen2_5_VLForConditionalGeneration):
|
| 121 |
+
config_class = QwenSanaConfig
|
| 122 |
+
|
| 123 |
+
def __init__(self, config):
|
| 124 |
+
Qwen2_5_VLForConditionalGeneration.__init__(self, config)
|
| 125 |
+
config.model_type = "QwenSana"
|
| 126 |
+
|
| 127 |
+
self.model = QwenSanaModel(config)
|
| 128 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 129 |
+
self.post_init()
|
| 130 |
+
|
| 131 |
+
def get_model(self):
|
| 132 |
+
return self.model
|
| 133 |
+
|
| 134 |
+
@torch.no_grad()
|
| 135 |
+
def generate_image(
|
| 136 |
+
self,
|
| 137 |
+
texts: List[str],
|
| 138 |
+
diffusion_kwargs: Optional[Dict] = None,
|
| 139 |
+
):
|
| 140 |
+
|
| 141 |
+
if isinstance(texts, str):
|
| 142 |
+
texts = [texts]
|
| 143 |
+
|
| 144 |
+
# Default parameters for Sana
|
| 145 |
+
default_kwargs = dict(
|
| 146 |
+
guidance_scale=3.5,
|
| 147 |
+
num_inference_steps=20,
|
| 148 |
+
height=1024,
|
| 149 |
+
width=1024
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
if diffusion_kwargs:
|
| 153 |
+
default_kwargs.update(diffusion_kwargs)
|
| 154 |
+
|
| 155 |
+
output_img = self.model.diffusion_expert(
|
| 156 |
+
texts,
|
| 157 |
+
**default_kwargs,
|
| 158 |
+
).images
|
| 159 |
+
|
| 160 |
+
return output_img
|
| 161 |
+
|
| 162 |
+
def extract_thinking_content(self, text: str) -> str:
|
| 163 |
+
pattern = r'<answer>(.*?)</answer>'
|
| 164 |
+
matches = re.findall(pattern, text, re.DOTALL)
|
| 165 |
+
|
| 166 |
+
if matches:
|
| 167 |
+
return matches[-1].strip().replace("<answer>", "").replace("</answer>", "")
|
| 168 |
+
else:
|
| 169 |
+
return text.strip().replace("<answer>", "").replace("</answer>", "")
|
| 170 |
+
|
| 171 |
+
@torch.no_grad()
|
| 172 |
+
def generate_image_cot(
|
| 173 |
+
self,
|
| 174 |
+
texts: List[str],
|
| 175 |
+
processor: Optional[object] = None,
|
| 176 |
+
diffusion_kwargs: Optional[Dict] = None,
|
| 177 |
+
llm_kwargs: Optional[Dict] = None,
|
| 178 |
+
cot_prompt_template: Optional[str] = None,
|
| 179 |
+
):
|
| 180 |
+
|
| 181 |
+
if isinstance(texts, str):
|
| 182 |
+
texts = [texts]
|
| 183 |
+
|
| 184 |
+
# Default parameters
|
| 185 |
+
default_diffusion_kwargs = dict(
|
| 186 |
+
guidance_scale=5.0,
|
| 187 |
+
num_inference_steps=20,
|
| 188 |
+
height=1024,
|
| 189 |
+
width=1024
|
| 190 |
+
)
|
| 191 |
+
if diffusion_kwargs:
|
| 192 |
+
default_diffusion_kwargs.update(diffusion_kwargs)
|
| 193 |
+
|
| 194 |
+
default_llm_kwargs = dict(
|
| 195 |
+
max_new_tokens=256,
|
| 196 |
+
temperature=0.7,
|
| 197 |
+
top_p=0.9,
|
| 198 |
+
do_sample=True
|
| 199 |
+
)
|
| 200 |
+
if llm_kwargs:
|
| 201 |
+
default_llm_kwargs.update(llm_kwargs)
|
| 202 |
+
|
| 203 |
+
if cot_prompt_template is None:
|
| 204 |
+
cot_prompt_template = """Please provide an enhanced prompt for the following image generation prompt to make the image more realistic, detailed, with clear separation and precise alignment of all entities.
|
| 205 |
+
Original prompt: {original_prompt}. Directly provide the improved prompt in <answer> </answer> tags."""
|
| 206 |
+
|
| 207 |
+
improved_prompts = []
|
| 208 |
+
|
| 209 |
+
for text in texts:
|
| 210 |
+
cot_input = cot_prompt_template.format(original_prompt=text)
|
| 211 |
+
|
| 212 |
+
messages = [{"role": "user", "content": cot_input}]
|
| 213 |
+
input_text_formatted = processor.apply_chat_template(
|
| 214 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 215 |
+
)
|
| 216 |
+
model_inputs = processor(
|
| 217 |
+
text=[input_text_formatted],
|
| 218 |
+
return_tensors="pt"
|
| 219 |
+
).to(self.device)
|
| 220 |
+
|
| 221 |
+
generated_ids = self.generate(
|
| 222 |
+
**model_inputs,
|
| 223 |
+
**default_llm_kwargs,
|
| 224 |
+
eos_token_id=processor.tokenizer.eos_token_id,
|
| 225 |
+
pad_token_id=processor.tokenizer.pad_token_id
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
generated_text = processor.batch_decode(
|
| 229 |
+
generated_ids[:, model_inputs['input_ids'].shape[1]:],
|
| 230 |
+
skip_special_tokens=True
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
improved_prompt = [self.extract_thinking_content(decode_text) for decode_text in generated_text]
|
| 234 |
+
improved_prompts.extend(improved_prompt)
|
| 235 |
+
|
| 236 |
+
print(f"Original prompt: {text}")
|
| 237 |
+
print(f"Improved prompt: {improved_prompt}")
|
| 238 |
+
print("-" * 50)
|
| 239 |
+
|
| 240 |
+
output_images = self.generate_image(improved_prompts, default_diffusion_kwargs)
|
| 241 |
+
|
| 242 |
+
return {
|
| 243 |
+
'images': output_images,
|
| 244 |
+
'original_prompts': texts,
|
| 245 |
+
'improved_prompts': improved_prompts
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
AutoConfig.register("QwenSana", QwenSanaConfig)
|
| 250 |
+
AutoModelForCausalLM.register(QwenSanaConfig, QwenSanaForInferenceLM)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
if __name__ == "__main__":
|
| 254 |
+
model = QwenSanaForInferenceLM.from_pretrained(
|
| 255 |
+
"Qwen/Qwen2.5-VL-3B-Instruct",
|
| 256 |
+
torch_dtype=torch.bfloat16
|
| 257 |
+
)
|
| 258 |
+
model.model.initialize_diffusion_expert()
|
| 259 |
+
model.model.diffusion_expert.to("cuda:0")
|
| 260 |
+
model.to("cuda:0")
|
| 261 |
+
|
| 262 |
+
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
| 263 |
+
|
| 264 |
+
# Test basic image generation
|
| 265 |
+
text = ["a photo of a cat"]
|
| 266 |
+
diffusion_kwargs = dict(
|
| 267 |
+
guidance_scale=3.5,
|
| 268 |
+
num_inference_steps=20,
|
| 269 |
+
width=1024,
|
| 270 |
+
height=1024,
|
| 271 |
+
generator=torch.manual_seed(0)
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
images = model.generate_image(text, diffusion_kwargs=diffusion_kwargs)
|
| 275 |
+
images[0].save("test_Sana.jpg")
|
| 276 |
+
|
| 277 |
+
# Test chain-of-thought image generation
|
| 278 |
+
outputs = model.generate_image_cot(text, processor=processor, diffusion_kwargs=diffusion_kwargs)
|
| 279 |
+
outputs['images'][0].save("test_Sana_cot.jpg")
|
| 280 |
+
|
| 281 |
+
# Save the model
|
| 282 |
+
model.save_pretrained("outputs/pretrain/qwenSana-1.5")
|
| 283 |
+
|
| 284 |
+
# print("Sana model integration completed successfully!")
|
| 285 |
+
|
| 286 |
+
# model = QwenSanaForInferenceLM.from_pretrained(
|
| 287 |
+
# "outputs/pretrain/qwenSana-1.5",
|
| 288 |
+
# torch_dtype=torch.bfloat16
|
| 289 |
+
# ).to("cuda")
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
| 293 |
+
|
| 294 |
+
# # Test basic image generation
|
| 295 |
+
# text = ["a photo of a cat"]
|
| 296 |
+
# diffusion_kwargs = dict(
|
| 297 |
+
# guidance_scale=5.0,
|
| 298 |
+
# num_inference_steps=20,
|
| 299 |
+
# width=1024,
|
| 300 |
+
# height=1024,
|
| 301 |
+
# generator=torch.manual_seed(0)
|
| 302 |
+
# )
|
| 303 |
+
|
| 304 |
+
# images = model.generate_image(text, diffusion_kwargs=diffusion_kwargs)
|
| 305 |
+
# images[0].save("test_Sana.jpg")
|
| 306 |
+
|
| 307 |
+
# # Test chain-of-thought image generation
|
| 308 |
+
# outputs = model.generate_image_cot(text, processor=processor, diffusion_kwargs=diffusion_kwargs)
|
| 309 |
+
# outputs['images'][0].save("test_Sana_cot.jpg")
|
| 310 |
+
|
unimodel/qwensd3/qwensd3_inference.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Fu-Yun Wang
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import List, Optional, Tuple, Union, Dict
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
from PIL import Image
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoProcessor
|
| 21 |
+
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLModel, Qwen2_5_VLForConditionalGeneration
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 25 |
+
from diffusers.pipelines.pipeline_utils import numpy_to_pil
|
| 26 |
+
import numpy as np
|
| 27 |
+
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteSchedulerOutput
|
| 28 |
+
from diffusers.schedulers import DPMSolverMultistepScheduler
|
| 29 |
+
import math
|
| 30 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 31 |
+
from diffusers import SD3Transformer2DModel, AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
| 32 |
+
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, CLIPTextConfig, T5Config, CLIPTextModelWithProjection
|
| 33 |
+
try:
|
| 34 |
+
from .sd3pipeline import StableDiffusion3Pipeline as SD3Pipeline
|
| 35 |
+
except:
|
| 36 |
+
from sd3pipeline import StableDiffusion3Pipeline as SD3Pipeline
|
| 37 |
+
# from diffusers import StableDiffusion3Pipeline as SD3Pipeline
|
| 38 |
+
import re
|
| 39 |
+
import datetime
|
| 40 |
+
import os
|
| 41 |
+
from transformers import GenerationConfig
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def save_grid_image(prompt, images, rows, cols):
|
| 45 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 46 |
+
base_dir = os.path.join("samples", timestamp, prompt[:100])
|
| 47 |
+
os.makedirs(base_dir, exist_ok=True)
|
| 48 |
+
|
| 49 |
+
filename = os.path.join(base_dir, "grid.jpg")
|
| 50 |
+
grid_image = create_image_grid(images, rows, cols)
|
| 51 |
+
grid_image.save(filename)
|
| 52 |
+
|
| 53 |
+
print(f"Saved: {filename}")
|
| 54 |
+
|
| 55 |
+
def create_image_grid(images, rows, cols):
|
| 56 |
+
"""Creates a grid of images and returns a single PIL Image."""
|
| 57 |
+
|
| 58 |
+
assert len(images) == rows * cols
|
| 59 |
+
|
| 60 |
+
width, height = images[0].size
|
| 61 |
+
grid_width = width * cols
|
| 62 |
+
grid_height = height * rows
|
| 63 |
+
|
| 64 |
+
grid_image = Image.new('RGB', (grid_width, grid_height))
|
| 65 |
+
|
| 66 |
+
for i, image in enumerate(images):
|
| 67 |
+
x = (i % cols) * width
|
| 68 |
+
y = (i // cols) * height
|
| 69 |
+
grid_image.paste(image, (x, y))
|
| 70 |
+
|
| 71 |
+
return grid_image
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def sde_step_with_logprob(
|
| 75 |
+
self: FlowMatchEulerDiscreteScheduler,
|
| 76 |
+
model_output: torch.FloatTensor,
|
| 77 |
+
timestep: Union[float, torch.FloatTensor],
|
| 78 |
+
sample: torch.FloatTensor,
|
| 79 |
+
prev_sample: Optional[torch.FloatTensor] = None,
|
| 80 |
+
generator: Optional[torch.Generator] = None,
|
| 81 |
+
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
| 82 |
+
"""
|
| 83 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
|
| 84 |
+
process from the learned model outputs (most often the predicted velocity).
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
model_output (`torch.FloatTensor`):
|
| 88 |
+
The direct output from learned flow model.
|
| 89 |
+
timestep (`float`):
|
| 90 |
+
The current discrete timestep in the diffusion chain.
|
| 91 |
+
sample (`torch.FloatTensor`):
|
| 92 |
+
A current instance of a sample created by the diffusion process.
|
| 93 |
+
generator (`torch.Generator`, *optional*):
|
| 94 |
+
A random number generator.
|
| 95 |
+
"""
|
| 96 |
+
step_index = [self.index_for_timestep(t) for t in timestep]
|
| 97 |
+
prev_step_index = [step+1 for step in step_index]
|
| 98 |
+
sigma = self.sigmas[step_index].view(-1, 1, 1, 1).to(model_output.device)
|
| 99 |
+
sigma_prev = self.sigmas[prev_step_index].view(-1, 1, 1, 1).to(model_output.device)
|
| 100 |
+
sigma_max = self.sigmas[1].item()
|
| 101 |
+
dt = sigma_prev - sigma
|
| 102 |
+
|
| 103 |
+
std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*0.7
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# our sde
|
| 107 |
+
prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
|
| 108 |
+
|
| 109 |
+
if prev_sample is not None and generator is not None:
|
| 110 |
+
raise ValueError(
|
| 111 |
+
"Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
|
| 112 |
+
" `prev_sample` stays `None`."
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if prev_sample is None:
|
| 116 |
+
variance_noise = randn_tensor(
|
| 117 |
+
model_output.shape,
|
| 118 |
+
generator=generator,
|
| 119 |
+
device=model_output.device,
|
| 120 |
+
dtype=model_output.dtype,
|
| 121 |
+
)
|
| 122 |
+
prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1*dt) * variance_noise
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
log_prob = (
|
| 126 |
+
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1*dt))**2))
|
| 127 |
+
- torch.log(std_dev_t * torch.sqrt(-1*dt))
|
| 128 |
+
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# mean along all but batch dimension
|
| 132 |
+
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
|
| 133 |
+
|
| 134 |
+
return prev_sample, log_prob, prev_sample_mean, std_dev_t * torch.sqrt(-1*dt)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# Copyright 2025 Fu-Yun Wang
|
| 139 |
+
#
|
| 140 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 141 |
+
# you may not use this file except in compliance with the License.
|
| 142 |
+
# You may obtain a copy of the License at
|
| 143 |
+
#
|
| 144 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 145 |
+
#
|
| 146 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 147 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 148 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 149 |
+
# See the License for the specific language governing permissions and
|
| 150 |
+
# limitations under the License.
|
| 151 |
+
|
| 152 |
+
def sde_step_with_logprob_simple(
|
| 153 |
+
self: FlowMatchEulerDiscreteScheduler,
|
| 154 |
+
model_output: torch.FloatTensor,
|
| 155 |
+
timestep: Union[float, torch.FloatTensor],
|
| 156 |
+
sample: torch.FloatTensor,
|
| 157 |
+
prev_sample: Optional[torch.FloatTensor] = None,
|
| 158 |
+
generator: Optional[torch.Generator] = None,
|
| 159 |
+
):
|
| 160 |
+
"""
|
| 161 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
|
| 162 |
+
process from the learned model outputs (most often the predicted velocity).
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
model_output (`torch.FloatTensor`):
|
| 166 |
+
The direct output from learned flow model.
|
| 167 |
+
timestep (`float`):
|
| 168 |
+
The current discrete timestep in the diffusion chain.
|
| 169 |
+
sample (`torch.FloatTensor`):
|
| 170 |
+
A current instance of a sample created by the diffusion process.
|
| 171 |
+
generator (`torch.Generator`, *optional*):
|
| 172 |
+
A random number generator.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
step_index = [self.index_for_timestep(t) for t in timestep]
|
| 176 |
+
prev_step_index = [step+1 for step in step_index]
|
| 177 |
+
sigma = self.sigmas[step_index].view(-1, 1, 1, 1).to(model_output.device)
|
| 178 |
+
sigma_prev = self.sigmas[prev_step_index].view(-1, 1, 1, 1).to(model_output.device)
|
| 179 |
+
sigma_max = self.sigmas[1].item()
|
| 180 |
+
dt = sigma_prev - sigma
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
eta = 0.5
|
| 184 |
+
Dt = - dt * eta
|
| 185 |
+
|
| 186 |
+
prev_sample_mean = sample * (1 - Dt / (1 - torch.where(sigma == 1, sigma_max, sigma))) + model_output * (dt - Dt)
|
| 187 |
+
|
| 188 |
+
std_dev_t = torch.sqrt(2 * Dt * (sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))))
|
| 189 |
+
|
| 190 |
+
if prev_sample is not None and generator is not None:
|
| 191 |
+
raise ValueError(
|
| 192 |
+
"Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
|
| 193 |
+
" `prev_sample` stays `None`."
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
if prev_sample is None:
|
| 197 |
+
# Generate noise if not provided
|
| 198 |
+
variance_noise = randn_tensor(
|
| 199 |
+
model_output.shape,
|
| 200 |
+
generator=generator,
|
| 201 |
+
device=model_output.device,
|
| 202 |
+
dtype=model_output.dtype,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
prev_sample = prev_sample_mean + std_dev_t * variance_noise
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
log_prob = (
|
| 209 |
+
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
|
| 210 |
+
- torch.log(std_dev_t)
|
| 211 |
+
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# mean along all but batch dimension
|
| 215 |
+
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
|
| 216 |
+
|
| 217 |
+
return prev_sample, log_prob, prev_sample_mean, std_dev_t
|
| 218 |
+
|
| 219 |
+
class QwenSD3MetaModel:
|
| 220 |
+
|
| 221 |
+
def __init__(self, config):
|
| 222 |
+
super(QwenSD3MetaModel, self).__init__(config)
|
| 223 |
+
if hasattr(config, "diffusion_expert"):
|
| 224 |
+
ckpt_id = "stabilityai/stable-diffusion-3.5-medium"
|
| 225 |
+
|
| 226 |
+
transformer_config = SD3Transformer2DModel.load_config(ckpt_id, subfolder="transformer")
|
| 227 |
+
vae_config = AutoencoderKL.load_config(ckpt_id, subfolder="vae")
|
| 228 |
+
text_encoder_config = CLIPTextConfig.from_pretrained(ckpt_id, subfolder="text_encoder", torch_dtype=config.torch_dtype)
|
| 229 |
+
text_encoder_2_config = CLIPTextConfig.from_pretrained(ckpt_id, subfolder="text_encoder_2", torch_dtype=config.torch_dtype)
|
| 230 |
+
text_encoder_3_config = T5Config.from_pretrained(ckpt_id, subfolder="text_encoder_3", torch_dtype=config.torch_dtype)
|
| 231 |
+
|
| 232 |
+
# Initialize components from their configurations
|
| 233 |
+
self.transformer = SD3Transformer2DModel.from_config(transformer_config)
|
| 234 |
+
self.vae = AutoencoderKL.from_config(vae_config)
|
| 235 |
+
self.text_encoder = CLIPTextModelWithProjection(text_encoder_config)
|
| 236 |
+
self.text_encoder_2 = CLIPTextModelWithProjection(text_encoder_2_config)
|
| 237 |
+
self.text_encoder_3 = T5EncoderModel(text_encoder_3_config)
|
| 238 |
+
|
| 239 |
+
# Initialize tokenizers (these don't use from_config as they are not models)
|
| 240 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(ckpt_id, subfolder="tokenizer")
|
| 241 |
+
self.tokenizer_2 = CLIPTokenizer.from_pretrained(ckpt_id, subfolder="tokenizer_2")
|
| 242 |
+
self.tokenizer_3 = T5TokenizerFast.from_pretrained(ckpt_id, subfolder="tokenizer_3")
|
| 243 |
+
self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(ckpt_id, subfolder="scheduler")
|
| 244 |
+
|
| 245 |
+
# Create the pipeline configuration dictionary
|
| 246 |
+
pipeline_config = {
|
| 247 |
+
"transformer": self.transformer,
|
| 248 |
+
"scheduler": self.scheduler,
|
| 249 |
+
"vae": self.vae,
|
| 250 |
+
"text_encoder": self.text_encoder,
|
| 251 |
+
"text_encoder_2": self.text_encoder_2,
|
| 252 |
+
"text_encoder_3": self.text_encoder_3,
|
| 253 |
+
"tokenizer": self.tokenizer,
|
| 254 |
+
"tokenizer_2": self.tokenizer_2,
|
| 255 |
+
"tokenizer_3": self.tokenizer_3,
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
self.diffusion_expert = SD3Pipeline(**pipeline_config)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def initialize_diffusion_expert(self, fsdp=None):
|
| 262 |
+
|
| 263 |
+
print("random initiation the diffusion expert !!!")
|
| 264 |
+
self.diffusion_expert = SD3Pipeline.from_pretrained("stabilityai/stable-diffusion-3.5-medium", revision="main", torch_dtype=torch.bfloat16)
|
| 265 |
+
self.text_encoder = self.diffusion_expert.text_encoder
|
| 266 |
+
self.text_encoder_model = self.diffusion_expert.text_encoder.text_model
|
| 267 |
+
self.text_encoder_2 = self.diffusion_expert.text_encoder_2
|
| 268 |
+
self.text_encoder_2_model = self.diffusion_expert.text_encoder_2.text_model
|
| 269 |
+
self.text_encoder_3 = self.diffusion_expert.text_encoder_3
|
| 270 |
+
self.tokenizer = self.diffusion_expert.tokenizer
|
| 271 |
+
self.tokenizer_2 = self.diffusion_expert.tokenizer_2
|
| 272 |
+
self.tokenizer_3 = self.diffusion_expert.tokenizer_3
|
| 273 |
+
self.vae = self.diffusion_expert.vae
|
| 274 |
+
self.transformer = self.diffusion_expert.transformer
|
| 275 |
+
self.scheduler = self.diffusion_expert.scheduler
|
| 276 |
+
|
| 277 |
+
self.config.diffusion_expert = "SD3"
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class QwenSD3Config(Qwen2_5_VLConfig):
|
| 282 |
+
model_type = "QwenSD3"
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class QwenSD3Model(QwenSD3MetaModel, Qwen2_5_VLModel):
|
| 286 |
+
config_class = QwenSD3Config
|
| 287 |
+
|
| 288 |
+
def __init__(self, config: Qwen2_5_VLConfig):
|
| 289 |
+
super(QwenSD3Model, self).__init__(config)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class QwenSD3ForInferenceLM(Qwen2_5_VLForConditionalGeneration):
|
| 293 |
+
config_class = QwenSD3Config
|
| 294 |
+
|
| 295 |
+
def __init__(self, config):
|
| 296 |
+
Qwen2_5_VLForConditionalGeneration.__init__(self, config)
|
| 297 |
+
config.model_type = "QwenSD3"
|
| 298 |
+
|
| 299 |
+
self.model = QwenSD3Model(config)
|
| 300 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 301 |
+
self.post_init()
|
| 302 |
+
|
| 303 |
+
def get_model(self):
|
| 304 |
+
return self.model
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
@torch.no_grad()
|
| 309 |
+
def generate_image(
|
| 310 |
+
self,
|
| 311 |
+
texts: List[str],
|
| 312 |
+
diffusion_kwargs: Optional[Dict] = dict(guidance_scale = 3.5, num_inference_steps=25),
|
| 313 |
+
sde_sampling: Optional[bool] = False,
|
| 314 |
+
):
|
| 315 |
+
|
| 316 |
+
if isinstance(texts, str):
|
| 317 |
+
texts = [texts]
|
| 318 |
+
|
| 319 |
+
if not sde_sampling:
|
| 320 |
+
output_img = self.model.diffusion_expert(
|
| 321 |
+
texts,
|
| 322 |
+
max_sequence_length=512,
|
| 323 |
+
**diffusion_kwargs,
|
| 324 |
+
).images
|
| 325 |
+
return output_img
|
| 326 |
+
else:
|
| 327 |
+
return self.model.diffusion_expert.sde_sampling(
|
| 328 |
+
texts,
|
| 329 |
+
max_sequence_length=512,
|
| 330 |
+
**diffusion_kwargs,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def extract_thinking_content(self, text: str) -> str:
|
| 335 |
+
pattern = r'<answer>(.*?)</answer>'
|
| 336 |
+
matches = re.findall(pattern, text, re.DOTALL)
|
| 337 |
+
|
| 338 |
+
if matches:
|
| 339 |
+
return matches[-1].strip().replace("<answer>", "").replace("</answer>", "")
|
| 340 |
+
else:
|
| 341 |
+
return text.strip().replace("<answer>", "").replace("</answer>", "")
|
| 342 |
+
|
| 343 |
+
@torch.no_grad()
|
| 344 |
+
def generate_image_cot(
|
| 345 |
+
self,
|
| 346 |
+
texts: List[str],
|
| 347 |
+
processor: Optional[object] = None,
|
| 348 |
+
diffusion_kwargs: Optional[Dict] = dict(guidance_scale = 3.5, num_inference_steps=25),
|
| 349 |
+
llm_kwargs: Optional[Dict] = dict(max_new_tokens=256, temperature=0.7, top_p=0.9, do_sample=True),
|
| 350 |
+
cot_prompt_template: Optional[str] = None,
|
| 351 |
+
):
|
| 352 |
+
|
| 353 |
+
if isinstance(texts, str):
|
| 354 |
+
texts = [texts]
|
| 355 |
+
|
| 356 |
+
if cot_prompt_template is None:
|
| 357 |
+
# cot_prompt_template = """Please improve the following image generation prompt to make it more detailed and specific for better image quality. Think step by step about what visual elements would make this image more compelling. Original prompt: {original_prompt}. Please provide the improved prompt in <thinking> </thinking> tags."""
|
| 358 |
+
cot_prompt_template = """Please provide an enhanced prompt for the following image generation prompt to make the image more realistic, detailed, with clear separation and precise alignment of all entities.
|
| 359 |
+
Original prompt: {original_prompt}. Directly provide the improved prompt in <answer> </answer> tags."""
|
| 360 |
+
|
| 361 |
+
improved_prompts = []
|
| 362 |
+
|
| 363 |
+
for text in texts:
|
| 364 |
+
cot_input = cot_prompt_template.format(original_prompt=text)
|
| 365 |
+
|
| 366 |
+
messages = [{"role": "user", "content": cot_input}]
|
| 367 |
+
input_text_formatted = processor.apply_chat_template(
|
| 368 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 369 |
+
)
|
| 370 |
+
model_inputs = processor(
|
| 371 |
+
text=[input_text_formatted],
|
| 372 |
+
return_tensors="pt"
|
| 373 |
+
).to(self.device)
|
| 374 |
+
|
| 375 |
+
generated_ids = self.generate(
|
| 376 |
+
**model_inputs,
|
| 377 |
+
**llm_kwargs,
|
| 378 |
+
eos_token_id=processor.tokenizer.eos_token_id,
|
| 379 |
+
pad_token_id=processor.tokenizer.pad_token_id
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
generated_text = processor.batch_decode(
|
| 383 |
+
generated_ids[:, model_inputs['input_ids'].shape[1]:],
|
| 384 |
+
skip_special_tokens=True
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
improved_prompt = [self.extract_thinking_content(decode_text) for decode_text in generated_text]
|
| 388 |
+
improved_prompts.extend(improved_prompt)
|
| 389 |
+
|
| 390 |
+
print(f"Original prompt: {text}")
|
| 391 |
+
print(f"Improved prompt: {improved_prompt}")
|
| 392 |
+
print("-" * 50)
|
| 393 |
+
|
| 394 |
+
output_images = self.generate_image(improved_prompts, diffusion_kwargs)
|
| 395 |
+
|
| 396 |
+
return {
|
| 397 |
+
'images': output_images,
|
| 398 |
+
'original_prompts': texts,
|
| 399 |
+
'improved_prompts': improved_prompts
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
AutoConfig.register("QwenSD3", QwenSD3Config)
|
| 403 |
+
AutoModelForCausalLM.register(QwenSD3Config, QwenSD3ForInferenceLM)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
if __name__ == "__main__":
|
| 407 |
+
pass
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
model = QwenSD3ForInferenceLM.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct",torch_dtype=torch.bfloat16)
|
| 411 |
+
model.model.initialize_diffusion_expert()
|
| 412 |
+
model.model.diffusion_expert.to("cuda:0")
|
| 413 |
+
model.to("cuda:0")
|
| 414 |
+
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
| 415 |
+
text = ["a photo of a cat"]
|
| 416 |
+
images = model.generate_image(text)
|
| 417 |
+
images[0].save("test_SD3.jpg")
|
| 418 |
+
outputs = model.generate_image_cot(text, processor = processor)
|
| 419 |
+
outputs['images'][0].save("test_SD3_cot.jpg")
|
| 420 |
+
|
| 421 |
+
model.save_pretrained("qwensd3")
|
| 422 |
+
|
| 423 |
+
# model = QwenSD3ForInferenceLM.from_pretrained("qwenSD3.0", torch_dtype=torch.bfloat16)
|
| 424 |
+
# model.to("cuda:0")
|
| 425 |
+
# model.save_pretrained("qwenSD3-test-2", torch_dtype=torch.bfloat16)
|
| 426 |
+
|
| 427 |
+
# model = QwenSD3ForInferenceLM.from_pretrained("qwenSD3-test", torch_dtype=torch.float16)
|
| 428 |
+
# # model.to("cuda:0")
|
| 429 |
+
# for n, p in model.named_parameters():
|
| 430 |
+
# if not p.dtype == torch.float16:
|
| 431 |
+
# print(n)
|
| 432 |
+
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
| 433 |
+
# text = ["a photo of a cat"]
|
| 434 |
+
# diffusion_kwargs = dict(guidance_scale = 5., num_inference_steps=20, width = 512, height = 512, generator = torch.manual_seed(0))
|
| 435 |
+
# images = model.generate_image(text, diffusion_kwargs=diffusion_kwargs)
|
| 436 |
+
# images[0].save("test_SD3.jpg")
|
| 437 |
+
|
| 438 |
+
# llm_kwargs = dict(max_new_tokens=256, temperature=0.7, top_p=0.9, do_sample=True, num_return_sequences=8)
|
| 439 |
+
# # generation_config = GenerationConfig.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", trust_remote_code=True)
|
| 440 |
+
# # generation_config.num_return_sequences = 8
|
| 441 |
+
# # print(generation_config)
|
| 442 |
+
# # llm_kwargs = dict(max_new_tokens=256, temperature=0.7, top_p=0.9, do_sample=True, generation_config=generation_config)
|
| 443 |
+
|
| 444 |
+
# outputs = model.generate_image_cot(text, processor = processor, llm_kwargs = llm_kwargs)
|
| 445 |
+
# # save_grid_image("cat", images['images'], 2, 2)
|
| 446 |
+
# for idx, image in enumerate(outputs['images']):
|
| 447 |
+
# image.save(f"test_SD3_cot_{idx}.jpg")
|
unimodel/qwensd3/sd3pipeline.py
ADDED
|
@@ -0,0 +1,1162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
|
| 2 |
+
# Copyright 2025 Fu-Yun Wang
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from transformers import (
|
| 21 |
+
CLIPTextModelWithProjection,
|
| 22 |
+
CLIPTokenizer,
|
| 23 |
+
SiglipImageProcessor,
|
| 24 |
+
SiglipVisionModel,
|
| 25 |
+
T5EncoderModel,
|
| 26 |
+
T5TokenizerFast,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 30 |
+
from diffusers.loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
|
| 31 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
| 32 |
+
from diffusers.models.transformers import SD3Transformer2DModel
|
| 33 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 34 |
+
from diffusers.utils import (
|
| 35 |
+
USE_PEFT_BACKEND,
|
| 36 |
+
is_torch_xla_available,
|
| 37 |
+
logging,
|
| 38 |
+
replace_example_docstring,
|
| 39 |
+
scale_lora_layers,
|
| 40 |
+
unscale_lora_layers,
|
| 41 |
+
)
|
| 42 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 43 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 44 |
+
from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
|
| 45 |
+
import deepspeed
|
| 46 |
+
from PIL import Image
|
| 47 |
+
import numpy as np
|
| 48 |
+
if is_torch_xla_available():
|
| 49 |
+
import torch_xla.core.xla_model as xm
|
| 50 |
+
|
| 51 |
+
XLA_AVAILABLE = True
|
| 52 |
+
else:
|
| 53 |
+
XLA_AVAILABLE = False
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 57 |
+
|
| 58 |
+
EXAMPLE_DOC_STRING = """
|
| 59 |
+
Examples:
|
| 60 |
+
```py
|
| 61 |
+
>>> import torch
|
| 62 |
+
>>> from diffusers import StableDiffusion3Pipeline
|
| 63 |
+
|
| 64 |
+
>>> pipe = StableDiffusion3Pipeline.from_pretrained(
|
| 65 |
+
... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
|
| 66 |
+
... )
|
| 67 |
+
>>> pipe.to("cuda")
|
| 68 |
+
>>> prompt = "A cat holding a sign that says hello world"
|
| 69 |
+
>>> image = pipe(prompt).images[0]
|
| 70 |
+
>>> image.save("sd3.png")
|
| 71 |
+
```
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
| 76 |
+
def calculate_shift(
|
| 77 |
+
image_seq_len,
|
| 78 |
+
base_seq_len: int = 256,
|
| 79 |
+
max_seq_len: int = 4096,
|
| 80 |
+
base_shift: float = 0.5,
|
| 81 |
+
max_shift: float = 1.15,
|
| 82 |
+
):
|
| 83 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 84 |
+
b = base_shift - m * base_seq_len
|
| 85 |
+
mu = image_seq_len * m + b
|
| 86 |
+
return mu
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 90 |
+
def retrieve_timesteps(
|
| 91 |
+
scheduler,
|
| 92 |
+
num_inference_steps: Optional[int] = None,
|
| 93 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 94 |
+
timesteps: Optional[List[int]] = None,
|
| 95 |
+
sigmas: Optional[List[float]] = None,
|
| 96 |
+
**kwargs,
|
| 97 |
+
):
|
| 98 |
+
r"""
|
| 99 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 100 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
scheduler (`SchedulerMixin`):
|
| 104 |
+
The scheduler to get timesteps from.
|
| 105 |
+
num_inference_steps (`int`):
|
| 106 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 107 |
+
must be `None`.
|
| 108 |
+
device (`str` or `torch.device`, *optional*):
|
| 109 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 110 |
+
timesteps (`List[int]`, *optional*):
|
| 111 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 112 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 113 |
+
sigmas (`List[float]`, *optional*):
|
| 114 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 115 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 119 |
+
second element is the number of inference steps.
|
| 120 |
+
"""
|
| 121 |
+
if timesteps is not None and sigmas is not None:
|
| 122 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 123 |
+
if timesteps is not None:
|
| 124 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 125 |
+
if not accepts_timesteps:
|
| 126 |
+
raise ValueError(
|
| 127 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 128 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 129 |
+
)
|
| 130 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 131 |
+
timesteps = scheduler.timesteps
|
| 132 |
+
num_inference_steps = len(timesteps)
|
| 133 |
+
elif sigmas is not None:
|
| 134 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 135 |
+
if not accept_sigmas:
|
| 136 |
+
raise ValueError(
|
| 137 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 138 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 139 |
+
)
|
| 140 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 141 |
+
timesteps = scheduler.timesteps
|
| 142 |
+
num_inference_steps = len(timesteps)
|
| 143 |
+
else:
|
| 144 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 145 |
+
timesteps = scheduler.timesteps
|
| 146 |
+
return timesteps, num_inference_steps
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
|
| 150 |
+
r"""
|
| 151 |
+
Args:
|
| 152 |
+
transformer ([`SD3Transformer2DModel`]):
|
| 153 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 154 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 155 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 156 |
+
vae ([`AutoencoderKL`]):
|
| 157 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 158 |
+
text_encoder ([`CLIPTextModelWithProjection`]):
|
| 159 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
| 160 |
+
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
|
| 161 |
+
with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
|
| 162 |
+
as its dimension.
|
| 163 |
+
text_encoder_2 ([`CLIPTextModelWithProjection`]):
|
| 164 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
| 165 |
+
specifically the
|
| 166 |
+
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
|
| 167 |
+
variant.
|
| 168 |
+
text_encoder_3 ([`T5EncoderModel`]):
|
| 169 |
+
Frozen text-encoder. Stable Diffusion 3 uses
|
| 170 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
| 171 |
+
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
| 172 |
+
tokenizer (`CLIPTokenizer`):
|
| 173 |
+
Tokenizer of class
|
| 174 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 175 |
+
tokenizer_2 (`CLIPTokenizer`):
|
| 176 |
+
Second Tokenizer of class
|
| 177 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 178 |
+
tokenizer_3 (`T5TokenizerFast`):
|
| 179 |
+
Tokenizer of class
|
| 180 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 181 |
+
image_encoder (`SiglipVisionModel`, *optional*):
|
| 182 |
+
Pre-trained Vision Model for IP Adapter.
|
| 183 |
+
feature_extractor (`SiglipImageProcessor`, *optional*):
|
| 184 |
+
Image processor for IP Adapter.
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
|
| 188 |
+
_optional_components = ["image_encoder", "feature_extractor"]
|
| 189 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
|
| 190 |
+
|
| 191 |
+
def __init__(
|
| 192 |
+
self,
|
| 193 |
+
transformer: SD3Transformer2DModel,
|
| 194 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 195 |
+
vae: AutoencoderKL,
|
| 196 |
+
text_encoder: CLIPTextModelWithProjection,
|
| 197 |
+
tokenizer: CLIPTokenizer,
|
| 198 |
+
text_encoder_2: CLIPTextModelWithProjection,
|
| 199 |
+
tokenizer_2: CLIPTokenizer,
|
| 200 |
+
text_encoder_3: T5EncoderModel,
|
| 201 |
+
tokenizer_3: T5TokenizerFast,
|
| 202 |
+
image_encoder: SiglipVisionModel = None,
|
| 203 |
+
feature_extractor: SiglipImageProcessor = None,
|
| 204 |
+
):
|
| 205 |
+
super().__init__()
|
| 206 |
+
|
| 207 |
+
self.register_modules(
|
| 208 |
+
vae=vae,
|
| 209 |
+
text_encoder=text_encoder,
|
| 210 |
+
text_encoder_2=text_encoder_2,
|
| 211 |
+
text_encoder_3=text_encoder_3,
|
| 212 |
+
tokenizer=tokenizer,
|
| 213 |
+
tokenizer_2=tokenizer_2,
|
| 214 |
+
tokenizer_3=tokenizer_3,
|
| 215 |
+
transformer=transformer,
|
| 216 |
+
scheduler=scheduler,
|
| 217 |
+
image_encoder=image_encoder,
|
| 218 |
+
feature_extractor=feature_extractor,
|
| 219 |
+
)
|
| 220 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 221 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 222 |
+
self.tokenizer_max_length = (
|
| 223 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
| 224 |
+
)
|
| 225 |
+
self.default_sample_size = (
|
| 226 |
+
self.transformer.config.sample_size
|
| 227 |
+
if hasattr(self, "transformer") and self.transformer is not None
|
| 228 |
+
else 128
|
| 229 |
+
)
|
| 230 |
+
self.patch_size = (
|
| 231 |
+
self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
def _get_t5_prompt_embeds(
|
| 235 |
+
self,
|
| 236 |
+
prompt: Union[str, List[str]] = None,
|
| 237 |
+
num_images_per_prompt: int = 1,
|
| 238 |
+
max_sequence_length: int = 256,
|
| 239 |
+
device: Optional[torch.device] = None,
|
| 240 |
+
dtype: Optional[torch.dtype] = None,
|
| 241 |
+
):
|
| 242 |
+
device = device or self._execution_device
|
| 243 |
+
dtype = dtype or self.text_encoder.dtype
|
| 244 |
+
|
| 245 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 246 |
+
batch_size = len(prompt)
|
| 247 |
+
|
| 248 |
+
if self.text_encoder_3 is None:
|
| 249 |
+
return torch.zeros(
|
| 250 |
+
(
|
| 251 |
+
batch_size * num_images_per_prompt,
|
| 252 |
+
self.tokenizer_max_length,
|
| 253 |
+
self.transformer.config.joint_attention_dim,
|
| 254 |
+
),
|
| 255 |
+
device=device,
|
| 256 |
+
dtype=dtype,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
text_inputs = self.tokenizer_3(
|
| 260 |
+
prompt,
|
| 261 |
+
padding="max_length",
|
| 262 |
+
max_length=max_sequence_length,
|
| 263 |
+
truncation=True,
|
| 264 |
+
add_special_tokens=True,
|
| 265 |
+
return_tensors="pt",
|
| 266 |
+
)
|
| 267 |
+
text_input_ids = text_inputs.input_ids
|
| 268 |
+
untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
|
| 269 |
+
|
| 270 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 271 |
+
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
| 272 |
+
logger.warning(
|
| 273 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 274 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
|
| 278 |
+
|
| 279 |
+
dtype = self.text_encoder_3.dtype
|
| 280 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 281 |
+
|
| 282 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 283 |
+
|
| 284 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 285 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 286 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 287 |
+
|
| 288 |
+
return prompt_embeds
|
| 289 |
+
|
| 290 |
+
def _get_clip_prompt_embeds(
|
| 291 |
+
self,
|
| 292 |
+
prompt: Union[str, List[str]],
|
| 293 |
+
num_images_per_prompt: int = 1,
|
| 294 |
+
device: Optional[torch.device] = None,
|
| 295 |
+
clip_skip: Optional[int] = None,
|
| 296 |
+
clip_model_index: int = 0,
|
| 297 |
+
):
|
| 298 |
+
device = device or self._execution_device
|
| 299 |
+
|
| 300 |
+
clip_tokenizers = [self.tokenizer, self.tokenizer_2]
|
| 301 |
+
clip_text_encoders = [self.text_encoder, self.text_encoder_2]
|
| 302 |
+
|
| 303 |
+
tokenizer = clip_tokenizers[clip_model_index]
|
| 304 |
+
text_encoder = clip_text_encoders[clip_model_index]
|
| 305 |
+
|
| 306 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 307 |
+
batch_size = len(prompt)
|
| 308 |
+
|
| 309 |
+
text_inputs = tokenizer(
|
| 310 |
+
prompt,
|
| 311 |
+
padding="max_length",
|
| 312 |
+
max_length=self.tokenizer_max_length,
|
| 313 |
+
truncation=True,
|
| 314 |
+
return_tensors="pt",
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
text_input_ids = text_inputs.input_ids
|
| 318 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 319 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 320 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
| 321 |
+
logger.warning(
|
| 322 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 323 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
| 324 |
+
)
|
| 325 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
| 326 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
| 327 |
+
|
| 328 |
+
if clip_skip is None:
|
| 329 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
| 330 |
+
else:
|
| 331 |
+
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
| 332 |
+
|
| 333 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
| 334 |
+
|
| 335 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 336 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 337 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 338 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 339 |
+
|
| 340 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 341 |
+
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 342 |
+
|
| 343 |
+
return prompt_embeds, pooled_prompt_embeds
|
| 344 |
+
|
| 345 |
+
def encode_prompt(
|
| 346 |
+
self,
|
| 347 |
+
prompt: Union[str, List[str]],
|
| 348 |
+
prompt_2: Union[str, List[str]],
|
| 349 |
+
prompt_3: Union[str, List[str]],
|
| 350 |
+
device: Optional[torch.device] = None,
|
| 351 |
+
num_images_per_prompt: int = 1,
|
| 352 |
+
do_classifier_free_guidance: bool = True,
|
| 353 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 354 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 355 |
+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
| 356 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 357 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 358 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 359 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 360 |
+
clip_skip: Optional[int] = None,
|
| 361 |
+
max_sequence_length: int = 256,
|
| 362 |
+
lora_scale: Optional[float] = None,
|
| 363 |
+
):
|
| 364 |
+
r"""
|
| 365 |
+
|
| 366 |
+
Args:
|
| 367 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 368 |
+
prompt to be encoded
|
| 369 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 370 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 371 |
+
used in all text-encoders
|
| 372 |
+
prompt_3 (`str` or `List[str]`, *optional*):
|
| 373 |
+
The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
|
| 374 |
+
used in all text-encoders
|
| 375 |
+
device: (`torch.device`):
|
| 376 |
+
torch device
|
| 377 |
+
num_images_per_prompt (`int`):
|
| 378 |
+
number of images that should be generated per prompt
|
| 379 |
+
do_classifier_free_guidance (`bool`):
|
| 380 |
+
whether to use classifier free guidance or not
|
| 381 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 382 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 383 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 384 |
+
less than `1`).
|
| 385 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 386 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
| 387 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
| 388 |
+
negative_prompt_3 (`str` or `List[str]`, *optional*):
|
| 389 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
|
| 390 |
+
`text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
|
| 391 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 392 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 393 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 394 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 395 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 396 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 397 |
+
argument.
|
| 398 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 399 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 400 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 401 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 402 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 403 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 404 |
+
input argument.
|
| 405 |
+
clip_skip (`int`, *optional*):
|
| 406 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 407 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 408 |
+
lora_scale (`float`, *optional*):
|
| 409 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 410 |
+
"""
|
| 411 |
+
device = device or self._execution_device
|
| 412 |
+
|
| 413 |
+
# set lora scale so that monkey patched LoRA
|
| 414 |
+
# function of text encoder can correctly access it
|
| 415 |
+
if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
|
| 416 |
+
self._lora_scale = lora_scale
|
| 417 |
+
|
| 418 |
+
# dynamically adjust the LoRA scale
|
| 419 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
| 420 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 421 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
| 422 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
| 423 |
+
|
| 424 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 425 |
+
if prompt is not None:
|
| 426 |
+
batch_size = len(prompt)
|
| 427 |
+
else:
|
| 428 |
+
batch_size = prompt_embeds.shape[0]
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
if prompt_embeds is None:
|
| 432 |
+
prompt_2 = prompt_2 or prompt
|
| 433 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
| 434 |
+
|
| 435 |
+
prompt_3 = prompt_3 or prompt
|
| 436 |
+
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
|
| 437 |
+
|
| 438 |
+
prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
|
| 439 |
+
prompt=prompt,
|
| 440 |
+
device=device,
|
| 441 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 442 |
+
clip_skip=clip_skip,
|
| 443 |
+
clip_model_index=0,
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
|
| 447 |
+
prompt=prompt_2,
|
| 448 |
+
device=device,
|
| 449 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 450 |
+
clip_skip=clip_skip,
|
| 451 |
+
clip_model_index=1,
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
|
| 455 |
+
|
| 456 |
+
t5_prompt_embed = self._get_t5_prompt_embeds(
|
| 457 |
+
prompt=prompt_3,
|
| 458 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 459 |
+
max_sequence_length=max_sequence_length,
|
| 460 |
+
device=device,
|
| 461 |
+
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
clip_prompt_embeds = torch.nn.functional.pad(
|
| 465 |
+
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
|
| 469 |
+
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 473 |
+
negative_prompt = negative_prompt or ""
|
| 474 |
+
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
| 475 |
+
negative_prompt_3 = negative_prompt_3 or negative_prompt
|
| 476 |
+
|
| 477 |
+
# normalize str to list
|
| 478 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 479 |
+
negative_prompt_2 = (
|
| 480 |
+
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
| 481 |
+
)
|
| 482 |
+
negative_prompt_3 = (
|
| 483 |
+
batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 487 |
+
raise TypeError(
|
| 488 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 489 |
+
f" {type(prompt)}."
|
| 490 |
+
)
|
| 491 |
+
elif batch_size != len(negative_prompt):
|
| 492 |
+
raise ValueError(
|
| 493 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 494 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 495 |
+
" the batch size of `prompt`."
|
| 496 |
+
)
|
| 497 |
+
# with deepspeed.zero.GatheredParameters(self.text_encoder.parameters()):
|
| 498 |
+
# negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
|
| 499 |
+
# negative_prompt,
|
| 500 |
+
# device=device,
|
| 501 |
+
# num_images_per_prompt=num_images_per_prompt,
|
| 502 |
+
# clip_skip=None,
|
| 503 |
+
# clip_model_index=0,
|
| 504 |
+
# )
|
| 505 |
+
# with deepspeed.zero.GatheredParameters(self.text_encoder_2.parameters()):
|
| 506 |
+
# negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
|
| 507 |
+
# negative_prompt_2,
|
| 508 |
+
# device=device,
|
| 509 |
+
# num_images_per_prompt=num_images_per_prompt,
|
| 510 |
+
# clip_skip=None,
|
| 511 |
+
# clip_model_index=1,
|
| 512 |
+
# )
|
| 513 |
+
negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
|
| 514 |
+
negative_prompt,
|
| 515 |
+
device=device,
|
| 516 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 517 |
+
clip_skip=None,
|
| 518 |
+
clip_model_index=0,
|
| 519 |
+
)
|
| 520 |
+
negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
|
| 521 |
+
negative_prompt_2,
|
| 522 |
+
device=device,
|
| 523 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 524 |
+
clip_skip=None,
|
| 525 |
+
clip_model_index=1,
|
| 526 |
+
)
|
| 527 |
+
negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
|
| 528 |
+
|
| 529 |
+
t5_negative_prompt_embed = self._get_t5_prompt_embeds(
|
| 530 |
+
prompt=negative_prompt_3,
|
| 531 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 532 |
+
max_sequence_length=max_sequence_length,
|
| 533 |
+
device=device,
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
negative_clip_prompt_embeds = torch.nn.functional.pad(
|
| 537 |
+
negative_clip_prompt_embeds,
|
| 538 |
+
(0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
|
| 542 |
+
negative_pooled_prompt_embeds = torch.cat(
|
| 543 |
+
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
if self.text_encoder is not None:
|
| 547 |
+
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 548 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 549 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 550 |
+
|
| 551 |
+
if self.text_encoder_2 is not None:
|
| 552 |
+
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 553 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 554 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
| 555 |
+
|
| 556 |
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
| 557 |
+
|
| 558 |
+
def check_inputs(
|
| 559 |
+
self,
|
| 560 |
+
prompt,
|
| 561 |
+
prompt_2,
|
| 562 |
+
prompt_3,
|
| 563 |
+
height,
|
| 564 |
+
width,
|
| 565 |
+
negative_prompt=None,
|
| 566 |
+
negative_prompt_2=None,
|
| 567 |
+
negative_prompt_3=None,
|
| 568 |
+
prompt_embeds=None,
|
| 569 |
+
negative_prompt_embeds=None,
|
| 570 |
+
pooled_prompt_embeds=None,
|
| 571 |
+
negative_pooled_prompt_embeds=None,
|
| 572 |
+
callback_on_step_end_tensor_inputs=None,
|
| 573 |
+
max_sequence_length=None,
|
| 574 |
+
):
|
| 575 |
+
if (
|
| 576 |
+
height % (self.vae_scale_factor * self.patch_size) != 0
|
| 577 |
+
or width % (self.vae_scale_factor * self.patch_size) != 0
|
| 578 |
+
):
|
| 579 |
+
raise ValueError(
|
| 580 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
|
| 581 |
+
f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 585 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 586 |
+
):
|
| 587 |
+
raise ValueError(
|
| 588 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
if prompt is not None and prompt_embeds is not None:
|
| 592 |
+
raise ValueError(
|
| 593 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 594 |
+
" only forward one of the two."
|
| 595 |
+
)
|
| 596 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
| 597 |
+
raise ValueError(
|
| 598 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 599 |
+
" only forward one of the two."
|
| 600 |
+
)
|
| 601 |
+
elif prompt_3 is not None and prompt_embeds is not None:
|
| 602 |
+
raise ValueError(
|
| 603 |
+
f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 604 |
+
" only forward one of the two."
|
| 605 |
+
)
|
| 606 |
+
elif prompt is None and prompt_embeds is None:
|
| 607 |
+
raise ValueError(
|
| 608 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 609 |
+
)
|
| 610 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 611 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 612 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
| 613 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
| 614 |
+
elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
|
| 615 |
+
raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
|
| 616 |
+
|
| 617 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 618 |
+
raise ValueError(
|
| 619 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 620 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 621 |
+
)
|
| 622 |
+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
| 623 |
+
raise ValueError(
|
| 624 |
+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
| 625 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 626 |
+
)
|
| 627 |
+
elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
|
| 628 |
+
raise ValueError(
|
| 629 |
+
f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
|
| 630 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 634 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 635 |
+
raise ValueError(
|
| 636 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 637 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 638 |
+
f" {negative_prompt_embeds.shape}."
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
| 642 |
+
raise ValueError(
|
| 643 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
| 647 |
+
raise ValueError(
|
| 648 |
+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
| 652 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
| 653 |
+
|
| 654 |
+
def prepare_latents(
|
| 655 |
+
self,
|
| 656 |
+
batch_size,
|
| 657 |
+
num_channels_latents,
|
| 658 |
+
height,
|
| 659 |
+
width,
|
| 660 |
+
dtype,
|
| 661 |
+
device,
|
| 662 |
+
generator,
|
| 663 |
+
latents=None,
|
| 664 |
+
):
|
| 665 |
+
if latents is not None:
|
| 666 |
+
return latents.to(device=device, dtype=dtype)
|
| 667 |
+
|
| 668 |
+
shape = (
|
| 669 |
+
batch_size,
|
| 670 |
+
num_channels_latents,
|
| 671 |
+
int(height) // self.vae_scale_factor,
|
| 672 |
+
int(width) // self.vae_scale_factor,
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 676 |
+
raise ValueError(
|
| 677 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 678 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 682 |
+
|
| 683 |
+
return latents
|
| 684 |
+
|
| 685 |
+
@property
|
| 686 |
+
def guidance_scale(self):
|
| 687 |
+
return self._guidance_scale
|
| 688 |
+
|
| 689 |
+
@property
|
| 690 |
+
def skip_guidance_layers(self):
|
| 691 |
+
return self._skip_guidance_layers
|
| 692 |
+
|
| 693 |
+
@property
|
| 694 |
+
def clip_skip(self):
|
| 695 |
+
return self._clip_skip
|
| 696 |
+
|
| 697 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 698 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 699 |
+
# corresponds to doing no classifier free guidance.
|
| 700 |
+
@property
|
| 701 |
+
def do_classifier_free_guidance(self):
|
| 702 |
+
return self._guidance_scale > 1
|
| 703 |
+
|
| 704 |
+
@property
|
| 705 |
+
def joint_attention_kwargs(self):
|
| 706 |
+
return self._joint_attention_kwargs
|
| 707 |
+
|
| 708 |
+
@property
|
| 709 |
+
def num_timesteps(self):
|
| 710 |
+
return self._num_timesteps
|
| 711 |
+
|
| 712 |
+
@property
|
| 713 |
+
def interrupt(self):
|
| 714 |
+
return self._interrupt
|
| 715 |
+
|
| 716 |
+
def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
|
| 717 |
+
"""Encodes the given image into a feature representation using a pre-trained image encoder.
|
| 718 |
+
|
| 719 |
+
Args:
|
| 720 |
+
image (`PipelineImageInput`):
|
| 721 |
+
Input image to be encoded.
|
| 722 |
+
device: (`torch.device`):
|
| 723 |
+
Torch device.
|
| 724 |
+
|
| 725 |
+
Returns:
|
| 726 |
+
`torch.Tensor`: The encoded image feature representation.
|
| 727 |
+
"""
|
| 728 |
+
if not isinstance(image, torch.Tensor):
|
| 729 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
| 730 |
+
|
| 731 |
+
image = image.to(device=device, dtype=self.dtype)
|
| 732 |
+
|
| 733 |
+
return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
| 734 |
+
|
| 735 |
+
def prepare_ip_adapter_image_embeds(
|
| 736 |
+
self,
|
| 737 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 738 |
+
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
|
| 739 |
+
device: Optional[torch.device] = None,
|
| 740 |
+
num_images_per_prompt: int = 1,
|
| 741 |
+
do_classifier_free_guidance: bool = True,
|
| 742 |
+
) -> torch.Tensor:
|
| 743 |
+
"""Prepares image embeddings for use in the IP-Adapter.
|
| 744 |
+
|
| 745 |
+
Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
|
| 746 |
+
|
| 747 |
+
Args:
|
| 748 |
+
ip_adapter_image (`PipelineImageInput`, *optional*):
|
| 749 |
+
The input image to extract features from for IP-Adapter.
|
| 750 |
+
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
|
| 751 |
+
Precomputed image embeddings.
|
| 752 |
+
device: (`torch.device`, *optional*):
|
| 753 |
+
Torch device.
|
| 754 |
+
num_images_per_prompt (`int`, defaults to 1):
|
| 755 |
+
Number of images that should be generated per prompt.
|
| 756 |
+
do_classifier_free_guidance (`bool`, defaults to True):
|
| 757 |
+
Whether to use classifier free guidance or not.
|
| 758 |
+
"""
|
| 759 |
+
device = device or self._execution_device
|
| 760 |
+
|
| 761 |
+
if ip_adapter_image_embeds is not None:
|
| 762 |
+
if do_classifier_free_guidance:
|
| 763 |
+
single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
|
| 764 |
+
else:
|
| 765 |
+
single_image_embeds = ip_adapter_image_embeds
|
| 766 |
+
elif ip_adapter_image is not None:
|
| 767 |
+
single_image_embeds = self.encode_image(ip_adapter_image, device)
|
| 768 |
+
if do_classifier_free_guidance:
|
| 769 |
+
single_negative_image_embeds = torch.zeros_like(single_image_embeds)
|
| 770 |
+
else:
|
| 771 |
+
raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
|
| 772 |
+
|
| 773 |
+
image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
| 774 |
+
|
| 775 |
+
if do_classifier_free_guidance:
|
| 776 |
+
negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
| 777 |
+
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
|
| 778 |
+
|
| 779 |
+
return image_embeds.to(device=device)
|
| 780 |
+
|
| 781 |
+
def enable_sequential_cpu_offload(self, *args, **kwargs):
|
| 782 |
+
if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
|
| 783 |
+
logger.warning(
|
| 784 |
+
"`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
|
| 785 |
+
"`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
|
| 786 |
+
"`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
+
super().enable_sequential_cpu_offload(*args, **kwargs)
|
| 790 |
+
|
| 791 |
+
@torch.no_grad()
|
| 792 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 793 |
+
def __call__(
|
| 794 |
+
self,
|
| 795 |
+
prompt: Union[str, List[str]] = None,
|
| 796 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 797 |
+
prompt_3: Optional[Union[str, List[str]]] = None,
|
| 798 |
+
height: Optional[int] = None,
|
| 799 |
+
width: Optional[int] = None,
|
| 800 |
+
num_inference_steps: int = 28,
|
| 801 |
+
sigmas: Optional[List[float]] = None,
|
| 802 |
+
guidance_scale: float = 7.0,
|
| 803 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 804 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 805 |
+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
| 806 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 807 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 808 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 809 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 810 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 811 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 812 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 813 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 814 |
+
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
|
| 815 |
+
output_type: Optional[str] = "pil",
|
| 816 |
+
return_dict: bool = True,
|
| 817 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 818 |
+
clip_skip: Optional[int] = None,
|
| 819 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 820 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 821 |
+
max_sequence_length: int = 256,
|
| 822 |
+
skip_guidance_layers: List[int] = None,
|
| 823 |
+
skip_layer_guidance_scale: float = 2.8,
|
| 824 |
+
skip_layer_guidance_stop: float = 0.2,
|
| 825 |
+
skip_layer_guidance_start: float = 0.01,
|
| 826 |
+
mu: Optional[float] = None,
|
| 827 |
+
):
|
| 828 |
+
r"""
|
| 829 |
+
Function invoked when calling the pipeline for generation.
|
| 830 |
+
|
| 831 |
+
Args:
|
| 832 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 833 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 834 |
+
instead.
|
| 835 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 836 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 837 |
+
will be used instead
|
| 838 |
+
prompt_3 (`str` or `List[str]`, *optional*):
|
| 839 |
+
The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
|
| 840 |
+
will be used instead
|
| 841 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 842 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 843 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 844 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 845 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 846 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 847 |
+
expense of slower inference.
|
| 848 |
+
sigmas (`List[float]`, *optional*):
|
| 849 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 850 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 851 |
+
will be used.
|
| 852 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 853 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 854 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 855 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 856 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 857 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 858 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 859 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 860 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 861 |
+
less than `1`).
|
| 862 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 863 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
| 864 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used instead
|
| 865 |
+
negative_prompt_3 (`str` or `List[str]`, *optional*):
|
| 866 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
|
| 867 |
+
`text_encoder_3`. If not defined, `negative_prompt` is used instead
|
| 868 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 869 |
+
The number of images to generate per prompt.
|
| 870 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 871 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 872 |
+
to make generation deterministic.
|
| 873 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 874 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 875 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 876 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 877 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 878 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 879 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 880 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 881 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 882 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 883 |
+
argument.
|
| 884 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 885 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 886 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 887 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 888 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 889 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 890 |
+
input argument.
|
| 891 |
+
ip_adapter_image (`PipelineImageInput`, *optional*):
|
| 892 |
+
Optional image input to work with IP Adapters.
|
| 893 |
+
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
|
| 894 |
+
Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
|
| 895 |
+
emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
|
| 896 |
+
`True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 897 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 898 |
+
The output format of the generate image. Choose between
|
| 899 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 900 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 901 |
+
Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
|
| 902 |
+
a plain tuple.
|
| 903 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 904 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 905 |
+
`self.processor` in
|
| 906 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 907 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 908 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 909 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 910 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 911 |
+
`callback_on_step_end_tensor_inputs`.
|
| 912 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 913 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 914 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 915 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 916 |
+
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
|
| 917 |
+
skip_guidance_layers (`List[int]`, *optional*):
|
| 918 |
+
A list of integers that specify layers to skip during guidance. If not provided, all layers will be
|
| 919 |
+
used for guidance. If provided, the guidance will only be applied to the layers specified in the list.
|
| 920 |
+
Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9].
|
| 921 |
+
skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
|
| 922 |
+
`skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
|
| 923 |
+
with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers
|
| 924 |
+
with a scale of `1`.
|
| 925 |
+
skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
|
| 926 |
+
`skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
|
| 927 |
+
`skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
|
| 928 |
+
StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
|
| 929 |
+
skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
|
| 930 |
+
`skip_guidance_layers` will start. The guidance will be applied to the layers specified in
|
| 931 |
+
`skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
|
| 932 |
+
StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
|
| 933 |
+
mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
|
| 934 |
+
|
| 935 |
+
Examples:
|
| 936 |
+
|
| 937 |
+
Returns:
|
| 938 |
+
[`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
|
| 939 |
+
[`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
|
| 940 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 941 |
+
"""
|
| 942 |
+
|
| 943 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 944 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 945 |
+
|
| 946 |
+
# 1. Check inputs. Raise error if not correct
|
| 947 |
+
self.check_inputs(
|
| 948 |
+
prompt,
|
| 949 |
+
prompt_2,
|
| 950 |
+
prompt_3,
|
| 951 |
+
height,
|
| 952 |
+
width,
|
| 953 |
+
negative_prompt=negative_prompt,
|
| 954 |
+
negative_prompt_2=negative_prompt_2,
|
| 955 |
+
negative_prompt_3=negative_prompt_3,
|
| 956 |
+
prompt_embeds=prompt_embeds,
|
| 957 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 958 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 959 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 960 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 961 |
+
max_sequence_length=max_sequence_length,
|
| 962 |
+
)
|
| 963 |
+
|
| 964 |
+
self._guidance_scale = guidance_scale
|
| 965 |
+
self._skip_layer_guidance_scale = skip_layer_guidance_scale
|
| 966 |
+
self._clip_skip = clip_skip
|
| 967 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 968 |
+
self._interrupt = False
|
| 969 |
+
|
| 970 |
+
# 2. Define call parameters
|
| 971 |
+
if prompt is not None and isinstance(prompt, str):
|
| 972 |
+
batch_size = 1
|
| 973 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 974 |
+
batch_size = len(prompt)
|
| 975 |
+
else:
|
| 976 |
+
batch_size = prompt_embeds.shape[0]
|
| 977 |
+
|
| 978 |
+
device = self._execution_device
|
| 979 |
+
|
| 980 |
+
lora_scale = (
|
| 981 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 982 |
+
)
|
| 983 |
+
(
|
| 984 |
+
prompt_embeds,
|
| 985 |
+
negative_prompt_embeds,
|
| 986 |
+
pooled_prompt_embeds,
|
| 987 |
+
negative_pooled_prompt_embeds,
|
| 988 |
+
) = self.encode_prompt(
|
| 989 |
+
prompt=prompt,
|
| 990 |
+
prompt_2=prompt_2,
|
| 991 |
+
prompt_3=prompt_3,
|
| 992 |
+
negative_prompt=negative_prompt,
|
| 993 |
+
negative_prompt_2=negative_prompt_2,
|
| 994 |
+
negative_prompt_3=negative_prompt_3,
|
| 995 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 996 |
+
prompt_embeds=prompt_embeds,
|
| 997 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 998 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 999 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 1000 |
+
device=device,
|
| 1001 |
+
clip_skip=self.clip_skip,
|
| 1002 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 1003 |
+
max_sequence_length=max_sequence_length,
|
| 1004 |
+
lora_scale=lora_scale,
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
if self.do_classifier_free_guidance:
|
| 1008 |
+
if skip_guidance_layers is not None:
|
| 1009 |
+
original_prompt_embeds = prompt_embeds
|
| 1010 |
+
original_pooled_prompt_embeds = pooled_prompt_embeds
|
| 1011 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 1012 |
+
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
| 1013 |
+
|
| 1014 |
+
# 4. Prepare latent variables
|
| 1015 |
+
num_channels_latents = self.transformer.config.in_channels
|
| 1016 |
+
latents = self.prepare_latents(
|
| 1017 |
+
batch_size * num_images_per_prompt,
|
| 1018 |
+
num_channels_latents,
|
| 1019 |
+
height,
|
| 1020 |
+
width,
|
| 1021 |
+
prompt_embeds.dtype,
|
| 1022 |
+
device,
|
| 1023 |
+
generator,
|
| 1024 |
+
latents,
|
| 1025 |
+
)
|
| 1026 |
+
|
| 1027 |
+
# 5. Prepare timesteps
|
| 1028 |
+
scheduler_kwargs = {}
|
| 1029 |
+
if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
|
| 1030 |
+
_, _, height, width = latents.shape
|
| 1031 |
+
image_seq_len = (height // self.transformer.config.patch_size) * (
|
| 1032 |
+
width // self.transformer.config.patch_size
|
| 1033 |
+
)
|
| 1034 |
+
mu = calculate_shift(
|
| 1035 |
+
image_seq_len,
|
| 1036 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 1037 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 1038 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 1039 |
+
self.scheduler.config.get("max_shift", 1.16),
|
| 1040 |
+
)
|
| 1041 |
+
scheduler_kwargs["mu"] = mu
|
| 1042 |
+
elif mu is not None:
|
| 1043 |
+
scheduler_kwargs["mu"] = mu
|
| 1044 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 1045 |
+
self.scheduler,
|
| 1046 |
+
num_inference_steps,
|
| 1047 |
+
device,
|
| 1048 |
+
sigmas=sigmas,
|
| 1049 |
+
**scheduler_kwargs,
|
| 1050 |
+
)
|
| 1051 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 1052 |
+
self._num_timesteps = len(timesteps)
|
| 1053 |
+
|
| 1054 |
+
# 6. Prepare image embeddings
|
| 1055 |
+
if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
|
| 1056 |
+
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 1057 |
+
ip_adapter_image,
|
| 1058 |
+
ip_adapter_image_embeds,
|
| 1059 |
+
device,
|
| 1060 |
+
batch_size * num_images_per_prompt,
|
| 1061 |
+
self.do_classifier_free_guidance,
|
| 1062 |
+
)
|
| 1063 |
+
|
| 1064 |
+
if self.joint_attention_kwargs is None:
|
| 1065 |
+
self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
|
| 1066 |
+
else:
|
| 1067 |
+
self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
|
| 1068 |
+
|
| 1069 |
+
|
| 1070 |
+
# 7. Denoising loop
|
| 1071 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1072 |
+
for i, t in enumerate(timesteps):
|
| 1073 |
+
if self.interrupt:
|
| 1074 |
+
continue
|
| 1075 |
+
|
| 1076 |
+
# expand the latents if we are doing classifier free guidance
|
| 1077 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 1078 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 1079 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 1080 |
+
|
| 1081 |
+
noise_pred = self.transformer(
|
| 1082 |
+
hidden_states=latent_model_input,
|
| 1083 |
+
timestep=timestep,
|
| 1084 |
+
encoder_hidden_states=prompt_embeds,
|
| 1085 |
+
pooled_projections=pooled_prompt_embeds,
|
| 1086 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1087 |
+
return_dict=False,
|
| 1088 |
+
)[0]
|
| 1089 |
+
|
| 1090 |
+
# perform guidance
|
| 1091 |
+
if self.do_classifier_free_guidance:
|
| 1092 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 1093 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 1094 |
+
should_skip_layers = (
|
| 1095 |
+
True
|
| 1096 |
+
if i > num_inference_steps * skip_layer_guidance_start
|
| 1097 |
+
and i < num_inference_steps * skip_layer_guidance_stop
|
| 1098 |
+
else False
|
| 1099 |
+
)
|
| 1100 |
+
if skip_guidance_layers is not None and should_skip_layers:
|
| 1101 |
+
timestep = t.expand(latents.shape[0])
|
| 1102 |
+
latent_model_input = latents
|
| 1103 |
+
noise_pred_skip_layers = self.transformer(
|
| 1104 |
+
hidden_states=latent_model_input,
|
| 1105 |
+
timestep=timestep,
|
| 1106 |
+
encoder_hidden_states=original_prompt_embeds,
|
| 1107 |
+
pooled_projections=original_pooled_prompt_embeds,
|
| 1108 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1109 |
+
return_dict=False,
|
| 1110 |
+
skip_layers=skip_guidance_layers,
|
| 1111 |
+
)[0]
|
| 1112 |
+
noise_pred = (
|
| 1113 |
+
noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale
|
| 1114 |
+
)
|
| 1115 |
+
|
| 1116 |
+
latents_dtype = latents.dtype
|
| 1117 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 1118 |
+
|
| 1119 |
+
if latents.dtype != latents_dtype:
|
| 1120 |
+
if torch.backends.mps.is_available():
|
| 1121 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 1122 |
+
latents = latents.to(latents_dtype)
|
| 1123 |
+
|
| 1124 |
+
if callback_on_step_end is not None:
|
| 1125 |
+
callback_kwargs = {}
|
| 1126 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 1127 |
+
callback_kwargs[k] = locals()[k]
|
| 1128 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1129 |
+
|
| 1130 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1131 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1132 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 1133 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
| 1134 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
| 1135 |
+
)
|
| 1136 |
+
|
| 1137 |
+
# call the callback, if provided
|
| 1138 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1139 |
+
progress_bar.update()
|
| 1140 |
+
|
| 1141 |
+
if XLA_AVAILABLE:
|
| 1142 |
+
xm.mark_step()
|
| 1143 |
+
|
| 1144 |
+
if output_type == "latent":
|
| 1145 |
+
image = latents
|
| 1146 |
+
|
| 1147 |
+
else:
|
| 1148 |
+
mean_img = torch.mean(latents[0], dim=0).cpu().float().numpy()
|
| 1149 |
+
Image.fromarray(((mean_img - mean_img.min()) / (mean_img.max() - mean_img.min()) * 255).astype(np.uint8)).save('mean.png')
|
| 1150 |
+
|
| 1151 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 1152 |
+
|
| 1153 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 1154 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 1155 |
+
|
| 1156 |
+
# Offload all models
|
| 1157 |
+
self.maybe_free_model_hooks()
|
| 1158 |
+
|
| 1159 |
+
if not return_dict:
|
| 1160 |
+
return (image,)
|
| 1161 |
+
|
| 1162 |
+
return StableDiffusion3PipelineOutput(images=image)
|