Spaces:
Running on Zero
Running on Zero
Commit ·
90ce156
0
Parent(s):
Upload Space with Xet-managed assets
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .dockerignore +22 -0
- .gitattributes +6 -0
- .gitignore +34 -0
- Dockerfile +36 -0
- LICENSE +201 -0
- README.md +670 -0
- README_zh.md +660 -0
- RIFE +1 -0
- SECURITY.md +11 -0
- SPACE_DEPLOYMENT.md +76 -0
- app.py +0 -0
- app_save.py +2064 -0
- app_wrong.py +2247 -0
- assets/image-understanding/cases/image-understanding-case-02.png +3 -0
- assets/image-understanding/cases/image-understanding-case-05.png +3 -0
- assets/image-understanding/cases/image-understanding-case-06.png +3 -0
- assets/logo/lance-logo.webp +3 -0
- assets/video-understanding/videos/video-understanding-caption-long-01.mp4 +3 -0
- assets/video-understanding/videos/video-understanding-caption-short-01.mp4 +3 -0
- assets/video-understanding/videos/video-understanding-vqa-01.mp4 +3 -0
- benchmarks/image_gen/DPG/DPG.jsonl +0 -0
- benchmarks/image_gen/DPG/README.md +57 -0
- benchmarks/image_gen/DPG/README_zh.md +57 -0
- benchmarks/image_gen/DPG/sample_DPG.py +509 -0
- benchmarks/image_gen/DPG/sample_DPG.sh +113 -0
- benchmarks/image_gen/GEdit/GEdit_en.json +0 -0
- benchmarks/image_gen/GEdit/README.md +68 -0
- benchmarks/image_gen/GEdit/README_zh.md +67 -0
- benchmarks/image_gen/GEdit/sample_GEdit.py +425 -0
- benchmarks/image_gen/GEdit/sample_GEdit.sh +106 -0
- benchmarks/image_gen/GenEVAL/GenEVAL.jsonl +0 -0
- benchmarks/image_gen/GenEVAL/README.md +73 -0
- benchmarks/image_gen/GenEVAL/README_zh.md +73 -0
- benchmarks/image_gen/GenEVAL/sample_GenEVAL.py +463 -0
- benchmarks/image_gen/GenEVAL/sample_GenEVAL.sh +110 -0
- benchmarks/sample_env.sh +107 -0
- benchmarks/video_gen/Vbench/README.md +72 -0
- benchmarks/video_gen/Vbench/README_zh.md +72 -0
- benchmarks/video_gen/Vbench/Vbench_recaption.jsonl +0 -0
- benchmarks/video_gen/Vbench/sample_vbench.py +559 -0
- benchmarks/video_gen/Vbench/sample_vbench.sh +127 -0
- benchmarks/video_gen/Vbench/temporal_flickering_prompts.json +77 -0
- common/__init__.py +16 -0
- common/model/__init__.py +20 -0
- common/model/checks.py +14 -0
- common/model/hacks.py +54 -0
- common/utils/__init__.py +55 -0
- common/utils/distributed.py +62 -0
- common/utils/logging.py +44 -0
- common/utils/misc.py +40 -0
.dockerignore
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
*.so
|
| 7 |
+
|
| 8 |
+
.venv/
|
| 9 |
+
venv/
|
| 10 |
+
env/
|
| 11 |
+
ENV/
|
| 12 |
+
|
| 13 |
+
.pytest_cache/
|
| 14 |
+
.mypy_cache/
|
| 15 |
+
.ruff_cache/
|
| 16 |
+
|
| 17 |
+
downloads/
|
| 18 |
+
results/
|
| 19 |
+
tmps/
|
| 20 |
+
*.log
|
| 21 |
+
|
| 22 |
+
.DS_Store
|
.gitattributes
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.webp filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
+
*.so
|
| 6 |
+
|
| 7 |
+
.Python
|
| 8 |
+
.python-version
|
| 9 |
+
.venv/
|
| 10 |
+
venv/
|
| 11 |
+
env/
|
| 12 |
+
ENV/
|
| 13 |
+
|
| 14 |
+
.pytest_cache/
|
| 15 |
+
.mypy_cache/
|
| 16 |
+
.ruff_cache/
|
| 17 |
+
.coverage
|
| 18 |
+
.coverage.*
|
| 19 |
+
htmlcov/
|
| 20 |
+
|
| 21 |
+
build/
|
| 22 |
+
dist/
|
| 23 |
+
*.egg-info/
|
| 24 |
+
.eggs/
|
| 25 |
+
|
| 26 |
+
.ipynb_checkpoints/
|
| 27 |
+
|
| 28 |
+
.DS_Store
|
| 29 |
+
|
| 30 |
+
# custom ignore
|
| 31 |
+
results/
|
| 32 |
+
downloads/
|
| 33 |
+
tmps/
|
| 34 |
+
*.log
|
Dockerfile
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel
|
| 2 |
+
|
| 3 |
+
ENV DEBIAN_FRONTEND=noninteractive \
|
| 4 |
+
PYTHONUNBUFFERED=1 \
|
| 5 |
+
PIP_NO_CACHE_DIR=1 \
|
| 6 |
+
GRADIO_SERVER_NAME=0.0.0.0 \
|
| 7 |
+
GRADIO_SERVER_PORT=7860 \
|
| 8 |
+
LANCE_AUTO_DOWNLOAD=1 \
|
| 9 |
+
LANCE_MODEL_BASE_DIR=/data/lance_models \
|
| 10 |
+
LANCE_GRADIO_TMP_ROOT=/tmp/lance_gradio
|
| 11 |
+
|
| 12 |
+
WORKDIR /app
|
| 13 |
+
|
| 14 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 15 |
+
ffmpeg \
|
| 16 |
+
git \
|
| 17 |
+
libgl1 \
|
| 18 |
+
libglib2.0-0 \
|
| 19 |
+
libsndfile1 \
|
| 20 |
+
libsm6 \
|
| 21 |
+
libxext6 \
|
| 22 |
+
ninja-build \
|
| 23 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 24 |
+
|
| 25 |
+
COPY requirements.txt /app/requirements.txt
|
| 26 |
+
|
| 27 |
+
RUN python -m pip install --upgrade pip setuptools wheel \
|
| 28 |
+
&& grep -v '^flash-attn==' requirements.txt > /tmp/requirements-no-flash-attn.txt \
|
| 29 |
+
&& python -m pip install -r /tmp/requirements-no-flash-attn.txt \
|
| 30 |
+
&& python -m pip install flash-attn==2.6.3 --no-build-isolation
|
| 31 |
+
|
| 32 |
+
COPY . /app
|
| 33 |
+
|
| 34 |
+
EXPOSE 7860
|
| 35 |
+
|
| 36 |
+
CMD ["python", "app.py"]
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,670 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Lance
|
| 3 |
+
emoji: 🎬
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
suggested_hardware: l40s
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
<div align="center">
|
| 12 |
+
<img src="assets/logo/lance-logo.webp" alt="Lance logo" width="300">
|
| 13 |
+
|
| 14 |
+
<h1 align="center"><sup>Lance: Unified Multimodal Modeling by Multi-Task Synergy</sup></h1>
|
| 15 |
+
<p>
|
| 16 |
+
<strong>
|
| 17 |
+
<a href="https://scholar.google.com.hk/citations?user=FXxoQlsAAAAJ&hl=zh-CN&oi=ao" style="text-decoration: none; color: inherit;">Fengyi Fu</a><sup>*</sup>,
|
| 18 |
+
<a href="https://corleone-huang.github.io/" style="text-decoration: none; color: inherit;">Mengqi Huang</a><sup>*,✉</sup>,
|
| 19 |
+
<a href="https://scholar.google.com.hk/citations?user=9ER6nVkAAAAJ&hl=zh-CN&oi=ao" style="text-decoration: none; color: inherit;">Shaojin Wu</a><sup>*</sup>,
|
| 20 |
+
Yunsheng Jiang<sup>*</sup>,
|
| 21 |
+
Yufei Huo,
|
| 22 |
+
<a href="https://guojianzhu.com/" style="text-decoration: none; color: inherit;">Jianzhu Guo</a><sup>✉,§</sup>
|
| 23 |
+
</strong><br>
|
| 24 |
+
Hao Li,
|
| 25 |
+
Yinghang Song,
|
| 26 |
+
Fei Ding,
|
| 27 |
+
Qian He,
|
| 28 |
+
Zheren Fu,
|
| 29 |
+
Zhendong Mao,
|
| 30 |
+
Yongdong Zhang
|
| 31 |
+
<br>
|
| 32 |
+
<em>ByteDance</em>
|
| 33 |
+
<br>
|
| 34 |
+
<sup>*</sup> Equal contribution <sup>✉</sup> Corresponding authors <sup>§</sup> Project lead
|
| 35 |
+
</p>
|
| 36 |
+
<p>
|
| 37 |
+
<a href="https://lance-project.github.io/" style="text-decoration: none; margin: 0 8px;"><img src="https://img.shields.io/badge/Homepage-Lance-blue?style=flat" alt="Homepage"></a>
|
| 38 |
+
<a href="http://arxiv.org/abs/2605.18678" style="text-decoration: none; margin: 0 8px;"><img src="https://img.shields.io/badge/Paper-arXiv-red?style=flat&logo=arxiv" alt="arXiv"></a>
|
| 39 |
+
<a href="https://huggingface.co/bytedance-research/Lance" style="text-decoration: none; margin: 0 8px;"><img src="https://img.shields.io/badge/Model-HuggingFace-yellow?style=flat&logo=huggingface" alt="Model"></a>
|
| 40 |
+
<br>
|
| 41 |
+
English | <a href="./README_zh.md"><ins>简体中文</ins></a>
|
| 42 |
+
</p>
|
| 43 |
+
</div>
|
| 44 |
+
|
| 45 |
+
## 🌟 Highlights
|
| 46 |
+
|
| 47 |
+
**Lance** is a 3B native unified multimodal model that supports **image and video understanding, generation, and editing** within a single framework.
|
| 48 |
+
|
| 49 |
+
- **Efficient at 3B scale.** With only **3B active parameters**, Lance delivers strong performance across image generation, image editing, and video generation benchmarks.
|
| 50 |
+
- **Trained from scratch.** Lance is built with a staged multi-task recipe and trained entirely from scratch (except for the ViT and VAE encoders; the transformer backbone is trained entirely from scratch) within a **128-A100-GPU** budget.
|
| 51 |
+
|
| 52 |
+
<div align="center">
|
| 53 |
+
<img src="assets/benchmarks/benchmark-overview.png" alt="Lance benchmark overview across image generation, image editing, video generation, and video understanding" width="980">
|
| 54 |
+
</div>
|
| 55 |
+
|
| 56 |
+
## 🎨 Demo
|
| 57 |
+
|
| 58 |
+
### Text-to-Video
|
| 59 |
+
|
| 60 |
+
<table align="center">
|
| 61 |
+
<tr>
|
| 62 |
+
<td><a href="assets/text-to-video/videos/text-to-video-demo-01.mp4"><img src="assets/text-to-video/previews/text-to-video-demo-01.gif" width="100%"></a></td>
|
| 63 |
+
<td><a href="assets/text-to-video/videos/text-to-video-demo-02.mp4"><img src="assets/text-to-video/previews/text-to-video-demo-02.gif" width="100%"></a></td>
|
| 64 |
+
<td><a href="assets/text-to-video/videos/text-to-video-demo-03.mp4"><img src="assets/text-to-video/previews/text-to-video-demo-03.gif" width="100%"></a></td>
|
| 65 |
+
<td><a href="assets/text-to-video/videos/text-to-video-demo-04.mp4"><img src="assets/text-to-video/previews/text-to-video-demo-04.gif" width="100%"></a></td>
|
| 66 |
+
</tr>
|
| 67 |
+
<tr>
|
| 68 |
+
<td><a href="assets/text-to-video/videos/text-to-video-demo-05.mp4"><img src="assets/text-to-video/previews/text-to-video-demo-05.gif" width="100%"></a></td>
|
| 69 |
+
<td><a href="assets/text-to-video/videos/text-to-video-demo-06.mp4"><img src="assets/text-to-video/previews/text-to-video-demo-06.gif" width="100%"></a></td>
|
| 70 |
+
<td><a href="assets/text-to-video/videos/text-to-video-demo-07.mp4"><img src="assets/text-to-video/previews/text-to-video-demo-07.gif" width="100%"></a></td>
|
| 71 |
+
<td><a href="assets/text-to-video/videos/text-to-video-demo-08.mp4"><img src="assets/text-to-video/previews/text-to-video-demo-08.gif" width="100%"></a></td>
|
| 72 |
+
</tr>
|
| 73 |
+
</table>
|
| 74 |
+
|
| 75 |
+
### Video Editing
|
| 76 |
+
|
| 77 |
+
<table align="center">
|
| 78 |
+
<tr>
|
| 79 |
+
<td><a href="assets/video-editing/videos/video-editing-demo-01.mp4"><img src="assets/video-editing/previews/video-editing-demo-01.gif" width="100%"></a></td>
|
| 80 |
+
<td><a href="assets/video-editing/videos/video-editing-demo-02.mp4"><img src="assets/video-editing/previews/video-editing-demo-02.gif" width="100%"></a></td>
|
| 81 |
+
<td><a href="assets/video-editing/videos/video-editing-demo-03.mp4"><img src="assets/video-editing/previews/video-editing-demo-03.gif" width="100%"></a></td>
|
| 82 |
+
<td><a href="assets/video-editing/videos/video-editing-demo-04.mp4"><img src="assets/video-editing/previews/video-editing-demo-04.gif" width="100%"></a></td>
|
| 83 |
+
</tr>
|
| 84 |
+
<tr>
|
| 85 |
+
<td><a href="assets/video-editing/videos/video-editing-demo-05.mp4"><img src="assets/video-editing/previews/video-editing-demo-05.gif" width="100%"></a></td>
|
| 86 |
+
<td><a href="assets/video-editing/videos/video-editing-demo-06.mp4"><img src="assets/video-editing/previews/video-editing-demo-06.gif" width="100%"></a></td>
|
| 87 |
+
<td><a href="assets/video-editing/videos/video-editing-demo-07.mp4"><img src="assets/video-editing/previews/video-editing-demo-07.gif" width="100%"></a></td>
|
| 88 |
+
<td><a href="assets/video-editing/videos/video-editing-demo-08.mp4"><img src="assets/video-editing/previews/video-editing-demo-08.gif" width="100%"></a></td>
|
| 89 |
+
</tr>
|
| 90 |
+
</table>
|
| 91 |
+
|
| 92 |
+
### Multi-turn Consistency Editing
|
| 93 |
+
|
| 94 |
+
<div align="center">
|
| 95 |
+
<a href="assets/multi-turn-editing/videos/multi-turn-editing-demo-01.mp4">
|
| 96 |
+
<img src="assets/multi-turn-editing/previews/multi-turn-editing-demo-01.gif" width="100%">
|
| 97 |
+
</a>
|
| 98 |
+
</div>
|
| 99 |
+
|
| 100 |
+
### Intelligent Video Generation
|
| 101 |
+
|
| 102 |
+
<table align="center">
|
| 103 |
+
<tr>
|
| 104 |
+
<td><a href="assets/intelligent-video/videos/intelligent-video-demo-01.mp4"><img src="assets/intelligent-video/previews/intelligent-video-demo-01.gif" width="100%"></a></td>
|
| 105 |
+
<td><a href="assets/intelligent-video/videos/intelligent-video-demo-02.mp4"><img src="assets/intelligent-video/previews/intelligent-video-demo-02.gif" width="100%"></a></td>
|
| 106 |
+
<td><a href="assets/intelligent-video/videos/intelligent-video-demo-03.mp4"><img src="assets/intelligent-video/previews/intelligent-video-demo-03.gif" width="100%"></a></td>
|
| 107 |
+
<td><a href="assets/intelligent-video/videos/intelligent-video-demo-04.mp4"><img src="assets/intelligent-video/previews/intelligent-video-demo-04.gif" width="100%"></a></td>
|
| 108 |
+
</tr>
|
| 109 |
+
</table>
|
| 110 |
+
|
| 111 |
+
### Video Understanding
|
| 112 |
+
|
| 113 |
+
<div align="center">
|
| 114 |
+
<table align="center">
|
| 115 |
+
<tr>
|
| 116 |
+
<td align="left" valign="top" width="33%">
|
| 117 |
+
<a href="assets/video-understanding/videos/video-understanding-vqa-01.mp4">
|
| 118 |
+
<img src="assets/video-understanding/previews/video-understanding-vqa-01.gif" width="100%">
|
| 119 |
+
</a>
|
| 120 |
+
<p><strong>Question:</strong> How many times did the person launch objects on the table? Options: (A) 3 (B) 2 (C) 4</p>
|
| 121 |
+
<p><strong>Response:</strong> (A) 3</p>
|
| 122 |
+
</td>
|
| 123 |
+
<td align="left" valign="top" width="33%">
|
| 124 |
+
<a href="assets/video-understanding/videos/video-understanding-vqa-02.mp4">
|
| 125 |
+
<img src="assets/video-understanding/previews/video-understanding-vqa-02.gif" width="100%">
|
| 126 |
+
</a>
|
| 127 |
+
<p><strong>Question:</strong> The person makes sets of repeated actions. How many distinct repeated actions did the person do? Options: (A) 2 (B) 3 (C) 4</p>
|
| 128 |
+
<p><strong>Response:</strong> (A) 2</p>
|
| 129 |
+
</td>
|
| 130 |
+
<td align="left" valign="top" width="33%">
|
| 131 |
+
<a href="assets/video-understanding/videos/video-understanding-vqa-03.mp4">
|
| 132 |
+
<img src="assets/video-understanding/previews/video-understanding-vqa-03.gif" width="100%">
|
| 133 |
+
</a>
|
| 134 |
+
<p><strong>Question:</strong> In which direction does the purple sphere move in the video? Options: (A) Down and to the right. (B) Up and to the left. (C) Up and to the right. (D) The object is stationary.</p>
|
| 135 |
+
<p><strong>Response:</strong> (A) Down and to the right.</p>
|
| 136 |
+
</td>
|
| 137 |
+
</tr>
|
| 138 |
+
<tr>
|
| 139 |
+
<td align="left" valign="top" width="33%">
|
| 140 |
+
<a href="assets/video-understanding/videos/video-understanding-vqa-04.mp4">
|
| 141 |
+
<img src="assets/video-understanding/previews/video-understanding-vqa-04.gif" width="100%">
|
| 142 |
+
</a>
|
| 143 |
+
<p><strong>Question:</strong> What is the unrealistic phenomenon displayed in the video? Options: (A) The man can manipulate time via phone. (B) Man grabs an object through a phone screen. (C) Chocolate transforms into different objects. (D) Visible means of propulsion enables flight.</p>
|
| 144 |
+
<p><strong>Response:</strong> (B) Man grabs an object through a phone screen.</p>
|
| 145 |
+
</td>
|
| 146 |
+
<td align="left" valign="top" width="33%">
|
| 147 |
+
<a href="assets/video-understanding/videos/video-understanding-caption-short-01.mp4">
|
| 148 |
+
<img src="assets/video-understanding/previews/video-understanding-caption-short-01.gif" width="100%">
|
| 149 |
+
</a>
|
| 150 |
+
<p><strong>Question:</strong> Offer a succinct account of the culinary process shown in this video.</p>
|
| 151 |
+
<p><strong>Response:</strong> Add tomato puree and mix it well with chicken pieces.</p>
|
| 152 |
+
</td>
|
| 153 |
+
<td align="left" valign="top" width="33%">
|
| 154 |
+
<a href="assets/video-understanding/videos/video-understanding-caption-long-01.mp4">
|
| 155 |
+
<img src="assets/video-understanding/previews/video-understanding-caption-long-01.gif" width="100%">
|
| 156 |
+
</a>
|
| 157 |
+
<p><strong>Question:</strong> Provide a detailed description of the given video, capturing its key moments.</p>
|
| 158 |
+
<p><strong>Response:</strong> In a sunlit meadow, a small tortoiseshell butterfly rests on a purple flower. A bee, with black and yellow stripes, lands on the same flower. The butterfly flaps its wings gently, while the bee busies itself, collecting nectar. The flower sways slightly in the breeze. The butterfly then takes off, and the bee follows, both heading to the next flower. The scene is a vivid display of insect interaction in a natural setting, with the colors of the butterfly and the bee contrasting against the green background of the meadow. The video captures this peaceful moment in a short 6-second duration.</p>
|
| 159 |
+
</td>
|
| 160 |
+
</tr>
|
| 161 |
+
</table>
|
| 162 |
+
</div>
|
| 163 |
+
|
| 164 |
+
### Text-to-Image Generation
|
| 165 |
+
|
| 166 |
+
<div align="center">
|
| 167 |
+
<img src="assets/text-to-image/text-to-image-overview.webp" alt="Lance text-to-image generation examples" width="100%">
|
| 168 |
+
</div>
|
| 169 |
+
|
| 170 |
+
### Image Editing
|
| 171 |
+
|
| 172 |
+
<div align="center">
|
| 173 |
+
<img src="assets/image-editing/image-editing-overview.webp" alt="Lance image editing examples" width="100%">
|
| 174 |
+
</div>
|
| 175 |
+
|
| 176 |
+
### Image Understanding
|
| 177 |
+
|
| 178 |
+
<div align="center">
|
| 179 |
+
<table align="center">
|
| 180 |
+
<tr>
|
| 181 |
+
<td align="left" valign="top" width="33%">
|
| 182 |
+
<img src="assets/image-understanding/cases/image-understanding-case-01.png" alt="Pie chart reasoning case" width="100%">
|
| 183 |
+
<p><strong>Question:</strong> Is the largest segment greater than sum of all the other segments?</p>
|
| 184 |
+
<p><strong>Response:</strong> Yes, the largest segment in the pie chart is the blue one, which is the most populous segment.</p>
|
| 185 |
+
</td>
|
| 186 |
+
<td align="left" valign="top" width="33%">
|
| 187 |
+
<img src="assets/image-understanding/cases/image-understanding-case-02.png" alt="Survey chart about border security priorities" width="100%">
|
| 188 |
+
<p><strong>Question:</strong> What percentage of respondents want better border security?</p>
|
| 189 |
+
<p><strong>Response:</strong> 29%</p>
|
| 190 |
+
</td>
|
| 191 |
+
<td align="left" valign="top" width="33%">
|
| 192 |
+
<img src="assets/image-understanding/cases/image-understanding-case-03.png" alt="License plate recognition case" width="100%">
|
| 193 |
+
<p><strong>Question:</strong> What is the license plate number of the car?</p>
|
| 194 |
+
<p><strong>Response:</strong> The license plate number of the car is BX62 BFY.</p>
|
| 195 |
+
</td>
|
| 196 |
+
</tr>
|
| 197 |
+
<tr>
|
| 198 |
+
<td align="left" valign="top" width="33%">
|
| 199 |
+
<img src="assets/image-understanding/cases/image-understanding-case-04.png" alt="Market research chart reasoning case" width="100%">
|
| 200 |
+
<p><strong>Question:</strong> According to the data from the proprietary market research, how much amount was spent on the promotional meetings and events during 1998?</p>
|
| 201 |
+
<p><strong>Response:</strong> According to the data from the proprietary market research, the total amount spent on the promotional meetings and events during 1998 was approximately $1.3 billion.</p>
|
| 202 |
+
</td>
|
| 203 |
+
<td align="left" valign="top" width="33%">
|
| 204 |
+
<img src="assets/image-understanding/cases/image-understanding-case-05.png" alt="Colosseum description case" width="100%">
|
| 205 |
+
<p><strong>Question:</strong> What is the appearance of the Colosseum in Rome, Italy?</p>
|
| 206 |
+
<p><strong>Response:</strong> The Colosseum in Rome, Italy is a magnificent ancient amphitheater made of stone and concrete. It is characterized by its large, arched openings and its semi-circular shape. The Colosseum has been partially restored, and its original color has faded, but the architectural details, like the columns and the semi-circular arches, are still visible. The lighting in the image highlights the grandeur and historical significance of the Colosseum, making it stand out against the backdrop of the evening sky.</p>
|
| 207 |
+
</td>
|
| 208 |
+
<td align="left" valign="top" width="33%">
|
| 209 |
+
<img src="assets/image-understanding/cases/image-understanding-case-06.png" alt="Total solar eclipse case" width="100%">
|
| 210 |
+
<p><strong>Question:</strong> How does a total solar eclipse look like from Earth?</p>
|
| 211 |
+
<p><strong>Response:</strong> A total solar eclipse looks like a dark circle with a bright, white outer edge. This is a result of the Earth's rotation and the Earth's shadow, which partially obscures the sun's light. The bright white edge of the eclipse is caused by the sun's high energy and its interaction with the Earth's atmosphere, while the dark part of the eclipse is due to the Earth's shadow and the surrounding air currents. The solar eclipse's shape, with its bright white edge and dark center, is similar to the shape of a full moon or a dark disk. It is a natural phenomenon that occurs in the atmosphere of the Earth and is an important part of the solar system.</p>
|
| 212 |
+
</td>
|
| 213 |
+
</tr>
|
| 214 |
+
</table>
|
| 215 |
+
</div>
|
| 216 |
+
|
| 217 |
+
## 🚀 Installation
|
| 218 |
+
|
| 219 |
+
### Recommended Environment
|
| 220 |
+
|
| 221 |
+
- **Software:** Python 3.10+, CUDA 12.4+ (required)
|
| 222 |
+
- **Hardware:** A GPU with at least 40GB VRAM is required for inference
|
| 223 |
+
|
| 224 |
+
### Installation Steps
|
| 225 |
+
```bash
|
| 226 |
+
bash ./setup_env.sh
|
| 227 |
+
```
|
| 228 |
+
|
| 229 |
+
### Download Model Weights
|
| 230 |
+
|
| 231 |
+
Please download all necessary model checkpoints from [Lance-3B on Hugging Face](https://huggingface.co/bytedance-research/Lance) and place them in the `downloads/` directory.
|
| 232 |
+
|
| 233 |
+
## 📚 Usage
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
### Inference
|
| 237 |
+
|
| 238 |
+
We provide a unified command-line interface for all generation / editing / understanding tasks:
|
| 239 |
+
|
| 240 |
+
#### Option 1: Configure and Run the Unified Script
|
| 241 |
+
|
| 242 |
+
```bash
|
| 243 |
+
bash inference_lance.sh
|
| 244 |
+
```
|
| 245 |
+
|
| 246 |
+
- Before running, please configure the inference parameters at the top of `inference_lance.sh`.
|
| 247 |
+
- **Supported tasks:** `t2i`, `t2v`, `image_edit`, `video_edit`, `x2t_image`, and `x2t_video`. You can modify `TASK_DEFAULT_CONFIGS` in `inference_lance.py` to customize the default data samples for each task.
|
| 248 |
+
- **Note:** For all tasks, we recommend following the `prompt` format used in the provided examples when writing input prompts, as this typically leads to better generation quality.
|
| 249 |
+
|
| 250 |
+
#### Option 2: Configure and Run the Unified Script
|
| 251 |
+
|
| 252 |
+
We provide task-specific one-click commands for different generation, editing, and understanding tasks.
|
| 253 |
+
|
| 254 |
+
##### Text-to-Video Generation
|
| 255 |
+
|
| 256 |
+
```bash
|
| 257 |
+
bash inference_lance.sh \
|
| 258 |
+
--TASK_NAME t2v \
|
| 259 |
+
--MODEL_PATH downloads/Lance_3B_Video \
|
| 260 |
+
--RESOLUTION video_480p \
|
| 261 |
+
--NUM_FRAMES 121 \
|
| 262 |
+
--VIDEO_HEIGHT 480 \
|
| 263 |
+
--VIDEO_WIDTH 848 \
|
| 264 |
+
--SAVE_PATH_GEN results/t2v
|
| 265 |
+
```
|
| 266 |
+
|
| 267 |
+
##### Text-to-Image Generation
|
| 268 |
+
|
| 269 |
+
```bash
|
| 270 |
+
bash inference_lance.sh \
|
| 271 |
+
--TASK_NAME t2i \
|
| 272 |
+
--MODEL_PATH downloads/Lance_3B \
|
| 273 |
+
--RESOLUTION image_768res \
|
| 274 |
+
--VIDEO_HEIGHT 768 \
|
| 275 |
+
--VIDEO_WIDTH 768 \
|
| 276 |
+
--SAVE_PATH_GEN results/t2i
|
| 277 |
+
```
|
| 278 |
+
|
| 279 |
+
##### Video Editing
|
| 280 |
+
|
| 281 |
+
```bash
|
| 282 |
+
bash inference_lance.sh \
|
| 283 |
+
--TASK_NAME video_edit \
|
| 284 |
+
--MODEL_PATH downloads/Lance_3B_Video \
|
| 285 |
+
--RESOLUTION video_480p \
|
| 286 |
+
--SAVE_PATH_GEN results/video_edit
|
| 287 |
+
```
|
| 288 |
+
|
| 289 |
+
##### Image Editing
|
| 290 |
+
|
| 291 |
+
```bash
|
| 292 |
+
bash inference_lance.sh \
|
| 293 |
+
--TASK_NAME image_edit \
|
| 294 |
+
--MODEL_PATH downloads/Lance_3B \
|
| 295 |
+
--RESOLUTION image_768res \
|
| 296 |
+
--SAVE_PATH_GEN results/image_edit
|
| 297 |
+
```
|
| 298 |
+
|
| 299 |
+
##### Video Understanding
|
| 300 |
+
|
| 301 |
+
```bash
|
| 302 |
+
bash inference_lance.sh \
|
| 303 |
+
--TASK_NAME x2t_video \
|
| 304 |
+
--MODEL_PATH downloads/Lance_3B_Video \
|
| 305 |
+
--RESOLUTION video_480p \
|
| 306 |
+
--NUM_FRAMES 50 \
|
| 307 |
+
--SAVE_PATH_GEN results/x2t_video
|
| 308 |
+
```
|
| 309 |
+
|
| 310 |
+
##### Image Understanding
|
| 311 |
+
|
| 312 |
+
```bash
|
| 313 |
+
bash inference_lance.sh \
|
| 314 |
+
--TASK_NAME x2t_image \
|
| 315 |
+
--MODEL_PATH downloads/Lance_3B \
|
| 316 |
+
--RESOLUTION image_768res \
|
| 317 |
+
--SAVE_PATH_GEN results/x2t_image
|
| 318 |
+
```
|
| 319 |
+
|
| 320 |
+
#### Available Tasks
|
| 321 |
+
|
| 322 |
+
| Task Name | Description | Example JSON |
|
| 323 |
+
|------------------------|--------------------------------------------------|----------------------------------------------|
|
| 324 |
+
| `t2v` | Text-to-Video generation | `config/examples/t2v_example.json` |
|
| 325 |
+
| `t2i` | Text-to-Image generation | `config/examples/t2i_example.json` |
|
| 326 |
+
| `image_edit` | Image editing | `config/examples/image_edit_example.json` |
|
| 327 |
+
| `video_edit` | Video editing | `config/examples/video_edit_example.json` |
|
| 328 |
+
| `x2t_image` | Image understanding | `config/examples/x2t_image_example.json` |
|
| 329 |
+
| `x2t_video` | Video understanding | `config/examples/x2t_video_example.json` |
|
| 330 |
+
|
| 331 |
+
For understanding examples:
|
| 332 |
+
|
| 333 |
+
- `config/examples/x2t_image_example.json`: image understanding examples for visual question answering and image-based reasoning.
|
| 334 |
+
- `config/examples/x2t_video_example.json`: video understanding examples for video question answering and video captioning.
|
| 335 |
+
|
| 336 |
+
#### Parameters
|
| 337 |
+
|
| 338 |
+
You can configure the following hyperparameters at the top of the `inference_lance.sh` script:
|
| 339 |
+
|
| 340 |
+
| Parameter | Default Value | Description |
|
| 341 |
+
| --- | --- | --- |
|
| 342 |
+
| `MODEL_PATH` | `"downloads/Lance_3B"` | Path to the downloaded Lance model weights (`Lance_3B` or `Lance_3B_Video`). |
|
| 343 |
+
| `NUM_GPUS` | `1` | Number of GPUs to use for inference. |
|
| 344 |
+
| `VALIDATION_NUM_TIMESTEPS` | `30` | Number of denoising steps (e.g., 30 or 50). |
|
| 345 |
+
| `VALIDATION_TIMESTEP_SHIFT` | `3.5` | Timestep shift parameter for flow matching scheduling. |
|
| 346 |
+
| `CFG_TEXT_SCALE` | `4.0` | Classifier-Free Guidance (CFG) scale for text conditioning. |
|
| 347 |
+
| `VALIDATION_DATA_SEED` | `42` | Random seed for generation reproducibility. |
|
| 348 |
+
| `NUM_FRAMES` | `50` | Number of frames for video generation (Max: 121). *Unused for image tasks.* |
|
| 349 |
+
| `VIDEO_HEIGHT` / `VIDEO_WIDTH`| `768` | Spatial resolution. *Unused for editing tasks (determined by input image/video).* |
|
| 350 |
+
| `RESOLUTION` | `"video_480p"` | Base resolution preset (`image_768res` or `video_480p`). |
|
| 351 |
+
|
| 352 |
+
### Gradio
|
| 353 |
+
```bash
|
| 354 |
+
python lance_gradio_t2v_v2t.py --gpus 0 --server-port 7860
|
| 355 |
+
```
|
| 356 |
+
|
| 357 |
+
### Benchmarks
|
| 358 |
+
|
| 359 |
+
#### DPG-Bench Evaluation
|
| 360 |
+
|
| 361 |
+
<div align="center">
|
| 362 |
+
<table align="center">
|
| 363 |
+
<thead>
|
| 364 |
+
<tr>
|
| 365 |
+
<th align="left">Models</th>
|
| 366 |
+
<th align="center"># Params.</th>
|
| 367 |
+
<th align="center">Global</th>
|
| 368 |
+
<th align="center">Entity</th>
|
| 369 |
+
<th align="center">Attribute</th>
|
| 370 |
+
<th align="center">Relation</th>
|
| 371 |
+
<th align="center">Other</th>
|
| 372 |
+
<th align="center">Overall</th>
|
| 373 |
+
</tr>
|
| 374 |
+
</thead>
|
| 375 |
+
<tbody>
|
| 376 |
+
<tr>
|
| 377 |
+
<td align="center" colspan="8"><i>Generation-only Models</i></td>
|
| 378 |
+
</tr>
|
| 379 |
+
<tr>
|
| 380 |
+
<td align="left">SDXL</td><td align="center">3.5B</td><td align="center">83.27</td><td align="center">82.43</td><td align="center">80.91</td><td align="center">86.76</td><td align="center">80.41</td><td align="center">74.65</td>
|
| 381 |
+
</tr>
|
| 382 |
+
<tr>
|
| 383 |
+
<td align="left">DALL-E 3</td><td align="center">-</td><td align="center">90.97</td><td align="center">89.61</td><td align="center">88.39</td><td align="center">90.58</td><td align="center">89.83</td><td align="center">83.50</td>
|
| 384 |
+
</tr>
|
| 385 |
+
<tr>
|
| 386 |
+
<td align="left">SD3-Medium</td><td align="center">2B</td><td align="center">87.90</td><td align="center">91.01</td><td align="center">88.83</td><td align="center">80.70</td><td align="center">88.68</td><td align="center">84.08</td>
|
| 387 |
+
</tr>
|
| 388 |
+
<tr>
|
| 389 |
+
<td align="left">FLUX.1-dev</td><td align="center">12B</td><td align="center">74.35</td><td align="center">90.00</td><td align="center">88.96</td><td align="center">90.87</td><td align="center">88.33</td><td align="center">83.84</td>
|
| 390 |
+
</tr>
|
| 391 |
+
<tr>
|
| 392 |
+
<td align="left">Qwen-Image</td><td align="center">20B</td><td align="center">91.32</td><td align="center">91.56</td><td align="center">92.02</td><td align="center">94.31</td><td align="center">92.73</td><td align="center">88.32</td>
|
| 393 |
+
</tr>
|
| 394 |
+
<tr>
|
| 395 |
+
<td align="center" colspan="8"><i>Unified Models</i></td>
|
| 396 |
+
</tr>
|
| 397 |
+
<tr>
|
| 398 |
+
<td align="left">Janus-Pro-7B</td><td align="center">7B</td><td align="center">86.90</td><td align="center">88.90</td><td align="center">89.40</td><td align="center">89.32</td><td align="center">89.48</td><td align="center">84.19</td>
|
| 399 |
+
</tr>
|
| 400 |
+
<tr>
|
| 401 |
+
<td align="left">OmniGen2</td><td align="center">4B</td><td align="center">88.81</td><td align="center">88.83</td><td align="center">90.18</td><td align="center">89.37</td><td align="center">90.27</td><td align="center">83.57</td>
|
| 402 |
+
</tr>
|
| 403 |
+
<tr>
|
| 404 |
+
<td align="left">Show-o2</td><td align="center">7B</td><td align="center">89.00</td><td align="center"><b>91.78</b></td><td align="center">89.96</td><td align="center">91.81</td><td align="center"><b>91.64</b></td><td align="center">86.14</td>
|
| 405 |
+
</tr>
|
| 406 |
+
<tr>
|
| 407 |
+
<td align="left">BAGEL<sup>†</sup></td><td align="center">7B</td><td align="center">88.94</td><td align="center">90.37</td><td align="center"><u>91.29</u></td><td align="center">90.82</td><td align="center">88.67</td><td align="center">85.07</td>
|
| 408 |
+
</tr>
|
| 409 |
+
<tr>
|
| 410 |
+
<td align="left">InternVL-U</td><td align="center">1.7B</td><td align="center"><u>90.39</u></td><td align="center">90.78</td><td align="center">90.68</td><td align="center">90.29</td><td align="center">88.77</td><td align="center">85.18</td>
|
| 411 |
+
</tr>
|
| 412 |
+
<tr>
|
| 413 |
+
<td align="left">TUNA</td><td align="center">7B</td><td align="center"><b>90.42</b></td><td align="center"><u>91.68</u></td><td align="center">90.94</td><td align="center"><u>91.87</u></td><td align="center"><u>90.73</u></td><td align="center"><b>86.76</b></td>
|
| 414 |
+
</tr>
|
| 415 |
+
<tr>
|
| 416 |
+
<td align="left">TUNA-2</td><td align="center">7B</td><td align="center">89.50</td><td align="center">91.40</td><td align="center"><b>92.07</b></td><td align="center">91.91</td><td align="center">88.81</td><td align="center"><u>86.54</u></td>
|
| 417 |
+
</tr>
|
| 418 |
+
<tr>
|
| 419 |
+
<td align="left">🌟 <b>Lance (Ours)</b></td><td align="center"><b>3B</b></td><td align="center"><b>83.89</b></td><td align="center"><b>91.07</b></td><td align="center"><b>89.36</b></td><td align="center"><b>93.38</b></td><td align="center"><b>80.80</b></td><td align="center"><b>84.67</b></td>
|
| 420 |
+
</tr>
|
| 421 |
+
</tbody>
|
| 422 |
+
</table>
|
| 423 |
+
</div>
|
| 424 |
+
|
| 425 |
+
<p align="center"><em><sup>†</sup> indicates methods that use LLM rewriters for prompt rewriting before generation.</em></p>
|
| 426 |
+
|
| 427 |
+
#### GenEval Evaluation
|
| 428 |
+
|
| 429 |
+
<div align="center">
|
| 430 |
+
<table align="center">
|
| 431 |
+
<thead>
|
| 432 |
+
<tr>
|
| 433 |
+
<th align="left">Models</th>
|
| 434 |
+
<th align="center"># Params.</th>
|
| 435 |
+
<th align="center">1-Obj.</th>
|
| 436 |
+
<th align="center">2-Obj.</th>
|
| 437 |
+
<th align="center">Count</th>
|
| 438 |
+
<th align="center">Colors</th>
|
| 439 |
+
<th align="center">Position</th>
|
| 440 |
+
<th align="center">Attr.</th>
|
| 441 |
+
<th align="center">Overall</th>
|
| 442 |
+
</tr>
|
| 443 |
+
</thead>
|
| 444 |
+
<tbody>
|
| 445 |
+
<tr>
|
| 446 |
+
<td align="center" colspan="9"><i>Generation-only Models</i></td>
|
| 447 |
+
</tr>
|
| 448 |
+
<tr>
|
| 449 |
+
<td align="left">SDXL</td><td align="center">3.5B</td><td align="center">0.98</td><td align="center">0.74</td><td align="center">0.39</td><td align="center">0.85</td><td align="center">0.15</td><td align="center">0.23</td><td align="center">0.55</td>
|
| 450 |
+
</tr>
|
| 451 |
+
<tr>
|
| 452 |
+
<td align="left">DALL-E 3</td><td align="center">-</td><td align="center">0.96</td><td align="center">0.87</td><td align="center">0.47</td><td align="center">0.83</td><td align="center">0.43</td><td align="center">0.45</td><td align="center">0.67</td>
|
| 453 |
+
</tr>
|
| 454 |
+
<tr>
|
| 455 |
+
<td align="left">SD3-Medium</td><td align="center">2B</td><td align="center">0.99</td><td align="center">0.94</td><td align="center">0.72</td><td align="center">0.89</td><td align="center">0.33</td><td align="center">0.60</td><td align="center">0.74</td>
|
| 456 |
+
</tr>
|
| 457 |
+
<tr>
|
| 458 |
+
<td align="left">FLUX.1-dev</td><td align="center">12B</td><td align="center">0.98</td><td align="center">0.93</td><td align="center">0.75</td><td align="center">0.93</td><td align="center">0.68</td><td align="center">0.65</td><td align="center">0.82</td>
|
| 459 |
+
</tr>
|
| 460 |
+
<tr>
|
| 461 |
+
<td align="left">Qwen-Image</td><td align="center">20B</td><td align="center">0.99</td><td align="center">0.92</td><td align="center">0.89</td><td align="center">0.88</td><td align="center">0.76</td><td align="center">0.77</td><td align="center">0.87</td>
|
| 462 |
+
</tr>
|
| 463 |
+
<tr>
|
| 464 |
+
<td align="center" colspan="9"><i>Unified Models</i></td>
|
| 465 |
+
</tr>
|
| 466 |
+
<tr>
|
| 467 |
+
<td align="left">Janus-Pro-7B</td><td align="center">7B</td><td align="center"><u>0.99</u></td><td align="center">0.89</td><td align="center">0.59</td><td align="center">0.90</td><td align="center">0.79</td><td align="center">0.66</td><td align="center">0.80</td>
|
| 468 |
+
</tr>
|
| 469 |
+
<tr>
|
| 470 |
+
<td align="left">OmniGen2</td><td align="center">4B</td><td align="center"><b>1.00</b></td><td align="center">0.95</td><td align="center">0.64</td><td align="center">0.88</td><td align="center">0.55</td><td align="center">0.76</td><td align="center">0.80</td>
|
| 471 |
+
</tr>
|
| 472 |
+
<tr>
|
| 473 |
+
<td align="left">Show-o2</td><td align="center">7B</td><td align="center"><b>1.00</b></td><td align="center">0.87</td><td align="center">0.58</td><td align="center">0.92</td><td align="center">0.52</td><td align="center">0.62</td><td align="center">0.76</td>
|
| 474 |
+
</tr>
|
| 475 |
+
<tr>
|
| 476 |
+
<td align="left">BAGEL<sup>†</sup></td><td align="center">7B</td><td align="center">0.98</td><td align="center">0.95</td><td align="center"><b>0.84</b></td><td align="center"><u>0.95</u></td><td align="center">0.78</td><td align="center">0.77</td><td align="center">0.88</td>
|
| 477 |
+
</tr>
|
| 478 |
+
<tr>
|
| 479 |
+
<td align="left">Mogao</td><td align="center">7B</td><td align="center"><b>1.00</b></td><td align="center"><b>0.97</b></td><td align="center"><u>0.83</u></td><td align="center">0.93</td><td align="center">0.84</td><td align="center">0.80</td><td align="center"><u>0.89</u></td>
|
| 480 |
+
</tr>
|
| 481 |
+
<tr>
|
| 482 |
+
<td align="left">InternVL-U</td><td align="center">1.7B</td><td align="center"><u>0.99</u></td><td align="center">0.94</td><td align="center">0.74</td><td align="center">0.91</td><td align="center">0.77</td><td align="center">0.74</td><td align="center">0.85</td>
|
| 483 |
+
</tr>
|
| 484 |
+
<tr>
|
| 485 |
+
<td align="left">TUNA</td><td align="center">7B</td><td align="center"><b>1.00</b></td><td align="center"><b>0.97</b></td><td align="center">0.81</td><td align="center">0.91</td><td align="center"><b>0.88</b></td><td align="center"><b>0.83</b></td><td align="center"><b>0.90</b></td>
|
| 486 |
+
</tr>
|
| 487 |
+
<tr>
|
| 488 |
+
<td align="left">TUNA-2</td><td align="center">7B</td><td align="center"><u>0.99</u></td><td align="center"><u>0.96</u></td><td align="center">0.80</td><td align="center">0.91</td><td align="center">0.84</td><td align="center">0.76</td><td align="center">0.87</td>
|
| 489 |
+
</tr>
|
| 490 |
+
<tr>
|
| 491 |
+
<td align="left">🌟 <b>Lance (Ours)</b></td><td align="center"><b>3B</b></td><td align="center"><b>1.00</b></td><td align="center"><b>0.94</b></td><td align="center"><b>0.84</b></td><td align="center"><b>0.97</b></td><td align="center"><b>0.87</b></td><td align="center"><b>0.81</b></td><td align="center"><b>0.90</b></td>
|
| 492 |
+
</tr>
|
| 493 |
+
</tbody>
|
| 494 |
+
</table>
|
| 495 |
+
</div>
|
| 496 |
+
|
| 497 |
+
<p align="center"><em><sup>†</sup> indicates methods that use LLM rewriters for prompt rewriting before generation.</em></p>
|
| 498 |
+
|
| 499 |
+
#### GEdit-Bench Evaluation
|
| 500 |
+
|
| 501 |
+
<div align="center">
|
| 502 |
+
<table align="center">
|
| 503 |
+
<thead>
|
| 504 |
+
<tr>
|
| 505 |
+
<th align="left">Models</th>
|
| 506 |
+
<th align="center"># Params.</th>
|
| 507 |
+
<th align="center">BC</th>
|
| 508 |
+
<th align="center">CA</th>
|
| 509 |
+
<th align="center">MM</th>
|
| 510 |
+
<th align="center">MC</th>
|
| 511 |
+
<th align="center">PB</th>
|
| 512 |
+
<th align="center">ST</th>
|
| 513 |
+
<th align="center">SA</th>
|
| 514 |
+
<th align="center">SR</th>
|
| 515 |
+
<th align="center">SRp</th>
|
| 516 |
+
<th align="center">TM</th>
|
| 517 |
+
<th align="center">TT</th>
|
| 518 |
+
<th align="center">Avg/G_O</th>
|
| 519 |
+
</tr>
|
| 520 |
+
</thead>
|
| 521 |
+
<tbody>
|
| 522 |
+
<tr>
|
| 523 |
+
<td align="center" colspan="14"><i>Generation-only Models</i></td>
|
| 524 |
+
</tr>
|
| 525 |
+
<tr>
|
| 526 |
+
<td align="left">Gemini 2.0</td><td align="center">-</td><td align="center">-</td><td align="center">-</td><td align="center">-</td><td align="center">-</td><td align="center">-</td><td align="center">-</td><td align="center">-</td><td align="center">-</td><td align="center">-</td><td align="center">-</td><td align="center">-</td><td align="center">6.32</td>
|
| 527 |
+
</tr>
|
| 528 |
+
<tr>
|
| 529 |
+
<td align="left">GPT Image 1</td><td align="center">-</td><td align="center">6.96</td><td align="center">6.85</td><td align="center">7.10</td><td align="center">5.41</td><td align="center">6.74</td><td align="center">7.44</td><td align="center">7.51</td><td align="center">8.73</td><td align="center">8.55</td><td align="center">8.45</td><td align="center">8.69</td><td align="center">7.49</td>
|
| 530 |
+
</tr>
|
| 531 |
+
<tr>
|
| 532 |
+
<td align="left">Qwen-Image-Edit</td><td align="center">20B</td><td align="center">8.23</td><td align="center">8.30</td><td align="center">7.33</td><td align="center">8.05</td><td align="center">7.49</td><td align="center">6.74</td><td align="center">8.57</td><td align="center">8.09</td><td align="center">8.29</td><td align="center">8.48</td><td align="center">8.50</td><td align="center">8.01</td>
|
| 533 |
+
</tr>
|
| 534 |
+
<tr>
|
| 535 |
+
<td align="center" colspan="14"><i>Unified Models</i></td>
|
| 536 |
+
</tr>
|
| 537 |
+
<tr>
|
| 538 |
+
<td align="left">Lumina-DiMOO</td><td align="center">8B</td><td align="center">3.43</td><td align="center">4.27</td><td align="center">3.08</td><td align="center">2.77</td><td align="center">4.74</td><td align="center">5.19</td><td align="center">4.44</td><td align="center">3.80</td><td align="center">4.38</td><td align="center">2.68</td><td align="center">4.20</td><td align="center">3.91</td>
|
| 539 |
+
</tr>
|
| 540 |
+
<tr>
|
| 541 |
+
<td align="left">Ovis-U1</td><td align="center">1.2B</td><td align="center"><u>7.49</u></td><td align="center">6.88</td><td align="center">6.21</td><td align="center">4.79</td><td align="center">5.98</td><td align="center"><u>6.46</u></td><td align="center">7.49</td><td align="center"><u>7.25</u></td><td align="center"><u>7.27</u></td><td align="center">4.48</td><td align="center">6.31</td><td align="center">6.42</td>
|
| 542 |
+
</tr>
|
| 543 |
+
<tr>
|
| 544 |
+
<td align="left">BAGEL</td><td align="center">7B</td><td align="center">7.32</td><td align="center">6.91</td><td align="center">6.38</td><td align="center">4.75</td><td align="center">4.57</td><td align="center">6.15</td><td align="center"><b>7.90</b></td><td align="center">7.16</td><td align="center">7.02</td><td align="center"><u>7.32</u></td><td align="center">6.22</td><td align="center">6.52</td>
|
| 545 |
+
</tr>
|
| 546 |
+
<tr>
|
| 547 |
+
<td align="left">InternVL-U</td><td align="center">1.7B</td><td align="center">7.08</td><td align="center">7.05</td><td align="center">6.38</td><td align="center"><u>7.02</u></td><td align="center"><u>6.03</u></td><td align="center">6.27</td><td align="center">7.13</td><td align="center">6.55</td><td align="center">6.33</td><td align="center">6.59</td><td align="center"><u>6.85</u></td><td align="center">6.66</td>
|
| 548 |
+
</tr>
|
| 549 |
+
<tr>
|
| 550 |
+
<td align="left">InternVL-U (w/ CoT)</td><td align="center">1.7B</td><td align="center">7.05</td><td align="center"><b>7.87</b></td><td align="center"><u>6.50</u></td><td align="center">6.99</td><td align="center">5.77</td><td align="center">6.10</td><td align="center">7.33</td><td align="center">7.16</td><td align="center">7.12</td><td align="center"><b>7.36</b></td><td align="center">6.46</td><td align="center"><u>6.88</u></td>
|
| 551 |
+
</tr>
|
| 552 |
+
<tr>
|
| 553 |
+
<td align="left">🌟 <b>Lance (Ours)</b></td><td align="center"><b>3B</b></td><td align="center"><b>7.73</b></td><td align="center"><u>7.74</u></td><td align="center"><b>7.28</b></td><td align="center"><b>7.83</b></td><td align="center"><b>7.50</b></td><td align="center"><b>7.03</b></td><td align="center"><u>7.64</u></td><td align="center"><b>7.85</b></td><td align="center"><b>7.71</b></td><td align="center">4.46</td><td align="center"><b>7.57</b></td><td align="center"><b>7.30</b></td>
|
| 554 |
+
</tr>
|
| 555 |
+
</tbody>
|
| 556 |
+
</table>
|
| 557 |
+
</div>
|
| 558 |
+
|
| 559 |
+
#### VBench Evaluation (Video Generation)
|
| 560 |
+
|
| 561 |
+
<div align="center">
|
| 562 |
+
<table align="center">
|
| 563 |
+
<thead>
|
| 564 |
+
<tr>
|
| 565 |
+
<th align="left">Type</th>
|
| 566 |
+
<th align="left">Model</th>
|
| 567 |
+
<th align="center"># Params.</th>
|
| 568 |
+
<th align="center">Total Score ↑</th>
|
| 569 |
+
</tr>
|
| 570 |
+
</thead>
|
| 571 |
+
<tbody>
|
| 572 |
+
<tr>
|
| 573 |
+
<td align="center" rowspan="12"><i>Gen. Only</i></td>
|
| 574 |
+
<td align="left">ModelScope</td><td align="center">1.7B</td><td align="center">75.75</td>
|
| 575 |
+
</tr>
|
| 576 |
+
<tr>
|
| 577 |
+
<td align="left">LaVie</td><td align="center">3B</td><td align="center">77.08</td>
|
| 578 |
+
</tr>
|
| 579 |
+
<tr>
|
| 580 |
+
<td align="left">Show-1</td><td align="center">6B</td><td align="center">78.93</td>
|
| 581 |
+
</tr>
|
| 582 |
+
<tr>
|
| 583 |
+
<td align="left">AnimateDiff-V2</td><td align="center">-</td><td align="center">80.27</td>
|
| 584 |
+
</tr>
|
| 585 |
+
<tr>
|
| 586 |
+
<td align="left">VideoCrafter-2.0</td><td align="center">-</td><td align="center">80.44</td>
|
| 587 |
+
</tr>
|
| 588 |
+
<tr>
|
| 589 |
+
<td align="left">CogVideoX</td><td align="center">5B</td><td align="center">81.61</td>
|
| 590 |
+
</tr>
|
| 591 |
+
<tr>
|
| 592 |
+
<td align="left">Kling</td><td align="center">-</td><td align="center">81.85</td>
|
| 593 |
+
</tr>
|
| 594 |
+
<tr>
|
| 595 |
+
<td align="left">Open-Sora-2.0</td><td align="center">-</td><td align="center">81.71</td>
|
| 596 |
+
</tr>
|
| 597 |
+
<tr>
|
| 598 |
+
<td align="left">Gen-3</td><td align="center">-</td><td align="center">82.32</td>
|
| 599 |
+
</tr>
|
| 600 |
+
<tr>
|
| 601 |
+
<td align="left">Step-Video-T2V</td><td align="center">30B</td><td align="center">81.83</td>
|
| 602 |
+
</tr>
|
| 603 |
+
<tr>
|
| 604 |
+
<td align="left">Hunyuan Video</td><td align="center">-</td><td align="center">83.43</td>
|
| 605 |
+
</tr>
|
| 606 |
+
<tr>
|
| 607 |
+
<td align="left">Wan2.1-T2V</td><td align="center">14B</td><td align="center">83.69</td>
|
| 608 |
+
</tr>
|
| 609 |
+
<tr>
|
| 610 |
+
<td align="center" rowspan="6"><i>Unified</i></td>
|
| 611 |
+
<td align="left">HaproOmni</td><td align="center">7B</td><td align="center">78.10</td>
|
| 612 |
+
</tr>
|
| 613 |
+
<tr>
|
| 614 |
+
<td align="left">Emu3</td><td align="center">8B</td><td align="center">80.96</td>
|
| 615 |
+
</tr>
|
| 616 |
+
<tr>
|
| 617 |
+
<td align="left">VILA-U</td><td align="center">7B</td><td align="center">74.01</td>
|
| 618 |
+
</tr>
|
| 619 |
+
<tr>
|
| 620 |
+
<td align="left">Show-o2</td><td align="center">2B</td><td align="center">81.34</td>
|
| 621 |
+
</tr>
|
| 622 |
+
<tr>
|
| 623 |
+
<td align="left">TUNA</td><td align="center">1.5B</td><td align="center"><u>84.06</u></td>
|
| 624 |
+
</tr>
|
| 625 |
+
<tr>
|
| 626 |
+
<td align="left">🌟 <b>Lance (Ours)</b></td><td align="center"><b>3B</b></td><td align="center"><b>85.11</b></td>
|
| 627 |
+
</tr>
|
| 628 |
+
</tbody>
|
| 629 |
+
</table>
|
| 630 |
+
</div>
|
| 631 |
+
|
| 632 |
+
#### Running Benchmarks
|
| 633 |
+
|
| 634 |
+
Ready-to-run benchmark scripts are provided under `benchmarks/`:
|
| 635 |
+
|
| 636 |
+
| Benchmark | Modality | Script |
|
| 637 |
+
|------------------------|----------|---------------------------------------------------------------|
|
| 638 |
+
| GenEVAL (image gen) | Image | `benchmarks/image_gen/GenEVAL/sample_GenEVAL.sh` |
|
| 639 |
+
| DPG (image gen) | Image | `benchmarks/image_gen/DPG/sample_DPG.sh` |
|
| 640 |
+
| GEdit (image edit) | Image | `benchmarks/image_gen/GEdit/sample_GEdit.sh` |
|
| 641 |
+
| VBench (video gen) | Video | `benchmarks/video_gen/Vbench/sample_vbench.sh` |
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
## 📄 License
|
| 645 |
+
|
| 646 |
+
Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 647 |
+
|
| 648 |
+
## 🙏 Acknowledgements
|
| 649 |
+
|
| 650 |
+
We would like to thank the contributors of [BAGEL](https://github.com/ByteDance-Seed/bagel), [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct), and [Wan2.2](https://github.com/Wan-Video/Wan2.2) for their open research and contributions.
|
| 651 |
+
|
| 652 |
+
## 💖 Citation
|
| 653 |
+
|
| 654 |
+
If you find **Lance** useful for your project or research, welcome to 🌟 this repo and cite our work using the following BibTeX:
|
| 655 |
+
|
| 656 |
+
```bibtex
|
| 657 |
+
@misc{fu2026lanceunifiedmultimodalmodeling,
|
| 658 |
+
title = {Lance: Unified Multimodal Modeling by Multi-Task Synergy},
|
| 659 |
+
author = {Fengyi Fu and Mengqi Huang and Shaojin Wu and Yunsheng Jiang and Yufei Huo and Hao Li and Yinghang Song and Fei Ding and Jianzhu Guo and Qian He and Zheren Fu and Zhendong Mao and Yongdong Zhang},
|
| 660 |
+
year = {2026},
|
| 661 |
+
eprint = {2605.18678},
|
| 662 |
+
archivePrefix = {arXiv},
|
| 663 |
+
primaryClass = {cs.CV},
|
| 664 |
+
url = {https://arxiv.org/abs/2605.18678},
|
| 665 |
+
}
|
| 666 |
+
```
|
| 667 |
+
|
| 668 |
+
## 📞 Contact
|
| 669 |
+
|
| 670 |
+
For questions, issues, or collaborations, please contact [Mengqi Huang](https://corleone-huang.github.io/) and [Jianzhu Guo](https://guojianzhu.com/).
|
README_zh.md
ADDED
|
@@ -0,0 +1,660 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
<img src="assets/logo/lance-logo.webp" alt="Lance logo" width="300">
|
| 3 |
+
|
| 4 |
+
<h1 align="center"><sup>Lance: Unified Multimodal Modeling by Multi-Task Synergy</sup></h1>
|
| 5 |
+
<p>
|
| 6 |
+
<strong>
|
| 7 |
+
<a href="https://scholar.google.com.hk/citations?user=FXxoQlsAAAAJ&hl=zh-CN&oi=ao" style="text-decoration: none; color: inherit;">Fengyi Fu</a><sup>*</sup>,
|
| 8 |
+
<a href="https://corleone-huang.github.io/" style="text-decoration: none; color: inherit;">Mengqi Huang</a><sup>*,✉</sup>,
|
| 9 |
+
<a href="https://scholar.google.com.hk/citations?user=9ER6nVkAAAAJ&hl=zh-CN&oi=ao" style="text-decoration: none; color: inherit;">Shaojin Wu</a><sup>*</sup>,
|
| 10 |
+
Yunsheng Jiang<sup>*</sup>,
|
| 11 |
+
Yufei Huo,
|
| 12 |
+
<a href="https://guojianzhu.com/" style="text-decoration: none; color: inherit;">Jianzhu Guo</a><sup>✉,§</sup>
|
| 13 |
+
</strong><br>
|
| 14 |
+
Hao Li,
|
| 15 |
+
Yinghang Song,
|
| 16 |
+
Fei Ding,
|
| 17 |
+
Qian He,
|
| 18 |
+
Zheren Fu,
|
| 19 |
+
Zhendong Mao,
|
| 20 |
+
Yongdong Zhang
|
| 21 |
+
<br>
|
| 22 |
+
<em>ByteDance</em>
|
| 23 |
+
<br>
|
| 24 |
+
<sup>*</sup> 共同一作 <sup>✉</sup> 通讯作者 <sup>§</sup> Project lead
|
| 25 |
+
</p>
|
| 26 |
+
<p>
|
| 27 |
+
<a href="https://lance-project.github.io/" style="text-decoration: none; margin: 0 8px;"><img src="https://img.shields.io/badge/Homepage-Lance-blue?style=flat" alt="Homepage"></a>
|
| 28 |
+
<a href="http://arxiv.org/abs/2605.18678" style="text-decoration: none; margin: 0 8px;"><img src="https://img.shields.io/badge/Paper-arXiv-red?style=flat&logo=arxiv" alt="arXiv"></a>
|
| 29 |
+
<a href="https://huggingface.co/bytedance-research/Lance" style="text-decoration: none; margin: 0 8px;"><img src="https://img.shields.io/badge/Model-HuggingFace-yellow?style=flat&logo=huggingface" alt="Model"></a>
|
| 30 |
+
<br>
|
| 31 |
+
<a href="./README.md"><ins>English</ins></a> | 简体中文
|
| 32 |
+
</p>
|
| 33 |
+
</div>
|
| 34 |
+
|
| 35 |
+
## 🌟 亮点
|
| 36 |
+
|
| 37 |
+
**Lance** 是一个3B参数、原生统一的多模态模型,在单一框架下同时支持 **图像与视频的理解、生成和编辑**。
|
| 38 |
+
|
| 39 |
+
- **3B 规模高效强大。** 仅使用 **3B active parameters**,Lance 即可在图像生成、图像编辑和视频生成等基准上取得强劲表现。
|
| 40 |
+
- **从零训练。** Lance 采用分阶段多任务训练配方,在 **128 张 A100 GPU** 的预算内从零完成训练。
|
| 41 |
+
|
| 42 |
+
<div align="center">
|
| 43 |
+
<img src="assets/benchmarks/benchmark-overview.png" alt="Lance benchmark overview across image generation, image editing, video generation, and video understanding" width="980">
|
| 44 |
+
</div>
|
| 45 |
+
|
| 46 |
+
## 🎨 演示
|
| 47 |
+
|
| 48 |
+
### 文生视频
|
| 49 |
+
|
| 50 |
+
<table align="center">
|
| 51 |
+
<tr>
|
| 52 |
+
<td><a href="assets/text-to-video/videos/text-to-video-demo-01.mp4"><img src="assets/text-to-video/previews/text-to-video-demo-01.gif" width="100%"></a></td>
|
| 53 |
+
<td><a href="assets/text-to-video/videos/text-to-video-demo-02.mp4"><img src="assets/text-to-video/previews/text-to-video-demo-02.gif" width="100%"></a></td>
|
| 54 |
+
<td><a href="assets/text-to-video/videos/text-to-video-demo-03.mp4"><img src="assets/text-to-video/previews/text-to-video-demo-03.gif" width="100%"></a></td>
|
| 55 |
+
<td><a href="assets/text-to-video/videos/text-to-video-demo-04.mp4"><img src="assets/text-to-video/previews/text-to-video-demo-04.gif" width="100%"></a></td>
|
| 56 |
+
</tr>
|
| 57 |
+
<tr>
|
| 58 |
+
<td><a href="assets/text-to-video/videos/text-to-video-demo-05.mp4"><img src="assets/text-to-video/previews/text-to-video-demo-05.gif" width="100%"></a></td>
|
| 59 |
+
<td><a href="assets/text-to-video/videos/text-to-video-demo-06.mp4"><img src="assets/text-to-video/previews/text-to-video-demo-06.gif" width="100%"></a></td>
|
| 60 |
+
<td><a href="assets/text-to-video/videos/text-to-video-demo-07.mp4"><img src="assets/text-to-video/previews/text-to-video-demo-07.gif" width="100%"></a></td>
|
| 61 |
+
<td><a href="assets/text-to-video/videos/text-to-video-demo-08.mp4"><img src="assets/text-to-video/previews/text-to-video-demo-08.gif" width="100%"></a></td>
|
| 62 |
+
</tr>
|
| 63 |
+
</table>
|
| 64 |
+
|
| 65 |
+
### 视频编辑
|
| 66 |
+
|
| 67 |
+
<table align="center">
|
| 68 |
+
<tr>
|
| 69 |
+
<td><a href="assets/video-editing/videos/video-editing-demo-01.mp4"><img src="assets/video-editing/previews/video-editing-demo-01.gif" width="100%"></a></td>
|
| 70 |
+
<td><a href="assets/video-editing/videos/video-editing-demo-02.mp4"><img src="assets/video-editing/previews/video-editing-demo-02.gif" width="100%"></a></td>
|
| 71 |
+
<td><a href="assets/video-editing/videos/video-editing-demo-03.mp4"><img src="assets/video-editing/previews/video-editing-demo-03.gif" width="100%"></a></td>
|
| 72 |
+
<td><a href="assets/video-editing/videos/video-editing-demo-04.mp4"><img src="assets/video-editing/previews/video-editing-demo-04.gif" width="100%"></a></td>
|
| 73 |
+
</tr>
|
| 74 |
+
<tr>
|
| 75 |
+
<td><a href="assets/video-editing/videos/video-editing-demo-05.mp4"><img src="assets/video-editing/previews/video-editing-demo-05.gif" width="100%"></a></td>
|
| 76 |
+
<td><a href="assets/video-editing/videos/video-editing-demo-06.mp4"><img src="assets/video-editing/previews/video-editing-demo-06.gif" width="100%"></a></td>
|
| 77 |
+
<td><a href="assets/video-editing/videos/video-editing-demo-07.mp4"><img src="assets/video-editing/previews/video-editing-demo-07.gif" width="100%"></a></td>
|
| 78 |
+
<td><a href="assets/video-editing/videos/video-editing-demo-08.mp4"><img src="assets/video-editing/previews/video-editing-demo-08.gif" width="100%"></a></td>
|
| 79 |
+
</tr>
|
| 80 |
+
</table>
|
| 81 |
+
|
| 82 |
+
### 多轮一致性编辑
|
| 83 |
+
|
| 84 |
+
<div align="center">
|
| 85 |
+
<a href="assets/multi-turn-editing/videos/multi-turn-editing-demo-01.mp4">
|
| 86 |
+
<img src="assets/multi-turn-editing/previews/multi-turn-editing-demo-01.gif" width="100%">
|
| 87 |
+
</a>
|
| 88 |
+
</div>
|
| 89 |
+
|
| 90 |
+
### 智能视频生成
|
| 91 |
+
|
| 92 |
+
<table align="center">
|
| 93 |
+
<tr>
|
| 94 |
+
<td><a href="assets/intelligent-video/videos/intelligent-video-demo-01.mp4"><img src="assets/intelligent-video/previews/intelligent-video-demo-01.gif" width="100%"></a></td>
|
| 95 |
+
<td><a href="assets/intelligent-video/videos/intelligent-video-demo-02.mp4"><img src="assets/intelligent-video/previews/intelligent-video-demo-02.gif" width="100%"></a></td>
|
| 96 |
+
<td><a href="assets/intelligent-video/videos/intelligent-video-demo-03.mp4"><img src="assets/intelligent-video/previews/intelligent-video-demo-03.gif" width="100%"></a></td>
|
| 97 |
+
<td><a href="assets/intelligent-video/videos/intelligent-video-demo-04.mp4"><img src="assets/intelligent-video/previews/intelligent-video-demo-04.gif" width="100%"></a></td>
|
| 98 |
+
</tr>
|
| 99 |
+
</table>
|
| 100 |
+
|
| 101 |
+
### 视频理解
|
| 102 |
+
|
| 103 |
+
<div align="center">
|
| 104 |
+
<table align="center">
|
| 105 |
+
<tr>
|
| 106 |
+
<td align="left" valign="top" width="33%">
|
| 107 |
+
<a href="assets/video-understanding/videos/video-understanding-vqa-01.mp4">
|
| 108 |
+
<img src="assets/video-understanding/previews/video-understanding-vqa-01.gif" width="100%">
|
| 109 |
+
</a>
|
| 110 |
+
<p><strong>问题:</strong> How many times did the person launch objects on the table? Options: (A) 3 (B) 2 (C) 4</p>
|
| 111 |
+
<p><strong>Response:</strong> (A) 3</p>
|
| 112 |
+
</td>
|
| 113 |
+
<td align="left" valign="top" width="33%">
|
| 114 |
+
<a href="assets/video-understanding/videos/video-understanding-vqa-02.mp4">
|
| 115 |
+
<img src="assets/video-understanding/previews/video-understanding-vqa-02.gif" width="100%">
|
| 116 |
+
</a>
|
| 117 |
+
<p><strong>问题:</strong> The person makes sets of repeated actions. How many distinct repeated actions did the person do? Options: (A) 2 (B) 3 (C) 4</p>
|
| 118 |
+
<p><strong>Response:</strong> (A) 2</p>
|
| 119 |
+
</td>
|
| 120 |
+
<td align="left" valign="top" width="33%">
|
| 121 |
+
<a href="assets/video-understanding/videos/video-understanding-vqa-03.mp4">
|
| 122 |
+
<img src="assets/video-understanding/previews/video-understanding-vqa-03.gif" width="100%">
|
| 123 |
+
</a>
|
| 124 |
+
<p><strong>问题:</strong> In which direction does the purple sphere move in the video? Options: (A) Down and to the right. (B) Up and to the left. (C) Up and to the right. (D) The object is stationary.</p>
|
| 125 |
+
<p><strong>Response:</strong> (A) Down and to the right.</p>
|
| 126 |
+
</td>
|
| 127 |
+
</tr>
|
| 128 |
+
<tr>
|
| 129 |
+
<td align="left" valign="top" width="33%">
|
| 130 |
+
<a href="assets/video-understanding/videos/video-understanding-vqa-04.mp4">
|
| 131 |
+
<img src="assets/video-understanding/previews/video-understanding-vqa-04.gif" width="100%">
|
| 132 |
+
</a>
|
| 133 |
+
<p><strong>问题:</strong> What is the unrealistic phenomenon displayed in the video? Options: (A) The man can manipulate time via phone. (B) Man grabs an object through a phone screen. (C) Chocolate transforms into different objects. (D) Visible means of propulsion enables flight.</p>
|
| 134 |
+
<p><strong>Response:</strong> (B) Man grabs an object through a phone screen.</p>
|
| 135 |
+
</td>
|
| 136 |
+
<td align="left" valign="top" width="33%">
|
| 137 |
+
<a href="assets/video-understanding/videos/video-understanding-caption-short-01.mp4">
|
| 138 |
+
<img src="assets/video-understanding/previews/video-understanding-caption-short-01.gif" width="100%">
|
| 139 |
+
</a>
|
| 140 |
+
<p><strong>问题:</strong> Offer a succinct account of the culinary process shown in this video.</p>
|
| 141 |
+
<p><strong>Response:</strong> Add tomato puree and mix it well with chicken pieces.</p>
|
| 142 |
+
</td>
|
| 143 |
+
<td align="left" valign="top" width="33%">
|
| 144 |
+
<a href="assets/video-understanding/videos/video-understanding-caption-long-01.mp4">
|
| 145 |
+
<img src="assets/video-understanding/previews/video-understanding-caption-long-01.gif" width="100%">
|
| 146 |
+
</a>
|
| 147 |
+
<p><strong>问题:</strong> Provide a detailed description of the given video, capturing its key moments.</p>
|
| 148 |
+
<p><strong>Response:</strong> In a sunlit meadow, a small tortoiseshell butterfly rests on a purple flower. A bee, with black and yellow stripes, lands on the same flower. The butterfly flaps its wings gently, while the bee busies itself, collecting nectar. The flower sways slightly in the breeze. The butterfly then takes off, and the bee follows, both heading to the next flower. The scene is a vivid display of insect interaction in a natural setting, with the colors of the butterfly and the bee contrasting against the green background of the meadow. The video captures this peaceful moment in a short 6-second duration.</p>
|
| 149 |
+
</td>
|
| 150 |
+
</tr>
|
| 151 |
+
</table>
|
| 152 |
+
</div>
|
| 153 |
+
|
| 154 |
+
### 文生图
|
| 155 |
+
|
| 156 |
+
<div align="center">
|
| 157 |
+
<img src="assets/text-to-image/text-to-image-overview.webp" alt="Lance text-to-image generation examples" width="100%">
|
| 158 |
+
</div>
|
| 159 |
+
|
| 160 |
+
### 图像编辑
|
| 161 |
+
|
| 162 |
+
<div align="center">
|
| 163 |
+
<img src="assets/image-editing/image-editing-overview.webp" alt="Lance image editing examples" width="100%">
|
| 164 |
+
</div>
|
| 165 |
+
|
| 166 |
+
### 图像理解
|
| 167 |
+
|
| 168 |
+
<div align="center">
|
| 169 |
+
<table align="center">
|
| 170 |
+
<tr>
|
| 171 |
+
<td align="left" valign="top" width="33%">
|
| 172 |
+
<img src="assets/image-understanding/cases/image-understanding-case-01.png" alt="Pie chart reasoning case" width="100%">
|
| 173 |
+
<p><strong>问题:</strong> Is the largest segment greater than sum of all the other segments?</p>
|
| 174 |
+
<p><strong>Response:</strong> Yes, the largest segment in the pie chart is the blue one, which is the most populous segment.</p>
|
| 175 |
+
</td>
|
| 176 |
+
<td align="left" valign="top" width="33%">
|
| 177 |
+
<img src="assets/image-understanding/cases/image-understanding-case-02.png" alt="Survey chart about border security priorities" width="100%">
|
| 178 |
+
<p><strong>问题:</strong> What percentage of respondents want better border security?</p>
|
| 179 |
+
<p><strong>Response:</strong> 29%</p>
|
| 180 |
+
</td>
|
| 181 |
+
<td align="left" valign="top" width="33%">
|
| 182 |
+
<img src="assets/image-understanding/cases/image-understanding-case-03.png" alt="License plate recognition case" width="100%">
|
| 183 |
+
<p><strong>问题:</strong> What is the license plate number of the car?</p>
|
| 184 |
+
<p><strong>Response:</strong> The license plate number of the car is BX62 BFY.</p>
|
| 185 |
+
</td>
|
| 186 |
+
</tr>
|
| 187 |
+
<tr>
|
| 188 |
+
<td align="left" valign="top" width="33%">
|
| 189 |
+
<img src="assets/image-understanding/cases/image-understanding-case-04.png" alt="Market research chart reasoning case" width="100%">
|
| 190 |
+
<p><strong>问题:</strong> According to the data from the proprietary market research, how much amount was spent on the promotional meetings and events during 1998?</p>
|
| 191 |
+
<p><strong>Response:</strong> According to the data from the proprietary market research, the total amount spent on the promotional meetings and events during 1998 was approximately $1.3 billion.</p>
|
| 192 |
+
</td>
|
| 193 |
+
<td align="left" valign="top" width="33%">
|
| 194 |
+
<img src="assets/image-understanding/cases/image-understanding-case-05.png" alt="Colosseum description case" width="100%">
|
| 195 |
+
<p><strong>问题:</strong> What is the appearance of the Colosseum in Rome, Italy?</p>
|
| 196 |
+
<p><strong>Response:</strong> The Colosseum in Rome, Italy is a magnificent ancient amphitheater made of stone and concrete. It is characterized by its large, arched openings and its semi-circular shape. The Colosseum has been partially restored, and its original color has faded, but the architectural details, like the columns and the semi-circular arches, are still visible. The lighting in the image highlights the grandeur and historical significance of the Colosseum, making it stand out against the backdrop of the evening sky.</p>
|
| 197 |
+
</td>
|
| 198 |
+
<td align="left" valign="top" width="33%">
|
| 199 |
+
<img src="assets/image-understanding/cases/image-understanding-case-06.png" alt="Total solar eclipse case" width="100%">
|
| 200 |
+
<p><strong>问题:</strong> How does a total solar eclipse look like from Earth?</p>
|
| 201 |
+
<p><strong>Response:</strong> A total solar eclipse looks like a dark circle with a bright, white outer edge. This is a result of the Earth's rotation and the Earth's shadow, which partially obscures the sun's light. The bright white edge of the eclipse is caused by the sun's high energy and its interaction with the Earth's atmosphere, while the dark part of the eclipse is due to the Earth's shadow and the surrounding air currents. The solar eclipse's shape, with its bright white edge and dark center, is similar to the shape of a full moon or a dark disk. It is a natural phenomenon that occurs in the atmosphere of the Earth and is an important part of the solar system.</p>
|
| 202 |
+
</td>
|
| 203 |
+
</tr>
|
| 204 |
+
</table>
|
| 205 |
+
</div>
|
| 206 |
+
|
| 207 |
+
## 🚀 安装
|
| 208 |
+
|
| 209 |
+
### 推荐环境
|
| 210 |
+
|
| 211 |
+
- **软件环境:** Python 3.10+,CUDA 12.4+(必需)
|
| 212 |
+
- **硬件环境:** 推理至少需要一张显存不低于 40GB 的 GPU
|
| 213 |
+
|
| 214 |
+
### 安装步骤
|
| 215 |
+
```bash
|
| 216 |
+
bash ./setup_env.sh
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
### 下载模型权重
|
| 220 |
+
|
| 221 |
+
请从 [Hugging Face 上的 Lance-3B](https://huggingface.co/bytedance-research/Lance) 下载所需的全部模型权重,并放置到 `downloads/` 目录下。
|
| 222 |
+
|
| 223 |
+
## 📚 使用方法
|
| 224 |
+
|
| 225 |
+
### 推理
|
| 226 |
+
|
| 227 |
+
Lance 为生成、编辑和理解任务提供了统一的命令行入口:
|
| 228 |
+
|
| 229 |
+
#### 方式一:配置并运行统一推理脚本
|
| 230 |
+
|
| 231 |
+
```bash
|
| 232 |
+
bash inference_lance.sh
|
| 233 |
+
```
|
| 234 |
+
|
| 235 |
+
- 运行前,请先在 `inference_lance.sh` 顶部配置推理参数。
|
| 236 |
+
- **支持任务:** `t2i`、`t2v`、`image_edit`、`video_edit`、`x2t_image` 和 `x2t_video`。你也可以在 `inference_lance.py` 中修改 `TASK_DEFAULT_CONFIGS`,自定义每个任务默认使用的数据样例。
|
| 237 |
+
- **注意:** 对于所有任务,建议在编写输入 prompt 时参考提供示例中的 `prompt` 格式,这通常有助于获得更好的生成效果。
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
#### Option 2: 运行任务专用一键脚本
|
| 241 |
+
|
| 242 |
+
我们提供了面向不同生成、编辑和理解任务的一键启动命令,便于快速运行指定任务类型。
|
| 243 |
+
|
| 244 |
+
##### 文本-视频生成
|
| 245 |
+
|
| 246 |
+
```bash
|
| 247 |
+
bash inference_lance.sh \
|
| 248 |
+
--TASK_NAME t2v \
|
| 249 |
+
--MODEL_PATH downloads/Lance_3B_Video \
|
| 250 |
+
--RESOLUTION video_480p \
|
| 251 |
+
--NUM_FRAMES 121 \
|
| 252 |
+
--VIDEO_HEIGHT 480 \
|
| 253 |
+
--VIDEO_WIDTH 848 \
|
| 254 |
+
--SAVE_PATH_GEN results/t2v
|
| 255 |
+
```
|
| 256 |
+
|
| 257 |
+
##### 文本-图像生成
|
| 258 |
+
|
| 259 |
+
```bash
|
| 260 |
+
bash inference_lance.sh \
|
| 261 |
+
--TASK_NAME t2i \
|
| 262 |
+
--MODEL_PATH downloads/Lance_3B \
|
| 263 |
+
--RESOLUTION image_768res \
|
| 264 |
+
--VIDEO_HEIGHT 768 \
|
| 265 |
+
--VIDEO_WIDTH 768 \
|
| 266 |
+
--SAVE_PATH_GEN results/t2i
|
| 267 |
+
```
|
| 268 |
+
|
| 269 |
+
##### 视频编辑
|
| 270 |
+
|
| 271 |
+
```bash
|
| 272 |
+
bash inference_lance.sh \
|
| 273 |
+
--TASK_NAME video_edit \
|
| 274 |
+
--MODEL_PATH downloads/Lance_3B_Video \
|
| 275 |
+
--RESOLUTION video_480p \
|
| 276 |
+
--SAVE_PATH_GEN results/video_edit
|
| 277 |
+
```
|
| 278 |
+
|
| 279 |
+
##### 图像编辑
|
| 280 |
+
|
| 281 |
+
```bash
|
| 282 |
+
bash inference_lance.sh \
|
| 283 |
+
--TASK_NAME image_edit \
|
| 284 |
+
--MODEL_PATH downloads/Lance_3B \
|
| 285 |
+
--RESOLUTION image_768res \
|
| 286 |
+
--SAVE_PATH_GEN results/image_edit
|
| 287 |
+
```
|
| 288 |
+
|
| 289 |
+
##### 视频理解
|
| 290 |
+
|
| 291 |
+
```bash
|
| 292 |
+
bash inference_lance.sh \
|
| 293 |
+
--TASK_NAME x2t_video \
|
| 294 |
+
--MODEL_PATH downloads/Lance_3B_Video \
|
| 295 |
+
--RESOLUTION video_480p \
|
| 296 |
+
--NUM_FRAMES 50 \
|
| 297 |
+
--SAVE_PATH_GEN results/x2t_video
|
| 298 |
+
```
|
| 299 |
+
|
| 300 |
+
##### 图像理解
|
| 301 |
+
|
| 302 |
+
```bash
|
| 303 |
+
bash inference_lance.sh \
|
| 304 |
+
--TASK_NAME x2t_image \
|
| 305 |
+
--MODEL_PATH downloads/Lance_3B \
|
| 306 |
+
--RESOLUTION image_768res \
|
| 307 |
+
--SAVE_PATH_GEN results/x2t_image
|
| 308 |
+
```
|
| 309 |
+
|
| 310 |
+
#### 可用任务
|
| 311 |
+
|
| 312 |
+
| 任务名 | 说明 | 示例 JSON |
|
| 313 |
+
|------------------------|--------------------------------------------------|----------------------------------------------|
|
| 314 |
+
| `t2v` | 文生视频 | `config/examples/t2v_example.json` |
|
| 315 |
+
| `t2i` | 文生图 | `config/examples/t2i_example.json` |
|
| 316 |
+
| `image_edit` | 图像编辑 | `config/examples/image_edit_example.json` |
|
| 317 |
+
| `video_edit` | 视频编辑 | `config/examples/video_edit_example.json` |
|
| 318 |
+
| `x2t_image` | 图像理解 | `config/examples/x2t_image_example.json` |
|
| 319 |
+
| `x2t_video` | 视频理解 | `config/examples/x2t_video_example.json` |
|
| 320 |
+
|
| 321 |
+
关于理解任务的示例文件:
|
| 322 |
+
|
| 323 |
+
- `config/examples/x2t_image_example.json`:用于图像理解示例,包括视觉问答和基于图像的推理。
|
| 324 |
+
- `config/examples/x2t_video_example.json`:用于视频理解示例,包括视频问答和视频描述。
|
| 325 |
+
|
| 326 |
+
#### 参数说明
|
| 327 |
+
|
| 328 |
+
你可以在 `inference_lance.sh` 顶部配置以下超参数:
|
| 329 |
+
|
| 330 |
+
| 参数 | 默认值 | 说明 |
|
| 331 |
+
| --- | --- | --- |
|
| 332 |
+
| `MODEL_PATH` | `"downloads/Lance_3B"` | 下载后的 Lance 模型权重路径(如 `Lance_3B` 或 `Lance_3B_Video`)。 |
|
| 333 |
+
| `NUM_GPUS` | `1` | 用于推理的 GPU 数量。 |
|
| 334 |
+
| `VALIDATION_NUM_TIMESTEPS` | `30` | 去噪步数(例如 30 或 50)。 |
|
| 335 |
+
| `VALIDATION_TIMESTEP_SHIFT` | `3.5` | Flow matching 调度中的 timestep shift 参数。 |
|
| 336 |
+
| `CFG_TEXT_SCALE` | `4.0` | 文本条件的 CFG(Classifier-Free Guidance)系数。 |
|
| 337 |
+
| `VALIDATION_DATA_SEED` | `42` | 用于复现实验的随机种子。 |
|
| 338 |
+
| `NUM_FRAMES` | `50` | 视频生成帧数(最大 121)。*图像任务不使用该参数。* |
|
| 339 |
+
| `VIDEO_HEIGHT` / `VIDEO_WIDTH`| `768` | 空间分辨率。*编辑任务不使用该参数(由输入图像/视频决定)。* |
|
| 340 |
+
| `RESOLUTION` | `"video_480p"` | 基础分辨率预设(如 `image_768res` 或 `video_480p`)。 |
|
| 341 |
+
|
| 342 |
+
### Gradio
|
| 343 |
+
```bash
|
| 344 |
+
python lance_gradio_t2v_v2t.py --gpus 0 --server-port 7860
|
| 345 |
+
```
|
| 346 |
+
|
| 347 |
+
### 基准评测
|
| 348 |
+
|
| 349 |
+
#### DPG-Bench 评测
|
| 350 |
+
|
| 351 |
+
<div align="center">
|
| 352 |
+
<table align="center">
|
| 353 |
+
<thead>
|
| 354 |
+
<tr>
|
| 355 |
+
<th align="left">模型</th>
|
| 356 |
+
<th align="center"># Params.</th>
|
| 357 |
+
<th align="center">Global</th>
|
| 358 |
+
<th align="center">Entity</th>
|
| 359 |
+
<th align="center">Attribute</th>
|
| 360 |
+
<th align="center">Relation</th>
|
| 361 |
+
<th align="center">Other</th>
|
| 362 |
+
<th align="center">Overall</th>
|
| 363 |
+
</tr>
|
| 364 |
+
</thead>
|
| 365 |
+
<tbody>
|
| 366 |
+
<tr>
|
| 367 |
+
<td align="center" colspan="8"><i>仅生成模型</i></td>
|
| 368 |
+
</tr>
|
| 369 |
+
<tr>
|
| 370 |
+
<td align="left">SDXL</td><td align="center">3.5B</td><td align="center">83.27</td><td align="center">82.43</td><td align="center">80.91</td><td align="center">86.76</td><td align="center">80.41</td><td align="center">74.65</td>
|
| 371 |
+
</tr>
|
| 372 |
+
<tr>
|
| 373 |
+
<td align="left">DALL-E 3</td><td align="center">-</td><td align="center">90.97</td><td align="center">89.61</td><td align="center">88.39</td><td align="center">90.58</td><td align="center">89.83</td><td align="center">83.50</td>
|
| 374 |
+
</tr>
|
| 375 |
+
<tr>
|
| 376 |
+
<td align="left">SD3-Medium</td><td align="center">2B</td><td align="center">87.90</td><td align="center">91.01</td><td align="center">88.83</td><td align="center">80.70</td><td align="center">88.68</td><td align="center">84.08</td>
|
| 377 |
+
</tr>
|
| 378 |
+
<tr>
|
| 379 |
+
<td align="left">FLUX.1-dev</td><td align="center">12B</td><td align="center">74.35</td><td align="center">90.00</td><td align="center">88.96</td><td align="center">90.87</td><td align="center">88.33</td><td align="center">83.84</td>
|
| 380 |
+
</tr>
|
| 381 |
+
<tr>
|
| 382 |
+
<td align="left">Qwen-Image</td><td align="center">20B</td><td align="center">91.32</td><td align="center">91.56</td><td align="center">92.02</td><td align="center">94.31</td><td align="center">92.73</td><td align="center">88.32</td>
|
| 383 |
+
</tr>
|
| 384 |
+
<tr>
|
| 385 |
+
<td align="center" colspan="8"><i>统一模型</i></td>
|
| 386 |
+
</tr>
|
| 387 |
+
<tr>
|
| 388 |
+
<td align="left">Janus-Pro-7B</td><td align="center">7B</td><td align="center">86.90</td><td align="center">88.90</td><td align="center">89.40</td><td align="center">89.32</td><td align="center">89.48</td><td align="center">84.19</td>
|
| 389 |
+
</tr>
|
| 390 |
+
<tr>
|
| 391 |
+
<td align="left">OmniGen2</td><td align="center">4B</td><td align="center">88.81</td><td align="center">88.83</td><td align="center">90.18</td><td align="center">89.37</td><td align="center">90.27</td><td align="center">83.57</td>
|
| 392 |
+
</tr>
|
| 393 |
+
<tr>
|
| 394 |
+
<td align="left">Show-o2</td><td align="center">7B</td><td align="center">89.00</td><td align="center"><b>91.78</b></td><td align="center">89.96</td><td align="center">91.81</td><td align="center"><b>91.64</b></td><td align="center">86.14</td>
|
| 395 |
+
</tr>
|
| 396 |
+
<tr>
|
| 397 |
+
<td align="left">BAGEL<sup>†</sup></td><td align="center">7B</td><td align="center">88.94</td><td align="center">90.37</td><td align="center"><u>91.29</u></td><td align="center">90.82</td><td align="center">88.67</td><td align="center">85.07</td>
|
| 398 |
+
</tr>
|
| 399 |
+
<tr>
|
| 400 |
+
<td align="left">InternVL-U</td><td align="center">1.7B</td><td align="center"><u>90.39</u></td><td align="center">90.78</td><td align="center">90.68</td><td align="center">90.29</td><td align="center">88.77</td><td align="center">85.18</td>
|
| 401 |
+
</tr>
|
| 402 |
+
<tr>
|
| 403 |
+
<td align="left">TUNA</td><td align="center">7B</td><td align="center"><b>90.42</b></td><td align="center"><u>91.68</u></td><td align="center">90.94</td><td align="center"><u>91.87</u></td><td align="center"><u>90.73</u></td><td align="center"><b>86.76</b></td>
|
| 404 |
+
</tr>
|
| 405 |
+
<tr>
|
| 406 |
+
<td align="left">TUNA-2</td><td align="center">7B</td><td align="center">89.50</td><td align="center">91.40</td><td align="center"><b>92.07</b></td><td align="center">91.91</td><td align="center">88.81</td><td align="center"><u>86.54</u></td>
|
| 407 |
+
</tr>
|
| 408 |
+
<tr>
|
| 409 |
+
<td align="left">🌟 <b>Lance (Ours)</b></td><td align="center"><b>3B</b></td><td align="center"><b>83.89</b></td><td align="center"><b>91.07</b></td><td align="center"><b>89.36</b></td><td align="center"><b>93.38</b></td><td align="center"><b>80.80</b></td><td align="center"><b>84.67</b></td>
|
| 410 |
+
</tr>
|
| 411 |
+
</tbody>
|
| 412 |
+
</table>
|
| 413 |
+
</div>
|
| 414 |
+
|
| 415 |
+
<p align="center"><em><sup>†</sup> 表示该方法在生成前使用 LLM rewriter 进行提示词改写。</em></p>
|
| 416 |
+
|
| 417 |
+
#### GenEval 评测
|
| 418 |
+
|
| 419 |
+
<div align="center">
|
| 420 |
+
<table align="center">
|
| 421 |
+
<thead>
|
| 422 |
+
<tr>
|
| 423 |
+
<th align="left">模型</th>
|
| 424 |
+
<th align="center"># Params.</th>
|
| 425 |
+
<th align="center">1-Obj.</th>
|
| 426 |
+
<th align="center">2-Obj.</th>
|
| 427 |
+
<th align="center">Count</th>
|
| 428 |
+
<th align="center">Colors</th>
|
| 429 |
+
<th align="center">Position</th>
|
| 430 |
+
<th align="center">Attr.</th>
|
| 431 |
+
<th align="center">Overall</th>
|
| 432 |
+
</tr>
|
| 433 |
+
</thead>
|
| 434 |
+
<tbody>
|
| 435 |
+
<tr>
|
| 436 |
+
<td align="center" colspan="9"><i>仅生成模型</i></td>
|
| 437 |
+
</tr>
|
| 438 |
+
<tr>
|
| 439 |
+
<td align="left">SDXL</td><td align="center">3.5B</td><td align="center">0.98</td><td align="center">0.74</td><td align="center">0.39</td><td align="center">0.85</td><td align="center">0.15</td><td align="center">0.23</td><td align="center">0.55</td>
|
| 440 |
+
</tr>
|
| 441 |
+
<tr>
|
| 442 |
+
<td align="left">DALL-E 3</td><td align="center">-</td><td align="center">0.96</td><td align="center">0.87</td><td align="center">0.47</td><td align="center">0.83</td><td align="center">0.43</td><td align="center">0.45</td><td align="center">0.67</td>
|
| 443 |
+
</tr>
|
| 444 |
+
<tr>
|
| 445 |
+
<td align="left">SD3-Medium</td><td align="center">2B</td><td align="center">0.99</td><td align="center">0.94</td><td align="center">0.72</td><td align="center">0.89</td><td align="center">0.33</td><td align="center">0.60</td><td align="center">0.74</td>
|
| 446 |
+
</tr>
|
| 447 |
+
<tr>
|
| 448 |
+
<td align="left">FLUX.1-dev</td><td align="center">12B</td><td align="center">0.98</td><td align="center">0.93</td><td align="center">0.75</td><td align="center">0.93</td><td align="center">0.68</td><td align="center">0.65</td><td align="center">0.82</td>
|
| 449 |
+
</tr>
|
| 450 |
+
<tr>
|
| 451 |
+
<td align="left">Qwen-Image</td><td align="center">20B</td><td align="center">0.99</td><td align="center">0.92</td><td align="center">0.89</td><td align="center">0.88</td><td align="center">0.76</td><td align="center">0.77</td><td align="center">0.87</td>
|
| 452 |
+
</tr>
|
| 453 |
+
<tr>
|
| 454 |
+
<td align="center" colspan="9"><i>统一模型</i></td>
|
| 455 |
+
</tr>
|
| 456 |
+
<tr>
|
| 457 |
+
<td align="left">Janus-Pro-7B</td><td align="center">7B</td><td align="center"><u>0.99</u></td><td align="center">0.89</td><td align="center">0.59</td><td align="center">0.90</td><td align="center">0.79</td><td align="center">0.66</td><td align="center">0.80</td>
|
| 458 |
+
</tr>
|
| 459 |
+
<tr>
|
| 460 |
+
<td align="left">OmniGen2</td><td align="center">4B</td><td align="center"><b>1.00</b></td><td align="center">0.95</td><td align="center">0.64</td><td align="center">0.88</td><td align="center">0.55</td><td align="center">0.76</td><td align="center">0.80</td>
|
| 461 |
+
</tr>
|
| 462 |
+
<tr>
|
| 463 |
+
<td align="left">Show-o2</td><td align="center">7B</td><td align="center"><b>1.00</b></td><td align="center">0.87</td><td align="center">0.58</td><td align="center">0.92</td><td align="center">0.52</td><td align="center">0.62</td><td align="center">0.76</td>
|
| 464 |
+
</tr>
|
| 465 |
+
<tr>
|
| 466 |
+
<td align="left">BAGEL<sup>†</sup></td><td align="center">7B</td><td align="center">0.98</td><td align="center">0.95</td><td align="center"><b>0.84</b></td><td align="center"><u>0.95</u></td><td align="center">0.78</td><td align="center">0.77</td><td align="center">0.88</td>
|
| 467 |
+
</tr>
|
| 468 |
+
<tr>
|
| 469 |
+
<td align="left">Mogao</td><td align="center">7B</td><td align="center"><b>1.00</b></td><td align="center"><b>0.97</b></td><td align="center"><u>0.83</u></td><td align="center">0.93</td><td align="center">0.84</td><td align="center">0.80</td><td align="center"><u>0.89</u></td>
|
| 470 |
+
</tr>
|
| 471 |
+
<tr>
|
| 472 |
+
<td align="left">InternVL-U</td><td align="center">1.7B</td><td align="center"><u>0.99</u></td><td align="center">0.94</td><td align="center">0.74</td><td align="center">0.91</td><td align="center">0.77</td><td align="center">0.74</td><td align="center">0.85</td>
|
| 473 |
+
</tr>
|
| 474 |
+
<tr>
|
| 475 |
+
<td align="left">TUNA</td><td align="center">7B</td><td align="center"><b>1.00</b></td><td align="center"><b>0.97</b></td><td align="center">0.81</td><td align="center">0.91</td><td align="center"><b>0.88</b></td><td align="center"><b>0.83</b></td><td align="center"><b>0.90</b></td>
|
| 476 |
+
</tr>
|
| 477 |
+
<tr>
|
| 478 |
+
<td align="left">TUNA-2</td><td align="center">7B</td><td align="center"><u>0.99</u></td><td align="center"><u>0.96</u></td><td align="center">0.80</td><td align="center">0.91</td><td align="center">0.84</td><td align="center">0.76</td><td align="center">0.87</td>
|
| 479 |
+
</tr>
|
| 480 |
+
<tr>
|
| 481 |
+
<td align="left">🌟 <b>Lance (Ours)</b></td><td align="center"><b>3B</b></td><td align="center"><b>1.00</b></td><td align="center"><b>0.94</b></td><td align="center"><b>0.84</b></td><td align="center"><b>0.97</b></td><td align="center"><b>0.87</b></td><td align="center"><b>0.81</b></td><td align="center"><b>0.90</b></td>
|
| 482 |
+
</tr>
|
| 483 |
+
</tbody>
|
| 484 |
+
</table>
|
| 485 |
+
</div>
|
| 486 |
+
|
| 487 |
+
<p align="center"><em><sup>†</sup> 表示该方法在生成前使用 LLM rewriter 进行提示词改写。</em></p>
|
| 488 |
+
|
| 489 |
+
#### GEdit-Bench 评测
|
| 490 |
+
|
| 491 |
+
<div align="center">
|
| 492 |
+
<table align="center">
|
| 493 |
+
<thead>
|
| 494 |
+
<tr>
|
| 495 |
+
<th align="left">模型</th>
|
| 496 |
+
<th align="center"># Params.</th>
|
| 497 |
+
<th align="center">BC</th>
|
| 498 |
+
<th align="center">CA</th>
|
| 499 |
+
<th align="center">MM</th>
|
| 500 |
+
<th align="center">MC</th>
|
| 501 |
+
<th align="center">PB</th>
|
| 502 |
+
<th align="center">ST</th>
|
| 503 |
+
<th align="center">SA</th>
|
| 504 |
+
<th align="center">SR</th>
|
| 505 |
+
<th align="center">SRp</th>
|
| 506 |
+
<th align="center">TM</th>
|
| 507 |
+
<th align="center">TT</th>
|
| 508 |
+
<th align="center">Avg/G_O</th>
|
| 509 |
+
</tr>
|
| 510 |
+
</thead>
|
| 511 |
+
<tbody>
|
| 512 |
+
<tr>
|
| 513 |
+
<td align="center" colspan="14"><i>仅生成模型</i></td>
|
| 514 |
+
</tr>
|
| 515 |
+
<tr>
|
| 516 |
+
<td align="left">Gemini 2.0</td><td align="center">-</td><td align="center">-</td><td align="center">-</td><td align="center">-</td><td align="center">-</td><td align="center">-</td><td align="center">-</td><td align="center">-</td><td align="center">-</td><td align="center">-</td><td align="center">-</td><td align="center">-</td><td align="center">6.32</td>
|
| 517 |
+
</tr>
|
| 518 |
+
<tr>
|
| 519 |
+
<td align="left">GPT Image 1</td><td align="center">-</td><td align="center">6.96</td><td align="center">6.85</td><td align="center">7.10</td><td align="center">5.41</td><td align="center">6.74</td><td align="center">7.44</td><td align="center">7.51</td><td align="center">8.73</td><td align="center">8.55</td><td align="center">8.45</td><td align="center">8.69</td><td align="center">7.49</td>
|
| 520 |
+
</tr>
|
| 521 |
+
<tr>
|
| 522 |
+
<td align="left">Qwen-Image-Edit</td><td align="center">20B</td><td align="center">8.23</td><td align="center">8.30</td><td align="center">7.33</td><td align="center">8.05</td><td align="center">7.49</td><td align="center">6.74</td><td align="center">8.57</td><td align="center">8.09</td><td align="center">8.29</td><td align="center">8.48</td><td align="center">8.50</td><td align="center">8.01</td>
|
| 523 |
+
</tr>
|
| 524 |
+
<tr>
|
| 525 |
+
<td align="center" colspan="14"><i>统一模型</i></td>
|
| 526 |
+
</tr>
|
| 527 |
+
<tr>
|
| 528 |
+
<td align="left">Lumina-DiMOO</td><td align="center">8B</td><td align="center">3.43</td><td align="center">4.27</td><td align="center">3.08</td><td align="center">2.77</td><td align="center">4.74</td><td align="center">5.19</td><td align="center">4.44</td><td align="center">3.80</td><td align="center">4.38</td><td align="center">2.68</td><td align="center">4.20</td><td align="center">3.91</td>
|
| 529 |
+
</tr>
|
| 530 |
+
<tr>
|
| 531 |
+
<td align="left">Ovis-U1</td><td align="center">1.2B</td><td align="center"><u>7.49</u></td><td align="center">6.88</td><td align="center">6.21</td><td align="center">4.79</td><td align="center">5.98</td><td align="center"><u>6.46</u></td><td align="center">7.49</td><td align="center"><u>7.25</u></td><td align="center"><u>7.27</u></td><td align="center">4.48</td><td align="center">6.31</td><td align="center">6.42</td>
|
| 532 |
+
</tr>
|
| 533 |
+
<tr>
|
| 534 |
+
<td align="left">BAGEL</td><td align="center">7B</td><td align="center">7.32</td><td align="center">6.91</td><td align="center">6.38</td><td align="center">4.75</td><td align="center">4.57</td><td align="center">6.15</td><td align="center"><b>7.90</b></td><td align="center">7.16</td><td align="center">7.02</td><td align="center"><u>7.32</u></td><td align="center">6.22</td><td align="center">6.52</td>
|
| 535 |
+
</tr>
|
| 536 |
+
<tr>
|
| 537 |
+
<td align="left">InternVL-U</td><td align="center">1.7B</td><td align="center">7.08</td><td align="center">7.05</td><td align="center">6.38</td><td align="center"><u>7.02</u></td><td align="center"><u>6.03</u></td><td align="center">6.27</td><td align="center">7.13</td><td align="center">6.55</td><td align="center">6.33</td><td align="center">6.59</td><td align="center"><u>6.85</u></td><td align="center">6.66</td>
|
| 538 |
+
</tr>
|
| 539 |
+
<tr>
|
| 540 |
+
<td align="left">InternVL-U (w/ CoT)</td><td align="center">1.7B</td><td align="center">7.05</td><td align="center"><b>7.87</b></td><td align="center"><u>6.50</u></td><td align="center">6.99</td><td align="center">5.77</td><td align="center">6.10</td><td align="center">7.33</td><td align="center">7.16</td><td align="center">7.12</td><td align="center"><b>7.36</b></td><td align="center">6.46</td><td align="center"><u>6.88</u></td>
|
| 541 |
+
</tr>
|
| 542 |
+
<tr>
|
| 543 |
+
<td align="left">🌟 <b>Lance (Ours)</b></td><td align="center"><b>3B</b></td><td align="center"><b>7.73</b></td><td align="center"><u>7.74</u></td><td align="center"><b>7.28</b></td><td align="center"><b>7.83</b></td><td align="center"><b>7.50</b></td><td align="center"><b>7.03</b></td><td align="center"><u>7.64</u></td><td align="center"><b>7.85</b></td><td align="center"><b>7.71</b></td><td align="center">4.46</td><td align="center"><b>7.57</b></td><td align="center"><b>7.30</b></td>
|
| 544 |
+
</tr>
|
| 545 |
+
</tbody>
|
| 546 |
+
</table>
|
| 547 |
+
</div>
|
| 548 |
+
|
| 549 |
+
#### VBench 评测(视频生成)
|
| 550 |
+
|
| 551 |
+
<div align="center">
|
| 552 |
+
<table align="center">
|
| 553 |
+
<thead>
|
| 554 |
+
<tr>
|
| 555 |
+
<th align="left">类型</th>
|
| 556 |
+
<th align="left">Model</th>
|
| 557 |
+
<th align="center"># Params.</th>
|
| 558 |
+
<th align="center">Total Score ↑</th>
|
| 559 |
+
</tr>
|
| 560 |
+
</thead>
|
| 561 |
+
<tbody>
|
| 562 |
+
<tr>
|
| 563 |
+
<td align="center" rowspan="12"><i>Gen. Only</i></td>
|
| 564 |
+
<td align="left">ModelScope</td><td align="center">1.7B</td><td align="center">75.75</td>
|
| 565 |
+
</tr>
|
| 566 |
+
<tr>
|
| 567 |
+
<td align="left">LaVie</td><td align="center">3B</td><td align="center">77.08</td>
|
| 568 |
+
</tr>
|
| 569 |
+
<tr>
|
| 570 |
+
<td align="left">Show-1</td><td align="center">6B</td><td align="center">78.93</td>
|
| 571 |
+
</tr>
|
| 572 |
+
<tr>
|
| 573 |
+
<td align="left">AnimateDiff-V2</td><td align="center">-</td><td align="center">80.27</td>
|
| 574 |
+
</tr>
|
| 575 |
+
<tr>
|
| 576 |
+
<td align="left">VideoCrafter-2.0</td><td align="center">-</td><td align="center">80.44</td>
|
| 577 |
+
</tr>
|
| 578 |
+
<tr>
|
| 579 |
+
<td align="left">CogVideoX</td><td align="center">5B</td><td align="center">81.61</td>
|
| 580 |
+
</tr>
|
| 581 |
+
<tr>
|
| 582 |
+
<td align="left">Kling</td><td align="center">-</td><td align="center">81.85</td>
|
| 583 |
+
</tr>
|
| 584 |
+
<tr>
|
| 585 |
+
<td align="left">Open-Sora-2.0</td><td align="center">-</td><td align="center">81.71</td>
|
| 586 |
+
</tr>
|
| 587 |
+
<tr>
|
| 588 |
+
<td align="left">Gen-3</td><td align="center">-</td><td align="center">82.32</td>
|
| 589 |
+
</tr>
|
| 590 |
+
<tr>
|
| 591 |
+
<td align="left">Step-Video-T2V</td><td align="center">30B</td><td align="center">81.83</td>
|
| 592 |
+
</tr>
|
| 593 |
+
<tr>
|
| 594 |
+
<td align="left">Hunyuan Video</td><td align="center">-</td><td align="center">83.43</td>
|
| 595 |
+
</tr>
|
| 596 |
+
<tr>
|
| 597 |
+
<td align="left">Wan2.1-T2V</td><td align="center">14B</td><td align="center">83.69</td>
|
| 598 |
+
</tr>
|
| 599 |
+
<tr>
|
| 600 |
+
<td align="center" rowspan="6"><i>Unified</i></td>
|
| 601 |
+
<td align="left">HaproOmni</td><td align="center">7B</td><td align="center">78.10</td>
|
| 602 |
+
</tr>
|
| 603 |
+
<tr>
|
| 604 |
+
<td align="left">Emu3</td><td align="center">8B</td><td align="center">80.96</td>
|
| 605 |
+
</tr>
|
| 606 |
+
<tr>
|
| 607 |
+
<td align="left">VILA-U</td><td align="center">7B</td><td align="center">74.01</td>
|
| 608 |
+
</tr>
|
| 609 |
+
<tr>
|
| 610 |
+
<td align="left">Show-o2</td><td align="center">2B</td><td align="center">81.34</td>
|
| 611 |
+
</tr>
|
| 612 |
+
<tr>
|
| 613 |
+
<td align="left">TUNA</td><td align="center">1.5B</td><td align="center"><u>84.06</u></td>
|
| 614 |
+
</tr>
|
| 615 |
+
<tr>
|
| 616 |
+
<td align="left">🌟 <b>Lance (Ours)</b></td><td align="center"><b>3B</b></td><td align="center"><b>85.11</b></td>
|
| 617 |
+
</tr>
|
| 618 |
+
</tbody>
|
| 619 |
+
</table>
|
| 620 |
+
</div>
|
| 621 |
+
|
| 622 |
+
#### 运行基准评测
|
| 623 |
+
|
| 624 |
+
`benchmarks/` 目录下提供了可直接运行的基准评测脚本:
|
| 625 |
+
|
| 626 |
+
| 基准 | 模态 | 脚本 |
|
| 627 |
+
|------------------------|----------|---------------------------------------------------------------|
|
| 628 |
+
| GenEVAL(图像生成) | 图像 | `benchmarks/image_gen/GenEVAL/sample_GenEVAL.sh` |
|
| 629 |
+
| DPG(图像生成) | 图像 | `benchmarks/image_gen/DPG/sample_DPG.sh` |
|
| 630 |
+
| GEdit(��像编辑) | 图像 | `benchmarks/image_gen/GEdit/sample_GEdit.sh` |
|
| 631 |
+
| VBench(视频生成) | 视频 | `benchmarks/video_gen/Vbench/sample_vbench.sh` |
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
## 📄 许可证
|
| 635 |
+
|
| 636 |
+
Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
| 637 |
+
|
| 638 |
+
## 🙏 致谢
|
| 639 |
+
|
| 640 |
+
我们感谢 [BAGEL](https://github.com/ByteDance-Seed/bagel)、[Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct) 和 [Wan2.2](https://github.com/Wan-Video/Wan2.2) 的贡献者,感谢他们开放的研究与社区贡献。
|
| 641 |
+
|
| 642 |
+
## 💖 引用
|
| 643 |
+
|
| 644 |
+
如果 **Lance** 对您的项目或研究有帮助,欢迎 🌟 本仓库,并使用以下 BibTeX 引用我们的工作:
|
| 645 |
+
|
| 646 |
+
```bibtex
|
| 647 |
+
@misc{fu2026lanceunifiedmultimodalmodeling,
|
| 648 |
+
title = {Lance: Unified Multimodal Modeling by Multi-Task Synergy},
|
| 649 |
+
author = {Fengyi Fu and Mengqi Huang and Shaojin Wu and Yunsheng Jiang and Yufei Huo and Hao Li and Yinghang Song and Fei Ding and Jianzhu Guo and Qian He and Zheren Fu and Zhendong Mao and Yongdong Zhang},
|
| 650 |
+
year = {2026},
|
| 651 |
+
eprint = {2605.18678},
|
| 652 |
+
archivePrefix = {arXiv},
|
| 653 |
+
primaryClass = {cs.CV},
|
| 654 |
+
url = {https://arxiv.org/abs/2605.18678},
|
| 655 |
+
}
|
| 656 |
+
```
|
| 657 |
+
|
| 658 |
+
## 📞 联系方式
|
| 659 |
+
|
| 660 |
+
如有问题、反馈或合作需求,请联系 [Mengqi Huang](https://corleone-huang.github.io/) 和 [Jianzhu Guo](https://guojianzhu.com/)。
|
RIFE
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit 5d8adbdd40e12c2c8f91930eff838aebe561c086
|
SECURITY.md
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Security and privacy
|
| 2 |
+
If you discover potential security issues in the project, or believe you may have found a security issue, please notify the ByteDance security team through our [security center](https://security.bytedance.com/src) or [vulnerability reporting email](mailto:src@bytedance.com). Please **do not** create public GitHub Issues.
|
| 3 |
+
|
| 4 |
+
We will assess the vulnerability based on the Common Vulnerability Scoring System (CVSS 3.1). The security team will keep you updated on key progress and may request further information or guidance from you. You are welcome to contact us via the email or website mentioned above to ask questions or discuss disclosure matters.
|
| 5 |
+
|
| 6 |
+
To protect the security of our customers, ByteDance requests that you do not publish or share information regarding the vulnerability in any public forum, nor publish or share data involving users, until the vulnerability has been remediated and our users have been notified. Please understand that the time required for remediation depends on the severity of the vulnerability and the scope of the impact.
|
| 7 |
+
|
| 8 |
+
Individuals, companies, and security teams may wish to publish security advisories on their own websites or other forums. Please contact us via the email or website mentioned above prior to publication to discuss the information that can be disclosed and to coordinate the disclosure timeline.
|
| 9 |
+
|
| 10 |
+
# Bug Bounty Reward
|
| 11 |
+
[For the policy of bug bounty reward](https://bytedance.larkoffice.com/docx/ZstQd7bbooDctqxBCAmcFasOngd), if you have any questions about the rules, please contact [https://src.bytedance.com/home](https://src.bytedance.com/home) for consultation.
|
SPACE_DEPLOYMENT.md
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Space Deployment
|
| 2 |
+
|
| 3 |
+
This repository is prepared for a Docker-based Hugging Face Space.
|
| 4 |
+
|
| 5 |
+
## Runtime
|
| 6 |
+
|
| 7 |
+
- Space SDK: Docker
|
| 8 |
+
- Public port: `7860`
|
| 9 |
+
- Entrypoint: `python app.py`
|
| 10 |
+
- Recommended hardware: GPU, preferably `l40s` or stronger
|
| 11 |
+
|
| 12 |
+
## Model Assets
|
| 13 |
+
|
| 14 |
+
The app first checks local model files under `LANCE_MODEL_BASE_DIR`.
|
| 15 |
+
|
| 16 |
+
Default behavior:
|
| 17 |
+
|
| 18 |
+
- Local checkout with `downloads/`: use `./downloads`
|
| 19 |
+
- Hugging Face Space without local assets: download from `bytedance-research/Lance` into `/data/lance_models`
|
| 20 |
+
- Video tasks preload `Lance_3B_Video` at startup.
|
| 21 |
+
- Image tasks unload the active video model first, then load `Lance_3B`.
|
| 22 |
+
- Switching back to a video task unloads `Lance_3B`, then reloads `Lance_3B_Video`.
|
| 23 |
+
|
| 24 |
+
Useful environment variables:
|
| 25 |
+
|
| 26 |
+
- `LANCE_MODEL_REPO_ID`: Hugging Face model repo to download from. Default: `bytedance-research/Lance`
|
| 27 |
+
- `LANCE_MODEL_BASE_DIR`: directory containing `Lance_3B_Video`, `Qwen2.5-VL-ViT`, and `Wan2.2_VAE.pth`
|
| 28 |
+
- `LANCE_VIDEO_MODEL_PATH`: explicit video model directory override
|
| 29 |
+
- `LANCE_IMAGE_MODEL_PATH`: explicit image model directory override
|
| 30 |
+
- `LANCE_MODEL_PATH`: legacy explicit model directory override used for both task families if the family-specific override is unset
|
| 31 |
+
- `LANCE_MODEL_VARIANT`: `video` or `image`; default is `video`
|
| 32 |
+
- `LANCE_AUTO_DOWNLOAD`: set to `1` to download missing assets from the Hub
|
| 33 |
+
- `LANCE_GPUS`: comma-separated GPU IDs, for example `0` or `0,1`
|
| 34 |
+
- `LANCE_QUEUE_SIZE`: Gradio queue size
|
| 35 |
+
- `LANCE_GRADIO_TMP_ROOT`: output and temporary file directory
|
| 36 |
+
|
| 37 |
+
Expected model layout:
|
| 38 |
+
|
| 39 |
+
```text
|
| 40 |
+
${LANCE_MODEL_BASE_DIR}/
|
| 41 |
+
Lance_3B_Video/
|
| 42 |
+
llm_config.json
|
| 43 |
+
model.safetensors
|
| 44 |
+
tokenizer.json
|
| 45 |
+
...
|
| 46 |
+
Lance_3B/
|
| 47 |
+
llm_config.json
|
| 48 |
+
model.safetensors
|
| 49 |
+
tokenizer.json
|
| 50 |
+
...
|
| 51 |
+
Qwen2.5-VL-ViT/
|
| 52 |
+
config.json
|
| 53 |
+
vit.safetensors
|
| 54 |
+
Wan2.2_VAE.pth
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## Local Docker Check
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
docker build -t lance-space .
|
| 61 |
+
docker run --gpus all -p 7860:7860 \
|
| 62 |
+
-e LANCE_MODEL_BASE_DIR=/models/lance \
|
| 63 |
+
-v /path/to/lance/downloads:/models/lance \
|
| 64 |
+
lance-space
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
Open `http://localhost:7860`.
|
| 68 |
+
|
| 69 |
+
## Files Not Uploaded
|
| 70 |
+
|
| 71 |
+
The Space build excludes generated or heavyweight local files through `.dockerignore`:
|
| 72 |
+
|
| 73 |
+
- `downloads/`
|
| 74 |
+
- `results/`
|
| 75 |
+
- `tmps/`
|
| 76 |
+
- Python cache files
|
app.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
app_save.py
ADDED
|
@@ -0,0 +1,2064 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import base64
|
| 5 |
+
import concurrent.futures
|
| 6 |
+
import gc
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import random
|
| 10 |
+
import subprocess
|
| 11 |
+
import threading
|
| 12 |
+
import time
|
| 13 |
+
import traceback
|
| 14 |
+
from collections import deque
|
| 15 |
+
from copy import deepcopy
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Optional
|
| 19 |
+
|
| 20 |
+
import gradio as gr
|
| 21 |
+
import torch
|
| 22 |
+
from huggingface_hub import snapshot_download
|
| 23 |
+
from safetensors.torch import load_file
|
| 24 |
+
from transformers import set_seed
|
| 25 |
+
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig
|
| 26 |
+
|
| 27 |
+
from common.utils.logging import get_logger
|
| 28 |
+
from common.utils.misc import AutoEncoderParams, tuple_mul
|
| 29 |
+
from config.config_factory import DataArguments, InferenceArguments, ModelArguments
|
| 30 |
+
from data.data_utils import add_special_tokens
|
| 31 |
+
from data.dataset_base import DataConfig, simple_custom_collate
|
| 32 |
+
from data.datasets_custom import ValidationDataset
|
| 33 |
+
from inference_lance import (
|
| 34 |
+
PROMPT_JSON_FILENAME,
|
| 35 |
+
apply_inference_defaults,
|
| 36 |
+
clean_memory,
|
| 37 |
+
init_from_model_path_if_needed,
|
| 38 |
+
save_prompt_results,
|
| 39 |
+
validate_on_fixed_batch,
|
| 40 |
+
)
|
| 41 |
+
from modeling.lance import Lance, LanceConfig, Qwen2ForCausalLM
|
| 42 |
+
from modeling.qwen2 import Qwen2Tokenizer
|
| 43 |
+
from modeling.qwen2.modeling_qwen2 import Qwen2Config
|
| 44 |
+
from modeling.vae.wan.model import WanVideoVAE
|
| 45 |
+
from modeling.vit.qwen2_5_vl_vit import Qwen2_5_VisionTransformerPretrainedModel
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
REPO_ROOT = Path(__file__).resolve().parent
|
| 49 |
+
GRADIO_TMP_ROOT = Path(os.getenv("LANCE_GRADIO_TMP_ROOT", "/tmp/lance_gradio")).expanduser()
|
| 50 |
+
TMP_INPUT_DIR = GRADIO_TMP_ROOT / "inputs"
|
| 51 |
+
RESULTS_ROOT = GRADIO_TMP_ROOT / "results"
|
| 52 |
+
GLOBAL_RECORDS_FILE = GRADIO_TMP_ROOT / "generation_records.jsonl"
|
| 53 |
+
RUN_RECORD_FILENAME = "generation_record.json"
|
| 54 |
+
|
| 55 |
+
LOCAL_MODEL_BASE_DIR = Path("downloads")
|
| 56 |
+
SPACE_MODEL_BASE_DIR = Path("/data/lance_models")
|
| 57 |
+
DEFAULT_MODEL_REPO_ID = "bytedance-research/Lance"
|
| 58 |
+
DEFAULT_MODEL_VARIANT = "video"
|
| 59 |
+
MODEL_VARIANT_VIDEO = "video"
|
| 60 |
+
MODEL_VARIANT_IMAGE = "image"
|
| 61 |
+
MODEL_VARIANT_TO_DIR = {
|
| 62 |
+
MODEL_VARIANT_VIDEO: "Lance_3B_Video",
|
| 63 |
+
MODEL_VARIANT_IMAGE: "Lance_3B",
|
| 64 |
+
}
|
| 65 |
+
DEFAULT_MODEL_PATH = LOCAL_MODEL_BASE_DIR / MODEL_VARIANT_TO_DIR[MODEL_VARIANT_VIDEO]
|
| 66 |
+
DEFAULT_VIT_TYPE = "qwen_2_5_vl_original"
|
| 67 |
+
DEFAULT_TASK = "t2v"
|
| 68 |
+
DEFAULT_TIMESTEPS = 30
|
| 69 |
+
DEFAULT_TIMESTEP_SHIFT = 3.5
|
| 70 |
+
DEFAULT_CFG_TEXT_SCALE = 4.0
|
| 71 |
+
DEFAULT_RESOLUTION = "video_848x480"
|
| 72 |
+
DEFAULT_IMAGE_RESOLUTION = "image_768x768"
|
| 73 |
+
DEFAULT_BASIC_SEED = 42
|
| 74 |
+
DEFAULT_HEIGHT = 480
|
| 75 |
+
DEFAULT_WIDTH = 848
|
| 76 |
+
DEFAULT_IMAGE_SIZE = 768
|
| 77 |
+
DEFAULT_NUM_FRAMES = 101
|
| 78 |
+
DEFAULT_VIDEO_ASPECT_RATIO = "16:9"
|
| 79 |
+
DEFAULT_IMAGE_ASPECT_RATIO = "1:1"
|
| 80 |
+
DEFAULT_FRAME_INTERPOLATION = True
|
| 81 |
+
ASPECT_RATIO_CHOICES = ["21:9", "16:9", "3:2", "4:3", "1:1", "3:4", "2:3", "9:16", "9:21"]
|
| 82 |
+
|
| 83 |
+
VIDEO_ASPECT_RATIO_TO_SIZE = {
|
| 84 |
+
"21:9": (976, 416),
|
| 85 |
+
"16:9": (848, 480),
|
| 86 |
+
"3:2": (784, 528),
|
| 87 |
+
"4:3": (736, 560),
|
| 88 |
+
"1:1": (640, 640),
|
| 89 |
+
"3:4": (560, 736),
|
| 90 |
+
"2:3": (528, 784),
|
| 91 |
+
"9:16": (480, 848),
|
| 92 |
+
"9:21": (416, 976),
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
IMAGE_ASPECT_RATIO_TO_SIZE = {
|
| 96 |
+
"21:9": (1168, 496),
|
| 97 |
+
"16:9": (1024, 576),
|
| 98 |
+
"3:2": (944, 624),
|
| 99 |
+
"4:3": (880, 672),
|
| 100 |
+
"1:1": (768, 768),
|
| 101 |
+
"3:4": (672, 880),
|
| 102 |
+
"2:3": (624, 944),
|
| 103 |
+
"9:16": (576, 1024),
|
| 104 |
+
"9:21": (496, 1168),
|
| 105 |
+
}
|
| 106 |
+
DEFAULT_GPUS = "0"
|
| 107 |
+
DEFAULT_QUEUE_SIZE = 32
|
| 108 |
+
USE_KVCACHE = True
|
| 109 |
+
TEXT_TEMPLATE = True
|
| 110 |
+
RECORD_WRITE_LOCK = threading.Lock()
|
| 111 |
+
|
| 112 |
+
LANCE_HOMEPAGE_URL = "https://lance-project.github.io/"
|
| 113 |
+
LANCE_PAPER_URL = "http://arxiv.org/abs/2605.18678"
|
| 114 |
+
LANCE_HUGGING_FACE_URL = "https://huggingface.co/bytedance-research/Lance"
|
| 115 |
+
LANCE_GITHUB_URL = "https://github.com/bytedance/Lance"
|
| 116 |
+
LANCE_LOGO_PATH = REPO_ROOT / "assets" / "logo" / "lance-logo.webp"
|
| 117 |
+
|
| 118 |
+
APP_CSS = """
|
| 119 |
+
.gradio-container {
|
| 120 |
+
max-width: 1680px !important;
|
| 121 |
+
margin-left: auto !important;
|
| 122 |
+
margin-right: auto !important;
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
.contain {
|
| 126 |
+
max-width: 1680px !important;
|
| 127 |
+
margin-left: auto !important;
|
| 128 |
+
margin-right: auto !important;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
.lance-hero {
|
| 132 |
+
text-align: center;
|
| 133 |
+
padding: 8px 12px 6px;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
.lance-logo {
|
| 137 |
+
width: min(160px, 36vw);
|
| 138 |
+
height: auto;
|
| 139 |
+
display: block;
|
| 140 |
+
margin: 0 auto 4px;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
.lance-title {
|
| 144 |
+
margin: 0 auto 5px;
|
| 145 |
+
font-size: clamp(20px, 2.4vw, 30px);
|
| 146 |
+
line-height: 1.08;
|
| 147 |
+
font-weight: 800;
|
| 148 |
+
letter-spacing: 0;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
.lance-authors {
|
| 152 |
+
margin: 0 auto 6px;
|
| 153 |
+
max-width: 980px;
|
| 154 |
+
font-size: 13px;
|
| 155 |
+
line-height: 1.35;
|
| 156 |
+
color: var(--body-text-color-subdued);
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
.lance-authors a {
|
| 160 |
+
color: inherit;
|
| 161 |
+
text-decoration: none;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
.lance-authors a:hover {
|
| 165 |
+
text-decoration: underline;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
.lance-badges {
|
| 169 |
+
display: flex;
|
| 170 |
+
flex-wrap: wrap;
|
| 171 |
+
justify-content: center;
|
| 172 |
+
gap: 5px;
|
| 173 |
+
margin: 4px auto 0;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
.lance-badges a {
|
| 177 |
+
line-height: 0;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
.lance-badges img {
|
| 181 |
+
height: 20px;
|
| 182 |
+
width: auto;
|
| 183 |
+
display: block;
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
.lance-status {
|
| 187 |
+
max-width: 1180px;
|
| 188 |
+
margin: 0 auto 18px;
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
.task-selector {
|
| 192 |
+
overflow-x: auto;
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
.task-selector .wrap {
|
| 196 |
+
display: grid;
|
| 197 |
+
grid-template-columns: repeat(3, minmax(220px, 1fr));
|
| 198 |
+
gap: 8px;
|
| 199 |
+
min-width: 680px;
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
.task-selector label {
|
| 203 |
+
justify-content: center;
|
| 204 |
+
min-height: 38px;
|
| 205 |
+
white-space: nowrap;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
.prompt-examples table,
|
| 209 |
+
.prompt-examples th,
|
| 210 |
+
.prompt-examples td {
|
| 211 |
+
border: 1px solid var(--border-color-primary) !important;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
.prompt-examples table {
|
| 215 |
+
border-collapse: collapse !important;
|
| 216 |
+
width: 100% !important;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
.prompt-examples td {
|
| 220 |
+
border-bottom: 1px solid var(--border-color-primary) !important;
|
| 221 |
+
padding: 12px !important;
|
| 222 |
+
vertical-align: top !important;
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
.prompt-example-proxy {
|
| 226 |
+
display: none !important;
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
.lance-main-row {
|
| 230 |
+
display: grid !important;
|
| 231 |
+
grid-template-columns: minmax(0, 1fr) minmax(0, 1fr) !important;
|
| 232 |
+
gap: 16px !important;
|
| 233 |
+
align-items: start !important;
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
.lance-main-column {
|
| 237 |
+
min-width: 0 !important;
|
| 238 |
+
width: 100% !important;
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
.lance-display-frame,
|
| 242 |
+
.lance-display-frame > div,
|
| 243 |
+
.lance-display-frame textarea {
|
| 244 |
+
width: 100% !important;
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
.lance-display-frame textarea {
|
| 248 |
+
min-height: 360px !important;
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
@media (max-width: 900px) {
|
| 252 |
+
.lance-main-row {
|
| 253 |
+
grid-template-columns: minmax(0, 1fr) !important;
|
| 254 |
+
}
|
| 255 |
+
}
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
TASK_T2V = "t2v"
|
| 259 |
+
TASK_T2I = "t2i"
|
| 260 |
+
TASK_V2T = "v2t"
|
| 261 |
+
TASK_X2T = "x2t"
|
| 262 |
+
TASK_X2T_VIDEO = "x2t_video"
|
| 263 |
+
TASK_X2T_IMAGE = "x2t_image"
|
| 264 |
+
TASK_IMAGE_EDIT = "image_edit"
|
| 265 |
+
TASK_VIDEO_EDIT = "video_edit"
|
| 266 |
+
TASK_LABEL_VIDEO_GENERATION = "Video Generation"
|
| 267 |
+
TASK_LABEL_VIDEO_EDIT = "Video Edit"
|
| 268 |
+
TASK_LABEL_VIDEO_UNDERSTANDING = "Video Understanding"
|
| 269 |
+
TASK_LABEL_IMAGE_GENERATION = "Image Generation"
|
| 270 |
+
TASK_LABEL_IMAGE_EDIT = "Image Edit"
|
| 271 |
+
TASK_LABEL_IMAGE_UNDERSTANDING = "Image Understanding"
|
| 272 |
+
TASK_CHOICES = [
|
| 273 |
+
TASK_LABEL_VIDEO_GENERATION,
|
| 274 |
+
TASK_LABEL_VIDEO_EDIT,
|
| 275 |
+
TASK_LABEL_VIDEO_UNDERSTANDING,
|
| 276 |
+
TASK_LABEL_IMAGE_GENERATION,
|
| 277 |
+
TASK_LABEL_IMAGE_EDIT,
|
| 278 |
+
TASK_LABEL_IMAGE_UNDERSTANDING,
|
| 279 |
+
]
|
| 280 |
+
TASK_LABEL_TO_INTERNAL = {
|
| 281 |
+
TASK_LABEL_VIDEO_GENERATION: TASK_T2V,
|
| 282 |
+
TASK_LABEL_VIDEO_EDIT: TASK_VIDEO_EDIT,
|
| 283 |
+
TASK_LABEL_VIDEO_UNDERSTANDING: TASK_X2T_VIDEO,
|
| 284 |
+
TASK_LABEL_IMAGE_GENERATION: TASK_T2I,
|
| 285 |
+
TASK_LABEL_IMAGE_EDIT: TASK_IMAGE_EDIT,
|
| 286 |
+
TASK_LABEL_IMAGE_UNDERSTANDING: TASK_X2T_IMAGE,
|
| 287 |
+
TASK_T2V: TASK_T2V,
|
| 288 |
+
TASK_VIDEO_EDIT: TASK_VIDEO_EDIT,
|
| 289 |
+
TASK_V2T: TASK_X2T_VIDEO,
|
| 290 |
+
TASK_X2T: TASK_X2T_VIDEO,
|
| 291 |
+
TASK_X2T_VIDEO: TASK_X2T_VIDEO,
|
| 292 |
+
TASK_T2I: TASK_T2I,
|
| 293 |
+
TASK_IMAGE_EDIT: TASK_IMAGE_EDIT,
|
| 294 |
+
TASK_X2T_IMAGE: TASK_X2T_IMAGE,
|
| 295 |
+
}
|
| 296 |
+
GENERATION_TASKS = {TASK_T2V, TASK_T2I, TASK_IMAGE_EDIT, TASK_VIDEO_EDIT}
|
| 297 |
+
UNDERSTANDING_TASKS = {TASK_X2T_VIDEO, TASK_X2T_IMAGE}
|
| 298 |
+
IMAGE_TASKS = {TASK_T2I, TASK_IMAGE_EDIT, TASK_X2T_IMAGE}
|
| 299 |
+
VIDEO_TASKS = {TASK_T2V, TASK_VIDEO_EDIT, TASK_X2T_VIDEO}
|
| 300 |
+
EDIT_TASKS = {TASK_IMAGE_EDIT, TASK_VIDEO_EDIT}
|
| 301 |
+
VIDEO_RESOLUTION_CHOICES = [DEFAULT_RESOLUTION]
|
| 302 |
+
IMAGE_RESOLUTION_CHOICES = [DEFAULT_IMAGE_RESOLUTION]
|
| 303 |
+
RESOLUTION_CHOICES = VIDEO_RESOLUTION_CHOICES + IMAGE_RESOLUTION_CHOICES
|
| 304 |
+
CAPTION_SYSTEM_PROMPT_TEMPLATE = (
|
| 305 |
+
"Describe the key features of the input {vision_type}, including color, shape, size, texture, objects, background."
|
| 306 |
+
)
|
| 307 |
+
V2T_CAPTION_SYSTEM_PROMPT = CAPTION_SYSTEM_PROMPT_TEMPLATE.format(vision_type="video")
|
| 308 |
+
I2T_CAPTION_SYSTEM_PROMPT = CAPTION_SYSTEM_PROMPT_TEMPLATE.format(vision_type="image")
|
| 309 |
+
V2T_QA_SYSTEM_PROMPT = "View the video attentively and provide a suitable answer to the posed question."
|
| 310 |
+
I2T_QA_SYSTEM_PROMPT = "View the image attentively and provide a suitable answer to the posed question."
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def get_aspect_ratio_choices_for_task(task: str) -> list[tuple[str, str]]:
|
| 314 |
+
"""Get Aspect Ratio choices with default/recommended marker for the given task."""
|
| 315 |
+
internal_task = normalize_task(task)
|
| 316 |
+
default_ratio = DEFAULT_IMAGE_ASPECT_RATIO if internal_task in IMAGE_TASKS else DEFAULT_VIDEO_ASPECT_RATIO
|
| 317 |
+
return [
|
| 318 |
+
(f"{ratio} (default)" if ratio == default_ratio else ratio, ratio)
|
| 319 |
+
for ratio in ASPECT_RATIO_CHOICES
|
| 320 |
+
]
|
| 321 |
+
|
| 322 |
+
def env_flag(name: str, default: bool) -> bool:
|
| 323 |
+
value = os.getenv(name)
|
| 324 |
+
if value is None:
|
| 325 |
+
return default
|
| 326 |
+
return value.strip().lower() in {"1", "true", "yes", "on"}
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def running_on_space() -> bool:
|
| 330 |
+
return bool(os.getenv("SPACE_ID") or os.getenv("SPACE_HOST"))
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def display_path(path: Path) -> str:
|
| 334 |
+
path_text = path.as_posix()
|
| 335 |
+
if path.is_absolute():
|
| 336 |
+
try:
|
| 337 |
+
path_text = path.relative_to(Path.cwd()).as_posix()
|
| 338 |
+
except ValueError:
|
| 339 |
+
return path_text
|
| 340 |
+
if path_text == "." or path_text.startswith("./"):
|
| 341 |
+
return path_text
|
| 342 |
+
return f"./{path_text}"
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def get_model_base_dir() -> Path:
|
| 346 |
+
configured = os.getenv("LANCE_MODEL_BASE_DIR")
|
| 347 |
+
if configured:
|
| 348 |
+
return Path(configured).expanduser()
|
| 349 |
+
if LOCAL_MODEL_BASE_DIR.exists():
|
| 350 |
+
return LOCAL_MODEL_BASE_DIR
|
| 351 |
+
return SPACE_MODEL_BASE_DIR if running_on_space() else LOCAL_MODEL_BASE_DIR
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def normalize_model_variant(model_variant: Optional[str] = None) -> str:
|
| 355 |
+
variant = (model_variant or os.getenv("LANCE_MODEL_VARIANT", DEFAULT_MODEL_VARIANT)).strip().lower()
|
| 356 |
+
if variant in {"image", "t2i", "i2t"}:
|
| 357 |
+
return MODEL_VARIANT_IMAGE
|
| 358 |
+
return MODEL_VARIANT_VIDEO
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def get_model_path(model_variant: Optional[str] = None) -> Path:
|
| 362 |
+
variant = normalize_model_variant(model_variant)
|
| 363 |
+
variant_env_name = "LANCE_IMAGE_MODEL_PATH" if variant == MODEL_VARIANT_IMAGE else "LANCE_VIDEO_MODEL_PATH"
|
| 364 |
+
variant_configured = os.getenv(variant_env_name)
|
| 365 |
+
if variant_configured:
|
| 366 |
+
return Path(variant_configured).expanduser()
|
| 367 |
+
|
| 368 |
+
configured = os.getenv("LANCE_MODEL_PATH")
|
| 369 |
+
if configured:
|
| 370 |
+
return Path(configured).expanduser()
|
| 371 |
+
|
| 372 |
+
model_dir_name = MODEL_VARIANT_TO_DIR[variant]
|
| 373 |
+
return get_model_base_dir() / model_dir_name
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def get_required_model_asset_paths(model_base_dir: Path, model_path: Path) -> list[Path]:
|
| 377 |
+
return [
|
| 378 |
+
model_path / "llm_config.json",
|
| 379 |
+
model_path / "model.safetensors",
|
| 380 |
+
model_base_dir / "Qwen2.5-VL-ViT" / "vit.safetensors",
|
| 381 |
+
model_base_dir / "Wan2.2_VAE.pth",
|
| 382 |
+
]
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def ensure_model_assets(model_variant: Optional[str] = None) -> Path:
|
| 386 |
+
model_base_dir = get_model_base_dir()
|
| 387 |
+
os.environ["LANCE_MODEL_BASE_DIR"] = display_path(model_base_dir)
|
| 388 |
+
model_path = get_model_path(model_variant)
|
| 389 |
+
|
| 390 |
+
required_paths = get_required_model_asset_paths(model_base_dir, model_path)
|
| 391 |
+
if all(path.exists() for path in required_paths):
|
| 392 |
+
return model_path
|
| 393 |
+
|
| 394 |
+
downloads_model_base_dir = Path("downloads")
|
| 395 |
+
if model_base_dir == Path(".") and downloads_model_base_dir.exists():
|
| 396 |
+
downloads_model_path = downloads_model_base_dir / MODEL_VARIANT_TO_DIR[normalize_model_variant(model_variant)]
|
| 397 |
+
downloads_required_paths = get_required_model_asset_paths(downloads_model_base_dir, downloads_model_path)
|
| 398 |
+
if all(path.exists() for path in downloads_required_paths):
|
| 399 |
+
model_base_dir = downloads_model_base_dir
|
| 400 |
+
model_path = downloads_model_path
|
| 401 |
+
required_paths = downloads_required_paths
|
| 402 |
+
os.environ["LANCE_MODEL_BASE_DIR"] = display_path(model_base_dir)
|
| 403 |
+
return model_path
|
| 404 |
+
|
| 405 |
+
auto_download = env_flag("LANCE_AUTO_DOWNLOAD", running_on_space())
|
| 406 |
+
if not auto_download:
|
| 407 |
+
missing = "\n".join(f"- {display_path(path)}" for path in required_paths if not path.exists())
|
| 408 |
+
raise FileNotFoundError(
|
| 409 |
+
"Lance model assets are missing. Set LANCE_MODEL_BASE_DIR or enable "
|
| 410 |
+
f"LANCE_AUTO_DOWNLOAD=1.\nMissing files:\n{missing}"
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
model_base_dir.mkdir(parents=True, exist_ok=True)
|
| 414 |
+
repo_id = os.getenv("LANCE_MODEL_REPO_ID", DEFAULT_MODEL_REPO_ID)
|
| 415 |
+
print(f"[startup] Downloading Lance model assets from {repo_id} to {display_path(model_base_dir)}", flush=True)
|
| 416 |
+
snapshot_path = Path(
|
| 417 |
+
snapshot_download(
|
| 418 |
+
repo_id=repo_id,
|
| 419 |
+
local_dir=str(model_base_dir),
|
| 420 |
+
local_dir_use_symlinks=False,
|
| 421 |
+
resume_download=True,
|
| 422 |
+
)
|
| 423 |
+
)
|
| 424 |
+
if snapshot_path != model_base_dir and not model_path.exists():
|
| 425 |
+
os.environ["LANCE_MODEL_BASE_DIR"] = display_path(snapshot_path)
|
| 426 |
+
model_path = get_model_path(model_variant)
|
| 427 |
+
return model_path
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def ensure_dirs() -> None:
|
| 431 |
+
TMP_INPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 432 |
+
RESULTS_ROOT.mkdir(parents=True, exist_ok=True)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def save_generation_record(record: dict, save_dir: Path) -> None:
|
| 436 |
+
ensure_dirs()
|
| 437 |
+
run_record_path = save_dir / RUN_RECORD_FILENAME
|
| 438 |
+
with run_record_path.open("w", encoding="utf-8") as f:
|
| 439 |
+
json.dump(record, f, ensure_ascii=False, indent=2)
|
| 440 |
+
|
| 441 |
+
with RECORD_WRITE_LOCK:
|
| 442 |
+
with GLOBAL_RECORDS_FILE.open("a", encoding="utf-8") as f:
|
| 443 |
+
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def normalize_seed(seed: int) -> int:
|
| 447 |
+
return random.randint(0, 2**31 - 1) if seed == -1 else seed
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def normalize_task(task: str) -> str:
|
| 451 |
+
task_key = (task or TASK_LABEL_VIDEO_GENERATION).strip()
|
| 452 |
+
task = TASK_LABEL_TO_INTERNAL.get(task_key, TASK_LABEL_TO_INTERNAL.get(task_key.lower(), ""))
|
| 453 |
+
if task not in GENERATION_TASKS | UNDERSTANDING_TASKS:
|
| 454 |
+
raise ValueError(f"Unsupported task type: {task}")
|
| 455 |
+
return task
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def normalize_resolution_for_backend(resolution: str, task: str) -> str:
|
| 459 |
+
internal_task = normalize_task(task)
|
| 460 |
+
if internal_task in IMAGE_TASKS:
|
| 461 |
+
return DEFAULT_IMAGE_RESOLUTION
|
| 462 |
+
if internal_task in VIDEO_TASKS:
|
| 463 |
+
return DEFAULT_RESOLUTION
|
| 464 |
+
return str(resolution)
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def get_default_aspect_ratio(task: str) -> str:
|
| 468 |
+
internal_task = normalize_task(task)
|
| 469 |
+
return DEFAULT_IMAGE_ASPECT_RATIO if internal_task in IMAGE_TASKS else DEFAULT_VIDEO_ASPECT_RATIO
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def get_size_for_aspect_ratio(task: str, aspect_ratio: str) -> tuple[int, int]:
|
| 473 |
+
internal_task = normalize_task(task)
|
| 474 |
+
aspect_ratio = aspect_ratio if aspect_ratio in ASPECT_RATIO_CHOICES else get_default_aspect_ratio(internal_task)
|
| 475 |
+
size_map = IMAGE_ASPECT_RATIO_TO_SIZE if internal_task in IMAGE_TASKS else VIDEO_ASPECT_RATIO_TO_SIZE
|
| 476 |
+
return size_map[aspect_ratio]
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
def format_size_markdown(task: str, width: int, height: int) -> str:
|
| 480 |
+
internal_task = normalize_task(task)
|
| 481 |
+
if internal_task in UNDERSTANDING_TASKS:
|
| 482 |
+
return ""
|
| 483 |
+
#return f"**Output Resolution:** `{width} x {height}`"
|
| 484 |
+
return f"{width} x {height}"
|
| 485 |
+
|
| 486 |
+
def update_size_from_aspect_ratio(task: str, aspect_ratio: str):
|
| 487 |
+
width, height = get_size_for_aspect_ratio(task, aspect_ratio)
|
| 488 |
+
return height, width, format_size_markdown(task, width, height)
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def reset_generation_defaults_for_task(task: str):
|
| 492 |
+
internal_task = normalize_task(task)
|
| 493 |
+
aspect_ratio = get_default_aspect_ratio(internal_task)
|
| 494 |
+
width, height = get_size_for_aspect_ratio(internal_task, aspect_ratio)
|
| 495 |
+
resolution = DEFAULT_IMAGE_RESOLUTION if internal_task in IMAGE_TASKS else DEFAULT_RESOLUTION
|
| 496 |
+
num_frames = DEFAULT_NUM_FRAMES if internal_task == TASK_T2V else 1
|
| 497 |
+
return aspect_ratio, height, width, num_frames, resolution, format_size_markdown(internal_task, width, height)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def apply_prompt_example(task: str, evt: gr.SelectData):
|
| 501 |
+
prompt_text = ""
|
| 502 |
+
if isinstance(evt.row_value, list) and evt.row_value:
|
| 503 |
+
prompt_text = str(evt.row_value[0])
|
| 504 |
+
elif evt.value is not None:
|
| 505 |
+
prompt_text = str(evt.value)
|
| 506 |
+
defaults = reset_generation_defaults_for_task(task)
|
| 507 |
+
return (prompt_text, *defaults)
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def get_understanding_system_prompt_choices(task: str) -> list[str]:
|
| 511 |
+
internal_task = normalize_task(task)
|
| 512 |
+
if internal_task == TASK_X2T_IMAGE:
|
| 513 |
+
return [I2T_QA_SYSTEM_PROMPT]
|
| 514 |
+
return [V2T_QA_SYSTEM_PROMPT]
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def normalize_understanding_system_prompt(task: str, system_prompt: Optional[str]) -> str:
|
| 518 |
+
return get_understanding_system_prompt_choices(task)[0]
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def create_request_json(
|
| 522 |
+
task: str,
|
| 523 |
+
prompt: str,
|
| 524 |
+
input_video: Optional[str],
|
| 525 |
+
input_image: Optional[str],
|
| 526 |
+
system_prompt: Optional[str] = None,
|
| 527 |
+
) -> Path:
|
| 528 |
+
ensure_dirs()
|
| 529 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 530 |
+
prompt_file = TMP_INPUT_DIR / f"{task}_{timestamp}.json"
|
| 531 |
+
|
| 532 |
+
if task == TASK_T2V:
|
| 533 |
+
payload = {"000000.mp4": prompt}
|
| 534 |
+
elif task == TASK_T2I:
|
| 535 |
+
payload = {"000000.png": prompt}
|
| 536 |
+
elif task == TASK_VIDEO_EDIT:
|
| 537 |
+
if not input_video:
|
| 538 |
+
raise ValueError("The video edit task requires an input video.")
|
| 539 |
+
payload = {
|
| 540 |
+
"000000": {
|
| 541 |
+
"interleave_array": [prompt, input_video, input_video],
|
| 542 |
+
"element_dtype_array": ["text", "video", "video"],
|
| 543 |
+
"istarget_in_interleave": [0, 0, 1],
|
| 544 |
+
}
|
| 545 |
+
}
|
| 546 |
+
elif task == TASK_IMAGE_EDIT:
|
| 547 |
+
if not input_image:
|
| 548 |
+
raise ValueError("The image edit task requires an input image.")
|
| 549 |
+
payload = {
|
| 550 |
+
"000000": {
|
| 551 |
+
"interleave_array": [prompt, input_image, input_image],
|
| 552 |
+
"element_dtype_array": ["text", "image", "image"],
|
| 553 |
+
"istarget_in_interleave": [0, 0, 1],
|
| 554 |
+
}
|
| 555 |
+
}
|
| 556 |
+
elif task == TASK_X2T_VIDEO:
|
| 557 |
+
if not input_video:
|
| 558 |
+
raise ValueError("The video understanding task requires an input video.")
|
| 559 |
+
system_prompt = normalize_understanding_system_prompt(task, system_prompt)
|
| 560 |
+
payload = {
|
| 561 |
+
"000000": {
|
| 562 |
+
"interleave_array": [input_video, [system_prompt, prompt, ""]],
|
| 563 |
+
"element_dtype_array": ["video", "text"],
|
| 564 |
+
"istarget_in_interleave": [0, 1],
|
| 565 |
+
}
|
| 566 |
+
}
|
| 567 |
+
elif task == TASK_X2T_IMAGE:
|
| 568 |
+
if not input_image:
|
| 569 |
+
raise ValueError("The image understanding task requires an input image.")
|
| 570 |
+
system_prompt = normalize_understanding_system_prompt(task, system_prompt)
|
| 571 |
+
payload = {
|
| 572 |
+
"000000": {
|
| 573 |
+
"interleave_array": [input_image, [system_prompt, prompt, ""]],
|
| 574 |
+
"element_dtype_array": ["image", "text"],
|
| 575 |
+
"istarget_in_interleave": [0, 1],
|
| 576 |
+
}
|
| 577 |
+
}
|
| 578 |
+
else:
|
| 579 |
+
raise ValueError(f"Unsupported task type: {task}")
|
| 580 |
+
|
| 581 |
+
with prompt_file.open("w", encoding="utf-8") as f:
|
| 582 |
+
json.dump(payload, f, ensure_ascii=False, indent=2)
|
| 583 |
+
return prompt_file
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def resolve_example_path(path: str) -> str:
|
| 587 |
+
candidate = Path(path)
|
| 588 |
+
if candidate.is_absolute():
|
| 589 |
+
return str(candidate)
|
| 590 |
+
repo_candidate = (REPO_ROOT / candidate)
|
| 591 |
+
if repo_candidate.exists():
|
| 592 |
+
return str(repo_candidate.resolve())
|
| 593 |
+
if candidate.exists():
|
| 594 |
+
return str(candidate.resolve())
|
| 595 |
+
return path
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
def resolve_browser_video_example_path(path: str) -> str:
|
| 599 |
+
candidate = Path(path)
|
| 600 |
+
compatible_candidate = candidate.with_name(f"{candidate.stem}_h264{candidate.suffix}")
|
| 601 |
+
repo_compatible_candidate = REPO_ROOT / compatible_candidate
|
| 602 |
+
if not compatible_candidate.is_absolute() and repo_compatible_candidate.exists():
|
| 603 |
+
return str(repo_compatible_candidate.resolve())
|
| 604 |
+
if compatible_candidate.is_absolute() and compatible_candidate.exists():
|
| 605 |
+
return str(compatible_candidate.resolve())
|
| 606 |
+
repo_candidate = REPO_ROOT / candidate
|
| 607 |
+
if not candidate.is_absolute() and repo_candidate.exists():
|
| 608 |
+
return str(repo_candidate.resolve())
|
| 609 |
+
if candidate.is_absolute() and candidate.exists():
|
| 610 |
+
return str(candidate.resolve())
|
| 611 |
+
return resolve_example_path(path)
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
def load_json_examples(relative_path: str) -> dict:
|
| 615 |
+
path = REPO_ROOT / relative_path
|
| 616 |
+
with path.open("r", encoding="utf-8") as f:
|
| 617 |
+
return json.load(f)
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
T2V_EXAMPLE_SUMMARIES = {
|
| 621 |
+
"000000.mp4": "Red panda surfing on a bright seaside wave.",
|
| 622 |
+
"000002.mp4": "Panda cub skateboarding in a creative loft.",
|
| 623 |
+
"000004.mp4": "Young woman shaping clay in a sunlit pottery workshop.",
|
| 624 |
+
"000005.mp4": "Panda boxing a robot in a luxurious palace ring.",
|
| 625 |
+
"000008.mp4": "Fantasy pastel horse stepping through a glowing cloud valley.",
|
| 626 |
+
}
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
def make_generation_examples(
|
| 630 |
+
task_label: str,
|
| 631 |
+
relative_path: str,
|
| 632 |
+
limit: int,
|
| 633 |
+
image_task: bool,
|
| 634 |
+
selected_keys: Optional[list[str]] = None,
|
| 635 |
+
summaries: Optional[dict[str, str]] = None,
|
| 636 |
+
) -> list[list]:
|
| 637 |
+
data = load_json_examples(relative_path)
|
| 638 |
+
items = [(key, data[key]) for key in selected_keys if key in data] if selected_keys else list(data.items())[:limit]
|
| 639 |
+
examples = []
|
| 640 |
+
for output_name, prompt in items:
|
| 641 |
+
examples.append([prompt])
|
| 642 |
+
return examples
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
def make_edit_examples(task_label: str, relative_path: str, limit: int, media_type: str) -> list[list]:
|
| 646 |
+
data = load_json_examples(relative_path)
|
| 647 |
+
examples = []
|
| 648 |
+
for sample in list(data.values())[:limit]:
|
| 649 |
+
interleave = sample["interleave_array"]
|
| 650 |
+
prompt = interleave[0]
|
| 651 |
+
media_path = resolve_example_path(interleave[1])
|
| 652 |
+
examples.append([
|
| 653 |
+
prompt,
|
| 654 |
+
media_path if media_type == "video" else None,
|
| 655 |
+
media_path if media_type == "image" else None,
|
| 656 |
+
])
|
| 657 |
+
return examples
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
def make_understanding_examples(task_label: str, relative_path: str, limit: int, media_type: str) -> list[list]:
|
| 661 |
+
data = load_json_examples(relative_path)
|
| 662 |
+
examples = []
|
| 663 |
+
for sample in list(data.values())[:limit]:
|
| 664 |
+
interleave = sample["interleave_array"]
|
| 665 |
+
media_path = (
|
| 666 |
+
resolve_browser_video_example_path(interleave[0])
|
| 667 |
+
if media_type == "video"
|
| 668 |
+
else resolve_example_path(interleave[0])
|
| 669 |
+
)
|
| 670 |
+
text_payload = interleave[1]
|
| 671 |
+
question = text_payload[1] if isinstance(text_payload, list) and len(text_payload) > 1 else ""
|
| 672 |
+
examples.append([
|
| 673 |
+
question,
|
| 674 |
+
media_path if media_type == "video" else None,
|
| 675 |
+
media_path if media_type == "image" else None,
|
| 676 |
+
])
|
| 677 |
+
return examples
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
def make_understanding_system_prompt_map(relative_path: str, task: str) -> dict[str, str]:
|
| 681 |
+
data = load_json_examples(relative_path)
|
| 682 |
+
system_prompts = {}
|
| 683 |
+
for sample in data.values():
|
| 684 |
+
interleave = sample["interleave_array"]
|
| 685 |
+
text_payload = interleave[1]
|
| 686 |
+
if not isinstance(text_payload, list) or len(text_payload) < 2:
|
| 687 |
+
continue
|
| 688 |
+
system_prompts[text_payload[1]] = normalize_understanding_system_prompt(task, text_payload[0])
|
| 689 |
+
return system_prompts
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
VIDEO_GENERATION_EXAMPLES = make_generation_examples(
|
| 693 |
+
TASK_LABEL_VIDEO_GENERATION,
|
| 694 |
+
"config/examples/t2v_example.json",
|
| 695 |
+
limit=6,
|
| 696 |
+
image_task=False,
|
| 697 |
+
#selected_keys=["000000.mp4", "000002.mp4", "000005.mp4", "000004.mp4", "000008.mp4"],
|
| 698 |
+
selected_keys=["000004.mp4", "000005.mp4", "000002.mp4", "000000.mp4", "000008.mp4", "000007.mp4"],
|
| 699 |
+
summaries=T2V_EXAMPLE_SUMMARIES,
|
| 700 |
+
)
|
| 701 |
+
VIDEO_EDIT_EXAMPLES = make_edit_examples(
|
| 702 |
+
TASK_LABEL_VIDEO_EDIT,
|
| 703 |
+
"config/examples/video_edit_example.json",
|
| 704 |
+
limit=3,
|
| 705 |
+
media_type="video",
|
| 706 |
+
)
|
| 707 |
+
VIDEO_UNDERSTANDING_EXAMPLES = make_understanding_examples(
|
| 708 |
+
TASK_LABEL_VIDEO_UNDERSTANDING,
|
| 709 |
+
"config/examples/x2t_video_example.json",
|
| 710 |
+
limit=3,
|
| 711 |
+
media_type="video",
|
| 712 |
+
)
|
| 713 |
+
VIDEO_UNDERSTANDING_SYSTEM_PROMPTS = make_understanding_system_prompt_map(
|
| 714 |
+
"config/examples/x2t_video_example.json",
|
| 715 |
+
TASK_X2T_VIDEO,
|
| 716 |
+
)
|
| 717 |
+
IMAGE_GENERATION_EXAMPLES = make_generation_examples(
|
| 718 |
+
TASK_LABEL_IMAGE_GENERATION,
|
| 719 |
+
"config/examples/t2i_example.json",
|
| 720 |
+
limit=5,
|
| 721 |
+
image_task=True,
|
| 722 |
+
selected_keys=["000000.png", "000003.png", "000006.png", "000008.png", "000009.png"],
|
| 723 |
+
)
|
| 724 |
+
IMAGE_EDIT_EXAMPLES = make_edit_examples(
|
| 725 |
+
TASK_LABEL_IMAGE_EDIT,
|
| 726 |
+
"config/examples/image_edit_example.json",
|
| 727 |
+
limit=5,
|
| 728 |
+
media_type="image",
|
| 729 |
+
)
|
| 730 |
+
IMAGE_UNDERSTANDING_EXAMPLES = make_understanding_examples(
|
| 731 |
+
TASK_LABEL_IMAGE_UNDERSTANDING,
|
| 732 |
+
"config/examples/x2t_image_example.json",
|
| 733 |
+
limit=3,
|
| 734 |
+
media_type="image",
|
| 735 |
+
)
|
| 736 |
+
IMAGE_UNDERSTANDING_SYSTEM_PROMPTS = make_understanding_system_prompt_map(
|
| 737 |
+
"config/examples/x2t_image_example.json",
|
| 738 |
+
TASK_X2T_IMAGE,
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
def build_save_dir(task: str) -> Path:
|
| 743 |
+
ensure_dirs()
|
| 744 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 745 |
+
return RESULTS_ROOT / f"{task}_{timestamp}_{int(time.time() * 1000) % 1000:03d}"
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
def find_generated_video(save_dir: Path) -> Optional[Path]:
|
| 749 |
+
videos = sorted(save_dir.glob("*.mp4"), key=lambda p: p.stat().st_mtime, reverse=True)
|
| 750 |
+
return videos[0] if videos else None
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
def find_generated_image(save_dir: Path) -> Optional[Path]:
|
| 754 |
+
images = sorted(save_dir.glob("*.png"), key=lambda p: p.stat().st_mtime, reverse=True)
|
| 755 |
+
return images[0] if images else None
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
def run_rife_interpolation(video_path: Path, device_id: int, exp: int = 1) -> tuple[Path, str]:
|
| 759 |
+
rife_dir = REPO_ROOT / "RIFE"
|
| 760 |
+
rife_script = rife_dir / "inference_video.py"
|
| 761 |
+
if not rife_script.exists():
|
| 762 |
+
raise FileNotFoundError(f"RIFE inference script not found: {rife_script}")
|
| 763 |
+
|
| 764 |
+
output_path = video_path.with_name(f"{video_path.stem}_rife_{2 ** exp}x{video_path.suffix}")
|
| 765 |
+
env = os.environ.copy()
|
| 766 |
+
env["CUDA_VISIBLE_DEVICES"] = str(device_id)
|
| 767 |
+
command = [
|
| 768 |
+
"python3",
|
| 769 |
+
str(rife_script),
|
| 770 |
+
"--exp",
|
| 771 |
+
str(exp),
|
| 772 |
+
"--video",
|
| 773 |
+
str(video_path),
|
| 774 |
+
"--output",
|
| 775 |
+
str(output_path),
|
| 776 |
+
"--model",
|
| 777 |
+
str(rife_dir / "train_log"),
|
| 778 |
+
]
|
| 779 |
+
rife_start = time.perf_counter()
|
| 780 |
+
try:
|
| 781 |
+
completed = subprocess.run(
|
| 782 |
+
command,
|
| 783 |
+
cwd=str(video_path.parent),
|
| 784 |
+
env=env,
|
| 785 |
+
check=True,
|
| 786 |
+
capture_output=True,
|
| 787 |
+
text=True,
|
| 788 |
+
)
|
| 789 |
+
except subprocess.CalledProcessError as exc:
|
| 790 |
+
raise RuntimeError(
|
| 791 |
+
"\n".join(
|
| 792 |
+
[
|
| 793 |
+
f"RIFE failed with exit code {exc.returncode}.",
|
| 794 |
+
f"command=CUDA_VISIBLE_DEVICES={device_id} {' '.join(command)}",
|
| 795 |
+
exc.stdout.strip() if exc.stdout else "",
|
| 796 |
+
exc.stderr.strip() if exc.stderr else "",
|
| 797 |
+
]
|
| 798 |
+
).strip()
|
| 799 |
+
) from exc
|
| 800 |
+
if not output_path.exists():
|
| 801 |
+
raise FileNotFoundError(f"RIFE completed but output video was not found: {output_path}")
|
| 802 |
+
elapsed = time.perf_counter() - rife_start
|
| 803 |
+
log = "\n".join(
|
| 804 |
+
[
|
| 805 |
+
"[rife] Frame interpolation finished.",
|
| 806 |
+
f"command=CUDA_VISIBLE_DEVICES={device_id} {' '.join(command)}",
|
| 807 |
+
f"elapsed={elapsed:.2f}s",
|
| 808 |
+
f"output={output_path}",
|
| 809 |
+
completed.stdout.strip(),
|
| 810 |
+
completed.stderr.strip(),
|
| 811 |
+
]
|
| 812 |
+
).strip()
|
| 813 |
+
return output_path, log
|
| 814 |
+
|
| 815 |
+
|
| 816 |
+
def extract_text_result(save_dir: Path) -> str:
|
| 817 |
+
prompt_result_path = save_dir / PROMPT_JSON_FILENAME
|
| 818 |
+
if not prompt_result_path.exists():
|
| 819 |
+
return ""
|
| 820 |
+
with prompt_result_path.open("r", encoding="utf-8") as f:
|
| 821 |
+
data = json.load(f)
|
| 822 |
+
if not data:
|
| 823 |
+
return ""
|
| 824 |
+
first_value = next(iter(data.values()))
|
| 825 |
+
return first_value if isinstance(first_value, str) else json.dumps(first_value, ensure_ascii=False)
|
| 826 |
+
|
| 827 |
+
|
| 828 |
+
class LanceT2VV2TPipeline:
|
| 829 |
+
def __init__(self, device_id: int, model_variant: str = MODEL_VARIANT_VIDEO) -> None:
|
| 830 |
+
self._init_lock = threading.Lock()
|
| 831 |
+
self._generate_lock = threading.Lock()
|
| 832 |
+
self.initialized = False
|
| 833 |
+
self.device = device_id
|
| 834 |
+
self.model_variant = normalize_model_variant(model_variant)
|
| 835 |
+
self.logger = get_logger(f"lance_{self.model_variant}_gpu{device_id}")
|
| 836 |
+
|
| 837 |
+
self.model: Optional[Lance] = None
|
| 838 |
+
self.vae_model: Optional[WanVideoVAE] = None
|
| 839 |
+
self.vae_config: Optional[AutoEncoderParams] = None
|
| 840 |
+
self.tokenizer: Optional[Qwen2Tokenizer] = None
|
| 841 |
+
self.new_token_ids: Optional[dict] = None
|
| 842 |
+
self.image_token_id: Optional[int] = None
|
| 843 |
+
self.base_model_args: Optional[ModelArguments] = None
|
| 844 |
+
self.base_data_args: Optional[DataArguments] = None
|
| 845 |
+
self.base_inference_args: Optional[InferenceArguments] = None
|
| 846 |
+
|
| 847 |
+
def _log_stage(self, stage_name: str, start_time: float, extra: str = "") -> None:
|
| 848 |
+
elapsed = time.perf_counter() - start_time
|
| 849 |
+
suffix = f" | {extra}" if extra else ""
|
| 850 |
+
print(f"[startup][gpu:{self.device}] {stage_name} done in {elapsed:.2f}s{suffix}", flush=True)
|
| 851 |
+
|
| 852 |
+
def _build_base_model_args(self) -> ModelArguments:
|
| 853 |
+
model_path = str(get_model_path(self.model_variant))
|
| 854 |
+
return ModelArguments(
|
| 855 |
+
model_path=model_path,
|
| 856 |
+
vit_type=DEFAULT_VIT_TYPE,
|
| 857 |
+
llm_qk_norm=True,
|
| 858 |
+
llm_qk_norm_und=True,
|
| 859 |
+
llm_qk_norm_gen=True,
|
| 860 |
+
tie_word_embeddings=False,
|
| 861 |
+
max_num_frames=121,
|
| 862 |
+
max_latent_size=64,
|
| 863 |
+
latent_patch_size=[1, 1, 1],
|
| 864 |
+
)
|
| 865 |
+
|
| 866 |
+
def _build_base_inference_args(self) -> InferenceArguments:
|
| 867 |
+
return InferenceArguments(
|
| 868 |
+
validation_num_timesteps=DEFAULT_TIMESTEPS,
|
| 869 |
+
validation_timestep_shift=DEFAULT_TIMESTEP_SHIFT,
|
| 870 |
+
copy_init_moe=True,
|
| 871 |
+
visual_und=True,
|
| 872 |
+
visual_gen=True,
|
| 873 |
+
vae_model_type="wan",
|
| 874 |
+
apply_qwen_2_5_vl_pos_emb=True,
|
| 875 |
+
apply_chat_template=False,
|
| 876 |
+
cfg_type=0,
|
| 877 |
+
validation_data_seed=42,
|
| 878 |
+
video_height=DEFAULT_HEIGHT,
|
| 879 |
+
video_width=DEFAULT_WIDTH,
|
| 880 |
+
num_frames=DEFAULT_NUM_FRAMES,
|
| 881 |
+
task=DEFAULT_TASK,
|
| 882 |
+
save_path_gen=str(RESULTS_ROOT),
|
| 883 |
+
resolution=DEFAULT_RESOLUTION,
|
| 884 |
+
text_template=TEXT_TEMPLATE,
|
| 885 |
+
use_KVcache=USE_KVCACHE,
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
def initialize(self) -> None:
|
| 889 |
+
with self._init_lock:
|
| 890 |
+
if self.initialized:
|
| 891 |
+
return
|
| 892 |
+
|
| 893 |
+
ensure_dirs()
|
| 894 |
+
resolved_model_path = ensure_model_assets(self.model_variant)
|
| 895 |
+
print(
|
| 896 |
+
f"[startup][gpu:{self.device}][{self.model_variant}] Using Lance model path: {resolved_model_path}",
|
| 897 |
+
flush=True,
|
| 898 |
+
)
|
| 899 |
+
if not torch.cuda.is_available():
|
| 900 |
+
raise RuntimeError("CUDA is unavailable. Lance T2V/V2T Gradio requires a GPU environment.")
|
| 901 |
+
if self.device >= torch.cuda.device_count():
|
| 902 |
+
raise RuntimeError(
|
| 903 |
+
f"GPU {self.device} is unavailable. Detected {torch.cuda.device_count()} GPU(s)."
|
| 904 |
+
)
|
| 905 |
+
torch.cuda.set_device(self.device)
|
| 906 |
+
|
| 907 |
+
model_args = self._build_base_model_args()
|
| 908 |
+
data_args = DataArguments()
|
| 909 |
+
inference_args = self._build_base_inference_args()
|
| 910 |
+
apply_inference_defaults(model_args, data_args, inference_args)
|
| 911 |
+
inference_args.validation_noise_seed = inference_args.validation_data_seed
|
| 912 |
+
|
| 913 |
+
self.base_model_args = model_args
|
| 914 |
+
self.base_data_args = data_args
|
| 915 |
+
self.base_inference_args = inference_args
|
| 916 |
+
|
| 917 |
+
set_seed(inference_args.global_seed)
|
| 918 |
+
|
| 919 |
+
stage_start = time.perf_counter()
|
| 920 |
+
print(
|
| 921 |
+
f"[startup][gpu:{self.device}] Loading LLM config: {Path(model_args.model_path) / 'llm_config.json'}",
|
| 922 |
+
flush=True,
|
| 923 |
+
)
|
| 924 |
+
llm_config: Qwen2Config = Qwen2Config.from_json_file(str(Path(model_args.model_path) / "llm_config.json"))
|
| 925 |
+
self._log_stage("LLM config load", stage_start)
|
| 926 |
+
|
| 927 |
+
llm_config.layer_module = model_args.layer_module
|
| 928 |
+
llm_config.qk_norm = model_args.llm_qk_norm
|
| 929 |
+
llm_config.qk_norm_und = model_args.llm_qk_norm_und
|
| 930 |
+
llm_config.qk_norm_gen = model_args.llm_qk_norm_gen
|
| 931 |
+
llm_config.tie_word_embeddings = model_args.tie_word_embeddings
|
| 932 |
+
llm_config.freeze_und = inference_args.freeze_und
|
| 933 |
+
llm_config.apply_qwen_2_5_vl_pos_emb = inference_args.apply_qwen_2_5_vl_pos_emb
|
| 934 |
+
|
| 935 |
+
stage_start = time.perf_counter()
|
| 936 |
+
print(f"[startup][gpu:{self.device}] Initializing LLM weights: {model_args.model_path}", flush=True)
|
| 937 |
+
language_model: Qwen2ForCausalLM = Qwen2ForCausalLM(llm_config)
|
| 938 |
+
self._log_stage("LLM weight init", stage_start)
|
| 939 |
+
|
| 940 |
+
vit_model = None
|
| 941 |
+
vit_config = None
|
| 942 |
+
if inference_args.visual_und:
|
| 943 |
+
if model_args.vit_type not in ("qwen2_5_vl", "qwen_2_5_vl_original"):
|
| 944 |
+
raise ValueError(f"Unsupported vit_type: {model_args.vit_type}")
|
| 945 |
+
stage_start = time.perf_counter()
|
| 946 |
+
print(f"[startup][gpu:{self.device}] Loading VIT config: {model_args.vit_path}", flush=True)
|
| 947 |
+
vit_config = Qwen2_5_VLVisionConfig.from_pretrained(model_args.vit_path)
|
| 948 |
+
self._log_stage("VIT config load", stage_start)
|
| 949 |
+
|
| 950 |
+
stage_start = time.perf_counter()
|
| 951 |
+
print(
|
| 952 |
+
f"[startup][gpu:{self.device}] Loading VIT weights: {Path(model_args.vit_path) / 'vit.safetensors'}",
|
| 953 |
+
flush=True,
|
| 954 |
+
)
|
| 955 |
+
vit_model = Qwen2_5_VisionTransformerPretrainedModel(vit_config)
|
| 956 |
+
vit_weights = load_file(str(Path(model_args.vit_path) / "vit.safetensors"))
|
| 957 |
+
vit_model.load_state_dict(vit_weights, strict=True)
|
| 958 |
+
self._log_stage("VIT weight load", stage_start)
|
| 959 |
+
clean_memory(vit_weights)
|
| 960 |
+
|
| 961 |
+
if inference_args.visual_gen:
|
| 962 |
+
stage_start = time.perf_counter()
|
| 963 |
+
print(f"[startup][gpu:{self.device}] Initializing VAE", flush=True)
|
| 964 |
+
vae_model = WanVideoVAE()
|
| 965 |
+
vae_config = deepcopy(vae_model.vae_config)
|
| 966 |
+
self._log_stage("VAE init", stage_start)
|
| 967 |
+
else:
|
| 968 |
+
vae_model = None
|
| 969 |
+
vae_config = None
|
| 970 |
+
|
| 971 |
+
config = LanceConfig(
|
| 972 |
+
visual_gen=inference_args.visual_gen,
|
| 973 |
+
visual_und=inference_args.visual_und,
|
| 974 |
+
llm_config=llm_config,
|
| 975 |
+
vit_config=vit_config if inference_args.visual_und else None,
|
| 976 |
+
vae_config=vae_config if inference_args.visual_gen else None,
|
| 977 |
+
latent_patch_size=model_args.latent_patch_size,
|
| 978 |
+
max_num_frames=model_args.max_num_frames,
|
| 979 |
+
max_latent_size=model_args.max_latent_size,
|
| 980 |
+
vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side,
|
| 981 |
+
connector_act=model_args.connector_act,
|
| 982 |
+
interpolate_pos=model_args.interpolate_pos,
|
| 983 |
+
timestep_shift=inference_args.timestep_shift,
|
| 984 |
+
)
|
| 985 |
+
model: Lance = Lance(
|
| 986 |
+
language_model=language_model,
|
| 987 |
+
vit_model=vit_model if inference_args.visual_und else None,
|
| 988 |
+
vit_type=model_args.vit_type,
|
| 989 |
+
config=config,
|
| 990 |
+
training_args=inference_args,
|
| 991 |
+
)
|
| 992 |
+
|
| 993 |
+
stage_start = time.perf_counter()
|
| 994 |
+
print(f"[startup][gpu:{self.device}] Moving Lance model to GPU {self.device}", flush=True)
|
| 995 |
+
model = model.to(self.device)
|
| 996 |
+
self._log_stage("Lance model move to GPU", stage_start)
|
| 997 |
+
|
| 998 |
+
stage_start = time.perf_counter()
|
| 999 |
+
print(f"[startup][gpu:{self.device}] Loading tokenizer: {model_args.model_path}", flush=True)
|
| 1000 |
+
tokenizer: Qwen2Tokenizer = Qwen2Tokenizer.from_pretrained(model_args.model_path)
|
| 1001 |
+
tokenizer, new_token_ids, num_new_tokens = add_special_tokens(tokenizer)
|
| 1002 |
+
self._log_stage("tokenizer load and special token init", stage_start, extra=f"num_new_tokens={num_new_tokens}")
|
| 1003 |
+
|
| 1004 |
+
if inference_args.copy_init_moe:
|
| 1005 |
+
language_model.init_moe()
|
| 1006 |
+
|
| 1007 |
+
init_from_model_path_if_needed(model, model_args)
|
| 1008 |
+
|
| 1009 |
+
if num_new_tokens > 0:
|
| 1010 |
+
model.language_model.resize_token_embeddings(len(tokenizer))
|
| 1011 |
+
model.config.llm_config.vocab_size = len(tokenizer)
|
| 1012 |
+
model.language_model.config.vocab_size = len(tokenizer)
|
| 1013 |
+
|
| 1014 |
+
if model_args.vit_type.lower() == "qwen2_5_vl":
|
| 1015 |
+
from common.model.hacks import hack_qwen2_5_vl_config
|
| 1016 |
+
|
| 1017 |
+
language_model = hack_qwen2_5_vl_config(language_model)
|
| 1018 |
+
|
| 1019 |
+
image_token_id = language_model.config.video_token_id
|
| 1020 |
+
new_token_ids.update({"image_token_id": image_token_id})
|
| 1021 |
+
model.update_tokenizer(tokenizer=tokenizer)
|
| 1022 |
+
|
| 1023 |
+
if model_args.tie_word_embeddings:
|
| 1024 |
+
model.language_model.untie_lm_head()
|
| 1025 |
+
model.language_model.copy_new_token_rows_to_lm_head(num_new_tokens)
|
| 1026 |
+
model_args.tie_word_embeddings = False
|
| 1027 |
+
llm_config.tie_word_embeddings = False
|
| 1028 |
+
else:
|
| 1029 |
+
assert (
|
| 1030 |
+
model.language_model.get_input_embeddings().weight.data.data_ptr()
|
| 1031 |
+
!= model.language_model.get_output_embeddings().weight.data.data_ptr()
|
| 1032 |
+
), "tie_word_embeddings conflict"
|
| 1033 |
+
|
| 1034 |
+
model = model.to(device=self.device, dtype=torch.bfloat16)
|
| 1035 |
+
model.eval()
|
| 1036 |
+
if vae_model is not None and hasattr(vae_model, "eval"):
|
| 1037 |
+
vae_model.eval()
|
| 1038 |
+
|
| 1039 |
+
self.model = model
|
| 1040 |
+
self.vae_model = vae_model
|
| 1041 |
+
self.vae_config = vae_config
|
| 1042 |
+
self.tokenizer = tokenizer
|
| 1043 |
+
self.new_token_ids = new_token_ids
|
| 1044 |
+
self.image_token_id = image_token_id
|
| 1045 |
+
self.initialized = True
|
| 1046 |
+
print(
|
| 1047 |
+
f"[startup][gpu:{self.device}][{self.model_variant}] Lance multimodal Gradio model loaded and ready for reuse.",
|
| 1048 |
+
flush=True,
|
| 1049 |
+
)
|
| 1050 |
+
|
| 1051 |
+
def unload(self) -> None:
|
| 1052 |
+
with self._init_lock:
|
| 1053 |
+
if self.model is not None:
|
| 1054 |
+
self.model.cpu()
|
| 1055 |
+
if self.vae_model is not None and hasattr(self.vae_model, "vae"):
|
| 1056 |
+
vae_inner = self.vae_model.vae
|
| 1057 |
+
if hasattr(vae_inner, "model"):
|
| 1058 |
+
vae_inner.model.cpu()
|
| 1059 |
+
|
| 1060 |
+
self.model = None
|
| 1061 |
+
self.vae_model = None
|
| 1062 |
+
self.vae_config = None
|
| 1063 |
+
self.tokenizer = None
|
| 1064 |
+
self.new_token_ids = None
|
| 1065 |
+
self.image_token_id = None
|
| 1066 |
+
self.base_model_args = None
|
| 1067 |
+
self.base_data_args = None
|
| 1068 |
+
self.base_inference_args = None
|
| 1069 |
+
self.initialized = False
|
| 1070 |
+
gc.collect()
|
| 1071 |
+
if torch.cuda.is_available():
|
| 1072 |
+
with torch.cuda.device(self.device):
|
| 1073 |
+
torch.cuda.empty_cache()
|
| 1074 |
+
torch.cuda.ipc_collect()
|
| 1075 |
+
|
| 1076 |
+
def _build_request_batch(
|
| 1077 |
+
self,
|
| 1078 |
+
prompt_file: Path,
|
| 1079 |
+
model_args: ModelArguments,
|
| 1080 |
+
data_args: DataArguments,
|
| 1081 |
+
inference_args: InferenceArguments,
|
| 1082 |
+
):
|
| 1083 |
+
assert self.tokenizer is not None
|
| 1084 |
+
assert self.new_token_ids is not None
|
| 1085 |
+
assert self.vae_config is not None
|
| 1086 |
+
|
| 1087 |
+
dataset_config = DataConfig.from_yaml(str(prompt_file))
|
| 1088 |
+
if inference_args.visual_und:
|
| 1089 |
+
dataset_config.vit_patch_size = model_args.vit_patch_size
|
| 1090 |
+
dataset_config.vit_patch_size_temporal = model_args.vit_patch_size_temporal
|
| 1091 |
+
dataset_config.vit_max_num_patch_per_side = model_args.vit_max_num_patch_per_side
|
| 1092 |
+
if inference_args.visual_gen:
|
| 1093 |
+
vae_downsample = tuple_mul(
|
| 1094 |
+
tuple(model_args.latent_patch_size),
|
| 1095 |
+
(
|
| 1096 |
+
self.vae_config.downsample_temporal,
|
| 1097 |
+
self.vae_config.downsample_spatial,
|
| 1098 |
+
self.vae_config.downsample_spatial,
|
| 1099 |
+
),
|
| 1100 |
+
)
|
| 1101 |
+
dataset_config.latent_patch_size = model_args.latent_patch_size
|
| 1102 |
+
dataset_config.vae_downsample = vae_downsample
|
| 1103 |
+
dataset_config.max_latent_size = model_args.max_latent_size
|
| 1104 |
+
dataset_config.max_num_frames = model_args.max_num_frames
|
| 1105 |
+
|
| 1106 |
+
dataset_config.text_cond_dropout_prob = model_args.text_cond_dropout_prob
|
| 1107 |
+
dataset_config.vae_cond_dropout_prob = model_args.vae_cond_dropout_prob
|
| 1108 |
+
dataset_config.vit_cond_dropout_prob = model_args.vit_cond_dropout_prob
|
| 1109 |
+
|
| 1110 |
+
dataset_config.num_frames = inference_args.num_frames
|
| 1111 |
+
dataset_config.H = inference_args.video_height
|
| 1112 |
+
dataset_config.W = inference_args.video_width
|
| 1113 |
+
dataset_config.task = inference_args.task
|
| 1114 |
+
dataset_config.resolution = inference_args.resolution
|
| 1115 |
+
dataset_config.text_template = inference_args.text_template
|
| 1116 |
+
|
| 1117 |
+
val_dataset = ValidationDataset(
|
| 1118 |
+
jsonl_path=str(prompt_file),
|
| 1119 |
+
tokenizer=self.tokenizer,
|
| 1120 |
+
data_args=data_args,
|
| 1121 |
+
model_args=model_args,
|
| 1122 |
+
training_args=inference_args,
|
| 1123 |
+
new_token_ids=self.new_token_ids,
|
| 1124 |
+
dataset_config=dataset_config,
|
| 1125 |
+
local_rank=0,
|
| 1126 |
+
world_size=1,
|
| 1127 |
+
)
|
| 1128 |
+
return simple_custom_collate([val_dataset[0]])
|
| 1129 |
+
|
| 1130 |
+
def generate(
|
| 1131 |
+
self,
|
| 1132 |
+
task: str,
|
| 1133 |
+
prompt: str,
|
| 1134 |
+
system_prompt: Optional[str],
|
| 1135 |
+
input_video: Optional[str],
|
| 1136 |
+
input_image: Optional[str],
|
| 1137 |
+
height: int,
|
| 1138 |
+
width: int,
|
| 1139 |
+
num_frames: int,
|
| 1140 |
+
seed: int,
|
| 1141 |
+
resolution: str,
|
| 1142 |
+
validation_num_timesteps: int,
|
| 1143 |
+
validation_timestep_shift: float,
|
| 1144 |
+
cfg_text_scale: float,
|
| 1145 |
+
enable_frame_interpolation: bool,
|
| 1146 |
+
):
|
| 1147 |
+
self.initialize()
|
| 1148 |
+
internal_task = normalize_task(task)
|
| 1149 |
+
prompt = (prompt or "").strip()
|
| 1150 |
+
input_video = str(input_video).strip() if input_video else ""
|
| 1151 |
+
input_image = str(input_image).strip() if input_image else ""
|
| 1152 |
+
|
| 1153 |
+
if internal_task in GENERATION_TASKS and not prompt:
|
| 1154 |
+
return None, None, "", "Please enter a prompt.", ""
|
| 1155 |
+
if internal_task in UNDERSTANDING_TASKS and not prompt:
|
| 1156 |
+
return None, None, "", "Please enter a question.", ""
|
| 1157 |
+
if internal_task in {TASK_VIDEO_EDIT, TASK_X2T_VIDEO} and not input_video:
|
| 1158 |
+
return None, None, "", "Please upload an input video.", ""
|
| 1159 |
+
if internal_task in {TASK_IMAGE_EDIT, TASK_X2T_IMAGE} and not input_image:
|
| 1160 |
+
return None, None, "", "Please upload an input image.", ""
|
| 1161 |
+
if height <= 0 or width <= 0:
|
| 1162 |
+
return None, None, "", "Height and width must be greater than 0.", ""
|
| 1163 |
+
if num_frames <= 0:
|
| 1164 |
+
return None, None, "", "The number of frames must be greater than 0.", ""
|
| 1165 |
+
|
| 1166 |
+
assert self.model is not None
|
| 1167 |
+
assert self.tokenizer is not None
|
| 1168 |
+
assert self.new_token_ids is not None
|
| 1169 |
+
assert self.image_token_id is not None
|
| 1170 |
+
assert self.base_model_args is not None
|
| 1171 |
+
assert self.base_data_args is not None
|
| 1172 |
+
assert self.base_inference_args is not None
|
| 1173 |
+
active_model_path = self.base_model_args.model_path
|
| 1174 |
+
|
| 1175 |
+
with self._generate_lock:
|
| 1176 |
+
torch.cuda.set_device(self.device)
|
| 1177 |
+
actual_seed = normalize_seed(int(seed))
|
| 1178 |
+
prompt_file = create_request_json(
|
| 1179 |
+
task=internal_task,
|
| 1180 |
+
prompt=prompt,
|
| 1181 |
+
input_video=input_video,
|
| 1182 |
+
input_image=input_image,
|
| 1183 |
+
system_prompt=system_prompt,
|
| 1184 |
+
)
|
| 1185 |
+
save_dir = build_save_dir(internal_task)
|
| 1186 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 1187 |
+
request_started_at = datetime.now().isoformat(timespec="seconds")
|
| 1188 |
+
|
| 1189 |
+
request_model_args = deepcopy(self.base_model_args)
|
| 1190 |
+
request_model_args.cfg_text_scale = float(cfg_text_scale)
|
| 1191 |
+
|
| 1192 |
+
request_data_args = deepcopy(self.base_data_args)
|
| 1193 |
+
request_data_args.val_dataset_config_file = str(prompt_file)
|
| 1194 |
+
|
| 1195 |
+
request_inference_args = deepcopy(self.base_inference_args)
|
| 1196 |
+
request_inference_args.validation_num_timesteps = int(validation_num_timesteps)
|
| 1197 |
+
request_inference_args.validation_timestep_shift = float(validation_timestep_shift)
|
| 1198 |
+
request_inference_args.validation_data_seed = actual_seed
|
| 1199 |
+
request_inference_args.validation_noise_seed = actual_seed
|
| 1200 |
+
request_inference_args.video_height = int(height)
|
| 1201 |
+
request_inference_args.video_width = int(width)
|
| 1202 |
+
request_inference_args.num_frames = int(num_frames)
|
| 1203 |
+
display_resolution = str(resolution)
|
| 1204 |
+
backend_resolution = normalize_resolution_for_backend(display_resolution, internal_task)
|
| 1205 |
+
request_inference_args.resolution = backend_resolution
|
| 1206 |
+
request_inference_args.save_path_gen = str(save_dir)
|
| 1207 |
+
request_inference_args.task = internal_task
|
| 1208 |
+
request_inference_args.text_template = TEXT_TEMPLATE
|
| 1209 |
+
request_inference_args.prompt_data_dict = {}
|
| 1210 |
+
|
| 1211 |
+
try:
|
| 1212 |
+
print(
|
| 1213 |
+
"[lance_gradio_t2v_v2t] Start generation "
|
| 1214 |
+
f"| task={internal_task} | gpu={self.device} | seed={actual_seed} | "
|
| 1215 |
+
f"size={height}x{width} | frames={num_frames} | resolution={display_resolution}",
|
| 1216 |
+
flush=True,
|
| 1217 |
+
)
|
| 1218 |
+
val_data_cpu = self._build_request_batch(
|
| 1219 |
+
prompt_file=prompt_file,
|
| 1220 |
+
model_args=request_model_args,
|
| 1221 |
+
data_args=request_data_args,
|
| 1222 |
+
inference_args=request_inference_args,
|
| 1223 |
+
)
|
| 1224 |
+
generate_start = time.perf_counter()
|
| 1225 |
+
validate_on_fixed_batch(
|
| 1226 |
+
fsdp_model=self.model,
|
| 1227 |
+
vae_model=self.vae_model,
|
| 1228 |
+
tokenizer=self.tokenizer,
|
| 1229 |
+
val_data_cpu=val_data_cpu,
|
| 1230 |
+
training_args=request_inference_args,
|
| 1231 |
+
model_args=request_model_args,
|
| 1232 |
+
inference_args=request_inference_args,
|
| 1233 |
+
new_token_ids=self.new_token_ids,
|
| 1234 |
+
image_token_id=self.image_token_id,
|
| 1235 |
+
device=self.device,
|
| 1236 |
+
save_source_video=False,
|
| 1237 |
+
save_path_gen=request_inference_args.save_path_gen,
|
| 1238 |
+
save_path_gt="",
|
| 1239 |
+
)
|
| 1240 |
+
elapsed = time.perf_counter() - generate_start
|
| 1241 |
+
save_prompt_results(request_inference_args.prompt_data_dict, request_inference_args.save_path_gen, self.logger)
|
| 1242 |
+
clean_memory()
|
| 1243 |
+
|
| 1244 |
+
video_path = find_generated_video(save_dir) if internal_task in {TASK_T2V, TASK_VIDEO_EDIT} else None
|
| 1245 |
+
original_video_path = video_path
|
| 1246 |
+
rife_log = ""
|
| 1247 |
+
rife_error = ""
|
| 1248 |
+
frame_interpolation_enabled = bool(enable_frame_interpolation) and internal_task in {TASK_T2V, TASK_VIDEO_EDIT}
|
| 1249 |
+
if frame_interpolation_enabled and video_path is not None:
|
| 1250 |
+
try:
|
| 1251 |
+
clean_memory()
|
| 1252 |
+
print(
|
| 1253 |
+
"[rife] Start frame interpolation "
|
| 1254 |
+
f"| task={internal_task} | gpu={self.device} | input={video_path}",
|
| 1255 |
+
flush=True,
|
| 1256 |
+
)
|
| 1257 |
+
video_path, rife_log = run_rife_interpolation(video_path, self.device, exp=1)
|
| 1258 |
+
except Exception:
|
| 1259 |
+
rife_error = traceback.format_exc()
|
| 1260 |
+
print(rife_error, flush=True)
|
| 1261 |
+
image_path = find_generated_image(save_dir) if internal_task in {TASK_T2I, TASK_IMAGE_EDIT} else None
|
| 1262 |
+
text_result = extract_text_result(save_dir) if internal_task in UNDERSTANDING_TASKS else ""
|
| 1263 |
+
record = {
|
| 1264 |
+
"request_started_at": request_started_at,
|
| 1265 |
+
"request_finished_at": datetime.now().isoformat(timespec="seconds"),
|
| 1266 |
+
"status": "success",
|
| 1267 |
+
"task": internal_task,
|
| 1268 |
+
"model_variant": self.model_variant,
|
| 1269 |
+
"model_path": active_model_path,
|
| 1270 |
+
"gpu": self.device,
|
| 1271 |
+
"prompt": prompt,
|
| 1272 |
+
"system_prompt": normalize_understanding_system_prompt(internal_task, system_prompt)
|
| 1273 |
+
if internal_task in UNDERSTANDING_TASKS
|
| 1274 |
+
else "",
|
| 1275 |
+
"input_video": input_video,
|
| 1276 |
+
"input_image": input_image,
|
| 1277 |
+
"seed": actual_seed,
|
| 1278 |
+
"height": int(height),
|
| 1279 |
+
"width": int(width),
|
| 1280 |
+
"num_frames": int(num_frames),
|
| 1281 |
+
"resolution": display_resolution,
|
| 1282 |
+
"backend_resolution": backend_resolution,
|
| 1283 |
+
"validation_num_timesteps": int(validation_num_timesteps),
|
| 1284 |
+
"validation_timestep_shift": float(validation_timestep_shift),
|
| 1285 |
+
"cfg_text_scale": float(cfg_text_scale),
|
| 1286 |
+
"frame_interpolation": frame_interpolation_enabled,
|
| 1287 |
+
"elapsed_seconds": round(elapsed, 3),
|
| 1288 |
+
"prompt_file": str(prompt_file),
|
| 1289 |
+
"output_dir": str(save_dir),
|
| 1290 |
+
"original_video_path": str(original_video_path) if original_video_path is not None else "",
|
| 1291 |
+
"video_path": str(video_path) if video_path is not None else "",
|
| 1292 |
+
"image_path": str(image_path) if image_path is not None else "",
|
| 1293 |
+
"text_result": text_result,
|
| 1294 |
+
"rife_error": rife_error,
|
| 1295 |
+
}
|
| 1296 |
+
if internal_task in {TASK_T2V, TASK_VIDEO_EDIT} and video_path is None:
|
| 1297 |
+
record["status"] = "completed_without_video"
|
| 1298 |
+
if internal_task in {TASK_T2I, TASK_IMAGE_EDIT} and image_path is None:
|
| 1299 |
+
record["status"] = "completed_without_image"
|
| 1300 |
+
if internal_task in UNDERSTANDING_TASKS and not text_result:
|
| 1301 |
+
record["status"] = "completed_without_text"
|
| 1302 |
+
save_generation_record(record, save_dir)
|
| 1303 |
+
|
| 1304 |
+
logs = "\n".join(
|
| 1305 |
+
[
|
| 1306 |
+
"[lance_gradio_t2v_v2t] Inference finished in-process.",
|
| 1307 |
+
f"task={internal_task}",
|
| 1308 |
+
f"model_variant={self.model_variant}",
|
| 1309 |
+
f"model_path={active_model_path}",
|
| 1310 |
+
f"gpu={self.device}",
|
| 1311 |
+
f"seed={actual_seed}",
|
| 1312 |
+
f"height={height}",
|
| 1313 |
+
f"width={width}",
|
| 1314 |
+
f"num_frames={num_frames}",
|
| 1315 |
+
f"resolution={display_resolution}",
|
| 1316 |
+
f"backend_resolution={backend_resolution}",
|
| 1317 |
+
f"validation_num_timesteps={validation_num_timesteps}",
|
| 1318 |
+
f"validation_timestep_shift={validation_timestep_shift}",
|
| 1319 |
+
f"cfg_text_scale={cfg_text_scale}",
|
| 1320 |
+
f"frame_interpolation={frame_interpolation_enabled}",
|
| 1321 |
+
f"original_video_path={original_video_path or ''}",
|
| 1322 |
+
f"rife_error={rife_error.strip() if rife_error else ''}",
|
| 1323 |
+
f"elapsed={elapsed:.2f}s",
|
| 1324 |
+
f"output_dir={save_dir}",
|
| 1325 |
+
rife_log,
|
| 1326 |
+
]
|
| 1327 |
+
)
|
| 1328 |
+
|
| 1329 |
+
if internal_task in {TASK_T2V, TASK_VIDEO_EDIT}:
|
| 1330 |
+
if video_path is None:
|
| 1331 |
+
status = (
|
| 1332 |
+
"Inference completed, but no output video was found.\n\n"
|
| 1333 |
+
f"- Task: `{internal_task}`\n"
|
| 1334 |
+
f"- Model: `{self.model_variant}`\n"
|
| 1335 |
+
f"- Model path: `{active_model_path}`\n"
|
| 1336 |
+
f"- GPU: `{self.device}`\n"
|
| 1337 |
+
f"- Actual seed: `{actual_seed}`\n"
|
| 1338 |
+
f"- Output directory: `{save_dir}`"
|
| 1339 |
+
)
|
| 1340 |
+
return None, None, "", status, logs
|
| 1341 |
+
# status = (
|
| 1342 |
+
# "Inference completed.\n\n"
|
| 1343 |
+
# f"- Task: `{internal_task}`\n"
|
| 1344 |
+
# f"- Model: `{self.model_variant}`\n"
|
| 1345 |
+
# f"- Model path: `{active_model_path}`\n"
|
| 1346 |
+
# f"- GPU: `{self.device}`\n"
|
| 1347 |
+
# f"- Actual seed: `{actual_seed}`\n"
|
| 1348 |
+
# f"- Output directory: `{save_dir}`\n"
|
| 1349 |
+
# f"- Result file: `{video_path}`"
|
| 1350 |
+
# )
|
| 1351 |
+
status = ""
|
| 1352 |
+
return str(video_path), None, "", status, logs
|
| 1353 |
+
|
| 1354 |
+
if internal_task in {TASK_T2I, TASK_IMAGE_EDIT}:
|
| 1355 |
+
if image_path is None:
|
| 1356 |
+
status = (
|
| 1357 |
+
"Inference completed, but no output image was found.\n\n"
|
| 1358 |
+
f"- Task: `{internal_task}`\n"
|
| 1359 |
+
f"- Model: `{self.model_variant}`\n"
|
| 1360 |
+
f"- Model path: `{active_model_path}`\n"
|
| 1361 |
+
f"- GPU: `{self.device}`\n"
|
| 1362 |
+
f"- Actual seed: `{actual_seed}`\n"
|
| 1363 |
+
f"- Output directory: `{save_dir}`"
|
| 1364 |
+
)
|
| 1365 |
+
return None, None, "", status, logs
|
| 1366 |
+
# status = (
|
| 1367 |
+
# "Inference completed.\n\n"
|
| 1368 |
+
# f"- Task: `{internal_task}`\n"
|
| 1369 |
+
# f"- Model: `{self.model_variant}`\n"
|
| 1370 |
+
# f"- Model path: `{active_model_path}`\n"
|
| 1371 |
+
# f"- GPU: `{self.device}`\n"
|
| 1372 |
+
# f"- Actual seed: `{actual_seed}`\n"
|
| 1373 |
+
# f"- Output directory: `{save_dir}`\n"
|
| 1374 |
+
# f"- Result file: `{image_path}`"
|
| 1375 |
+
# )
|
| 1376 |
+
status = ""
|
| 1377 |
+
return None, str(image_path), "", status, logs
|
| 1378 |
+
|
| 1379 |
+
# status = (
|
| 1380 |
+
# "Understanding completed.\n\n"
|
| 1381 |
+
# f"- Task: `{task}`\n"
|
| 1382 |
+
# f"- Model: `{self.model_variant}`\n"
|
| 1383 |
+
# f"- Model path: `{active_model_path}`\n"
|
| 1384 |
+
# f"- GPU: `{self.device}`\n"
|
| 1385 |
+
# f"- Actual seed: `{actual_seed}`\n"
|
| 1386 |
+
# f"- Output directory: `{save_dir}`"
|
| 1387 |
+
# )
|
| 1388 |
+
status = ""
|
| 1389 |
+
return None, None, text_result, status, logs
|
| 1390 |
+
except Exception:
|
| 1391 |
+
error_trace = traceback.format_exc()
|
| 1392 |
+
print(error_trace, flush=True)
|
| 1393 |
+
record = {
|
| 1394 |
+
"request_started_at": request_started_at,
|
| 1395 |
+
"request_finished_at": datetime.now().isoformat(timespec="seconds"),
|
| 1396 |
+
"status": "failed",
|
| 1397 |
+
"task": internal_task,
|
| 1398 |
+
"model_variant": self.model_variant,
|
| 1399 |
+
"model_path": active_model_path,
|
| 1400 |
+
"gpu": self.device,
|
| 1401 |
+
"prompt": prompt,
|
| 1402 |
+
"input_video": input_video,
|
| 1403 |
+
"input_image": input_image,
|
| 1404 |
+
"seed": actual_seed,
|
| 1405 |
+
"height": int(height),
|
| 1406 |
+
"width": int(width),
|
| 1407 |
+
"num_frames": int(num_frames),
|
| 1408 |
+
"resolution": display_resolution,
|
| 1409 |
+
"backend_resolution": backend_resolution,
|
| 1410 |
+
"validation_num_timesteps": int(validation_num_timesteps),
|
| 1411 |
+
"validation_timestep_shift": float(validation_timestep_shift),
|
| 1412 |
+
"cfg_text_scale": float(cfg_text_scale),
|
| 1413 |
+
"prompt_file": str(prompt_file),
|
| 1414 |
+
"output_dir": str(save_dir),
|
| 1415 |
+
"video_path": "",
|
| 1416 |
+
"image_path": "",
|
| 1417 |
+
"text_result": "",
|
| 1418 |
+
"error": error_trace,
|
| 1419 |
+
}
|
| 1420 |
+
save_generation_record(record, save_dir)
|
| 1421 |
+
status = (
|
| 1422 |
+
"Inference failed.\n\n"
|
| 1423 |
+
f"- Task: `{internal_task}`\n"
|
| 1424 |
+
f"- Model: `{self.model_variant}`\n"
|
| 1425 |
+
f"- Model path: `{active_model_path}`\n"
|
| 1426 |
+
f"- GPU: `{self.device}`\n"
|
| 1427 |
+
f"- Actual seed: `{actual_seed}`\n"
|
| 1428 |
+
f"- Resolution: `{display_resolution}`\n"
|
| 1429 |
+
f"- Output directory: `{save_dir}`"
|
| 1430 |
+
)
|
| 1431 |
+
return None, None, "", status, error_trace
|
| 1432 |
+
|
| 1433 |
+
|
| 1434 |
+
class PipelinePool:
|
| 1435 |
+
def __init__(self, gpu_ids: list[int], model_variant: str = MODEL_VARIANT_VIDEO) -> None:
|
| 1436 |
+
if not gpu_ids:
|
| 1437 |
+
raise ValueError("At least one GPU must be configured.")
|
| 1438 |
+
self.gpu_ids = gpu_ids
|
| 1439 |
+
self.model_variant = normalize_model_variant(model_variant)
|
| 1440 |
+
self.pipelines = [
|
| 1441 |
+
LanceT2VV2TPipeline(device_id=gpu_id, model_variant=self.model_variant)
|
| 1442 |
+
for gpu_id in gpu_ids
|
| 1443 |
+
]
|
| 1444 |
+
self._available = deque(self.pipelines)
|
| 1445 |
+
self._condition = threading.Condition()
|
| 1446 |
+
|
| 1447 |
+
@property
|
| 1448 |
+
def size(self) -> int:
|
| 1449 |
+
return len(self.pipelines)
|
| 1450 |
+
|
| 1451 |
+
@property
|
| 1452 |
+
def gpu_summary(self) -> str:
|
| 1453 |
+
return ",".join(str(gpu_id) for gpu_id in self.gpu_ids)
|
| 1454 |
+
|
| 1455 |
+
def initialize_all(self) -> None:
|
| 1456 |
+
print(f"[startup][{self.model_variant}] Preparing parallel GPU preload: {self.gpu_ids}", flush=True)
|
| 1457 |
+
exceptions: list[Exception] = []
|
| 1458 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=self.size) as executor:
|
| 1459 |
+
futures = {
|
| 1460 |
+
executor.submit(pipeline.initialize): pipeline.device for pipeline in self.pipelines
|
| 1461 |
+
}
|
| 1462 |
+
for future in concurrent.futures.as_completed(futures):
|
| 1463 |
+
gpu_id = futures[future]
|
| 1464 |
+
try:
|
| 1465 |
+
future.result()
|
| 1466 |
+
except Exception as exc:
|
| 1467 |
+
print(f"[startup][gpu:{gpu_id}][{self.model_variant}] Preload failed: {exc}", flush=True)
|
| 1468 |
+
exceptions.append(exc)
|
| 1469 |
+
if exceptions:
|
| 1470 |
+
raise RuntimeError(
|
| 1471 |
+
f"{self.model_variant} preload failed on {len(exceptions)} GPU(s). Please check the terminal logs."
|
| 1472 |
+
) from exceptions[0]
|
| 1473 |
+
print(
|
| 1474 |
+
f"[startup][{self.model_variant}] GPU preload finished. Ready to handle {self.size} concurrent request(s).",
|
| 1475 |
+
flush=True,
|
| 1476 |
+
)
|
| 1477 |
+
|
| 1478 |
+
def acquire(self) -> LanceT2VV2TPipeline:
|
| 1479 |
+
with self._condition:
|
| 1480 |
+
while not self._available:
|
| 1481 |
+
self._condition.wait()
|
| 1482 |
+
return self._available.popleft()
|
| 1483 |
+
|
| 1484 |
+
def release(self, pipeline: LanceT2VV2TPipeline) -> None:
|
| 1485 |
+
with self._condition:
|
| 1486 |
+
self._available.append(pipeline)
|
| 1487 |
+
self._condition.notify()
|
| 1488 |
+
|
| 1489 |
+
def unload_all(self) -> None:
|
| 1490 |
+
print(f"[runtime][{self.model_variant}] Unloading model pool from GPU(s): {self.gpu_ids}", flush=True)
|
| 1491 |
+
with self._condition:
|
| 1492 |
+
while len(self._available) != len(self.pipelines):
|
| 1493 |
+
self._condition.wait()
|
| 1494 |
+
|
| 1495 |
+
for pipeline in self.pipelines:
|
| 1496 |
+
pipeline.unload()
|
| 1497 |
+
|
| 1498 |
+
gc.collect()
|
| 1499 |
+
if torch.cuda.is_available():
|
| 1500 |
+
torch.cuda.empty_cache()
|
| 1501 |
+
torch.cuda.ipc_collect()
|
| 1502 |
+
print(f"[runtime][{self.model_variant}] Model pool unloaded.", flush=True)
|
| 1503 |
+
|
| 1504 |
+
def generate(
|
| 1505 |
+
self,
|
| 1506 |
+
task: str,
|
| 1507 |
+
prompt: str,
|
| 1508 |
+
system_prompt: Optional[str],
|
| 1509 |
+
input_video: Optional[str],
|
| 1510 |
+
input_image: Optional[str],
|
| 1511 |
+
height: int,
|
| 1512 |
+
width: int,
|
| 1513 |
+
num_frames: int,
|
| 1514 |
+
seed: int,
|
| 1515 |
+
resolution: str,
|
| 1516 |
+
validation_num_timesteps: int,
|
| 1517 |
+
validation_timestep_shift: float,
|
| 1518 |
+
cfg_text_scale: float,
|
| 1519 |
+
enable_frame_interpolation: bool,
|
| 1520 |
+
):
|
| 1521 |
+
pipeline = self.acquire()
|
| 1522 |
+
try:
|
| 1523 |
+
return pipeline.generate(
|
| 1524 |
+
task=task,
|
| 1525 |
+
prompt=prompt,
|
| 1526 |
+
system_prompt=system_prompt,
|
| 1527 |
+
input_video=input_video,
|
| 1528 |
+
input_image=input_image,
|
| 1529 |
+
height=height,
|
| 1530 |
+
width=width,
|
| 1531 |
+
num_frames=num_frames,
|
| 1532 |
+
seed=seed,
|
| 1533 |
+
resolution=resolution,
|
| 1534 |
+
validation_num_timesteps=validation_num_timesteps,
|
| 1535 |
+
validation_timestep_shift=validation_timestep_shift,
|
| 1536 |
+
cfg_text_scale=cfg_text_scale,
|
| 1537 |
+
enable_frame_interpolation=enable_frame_interpolation,
|
| 1538 |
+
)
|
| 1539 |
+
finally:
|
| 1540 |
+
self.release(pipeline)
|
| 1541 |
+
|
| 1542 |
+
|
| 1543 |
+
ACTIVE_PIPELINE_POOL: Optional[PipelinePool] = None
|
| 1544 |
+
ACTIVE_POOL_LOCK = threading.Lock()
|
| 1545 |
+
QUEUE_MAX_SIZE = DEFAULT_QUEUE_SIZE
|
| 1546 |
+
|
| 1547 |
+
|
| 1548 |
+
def get_task_model_variant(task: str) -> str:
|
| 1549 |
+
internal_task = normalize_task(task)
|
| 1550 |
+
return MODEL_VARIANT_IMAGE if internal_task in IMAGE_TASKS else MODEL_VARIANT_VIDEO
|
| 1551 |
+
|
| 1552 |
+
|
| 1553 |
+
def get_pipeline_pool(task: str) -> PipelinePool:
|
| 1554 |
+
global ACTIVE_PIPELINE_POOL
|
| 1555 |
+
model_variant = get_task_model_variant(task)
|
| 1556 |
+
with ACTIVE_POOL_LOCK:
|
| 1557 |
+
if ACTIVE_PIPELINE_POOL is not None and ACTIVE_PIPELINE_POOL.model_variant == model_variant:
|
| 1558 |
+
return ACTIVE_PIPELINE_POOL
|
| 1559 |
+
|
| 1560 |
+
gpu_ids = parse_gpu_ids(os.getenv("LANCE_GPUS", DEFAULT_GPUS))
|
| 1561 |
+
if ACTIVE_PIPELINE_POOL is not None:
|
| 1562 |
+
previous_variant = ACTIVE_PIPELINE_POOL.model_variant
|
| 1563 |
+
print(
|
| 1564 |
+
f"[runtime] Switching Lance model from {previous_variant} to {model_variant}.",
|
| 1565 |
+
flush=True,
|
| 1566 |
+
)
|
| 1567 |
+
ACTIVE_PIPELINE_POOL.unload_all()
|
| 1568 |
+
ACTIVE_PIPELINE_POOL = None
|
| 1569 |
+
|
| 1570 |
+
ACTIVE_PIPELINE_POOL = PipelinePool(gpu_ids, model_variant=model_variant)
|
| 1571 |
+
ACTIVE_PIPELINE_POOL.initialize_all()
|
| 1572 |
+
return ACTIVE_PIPELINE_POOL
|
| 1573 |
+
|
| 1574 |
+
|
| 1575 |
+
def run_task(
|
| 1576 |
+
task: str,
|
| 1577 |
+
prompt: str,
|
| 1578 |
+
system_prompt: Optional[str],
|
| 1579 |
+
input_video: Optional[str],
|
| 1580 |
+
input_image: Optional[str],
|
| 1581 |
+
height: int,
|
| 1582 |
+
width: int,
|
| 1583 |
+
num_frames: int,
|
| 1584 |
+
seed: int,
|
| 1585 |
+
resolution: str,
|
| 1586 |
+
validation_num_timesteps: int,
|
| 1587 |
+
validation_timestep_shift: float,
|
| 1588 |
+
cfg_text_scale: float,
|
| 1589 |
+
enable_frame_interpolation: bool,
|
| 1590 |
+
):
|
| 1591 |
+
pipeline_pool = get_pipeline_pool(task)
|
| 1592 |
+
return pipeline_pool.generate(
|
| 1593 |
+
task=task,
|
| 1594 |
+
prompt=prompt,
|
| 1595 |
+
system_prompt=system_prompt,
|
| 1596 |
+
input_video=input_video,
|
| 1597 |
+
input_image=input_image,
|
| 1598 |
+
height=height,
|
| 1599 |
+
width=width,
|
| 1600 |
+
num_frames=num_frames,
|
| 1601 |
+
seed=seed,
|
| 1602 |
+
resolution=resolution,
|
| 1603 |
+
validation_num_timesteps=validation_num_timesteps,
|
| 1604 |
+
validation_timestep_shift=validation_timestep_shift,
|
| 1605 |
+
cfg_text_scale=cfg_text_scale,
|
| 1606 |
+
enable_frame_interpolation=enable_frame_interpolation,
|
| 1607 |
+
)
|
| 1608 |
+
|
| 1609 |
+
|
| 1610 |
+
def build_status_markdown() -> str:
|
| 1611 |
+
gpu_text = "unknown"
|
| 1612 |
+
concurrency = 1
|
| 1613 |
+
active_variant = "none"
|
| 1614 |
+
if ACTIVE_PIPELINE_POOL is not None:
|
| 1615 |
+
active_variant = ACTIVE_PIPELINE_POOL.model_variant
|
| 1616 |
+
gpu_text = ACTIVE_PIPELINE_POOL.gpu_summary
|
| 1617 |
+
concurrency = ACTIVE_PIPELINE_POOL.size
|
| 1618 |
+
return (
|
| 1619 |
+
f"**Status** GPU: `{gpu_text}` | Max concurrency: `{concurrency}` | "
|
| 1620 |
+
f"Queue limit: `{QUEUE_MAX_SIZE}` | Active model: `{active_variant}` | "
|
| 1621 |
+
f"Switch mode: `unload then load`"
|
| 1622 |
+
)
|
| 1623 |
+
|
| 1624 |
+
|
| 1625 |
+
def get_logo_data_uri() -> str:
|
| 1626 |
+
if not LANCE_LOGO_PATH.exists():
|
| 1627 |
+
return ""
|
| 1628 |
+
encoded_logo = base64.b64encode(LANCE_LOGO_PATH.read_bytes()).decode("ascii")
|
| 1629 |
+
return f"data:image/webp;base64,{encoded_logo}"
|
| 1630 |
+
|
| 1631 |
+
|
| 1632 |
+
def build_header_html() -> str:
|
| 1633 |
+
logo_data_uri = get_logo_data_uri()
|
| 1634 |
+
logo_html = (
|
| 1635 |
+
f'<img class="lance-logo" src="{logo_data_uri}" alt="Lance logo">'
|
| 1636 |
+
if logo_data_uri
|
| 1637 |
+
else ""
|
| 1638 |
+
)
|
| 1639 |
+
return f"""
|
| 1640 |
+
<div class="lance-hero">
|
| 1641 |
+
{logo_html}
|
| 1642 |
+
<h1 class="lance-title">Lance: Unified Multimodal Modeling by Multi-Task Synergy</h1>
|
| 1643 |
+
<div class="lance-authors">
|
| 1644 |
+
<strong>
|
| 1645 |
+
<a href="https://scholar.google.com.hk/citations?user=FXxoQlsAAAAJ&hl=zh-CN&oi=ao" target="_blank">Fengyi Fu</a><sup>*</sup>,
|
| 1646 |
+
<a href="https://corleone-huang.github.io/" target="_blank">Mengqi Huang</a><sup>*,✉</sup>,
|
| 1647 |
+
<a href="https://scholar.google.com.hk/citations?user=9ER6nVkAAAAJ&hl=zh-CN&oi=ao" target="_blank">Shaojin Wu</a><sup>*</sup>,
|
| 1648 |
+
Yunsheng Jiang<sup>*</sup>,
|
| 1649 |
+
Yufei Huo,
|
| 1650 |
+
<a href="https://guojianzhu.com/" target="_blank">Jianzhu Guo</a><sup>✉,§</sup>
|
| 1651 |
+
</strong><br>
|
| 1652 |
+
Hao Li, Yinghang Song, Fei Ding, Qian He, Zheren Fu, Zhendong Mao, Yongdong Zhang<br>
|
| 1653 |
+
<em>ByteDance</em>
|
| 1654 |
+
</div>
|
| 1655 |
+
<div class="lance-badges">
|
| 1656 |
+
<a href="{LANCE_HOMEPAGE_URL}" target="_blank" rel="noopener noreferrer">
|
| 1657 |
+
<img alt="Homepage" src="https://img.shields.io/badge/Homepage-Lance-blue?style=flat">
|
| 1658 |
+
</a>
|
| 1659 |
+
<a href="{LANCE_PAPER_URL}" target="_blank" rel="noopener noreferrer">
|
| 1660 |
+
<img alt="Paper" src="https://img.shields.io/badge/Paper-arXiv-red?style=flat&logo=arxiv">
|
| 1661 |
+
</a>
|
| 1662 |
+
<a href="{LANCE_HUGGING_FACE_URL}" target="_blank" rel="noopener noreferrer">
|
| 1663 |
+
<img alt="Hugging Face" src="https://img.shields.io/badge/Model-HuggingFace-yellow?style=flat&logo=huggingface">
|
| 1664 |
+
</a>
|
| 1665 |
+
<a href="{LANCE_GITHUB_URL}" target="_blank" rel="noopener noreferrer">
|
| 1666 |
+
<img alt="GitHub" src="https://img.shields.io/badge/Code-GitHub-536af5?color=536af5&logo=github">
|
| 1667 |
+
</a>
|
| 1668 |
+
</div>
|
| 1669 |
+
</div>
|
| 1670 |
+
"""
|
| 1671 |
+
|
| 1672 |
+
|
| 1673 |
+
def update_task_ui(task: str):
|
| 1674 |
+
internal_task = normalize_task(task)
|
| 1675 |
+
is_image_task = internal_task in IMAGE_TASKS
|
| 1676 |
+
is_video_task = internal_task in VIDEO_TASKS
|
| 1677 |
+
is_edit_task = internal_task in EDIT_TASKS
|
| 1678 |
+
is_understanding_task = internal_task in UNDERSTANDING_TASKS
|
| 1679 |
+
is_generation_task = internal_task in GENERATION_TASKS
|
| 1680 |
+
show_media_input = is_edit_task or is_understanding_task
|
| 1681 |
+
resolution_choices = IMAGE_RESOLUTION_CHOICES if is_image_task else VIDEO_RESOLUTION_CHOICES
|
| 1682 |
+
resolution_value = DEFAULT_IMAGE_RESOLUTION if is_image_task else DEFAULT_RESOLUTION
|
| 1683 |
+
aspect_ratio_value = DEFAULT_IMAGE_ASPECT_RATIO if is_image_task else DEFAULT_VIDEO_ASPECT_RATIO
|
| 1684 |
+
width_value, height_value = get_size_for_aspect_ratio(internal_task, aspect_ratio_value)
|
| 1685 |
+
size_markdown = format_size_markdown(internal_task, width_value, height_value)
|
| 1686 |
+
system_prompt_choices = get_understanding_system_prompt_choices(internal_task)
|
| 1687 |
+
|
| 1688 |
+
if is_generation_task:
|
| 1689 |
+
text_label = "Prompt"
|
| 1690 |
+
text_placeholder = "Describe what you want to generate..."
|
| 1691 |
+
elif is_edit_task:
|
| 1692 |
+
text_label = "Instruction"
|
| 1693 |
+
text_placeholder = "Describe the edit you want..."
|
| 1694 |
+
else:
|
| 1695 |
+
text_label = "Question"
|
| 1696 |
+
text_placeholder = "Ask a question about the input..."
|
| 1697 |
+
|
| 1698 |
+
return (
|
| 1699 |
+
gr.update(
|
| 1700 |
+
label=text_label,
|
| 1701 |
+
placeholder=text_placeholder,
|
| 1702 |
+
visible=True,
|
| 1703 |
+
),
|
| 1704 |
+
gr.update(
|
| 1705 |
+
choices=system_prompt_choices,
|
| 1706 |
+
value=system_prompt_choices[0],
|
| 1707 |
+
visible=False,
|
| 1708 |
+
),
|
| 1709 |
+
gr.update(label="Input Video", visible=show_media_input and is_video_task),
|
| 1710 |
+
gr.update(label="Input Image", visible=show_media_input and is_image_task),
|
| 1711 |
+
gr.update(value=aspect_ratio_value, visible=is_generation_task or is_edit_task),
|
| 1712 |
+
gr.update(value=height_value),
|
| 1713 |
+
gr.update(value=width_value),
|
| 1714 |
+
gr.update(value=size_markdown, visible=is_generation_task or is_edit_task),
|
| 1715 |
+
gr.update(visible=internal_task == TASK_T2V, value=DEFAULT_NUM_FRAMES),
|
| 1716 |
+
gr.update(visible=internal_task in {TASK_T2V, TASK_VIDEO_EDIT}, value=DEFAULT_FRAME_INTERPOLATION),
|
| 1717 |
+
gr.update(choices=resolution_choices, value=resolution_value, visible=False),
|
| 1718 |
+
gr.update(visible=internal_task in {TASK_T2V, TASK_VIDEO_EDIT}),
|
| 1719 |
+
gr.update(visible=internal_task in {TASK_T2I, TASK_IMAGE_EDIT}),
|
| 1720 |
+
gr.update(visible=is_understanding_task, value=""),
|
| 1721 |
+
gr.update(visible=internal_task == TASK_T2V),
|
| 1722 |
+
gr.update(visible=internal_task == TASK_VIDEO_EDIT),
|
| 1723 |
+
gr.update(visible=internal_task == TASK_X2T_VIDEO),
|
| 1724 |
+
gr.update(visible=internal_task == TASK_T2I),
|
| 1725 |
+
gr.update(visible=internal_task == TASK_IMAGE_EDIT),
|
| 1726 |
+
gr.update(visible=internal_task == TASK_X2T_IMAGE),
|
| 1727 |
+
)
|
| 1728 |
+
|
| 1729 |
+
|
| 1730 |
+
def keep_example_clicks_from_changing_visibility(*examples_components) -> None:
|
| 1731 |
+
for examples_component in examples_components:
|
| 1732 |
+
dataset = getattr(examples_component, "dataset", None)
|
| 1733 |
+
component_props = getattr(dataset, "component_props", None)
|
| 1734 |
+
if not component_props:
|
| 1735 |
+
continue
|
| 1736 |
+
for props in component_props:
|
| 1737 |
+
props.pop("visible", None)
|
| 1738 |
+
|
| 1739 |
+
|
| 1740 |
+
def build_demo() -> gr.Blocks:
|
| 1741 |
+
with gr.Blocks(title="Lance", css=APP_CSS) as demo:
|
| 1742 |
+
gr.HTML(build_header_html())
|
| 1743 |
+
gr.Markdown(build_status_markdown(), elem_classes=["lance-status"], visible=False)
|
| 1744 |
+
|
| 1745 |
+
with gr.Row(elem_classes=["lance-main-row"]):
|
| 1746 |
+
with gr.Column(scale=1, elem_classes=["lance-main-column"]):
|
| 1747 |
+
task = gr.Radio(
|
| 1748 |
+
label="Task",
|
| 1749 |
+
choices=TASK_CHOICES,
|
| 1750 |
+
value=TASK_LABEL_VIDEO_GENERATION,
|
| 1751 |
+
elem_classes=["task-selector"],
|
| 1752 |
+
)
|
| 1753 |
+
prompt = gr.Textbox(
|
| 1754 |
+
label="Prompt",
|
| 1755 |
+
lines=6,
|
| 1756 |
+
placeholder="Describe the video you want to generate...",
|
| 1757 |
+
)
|
| 1758 |
+
system_prompt = gr.Dropdown(
|
| 1759 |
+
label="System Prompt",
|
| 1760 |
+
choices=get_understanding_system_prompt_choices(TASK_X2T_VIDEO),
|
| 1761 |
+
value=V2T_QA_SYSTEM_PROMPT,
|
| 1762 |
+
visible=False,
|
| 1763 |
+
)
|
| 1764 |
+
input_video = gr.Video(label="Input Video", visible=False, elem_classes=["lance-display-frame"])
|
| 1765 |
+
input_image = gr.Image(label="Input Image", type="filepath", visible=False, elem_classes=["lance-display-frame"])
|
| 1766 |
+
with gr.Row():
|
| 1767 |
+
seed = gr.Number(
|
| 1768 |
+
label="Seed (-1 for random seed)",
|
| 1769 |
+
value=DEFAULT_BASIC_SEED,
|
| 1770 |
+
precision=0,
|
| 1771 |
+
# info="-1 for random seed",
|
| 1772 |
+
)
|
| 1773 |
+
aspect_ratio = gr.Dropdown(
|
| 1774 |
+
label="Aspect Ratio",
|
| 1775 |
+
# choices=ASPECT_RATIO_CHOICES, # 原始版本,不显示 是否为 default
|
| 1776 |
+
choices=get_aspect_ratio_choices_for_task(TASK_T2V),
|
| 1777 |
+
value=DEFAULT_VIDEO_ASPECT_RATIO,
|
| 1778 |
+
)
|
| 1779 |
+
# real_size = gr.Markdown(format_size_markdown(TASK_T2V, DEFAULT_WIDTH, DEFAULT_HEIGHT))
|
| 1780 |
+
real_size = gr.Textbox(
|
| 1781 |
+
label="Output Resolution",
|
| 1782 |
+
value=format_size_markdown(TASK_T2V, DEFAULT_WIDTH, DEFAULT_HEIGHT),
|
| 1783 |
+
interactive=False,
|
| 1784 |
+
)
|
| 1785 |
+
enable_frame_interpolation = gr.Checkbox(
|
| 1786 |
+
label="Frame Interpolation",
|
| 1787 |
+
value=DEFAULT_FRAME_INTERPOLATION,
|
| 1788 |
+
)
|
| 1789 |
+
resolution = gr.Dropdown(
|
| 1790 |
+
label="Resolution",
|
| 1791 |
+
choices=RESOLUTION_CHOICES,
|
| 1792 |
+
value=DEFAULT_RESOLUTION,
|
| 1793 |
+
visible=False,
|
| 1794 |
+
)
|
| 1795 |
+
height = gr.Number(value=DEFAULT_HEIGHT, precision=0, visible=False)
|
| 1796 |
+
width = gr.Number(value=DEFAULT_WIDTH, precision=0, visible=False)
|
| 1797 |
+
num_frames = gr.Slider(
|
| 1798 |
+
minimum=1,
|
| 1799 |
+
maximum=121,
|
| 1800 |
+
step=1,
|
| 1801 |
+
value=DEFAULT_NUM_FRAMES,
|
| 1802 |
+
label="Output Frames",
|
| 1803 |
+
)
|
| 1804 |
+
# seed = gr.Number(
|
| 1805 |
+
# label="Seed",
|
| 1806 |
+
# value=DEFAULT_BASIC_SEED,
|
| 1807 |
+
# precision=0,
|
| 1808 |
+
# info="-1 means using a random seed each time",
|
| 1809 |
+
# )
|
| 1810 |
+
|
| 1811 |
+
with gr.Accordion("Advanced Parameters", open=False):
|
| 1812 |
+
validation_num_timesteps = gr.Slider(
|
| 1813 |
+
minimum=1,
|
| 1814 |
+
maximum=50,
|
| 1815 |
+
step=1,
|
| 1816 |
+
value=DEFAULT_TIMESTEPS,
|
| 1817 |
+
label="Validation Num Timesteps",
|
| 1818 |
+
)
|
| 1819 |
+
with gr.Row():
|
| 1820 |
+
validation_timestep_shift = gr.Number(
|
| 1821 |
+
label="Validation Timestep Shift",
|
| 1822 |
+
value=DEFAULT_TIMESTEP_SHIFT,
|
| 1823 |
+
)
|
| 1824 |
+
cfg_text_scale = gr.Number(
|
| 1825 |
+
label="CFG Text Scale",
|
| 1826 |
+
value=DEFAULT_CFG_TEXT_SCALE,
|
| 1827 |
+
)
|
| 1828 |
+
|
| 1829 |
+
generation_example_inputs = [
|
| 1830 |
+
prompt,
|
| 1831 |
+
input_video,
|
| 1832 |
+
input_image,
|
| 1833 |
+
]
|
| 1834 |
+
|
| 1835 |
+
with gr.Column(scale=1, elem_classes=["lance-main-column"]):
|
| 1836 |
+
output_video = gr.Video(label="Output Video", elem_classes=["lance-display-frame"])
|
| 1837 |
+
output_image = gr.Image(label="Output Image", type="filepath", visible=False, elem_classes=["lance-display-frame"])
|
| 1838 |
+
output_text = gr.Textbox(label="Output Text", lines=8, visible=False, elem_classes=["lance-display-frame"])
|
| 1839 |
+
status = gr.Markdown("WAITING TO RUN.")
|
| 1840 |
+
logs = gr.Textbox(label="Run Logs", lines=22, max_lines=30)
|
| 1841 |
+
|
| 1842 |
+
run_button = gr.Button("RUN", variant="primary")
|
| 1843 |
+
|
| 1844 |
+
with gr.Group(visible=True, elem_classes=["prompt-examples"]) as video_generation_examples_group:
|
| 1845 |
+
gr.Markdown("### Video generation recommended cases")
|
| 1846 |
+
video_generation_examples = gr.Dataframe(
|
| 1847 |
+
value=VIDEO_GENERATION_EXAMPLES,
|
| 1848 |
+
headers=["Prompt"],
|
| 1849 |
+
datatype=["str"],
|
| 1850 |
+
interactive=False,
|
| 1851 |
+
show_row_numbers=False,
|
| 1852 |
+
wrap=True,
|
| 1853 |
+
line_breaks=True,
|
| 1854 |
+
row_count=(len(VIDEO_GENERATION_EXAMPLES), "fixed"),
|
| 1855 |
+
col_count=(1, "fixed"),
|
| 1856 |
+
max_height=420,
|
| 1857 |
+
elem_classes=["prompt-table"],
|
| 1858 |
+
)
|
| 1859 |
+
|
| 1860 |
+
with gr.Group(visible=False) as video_edit_examples_group:
|
| 1861 |
+
gr.Markdown("### Video edit recommended cases")
|
| 1862 |
+
video_edit_examples = gr.Examples(
|
| 1863 |
+
examples=VIDEO_EDIT_EXAMPLES,
|
| 1864 |
+
inputs=generation_example_inputs,
|
| 1865 |
+
label="",
|
| 1866 |
+
examples_per_page=3,
|
| 1867 |
+
cache_examples=False,
|
| 1868 |
+
preprocess=False,
|
| 1869 |
+
postprocess=False,
|
| 1870 |
+
)
|
| 1871 |
+
|
| 1872 |
+
with gr.Group(visible=False) as video_understanding_examples_group:
|
| 1873 |
+
gr.Markdown("### Video understanding recommended cases")
|
| 1874 |
+
video_understanding_examples = gr.Examples(
|
| 1875 |
+
examples=VIDEO_UNDERSTANDING_EXAMPLES,
|
| 1876 |
+
inputs=generation_example_inputs,
|
| 1877 |
+
label="",
|
| 1878 |
+
examples_per_page=4,
|
| 1879 |
+
cache_examples=False,
|
| 1880 |
+
preprocess=False,
|
| 1881 |
+
postprocess=False,
|
| 1882 |
+
)
|
| 1883 |
+
|
| 1884 |
+
with gr.Group(visible=False, elem_classes=["prompt-examples"]) as image_generation_examples_group:
|
| 1885 |
+
gr.Markdown("### Image generation recommended cases")
|
| 1886 |
+
image_generation_examples = gr.Dataframe(
|
| 1887 |
+
value=IMAGE_GENERATION_EXAMPLES,
|
| 1888 |
+
headers=["Prompt"],
|
| 1889 |
+
datatype=["str"],
|
| 1890 |
+
interactive=False,
|
| 1891 |
+
show_row_numbers=False,
|
| 1892 |
+
wrap=True,
|
| 1893 |
+
line_breaks=True,
|
| 1894 |
+
row_count=(len(IMAGE_GENERATION_EXAMPLES), "fixed"),
|
| 1895 |
+
col_count=(1, "fixed"),
|
| 1896 |
+
max_height=420,
|
| 1897 |
+
elem_classes=["prompt-table"],
|
| 1898 |
+
)
|
| 1899 |
+
|
| 1900 |
+
with gr.Group(visible=False) as image_edit_examples_group:
|
| 1901 |
+
gr.Markdown("### Image edit recommended cases")
|
| 1902 |
+
image_edit_examples = gr.Examples(
|
| 1903 |
+
examples=IMAGE_EDIT_EXAMPLES,
|
| 1904 |
+
inputs=generation_example_inputs,
|
| 1905 |
+
label="",
|
| 1906 |
+
examples_per_page=5,
|
| 1907 |
+
cache_examples=False,
|
| 1908 |
+
preprocess=False,
|
| 1909 |
+
postprocess=False,
|
| 1910 |
+
)
|
| 1911 |
+
|
| 1912 |
+
with gr.Group(visible=False) as image_understanding_examples_group:
|
| 1913 |
+
gr.Markdown("### Image understanding recommended cases")
|
| 1914 |
+
image_understanding_examples = gr.Examples(
|
| 1915 |
+
examples=IMAGE_UNDERSTANDING_EXAMPLES,
|
| 1916 |
+
inputs=generation_example_inputs,
|
| 1917 |
+
label="",
|
| 1918 |
+
examples_per_page=4,
|
| 1919 |
+
cache_examples=False,
|
| 1920 |
+
preprocess=False,
|
| 1921 |
+
postprocess=False,
|
| 1922 |
+
)
|
| 1923 |
+
|
| 1924 |
+
keep_example_clicks_from_changing_visibility(
|
| 1925 |
+
video_generation_examples,
|
| 1926 |
+
video_edit_examples,
|
| 1927 |
+
video_understanding_examples,
|
| 1928 |
+
image_generation_examples,
|
| 1929 |
+
image_edit_examples,
|
| 1930 |
+
image_understanding_examples,
|
| 1931 |
+
)
|
| 1932 |
+
|
| 1933 |
+
task.change(
|
| 1934 |
+
fn=update_task_ui,
|
| 1935 |
+
inputs=[task],
|
| 1936 |
+
outputs=[
|
| 1937 |
+
prompt,
|
| 1938 |
+
system_prompt,
|
| 1939 |
+
input_video,
|
| 1940 |
+
input_image,
|
| 1941 |
+
aspect_ratio,
|
| 1942 |
+
height,
|
| 1943 |
+
width,
|
| 1944 |
+
real_size,
|
| 1945 |
+
num_frames,
|
| 1946 |
+
enable_frame_interpolation,
|
| 1947 |
+
resolution,
|
| 1948 |
+
output_video,
|
| 1949 |
+
output_image,
|
| 1950 |
+
output_text,
|
| 1951 |
+
video_generation_examples_group,
|
| 1952 |
+
video_edit_examples_group,
|
| 1953 |
+
video_understanding_examples_group,
|
| 1954 |
+
image_generation_examples_group,
|
| 1955 |
+
image_edit_examples_group,
|
| 1956 |
+
image_understanding_examples_group,
|
| 1957 |
+
],
|
| 1958 |
+
)
|
| 1959 |
+
|
| 1960 |
+
aspect_ratio.change(
|
| 1961 |
+
fn=update_size_from_aspect_ratio,
|
| 1962 |
+
inputs=[task, aspect_ratio],
|
| 1963 |
+
outputs=[height, width, real_size],
|
| 1964 |
+
queue=False,
|
| 1965 |
+
show_api=False,
|
| 1966 |
+
)
|
| 1967 |
+
|
| 1968 |
+
for examples_component in (video_edit_examples, video_understanding_examples, image_edit_examples, image_understanding_examples):
|
| 1969 |
+
examples_component.load_input_event.then(
|
| 1970 |
+
fn=reset_generation_defaults_for_task,
|
| 1971 |
+
inputs=[task],
|
| 1972 |
+
outputs=[aspect_ratio, height, width, num_frames, resolution, real_size],
|
| 1973 |
+
queue=False,
|
| 1974 |
+
show_api=False,
|
| 1975 |
+
)
|
| 1976 |
+
|
| 1977 |
+
video_generation_examples.select(
|
| 1978 |
+
fn=apply_prompt_example,
|
| 1979 |
+
inputs=[task],
|
| 1980 |
+
outputs=[prompt, aspect_ratio, height, width, num_frames, resolution, real_size],
|
| 1981 |
+
queue=False,
|
| 1982 |
+
show_api=False,
|
| 1983 |
+
)
|
| 1984 |
+
image_generation_examples.select(
|
| 1985 |
+
fn=apply_prompt_example,
|
| 1986 |
+
inputs=[task],
|
| 1987 |
+
outputs=[prompt, aspect_ratio, height, width, num_frames, resolution, real_size],
|
| 1988 |
+
queue=False,
|
| 1989 |
+
show_api=False,
|
| 1990 |
+
)
|
| 1991 |
+
|
| 1992 |
+
run_button.click(
|
| 1993 |
+
fn=run_task,
|
| 1994 |
+
inputs=[
|
| 1995 |
+
task,
|
| 1996 |
+
prompt,
|
| 1997 |
+
system_prompt,
|
| 1998 |
+
input_video,
|
| 1999 |
+
input_image,
|
| 2000 |
+
height,
|
| 2001 |
+
width,
|
| 2002 |
+
num_frames,
|
| 2003 |
+
seed,
|
| 2004 |
+
resolution,
|
| 2005 |
+
validation_num_timesteps,
|
| 2006 |
+
validation_timestep_shift,
|
| 2007 |
+
cfg_text_scale,
|
| 2008 |
+
enable_frame_interpolation,
|
| 2009 |
+
],
|
| 2010 |
+
outputs=[output_video, output_image, output_text, status, logs],
|
| 2011 |
+
)
|
| 2012 |
+
|
| 2013 |
+
return demo
|
| 2014 |
+
|
| 2015 |
+
|
| 2016 |
+
def parse_args() -> argparse.Namespace:
|
| 2017 |
+
parser = argparse.ArgumentParser(description="Lance multimodal Gradio")
|
| 2018 |
+
parser.add_argument("--server-name", default=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"))
|
| 2019 |
+
parser.add_argument("--server-port", type=int, default=int(os.getenv("GRADIO_SERVER_PORT", "7860")))
|
| 2020 |
+
parser.add_argument("--share", action="store_true", default=env_flag("GRADIO_SHARE", False))
|
| 2021 |
+
parser.add_argument(
|
| 2022 |
+
"--gpus",
|
| 2023 |
+
default=os.getenv("LANCE_GPUS", DEFAULT_GPUS),
|
| 2024 |
+
help="Comma-separated GPU list, for example: 0,1,2,3,4,5,6",
|
| 2025 |
+
)
|
| 2026 |
+
parser.add_argument(
|
| 2027 |
+
"--queue-size",
|
| 2028 |
+
type=int,
|
| 2029 |
+
default=int(os.getenv("LANCE_QUEUE_SIZE", str(DEFAULT_QUEUE_SIZE))),
|
| 2030 |
+
help="Maximum number of queued Gradio requests.",
|
| 2031 |
+
)
|
| 2032 |
+
return parser.parse_args()
|
| 2033 |
+
|
| 2034 |
+
|
| 2035 |
+
def parse_gpu_ids(gpu_string: str) -> list[int]:
|
| 2036 |
+
gpu_ids: list[int] = []
|
| 2037 |
+
for item in gpu_string.split(","):
|
| 2038 |
+
item = item.strip()
|
| 2039 |
+
if not item:
|
| 2040 |
+
continue
|
| 2041 |
+
gpu_ids.append(int(item))
|
| 2042 |
+
if not gpu_ids:
|
| 2043 |
+
raise ValueError("No valid GPU IDs were parsed.")
|
| 2044 |
+
return gpu_ids
|
| 2045 |
+
|
| 2046 |
+
|
| 2047 |
+
if __name__ == "__main__":
|
| 2048 |
+
args = parse_args()
|
| 2049 |
+
os.environ["LANCE_GPUS"] = args.gpus
|
| 2050 |
+
resolved_model_path = ensure_model_assets(MODEL_VARIANT_VIDEO)
|
| 2051 |
+
print(f"[startup] Using Lance model path: {resolved_model_path}", flush=True)
|
| 2052 |
+
QUEUE_MAX_SIZE = args.queue_size
|
| 2053 |
+
gpu_ids = parse_gpu_ids(args.gpus)
|
| 2054 |
+
ACTIVE_PIPELINE_POOL = PipelinePool(gpu_ids, model_variant=MODEL_VARIANT_VIDEO)
|
| 2055 |
+
ACTIVE_PIPELINE_POOL.initialize_all()
|
| 2056 |
+
demo = build_demo()
|
| 2057 |
+
demo.queue(
|
| 2058 |
+
max_size=args.queue_size,
|
| 2059 |
+
default_concurrency_limit=ACTIVE_PIPELINE_POOL.size,
|
| 2060 |
+
).launch(
|
| 2061 |
+
server_name=args.server_name,
|
| 2062 |
+
server_port=args.server_port,
|
| 2063 |
+
share=args.share,
|
| 2064 |
+
)
|
app_wrong.py
ADDED
|
@@ -0,0 +1,2247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import base64
|
| 5 |
+
import concurrent.futures
|
| 6 |
+
import gc
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import random
|
| 10 |
+
import subprocess
|
| 11 |
+
import threading
|
| 12 |
+
import time
|
| 13 |
+
import traceback
|
| 14 |
+
from collections import deque
|
| 15 |
+
from copy import deepcopy
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Optional
|
| 19 |
+
|
| 20 |
+
import gradio as gr
|
| 21 |
+
import torch
|
| 22 |
+
from huggingface_hub import snapshot_download
|
| 23 |
+
from safetensors.torch import load_file
|
| 24 |
+
from transformers import set_seed
|
| 25 |
+
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig
|
| 26 |
+
|
| 27 |
+
from common.utils.logging import get_logger
|
| 28 |
+
from common.utils.misc import AutoEncoderParams, tuple_mul
|
| 29 |
+
from config.config_factory import DataArguments, InferenceArguments, ModelArguments
|
| 30 |
+
from data.data_utils import add_special_tokens
|
| 31 |
+
from data.dataset_base import DataConfig, simple_custom_collate
|
| 32 |
+
from data.datasets_custom import ValidationDataset
|
| 33 |
+
from inference_lance import (
|
| 34 |
+
PROMPT_JSON_FILENAME,
|
| 35 |
+
apply_inference_defaults,
|
| 36 |
+
clean_memory,
|
| 37 |
+
init_from_model_path_if_needed,
|
| 38 |
+
save_prompt_results,
|
| 39 |
+
validate_on_fixed_batch,
|
| 40 |
+
)
|
| 41 |
+
from modeling.lance import Lance, LanceConfig, Qwen2ForCausalLM
|
| 42 |
+
from modeling.qwen2 import Qwen2Tokenizer
|
| 43 |
+
from modeling.qwen2.modeling_qwen2 import Qwen2Config
|
| 44 |
+
from modeling.vae.wan.model import WanVideoVAE
|
| 45 |
+
from modeling.vit.qwen2_5_vl_vit import Qwen2_5_VisionTransformerPretrainedModel
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
REPO_ROOT = Path(__file__).resolve().parent
|
| 49 |
+
GRADIO_TMP_ROOT = Path(os.getenv("LANCE_GRADIO_TMP_ROOT", "/tmp/lance_gradio")).expanduser()
|
| 50 |
+
TMP_INPUT_DIR = GRADIO_TMP_ROOT / "inputs"
|
| 51 |
+
RESULTS_ROOT = GRADIO_TMP_ROOT / "results"
|
| 52 |
+
GLOBAL_RECORDS_FILE = GRADIO_TMP_ROOT / "generation_records.jsonl"
|
| 53 |
+
RUN_RECORD_FILENAME = "generation_record.json"
|
| 54 |
+
|
| 55 |
+
LOCAL_MODEL_BASE_DIR = Path("downloads")
|
| 56 |
+
SPACE_MODEL_BASE_DIR = Path("/data/lance_models")
|
| 57 |
+
DEFAULT_MODEL_REPO_ID = "bytedance-research/Lance"
|
| 58 |
+
DEFAULT_MODEL_VARIANT = "video"
|
| 59 |
+
MODEL_VARIANT_VIDEO = "video"
|
| 60 |
+
MODEL_VARIANT_IMAGE = "image"
|
| 61 |
+
MODEL_VARIANT_TO_DIR = {
|
| 62 |
+
MODEL_VARIANT_VIDEO: "Lance_3B_Video",
|
| 63 |
+
MODEL_VARIANT_IMAGE: "Lance_3B",
|
| 64 |
+
}
|
| 65 |
+
DEFAULT_MODEL_PATH = LOCAL_MODEL_BASE_DIR / MODEL_VARIANT_TO_DIR[MODEL_VARIANT_VIDEO]
|
| 66 |
+
DEFAULT_VIT_TYPE = "qwen_2_5_vl_original"
|
| 67 |
+
DEFAULT_TASK = "t2v"
|
| 68 |
+
DEFAULT_TIMESTEPS = 30
|
| 69 |
+
DEFAULT_TIMESTEP_SHIFT = 3.5
|
| 70 |
+
DEFAULT_CFG_TEXT_SCALE = 4.0
|
| 71 |
+
DEFAULT_RESOLUTION = "video_848x480"
|
| 72 |
+
DEFAULT_IMAGE_RESOLUTION = "image_768x768"
|
| 73 |
+
DEFAULT_BASIC_SEED = 42
|
| 74 |
+
DEFAULT_HEIGHT = 480
|
| 75 |
+
DEFAULT_WIDTH = 848
|
| 76 |
+
DEFAULT_IMAGE_SIZE = 768
|
| 77 |
+
DEFAULT_VIDEO_DURATION_SECONDS = 5
|
| 78 |
+
DEFAULT_NUM_FRAMES = 12 * DEFAULT_VIDEO_DURATION_SECONDS + 1
|
| 79 |
+
DEFAULT_VIDEO_ASPECT_RATIO = "16:9"
|
| 80 |
+
DEFAULT_IMAGE_ASPECT_RATIO = "1:1"
|
| 81 |
+
FRAME_INTERPOLATION_YES = "Yes"
|
| 82 |
+
FRAME_INTERPOLATION_NO = "No"
|
| 83 |
+
DEFAULT_FRAME_INTERPOLATION = FRAME_INTERPOLATION_YES
|
| 84 |
+
ASPECT_RATIO_CHOICES = ["21:9", "16:9", "3:2", "4:3", "1:1", "3:4", "2:3", "9:16", "9:21"]
|
| 85 |
+
|
| 86 |
+
VIDEO_ASPECT_RATIO_TO_SIZE = {
|
| 87 |
+
"21:9": (976, 416),
|
| 88 |
+
"16:9": (848, 480),
|
| 89 |
+
"3:2": (784, 528),
|
| 90 |
+
"4:3": (736, 560),
|
| 91 |
+
"1:1": (640, 640),
|
| 92 |
+
"3:4": (560, 736),
|
| 93 |
+
"2:3": (528, 784),
|
| 94 |
+
"9:16": (480, 848),
|
| 95 |
+
"9:21": (416, 976),
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
IMAGE_ASPECT_RATIO_TO_SIZE = {
|
| 99 |
+
"21:9": (1168, 496),
|
| 100 |
+
"16:9": (1024, 576),
|
| 101 |
+
"3:2": (944, 624),
|
| 102 |
+
"4:3": (880, 672),
|
| 103 |
+
"1:1": (768, 768),
|
| 104 |
+
"3:4": (672, 880),
|
| 105 |
+
"2:3": (624, 944),
|
| 106 |
+
"9:16": (576, 1024),
|
| 107 |
+
"9:21": (496, 1168),
|
| 108 |
+
}
|
| 109 |
+
DEFAULT_GPUS = "0"
|
| 110 |
+
DEFAULT_QUEUE_SIZE = 32
|
| 111 |
+
USE_KVCACHE = True
|
| 112 |
+
TEXT_TEMPLATE = True
|
| 113 |
+
RECORD_WRITE_LOCK = threading.Lock()
|
| 114 |
+
|
| 115 |
+
LANCE_HOMEPAGE_URL = "https://lance-project.github.io/"
|
| 116 |
+
LANCE_PAPER_URL = "http://arxiv.org/abs/2605.18678"
|
| 117 |
+
LANCE_HUGGING_FACE_URL = "https://huggingface.co/bytedance-research/Lance"
|
| 118 |
+
LANCE_GITHUB_URL = "https://github.com/bytedance/Lance"
|
| 119 |
+
LANCE_LOGO_PATH = REPO_ROOT / "assets" / "logo" / "lance-logo.webp"
|
| 120 |
+
|
| 121 |
+
APP_CSS = """
|
| 122 |
+
.gradio-container {
|
| 123 |
+
max-width: 1680px !important;
|
| 124 |
+
margin-left: auto !important;
|
| 125 |
+
margin-right: auto !important;
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
.contain {
|
| 129 |
+
max-width: 1680px !important;
|
| 130 |
+
margin-left: auto !important;
|
| 131 |
+
margin-right: auto !important;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
.lance-hero {
|
| 135 |
+
text-align: center;
|
| 136 |
+
padding: 8px 12px 6px;
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
.lance-logo {
|
| 140 |
+
width: min(160px, 36vw);
|
| 141 |
+
height: auto;
|
| 142 |
+
display: block;
|
| 143 |
+
margin: 0 auto 4px;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
.lance-title {
|
| 147 |
+
margin: 0 auto 5px;
|
| 148 |
+
font-size: clamp(20px, 2.4vw, 30px);
|
| 149 |
+
line-height: 1.08;
|
| 150 |
+
font-weight: 800;
|
| 151 |
+
letter-spacing: 0;
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
.lance-authors {
|
| 155 |
+
margin: 0 auto 6px;
|
| 156 |
+
max-width: 1280px;
|
| 157 |
+
font-size: 20px;
|
| 158 |
+
line-height: 1.24;
|
| 159 |
+
color: var(--body-text-color-subdued);
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
.lance-authors a {
|
| 163 |
+
color: inherit;
|
| 164 |
+
text-decoration: none;
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
.lance-authors a:hover {
|
| 168 |
+
text-decoration: underline;
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
.lance-badges {
|
| 172 |
+
display: flex;
|
| 173 |
+
flex-wrap: wrap;
|
| 174 |
+
justify-content: center;
|
| 175 |
+
gap: 5px;
|
| 176 |
+
margin: 4px auto 0;
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
.lance-badges a {
|
| 180 |
+
line-height: 0;
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
.lance-badges img {
|
| 184 |
+
height: 20px;
|
| 185 |
+
width: auto;
|
| 186 |
+
display: block;
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
.lance-status {
|
| 190 |
+
max-width: 1180px;
|
| 191 |
+
margin: 0 auto 18px;
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
.task-selector {
|
| 195 |
+
overflow-x: auto;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
.lance-main-column > label span,
|
| 199 |
+
.lance-main-column > .block-title,
|
| 200 |
+
.lance-main-column > .label-wrap span,
|
| 201 |
+
.lance-main-column > .form > label span,
|
| 202 |
+
.lance-main-column > .form > .block-title,
|
| 203 |
+
.lance-main-column > .form > .label-wrap span {
|
| 204 |
+
font-size: 20px !important;
|
| 205 |
+
font-weight: 700 !important;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
.task-selector .wrap {
|
| 209 |
+
display: grid;
|
| 210 |
+
grid-template-columns: repeat(3, minmax(220px, 1fr));
|
| 211 |
+
gap: 8px;
|
| 212 |
+
min-width: 680px;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
.task-selector label {
|
| 216 |
+
justify-content: center;
|
| 217 |
+
min-height: 38px;
|
| 218 |
+
white-space: nowrap;
|
| 219 |
+
border-radius: 10px !important;
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
.task-selector span {
|
| 223 |
+
font-size: 20px !important;
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
.recommended-title {
|
| 227 |
+
text-align: center !important;
|
| 228 |
+
margin: 14px auto 10px !important;
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
.recommended-title h3,
|
| 232 |
+
.recommended-title p {
|
| 233 |
+
text-align: center !important;
|
| 234 |
+
font-size: 22px !important;
|
| 235 |
+
font-weight: 800 !important;
|
| 236 |
+
color: var(--body-text-color) !important;
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
.example-panel {
|
| 240 |
+
margin-top: 14px !important;
|
| 241 |
+
padding: 10px 12px !important;
|
| 242 |
+
border-radius: 8px !important;
|
| 243 |
+
background: rgba(248, 250, 252, 0.72) !important;
|
| 244 |
+
border: 1px solid var(--border-color-primary) !important;
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
.prompt-examples table,
|
| 248 |
+
.prompt-examples th,
|
| 249 |
+
.prompt-examples td {
|
| 250 |
+
border: 1px solid var(--border-color-primary) !important;
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
.prompt-examples table {
|
| 254 |
+
border-collapse: collapse !important;
|
| 255 |
+
width: 100% !important;
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
.prompt-examples td {
|
| 259 |
+
border-bottom: 1px solid var(--border-color-primary) !important;
|
| 260 |
+
padding: 12px !important;
|
| 261 |
+
vertical-align: top !important;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
.example-panel th,
|
| 265 |
+
.example-panel .block-label,
|
| 266 |
+
.example-panel label span,
|
| 267 |
+
.example-panel .label-wrap span {
|
| 268 |
+
font-size: 18px !important;
|
| 269 |
+
font-weight: 700 !important;
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
.prompt-dataset {
|
| 273 |
+
max-height: 420px !important;
|
| 274 |
+
overflow-y: auto !important;
|
| 275 |
+
overscroll-behavior: contain !important;
|
| 276 |
+
scrollbar-gutter: stable !important;
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
.prompt-dataset button {
|
| 280 |
+
height: auto !important;
|
| 281 |
+
min-height: 48px !important;
|
| 282 |
+
white-space: normal !important;
|
| 283 |
+
text-align: left !important;
|
| 284 |
+
align-items: flex-start !important;
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
.prompt-dataset .paginate {
|
| 288 |
+
display: none !important;
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
.prompt-example-proxy {
|
| 292 |
+
display: none !important;
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
.lance-main-row {
|
| 296 |
+
display: grid !important;
|
| 297 |
+
grid-template-columns: minmax(0, 1fr) minmax(0, 1fr) !important;
|
| 298 |
+
gap: 16px !important;
|
| 299 |
+
align-items: start !important;
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
.lance-main-column {
|
| 303 |
+
min-width: 0 !important;
|
| 304 |
+
width: 100% !important;
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
.lance-display-frame,
|
| 308 |
+
.lance-display-frame > div,
|
| 309 |
+
.lance-display-frame textarea {
|
| 310 |
+
width: 100% !important;
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
.lance-display-frame textarea {
|
| 314 |
+
min-height: 360px !important;
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
.lance-run-button {
|
| 318 |
+
font-size: 18px !important;
|
| 319 |
+
font-weight: 800 !important;
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
.generation-controls-row {
|
| 323 |
+
width: 100% !important;
|
| 324 |
+
max-width: 100% !important;
|
| 325 |
+
overflow-x: hidden !important;
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
.generation-controls-row > .form {
|
| 329 |
+
display: grid !important;
|
| 330 |
+
grid-template-columns:
|
| 331 |
+
minmax(0, 1.25fr)
|
| 332 |
+
minmax(0, 1.3fr)
|
| 333 |
+
minmax(0, 1fr)
|
| 334 |
+
minmax(0, 1.25fr) !important;
|
| 335 |
+
gap: 12px !important;
|
| 336 |
+
align-items: start !important;
|
| 337 |
+
width: 100% !important;
|
| 338 |
+
max-width: 100% !important;
|
| 339 |
+
overflow: visible !important;
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
.generation-control,
|
| 343 |
+
.generation-control > div,
|
| 344 |
+
.generation-controls-row > .form > div {
|
| 345 |
+
min-width: 0 !important;
|
| 346 |
+
max-width: 100% !important;
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
.generation-controls-row .generation-control label,
|
| 350 |
+
.generation-controls-row .generation-control label span,
|
| 351 |
+
.generation-controls-row .generation-control .block-label,
|
| 352 |
+
.generation-controls-row .generation-control .block-title,
|
| 353 |
+
.generation-controls-row .generation-control > label,
|
| 354 |
+
.generation-controls-row .generation-control .label-wrap,
|
| 355 |
+
.generation-controls-row .generation-control .label-wrap span {
|
| 356 |
+
font-size: 22px !important;
|
| 357 |
+
font-weight: 700 !important;
|
| 358 |
+
line-height: 1.15 !important;
|
| 359 |
+
letter-spacing: 0 !important;
|
| 360 |
+
white-space: normal !important;
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
.generation-controls-row .generation-value-control input,
|
| 364 |
+
.generation-controls-row .generation-value-control textarea,
|
| 365 |
+
.generation-controls-row .generation-value-control [data-testid="textbox"],
|
| 366 |
+
.generation-controls-row .generation-dropdown-control input[role="listbox"],
|
| 367 |
+
.generation-controls-row .generation-dropdown-control input.border-none[role="listbox"],
|
| 368 |
+
.generation-controls-row .generation-dropdown-control .secondary-wrap input {
|
| 369 |
+
font-size: 22px !important;
|
| 370 |
+
font-weight: 700 !important;
|
| 371 |
+
line-height: 1.2 !important;
|
| 372 |
+
letter-spacing: 0 !important;
|
| 373 |
+
text-align: left !important;
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
.generation-controls-row .generation-value-control input,
|
| 377 |
+
.generation-controls-row .generation-value-control textarea,
|
| 378 |
+
.generation-controls-row .generation-dropdown-control input[role="listbox"],
|
| 379 |
+
.generation-controls-row .generation-dropdown-control input.border-none[role="listbox"],
|
| 380 |
+
.generation-controls-row .generation-dropdown-control .secondary-wrap input {
|
| 381 |
+
min-height: 64px !important;
|
| 382 |
+
width: 100% !important;
|
| 383 |
+
box-sizing: border-box !important;
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
@media (max-width: 1100px) {
|
| 387 |
+
.generation-controls-row > .form {
|
| 388 |
+
grid-template-columns: repeat(2, minmax(0, 1fr)) !important;
|
| 389 |
+
}
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
@media (max-width: 900px) {
|
| 393 |
+
.lance-main-row {
|
| 394 |
+
grid-template-columns: minmax(0, 1fr) !important;
|
| 395 |
+
}
|
| 396 |
+
}
|
| 397 |
+
"""
|
| 398 |
+
|
| 399 |
+
TASK_T2V = "t2v"
|
| 400 |
+
TASK_T2I = "t2i"
|
| 401 |
+
TASK_V2T = "v2t"
|
| 402 |
+
TASK_X2T = "x2t"
|
| 403 |
+
TASK_X2T_VIDEO = "x2t_video"
|
| 404 |
+
TASK_X2T_IMAGE = "x2t_image"
|
| 405 |
+
TASK_IMAGE_EDIT = "image_edit"
|
| 406 |
+
TASK_VIDEO_EDIT = "video_edit"
|
| 407 |
+
TASK_LABEL_VIDEO_GENERATION = "Video Generation"
|
| 408 |
+
TASK_LABEL_VIDEO_EDIT = "Video Edit"
|
| 409 |
+
TASK_LABEL_VIDEO_UNDERSTANDING = "Video Understanding"
|
| 410 |
+
TASK_LABEL_IMAGE_GENERATION = "Image Generation"
|
| 411 |
+
TASK_LABEL_IMAGE_EDIT = "Image Edit"
|
| 412 |
+
TASK_LABEL_IMAGE_UNDERSTANDING = "Image Understanding"
|
| 413 |
+
TASK_CHOICES = [
|
| 414 |
+
TASK_LABEL_VIDEO_GENERATION,
|
| 415 |
+
TASK_LABEL_VIDEO_EDIT,
|
| 416 |
+
TASK_LABEL_VIDEO_UNDERSTANDING,
|
| 417 |
+
TASK_LABEL_IMAGE_GENERATION,
|
| 418 |
+
TASK_LABEL_IMAGE_EDIT,
|
| 419 |
+
TASK_LABEL_IMAGE_UNDERSTANDING,
|
| 420 |
+
]
|
| 421 |
+
TASK_LABEL_TO_INTERNAL = {
|
| 422 |
+
TASK_LABEL_VIDEO_GENERATION: TASK_T2V,
|
| 423 |
+
TASK_LABEL_VIDEO_EDIT: TASK_VIDEO_EDIT,
|
| 424 |
+
TASK_LABEL_VIDEO_UNDERSTANDING: TASK_X2T_VIDEO,
|
| 425 |
+
TASK_LABEL_IMAGE_GENERATION: TASK_T2I,
|
| 426 |
+
TASK_LABEL_IMAGE_EDIT: TASK_IMAGE_EDIT,
|
| 427 |
+
TASK_LABEL_IMAGE_UNDERSTANDING: TASK_X2T_IMAGE,
|
| 428 |
+
TASK_T2V: TASK_T2V,
|
| 429 |
+
TASK_VIDEO_EDIT: TASK_VIDEO_EDIT,
|
| 430 |
+
TASK_V2T: TASK_X2T_VIDEO,
|
| 431 |
+
TASK_X2T: TASK_X2T_VIDEO,
|
| 432 |
+
TASK_X2T_VIDEO: TASK_X2T_VIDEO,
|
| 433 |
+
TASK_T2I: TASK_T2I,
|
| 434 |
+
TASK_IMAGE_EDIT: TASK_IMAGE_EDIT,
|
| 435 |
+
TASK_X2T_IMAGE: TASK_X2T_IMAGE,
|
| 436 |
+
}
|
| 437 |
+
GENERATION_TASKS = {TASK_T2V, TASK_T2I, TASK_IMAGE_EDIT, TASK_VIDEO_EDIT}
|
| 438 |
+
UNDERSTANDING_TASKS = {TASK_X2T_VIDEO, TASK_X2T_IMAGE}
|
| 439 |
+
IMAGE_TASKS = {TASK_T2I, TASK_IMAGE_EDIT, TASK_X2T_IMAGE}
|
| 440 |
+
VIDEO_TASKS = {TASK_T2V, TASK_VIDEO_EDIT, TASK_X2T_VIDEO}
|
| 441 |
+
EDIT_TASKS = {TASK_IMAGE_EDIT, TASK_VIDEO_EDIT}
|
| 442 |
+
VIDEO_RESOLUTION_CHOICES = [DEFAULT_RESOLUTION]
|
| 443 |
+
IMAGE_RESOLUTION_CHOICES = [DEFAULT_IMAGE_RESOLUTION]
|
| 444 |
+
RESOLUTION_CHOICES = VIDEO_RESOLUTION_CHOICES + IMAGE_RESOLUTION_CHOICES
|
| 445 |
+
CAPTION_SYSTEM_PROMPT_TEMPLATE = (
|
| 446 |
+
"Describe the key features of the input {vision_type}, including color, shape, size, texture, objects, background."
|
| 447 |
+
)
|
| 448 |
+
V2T_CAPTION_SYSTEM_PROMPT = CAPTION_SYSTEM_PROMPT_TEMPLATE.format(vision_type="video")
|
| 449 |
+
I2T_CAPTION_SYSTEM_PROMPT = CAPTION_SYSTEM_PROMPT_TEMPLATE.format(vision_type="image")
|
| 450 |
+
V2T_QA_SYSTEM_PROMPT = "View the video attentively and provide a suitable answer to the posed question."
|
| 451 |
+
I2T_QA_SYSTEM_PROMPT = "View the image attentively and provide a suitable answer to the posed question."
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def get_aspect_ratio_choices_for_task(task: str) -> list[tuple[str, str]]:
|
| 455 |
+
"""Get Aspect Ratio choices with default/recommended marker for the given task."""
|
| 456 |
+
internal_task = normalize_task(task)
|
| 457 |
+
default_ratio = DEFAULT_IMAGE_ASPECT_RATIO if internal_task in IMAGE_TASKS else DEFAULT_VIDEO_ASPECT_RATIO
|
| 458 |
+
return [
|
| 459 |
+
(f"{ratio} (default)" if ratio == default_ratio else ratio, ratio)
|
| 460 |
+
for ratio in ASPECT_RATIO_CHOICES
|
| 461 |
+
]
|
| 462 |
+
|
| 463 |
+
def env_flag(name: str, default: bool) -> bool:
|
| 464 |
+
value = os.getenv(name)
|
| 465 |
+
if value is None:
|
| 466 |
+
return default
|
| 467 |
+
return value.strip().lower() in {"1", "true", "yes", "on"}
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def running_on_space() -> bool:
|
| 471 |
+
return bool(os.getenv("SPACE_ID") or os.getenv("SPACE_HOST"))
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def display_path(path: Path) -> str:
|
| 475 |
+
path_text = path.as_posix()
|
| 476 |
+
if path.is_absolute():
|
| 477 |
+
try:
|
| 478 |
+
path_text = path.relative_to(Path.cwd()).as_posix()
|
| 479 |
+
except ValueError:
|
| 480 |
+
return path_text
|
| 481 |
+
if path_text == "." or path_text.startswith("./"):
|
| 482 |
+
return path_text
|
| 483 |
+
return f"./{path_text}"
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def get_model_base_dir() -> Path:
|
| 487 |
+
configured = os.getenv("LANCE_MODEL_BASE_DIR")
|
| 488 |
+
if configured:
|
| 489 |
+
return Path(configured).expanduser()
|
| 490 |
+
if LOCAL_MODEL_BASE_DIR.exists():
|
| 491 |
+
return LOCAL_MODEL_BASE_DIR
|
| 492 |
+
return SPACE_MODEL_BASE_DIR if running_on_space() else LOCAL_MODEL_BASE_DIR
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def normalize_model_variant(model_variant: Optional[str] = None) -> str:
|
| 496 |
+
variant = (model_variant or os.getenv("LANCE_MODEL_VARIANT", DEFAULT_MODEL_VARIANT)).strip().lower()
|
| 497 |
+
if variant in {"image", "t2i", "i2t"}:
|
| 498 |
+
return MODEL_VARIANT_IMAGE
|
| 499 |
+
return MODEL_VARIANT_VIDEO
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def get_model_path(model_variant: Optional[str] = None) -> Path:
|
| 503 |
+
variant = normalize_model_variant(model_variant)
|
| 504 |
+
variant_env_name = "LANCE_IMAGE_MODEL_PATH" if variant == MODEL_VARIANT_IMAGE else "LANCE_VIDEO_MODEL_PATH"
|
| 505 |
+
variant_configured = os.getenv(variant_env_name)
|
| 506 |
+
if variant_configured:
|
| 507 |
+
return Path(variant_configured).expanduser()
|
| 508 |
+
|
| 509 |
+
configured = os.getenv("LANCE_MODEL_PATH")
|
| 510 |
+
if configured:
|
| 511 |
+
return Path(configured).expanduser()
|
| 512 |
+
|
| 513 |
+
model_dir_name = MODEL_VARIANT_TO_DIR[variant]
|
| 514 |
+
return get_model_base_dir() / model_dir_name
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def get_required_model_asset_paths(model_base_dir: Path, model_path: Path) -> list[Path]:
|
| 518 |
+
return [
|
| 519 |
+
model_path / "llm_config.json",
|
| 520 |
+
model_path / "model.safetensors",
|
| 521 |
+
model_base_dir / "Qwen2.5-VL-ViT" / "vit.safetensors",
|
| 522 |
+
model_base_dir / "Wan2.2_VAE.pth",
|
| 523 |
+
]
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
def ensure_model_assets(model_variant: Optional[str] = None) -> Path:
|
| 527 |
+
model_base_dir = get_model_base_dir()
|
| 528 |
+
os.environ["LANCE_MODEL_BASE_DIR"] = display_path(model_base_dir)
|
| 529 |
+
model_path = get_model_path(model_variant)
|
| 530 |
+
|
| 531 |
+
required_paths = get_required_model_asset_paths(model_base_dir, model_path)
|
| 532 |
+
if all(path.exists() for path in required_paths):
|
| 533 |
+
return model_path
|
| 534 |
+
|
| 535 |
+
downloads_model_base_dir = Path("downloads")
|
| 536 |
+
if model_base_dir == Path(".") and downloads_model_base_dir.exists():
|
| 537 |
+
downloads_model_path = downloads_model_base_dir / MODEL_VARIANT_TO_DIR[normalize_model_variant(model_variant)]
|
| 538 |
+
downloads_required_paths = get_required_model_asset_paths(downloads_model_base_dir, downloads_model_path)
|
| 539 |
+
if all(path.exists() for path in downloads_required_paths):
|
| 540 |
+
model_base_dir = downloads_model_base_dir
|
| 541 |
+
model_path = downloads_model_path
|
| 542 |
+
required_paths = downloads_required_paths
|
| 543 |
+
os.environ["LANCE_MODEL_BASE_DIR"] = display_path(model_base_dir)
|
| 544 |
+
return model_path
|
| 545 |
+
|
| 546 |
+
auto_download = env_flag("LANCE_AUTO_DOWNLOAD", running_on_space())
|
| 547 |
+
if not auto_download:
|
| 548 |
+
missing = "\n".join(f"- {display_path(path)}" for path in required_paths if not path.exists())
|
| 549 |
+
raise FileNotFoundError(
|
| 550 |
+
"Lance model assets are missing. Set LANCE_MODEL_BASE_DIR or enable "
|
| 551 |
+
f"LANCE_AUTO_DOWNLOAD=1.\nMissing files:\n{missing}"
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
model_base_dir.mkdir(parents=True, exist_ok=True)
|
| 555 |
+
repo_id = os.getenv("LANCE_MODEL_REPO_ID", DEFAULT_MODEL_REPO_ID)
|
| 556 |
+
print(f"[startup] Downloading Lance model assets from {repo_id} to {display_path(model_base_dir)}", flush=True)
|
| 557 |
+
snapshot_path = Path(
|
| 558 |
+
snapshot_download(
|
| 559 |
+
repo_id=repo_id,
|
| 560 |
+
local_dir=str(model_base_dir),
|
| 561 |
+
local_dir_use_symlinks=False,
|
| 562 |
+
resume_download=True,
|
| 563 |
+
)
|
| 564 |
+
)
|
| 565 |
+
if snapshot_path != model_base_dir and not model_path.exists():
|
| 566 |
+
os.environ["LANCE_MODEL_BASE_DIR"] = display_path(snapshot_path)
|
| 567 |
+
model_path = get_model_path(model_variant)
|
| 568 |
+
return model_path
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def ensure_dirs() -> None:
|
| 572 |
+
TMP_INPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 573 |
+
RESULTS_ROOT.mkdir(parents=True, exist_ok=True)
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
def save_generation_record(record: dict, save_dir: Path) -> None:
|
| 577 |
+
ensure_dirs()
|
| 578 |
+
run_record_path = save_dir / RUN_RECORD_FILENAME
|
| 579 |
+
with run_record_path.open("w", encoding="utf-8") as f:
|
| 580 |
+
json.dump(record, f, ensure_ascii=False, indent=2)
|
| 581 |
+
|
| 582 |
+
with RECORD_WRITE_LOCK:
|
| 583 |
+
with GLOBAL_RECORDS_FILE.open("a", encoding="utf-8") as f:
|
| 584 |
+
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def normalize_seed(seed: int) -> int:
|
| 588 |
+
return random.randint(0, 2**31 - 1) if seed == -1 else seed
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
def normalize_task(task: str) -> str:
|
| 592 |
+
task_key = (task or TASK_LABEL_VIDEO_GENERATION).strip()
|
| 593 |
+
task = TASK_LABEL_TO_INTERNAL.get(task_key, TASK_LABEL_TO_INTERNAL.get(task_key.lower(), ""))
|
| 594 |
+
if task not in GENERATION_TASKS | UNDERSTANDING_TASKS:
|
| 595 |
+
raise ValueError(f"Unsupported task type: {task}")
|
| 596 |
+
return task
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def normalize_resolution_for_backend(resolution: str, task: str) -> str:
|
| 600 |
+
internal_task = normalize_task(task)
|
| 601 |
+
if internal_task in IMAGE_TASKS:
|
| 602 |
+
return DEFAULT_IMAGE_RESOLUTION
|
| 603 |
+
if internal_task in VIDEO_TASKS:
|
| 604 |
+
return DEFAULT_RESOLUTION
|
| 605 |
+
return str(resolution)
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def get_default_aspect_ratio(task: str) -> str:
|
| 609 |
+
internal_task = normalize_task(task)
|
| 610 |
+
return DEFAULT_IMAGE_ASPECT_RATIO if internal_task in IMAGE_TASKS else DEFAULT_VIDEO_ASPECT_RATIO
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
def get_size_for_aspect_ratio(task: str, aspect_ratio: str) -> tuple[int, int]:
|
| 614 |
+
internal_task = normalize_task(task)
|
| 615 |
+
aspect_ratio = aspect_ratio if aspect_ratio in ASPECT_RATIO_CHOICES else get_default_aspect_ratio(internal_task)
|
| 616 |
+
size_map = IMAGE_ASPECT_RATIO_TO_SIZE if internal_task in IMAGE_TASKS else VIDEO_ASPECT_RATIO_TO_SIZE
|
| 617 |
+
return size_map[aspect_ratio]
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
def format_size_markdown(task: str, width: int, height: int) -> str:
|
| 621 |
+
internal_task = normalize_task(task)
|
| 622 |
+
if internal_task in UNDERSTANDING_TASKS:
|
| 623 |
+
return ""
|
| 624 |
+
return f"{width} x {height}"
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
def normalize_frame_interpolation(value) -> bool:
|
| 628 |
+
if isinstance(value, bool):
|
| 629 |
+
return value
|
| 630 |
+
return str(value or "").strip().lower() in {"1", "true", "yes", "on", "open"}
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
def video_seconds_to_num_frames(seconds: int) -> int:
|
| 634 |
+
seconds = max(1, min(10, int(seconds)))
|
| 635 |
+
return 12 * seconds + 1
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
def update_size_from_aspect_ratio(task: str, aspect_ratio: str):
|
| 639 |
+
width, height = get_size_for_aspect_ratio(task, aspect_ratio)
|
| 640 |
+
return height, width, format_size_markdown(task, width, height)
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
def reset_generation_defaults_for_task(task: str):
|
| 644 |
+
internal_task = normalize_task(task)
|
| 645 |
+
aspect_ratio = get_default_aspect_ratio(internal_task)
|
| 646 |
+
width, height = get_size_for_aspect_ratio(internal_task, aspect_ratio)
|
| 647 |
+
resolution = DEFAULT_IMAGE_RESOLUTION if internal_task in IMAGE_TASKS else DEFAULT_RESOLUTION
|
| 648 |
+
num_frames = DEFAULT_VIDEO_DURATION_SECONDS if internal_task == TASK_T2V else 1
|
| 649 |
+
return aspect_ratio, height, width, num_frames, resolution, format_size_markdown(internal_task, width, height)
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
def apply_prompt_example(task: str, evt: gr.SelectData):
|
| 653 |
+
prompt_text = ""
|
| 654 |
+
if isinstance(evt.row_value, list) and evt.row_value:
|
| 655 |
+
prompt_text = str(evt.row_value[0])
|
| 656 |
+
elif isinstance(evt.value, list) and evt.value:
|
| 657 |
+
prompt_text = str(evt.value[0])
|
| 658 |
+
elif evt.value is not None:
|
| 659 |
+
prompt_text = str(evt.value)
|
| 660 |
+
defaults = reset_generation_defaults_for_task(task)
|
| 661 |
+
return (prompt_text, *defaults)
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
def get_understanding_system_prompt_choices(task: str) -> list[str]:
|
| 665 |
+
internal_task = normalize_task(task)
|
| 666 |
+
if internal_task == TASK_X2T_IMAGE:
|
| 667 |
+
return [I2T_QA_SYSTEM_PROMPT]
|
| 668 |
+
return [V2T_QA_SYSTEM_PROMPT]
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
def normalize_understanding_system_prompt(task: str, system_prompt: Optional[str]) -> str:
|
| 672 |
+
return get_understanding_system_prompt_choices(task)[0]
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
def create_request_json(
|
| 676 |
+
task: str,
|
| 677 |
+
prompt: str,
|
| 678 |
+
input_video: Optional[str],
|
| 679 |
+
input_image: Optional[str],
|
| 680 |
+
system_prompt: Optional[str] = None,
|
| 681 |
+
) -> Path:
|
| 682 |
+
ensure_dirs()
|
| 683 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 684 |
+
prompt_file = TMP_INPUT_DIR / f"{task}_{timestamp}.json"
|
| 685 |
+
|
| 686 |
+
if task == TASK_T2V:
|
| 687 |
+
payload = {"000000.mp4": prompt}
|
| 688 |
+
elif task == TASK_T2I:
|
| 689 |
+
payload = {"000000.png": prompt}
|
| 690 |
+
elif task == TASK_VIDEO_EDIT:
|
| 691 |
+
if not input_video:
|
| 692 |
+
raise ValueError("The video edit task requires an input video.")
|
| 693 |
+
payload = {
|
| 694 |
+
"000000": {
|
| 695 |
+
"interleave_array": [prompt, input_video, input_video],
|
| 696 |
+
"element_dtype_array": ["text", "video", "video"],
|
| 697 |
+
"istarget_in_interleave": [0, 0, 1],
|
| 698 |
+
}
|
| 699 |
+
}
|
| 700 |
+
elif task == TASK_IMAGE_EDIT:
|
| 701 |
+
if not input_image:
|
| 702 |
+
raise ValueError("The image edit task requires an input image.")
|
| 703 |
+
payload = {
|
| 704 |
+
"000000": {
|
| 705 |
+
"interleave_array": [prompt, input_image, input_image],
|
| 706 |
+
"element_dtype_array": ["text", "image", "image"],
|
| 707 |
+
"istarget_in_interleave": [0, 0, 1],
|
| 708 |
+
}
|
| 709 |
+
}
|
| 710 |
+
elif task == TASK_X2T_VIDEO:
|
| 711 |
+
if not input_video:
|
| 712 |
+
raise ValueError("The video understanding task requires an input video.")
|
| 713 |
+
system_prompt = normalize_understanding_system_prompt(task, system_prompt)
|
| 714 |
+
payload = {
|
| 715 |
+
"000000": {
|
| 716 |
+
"interleave_array": [input_video, [system_prompt, prompt, ""]],
|
| 717 |
+
"element_dtype_array": ["video", "text"],
|
| 718 |
+
"istarget_in_interleave": [0, 1],
|
| 719 |
+
}
|
| 720 |
+
}
|
| 721 |
+
elif task == TASK_X2T_IMAGE:
|
| 722 |
+
if not input_image:
|
| 723 |
+
raise ValueError("The image understanding task requires an input image.")
|
| 724 |
+
system_prompt = normalize_understanding_system_prompt(task, system_prompt)
|
| 725 |
+
payload = {
|
| 726 |
+
"000000": {
|
| 727 |
+
"interleave_array": [input_image, [system_prompt, prompt, ""]],
|
| 728 |
+
"element_dtype_array": ["image", "text"],
|
| 729 |
+
"istarget_in_interleave": [0, 1],
|
| 730 |
+
}
|
| 731 |
+
}
|
| 732 |
+
else:
|
| 733 |
+
raise ValueError(f"Unsupported task type: {task}")
|
| 734 |
+
|
| 735 |
+
with prompt_file.open("w", encoding="utf-8") as f:
|
| 736 |
+
json.dump(payload, f, ensure_ascii=False, indent=2)
|
| 737 |
+
return prompt_file
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
def resolve_example_path(path: str) -> str:
|
| 741 |
+
candidate = Path(path)
|
| 742 |
+
if candidate.is_absolute():
|
| 743 |
+
return str(candidate)
|
| 744 |
+
repo_candidate = (REPO_ROOT / candidate)
|
| 745 |
+
if repo_candidate.exists():
|
| 746 |
+
return str(repo_candidate.resolve())
|
| 747 |
+
if candidate.exists():
|
| 748 |
+
return str(candidate.resolve())
|
| 749 |
+
return path
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
def resolve_browser_video_example_path(path: str) -> str:
|
| 753 |
+
candidate = Path(path)
|
| 754 |
+
compatible_candidate = candidate.with_name(f"{candidate.stem}_h264{candidate.suffix}")
|
| 755 |
+
repo_compatible_candidate = REPO_ROOT / compatible_candidate
|
| 756 |
+
if not compatible_candidate.is_absolute() and repo_compatible_candidate.exists():
|
| 757 |
+
return str(repo_compatible_candidate.resolve())
|
| 758 |
+
if compatible_candidate.is_absolute() and compatible_candidate.exists():
|
| 759 |
+
return str(compatible_candidate.resolve())
|
| 760 |
+
repo_candidate = REPO_ROOT / candidate
|
| 761 |
+
if not candidate.is_absolute() and repo_candidate.exists():
|
| 762 |
+
return str(repo_candidate.resolve())
|
| 763 |
+
if candidate.is_absolute() and candidate.exists():
|
| 764 |
+
return str(candidate.resolve())
|
| 765 |
+
return resolve_example_path(path)
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
def load_json_examples(relative_path: str) -> dict:
|
| 769 |
+
path = REPO_ROOT / relative_path
|
| 770 |
+
with path.open("r", encoding="utf-8") as f:
|
| 771 |
+
return json.load(f)
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
T2V_EXAMPLE_SUMMARIES = {
|
| 775 |
+
"000000.mp4": "Red panda surfing on a bright seaside wave.",
|
| 776 |
+
"000002.mp4": "Panda cub skateboarding in a creative loft.",
|
| 777 |
+
"000004.mp4": "Young woman shaping clay in a sunlit pottery workshop.",
|
| 778 |
+
"000005.mp4": "Panda boxing a robot in a luxurious palace ring.",
|
| 779 |
+
"000008.mp4": "Fantasy pastel horse stepping through a glowing cloud valley.",
|
| 780 |
+
}
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
def make_generation_examples(
|
| 784 |
+
task_label: str,
|
| 785 |
+
relative_path: str,
|
| 786 |
+
limit: int,
|
| 787 |
+
image_task: bool,
|
| 788 |
+
selected_keys: Optional[list[str]] = None,
|
| 789 |
+
summaries: Optional[dict[str, str]] = None,
|
| 790 |
+
) -> list[list]:
|
| 791 |
+
data = load_json_examples(relative_path)
|
| 792 |
+
items = [(key, data[key]) for key in selected_keys if key in data] if selected_keys else list(data.items())[:limit]
|
| 793 |
+
examples = []
|
| 794 |
+
for output_name, prompt in items:
|
| 795 |
+
examples.append([prompt])
|
| 796 |
+
return examples
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
def make_edit_examples(task_label: str, relative_path: str, limit: int, media_type: str) -> list[list]:
|
| 800 |
+
data = load_json_examples(relative_path)
|
| 801 |
+
examples = []
|
| 802 |
+
for sample in list(data.values())[:limit]:
|
| 803 |
+
interleave = sample["interleave_array"]
|
| 804 |
+
prompt = interleave[0]
|
| 805 |
+
media_path = resolve_example_path(interleave[1])
|
| 806 |
+
examples.append([
|
| 807 |
+
prompt,
|
| 808 |
+
media_path if media_type == "video" else None,
|
| 809 |
+
media_path if media_type == "image" else None,
|
| 810 |
+
])
|
| 811 |
+
return examples
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
def make_understanding_examples(task_label: str, relative_path: str, limit: int, media_type: str) -> list[list]:
|
| 815 |
+
data = load_json_examples(relative_path)
|
| 816 |
+
examples = []
|
| 817 |
+
for sample in list(data.values())[:limit]:
|
| 818 |
+
interleave = sample["interleave_array"]
|
| 819 |
+
media_path = (
|
| 820 |
+
resolve_browser_video_example_path(interleave[0])
|
| 821 |
+
if media_type == "video"
|
| 822 |
+
else resolve_example_path(interleave[0])
|
| 823 |
+
)
|
| 824 |
+
text_payload = interleave[1]
|
| 825 |
+
question = text_payload[1] if isinstance(text_payload, list) and len(text_payload) > 1 else ""
|
| 826 |
+
examples.append([
|
| 827 |
+
question,
|
| 828 |
+
media_path if media_type == "video" else None,
|
| 829 |
+
media_path if media_type == "image" else None,
|
| 830 |
+
])
|
| 831 |
+
return examples
|
| 832 |
+
|
| 833 |
+
|
| 834 |
+
def make_understanding_system_prompt_map(relative_path: str, task: str) -> dict[str, str]:
|
| 835 |
+
data = load_json_examples(relative_path)
|
| 836 |
+
system_prompts = {}
|
| 837 |
+
for sample in data.values():
|
| 838 |
+
interleave = sample["interleave_array"]
|
| 839 |
+
text_payload = interleave[1]
|
| 840 |
+
if not isinstance(text_payload, list) or len(text_payload) < 2:
|
| 841 |
+
continue
|
| 842 |
+
system_prompts[text_payload[1]] = normalize_understanding_system_prompt(task, text_payload[0])
|
| 843 |
+
return system_prompts
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
VIDEO_GENERATION_EXAMPLES = make_generation_examples(
|
| 847 |
+
TASK_LABEL_VIDEO_GENERATION,
|
| 848 |
+
"config/examples/t2v_example.json",
|
| 849 |
+
limit=6,
|
| 850 |
+
image_task=False,
|
| 851 |
+
#selected_keys=["000000.mp4", "000002.mp4", "000005.mp4", "000004.mp4", "000008.mp4"],
|
| 852 |
+
selected_keys=["000004.mp4", "000002.mp4", "000000.mp4", "000005.mp4", "000008.mp4", "000007.mp4"],
|
| 853 |
+
summaries=T2V_EXAMPLE_SUMMARIES,
|
| 854 |
+
)
|
| 855 |
+
VIDEO_EDIT_EXAMPLES = make_edit_examples(
|
| 856 |
+
TASK_LABEL_VIDEO_EDIT,
|
| 857 |
+
"config/examples/video_edit_example.json",
|
| 858 |
+
limit=3,
|
| 859 |
+
media_type="video",
|
| 860 |
+
)
|
| 861 |
+
VIDEO_UNDERSTANDING_EXAMPLES = make_understanding_examples(
|
| 862 |
+
TASK_LABEL_VIDEO_UNDERSTANDING,
|
| 863 |
+
"config/examples/x2t_video_example.json",
|
| 864 |
+
limit=3,
|
| 865 |
+
media_type="video",
|
| 866 |
+
)
|
| 867 |
+
VIDEO_UNDERSTANDING_SYSTEM_PROMPTS = make_understanding_system_prompt_map(
|
| 868 |
+
"config/examples/x2t_video_example.json",
|
| 869 |
+
TASK_X2T_VIDEO,
|
| 870 |
+
)
|
| 871 |
+
IMAGE_GENERATION_EXAMPLES = make_generation_examples(
|
| 872 |
+
TASK_LABEL_IMAGE_GENERATION,
|
| 873 |
+
"config/examples/t2i_example.json",
|
| 874 |
+
limit=5,
|
| 875 |
+
image_task=True,
|
| 876 |
+
selected_keys=["000000.png", "000003.png", "000006.png", "000008.png", "000009.png"],
|
| 877 |
+
)
|
| 878 |
+
IMAGE_EDIT_EXAMPLES = make_edit_examples(
|
| 879 |
+
TASK_LABEL_IMAGE_EDIT,
|
| 880 |
+
"config/examples/image_edit_example.json",
|
| 881 |
+
limit=5,
|
| 882 |
+
media_type="image",
|
| 883 |
+
)
|
| 884 |
+
IMAGE_UNDERSTANDING_EXAMPLES = make_understanding_examples(
|
| 885 |
+
TASK_LABEL_IMAGE_UNDERSTANDING,
|
| 886 |
+
"config/examples/x2t_image_example.json",
|
| 887 |
+
limit=3,
|
| 888 |
+
media_type="image",
|
| 889 |
+
)
|
| 890 |
+
IMAGE_UNDERSTANDING_SYSTEM_PROMPTS = make_understanding_system_prompt_map(
|
| 891 |
+
"config/examples/x2t_image_example.json",
|
| 892 |
+
TASK_X2T_IMAGE,
|
| 893 |
+
)
|
| 894 |
+
|
| 895 |
+
|
| 896 |
+
def build_save_dir(task: str) -> Path:
|
| 897 |
+
ensure_dirs()
|
| 898 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 899 |
+
return RESULTS_ROOT / f"{task}_{timestamp}_{int(time.time() * 1000) % 1000:03d}"
|
| 900 |
+
|
| 901 |
+
|
| 902 |
+
def find_generated_video(save_dir: Path) -> Optional[Path]:
|
| 903 |
+
videos = sorted(save_dir.glob("*.mp4"), key=lambda p: p.stat().st_mtime, reverse=True)
|
| 904 |
+
return videos[0] if videos else None
|
| 905 |
+
|
| 906 |
+
|
| 907 |
+
def find_generated_image(save_dir: Path) -> Optional[Path]:
|
| 908 |
+
images = sorted(save_dir.glob("*.png"), key=lambda p: p.stat().st_mtime, reverse=True)
|
| 909 |
+
return images[0] if images else None
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
def run_rife_interpolation(video_path: Path, device_id: int, exp: int = 1) -> tuple[Path, str]:
|
| 913 |
+
rife_dir = REPO_ROOT / "RIFE"
|
| 914 |
+
rife_script = rife_dir / "inference_video.py"
|
| 915 |
+
if not rife_script.exists():
|
| 916 |
+
raise FileNotFoundError(f"RIFE inference script not found: {rife_script}")
|
| 917 |
+
|
| 918 |
+
output_path = video_path.with_name(f"{video_path.stem}_rife_{2 ** exp}x{video_path.suffix}")
|
| 919 |
+
env = os.environ.copy()
|
| 920 |
+
env["CUDA_VISIBLE_DEVICES"] = str(device_id)
|
| 921 |
+
command = [
|
| 922 |
+
"python3",
|
| 923 |
+
str(rife_script),
|
| 924 |
+
"--exp",
|
| 925 |
+
str(exp),
|
| 926 |
+
"--video",
|
| 927 |
+
str(video_path),
|
| 928 |
+
"--output",
|
| 929 |
+
str(output_path),
|
| 930 |
+
"--model",
|
| 931 |
+
str(rife_dir / "train_log"),
|
| 932 |
+
]
|
| 933 |
+
rife_start = time.perf_counter()
|
| 934 |
+
try:
|
| 935 |
+
completed = subprocess.run(
|
| 936 |
+
command,
|
| 937 |
+
cwd=str(video_path.parent),
|
| 938 |
+
env=env,
|
| 939 |
+
check=True,
|
| 940 |
+
capture_output=True,
|
| 941 |
+
text=True,
|
| 942 |
+
)
|
| 943 |
+
except subprocess.CalledProcessError as exc:
|
| 944 |
+
raise RuntimeError(
|
| 945 |
+
"\n".join(
|
| 946 |
+
[
|
| 947 |
+
f"RIFE failed with exit code {exc.returncode}.",
|
| 948 |
+
f"command=CUDA_VISIBLE_DEVICES={device_id} {' '.join(command)}",
|
| 949 |
+
exc.stdout.strip() if exc.stdout else "",
|
| 950 |
+
exc.stderr.strip() if exc.stderr else "",
|
| 951 |
+
]
|
| 952 |
+
).strip()
|
| 953 |
+
) from exc
|
| 954 |
+
if not output_path.exists():
|
| 955 |
+
raise FileNotFoundError(f"RIFE completed but output video was not found: {output_path}")
|
| 956 |
+
elapsed = time.perf_counter() - rife_start
|
| 957 |
+
log = "\n".join(
|
| 958 |
+
[
|
| 959 |
+
"[rife] Frame interpolation finished.",
|
| 960 |
+
f"command=CUDA_VISIBLE_DEVICES={device_id} {' '.join(command)}",
|
| 961 |
+
f"elapsed={elapsed:.2f}s",
|
| 962 |
+
f"output={output_path}",
|
| 963 |
+
completed.stdout.strip(),
|
| 964 |
+
completed.stderr.strip(),
|
| 965 |
+
]
|
| 966 |
+
).strip()
|
| 967 |
+
return output_path, log
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
def extract_text_result(save_dir: Path) -> str:
|
| 971 |
+
prompt_result_path = save_dir / PROMPT_JSON_FILENAME
|
| 972 |
+
if not prompt_result_path.exists():
|
| 973 |
+
return ""
|
| 974 |
+
with prompt_result_path.open("r", encoding="utf-8") as f:
|
| 975 |
+
data = json.load(f)
|
| 976 |
+
if not data:
|
| 977 |
+
return ""
|
| 978 |
+
first_value = next(iter(data.values()))
|
| 979 |
+
return first_value if isinstance(first_value, str) else json.dumps(first_value, ensure_ascii=False)
|
| 980 |
+
|
| 981 |
+
|
| 982 |
+
class LanceT2VV2TPipeline:
|
| 983 |
+
def __init__(self, device_id: int, model_variant: str = MODEL_VARIANT_VIDEO) -> None:
|
| 984 |
+
self._init_lock = threading.Lock()
|
| 985 |
+
self._generate_lock = threading.Lock()
|
| 986 |
+
self.initialized = False
|
| 987 |
+
self.device = device_id
|
| 988 |
+
self.model_variant = normalize_model_variant(model_variant)
|
| 989 |
+
self.logger = get_logger(f"lance_{self.model_variant}_gpu{device_id}")
|
| 990 |
+
|
| 991 |
+
self.model: Optional[Lance] = None
|
| 992 |
+
self.vae_model: Optional[WanVideoVAE] = None
|
| 993 |
+
self.vae_config: Optional[AutoEncoderParams] = None
|
| 994 |
+
self.tokenizer: Optional[Qwen2Tokenizer] = None
|
| 995 |
+
self.new_token_ids: Optional[dict] = None
|
| 996 |
+
self.image_token_id: Optional[int] = None
|
| 997 |
+
self.base_model_args: Optional[ModelArguments] = None
|
| 998 |
+
self.base_data_args: Optional[DataArguments] = None
|
| 999 |
+
self.base_inference_args: Optional[InferenceArguments] = None
|
| 1000 |
+
|
| 1001 |
+
def _log_stage(self, stage_name: str, start_time: float, extra: str = "") -> None:
|
| 1002 |
+
elapsed = time.perf_counter() - start_time
|
| 1003 |
+
suffix = f" | {extra}" if extra else ""
|
| 1004 |
+
print(f"[startup][gpu:{self.device}] {stage_name} done in {elapsed:.2f}s{suffix}", flush=True)
|
| 1005 |
+
|
| 1006 |
+
def _build_base_model_args(self) -> ModelArguments:
|
| 1007 |
+
model_path = str(get_model_path(self.model_variant))
|
| 1008 |
+
return ModelArguments(
|
| 1009 |
+
model_path=model_path,
|
| 1010 |
+
vit_type=DEFAULT_VIT_TYPE,
|
| 1011 |
+
llm_qk_norm=True,
|
| 1012 |
+
llm_qk_norm_und=True,
|
| 1013 |
+
llm_qk_norm_gen=True,
|
| 1014 |
+
tie_word_embeddings=False,
|
| 1015 |
+
max_num_frames=121,
|
| 1016 |
+
max_latent_size=64,
|
| 1017 |
+
latent_patch_size=[1, 1, 1],
|
| 1018 |
+
)
|
| 1019 |
+
|
| 1020 |
+
def _build_base_inference_args(self) -> InferenceArguments:
|
| 1021 |
+
return InferenceArguments(
|
| 1022 |
+
validation_num_timesteps=DEFAULT_TIMESTEPS,
|
| 1023 |
+
validation_timestep_shift=DEFAULT_TIMESTEP_SHIFT,
|
| 1024 |
+
copy_init_moe=True,
|
| 1025 |
+
visual_und=True,
|
| 1026 |
+
visual_gen=True,
|
| 1027 |
+
vae_model_type="wan",
|
| 1028 |
+
apply_qwen_2_5_vl_pos_emb=True,
|
| 1029 |
+
apply_chat_template=False,
|
| 1030 |
+
cfg_type=0,
|
| 1031 |
+
validation_data_seed=42,
|
| 1032 |
+
video_height=DEFAULT_HEIGHT,
|
| 1033 |
+
video_width=DEFAULT_WIDTH,
|
| 1034 |
+
num_frames=DEFAULT_NUM_FRAMES,
|
| 1035 |
+
task=DEFAULT_TASK,
|
| 1036 |
+
save_path_gen=str(RESULTS_ROOT),
|
| 1037 |
+
resolution=DEFAULT_RESOLUTION,
|
| 1038 |
+
text_template=TEXT_TEMPLATE,
|
| 1039 |
+
use_KVcache=USE_KVCACHE,
|
| 1040 |
+
)
|
| 1041 |
+
|
| 1042 |
+
def initialize(self) -> None:
|
| 1043 |
+
with self._init_lock:
|
| 1044 |
+
if self.initialized:
|
| 1045 |
+
return
|
| 1046 |
+
|
| 1047 |
+
ensure_dirs()
|
| 1048 |
+
resolved_model_path = ensure_model_assets(self.model_variant)
|
| 1049 |
+
print(
|
| 1050 |
+
f"[startup][gpu:{self.device}][{self.model_variant}] Using Lance model path: {resolved_model_path}",
|
| 1051 |
+
flush=True,
|
| 1052 |
+
)
|
| 1053 |
+
if not torch.cuda.is_available():
|
| 1054 |
+
raise RuntimeError("CUDA is unavailable. Lance T2V/V2T Gradio requires a GPU environment.")
|
| 1055 |
+
if self.device >= torch.cuda.device_count():
|
| 1056 |
+
raise RuntimeError(
|
| 1057 |
+
f"GPU {self.device} is unavailable. Detected {torch.cuda.device_count()} GPU(s)."
|
| 1058 |
+
)
|
| 1059 |
+
torch.cuda.set_device(self.device)
|
| 1060 |
+
|
| 1061 |
+
model_args = self._build_base_model_args()
|
| 1062 |
+
data_args = DataArguments()
|
| 1063 |
+
inference_args = self._build_base_inference_args()
|
| 1064 |
+
apply_inference_defaults(model_args, data_args, inference_args)
|
| 1065 |
+
inference_args.validation_noise_seed = inference_args.validation_data_seed
|
| 1066 |
+
|
| 1067 |
+
self.base_model_args = model_args
|
| 1068 |
+
self.base_data_args = data_args
|
| 1069 |
+
self.base_inference_args = inference_args
|
| 1070 |
+
|
| 1071 |
+
set_seed(inference_args.global_seed)
|
| 1072 |
+
|
| 1073 |
+
stage_start = time.perf_counter()
|
| 1074 |
+
print(
|
| 1075 |
+
f"[startup][gpu:{self.device}] Loading LLM config: {Path(model_args.model_path) / 'llm_config.json'}",
|
| 1076 |
+
flush=True,
|
| 1077 |
+
)
|
| 1078 |
+
llm_config: Qwen2Config = Qwen2Config.from_json_file(str(Path(model_args.model_path) / "llm_config.json"))
|
| 1079 |
+
self._log_stage("LLM config load", stage_start)
|
| 1080 |
+
|
| 1081 |
+
llm_config.layer_module = model_args.layer_module
|
| 1082 |
+
llm_config.qk_norm = model_args.llm_qk_norm
|
| 1083 |
+
llm_config.qk_norm_und = model_args.llm_qk_norm_und
|
| 1084 |
+
llm_config.qk_norm_gen = model_args.llm_qk_norm_gen
|
| 1085 |
+
llm_config.tie_word_embeddings = model_args.tie_word_embeddings
|
| 1086 |
+
llm_config.freeze_und = inference_args.freeze_und
|
| 1087 |
+
llm_config.apply_qwen_2_5_vl_pos_emb = inference_args.apply_qwen_2_5_vl_pos_emb
|
| 1088 |
+
|
| 1089 |
+
stage_start = time.perf_counter()
|
| 1090 |
+
print(f"[startup][gpu:{self.device}] Initializing LLM weights: {model_args.model_path}", flush=True)
|
| 1091 |
+
language_model: Qwen2ForCausalLM = Qwen2ForCausalLM(llm_config)
|
| 1092 |
+
self._log_stage("LLM weight init", stage_start)
|
| 1093 |
+
|
| 1094 |
+
vit_model = None
|
| 1095 |
+
vit_config = None
|
| 1096 |
+
if inference_args.visual_und:
|
| 1097 |
+
if model_args.vit_type not in ("qwen2_5_vl", "qwen_2_5_vl_original"):
|
| 1098 |
+
raise ValueError(f"Unsupported vit_type: {model_args.vit_type}")
|
| 1099 |
+
stage_start = time.perf_counter()
|
| 1100 |
+
print(f"[startup][gpu:{self.device}] Loading VIT config: {model_args.vit_path}", flush=True)
|
| 1101 |
+
vit_config = Qwen2_5_VLVisionConfig.from_pretrained(model_args.vit_path)
|
| 1102 |
+
self._log_stage("VIT config load", stage_start)
|
| 1103 |
+
|
| 1104 |
+
stage_start = time.perf_counter()
|
| 1105 |
+
print(
|
| 1106 |
+
f"[startup][gpu:{self.device}] Loading VIT weights: {Path(model_args.vit_path) / 'vit.safetensors'}",
|
| 1107 |
+
flush=True,
|
| 1108 |
+
)
|
| 1109 |
+
vit_model = Qwen2_5_VisionTransformerPretrainedModel(vit_config)
|
| 1110 |
+
vit_weights = load_file(str(Path(model_args.vit_path) / "vit.safetensors"))
|
| 1111 |
+
vit_model.load_state_dict(vit_weights, strict=True)
|
| 1112 |
+
self._log_stage("VIT weight load", stage_start)
|
| 1113 |
+
clean_memory(vit_weights)
|
| 1114 |
+
|
| 1115 |
+
if inference_args.visual_gen:
|
| 1116 |
+
stage_start = time.perf_counter()
|
| 1117 |
+
print(f"[startup][gpu:{self.device}] Initializing VAE", flush=True)
|
| 1118 |
+
vae_model = WanVideoVAE()
|
| 1119 |
+
vae_config = deepcopy(vae_model.vae_config)
|
| 1120 |
+
self._log_stage("VAE init", stage_start)
|
| 1121 |
+
else:
|
| 1122 |
+
vae_model = None
|
| 1123 |
+
vae_config = None
|
| 1124 |
+
|
| 1125 |
+
config = LanceConfig(
|
| 1126 |
+
visual_gen=inference_args.visual_gen,
|
| 1127 |
+
visual_und=inference_args.visual_und,
|
| 1128 |
+
llm_config=llm_config,
|
| 1129 |
+
vit_config=vit_config if inference_args.visual_und else None,
|
| 1130 |
+
vae_config=vae_config if inference_args.visual_gen else None,
|
| 1131 |
+
latent_patch_size=model_args.latent_patch_size,
|
| 1132 |
+
max_num_frames=model_args.max_num_frames,
|
| 1133 |
+
max_latent_size=model_args.max_latent_size,
|
| 1134 |
+
vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side,
|
| 1135 |
+
connector_act=model_args.connector_act,
|
| 1136 |
+
interpolate_pos=model_args.interpolate_pos,
|
| 1137 |
+
timestep_shift=inference_args.timestep_shift,
|
| 1138 |
+
)
|
| 1139 |
+
model: Lance = Lance(
|
| 1140 |
+
language_model=language_model,
|
| 1141 |
+
vit_model=vit_model if inference_args.visual_und else None,
|
| 1142 |
+
vit_type=model_args.vit_type,
|
| 1143 |
+
config=config,
|
| 1144 |
+
training_args=inference_args,
|
| 1145 |
+
)
|
| 1146 |
+
|
| 1147 |
+
stage_start = time.perf_counter()
|
| 1148 |
+
print(f"[startup][gpu:{self.device}] Moving Lance model to GPU {self.device}", flush=True)
|
| 1149 |
+
model = model.to(self.device)
|
| 1150 |
+
self._log_stage("Lance model move to GPU", stage_start)
|
| 1151 |
+
|
| 1152 |
+
stage_start = time.perf_counter()
|
| 1153 |
+
print(f"[startup][gpu:{self.device}] Loading tokenizer: {model_args.model_path}", flush=True)
|
| 1154 |
+
tokenizer: Qwen2Tokenizer = Qwen2Tokenizer.from_pretrained(model_args.model_path)
|
| 1155 |
+
tokenizer, new_token_ids, num_new_tokens = add_special_tokens(tokenizer)
|
| 1156 |
+
self._log_stage("tokenizer load and special token init", stage_start, extra=f"num_new_tokens={num_new_tokens}")
|
| 1157 |
+
|
| 1158 |
+
if inference_args.copy_init_moe:
|
| 1159 |
+
language_model.init_moe()
|
| 1160 |
+
|
| 1161 |
+
init_from_model_path_if_needed(model, model_args)
|
| 1162 |
+
|
| 1163 |
+
if num_new_tokens > 0:
|
| 1164 |
+
model.language_model.resize_token_embeddings(len(tokenizer))
|
| 1165 |
+
model.config.llm_config.vocab_size = len(tokenizer)
|
| 1166 |
+
model.language_model.config.vocab_size = len(tokenizer)
|
| 1167 |
+
|
| 1168 |
+
if model_args.vit_type.lower() == "qwen2_5_vl":
|
| 1169 |
+
from common.model.hacks import hack_qwen2_5_vl_config
|
| 1170 |
+
|
| 1171 |
+
language_model = hack_qwen2_5_vl_config(language_model)
|
| 1172 |
+
|
| 1173 |
+
image_token_id = language_model.config.video_token_id
|
| 1174 |
+
new_token_ids.update({"image_token_id": image_token_id})
|
| 1175 |
+
model.update_tokenizer(tokenizer=tokenizer)
|
| 1176 |
+
|
| 1177 |
+
if model_args.tie_word_embeddings:
|
| 1178 |
+
model.language_model.untie_lm_head()
|
| 1179 |
+
model.language_model.copy_new_token_rows_to_lm_head(num_new_tokens)
|
| 1180 |
+
model_args.tie_word_embeddings = False
|
| 1181 |
+
llm_config.tie_word_embeddings = False
|
| 1182 |
+
else:
|
| 1183 |
+
assert (
|
| 1184 |
+
model.language_model.get_input_embeddings().weight.data.data_ptr()
|
| 1185 |
+
!= model.language_model.get_output_embeddings().weight.data.data_ptr()
|
| 1186 |
+
), "tie_word_embeddings conflict"
|
| 1187 |
+
|
| 1188 |
+
model = model.to(device=self.device, dtype=torch.bfloat16)
|
| 1189 |
+
model.eval()
|
| 1190 |
+
if vae_model is not None and hasattr(vae_model, "eval"):
|
| 1191 |
+
vae_model.eval()
|
| 1192 |
+
|
| 1193 |
+
self.model = model
|
| 1194 |
+
self.vae_model = vae_model
|
| 1195 |
+
self.vae_config = vae_config
|
| 1196 |
+
self.tokenizer = tokenizer
|
| 1197 |
+
self.new_token_ids = new_token_ids
|
| 1198 |
+
self.image_token_id = image_token_id
|
| 1199 |
+
self.initialized = True
|
| 1200 |
+
print(
|
| 1201 |
+
f"[startup][gpu:{self.device}][{self.model_variant}] Lance multimodal Gradio model loaded and ready for reuse.",
|
| 1202 |
+
flush=True,
|
| 1203 |
+
)
|
| 1204 |
+
|
| 1205 |
+
def unload(self) -> None:
|
| 1206 |
+
with self._init_lock:
|
| 1207 |
+
if self.model is not None:
|
| 1208 |
+
self.model.cpu()
|
| 1209 |
+
if self.vae_model is not None and hasattr(self.vae_model, "vae"):
|
| 1210 |
+
vae_inner = self.vae_model.vae
|
| 1211 |
+
if hasattr(vae_inner, "model"):
|
| 1212 |
+
vae_inner.model.cpu()
|
| 1213 |
+
|
| 1214 |
+
self.model = None
|
| 1215 |
+
self.vae_model = None
|
| 1216 |
+
self.vae_config = None
|
| 1217 |
+
self.tokenizer = None
|
| 1218 |
+
self.new_token_ids = None
|
| 1219 |
+
self.image_token_id = None
|
| 1220 |
+
self.base_model_args = None
|
| 1221 |
+
self.base_data_args = None
|
| 1222 |
+
self.base_inference_args = None
|
| 1223 |
+
self.initialized = False
|
| 1224 |
+
gc.collect()
|
| 1225 |
+
if torch.cuda.is_available():
|
| 1226 |
+
with torch.cuda.device(self.device):
|
| 1227 |
+
torch.cuda.empty_cache()
|
| 1228 |
+
torch.cuda.ipc_collect()
|
| 1229 |
+
|
| 1230 |
+
def _build_request_batch(
|
| 1231 |
+
self,
|
| 1232 |
+
prompt_file: Path,
|
| 1233 |
+
model_args: ModelArguments,
|
| 1234 |
+
data_args: DataArguments,
|
| 1235 |
+
inference_args: InferenceArguments,
|
| 1236 |
+
):
|
| 1237 |
+
assert self.tokenizer is not None
|
| 1238 |
+
assert self.new_token_ids is not None
|
| 1239 |
+
assert self.vae_config is not None
|
| 1240 |
+
|
| 1241 |
+
dataset_config = DataConfig.from_yaml(str(prompt_file))
|
| 1242 |
+
if inference_args.visual_und:
|
| 1243 |
+
dataset_config.vit_patch_size = model_args.vit_patch_size
|
| 1244 |
+
dataset_config.vit_patch_size_temporal = model_args.vit_patch_size_temporal
|
| 1245 |
+
dataset_config.vit_max_num_patch_per_side = model_args.vit_max_num_patch_per_side
|
| 1246 |
+
if inference_args.visual_gen:
|
| 1247 |
+
vae_downsample = tuple_mul(
|
| 1248 |
+
tuple(model_args.latent_patch_size),
|
| 1249 |
+
(
|
| 1250 |
+
self.vae_config.downsample_temporal,
|
| 1251 |
+
self.vae_config.downsample_spatial,
|
| 1252 |
+
self.vae_config.downsample_spatial,
|
| 1253 |
+
),
|
| 1254 |
+
)
|
| 1255 |
+
dataset_config.latent_patch_size = model_args.latent_patch_size
|
| 1256 |
+
dataset_config.vae_downsample = vae_downsample
|
| 1257 |
+
dataset_config.max_latent_size = model_args.max_latent_size
|
| 1258 |
+
dataset_config.max_num_frames = model_args.max_num_frames
|
| 1259 |
+
|
| 1260 |
+
dataset_config.text_cond_dropout_prob = model_args.text_cond_dropout_prob
|
| 1261 |
+
dataset_config.vae_cond_dropout_prob = model_args.vae_cond_dropout_prob
|
| 1262 |
+
dataset_config.vit_cond_dropout_prob = model_args.vit_cond_dropout_prob
|
| 1263 |
+
|
| 1264 |
+
dataset_config.num_frames = inference_args.num_frames
|
| 1265 |
+
dataset_config.H = inference_args.video_height
|
| 1266 |
+
dataset_config.W = inference_args.video_width
|
| 1267 |
+
dataset_config.task = inference_args.task
|
| 1268 |
+
dataset_config.resolution = inference_args.resolution
|
| 1269 |
+
dataset_config.text_template = inference_args.text_template
|
| 1270 |
+
|
| 1271 |
+
val_dataset = ValidationDataset(
|
| 1272 |
+
jsonl_path=str(prompt_file),
|
| 1273 |
+
tokenizer=self.tokenizer,
|
| 1274 |
+
data_args=data_args,
|
| 1275 |
+
model_args=model_args,
|
| 1276 |
+
training_args=inference_args,
|
| 1277 |
+
new_token_ids=self.new_token_ids,
|
| 1278 |
+
dataset_config=dataset_config,
|
| 1279 |
+
local_rank=0,
|
| 1280 |
+
world_size=1,
|
| 1281 |
+
)
|
| 1282 |
+
return simple_custom_collate([val_dataset[0]])
|
| 1283 |
+
|
| 1284 |
+
def generate(
|
| 1285 |
+
self,
|
| 1286 |
+
task: str,
|
| 1287 |
+
prompt: str,
|
| 1288 |
+
system_prompt: Optional[str],
|
| 1289 |
+
input_video: Optional[str],
|
| 1290 |
+
input_image: Optional[str],
|
| 1291 |
+
height: int,
|
| 1292 |
+
width: int,
|
| 1293 |
+
num_frames: int,
|
| 1294 |
+
seed: int,
|
| 1295 |
+
resolution: str,
|
| 1296 |
+
validation_num_timesteps: int,
|
| 1297 |
+
validation_timestep_shift: float,
|
| 1298 |
+
cfg_text_scale: float,
|
| 1299 |
+
enable_frame_interpolation: bool,
|
| 1300 |
+
):
|
| 1301 |
+
self.initialize()
|
| 1302 |
+
internal_task = normalize_task(task)
|
| 1303 |
+
prompt = (prompt or "").strip()
|
| 1304 |
+
input_video = str(input_video).strip() if input_video else ""
|
| 1305 |
+
input_image = str(input_image).strip() if input_image else ""
|
| 1306 |
+
|
| 1307 |
+
if internal_task in GENERATION_TASKS and not prompt:
|
| 1308 |
+
return None, None, "", "Please enter a prompt.", ""
|
| 1309 |
+
if internal_task in UNDERSTANDING_TASKS and not prompt:
|
| 1310 |
+
return None, None, "", "Please enter a question.", ""
|
| 1311 |
+
if internal_task in {TASK_VIDEO_EDIT, TASK_X2T_VIDEO} and not input_video:
|
| 1312 |
+
return None, None, "", "Please upload an input video.", ""
|
| 1313 |
+
if internal_task in {TASK_IMAGE_EDIT, TASK_X2T_IMAGE} and not input_image:
|
| 1314 |
+
return None, None, "", "Please upload an input image.", ""
|
| 1315 |
+
if height <= 0 or width <= 0:
|
| 1316 |
+
return None, None, "", "Height and width must be greater than 0.", ""
|
| 1317 |
+
if num_frames <= 0:
|
| 1318 |
+
return None, None, "", "The number of frames must be greater than 0.", ""
|
| 1319 |
+
|
| 1320 |
+
assert self.model is not None
|
| 1321 |
+
assert self.tokenizer is not None
|
| 1322 |
+
assert self.new_token_ids is not None
|
| 1323 |
+
assert self.image_token_id is not None
|
| 1324 |
+
assert self.base_model_args is not None
|
| 1325 |
+
assert self.base_data_args is not None
|
| 1326 |
+
assert self.base_inference_args is not None
|
| 1327 |
+
active_model_path = self.base_model_args.model_path
|
| 1328 |
+
|
| 1329 |
+
with self._generate_lock:
|
| 1330 |
+
torch.cuda.set_device(self.device)
|
| 1331 |
+
actual_seed = normalize_seed(int(seed))
|
| 1332 |
+
prompt_file = create_request_json(
|
| 1333 |
+
task=internal_task,
|
| 1334 |
+
prompt=prompt,
|
| 1335 |
+
input_video=input_video,
|
| 1336 |
+
input_image=input_image,
|
| 1337 |
+
system_prompt=system_prompt,
|
| 1338 |
+
)
|
| 1339 |
+
save_dir = build_save_dir(internal_task)
|
| 1340 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 1341 |
+
request_started_at = datetime.now().isoformat(timespec="seconds")
|
| 1342 |
+
|
| 1343 |
+
request_model_args = deepcopy(self.base_model_args)
|
| 1344 |
+
request_model_args.cfg_text_scale = float(cfg_text_scale)
|
| 1345 |
+
|
| 1346 |
+
request_data_args = deepcopy(self.base_data_args)
|
| 1347 |
+
request_data_args.val_dataset_config_file = str(prompt_file)
|
| 1348 |
+
|
| 1349 |
+
request_inference_args = deepcopy(self.base_inference_args)
|
| 1350 |
+
request_inference_args.validation_num_timesteps = int(validation_num_timesteps)
|
| 1351 |
+
request_inference_args.validation_timestep_shift = float(validation_timestep_shift)
|
| 1352 |
+
request_inference_args.validation_data_seed = actual_seed
|
| 1353 |
+
request_inference_args.validation_noise_seed = actual_seed
|
| 1354 |
+
request_inference_args.video_height = int(height)
|
| 1355 |
+
request_inference_args.video_width = int(width)
|
| 1356 |
+
request_inference_args.num_frames = int(num_frames)
|
| 1357 |
+
display_resolution = str(resolution)
|
| 1358 |
+
backend_resolution = normalize_resolution_for_backend(display_resolution, internal_task)
|
| 1359 |
+
request_inference_args.resolution = backend_resolution
|
| 1360 |
+
request_inference_args.save_path_gen = str(save_dir)
|
| 1361 |
+
request_inference_args.task = internal_task
|
| 1362 |
+
request_inference_args.text_template = TEXT_TEMPLATE
|
| 1363 |
+
request_inference_args.prompt_data_dict = {}
|
| 1364 |
+
|
| 1365 |
+
try:
|
| 1366 |
+
print(
|
| 1367 |
+
"[lance_gradio_t2v_v2t] Start generation "
|
| 1368 |
+
f"| task={internal_task} | gpu={self.device} | seed={actual_seed} | "
|
| 1369 |
+
f"size={height}x{width} | frames={num_frames} | resolution={display_resolution}",
|
| 1370 |
+
flush=True,
|
| 1371 |
+
)
|
| 1372 |
+
val_data_cpu = self._build_request_batch(
|
| 1373 |
+
prompt_file=prompt_file,
|
| 1374 |
+
model_args=request_model_args,
|
| 1375 |
+
data_args=request_data_args,
|
| 1376 |
+
inference_args=request_inference_args,
|
| 1377 |
+
)
|
| 1378 |
+
generate_start = time.perf_counter()
|
| 1379 |
+
validate_on_fixed_batch(
|
| 1380 |
+
fsdp_model=self.model,
|
| 1381 |
+
vae_model=self.vae_model,
|
| 1382 |
+
tokenizer=self.tokenizer,
|
| 1383 |
+
val_data_cpu=val_data_cpu,
|
| 1384 |
+
training_args=request_inference_args,
|
| 1385 |
+
model_args=request_model_args,
|
| 1386 |
+
inference_args=request_inference_args,
|
| 1387 |
+
new_token_ids=self.new_token_ids,
|
| 1388 |
+
image_token_id=self.image_token_id,
|
| 1389 |
+
device=self.device,
|
| 1390 |
+
save_source_video=False,
|
| 1391 |
+
save_path_gen=request_inference_args.save_path_gen,
|
| 1392 |
+
save_path_gt="",
|
| 1393 |
+
)
|
| 1394 |
+
elapsed = time.perf_counter() - generate_start
|
| 1395 |
+
save_prompt_results(request_inference_args.prompt_data_dict, request_inference_args.save_path_gen, self.logger)
|
| 1396 |
+
clean_memory()
|
| 1397 |
+
|
| 1398 |
+
video_path = find_generated_video(save_dir) if internal_task in {TASK_T2V, TASK_VIDEO_EDIT} else None
|
| 1399 |
+
original_video_path = video_path
|
| 1400 |
+
rife_log = ""
|
| 1401 |
+
rife_error = ""
|
| 1402 |
+
frame_interpolation_enabled = normalize_frame_interpolation(enable_frame_interpolation) and internal_task in {TASK_T2V, TASK_VIDEO_EDIT}
|
| 1403 |
+
if frame_interpolation_enabled and video_path is not None:
|
| 1404 |
+
try:
|
| 1405 |
+
clean_memory()
|
| 1406 |
+
print(
|
| 1407 |
+
"[rife] Start frame interpolation "
|
| 1408 |
+
f"| task={internal_task} | gpu={self.device} | input={video_path}",
|
| 1409 |
+
flush=True,
|
| 1410 |
+
)
|
| 1411 |
+
video_path, rife_log = run_rife_interpolation(video_path, self.device, exp=1)
|
| 1412 |
+
except Exception:
|
| 1413 |
+
rife_error = traceback.format_exc()
|
| 1414 |
+
print(rife_error, flush=True)
|
| 1415 |
+
image_path = find_generated_image(save_dir) if internal_task in {TASK_T2I, TASK_IMAGE_EDIT} else None
|
| 1416 |
+
text_result = extract_text_result(save_dir) if internal_task in UNDERSTANDING_TASKS else ""
|
| 1417 |
+
record = {
|
| 1418 |
+
"request_started_at": request_started_at,
|
| 1419 |
+
"request_finished_at": datetime.now().isoformat(timespec="seconds"),
|
| 1420 |
+
"status": "success",
|
| 1421 |
+
"task": internal_task,
|
| 1422 |
+
"model_variant": self.model_variant,
|
| 1423 |
+
"model_path": active_model_path,
|
| 1424 |
+
"gpu": self.device,
|
| 1425 |
+
"prompt": prompt,
|
| 1426 |
+
"system_prompt": normalize_understanding_system_prompt(internal_task, system_prompt)
|
| 1427 |
+
if internal_task in UNDERSTANDING_TASKS
|
| 1428 |
+
else "",
|
| 1429 |
+
"input_video": input_video,
|
| 1430 |
+
"input_image": input_image,
|
| 1431 |
+
"seed": actual_seed,
|
| 1432 |
+
"height": int(height),
|
| 1433 |
+
"width": int(width),
|
| 1434 |
+
"num_frames": int(num_frames),
|
| 1435 |
+
"resolution": display_resolution,
|
| 1436 |
+
"backend_resolution": backend_resolution,
|
| 1437 |
+
"validation_num_timesteps": int(validation_num_timesteps),
|
| 1438 |
+
"validation_timestep_shift": float(validation_timestep_shift),
|
| 1439 |
+
"cfg_text_scale": float(cfg_text_scale),
|
| 1440 |
+
"frame_interpolation": frame_interpolation_enabled,
|
| 1441 |
+
"elapsed_seconds": round(elapsed, 3),
|
| 1442 |
+
"prompt_file": str(prompt_file),
|
| 1443 |
+
"output_dir": str(save_dir),
|
| 1444 |
+
"original_video_path": str(original_video_path) if original_video_path is not None else "",
|
| 1445 |
+
"video_path": str(video_path) if video_path is not None else "",
|
| 1446 |
+
"image_path": str(image_path) if image_path is not None else "",
|
| 1447 |
+
"text_result": text_result,
|
| 1448 |
+
"rife_error": rife_error,
|
| 1449 |
+
}
|
| 1450 |
+
if internal_task in {TASK_T2V, TASK_VIDEO_EDIT} and video_path is None:
|
| 1451 |
+
record["status"] = "completed_without_video"
|
| 1452 |
+
if internal_task in {TASK_T2I, TASK_IMAGE_EDIT} and image_path is None:
|
| 1453 |
+
record["status"] = "completed_without_image"
|
| 1454 |
+
if internal_task in UNDERSTANDING_TASKS and not text_result:
|
| 1455 |
+
record["status"] = "completed_without_text"
|
| 1456 |
+
save_generation_record(record, save_dir)
|
| 1457 |
+
|
| 1458 |
+
logs = "\n".join(
|
| 1459 |
+
[
|
| 1460 |
+
"[lance_gradio_t2v_v2t] Inference finished in-process.",
|
| 1461 |
+
f"task={internal_task}",
|
| 1462 |
+
f"model_variant={self.model_variant}",
|
| 1463 |
+
f"model_path={active_model_path}",
|
| 1464 |
+
f"gpu={self.device}",
|
| 1465 |
+
f"seed={actual_seed}",
|
| 1466 |
+
f"height={height}",
|
| 1467 |
+
f"width={width}",
|
| 1468 |
+
f"num_frames={num_frames}",
|
| 1469 |
+
f"resolution={display_resolution}",
|
| 1470 |
+
f"backend_resolution={backend_resolution}",
|
| 1471 |
+
f"validation_num_timesteps={validation_num_timesteps}",
|
| 1472 |
+
f"validation_timestep_shift={validation_timestep_shift}",
|
| 1473 |
+
f"cfg_text_scale={cfg_text_scale}",
|
| 1474 |
+
f"frame_interpolation={frame_interpolation_enabled}",
|
| 1475 |
+
f"original_video_path={original_video_path or ''}",
|
| 1476 |
+
f"rife_error={rife_error.strip() if rife_error else ''}",
|
| 1477 |
+
f"elapsed={elapsed:.2f}s",
|
| 1478 |
+
f"output_dir={save_dir}",
|
| 1479 |
+
rife_log,
|
| 1480 |
+
]
|
| 1481 |
+
)
|
| 1482 |
+
|
| 1483 |
+
if internal_task in {TASK_T2V, TASK_VIDEO_EDIT}:
|
| 1484 |
+
if video_path is None:
|
| 1485 |
+
status = (
|
| 1486 |
+
"Inference completed, but no output video was found.\n\n"
|
| 1487 |
+
f"- Task: `{internal_task}`\n"
|
| 1488 |
+
f"- Model: `{self.model_variant}`\n"
|
| 1489 |
+
f"- Model path: `{active_model_path}`\n"
|
| 1490 |
+
f"- GPU: `{self.device}`\n"
|
| 1491 |
+
f"- Actual seed: `{actual_seed}`\n"
|
| 1492 |
+
f"- Output directory: `{save_dir}`"
|
| 1493 |
+
)
|
| 1494 |
+
return None, None, "", status, logs
|
| 1495 |
+
# status = (
|
| 1496 |
+
# "Inference completed.\n\n"
|
| 1497 |
+
# f"- Task: `{internal_task}`\n"
|
| 1498 |
+
# f"- Model: `{self.model_variant}`\n"
|
| 1499 |
+
# f"- Model path: `{active_model_path}`\n"
|
| 1500 |
+
# f"- GPU: `{self.device}`\n"
|
| 1501 |
+
# f"- Actual seed: `{actual_seed}`\n"
|
| 1502 |
+
# f"- Output directory: `{save_dir}`\n"
|
| 1503 |
+
# f"- Result file: `{video_path}`"
|
| 1504 |
+
# )
|
| 1505 |
+
status = ""
|
| 1506 |
+
return str(video_path), None, "", status, logs
|
| 1507 |
+
|
| 1508 |
+
if internal_task in {TASK_T2I, TASK_IMAGE_EDIT}:
|
| 1509 |
+
if image_path is None:
|
| 1510 |
+
status = (
|
| 1511 |
+
"Inference completed, but no output image was found.\n\n"
|
| 1512 |
+
f"- Task: `{internal_task}`\n"
|
| 1513 |
+
f"- Model: `{self.model_variant}`\n"
|
| 1514 |
+
f"- Model path: `{active_model_path}`\n"
|
| 1515 |
+
f"- GPU: `{self.device}`\n"
|
| 1516 |
+
f"- Actual seed: `{actual_seed}`\n"
|
| 1517 |
+
f"- Output directory: `{save_dir}`"
|
| 1518 |
+
)
|
| 1519 |
+
return None, None, "", status, logs
|
| 1520 |
+
# status = (
|
| 1521 |
+
# "Inference completed.\n\n"
|
| 1522 |
+
# f"- Task: `{internal_task}`\n"
|
| 1523 |
+
# f"- Model: `{self.model_variant}`\n"
|
| 1524 |
+
# f"- Model path: `{active_model_path}`\n"
|
| 1525 |
+
# f"- GPU: `{self.device}`\n"
|
| 1526 |
+
# f"- Actual seed: `{actual_seed}`\n"
|
| 1527 |
+
# f"- Output directory: `{save_dir}`\n"
|
| 1528 |
+
# f"- Result file: `{image_path}`"
|
| 1529 |
+
# )
|
| 1530 |
+
status = ""
|
| 1531 |
+
return None, str(image_path), "", status, logs
|
| 1532 |
+
|
| 1533 |
+
# status = (
|
| 1534 |
+
# "Understanding completed.\n\n"
|
| 1535 |
+
# f"- Task: `{task}`\n"
|
| 1536 |
+
# f"- Model: `{self.model_variant}`\n"
|
| 1537 |
+
# f"- Model path: `{active_model_path}`\n"
|
| 1538 |
+
# f"- GPU: `{self.device}`\n"
|
| 1539 |
+
# f"- Actual seed: `{actual_seed}`\n"
|
| 1540 |
+
# f"- Output directory: `{save_dir}`"
|
| 1541 |
+
# )
|
| 1542 |
+
status = ""
|
| 1543 |
+
return None, None, text_result, status, logs
|
| 1544 |
+
except Exception:
|
| 1545 |
+
error_trace = traceback.format_exc()
|
| 1546 |
+
print(error_trace, flush=True)
|
| 1547 |
+
record = {
|
| 1548 |
+
"request_started_at": request_started_at,
|
| 1549 |
+
"request_finished_at": datetime.now().isoformat(timespec="seconds"),
|
| 1550 |
+
"status": "failed",
|
| 1551 |
+
"task": internal_task,
|
| 1552 |
+
"model_variant": self.model_variant,
|
| 1553 |
+
"model_path": active_model_path,
|
| 1554 |
+
"gpu": self.device,
|
| 1555 |
+
"prompt": prompt,
|
| 1556 |
+
"input_video": input_video,
|
| 1557 |
+
"input_image": input_image,
|
| 1558 |
+
"seed": actual_seed,
|
| 1559 |
+
"height": int(height),
|
| 1560 |
+
"width": int(width),
|
| 1561 |
+
"num_frames": int(num_frames),
|
| 1562 |
+
"resolution": display_resolution,
|
| 1563 |
+
"backend_resolution": backend_resolution,
|
| 1564 |
+
"validation_num_timesteps": int(validation_num_timesteps),
|
| 1565 |
+
"validation_timestep_shift": float(validation_timestep_shift),
|
| 1566 |
+
"cfg_text_scale": float(cfg_text_scale),
|
| 1567 |
+
"prompt_file": str(prompt_file),
|
| 1568 |
+
"output_dir": str(save_dir),
|
| 1569 |
+
"video_path": "",
|
| 1570 |
+
"image_path": "",
|
| 1571 |
+
"text_result": "",
|
| 1572 |
+
"error": error_trace,
|
| 1573 |
+
}
|
| 1574 |
+
save_generation_record(record, save_dir)
|
| 1575 |
+
status = (
|
| 1576 |
+
"Inference failed.\n\n"
|
| 1577 |
+
f"- Task: `{internal_task}`\n"
|
| 1578 |
+
f"- Model: `{self.model_variant}`\n"
|
| 1579 |
+
f"- Model path: `{active_model_path}`\n"
|
| 1580 |
+
f"- GPU: `{self.device}`\n"
|
| 1581 |
+
f"- Actual seed: `{actual_seed}`\n"
|
| 1582 |
+
f"- Resolution: `{display_resolution}`\n"
|
| 1583 |
+
f"- Output directory: `{save_dir}`"
|
| 1584 |
+
)
|
| 1585 |
+
return None, None, "", status, error_trace
|
| 1586 |
+
|
| 1587 |
+
|
| 1588 |
+
class PipelinePool:
|
| 1589 |
+
def __init__(self, gpu_ids: list[int], model_variant: str = MODEL_VARIANT_VIDEO) -> None:
|
| 1590 |
+
if not gpu_ids:
|
| 1591 |
+
raise ValueError("At least one GPU must be configured.")
|
| 1592 |
+
self.gpu_ids = gpu_ids
|
| 1593 |
+
self.model_variant = normalize_model_variant(model_variant)
|
| 1594 |
+
self.pipelines = [
|
| 1595 |
+
LanceT2VV2TPipeline(device_id=gpu_id, model_variant=self.model_variant)
|
| 1596 |
+
for gpu_id in gpu_ids
|
| 1597 |
+
]
|
| 1598 |
+
self._available = deque(self.pipelines)
|
| 1599 |
+
self._condition = threading.Condition()
|
| 1600 |
+
|
| 1601 |
+
@property
|
| 1602 |
+
def size(self) -> int:
|
| 1603 |
+
return len(self.pipelines)
|
| 1604 |
+
|
| 1605 |
+
@property
|
| 1606 |
+
def gpu_summary(self) -> str:
|
| 1607 |
+
return ",".join(str(gpu_id) for gpu_id in self.gpu_ids)
|
| 1608 |
+
|
| 1609 |
+
def initialize_all(self) -> None:
|
| 1610 |
+
print(f"[startup][{self.model_variant}] Preparing parallel GPU preload: {self.gpu_ids}", flush=True)
|
| 1611 |
+
exceptions: list[Exception] = []
|
| 1612 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=self.size) as executor:
|
| 1613 |
+
futures = {
|
| 1614 |
+
executor.submit(pipeline.initialize): pipeline.device for pipeline in self.pipelines
|
| 1615 |
+
}
|
| 1616 |
+
for future in concurrent.futures.as_completed(futures):
|
| 1617 |
+
gpu_id = futures[future]
|
| 1618 |
+
try:
|
| 1619 |
+
future.result()
|
| 1620 |
+
except Exception as exc:
|
| 1621 |
+
print(f"[startup][gpu:{gpu_id}][{self.model_variant}] Preload failed: {exc}", flush=True)
|
| 1622 |
+
exceptions.append(exc)
|
| 1623 |
+
if exceptions:
|
| 1624 |
+
raise RuntimeError(
|
| 1625 |
+
f"{self.model_variant} preload failed on {len(exceptions)} GPU(s). Please check the terminal logs."
|
| 1626 |
+
) from exceptions[0]
|
| 1627 |
+
print(
|
| 1628 |
+
f"[startup][{self.model_variant}] GPU preload finished. Ready to handle {self.size} concurrent request(s).",
|
| 1629 |
+
flush=True,
|
| 1630 |
+
)
|
| 1631 |
+
|
| 1632 |
+
def acquire(self) -> LanceT2VV2TPipeline:
|
| 1633 |
+
with self._condition:
|
| 1634 |
+
while not self._available:
|
| 1635 |
+
self._condition.wait()
|
| 1636 |
+
return self._available.popleft()
|
| 1637 |
+
|
| 1638 |
+
def release(self, pipeline: LanceT2VV2TPipeline) -> None:
|
| 1639 |
+
with self._condition:
|
| 1640 |
+
self._available.append(pipeline)
|
| 1641 |
+
self._condition.notify()
|
| 1642 |
+
|
| 1643 |
+
def unload_all(self) -> None:
|
| 1644 |
+
print(f"[runtime][{self.model_variant}] Unloading model pool from GPU(s): {self.gpu_ids}", flush=True)
|
| 1645 |
+
with self._condition:
|
| 1646 |
+
while len(self._available) != len(self.pipelines):
|
| 1647 |
+
self._condition.wait()
|
| 1648 |
+
|
| 1649 |
+
for pipeline in self.pipelines:
|
| 1650 |
+
pipeline.unload()
|
| 1651 |
+
|
| 1652 |
+
gc.collect()
|
| 1653 |
+
if torch.cuda.is_available():
|
| 1654 |
+
torch.cuda.empty_cache()
|
| 1655 |
+
torch.cuda.ipc_collect()
|
| 1656 |
+
print(f"[runtime][{self.model_variant}] Model pool unloaded.", flush=True)
|
| 1657 |
+
|
| 1658 |
+
def generate(
|
| 1659 |
+
self,
|
| 1660 |
+
task: str,
|
| 1661 |
+
prompt: str,
|
| 1662 |
+
system_prompt: Optional[str],
|
| 1663 |
+
input_video: Optional[str],
|
| 1664 |
+
input_image: Optional[str],
|
| 1665 |
+
height: int,
|
| 1666 |
+
width: int,
|
| 1667 |
+
num_frames: int,
|
| 1668 |
+
seed: int,
|
| 1669 |
+
resolution: str,
|
| 1670 |
+
validation_num_timesteps: int,
|
| 1671 |
+
validation_timestep_shift: float,
|
| 1672 |
+
cfg_text_scale: float,
|
| 1673 |
+
enable_frame_interpolation: bool,
|
| 1674 |
+
):
|
| 1675 |
+
pipeline = self.acquire()
|
| 1676 |
+
try:
|
| 1677 |
+
with get_gpu_runtime_lock(pipeline.device):
|
| 1678 |
+
return pipeline.generate(
|
| 1679 |
+
task=task,
|
| 1680 |
+
prompt=prompt,
|
| 1681 |
+
system_prompt=system_prompt,
|
| 1682 |
+
input_video=input_video,
|
| 1683 |
+
input_image=input_image,
|
| 1684 |
+
height=height,
|
| 1685 |
+
width=width,
|
| 1686 |
+
num_frames=num_frames,
|
| 1687 |
+
seed=seed,
|
| 1688 |
+
resolution=resolution,
|
| 1689 |
+
validation_num_timesteps=validation_num_timesteps,
|
| 1690 |
+
validation_timestep_shift=validation_timestep_shift,
|
| 1691 |
+
cfg_text_scale=cfg_text_scale,
|
| 1692 |
+
enable_frame_interpolation=enable_frame_interpolation,
|
| 1693 |
+
)
|
| 1694 |
+
finally:
|
| 1695 |
+
self.release(pipeline)
|
| 1696 |
+
|
| 1697 |
+
|
| 1698 |
+
ACTIVE_PIPELINE_POOLS: dict[str, PipelinePool] = {}
|
| 1699 |
+
ACTIVE_POOL_LOCK = threading.Lock()
|
| 1700 |
+
GPU_RUNTIME_LOCKS: dict[int, threading.Lock] = {}
|
| 1701 |
+
GPU_RUNTIME_LOCKS_LOCK = threading.Lock()
|
| 1702 |
+
QUEUE_MAX_SIZE = DEFAULT_QUEUE_SIZE
|
| 1703 |
+
PRELOAD_MODEL_VARIANTS = [MODEL_VARIANT_VIDEO, MODEL_VARIANT_IMAGE]
|
| 1704 |
+
|
| 1705 |
+
|
| 1706 |
+
def get_gpu_runtime_lock(device_id: int) -> threading.Lock:
|
| 1707 |
+
with GPU_RUNTIME_LOCKS_LOCK:
|
| 1708 |
+
lock = GPU_RUNTIME_LOCKS.get(device_id)
|
| 1709 |
+
if lock is None:
|
| 1710 |
+
lock = threading.Lock()
|
| 1711 |
+
GPU_RUNTIME_LOCKS[device_id] = lock
|
| 1712 |
+
return lock
|
| 1713 |
+
|
| 1714 |
+
|
| 1715 |
+
def get_task_model_variant(task: str) -> str:
|
| 1716 |
+
internal_task = normalize_task(task)
|
| 1717 |
+
return MODEL_VARIANT_IMAGE if internal_task in IMAGE_TASKS else MODEL_VARIANT_VIDEO
|
| 1718 |
+
|
| 1719 |
+
|
| 1720 |
+
def get_pipeline_pool(task: str) -> PipelinePool:
|
| 1721 |
+
model_variant = get_task_model_variant(task)
|
| 1722 |
+
with ACTIVE_POOL_LOCK:
|
| 1723 |
+
pipeline_pool = ACTIVE_PIPELINE_POOLS.get(model_variant)
|
| 1724 |
+
if pipeline_pool is not None:
|
| 1725 |
+
return pipeline_pool
|
| 1726 |
+
|
| 1727 |
+
gpu_ids = parse_gpu_ids(os.getenv("LANCE_GPUS", DEFAULT_GPUS))
|
| 1728 |
+
print(
|
| 1729 |
+
f"[runtime] Loading Lance {model_variant} model pool without unloading existing pools.",
|
| 1730 |
+
flush=True,
|
| 1731 |
+
)
|
| 1732 |
+
pipeline_pool = PipelinePool(gpu_ids, model_variant=model_variant)
|
| 1733 |
+
pipeline_pool.initialize_all()
|
| 1734 |
+
ACTIVE_PIPELINE_POOLS[model_variant] = pipeline_pool
|
| 1735 |
+
return pipeline_pool
|
| 1736 |
+
|
| 1737 |
+
|
| 1738 |
+
def preload_pipeline_pools(gpu_ids: list[int], model_variants: list[str]) -> None:
|
| 1739 |
+
for model_variant in model_variants:
|
| 1740 |
+
normalized_variant = normalize_model_variant(model_variant)
|
| 1741 |
+
if normalized_variant in ACTIVE_PIPELINE_POOLS:
|
| 1742 |
+
continue
|
| 1743 |
+
resolved_model_path = ensure_model_assets(normalized_variant)
|
| 1744 |
+
print(
|
| 1745 |
+
f"[startup][{normalized_variant}] Using Lance model path: {resolved_model_path}",
|
| 1746 |
+
flush=True,
|
| 1747 |
+
)
|
| 1748 |
+
pipeline_pool = PipelinePool(gpu_ids, model_variant=normalized_variant)
|
| 1749 |
+
pipeline_pool.initialize_all()
|
| 1750 |
+
ACTIVE_PIPELINE_POOLS[normalized_variant] = pipeline_pool
|
| 1751 |
+
|
| 1752 |
+
|
| 1753 |
+
def run_task(
|
| 1754 |
+
task: str,
|
| 1755 |
+
prompt: str,
|
| 1756 |
+
system_prompt: Optional[str],
|
| 1757 |
+
input_video: Optional[str],
|
| 1758 |
+
input_image: Optional[str],
|
| 1759 |
+
height: int,
|
| 1760 |
+
width: int,
|
| 1761 |
+
num_frames: int,
|
| 1762 |
+
seed: int,
|
| 1763 |
+
resolution: str,
|
| 1764 |
+
validation_num_timesteps: int,
|
| 1765 |
+
validation_timestep_shift: float,
|
| 1766 |
+
cfg_text_scale: float,
|
| 1767 |
+
enable_frame_interpolation: bool,
|
| 1768 |
+
):
|
| 1769 |
+
internal_task = normalize_task(task)
|
| 1770 |
+
if internal_task == TASK_T2V:
|
| 1771 |
+
num_frames = video_seconds_to_num_frames(num_frames)
|
| 1772 |
+
pipeline_pool = get_pipeline_pool(task)
|
| 1773 |
+
return pipeline_pool.generate(
|
| 1774 |
+
task=task,
|
| 1775 |
+
prompt=prompt,
|
| 1776 |
+
system_prompt=system_prompt,
|
| 1777 |
+
input_video=input_video,
|
| 1778 |
+
input_image=input_image,
|
| 1779 |
+
height=height,
|
| 1780 |
+
width=width,
|
| 1781 |
+
num_frames=num_frames,
|
| 1782 |
+
seed=seed,
|
| 1783 |
+
resolution=resolution,
|
| 1784 |
+
validation_num_timesteps=validation_num_timesteps,
|
| 1785 |
+
validation_timestep_shift=validation_timestep_shift,
|
| 1786 |
+
cfg_text_scale=cfg_text_scale,
|
| 1787 |
+
enable_frame_interpolation=enable_frame_interpolation,
|
| 1788 |
+
)
|
| 1789 |
+
|
| 1790 |
+
|
| 1791 |
+
def build_status_markdown() -> str:
|
| 1792 |
+
gpu_text = "unknown"
|
| 1793 |
+
concurrency = 1
|
| 1794 |
+
loaded_variants = "none"
|
| 1795 |
+
if ACTIVE_PIPELINE_POOLS:
|
| 1796 |
+
loaded_variants = ",".join(sorted(ACTIVE_PIPELINE_POOLS))
|
| 1797 |
+
gpu_ids = sorted({gpu_id for pool in ACTIVE_PIPELINE_POOLS.values() for gpu_id in pool.gpu_ids})
|
| 1798 |
+
gpu_text = ",".join(str(gpu_id) for gpu_id in gpu_ids)
|
| 1799 |
+
concurrency = len(gpu_ids)
|
| 1800 |
+
return (
|
| 1801 |
+
f"**Status** GPU: `{gpu_text}` | Max concurrency: `{concurrency}` | "
|
| 1802 |
+
f"Queue limit: `{QUEUE_MAX_SIZE}` | Loaded models: `{loaded_variants}` | "
|
| 1803 |
+
f"Switch mode: `dual resident`"
|
| 1804 |
+
)
|
| 1805 |
+
|
| 1806 |
+
|
| 1807 |
+
def get_logo_data_uri() -> str:
|
| 1808 |
+
if not LANCE_LOGO_PATH.exists():
|
| 1809 |
+
return ""
|
| 1810 |
+
encoded_logo = base64.b64encode(LANCE_LOGO_PATH.read_bytes()).decode("ascii")
|
| 1811 |
+
return f"data:image/webp;base64,{encoded_logo}"
|
| 1812 |
+
|
| 1813 |
+
|
| 1814 |
+
def build_header_html() -> str:
|
| 1815 |
+
logo_data_uri = get_logo_data_uri()
|
| 1816 |
+
logo_html = (
|
| 1817 |
+
f'<img class="lance-logo" src="{logo_data_uri}" alt="Lance logo">'
|
| 1818 |
+
if logo_data_uri
|
| 1819 |
+
else ""
|
| 1820 |
+
)
|
| 1821 |
+
return f"""
|
| 1822 |
+
<div class="lance-hero">
|
| 1823 |
+
{logo_html}
|
| 1824 |
+
<h1 class="lance-title">Lance: Unified Multimodal Modeling by Multi-Task Synergy</h1>
|
| 1825 |
+
<div class="lance-authors">
|
| 1826 |
+
<strong>
|
| 1827 |
+
<a href="https://scholar.google.com.hk/citations?user=FXxoQlsAAAAJ&hl=zh-CN&oi=ao" target="_blank">Fengyi Fu</a><sup>*</sup>,
|
| 1828 |
+
<a href="https://corleone-huang.github.io/" target="_blank">Mengqi Huang</a><sup>*,✉</sup>,
|
| 1829 |
+
<a href="https://scholar.google.com.hk/citations?user=9ER6nVkAAAAJ&hl=zh-CN&oi=ao" target="_blank">Shaojin Wu</a><sup>*</sup>,
|
| 1830 |
+
Yunsheng Jiang<sup>*</sup>,
|
| 1831 |
+
Yufei Huo,
|
| 1832 |
+
<a href="https://guojianzhu.com/" target="_blank">Jianzhu Guo</a><sup>✉,§</sup>
|
| 1833 |
+
</strong><br>
|
| 1834 |
+
Hao Li, Yinghang Song, Fei Ding, Qian He, Zheren Fu, Zhendong Mao, Yongdong Zhang<br>
|
| 1835 |
+
<em>ByteDance</em>
|
| 1836 |
+
</div>
|
| 1837 |
+
<div class="lance-badges">
|
| 1838 |
+
<a href="{LANCE_HOMEPAGE_URL}" target="_blank" rel="noopener noreferrer">
|
| 1839 |
+
<img alt="Homepage" src="https://img.shields.io/badge/Homepage-Lance-blue?style=flat">
|
| 1840 |
+
</a>
|
| 1841 |
+
<a href="{LANCE_PAPER_URL}" target="_blank" rel="noopener noreferrer">
|
| 1842 |
+
<img alt="Paper" src="https://img.shields.io/badge/Paper-arXiv-red?style=flat&logo=arxiv">
|
| 1843 |
+
</a>
|
| 1844 |
+
<a href="{LANCE_HUGGING_FACE_URL}" target="_blank" rel="noopener noreferrer">
|
| 1845 |
+
<img alt="Hugging Face" src="https://img.shields.io/badge/Model-HuggingFace-yellow?style=flat&logo=huggingface">
|
| 1846 |
+
</a>
|
| 1847 |
+
<a href="{LANCE_GITHUB_URL}" target="_blank" rel="noopener noreferrer">
|
| 1848 |
+
<img alt="GitHub" src="https://img.shields.io/badge/Code-GitHub-536af5?color=536af5&logo=github">
|
| 1849 |
+
</a>
|
| 1850 |
+
</div>
|
| 1851 |
+
</div>
|
| 1852 |
+
"""
|
| 1853 |
+
|
| 1854 |
+
|
| 1855 |
+
def update_task_ui(task: str):
|
| 1856 |
+
internal_task = normalize_task(task)
|
| 1857 |
+
is_image_task = internal_task in IMAGE_TASKS
|
| 1858 |
+
is_video_task = internal_task in VIDEO_TASKS
|
| 1859 |
+
is_edit_task = internal_task in EDIT_TASKS
|
| 1860 |
+
is_understanding_task = internal_task in UNDERSTANDING_TASKS
|
| 1861 |
+
is_generation_task = internal_task in GENERATION_TASKS
|
| 1862 |
+
show_media_input = is_edit_task or is_understanding_task
|
| 1863 |
+
resolution_choices = IMAGE_RESOLUTION_CHOICES if is_image_task else VIDEO_RESOLUTION_CHOICES
|
| 1864 |
+
resolution_value = DEFAULT_IMAGE_RESOLUTION if is_image_task else DEFAULT_RESOLUTION
|
| 1865 |
+
aspect_ratio_value = DEFAULT_IMAGE_ASPECT_RATIO if is_image_task else DEFAULT_VIDEO_ASPECT_RATIO
|
| 1866 |
+
width_value, height_value = get_size_for_aspect_ratio(internal_task, aspect_ratio_value)
|
| 1867 |
+
size_markdown = format_size_markdown(internal_task, width_value, height_value)
|
| 1868 |
+
system_prompt_choices = get_understanding_system_prompt_choices(internal_task)
|
| 1869 |
+
|
| 1870 |
+
if is_generation_task:
|
| 1871 |
+
text_label = "Prompt"
|
| 1872 |
+
text_placeholder = "Describe what you want to generate..."
|
| 1873 |
+
elif is_edit_task:
|
| 1874 |
+
text_label = "Instruction"
|
| 1875 |
+
text_placeholder = "Describe the edit you want..."
|
| 1876 |
+
else:
|
| 1877 |
+
text_label = "Question"
|
| 1878 |
+
text_placeholder = "Ask a question about the input..."
|
| 1879 |
+
|
| 1880 |
+
return (
|
| 1881 |
+
gr.update(
|
| 1882 |
+
label=text_label,
|
| 1883 |
+
placeholder=text_placeholder,
|
| 1884 |
+
visible=True,
|
| 1885 |
+
),
|
| 1886 |
+
gr.update(
|
| 1887 |
+
choices=system_prompt_choices,
|
| 1888 |
+
value=system_prompt_choices[0],
|
| 1889 |
+
visible=False,
|
| 1890 |
+
),
|
| 1891 |
+
gr.update(label="Input Video", visible=show_media_input and is_video_task),
|
| 1892 |
+
gr.update(label="Input Image", visible=show_media_input and is_image_task),
|
| 1893 |
+
gr.update(value=aspect_ratio_value, visible=is_generation_task or is_edit_task),
|
| 1894 |
+
gr.update(value=height_value),
|
| 1895 |
+
gr.update(value=width_value),
|
| 1896 |
+
gr.update(value=size_markdown, visible=is_generation_task or is_edit_task),
|
| 1897 |
+
gr.update(visible=internal_task == TASK_T2V, value=DEFAULT_VIDEO_DURATION_SECONDS if internal_task == TASK_T2V else 1),
|
| 1898 |
+
gr.update(visible=internal_task in {TASK_T2V, TASK_VIDEO_EDIT}, value=DEFAULT_FRAME_INTERPOLATION),
|
| 1899 |
+
gr.update(choices=resolution_choices, value=resolution_value, visible=False),
|
| 1900 |
+
gr.update(visible=internal_task in {TASK_T2V, TASK_VIDEO_EDIT}),
|
| 1901 |
+
gr.update(visible=internal_task in {TASK_T2I, TASK_IMAGE_EDIT}),
|
| 1902 |
+
gr.update(visible=is_understanding_task, value=""),
|
| 1903 |
+
gr.update(visible=internal_task == TASK_T2V),
|
| 1904 |
+
gr.update(visible=internal_task == TASK_VIDEO_EDIT),
|
| 1905 |
+
gr.update(visible=internal_task == TASK_X2T_VIDEO),
|
| 1906 |
+
gr.update(visible=internal_task == TASK_T2I),
|
| 1907 |
+
gr.update(visible=internal_task == TASK_IMAGE_EDIT),
|
| 1908 |
+
gr.update(visible=internal_task == TASK_X2T_IMAGE),
|
| 1909 |
+
)
|
| 1910 |
+
|
| 1911 |
+
|
| 1912 |
+
def keep_example_clicks_from_changing_visibility(*examples_components) -> None:
|
| 1913 |
+
for examples_component in examples_components:
|
| 1914 |
+
dataset = getattr(examples_component, "dataset", None)
|
| 1915 |
+
component_props = getattr(dataset, "component_props", None)
|
| 1916 |
+
if not component_props:
|
| 1917 |
+
continue
|
| 1918 |
+
for props in component_props:
|
| 1919 |
+
props.pop("visible", None)
|
| 1920 |
+
|
| 1921 |
+
|
| 1922 |
+
def build_demo() -> gr.Blocks:
|
| 1923 |
+
with gr.Blocks(title="Lance", css=APP_CSS) as demo:
|
| 1924 |
+
gr.HTML(build_header_html())
|
| 1925 |
+
gr.Markdown(build_status_markdown(), elem_classes=["lance-status"], visible=False)
|
| 1926 |
+
|
| 1927 |
+
with gr.Row(elem_classes=["lance-main-row"]):
|
| 1928 |
+
with gr.Column(scale=1, elem_classes=["lance-main-column"]):
|
| 1929 |
+
task = gr.Radio(
|
| 1930 |
+
label="Task",
|
| 1931 |
+
choices=TASK_CHOICES,
|
| 1932 |
+
value=TASK_LABEL_VIDEO_GENERATION,
|
| 1933 |
+
elem_classes=["task-selector"],
|
| 1934 |
+
)
|
| 1935 |
+
prompt = gr.Textbox(
|
| 1936 |
+
label="Prompt",
|
| 1937 |
+
lines=6,
|
| 1938 |
+
placeholder="Describe the video you want to generate...",
|
| 1939 |
+
)
|
| 1940 |
+
system_prompt = gr.Dropdown(
|
| 1941 |
+
label="System Prompt",
|
| 1942 |
+
choices=get_understanding_system_prompt_choices(TASK_X2T_VIDEO),
|
| 1943 |
+
value=V2T_QA_SYSTEM_PROMPT,
|
| 1944 |
+
visible=False,
|
| 1945 |
+
)
|
| 1946 |
+
input_video = gr.Video(label="Input Video", visible=False, elem_classes=["lance-display-frame"])
|
| 1947 |
+
input_image = gr.Image(label="Input Image", type="filepath", visible=False, elem_classes=["lance-display-frame"])
|
| 1948 |
+
with gr.Row(elem_classes=["generation-controls-row"]):
|
| 1949 |
+
enable_frame_interpolation = gr.Dropdown(
|
| 1950 |
+
label="Frame Interpolation",
|
| 1951 |
+
choices=[FRAME_INTERPOLATION_YES, FRAME_INTERPOLATION_NO],
|
| 1952 |
+
value=DEFAULT_FRAME_INTERPOLATION,
|
| 1953 |
+
elem_classes=["generation-control", "generation-dropdown-control"],
|
| 1954 |
+
min_width=0,
|
| 1955 |
+
)
|
| 1956 |
+
seed = gr.Number(
|
| 1957 |
+
label="Seed (-1 for random seed)",
|
| 1958 |
+
value=DEFAULT_BASIC_SEED,
|
| 1959 |
+
precision=0,
|
| 1960 |
+
elem_classes=["generation-control", "generation-value-control"],
|
| 1961 |
+
min_width=0,
|
| 1962 |
+
# info="-1 for random seed",
|
| 1963 |
+
)
|
| 1964 |
+
aspect_ratio = gr.Dropdown(
|
| 1965 |
+
label="Aspect Ratio",
|
| 1966 |
+
# choices=ASPECT_RATIO_CHOICES, # 原始版本,不显示 是否为 default
|
| 1967 |
+
choices=get_aspect_ratio_choices_for_task(TASK_T2V),
|
| 1968 |
+
value=DEFAULT_VIDEO_ASPECT_RATIO,
|
| 1969 |
+
elem_classes=["generation-control", "generation-dropdown-control"],
|
| 1970 |
+
min_width=0,
|
| 1971 |
+
)
|
| 1972 |
+
# real_size = gr.Markdown(format_size_markdown(TASK_T2V, DEFAULT_WIDTH, DEFAULT_HEIGHT))
|
| 1973 |
+
real_size = gr.Textbox(
|
| 1974 |
+
label="Output Resolution",
|
| 1975 |
+
value=format_size_markdown(TASK_T2V, DEFAULT_WIDTH, DEFAULT_HEIGHT),
|
| 1976 |
+
interactive=False,
|
| 1977 |
+
elem_classes=["generation-control", "generation-value-control"],
|
| 1978 |
+
min_width=0,
|
| 1979 |
+
)
|
| 1980 |
+
resolution = gr.Dropdown(
|
| 1981 |
+
label="Resolution",
|
| 1982 |
+
choices=RESOLUTION_CHOICES,
|
| 1983 |
+
value=DEFAULT_RESOLUTION,
|
| 1984 |
+
visible=False,
|
| 1985 |
+
)
|
| 1986 |
+
height = gr.Number(value=DEFAULT_HEIGHT, precision=0, visible=False)
|
| 1987 |
+
width = gr.Number(value=DEFAULT_WIDTH, precision=0, visible=False)
|
| 1988 |
+
num_frames = gr.Slider(
|
| 1989 |
+
minimum=1,
|
| 1990 |
+
maximum=10,
|
| 1991 |
+
step=1,
|
| 1992 |
+
value=DEFAULT_VIDEO_DURATION_SECONDS,
|
| 1993 |
+
label="Video Duration (seconds)",
|
| 1994 |
+
)
|
| 1995 |
+
# seed = gr.Number(
|
| 1996 |
+
# label="Seed",
|
| 1997 |
+
# value=DEFAULT_BASIC_SEED,
|
| 1998 |
+
# precision=0,
|
| 1999 |
+
# info="-1 means using a random seed each time",
|
| 2000 |
+
# )
|
| 2001 |
+
|
| 2002 |
+
with gr.Accordion("Advanced Parameters", open=False):
|
| 2003 |
+
validation_num_timesteps = gr.Slider(
|
| 2004 |
+
minimum=1,
|
| 2005 |
+
maximum=100,
|
| 2006 |
+
step=1,
|
| 2007 |
+
value=DEFAULT_TIMESTEPS,
|
| 2008 |
+
label="Validation Num Timesteps",
|
| 2009 |
+
)
|
| 2010 |
+
with gr.Row():
|
| 2011 |
+
validation_timestep_shift = gr.Number(
|
| 2012 |
+
label="Validation Timestep Shift",
|
| 2013 |
+
value=DEFAULT_TIMESTEP_SHIFT,
|
| 2014 |
+
)
|
| 2015 |
+
cfg_text_scale = gr.Number(
|
| 2016 |
+
label="CFG Text Scale",
|
| 2017 |
+
value=DEFAULT_CFG_TEXT_SCALE,
|
| 2018 |
+
)
|
| 2019 |
+
|
| 2020 |
+
generation_example_inputs = [
|
| 2021 |
+
prompt,
|
| 2022 |
+
input_video,
|
| 2023 |
+
input_image,
|
| 2024 |
+
]
|
| 2025 |
+
|
| 2026 |
+
with gr.Column(scale=1, elem_classes=["lance-main-column"]):
|
| 2027 |
+
output_video = gr.Video(label="Output Video", elem_classes=["lance-display-frame"])
|
| 2028 |
+
output_image = gr.Image(label="Output Image", type="filepath", visible=False, elem_classes=["lance-display-frame"])
|
| 2029 |
+
output_text = gr.Textbox(label="Output Text", lines=8, visible=False, elem_classes=["lance-display-frame"])
|
| 2030 |
+
status = gr.Markdown("WAITING TO RUN.")
|
| 2031 |
+
logs = gr.Textbox(label="Run Logs", lines=22, max_lines=30)
|
| 2032 |
+
|
| 2033 |
+
run_button = gr.Button("🚀 Generate", variant="primary", elem_classes=["lance-run-button"])
|
| 2034 |
+
|
| 2035 |
+
with gr.Group(visible=True, elem_classes=["prompt-examples", "example-panel"]) as video_generation_examples_group:
|
| 2036 |
+
gr.Markdown("### Video generation recommended cases", elem_classes=["recommended-title"])
|
| 2037 |
+
video_generation_examples = gr.Dataset(
|
| 2038 |
+
samples=VIDEO_GENERATION_EXAMPLES,
|
| 2039 |
+
components=[gr.Textbox(label="Prompt", visible=False)],
|
| 2040 |
+
headers=["Prompt"],
|
| 2041 |
+
show_label=False,
|
| 2042 |
+
type="values",
|
| 2043 |
+
layout="table",
|
| 2044 |
+
samples_per_page=len(VIDEO_GENERATION_EXAMPLES),
|
| 2045 |
+
elem_classes=["prompt-dataset"],
|
| 2046 |
+
)
|
| 2047 |
+
|
| 2048 |
+
with gr.Group(visible=False, elem_classes=["example-panel"]) as video_edit_examples_group:
|
| 2049 |
+
gr.Markdown("### Video edit recommended cases", elem_classes=["recommended-title"])
|
| 2050 |
+
video_edit_examples = gr.Examples(
|
| 2051 |
+
examples=VIDEO_EDIT_EXAMPLES,
|
| 2052 |
+
inputs=generation_example_inputs,
|
| 2053 |
+
label="",
|
| 2054 |
+
examples_per_page=3,
|
| 2055 |
+
cache_examples=False,
|
| 2056 |
+
preprocess=False,
|
| 2057 |
+
postprocess=False,
|
| 2058 |
+
)
|
| 2059 |
+
|
| 2060 |
+
with gr.Group(visible=False, elem_classes=["example-panel"]) as video_understanding_examples_group:
|
| 2061 |
+
gr.Markdown("### Video understanding recommended cases", elem_classes=["recommended-title"])
|
| 2062 |
+
video_understanding_examples = gr.Examples(
|
| 2063 |
+
examples=VIDEO_UNDERSTANDING_EXAMPLES,
|
| 2064 |
+
inputs=generation_example_inputs,
|
| 2065 |
+
label="",
|
| 2066 |
+
examples_per_page=4,
|
| 2067 |
+
cache_examples=False,
|
| 2068 |
+
preprocess=False,
|
| 2069 |
+
postprocess=False,
|
| 2070 |
+
)
|
| 2071 |
+
|
| 2072 |
+
with gr.Group(visible=False, elem_classes=["prompt-examples", "example-panel"]) as image_generation_examples_group:
|
| 2073 |
+
gr.Markdown("### Image generation recommended cases", elem_classes=["recommended-title"])
|
| 2074 |
+
image_generation_examples = gr.Dataset(
|
| 2075 |
+
samples=IMAGE_GENERATION_EXAMPLES,
|
| 2076 |
+
components=[gr.Textbox(label="Prompt", visible=False)],
|
| 2077 |
+
headers=["Prompt"],
|
| 2078 |
+
show_label=False,
|
| 2079 |
+
type="values",
|
| 2080 |
+
layout="table",
|
| 2081 |
+
samples_per_page=len(IMAGE_GENERATION_EXAMPLES),
|
| 2082 |
+
elem_classes=["prompt-dataset"],
|
| 2083 |
+
)
|
| 2084 |
+
|
| 2085 |
+
with gr.Group(visible=False, elem_classes=["example-panel"]) as image_edit_examples_group:
|
| 2086 |
+
gr.Markdown("### Image edit recommended cases", elem_classes=["recommended-title"])
|
| 2087 |
+
image_edit_examples = gr.Examples(
|
| 2088 |
+
examples=IMAGE_EDIT_EXAMPLES,
|
| 2089 |
+
inputs=generation_example_inputs,
|
| 2090 |
+
label="",
|
| 2091 |
+
examples_per_page=5,
|
| 2092 |
+
cache_examples=False,
|
| 2093 |
+
preprocess=False,
|
| 2094 |
+
postprocess=False,
|
| 2095 |
+
)
|
| 2096 |
+
|
| 2097 |
+
with gr.Group(visible=False, elem_classes=["example-panel"]) as image_understanding_examples_group:
|
| 2098 |
+
gr.Markdown("### Image understanding recommended cases", elem_classes=["recommended-title"])
|
| 2099 |
+
image_understanding_examples = gr.Examples(
|
| 2100 |
+
examples=IMAGE_UNDERSTANDING_EXAMPLES,
|
| 2101 |
+
inputs=generation_example_inputs,
|
| 2102 |
+
label="",
|
| 2103 |
+
examples_per_page=4,
|
| 2104 |
+
cache_examples=False,
|
| 2105 |
+
preprocess=False,
|
| 2106 |
+
postprocess=False,
|
| 2107 |
+
)
|
| 2108 |
+
|
| 2109 |
+
keep_example_clicks_from_changing_visibility(
|
| 2110 |
+
video_generation_examples,
|
| 2111 |
+
video_edit_examples,
|
| 2112 |
+
video_understanding_examples,
|
| 2113 |
+
image_generation_examples,
|
| 2114 |
+
image_edit_examples,
|
| 2115 |
+
image_understanding_examples,
|
| 2116 |
+
)
|
| 2117 |
+
|
| 2118 |
+
task.change(
|
| 2119 |
+
fn=update_task_ui,
|
| 2120 |
+
inputs=[task],
|
| 2121 |
+
outputs=[
|
| 2122 |
+
prompt,
|
| 2123 |
+
system_prompt,
|
| 2124 |
+
input_video,
|
| 2125 |
+
input_image,
|
| 2126 |
+
aspect_ratio,
|
| 2127 |
+
height,
|
| 2128 |
+
width,
|
| 2129 |
+
real_size,
|
| 2130 |
+
num_frames,
|
| 2131 |
+
enable_frame_interpolation,
|
| 2132 |
+
resolution,
|
| 2133 |
+
output_video,
|
| 2134 |
+
output_image,
|
| 2135 |
+
output_text,
|
| 2136 |
+
video_generation_examples_group,
|
| 2137 |
+
video_edit_examples_group,
|
| 2138 |
+
video_understanding_examples_group,
|
| 2139 |
+
image_generation_examples_group,
|
| 2140 |
+
image_edit_examples_group,
|
| 2141 |
+
image_understanding_examples_group,
|
| 2142 |
+
],
|
| 2143 |
+
)
|
| 2144 |
+
|
| 2145 |
+
aspect_ratio.change(
|
| 2146 |
+
fn=update_size_from_aspect_ratio,
|
| 2147 |
+
inputs=[task, aspect_ratio],
|
| 2148 |
+
outputs=[height, width, real_size],
|
| 2149 |
+
queue=False,
|
| 2150 |
+
show_api=False,
|
| 2151 |
+
)
|
| 2152 |
+
|
| 2153 |
+
for examples_component in (video_edit_examples, video_understanding_examples, image_edit_examples, image_understanding_examples):
|
| 2154 |
+
examples_component.load_input_event.then(
|
| 2155 |
+
fn=reset_generation_defaults_for_task,
|
| 2156 |
+
inputs=[task],
|
| 2157 |
+
outputs=[aspect_ratio, height, width, num_frames, resolution, real_size],
|
| 2158 |
+
queue=False,
|
| 2159 |
+
show_api=False,
|
| 2160 |
+
)
|
| 2161 |
+
|
| 2162 |
+
video_generation_examples.select(
|
| 2163 |
+
fn=apply_prompt_example,
|
| 2164 |
+
inputs=[task],
|
| 2165 |
+
outputs=[prompt, aspect_ratio, height, width, num_frames, resolution, real_size],
|
| 2166 |
+
queue=False,
|
| 2167 |
+
show_api=False,
|
| 2168 |
+
)
|
| 2169 |
+
image_generation_examples.select(
|
| 2170 |
+
fn=apply_prompt_example,
|
| 2171 |
+
inputs=[task],
|
| 2172 |
+
outputs=[prompt, aspect_ratio, height, width, num_frames, resolution, real_size],
|
| 2173 |
+
queue=False,
|
| 2174 |
+
show_api=False,
|
| 2175 |
+
)
|
| 2176 |
+
|
| 2177 |
+
run_button.click(
|
| 2178 |
+
fn=run_task,
|
| 2179 |
+
inputs=[
|
| 2180 |
+
task,
|
| 2181 |
+
prompt,
|
| 2182 |
+
system_prompt,
|
| 2183 |
+
input_video,
|
| 2184 |
+
input_image,
|
| 2185 |
+
height,
|
| 2186 |
+
width,
|
| 2187 |
+
num_frames,
|
| 2188 |
+
seed,
|
| 2189 |
+
resolution,
|
| 2190 |
+
validation_num_timesteps,
|
| 2191 |
+
validation_timestep_shift,
|
| 2192 |
+
cfg_text_scale,
|
| 2193 |
+
enable_frame_interpolation,
|
| 2194 |
+
],
|
| 2195 |
+
outputs=[output_video, output_image, output_text, status, logs],
|
| 2196 |
+
)
|
| 2197 |
+
|
| 2198 |
+
return demo
|
| 2199 |
+
|
| 2200 |
+
|
| 2201 |
+
def parse_args() -> argparse.Namespace:
|
| 2202 |
+
parser = argparse.ArgumentParser(description="Lance multimodal Gradio")
|
| 2203 |
+
parser.add_argument("--server-name", default=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"))
|
| 2204 |
+
parser.add_argument("--server-port", type=int, default=int(os.getenv("GRADIO_SERVER_PORT", "7860")))
|
| 2205 |
+
parser.add_argument("--share", action="store_true", default=env_flag("GRADIO_SHARE", False))
|
| 2206 |
+
parser.add_argument(
|
| 2207 |
+
"--gpus",
|
| 2208 |
+
default=os.getenv("LANCE_GPUS", DEFAULT_GPUS),
|
| 2209 |
+
help="Comma-separated GPU list, for example: 0,1,2,3,4,5,6",
|
| 2210 |
+
)
|
| 2211 |
+
parser.add_argument(
|
| 2212 |
+
"--queue-size",
|
| 2213 |
+
type=int,
|
| 2214 |
+
default=int(os.getenv("LANCE_QUEUE_SIZE", str(DEFAULT_QUEUE_SIZE))),
|
| 2215 |
+
help="Maximum number of queued Gradio requests.",
|
| 2216 |
+
)
|
| 2217 |
+
return parser.parse_args()
|
| 2218 |
+
|
| 2219 |
+
|
| 2220 |
+
def parse_gpu_ids(gpu_string: str) -> list[int]:
|
| 2221 |
+
gpu_ids: list[int] = []
|
| 2222 |
+
for item in gpu_string.split(","):
|
| 2223 |
+
item = item.strip()
|
| 2224 |
+
if not item:
|
| 2225 |
+
continue
|
| 2226 |
+
gpu_ids.append(int(item))
|
| 2227 |
+
if not gpu_ids:
|
| 2228 |
+
raise ValueError("No valid GPU IDs were parsed.")
|
| 2229 |
+
return gpu_ids
|
| 2230 |
+
|
| 2231 |
+
|
| 2232 |
+
if __name__ == "__main__":
|
| 2233 |
+
args = parse_args()
|
| 2234 |
+
os.environ["LANCE_GPUS"] = args.gpus
|
| 2235 |
+
QUEUE_MAX_SIZE = args.queue_size
|
| 2236 |
+
gpu_ids = parse_gpu_ids(args.gpus)
|
| 2237 |
+
preload_pipeline_pools(gpu_ids, PRELOAD_MODEL_VARIANTS)
|
| 2238 |
+
default_concurrency_limit = max(1, len(gpu_ids))
|
| 2239 |
+
demo = build_demo()
|
| 2240 |
+
demo.queue(
|
| 2241 |
+
max_size=args.queue_size,
|
| 2242 |
+
default_concurrency_limit=default_concurrency_limit,
|
| 2243 |
+
).launch(
|
| 2244 |
+
server_name=args.server_name,
|
| 2245 |
+
server_port=args.server_port,
|
| 2246 |
+
share=args.share,
|
| 2247 |
+
)
|
assets/image-understanding/cases/image-understanding-case-02.png
ADDED
|
Git LFS Details
|
assets/image-understanding/cases/image-understanding-case-05.png
ADDED
|
Git LFS Details
|
assets/image-understanding/cases/image-understanding-case-06.png
ADDED
|
Git LFS Details
|
assets/logo/lance-logo.webp
ADDED
|
Git LFS Details
|
assets/video-understanding/videos/video-understanding-caption-long-01.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5f070eefe27dc3f3d065926345299b996124dc1ee4372c223164ddfd0792ce1a
|
| 3 |
+
size 5318845
|
assets/video-understanding/videos/video-understanding-caption-short-01.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5fcb4c18846571444ae024331a64e8740716e3b151f3e05a0d901b405b608da6
|
| 3 |
+
size 2209818
|
assets/video-understanding/videos/video-understanding-vqa-01.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4f382ee52b21942d7840eef2843bf5c57ed4e5ff4bb958e2c4fa23635030c02b
|
| 3 |
+
size 2673972
|
benchmarks/image_gen/DPG/DPG.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
benchmarks/image_gen/DPG/README.md
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[Chinese Version](./README_zh.md)
|
| 2 |
+
|
| 3 |
+
# DPG Image Generation Evaluation
|
| 4 |
+
|
| 5 |
+
Benchmark evaluation scripts for DPG based on the Lance model.
|
| 6 |
+
|
| 7 |
+
## Files
|
| 8 |
+
|
| 9 |
+
- `sample_DPG.py` - Python inference script
|
| 10 |
+
- `sample_DPG.sh` - Launch script
|
| 11 |
+
- `DPG.jsonl` - Evaluation dataset
|
| 12 |
+
|
| 13 |
+
## Quick Start
|
| 14 |
+
|
| 15 |
+
### Basic Usage
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
bash benchmarks/image_gen/DPG/sample_DPG.sh
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
Before running, edit the "Inference Parameters" section at the top of `benchmarks/image_gen/DPG/sample_DPG.sh`.
|
| 22 |
+
|
| 23 |
+
## Parameters
|
| 24 |
+
|
| 25 |
+
| Parameter | Default | Description |
|
| 26 |
+
|------|--------|------|
|
| 27 |
+
| `TASK_NAME` | `t2i` | Task type. DPG is fixed to image generation. |
|
| 28 |
+
| `VALIDATION_NUM_TIMESTEPS` | 50 | Number of inference steps. |
|
| 29 |
+
| `VALIDATION_TIMESTEP_SHIFT` | 3.5 | Timestep shift. |
|
| 30 |
+
| `EVALUATION_SEED` | 42 | Random seed. |
|
| 31 |
+
| `CFG_TEXT_SCALE` | 4.0 | CFG scale. |
|
| 32 |
+
| `CFG_INTERVAL_START` | 0.4 | Start of the CFG interval. |
|
| 33 |
+
| `CFG_INTERVAL_END` | 1.0 | End of the CFG interval. |
|
| 34 |
+
| `SAMPLE_NUM_PER_PROMPT` | 4 | Number of images generated per case for the final grid. |
|
| 35 |
+
| `USE_KVCACHE` | `true` | Whether to enable KV cache. |
|
| 36 |
+
| `NUM_GPUS` | 8 | Number of GPUs. |
|
| 37 |
+
| `VIDEO_HEIGHT`/`VIDEO_WIDTH` | 768 | Image resolution. |
|
| 38 |
+
| `MODEL_PATH` | `downloads/Lance_3B` | Path to the Lance checkpoint. |
|
| 39 |
+
| `VAL_DATASET_CONFIG_FILE` | `benchmarks/image_gen/DPG/DPG.jsonl` | Path to the evaluation data. |
|
| 40 |
+
|
| 41 |
+
## How To Modify
|
| 42 |
+
|
| 43 |
+
- Edit the "Inference Parameters" section at the top of `benchmarks/image_gen/DPG/sample_DPG.sh`.
|
| 44 |
+
- After updating the parameters, run `bash benchmarks/image_gen/DPG/sample_DPG.sh` directly.
|
| 45 |
+
- `SAVE_PATH_GEN` is generated automatically from the script parameters and does not need to be set manually.
|
| 46 |
+
|
| 47 |
+
## Output Format
|
| 48 |
+
|
| 49 |
+
Results are saved in a structure like this:
|
| 50 |
+
|
| 51 |
+
```
|
| 52 |
+
results/DPG_ts50_tss3.5_seed42_cfg4.0_kvcache_20260507_120000/
|
| 53 |
+
├── 0.png
|
| 54 |
+
├── 1.png
|
| 55 |
+
├── 2.png
|
| 56 |
+
└── ...
|
| 57 |
+
```
|
benchmarks/image_gen/DPG/README_zh.md
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[English Version](./README.md)
|
| 2 |
+
|
| 3 |
+
# DPG 图像生成评估
|
| 4 |
+
|
| 5 |
+
基于 Lance 模型的 DPG 评估基准测试脚本。
|
| 6 |
+
|
| 7 |
+
## 文件说明
|
| 8 |
+
|
| 9 |
+
- `sample_DPG.py` - 推理 Python 脚本
|
| 10 |
+
- `sample_DPG.sh` - 启动脚本
|
| 11 |
+
- `DPG.jsonl` - 评估数据集
|
| 12 |
+
|
| 13 |
+
## 快速开始
|
| 14 |
+
|
| 15 |
+
### 基本用法
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
bash benchmarks/image_gen/DPG/sample_DPG.sh
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
运行前请直接修改 `benchmarks/image_gen/DPG/sample_DPG.sh` 顶部的“推理参数配置”区。
|
| 22 |
+
|
| 23 |
+
## 参数说明
|
| 24 |
+
|
| 25 |
+
| 参数 | 默认值 | 说明 |
|
| 26 |
+
|------|--------|------|
|
| 27 |
+
| `TASK_NAME` | `t2i` | 任务类型,DPG 固定为图像生成 |
|
| 28 |
+
| `VALIDATION_NUM_TIMESTEPS` | 50 | 推理步数 |
|
| 29 |
+
| `VALIDATION_TIMESTEP_SHIFT` | 3.5 | Timestep shift |
|
| 30 |
+
| `EVALUATION_SEED` | 42 | 随机种子 |
|
| 31 |
+
| `CFG_TEXT_SCALE` | 4.0 | CFG scale |
|
| 32 |
+
| `CFG_INTERVAL_START` | 0.4 | CFG 区间起点 |
|
| 33 |
+
| `CFG_INTERVAL_END` | 1.0 | CFG 区间终点 |
|
| 34 |
+
| `SAMPLE_NUM_PER_PROMPT` | 4 | 每个 case 生成的图像数量,用于拼接最终网格图 |
|
| 35 |
+
| `USE_KVCACHE` | `true` | 是否启用 KV cache |
|
| 36 |
+
| `NUM_GPUS` | 8 | GPU 数量 |
|
| 37 |
+
| `VIDEO_HEIGHT`/`VIDEO_WIDTH` | 768 | 图像分辨率 |
|
| 38 |
+
| `MODEL_PATH` | `downloads/Lance_3B` | Lance checkpoint 路径 |
|
| 39 |
+
| `VAL_DATASET_CONFIG_FILE` | `benchmarks/image_gen/DPG/DPG.jsonl` | 评估数据路径 |
|
| 40 |
+
|
| 41 |
+
## 修改方式
|
| 42 |
+
|
| 43 |
+
- 请手动编辑 `benchmarks/image_gen/DPG/sample_DPG.sh` 顶部的“推理参数配置”区。
|
| 44 |
+
- 修改完成后,直接运行 `bash benchmarks/image_gen/DPG/sample_DPG.sh`。
|
| 45 |
+
- `SAVE_PATH_GEN` 由脚本根据顶部参数自动生成,不需要手动设置。
|
| 46 |
+
|
| 47 |
+
## 保存格式
|
| 48 |
+
|
| 49 |
+
结果会按照以下结构保存:
|
| 50 |
+
|
| 51 |
+
```
|
| 52 |
+
results/DPG_ts50_tss3.5_seed42_cfg4.0_kvcache_20260507_120000/
|
| 53 |
+
├── 0.png
|
| 54 |
+
├── 1.png
|
| 55 |
+
├── 2.png
|
| 56 |
+
└── ...
|
| 57 |
+
```
|
benchmarks/image_gen/DPG/sample_DPG.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# coding: utf-8
|
| 15 |
+
|
| 16 |
+
import warnings
|
| 17 |
+
warnings.filterwarnings("ignore", message=".*pkg_resources is deprecated.*", category=UserWarning)
|
| 18 |
+
warnings.filterwarnings("ignore", category=FutureWarning, module="diffusers.models.transformers.transformer_2d")
|
| 19 |
+
import os
|
| 20 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
| 21 |
+
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
| 22 |
+
|
| 23 |
+
import os.path as osp
|
| 24 |
+
from copy import deepcopy
|
| 25 |
+
from typing import Tuple, cast, Optional
|
| 26 |
+
import torch
|
| 27 |
+
import torch.distributed as dist
|
| 28 |
+
from torch.utils.data import DataLoader
|
| 29 |
+
from transformers import HfArgumentParser, set_seed
|
| 30 |
+
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig
|
| 31 |
+
from safetensors.torch import load_file
|
| 32 |
+
from PIL import Image
|
| 33 |
+
from torchvision.utils import make_grid
|
| 34 |
+
import numpy as np
|
| 35 |
+
from tqdm import trange
|
| 36 |
+
|
| 37 |
+
from data.dataset_base import DataConfig, simple_custom_collate
|
| 38 |
+
from data.data_utils import add_special_tokens
|
| 39 |
+
from modeling.vae.wan.model import WanVideoVAE
|
| 40 |
+
from modeling.lance import LanceConfig, Lance, Qwen2ForCausalLM
|
| 41 |
+
from modeling.qwen2 import Qwen2Tokenizer
|
| 42 |
+
from modeling.qwen2.modeling_qwen2 import Qwen2Config
|
| 43 |
+
from modeling.vit.qwen2_5_vl_vit import Qwen2_5_VisionTransformerPretrainedModel
|
| 44 |
+
from common.utils.misc import tuple_mul, AutoEncoderParams
|
| 45 |
+
from common.utils.logging import get_logger
|
| 46 |
+
from common.val.utils import make_padded_latent
|
| 47 |
+
from data.datasets_custom import ValidationDataset
|
| 48 |
+
from config.config_factory import ModelArguments, DataArguments, TrainingArguments, EvaluationArguments, get_model_path
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def init_from_vlm_if_needed(model: Qwen2ForCausalLM, model_args: ModelArguments, log_rank0):
|
| 52 |
+
# NOTE: 初始化加载VLM模型走这里
|
| 53 |
+
def load_safetensors_state_dict(folder_path):
|
| 54 |
+
# 只选取safetensors文件,按文件名排序保证顺序
|
| 55 |
+
safetensor_files = sorted(
|
| 56 |
+
f for f in os.listdir(folder_path) if f.endswith(".safetensors")
|
| 57 |
+
)
|
| 58 |
+
state_dict = {}
|
| 59 |
+
for filename in safetensor_files:
|
| 60 |
+
file_path = osp.join(folder_path, filename)
|
| 61 |
+
state_dict.update(load_file(file_path))
|
| 62 |
+
return state_dict
|
| 63 |
+
|
| 64 |
+
state_dict = load_safetensors_state_dict(model_args.llm_path)
|
| 65 |
+
|
| 66 |
+
# 参数名的更改以适配Lance的参数名
|
| 67 |
+
for k in list(state_dict.keys()):
|
| 68 |
+
if "visual" in k: # ViT and connector
|
| 69 |
+
state_dict[k.replace("visual", "vit_model")] = state_dict.pop(k)
|
| 70 |
+
else:
|
| 71 |
+
# 添加language_model前缀
|
| 72 |
+
state_dict["language_model." + k] = state_dict.pop(k)
|
| 73 |
+
|
| 74 |
+
result = model.load_state_dict(state_dict, strict=False)
|
| 75 |
+
|
| 76 |
+
clean_memory(state_dict)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def init_from_model_path_if_needed(model: Qwen2ForCausalLM, model_args: ModelArguments):
|
| 80 |
+
# 统一从 model_path 加载训练好的 Lance checkpoint。
|
| 81 |
+
path_dir = model_args.model_path
|
| 82 |
+
ema_path = osp.join(path_dir, "ema.safetensors")
|
| 83 |
+
model_path = osp.join(path_dir, "model.safetensors")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
model_path_ft = None
|
| 87 |
+
if osp.exists(model_path):
|
| 88 |
+
model_path_ft = model_path
|
| 89 |
+
elif osp.exists(ema_path):
|
| 90 |
+
model_path_ft = ema_path
|
| 91 |
+
|
| 92 |
+
if model_path_ft:
|
| 93 |
+
model_state_dict = load_file(model_path_ft, device="cpu")
|
| 94 |
+
else:
|
| 95 |
+
raise FileNotFoundError(
|
| 96 |
+
f"Fine-tuning failed: No valid checkpoint ('ema.safetensors' or 'model.safetensors') found in {path_dir}"
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# NOTE: position embeds are fixed sinusoidal embeddings, so we can just pop it off,
|
| 100 |
+
# which makes it easier to adapt to different resolutions.
|
| 101 |
+
if 'latent_pos_embed.pos_embed' in model_state_dict:
|
| 102 |
+
model_state_dict.pop('latent_pos_embed.pos_embed')
|
| 103 |
+
|
| 104 |
+
msg = model.load_state_dict(model_state_dict, strict=False)
|
| 105 |
+
|
| 106 |
+
clean_memory(model_state_dict)
|
| 107 |
+
|
| 108 |
+
return msg
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def clean_memory(*objects):
|
| 112 |
+
"""清理内存并释放 GPU 缓存"""
|
| 113 |
+
for obj in objects:
|
| 114 |
+
del obj
|
| 115 |
+
import gc
|
| 116 |
+
gc.collect()
|
| 117 |
+
if torch.cuda.is_available():
|
| 118 |
+
torch.cuda.empty_cache()
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def decode_video_tensor_for_dpg(v_list):
|
| 122 |
+
"""
|
| 123 |
+
专门为 DPG 解码视频张量,保持原有的保存格式
|
| 124 |
+
"""
|
| 125 |
+
N_target = len(v_list)
|
| 126 |
+
if N_target != 1:
|
| 127 |
+
from einops import rearrange
|
| 128 |
+
padded_videos_latent = [v.permute(1, 0, 2, 3) for v in v_list]
|
| 129 |
+
v_tc_hw = rearrange(padded_videos_latent, "n t c h w -> t c h (n w)")
|
| 130 |
+
else:
|
| 131 |
+
v_tc_hw = v_list[0].permute(1, 0, 2, 3)
|
| 132 |
+
|
| 133 |
+
v_tc_hw = v_tc_hw.float().clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round().clamp(0, 255).to(torch.uint8)
|
| 134 |
+
return v_tc_hw
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def resolve_dpg_paths(
|
| 138 |
+
model_args: ModelArguments,
|
| 139 |
+
data_args: DataArguments,
|
| 140 |
+
) -> None:
|
| 141 |
+
if not model_args.model_path:
|
| 142 |
+
raise ValueError("DPG requires --model_path to be provided explicitly.")
|
| 143 |
+
|
| 144 |
+
if not model_args.llm_path:
|
| 145 |
+
model_args.llm_path = model_args.model_path
|
| 146 |
+
|
| 147 |
+
if not model_args.vit_path:
|
| 148 |
+
model_args.vit_path = get_model_path("vit.qwen2_5_vl")
|
| 149 |
+
|
| 150 |
+
if not data_args.val_dataset_config_file:
|
| 151 |
+
data_args.val_dataset_config_file = get_model_path("dpg.data")
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def validate_on_fixed_batch(
|
| 155 |
+
fsdp_model: Lance,
|
| 156 |
+
vae_model: Optional[WanVideoVAE],
|
| 157 |
+
tokenizer: Qwen2Tokenizer,
|
| 158 |
+
val_data_cpu: dict,
|
| 159 |
+
training_args: TrainingArguments,
|
| 160 |
+
model_args: ModelArguments,
|
| 161 |
+
data_args: DataArguments,
|
| 162 |
+
inference_args: EvaluationArguments,
|
| 163 |
+
curr_step: int,
|
| 164 |
+
logger,
|
| 165 |
+
new_token_ids,
|
| 166 |
+
image_token_id: int,
|
| 167 |
+
device: int,
|
| 168 |
+
save_source_video: bool = False,
|
| 169 |
+
save_path_gen: str = "",
|
| 170 |
+
save_path_gt: str = "",
|
| 171 |
+
sample_num_per_prompt: int = 1,
|
| 172 |
+
):
|
| 173 |
+
"""
|
| 174 |
+
验证逻辑,保持与原文件相同的保存格式
|
| 175 |
+
"""
|
| 176 |
+
# 检查是否初始化了分布式环境
|
| 177 |
+
if dist.is_initialized():
|
| 178 |
+
is_rank0 = (dist.get_rank() == 0)
|
| 179 |
+
else:
|
| 180 |
+
is_rank0 = True
|
| 181 |
+
|
| 182 |
+
log_rank0 = logger.info if is_rank0 else (lambda *_: None)
|
| 183 |
+
val_data = val_data_cpu.cuda(device).to_dict()
|
| 184 |
+
|
| 185 |
+
with torch.no_grad(), torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
| 186 |
+
# 计算 padded_latent
|
| 187 |
+
if "padded_videos" in val_data.keys():
|
| 188 |
+
val_data["padded_latent"] = make_padded_latent(val_data["padded_videos"], val_data["vae_data_mode"], vae_model)
|
| 189 |
+
|
| 190 |
+
# -------------------- GEN 分支 --------------------
|
| 191 |
+
tensor_list_for_grid = []
|
| 192 |
+
loop_iterator = trange(sample_num_per_prompt) if is_rank0 else range(sample_num_per_prompt)
|
| 193 |
+
|
| 194 |
+
# 支持断点重新生成
|
| 195 |
+
save_name = f"{save_path_gen}/{val_data['index']}.png"
|
| 196 |
+
if os.path.exists(save_name):
|
| 197 |
+
return None
|
| 198 |
+
|
| 199 |
+
for sample_num_per_prompt_index in loop_iterator:
|
| 200 |
+
# 采样生成(保持原参数)
|
| 201 |
+
params = {
|
| 202 |
+
"val_packed_text_ids": val_data["packed_text_ids"],
|
| 203 |
+
"val_packed_text_indexes": val_data["packed_text_indexes"],
|
| 204 |
+
"val_sample_lens": val_data["sample_lens"],
|
| 205 |
+
"val_packed_position_ids": val_data["packed_position_ids"],
|
| 206 |
+
"val_split_lens": val_data["split_lens"],
|
| 207 |
+
"val_attn_modes": val_data["attn_modes"],
|
| 208 |
+
"val_sample_N_target": val_data["sample_N_target"],
|
| 209 |
+
"val_packed_vae_token_indexes": val_data["packed_vae_token_indexes"],
|
| 210 |
+
"timestep_shift": training_args.validation_timestep_shift,
|
| 211 |
+
"num_timesteps": training_args.validation_num_timesteps,
|
| 212 |
+
"val_mse_loss_indexes": val_data.get("mse_loss_indexes", None),
|
| 213 |
+
"val_padded_latent": val_data["padded_latent"],
|
| 214 |
+
"video_sizes": val_data["video_sizes"],
|
| 215 |
+
"cfg_text_scale": model_args.cfg_text_scale,
|
| 216 |
+
"cfg_interval": training_args.cfg_interval,
|
| 217 |
+
"cfg_renorm_min": training_args.cfg_renorm_min,
|
| 218 |
+
"cfg_renorm_type": training_args.cfg_renorm_type,
|
| 219 |
+
"device": device,
|
| 220 |
+
"dtype": torch.bfloat16,
|
| 221 |
+
"new_token_ids": new_token_ids,
|
| 222 |
+
"max_samples": training_args.validation_max_samples,
|
| 223 |
+
"validation_noise_seed": training_args.validation_noise_seed + sample_num_per_prompt_index,
|
| 224 |
+
"apply_chat_template": training_args.apply_chat_template,
|
| 225 |
+
"apply_qwen_2_5_vl_pos_emb": training_args.apply_qwen_2_5_vl_pos_emb,
|
| 226 |
+
"image_token_id": image_token_id,
|
| 227 |
+
"val_packed_vit_token_indexes": val_data.get("packed_vit_token_indexes", None),
|
| 228 |
+
"val_packed_vit_tokens": val_data.get("packed_vit_tokens", None),
|
| 229 |
+
"vit_video_grid_thw": val_data.get("vit_video_grid_thw", None),
|
| 230 |
+
"vae_video_grid_thw": val_data["vae_video_grid_thw"],
|
| 231 |
+
"video_grid_thw": val_data.get("video_grid_thw", None),
|
| 232 |
+
"caption": val_data.get("caption", None),
|
| 233 |
+
"sample_task": val_data["sample_task"],
|
| 234 |
+
"sample_modality": val_data["sample_modality"],
|
| 235 |
+
"cfg_type": training_args.cfg_type,
|
| 236 |
+
"cfg_uncond_token_id": training_args.cfg_uncond_token_id,
|
| 237 |
+
"index": val_data["index"],
|
| 238 |
+
"val_padded_videos": val_data["padded_videos"] if save_source_video else None,
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
if training_args.use_KVcache:
|
| 242 |
+
denoise_latent, captions, padded_videos, index = fsdp_model.validation_gen_KVcache(**params)
|
| 243 |
+
else:
|
| 244 |
+
denoise_latent, captions, padded_videos, index = fsdp_model.validation_gen(**params)
|
| 245 |
+
|
| 246 |
+
# 解码 + 保存
|
| 247 |
+
for i_val, latent in enumerate(denoise_latent):
|
| 248 |
+
v_list = [vae_model.vae_decode([latent_])[0] for latent_ in latent]
|
| 249 |
+
|
| 250 |
+
# 保持与原文件相同的保存格式
|
| 251 |
+
v_thwc = decode_video_tensor_for_dpg(v_list)
|
| 252 |
+
|
| 253 |
+
# 直接取第0帧
|
| 254 |
+
if v_thwc.shape[0] == 1:
|
| 255 |
+
tensor_list_for_grid.append(v_thwc.squeeze(0).cpu())
|
| 256 |
+
else:
|
| 257 |
+
raise NotImplementedError("需要保存图像")
|
| 258 |
+
|
| 259 |
+
# 保持原有的保存格式
|
| 260 |
+
grid_tensor = make_grid(tensor_list_for_grid, nrow=int(np.sqrt(sample_num_per_prompt)), padding=0, pad_value=255)
|
| 261 |
+
grid_numpy = grid_tensor.permute(1, 2, 0).numpy()
|
| 262 |
+
Image.fromarray(grid_numpy).save(save_name)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def main():
|
| 266 |
+
# ========================= Env setup ==============================
|
| 267 |
+
assert torch.cuda.is_available()
|
| 268 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
| 269 |
+
dist.init_process_group("nccl")
|
| 270 |
+
GLOBAL_RANK = dist.get_rank()
|
| 271 |
+
WORLD_SIZE = dist.get_world_size()
|
| 272 |
+
else:
|
| 273 |
+
GLOBAL_RANK = 0
|
| 274 |
+
WORLD_SIZE = 1
|
| 275 |
+
|
| 276 |
+
LOCAL_RANK = GLOBAL_RANK % torch.cuda.device_count()
|
| 277 |
+
DEVICE = LOCAL_RANK
|
| 278 |
+
torch.cuda.set_device(DEVICE)
|
| 279 |
+
|
| 280 |
+
# ========================= Args and logger setup ==============================
|
| 281 |
+
parser = HfArgumentParser((ModelArguments, DataArguments, EvaluationArguments))
|
| 282 |
+
model_args, data_args, inference_args = cast(
|
| 283 |
+
Tuple[ModelArguments, DataArguments, EvaluationArguments],
|
| 284 |
+
parser.parse_args_into_dataclasses(),
|
| 285 |
+
)
|
| 286 |
+
training_args = inference_args
|
| 287 |
+
|
| 288 |
+
# ========================= DPG 路径解析 ==============================
|
| 289 |
+
resolve_dpg_paths(model_args, data_args)
|
| 290 |
+
|
| 291 |
+
# NOTE validation_noise_seed 与 validation_data_seed 相同
|
| 292 |
+
training_args.validation_noise_seed = inference_args.evaluation_seed
|
| 293 |
+
training_args.validation_data_seed = inference_args.evaluation_seed
|
| 294 |
+
logger = get_logger()
|
| 295 |
+
log_rank0 = print if GLOBAL_RANK == 0 else (lambda *_: None)
|
| 296 |
+
|
| 297 |
+
# Set seed:
|
| 298 |
+
seed = training_args.global_seed * WORLD_SIZE + GLOBAL_RANK
|
| 299 |
+
set_seed(seed)
|
| 300 |
+
|
| 301 |
+
# ========================= LLM model setup ==============================
|
| 302 |
+
llm_config: Qwen2Config = Qwen2Config.from_json_file(osp.join(model_args.model_path, "llm_config.json"))
|
| 303 |
+
|
| 304 |
+
llm_config.layer_module = model_args.layer_module
|
| 305 |
+
llm_config.qk_norm = model_args.llm_qk_norm
|
| 306 |
+
llm_config.qk_norm_und = model_args.llm_qk_norm_und
|
| 307 |
+
llm_config.qk_norm_gen = model_args.llm_qk_norm_gen
|
| 308 |
+
|
| 309 |
+
llm_config.tie_word_embeddings = model_args.tie_word_embeddings
|
| 310 |
+
llm_config.freeze_und = training_args.freeze_und
|
| 311 |
+
llm_config.apply_qwen_2_5_vl_pos_emb = training_args.apply_qwen_2_5_vl_pos_emb
|
| 312 |
+
|
| 313 |
+
language_model: Qwen2ForCausalLM = Qwen2ForCausalLM(llm_config)
|
| 314 |
+
|
| 315 |
+
if training_args.visual_und:
|
| 316 |
+
if model_args.vit_type in ("qwen2_5_vl", "qwen_2_5_vl_original"):
|
| 317 |
+
vit_config = Qwen2_5_VLVisionConfig.from_pretrained(model_args.vit_path)
|
| 318 |
+
vit_model = Qwen2_5_VisionTransformerPretrainedModel(vit_config)
|
| 319 |
+
vit_weights = load_file(osp.join(model_args.vit_path, "vit.safetensors"))
|
| 320 |
+
vit_model.load_state_dict(vit_weights, strict=True)
|
| 321 |
+
else:
|
| 322 |
+
raise ValueError(f"Unsupported vit_type: {model_args.vit_type}")
|
| 323 |
+
|
| 324 |
+
clean_memory(vit_weights)
|
| 325 |
+
|
| 326 |
+
if training_args.visual_gen:
|
| 327 |
+
vae_model = WanVideoVAE()
|
| 328 |
+
vae_config: AutoEncoderParams = deepcopy(vae_model.vae_config)
|
| 329 |
+
else:
|
| 330 |
+
vae_model = None
|
| 331 |
+
vae_config = None
|
| 332 |
+
|
| 333 |
+
# Lance的配置
|
| 334 |
+
config = LanceConfig(
|
| 335 |
+
visual_gen=training_args.visual_gen,
|
| 336 |
+
visual_und=training_args.visual_und,
|
| 337 |
+
llm_config=llm_config,
|
| 338 |
+
vit_config=vit_config if training_args.visual_und else None,
|
| 339 |
+
vae_config=vae_config if training_args.visual_gen else None,
|
| 340 |
+
latent_patch_size=model_args.latent_patch_size,
|
| 341 |
+
max_num_frames=model_args.max_num_frames,
|
| 342 |
+
max_latent_size=model_args.max_latent_size,
|
| 343 |
+
vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side,
|
| 344 |
+
connector_act=model_args.connector_act,
|
| 345 |
+
interpolate_pos=model_args.interpolate_pos,
|
| 346 |
+
timestep_shift=training_args.timestep_shift,
|
| 347 |
+
)
|
| 348 |
+
model: Lance = Lance(
|
| 349 |
+
language_model=language_model,
|
| 350 |
+
vit_model=vit_model if training_args.visual_und else None,
|
| 351 |
+
vit_type=model_args.vit_type,
|
| 352 |
+
config=config,
|
| 353 |
+
training_args=training_args,
|
| 354 |
+
)
|
| 355 |
+
model = model.to(DEVICE)
|
| 356 |
+
|
| 357 |
+
# Setup tokenizer for model:
|
| 358 |
+
tokenizer: Qwen2Tokenizer = Qwen2Tokenizer.from_pretrained(model_args.model_path)
|
| 359 |
+
|
| 360 |
+
tokenizer, new_token_ids, num_new_tokens = add_special_tokens(tokenizer)
|
| 361 |
+
|
| 362 |
+
# 在加载ckpt前,初始化moe
|
| 363 |
+
if training_args.copy_init_moe:
|
| 364 |
+
language_model.init_moe()
|
| 365 |
+
|
| 366 |
+
init_from_model_path_if_needed(model, model_args)
|
| 367 |
+
|
| 368 |
+
# 现在再 resize
|
| 369 |
+
if num_new_tokens > 0:
|
| 370 |
+
model.language_model.resize_token_embeddings(len(tokenizer))
|
| 371 |
+
model.config.llm_config.vocab_size = len(tokenizer)
|
| 372 |
+
model.language_model.config.vocab_size = len(tokenizer)
|
| 373 |
+
|
| 374 |
+
if model_args.vit_type.lower() == "qwen2_5_vl":
|
| 375 |
+
from common.model.hacks import hack_qwen2_5_vl_config
|
| 376 |
+
language_model = hack_qwen2_5_vl_config(language_model)
|
| 377 |
+
|
| 378 |
+
image_token_id = language_model.config.video_token_id
|
| 379 |
+
new_token_ids.update({"image_token_id": image_token_id})
|
| 380 |
+
model.update_tokenizer(tokenizer=tokenizer)
|
| 381 |
+
|
| 382 |
+
if model_args.tie_word_embeddings:
|
| 383 |
+
model.language_model.untie_lm_head()
|
| 384 |
+
model.language_model.copy_new_token_rows_to_lm_head(num_new_tokens)
|
| 385 |
+
|
| 386 |
+
model_args.tie_word_embeddings = False
|
| 387 |
+
llm_config.tie_word_embeddings = False
|
| 388 |
+
else:
|
| 389 |
+
assert model.language_model.get_input_embeddings().weight.data.data_ptr() != model.language_model.get_output_embeddings().weight.data.data_ptr(), 'tie_world_embeddings 冲突'
|
| 390 |
+
|
| 391 |
+
model = model.to(device=DEVICE, dtype=torch.bfloat16)
|
| 392 |
+
model.eval()
|
| 393 |
+
if vae_model is not None and hasattr(vae_model, "eval"):
|
| 394 |
+
vae_model.eval()
|
| 395 |
+
|
| 396 |
+
# Setup packed dataloader - 直接初始化简单的 DataConfig 对象
|
| 397 |
+
dataset_config = DataConfig(grouped_datasets={})
|
| 398 |
+
|
| 399 |
+
# 配置基本参数
|
| 400 |
+
dataset_config.num_frames = inference_args.num_frames
|
| 401 |
+
dataset_config.H = inference_args.video_height
|
| 402 |
+
dataset_config.W = inference_args.video_width
|
| 403 |
+
dataset_config.task = inference_args.task
|
| 404 |
+
dataset_config.resolution = inference_args.resolution
|
| 405 |
+
dataset_config.text_template = inference_args.text_template
|
| 406 |
+
|
| 407 |
+
# 配置 VIT 相关参数
|
| 408 |
+
if training_args.visual_und:
|
| 409 |
+
dataset_config.vit_patch_size = model_args.vit_patch_size
|
| 410 |
+
dataset_config.vit_patch_size_temporal = model_args.vit_patch_size_temporal
|
| 411 |
+
dataset_config.vit_max_num_patch_per_side = model_args.vit_max_num_patch_per_side
|
| 412 |
+
|
| 413 |
+
# 配置 VAE 相关参数
|
| 414 |
+
if training_args.visual_gen and vae_config:
|
| 415 |
+
assert len(model_args.latent_patch_size) == 3, "len(latent_patch_size) must be 3"
|
| 416 |
+
vae_downsample = tuple_mul(
|
| 417 |
+
model_args.latent_patch_size, (vae_config.downsample_temporal, vae_config.downsample_spatial, vae_config.downsample_spatial)
|
| 418 |
+
)
|
| 419 |
+
dataset_config.latent_patch_size = model_args.latent_patch_size
|
| 420 |
+
dataset_config.vae_downsample = vae_downsample
|
| 421 |
+
dataset_config.max_latent_size = model_args.max_latent_size
|
| 422 |
+
dataset_config.max_num_frames = model_args.max_num_frames
|
| 423 |
+
|
| 424 |
+
# fix: 共享dropout
|
| 425 |
+
dataset_config.text_cond_dropout_prob = model_args.text_cond_dropout_prob
|
| 426 |
+
dataset_config.vae_cond_dropout_prob = model_args.vae_cond_dropout_prob
|
| 427 |
+
dataset_config.vit_cond_dropout_prob = model_args.vit_cond_dropout_prob
|
| 428 |
+
|
| 429 |
+
# 创建数据集
|
| 430 |
+
val_dataset = ValidationDataset(
|
| 431 |
+
jsonl_path= data_args.val_dataset_config_file,
|
| 432 |
+
tokenizer=tokenizer,
|
| 433 |
+
data_args=data_args,
|
| 434 |
+
model_args=model_args,
|
| 435 |
+
training_args=training_args,
|
| 436 |
+
new_token_ids=new_token_ids,
|
| 437 |
+
dataset_config=dataset_config,
|
| 438 |
+
local_rank=GLOBAL_RANK,
|
| 439 |
+
world_size=WORLD_SIZE,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
val_loader = DataLoader(
|
| 443 |
+
val_dataset,
|
| 444 |
+
batch_size=1,
|
| 445 |
+
num_workers=0,
|
| 446 |
+
pin_memory=True,
|
| 447 |
+
collate_fn=simple_custom_collate,
|
| 448 |
+
drop_last=True,
|
| 449 |
+
prefetch_factor=None,
|
| 450 |
+
persistent_workers=False,
|
| 451 |
+
multiprocessing_context=None,
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
val_loader_iter = iter(val_loader)
|
| 455 |
+
|
| 456 |
+
if not os.path.exists(inference_args.save_path_gen):
|
| 457 |
+
os.makedirs(inference_args.save_path_gen, exist_ok=True)
|
| 458 |
+
|
| 459 |
+
# 主循环
|
| 460 |
+
from tqdm import tqdm
|
| 461 |
+
import time
|
| 462 |
+
from datetime import datetime, timedelta
|
| 463 |
+
|
| 464 |
+
total_batches = len(val_loader)
|
| 465 |
+
pbar = tqdm(total=total_batches, desc="Validating", unit="batch", leave=True, ncols=120, disable=(GLOBAL_RANK != 0))
|
| 466 |
+
start_time = time.time()
|
| 467 |
+
|
| 468 |
+
for i in range(total_batches):
|
| 469 |
+
val_data_cpu = next(val_loader_iter)
|
| 470 |
+
|
| 471 |
+
validate_on_fixed_batch(
|
| 472 |
+
fsdp_model=model,
|
| 473 |
+
vae_model=vae_model,
|
| 474 |
+
tokenizer=tokenizer,
|
| 475 |
+
val_data_cpu=val_data_cpu,
|
| 476 |
+
training_args=training_args,
|
| 477 |
+
model_args=model_args,
|
| 478 |
+
data_args=data_args,
|
| 479 |
+
inference_args=inference_args,
|
| 480 |
+
curr_step=0,
|
| 481 |
+
logger=logger,
|
| 482 |
+
new_token_ids=new_token_ids,
|
| 483 |
+
image_token_id=image_token_id,
|
| 484 |
+
device=DEVICE,
|
| 485 |
+
save_source_video=False,
|
| 486 |
+
save_path_gen=inference_args.save_path_gen,
|
| 487 |
+
save_path_gt="",
|
| 488 |
+
sample_num_per_prompt=inference_args.sample_num_per_prompt,
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
if GLOBAL_RANK == 0:
|
| 492 |
+
elapsed = time.time() - start_time
|
| 493 |
+
avg_time = elapsed / (i + 1)
|
| 494 |
+
eta_seconds = avg_time * (total_batches - i - 1)
|
| 495 |
+
expected_finish = datetime.now() + timedelta(seconds=eta_seconds)
|
| 496 |
+
finish_str = expected_finish.strftime('%Y-%m-%d %H:%M:%S')
|
| 497 |
+
|
| 498 |
+
pbar.set_postfix_str(f"ETA: {timedelta(seconds=int(eta_seconds))} | Finish: {finish_str}")
|
| 499 |
+
pbar.update(1)
|
| 500 |
+
|
| 501 |
+
if GLOBAL_RANK == 0:
|
| 502 |
+
pbar.close()
|
| 503 |
+
|
| 504 |
+
if dist.is_initialized():
|
| 505 |
+
dist.destroy_process_group()
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
if __name__ == "__main__":
|
| 509 |
+
main()
|
benchmarks/image_gen/DPG/sample_DPG.sh
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
| 4 |
+
source "$SCRIPT_DIR/../../sample_env.sh"
|
| 5 |
+
|
| 6 |
+
# ========================= 推理参数配置 =========================
|
| 7 |
+
TASK_NAME="t2i"
|
| 8 |
+
NUM_GPUS=8
|
| 9 |
+
|
| 10 |
+
VALIDATION_NUM_TIMESTEPS=50
|
| 11 |
+
VALIDATION_TIMESTEP_SHIFT=3.5
|
| 12 |
+
EVALUATION_SEED=42
|
| 13 |
+
CFG_TEXT_SCALE=4.0
|
| 14 |
+
CFG_INTERVAL_START=0.4
|
| 15 |
+
CFG_INTERVAL_END=1.0
|
| 16 |
+
SAMPLE_NUM_PER_PROMPT=4
|
| 17 |
+
USE_KVCACHE=true
|
| 18 |
+
|
| 19 |
+
VIDEO_HEIGHT=768
|
| 20 |
+
VIDEO_WIDTH=768
|
| 21 |
+
|
| 22 |
+
MODEL_PATH="downloads/Lance_3B"
|
| 23 |
+
VAL_DATASET_CONFIG_FILE="benchmarks/image_gen/DPG/DPG.jsonl"
|
| 24 |
+
|
| 25 |
+
# ========================= 自动生成路径 =========================
|
| 26 |
+
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
|
| 27 |
+
KVCACHE_TAG=""
|
| 28 |
+
if [ "$USE_KVCACHE" = "true" ]; then
|
| 29 |
+
KVCACHE_TAG="kvcache_"
|
| 30 |
+
fi
|
| 31 |
+
SAVE_PATH_GEN="results/DPG_ts${VALIDATION_NUM_TIMESTEPS}_tss${VALIDATION_TIMESTEP_SHIFT}_seed${EVALUATION_SEED}_cfg${CFG_TEXT_SCALE}_${KVCACHE_TAG}${TIMESTAMP}"
|
| 32 |
+
|
| 33 |
+
if [ -z "$MODEL_PATH" ]; then
|
| 34 |
+
echo "错误: 请在脚本顶部配置区手动设置 MODEL_PATH"
|
| 35 |
+
exit 1
|
| 36 |
+
fi
|
| 37 |
+
|
| 38 |
+
# ============================== 环境与分布式配置 ==============================
|
| 39 |
+
lance_setup_common_env
|
| 40 |
+
lance_setup_distributed_env "$NUM_GPUS"
|
| 41 |
+
lance_setup_shard_env 1
|
| 42 |
+
|
| 43 |
+
# ========================= 显示任务配置 =========================
|
| 44 |
+
echo "================================================"
|
| 45 |
+
echo "DPG T2I 推理"
|
| 46 |
+
echo "================================================"
|
| 47 |
+
echo "GPU数量: ${NUM_GPUS}"
|
| 48 |
+
echo "保存路径: ${SAVE_PATH_GEN}"
|
| 49 |
+
echo "分辨率: ${VIDEO_HEIGHT}x${VIDEO_WIDTH}"
|
| 50 |
+
echo "模型路径: ${MODEL_PATH}"
|
| 51 |
+
if [ -n "$VAL_DATASET_CONFIG_FILE" ]; then
|
| 52 |
+
echo "数据路径: ${VAL_DATASET_CONFIG_FILE}"
|
| 53 |
+
fi
|
| 54 |
+
echo ""
|
| 55 |
+
echo "关键参数:"
|
| 56 |
+
echo " - validation_num_timesteps: ${VALIDATION_NUM_TIMESTEPS}"
|
| 57 |
+
echo " - validation_timestep_shift: ${VALIDATION_TIMESTEP_SHIFT}"
|
| 58 |
+
echo " - evaluation_seed: ${EVALUATION_SEED}"
|
| 59 |
+
echo " - cfg_text_scale: ${CFG_TEXT_SCALE}"
|
| 60 |
+
echo " - cfg_interval: [${CFG_INTERVAL_START}, ${CFG_INTERVAL_END}]"
|
| 61 |
+
echo " - sample_num_per_prompt: ${SAMPLE_NUM_PER_PROMPT}"
|
| 62 |
+
echo " - use_KVcache: ${USE_KVCACHE}"
|
| 63 |
+
echo "================================================"
|
| 64 |
+
echo ""
|
| 65 |
+
|
| 66 |
+
# ============================== 执行推理 ==============================
|
| 67 |
+
# 注意:请直接修改本脚本顶部的“推理参数配置”区
|
| 68 |
+
accelerate launch \
|
| 69 |
+
--num_machines $NUM_MACHINES \
|
| 70 |
+
--num_processes $TOTAL_RANK \
|
| 71 |
+
--machine_rank $MACHINE_RANK \
|
| 72 |
+
--main_process_ip $MAIN_PROCESS_IP \
|
| 73 |
+
--main_process_port $MAIN_PROCESS_PORT \
|
| 74 |
+
--mixed_precision bf16 \
|
| 75 |
+
benchmarks/image_gen/DPG/sample_DPG.py \
|
| 76 |
+
--model_path "$MODEL_PATH" \
|
| 77 |
+
--val_dataset_config_file "$VAL_DATASET_CONFIG_FILE" \
|
| 78 |
+
--vit_type qwen_2_5_vl_original \
|
| 79 |
+
--llm_qk_norm true \
|
| 80 |
+
--llm_qk_norm_und true \
|
| 81 |
+
--llm_qk_norm_gen true \
|
| 82 |
+
--tie_word_embeddings false \
|
| 83 |
+
--validation_num_timesteps $VALIDATION_NUM_TIMESTEPS \
|
| 84 |
+
--validation_timestep_shift $VALIDATION_TIMESTEP_SHIFT \
|
| 85 |
+
--copy_init_moe true \
|
| 86 |
+
--use_flex true \
|
| 87 |
+
--max_num_frames 1 \
|
| 88 |
+
--max_latent_size 64 \
|
| 89 |
+
--latent_patch_size 1 1 1 \
|
| 90 |
+
--num_replicate $NUM_REPLICATE \
|
| 91 |
+
--num_shard $NUM_SHARD \
|
| 92 |
+
--visual_und true \
|
| 93 |
+
--visual_gen true \
|
| 94 |
+
--vae_model_type wan \
|
| 95 |
+
--apply_qwen_2_5_vl_pos_emb true \
|
| 96 |
+
--apply_chat_template false \
|
| 97 |
+
--cfg_type 0 \
|
| 98 |
+
--validation_data_seed $EVALUATION_SEED \
|
| 99 |
+
--video_height $VIDEO_HEIGHT \
|
| 100 |
+
--video_width $VIDEO_WIDTH \
|
| 101 |
+
--task $TASK_NAME \
|
| 102 |
+
--save_path_gen $SAVE_PATH_GEN \
|
| 103 |
+
--resolution image_768res \
|
| 104 |
+
--text_template true \
|
| 105 |
+
--sample_num_per_prompt $SAMPLE_NUM_PER_PROMPT \
|
| 106 |
+
--cfg_text_scale $CFG_TEXT_SCALE \
|
| 107 |
+
--cfg_interval $CFG_INTERVAL_START $CFG_INTERVAL_END \
|
| 108 |
+
--use_KVcache $USE_KVCACHE
|
| 109 |
+
|
| 110 |
+
echo ""
|
| 111 |
+
echo "================================================"
|
| 112 |
+
echo "完成! 结果: ${SAVE_PATH_GEN}"
|
| 113 |
+
echo "================================================"
|
benchmarks/image_gen/GEdit/GEdit_en.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
benchmarks/image_gen/GEdit/README.md
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[Chinese Version](./README_zh.md)
|
| 2 |
+
|
| 3 |
+
# GEdit Image Editing Evaluation
|
| 4 |
+
|
| 5 |
+
Benchmark evaluation scripts for GEdit based on the Lance model.
|
| 6 |
+
|
| 7 |
+
## Files
|
| 8 |
+
|
| 9 |
+
- `sample_GEdit.py` - Python inference script
|
| 10 |
+
- `sample_GEdit.sh` - Launch script
|
| 11 |
+
- `GEdit_en.json` - Evaluation dataset
|
| 12 |
+
|
| 13 |
+
## Quick Start
|
| 14 |
+
|
| 15 |
+
### Basic Usage
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
bash benchmarks/image_gen/GEdit/sample_GEdit.sh
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
Before running, edit the "Inference Parameters" section at the top of `benchmarks/image_gen/GEdit/sample_GEdit.sh`.
|
| 22 |
+
Please follow `https://github.com/stepfun-ai/Step1X-Edit` to download the source images in GEdit-Bench and put all images in `benchmarks/image_gen/GEdit/images/`.
|
| 23 |
+
|
| 24 |
+
## Parameters
|
| 25 |
+
|
| 26 |
+
| Parameter | Default | Description |
|
| 27 |
+
|------|--------|------|
|
| 28 |
+
| `TASK_NAME` | `image_edit` | Task type. GEdit is fixed to image editing. |
|
| 29 |
+
| `VALIDATION_NUM_TIMESTEPS` | 50 | Number of inference steps. |
|
| 30 |
+
| `VALIDATION_TIMESTEP_SHIFT` | 3.5 | Timestep shift. |
|
| 31 |
+
| `EVALUATION_SEED` | 42 | Random seed. |
|
| 32 |
+
| `CFG_TEXT_SCALE` | 4.0 | CFG scale. |
|
| 33 |
+
| `CFG_INTERVAL_START` | 0.4 | Start of the CFG interval. |
|
| 34 |
+
| `CFG_INTERVAL_END` | 1.0 | End of the CFG interval. |
|
| 35 |
+
| `USE_KVCACHE` | `true` | Whether to enable KV cache. |
|
| 36 |
+
| `NUM_GPUS` | 8 | Number of GPUs. |
|
| 37 |
+
| `MODEL_PATH` | `downloads/Lance_3B` | Path to the Lance checkpoint. |
|
| 38 |
+
| `VAL_DATASET_CONFIG_FILE` | `benchmarks/image_gen/GEdit/GEdit_en.json` | Path to the evaluation data. |
|
| 39 |
+
|
| 40 |
+
## How To Modify
|
| 41 |
+
|
| 42 |
+
- Edit the "Inference Parameters" section at the top of `benchmarks/image_gen/GEdit/sample_GEdit.sh`.
|
| 43 |
+
- After updating the parameters, run `bash benchmarks/image_gen/GEdit/sample_GEdit.sh` directly.
|
| 44 |
+
- `SAVE_PATH_GEN` is generated automatically from the script parameters and does not need to be set manually.
|
| 45 |
+
|
| 46 |
+
## Output Format
|
| 47 |
+
|
| 48 |
+
Results are saved in a structure like this:
|
| 49 |
+
|
| 50 |
+
```
|
| 51 |
+
results/GEdit_ts50_tss3.5_seed42_cfg4.0_kvcache_20260507_120000/
|
| 52 |
+
├── fullset/
|
| 53 |
+
│ ├── add/
|
| 54 |
+
│ │ ├── en/
|
| 55 |
+
│ │ │ ├── 000001.webp
|
| 56 |
+
│ │ │ └── ...
|
| 57 |
+
│ ├── remove/
|
| 58 |
+
│ │ └── en/
|
| 59 |
+
│ │ └── ...
|
| 60 |
+
├── prompt.json
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
Each case generates one edited image by default and stores it as a `.webp` file under `task_type/instruction_language/key`. A `prompt.json` file is also written to record the generated text.
|
| 64 |
+
|
| 65 |
+
## Notes
|
| 66 |
+
|
| 67 |
+
- If you need to switch the model, dataset, or resolution, edit the script configuration at the top directly.
|
| 68 |
+
- The default result directory automatically includes key parameters and a timestamp for easier experiment tracking.
|
benchmarks/image_gen/GEdit/README_zh.md
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[English Version](./README.md)
|
| 2 |
+
|
| 3 |
+
# GEdit 图像编辑评估
|
| 4 |
+
|
| 5 |
+
基于 Lance 模型的 GEdit 评估基准测试脚本。
|
| 6 |
+
|
| 7 |
+
## 文件说明
|
| 8 |
+
|
| 9 |
+
- `sample_GEdit.py` - 推理 Python 脚本
|
| 10 |
+
- `sample_GEdit.sh` - 启动脚本
|
| 11 |
+
- `GEdit_en.json` - 评估数据集
|
| 12 |
+
|
| 13 |
+
## 快速开始
|
| 14 |
+
|
| 15 |
+
### 基本用法
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
bash benchmarks/image_gen/GEdit/sample_GEdit.sh
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
运行前请直接修改 `benchmarks/image_gen/GEdit/sample_GEdit.sh` 顶部的“推理参数配置”区。
|
| 22 |
+
请参考 `https://github.com/stepfun-ai/Step1X-Edit` 下载 GEdit-Bench 的源图,并将所有图片放到 `benchmarks/image_gen/GEdit/images/` 中。
|
| 23 |
+
|
| 24 |
+
## 参数说明
|
| 25 |
+
|
| 26 |
+
| 参数 | 默认值 | 说明 |
|
| 27 |
+
|------|--------|------|
|
| 28 |
+
| `TASK_NAME` | `image_edit` | 任务类型,GEdit 固定为图像编辑 |
|
| 29 |
+
| `VALIDATION_NUM_TIMESTEPS` | 50 | 推理步数 |
|
| 30 |
+
| `VALIDATION_TIMESTEP_SHIFT` | 3.5 | Timestep shift |
|
| 31 |
+
| `EVALUATION_SEED` | 42 | 随机种子 |
|
| 32 |
+
| `CFG_TEXT_SCALE` | 4.0 | CFG scale |
|
| 33 |
+
| `CFG_INTERVAL_START` | 0.4 | CFG 区间起点 |
|
| 34 |
+
| `CFG_INTERVAL_END` | 1.0 | CFG 区间终点 |
|
| 35 |
+
| `USE_KVCACHE` | `true` | 是否启用 KV cache |
|
| 36 |
+
| `NUM_GPUS` | 8 | GPU 数量 |
|
| 37 |
+
| `MODEL_PATH` | `downloads/Lance_3B` | Lance checkpoint 路径 |
|
| 38 |
+
| `VAL_DATASET_CONFIG_FILE` | `benchmarks/image_gen/GEdit/GEdit_en.json` | 评估数据路径 |
|
| 39 |
+
|
| 40 |
+
## 修改方式
|
| 41 |
+
|
| 42 |
+
- 请手动编辑 `benchmarks/image_gen/GEdit/sample_GEdit.sh` 顶部的“推理参数配置”区。
|
| 43 |
+
- 修改完成后,直接运行 `bash benchmarks/image_gen/GEdit/sample_GEdit.sh`。
|
| 44 |
+
- `SAVE_PATH_GEN` 由脚本根据顶部参数自动生成,不需要手动设置。
|
| 45 |
+
|
| 46 |
+
## 保存格式
|
| 47 |
+
|
| 48 |
+
结果会按照以下结构保存:
|
| 49 |
+
|
| 50 |
+
```
|
| 51 |
+
results/GEdit_ts50_tss3.5_seed42_cfg4.0_kvcache_20260507_120000/
|
| 52 |
+
├── fullset/
|
| 53 |
+
│ ├── add/
|
| 54 |
+
│ │ ├── en/
|
| 55 |
+
│ │ │ ├── 000001.webp
|
| 56 |
+
│ │ │ └── ...
|
| 57 |
+
│ ├── remove/
|
| 58 |
+
│ │ └── en/
|
| 59 |
+
│ │ └── ...
|
| 60 |
+
├── prompt.json
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
每个 case 默认生成 1 张编辑结果图,并按 `task_type/instruction_language/key` 分目录保存为 `.webp` 文件;同时会额外写出 `prompt.json` 用于记录生成文本。
|
| 64 |
+
## 注意事项
|
| 65 |
+
|
| 66 |
+
- 如果需要切换模型、数据集或分辨率,请直接修改脚本顶部配置。
|
| 67 |
+
- 默认结果目录会自动包含关键参数和时间戳,方便区分不同实验。
|
benchmarks/image_gen/GEdit/sample_GEdit.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# coding: utf-8
|
| 15 |
+
|
| 16 |
+
import warnings
|
| 17 |
+
warnings.filterwarnings("ignore", message=".*pkg_resources is deprecated.*", category=UserWarning)
|
| 18 |
+
warnings.filterwarnings("ignore", category=FutureWarning, module="diffusers.models.transformers.transformer_2d")
|
| 19 |
+
import os
|
| 20 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
| 21 |
+
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
| 22 |
+
|
| 23 |
+
import os.path as osp
|
| 24 |
+
from copy import deepcopy
|
| 25 |
+
import json
|
| 26 |
+
from typing import Tuple, cast, Optional
|
| 27 |
+
import torch
|
| 28 |
+
import torch.distributed as dist
|
| 29 |
+
from torch.utils.data import DataLoader
|
| 30 |
+
from transformers import HfArgumentParser, set_seed
|
| 31 |
+
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig
|
| 32 |
+
from safetensors.torch import load_file
|
| 33 |
+
from PIL import Image
|
| 34 |
+
from tqdm import trange
|
| 35 |
+
|
| 36 |
+
from data.dataset_base import DataConfig, simple_custom_collate
|
| 37 |
+
from data.data_utils import add_special_tokens
|
| 38 |
+
from modeling.vae.wan.model import WanVideoVAE
|
| 39 |
+
from modeling.lance import LanceConfig, Lance, Qwen2ForCausalLM
|
| 40 |
+
from modeling.qwen2 import Qwen2Tokenizer
|
| 41 |
+
from modeling.qwen2.modeling_qwen2 import Qwen2Config
|
| 42 |
+
from modeling.vit.qwen2_5_vl_vit import Qwen2_5_VisionTransformerPretrainedModel
|
| 43 |
+
from common.utils.misc import tuple_mul, AutoEncoderParams
|
| 44 |
+
from common.val.utils import make_padded_latent, decode_video_tensor
|
| 45 |
+
from data.datasets_custom import ValidationDataset
|
| 46 |
+
from config.config_factory import ModelArguments, DataArguments, TrainingArguments, EvaluationArguments, get_model_path
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def init_from_vlm_if_needed(model: Qwen2ForCausalLM, model_args: ModelArguments, log_rank0):
|
| 50 |
+
def load_safetensors_state_dict(folder_path):
|
| 51 |
+
safetensor_files = sorted(
|
| 52 |
+
f for f in os.listdir(folder_path) if f.endswith(".safetensors")
|
| 53 |
+
)
|
| 54 |
+
state_dict = {}
|
| 55 |
+
for filename in safetensor_files:
|
| 56 |
+
file_path = osp.join(folder_path, filename)
|
| 57 |
+
state_dict.update(load_file(file_path))
|
| 58 |
+
return state_dict
|
| 59 |
+
|
| 60 |
+
state_dict = load_safetensors_state_dict(model_args.llm_path)
|
| 61 |
+
|
| 62 |
+
for k in list(state_dict.keys()):
|
| 63 |
+
if "visual" in k:
|
| 64 |
+
state_dict[k.replace("visual", "vit_model")] = state_dict.pop(k)
|
| 65 |
+
else:
|
| 66 |
+
state_dict["language_model." + k] = state_dict.pop(k)
|
| 67 |
+
|
| 68 |
+
result = model.load_state_dict(state_dict, strict=False)
|
| 69 |
+
del state_dict
|
| 70 |
+
import gc; gc.collect(); torch.cuda.empty_cache()
|
| 71 |
+
return result
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def init_from_model_path_if_needed(model: Qwen2ForCausalLM, model_args: ModelArguments):
|
| 75 |
+
path_dir = model_args.model_path
|
| 76 |
+
ema_path = osp.join(path_dir, "ema.safetensors")
|
| 77 |
+
model_path = osp.join(path_dir, "model.safetensors")
|
| 78 |
+
|
| 79 |
+
model_path_ft = None
|
| 80 |
+
if osp.exists(model_path):
|
| 81 |
+
model_path_ft = model_path
|
| 82 |
+
elif osp.exists(ema_path):
|
| 83 |
+
model_path_ft = ema_path
|
| 84 |
+
|
| 85 |
+
if model_path_ft:
|
| 86 |
+
model_state_dict = load_file(model_path_ft, device="cpu")
|
| 87 |
+
else:
|
| 88 |
+
raise FileNotFoundError(
|
| 89 |
+
f"Fine-tuning failed: No valid checkpoint ('ema.safetensors' or 'model.safetensors') found in {path_dir}"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
if 'latent_pos_embed.pos_embed' in model_state_dict:
|
| 93 |
+
model_state_dict.pop('latent_pos_embed.pos_embed')
|
| 94 |
+
|
| 95 |
+
msg = model.load_state_dict(model_state_dict, strict=False)
|
| 96 |
+
del model_state_dict
|
| 97 |
+
import gc; gc.collect(); torch.cuda.empty_cache()
|
| 98 |
+
return msg
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def save_prompt_results(prompt_data_dict, save_path_gen):
|
| 102 |
+
prompt_json_path = os.path.join(save_path_gen, "prompt.json")
|
| 103 |
+
with open(prompt_json_path, 'w', encoding='utf-8') as f:
|
| 104 |
+
json.dump(prompt_data_dict, f, ensure_ascii=False, indent=2)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def resolve_gedit_paths(
|
| 108 |
+
model_args: ModelArguments,
|
| 109 |
+
data_args: DataArguments,
|
| 110 |
+
) -> None:
|
| 111 |
+
if not model_args.model_path:
|
| 112 |
+
raise ValueError("GEdit requires --model_path to be provided explicitly.")
|
| 113 |
+
|
| 114 |
+
if not model_args.llm_path:
|
| 115 |
+
model_args.llm_path = model_args.model_path
|
| 116 |
+
|
| 117 |
+
if not model_args.vit_path:
|
| 118 |
+
model_args.vit_path = get_model_path("vit.qwen2_5_vl")
|
| 119 |
+
|
| 120 |
+
if not data_args.val_dataset_config_file:
|
| 121 |
+
data_args.val_dataset_config_file = get_model_path("gedit.data")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def validate_on_fixed_batch(
|
| 125 |
+
fsdp_model: Lance,
|
| 126 |
+
vae_model: Optional[WanVideoVAE],
|
| 127 |
+
val_data_cpu: dict,
|
| 128 |
+
training_args: TrainingArguments,
|
| 129 |
+
model_args: ModelArguments,
|
| 130 |
+
inference_args: EvaluationArguments,
|
| 131 |
+
new_token_ids,
|
| 132 |
+
image_token_id: int,
|
| 133 |
+
device: int,
|
| 134 |
+
save_path_gen: str = "",
|
| 135 |
+
):
|
| 136 |
+
val_data = val_data_cpu.cuda(device).to_dict()
|
| 137 |
+
fsdp_model = fsdp_model.to(device=device, dtype=torch.bfloat16)
|
| 138 |
+
|
| 139 |
+
with torch.no_grad(), torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
| 140 |
+
if "padded_videos" in val_data.keys():
|
| 141 |
+
val_data["padded_latent"] = make_padded_latent(val_data["padded_videos"], val_data["vae_data_mode"], vae_model)
|
| 142 |
+
|
| 143 |
+
metadata = val_data["additional_info"]
|
| 144 |
+
task_type = metadata["task_type"]
|
| 145 |
+
instruction_language = metadata["instruction_language"]
|
| 146 |
+
save_key = metadata["key"]
|
| 147 |
+
save_dir_current = os.path.join(save_path_gen, "fullset/{}/{}".format(task_type, instruction_language))
|
| 148 |
+
os.makedirs(save_dir_current, exist_ok=True)
|
| 149 |
+
|
| 150 |
+
# -------------------- GEN 分支 --------------------
|
| 151 |
+
params = {
|
| 152 |
+
"val_packed_text_ids": val_data["packed_text_ids"],
|
| 153 |
+
"val_packed_text_indexes": val_data["packed_text_indexes"],
|
| 154 |
+
"val_sample_lens": val_data["sample_lens"],
|
| 155 |
+
"val_packed_position_ids": val_data["packed_position_ids"],
|
| 156 |
+
"val_split_lens": val_data["split_lens"],
|
| 157 |
+
"val_attn_modes": val_data["attn_modes"],
|
| 158 |
+
"val_sample_N_target": val_data["sample_N_target"],
|
| 159 |
+
"val_packed_vae_token_indexes": val_data["packed_vae_token_indexes"],
|
| 160 |
+
"timestep_shift": training_args.validation_timestep_shift,
|
| 161 |
+
"num_timesteps": training_args.validation_num_timesteps,
|
| 162 |
+
"val_mse_loss_indexes": val_data.get("mse_loss_indexes", None),
|
| 163 |
+
"val_padded_latent": val_data["padded_latent"],
|
| 164 |
+
"video_sizes": val_data["video_sizes"],
|
| 165 |
+
"cfg_text_scale": model_args.cfg_text_scale,
|
| 166 |
+
"cfg_interval": training_args.cfg_interval,
|
| 167 |
+
"cfg_renorm_min": training_args.cfg_renorm_min,
|
| 168 |
+
"cfg_renorm_type": training_args.cfg_renorm_type,
|
| 169 |
+
"device": device,
|
| 170 |
+
"dtype": torch.bfloat16,
|
| 171 |
+
"new_token_ids": new_token_ids,
|
| 172 |
+
"max_samples": training_args.validation_max_samples,
|
| 173 |
+
"validation_noise_seed": training_args.validation_noise_seed,
|
| 174 |
+
"apply_chat_template": training_args.apply_chat_template,
|
| 175 |
+
"apply_qwen_2_5_vl_pos_emb": training_args.apply_qwen_2_5_vl_pos_emb,
|
| 176 |
+
"image_token_id": image_token_id,
|
| 177 |
+
"val_packed_vit_token_indexes": val_data.get("packed_vit_token_indexes", None),
|
| 178 |
+
"val_packed_vit_tokens": val_data.get("packed_vit_tokens", None),
|
| 179 |
+
"vit_video_grid_thw": val_data.get("vit_video_grid_thw", None),
|
| 180 |
+
"vae_video_grid_thw": val_data["vae_video_grid_thw"],
|
| 181 |
+
"video_grid_thw": val_data.get("video_grid_thw", None),
|
| 182 |
+
"caption": val_data.get("caption", None),
|
| 183 |
+
"sample_task": val_data["sample_task"],
|
| 184 |
+
"sample_modality": val_data["sample_modality"],
|
| 185 |
+
"cfg_type": training_args.cfg_type,
|
| 186 |
+
"cfg_uncond_token_id": training_args.cfg_uncond_token_id,
|
| 187 |
+
"index": val_data["index"],
|
| 188 |
+
"val_padded_videos": None,
|
| 189 |
+
}
|
| 190 |
+
if inference_args.use_KVcache:
|
| 191 |
+
denoise_latent, captions, _, _ = fsdp_model.validation_gen_KVcache(**params)
|
| 192 |
+
else:
|
| 193 |
+
denoise_latent, captions, _, _ = fsdp_model.validation_gen(**params)
|
| 194 |
+
|
| 195 |
+
for i_val, latent in enumerate(denoise_latent):
|
| 196 |
+
target_latent = latent[-1]
|
| 197 |
+
v_target = vae_model.vae_decode([target_latent])[0]
|
| 198 |
+
|
| 199 |
+
v_thwc = decode_video_tensor([v_target], save_path="", save_half=False)
|
| 200 |
+
|
| 201 |
+
if v_thwc.shape[0] != 1:
|
| 202 |
+
raise NotImplementedError(
|
| 203 |
+
"GEdit benchmark only supports image output (max_num_frames=1), "
|
| 204 |
+
f"but got {v_thwc.shape[0]} frames."
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
save_name = f'{save_dir_current}/{save_key}.webp'
|
| 208 |
+
Image.fromarray(v_thwc[0]).save(save_name)
|
| 209 |
+
inference_args.prompt_data_dict[save_name] = captions[i_val]
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def main():
|
| 213 |
+
assert torch.cuda.is_available()
|
| 214 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
| 215 |
+
dist.init_process_group("nccl")
|
| 216 |
+
GLOBAL_RANK = dist.get_rank()
|
| 217 |
+
WORLD_SIZE = dist.get_world_size()
|
| 218 |
+
else:
|
| 219 |
+
GLOBAL_RANK = 0
|
| 220 |
+
WORLD_SIZE = 1
|
| 221 |
+
|
| 222 |
+
LOCAL_RANK = GLOBAL_RANK % torch.cuda.device_count()
|
| 223 |
+
DEVICE = LOCAL_RANK
|
| 224 |
+
torch.cuda.set_device(DEVICE)
|
| 225 |
+
|
| 226 |
+
parser = HfArgumentParser((ModelArguments, DataArguments, EvaluationArguments))
|
| 227 |
+
model_args, data_args, inference_args = cast(
|
| 228 |
+
Tuple[ModelArguments, DataArguments, EvaluationArguments],
|
| 229 |
+
parser.parse_args_into_dataclasses(),
|
| 230 |
+
)
|
| 231 |
+
training_args = inference_args
|
| 232 |
+
|
| 233 |
+
training_args.validation_noise_seed = training_args.validation_data_seed
|
| 234 |
+
|
| 235 |
+
log_rank0 = print if GLOBAL_RANK == 0 else (lambda *_: None)
|
| 236 |
+
|
| 237 |
+
seed = training_args.global_seed * WORLD_SIZE + GLOBAL_RANK
|
| 238 |
+
set_seed(seed)
|
| 239 |
+
|
| 240 |
+
resolve_gedit_paths(model_args, data_args)
|
| 241 |
+
|
| 242 |
+
llm_config: Qwen2Config = Qwen2Config.from_json_file(osp.join(model_args.model_path, "llm_config.json"))
|
| 243 |
+
|
| 244 |
+
llm_config.layer_module = model_args.layer_module
|
| 245 |
+
llm_config.qk_norm = model_args.llm_qk_norm
|
| 246 |
+
llm_config.qk_norm_und = model_args.llm_qk_norm_und
|
| 247 |
+
llm_config.qk_norm_gen = model_args.llm_qk_norm_gen
|
| 248 |
+
llm_config.tie_word_embeddings = model_args.tie_word_embeddings
|
| 249 |
+
llm_config.freeze_und = training_args.freeze_und
|
| 250 |
+
llm_config.apply_qwen_2_5_vl_pos_emb = training_args.apply_qwen_2_5_vl_pos_emb
|
| 251 |
+
|
| 252 |
+
language_model: Qwen2ForCausalLM = Qwen2ForCausalLM(llm_config)
|
| 253 |
+
|
| 254 |
+
if training_args.visual_und:
|
| 255 |
+
if model_args.vit_type in ("qwen2_5_vl", "qwen_2_5_vl_original"):
|
| 256 |
+
vit_config = Qwen2_5_VLVisionConfig.from_pretrained(model_args.vit_path)
|
| 257 |
+
vit_model = Qwen2_5_VisionTransformerPretrainedModel(vit_config)
|
| 258 |
+
vit_weights = load_file(osp.join(model_args.vit_path, "vit.safetensors"))
|
| 259 |
+
vit_model.load_state_dict(vit_weights, strict=True)
|
| 260 |
+
else:
|
| 261 |
+
raise ValueError(f"Unsupported vit_type: {model_args.vit_type}")
|
| 262 |
+
|
| 263 |
+
del vit_weights
|
| 264 |
+
import gc; gc.collect(); torch.cuda.empty_cache()
|
| 265 |
+
|
| 266 |
+
if training_args.visual_gen:
|
| 267 |
+
vae_model = WanVideoVAE()
|
| 268 |
+
vae_config: AutoEncoderParams = deepcopy(vae_model.vae_config)
|
| 269 |
+
else:
|
| 270 |
+
vae_model = None
|
| 271 |
+
vae_config = None
|
| 272 |
+
|
| 273 |
+
config = LanceConfig(
|
| 274 |
+
visual_gen=training_args.visual_gen,
|
| 275 |
+
visual_und=training_args.visual_und,
|
| 276 |
+
llm_config=llm_config,
|
| 277 |
+
vit_config=vit_config if training_args.visual_und else None,
|
| 278 |
+
vae_config=vae_config if training_args.visual_gen else None,
|
| 279 |
+
latent_patch_size=model_args.latent_patch_size,
|
| 280 |
+
max_num_frames=model_args.max_num_frames,
|
| 281 |
+
max_latent_size=model_args.max_latent_size,
|
| 282 |
+
vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side,
|
| 283 |
+
connector_act=model_args.connector_act,
|
| 284 |
+
interpolate_pos=model_args.interpolate_pos,
|
| 285 |
+
timestep_shift=training_args.timestep_shift,
|
| 286 |
+
)
|
| 287 |
+
model: Lance = Lance(
|
| 288 |
+
language_model=language_model,
|
| 289 |
+
vit_model=vit_model if training_args.visual_und else None,
|
| 290 |
+
vit_type=model_args.vit_type,
|
| 291 |
+
config=config,
|
| 292 |
+
training_args=training_args,
|
| 293 |
+
)
|
| 294 |
+
model = model.to(DEVICE)
|
| 295 |
+
|
| 296 |
+
tokenizer: Qwen2Tokenizer = Qwen2Tokenizer.from_pretrained(model_args.model_path)
|
| 297 |
+
|
| 298 |
+
tokenizer, new_token_ids, num_new_tokens = add_special_tokens(tokenizer)
|
| 299 |
+
|
| 300 |
+
if training_args.copy_init_moe:
|
| 301 |
+
language_model.init_moe()
|
| 302 |
+
|
| 303 |
+
init_from_model_path_if_needed(model, model_args)
|
| 304 |
+
|
| 305 |
+
if num_new_tokens > 0:
|
| 306 |
+
model.language_model.resize_token_embeddings(len(tokenizer))
|
| 307 |
+
model.config.llm_config.vocab_size = len(tokenizer)
|
| 308 |
+
model.language_model.config.vocab_size = len(tokenizer)
|
| 309 |
+
|
| 310 |
+
if model_args.vit_type.lower() == "qwen2_5_vl":
|
| 311 |
+
from common.model.hacks import hack_qwen2_5_vl_config
|
| 312 |
+
language_model = hack_qwen2_5_vl_config(language_model)
|
| 313 |
+
|
| 314 |
+
image_token_id = language_model.config.video_token_id
|
| 315 |
+
new_token_ids.update({"image_token_id": image_token_id})
|
| 316 |
+
model.update_tokenizer(tokenizer=tokenizer)
|
| 317 |
+
|
| 318 |
+
if model_args.tie_word_embeddings:
|
| 319 |
+
model.language_model.untie_lm_head()
|
| 320 |
+
model.language_model.copy_new_token_rows_to_lm_head(num_new_tokens)
|
| 321 |
+
|
| 322 |
+
model_args.tie_word_embeddings = False
|
| 323 |
+
llm_config.tie_word_embeddings = False
|
| 324 |
+
else:
|
| 325 |
+
assert model.language_model.get_input_embeddings().weight.data.data_ptr() != model.language_model.get_output_embeddings().weight.data.data_ptr(), 'tie_world_embeddings 冲突'
|
| 326 |
+
|
| 327 |
+
model = model.to(device=DEVICE, dtype=torch.bfloat16)
|
| 328 |
+
model.eval()
|
| 329 |
+
if vae_model is not None and hasattr(vae_model, "eval"):
|
| 330 |
+
vae_model.eval()
|
| 331 |
+
|
| 332 |
+
dataset_config = DataConfig(grouped_datasets={})
|
| 333 |
+
|
| 334 |
+
if training_args.visual_und:
|
| 335 |
+
dataset_config.vit_patch_size = model_args.vit_patch_size
|
| 336 |
+
dataset_config.vit_patch_size_temporal = model_args.vit_patch_size_temporal
|
| 337 |
+
dataset_config.vit_max_num_patch_per_side = model_args.vit_max_num_patch_per_side
|
| 338 |
+
if training_args.visual_gen:
|
| 339 |
+
assert len(model_args.latent_patch_size) == 3, "len(latent_patch_size) must be 3"
|
| 340 |
+
vae_downsample = tuple_mul(
|
| 341 |
+
model_args.latent_patch_size, (vae_config.downsample_temporal, vae_config.downsample_spatial, vae_config.downsample_spatial)
|
| 342 |
+
)
|
| 343 |
+
dataset_config.latent_patch_size = model_args.latent_patch_size
|
| 344 |
+
dataset_config.vae_downsample = vae_downsample
|
| 345 |
+
dataset_config.max_latent_size = model_args.max_latent_size
|
| 346 |
+
dataset_config.max_num_frames = model_args.max_num_frames
|
| 347 |
+
|
| 348 |
+
dataset_config.text_cond_dropout_prob = model_args.text_cond_dropout_prob
|
| 349 |
+
dataset_config.vae_cond_dropout_prob = model_args.vae_cond_dropout_prob
|
| 350 |
+
dataset_config.vit_cond_dropout_prob = model_args.vit_cond_dropout_prob
|
| 351 |
+
|
| 352 |
+
dataset_config.num_frames = inference_args.num_frames
|
| 353 |
+
dataset_config.H = inference_args.video_height
|
| 354 |
+
dataset_config.W = inference_args.video_width
|
| 355 |
+
dataset_config.task = inference_args.task
|
| 356 |
+
dataset_config.resolution = inference_args.resolution
|
| 357 |
+
dataset_config.text_template = inference_args.text_template
|
| 358 |
+
|
| 359 |
+
val_dataset = ValidationDataset(
|
| 360 |
+
jsonl_path=data_args.val_dataset_config_file,
|
| 361 |
+
tokenizer=tokenizer,
|
| 362 |
+
data_args=data_args,
|
| 363 |
+
model_args=model_args,
|
| 364 |
+
training_args=training_args,
|
| 365 |
+
new_token_ids=new_token_ids,
|
| 366 |
+
dataset_config=dataset_config,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
val_loader = DataLoader(
|
| 370 |
+
val_dataset,
|
| 371 |
+
batch_size=1,
|
| 372 |
+
num_workers=0,
|
| 373 |
+
pin_memory=True,
|
| 374 |
+
collate_fn=simple_custom_collate,
|
| 375 |
+
drop_last=True,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
val_loader_iter = iter(val_loader)
|
| 379 |
+
|
| 380 |
+
if not hasattr(inference_args, "prompt_data_dict"):
|
| 381 |
+
inference_args.prompt_data_dict = {}
|
| 382 |
+
|
| 383 |
+
if not os.path.exists(inference_args.save_path_gen):
|
| 384 |
+
os.makedirs(inference_args.save_path_gen)
|
| 385 |
+
|
| 386 |
+
for epoch in trange(len(val_loader), desc="Validating", unit="batch", leave=True, ncols=80, disable=(GLOBAL_RANK != 0)):
|
| 387 |
+
try:
|
| 388 |
+
val_data_cpu = next(val_loader_iter)
|
| 389 |
+
except StopIteration:
|
| 390 |
+
break
|
| 391 |
+
|
| 392 |
+
validate_on_fixed_batch(
|
| 393 |
+
fsdp_model=model,
|
| 394 |
+
vae_model=vae_model,
|
| 395 |
+
val_data_cpu=val_data_cpu,
|
| 396 |
+
training_args=training_args,
|
| 397 |
+
model_args=model_args,
|
| 398 |
+
inference_args=inference_args,
|
| 399 |
+
new_token_ids=new_token_ids,
|
| 400 |
+
image_token_id=image_token_id,
|
| 401 |
+
device=DEVICE,
|
| 402 |
+
save_path_gen=inference_args.save_path_gen,
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
if dist.is_initialized():
|
| 406 |
+
dist.barrier()
|
| 407 |
+
gathered = [None for _ in range(dist.get_world_size())]
|
| 408 |
+
dist.all_gather_object(gathered, inference_args.prompt_data_dict)
|
| 409 |
+
|
| 410 |
+
if GLOBAL_RANK == 0:
|
| 411 |
+
merged = {}
|
| 412 |
+
for d in gathered:
|
| 413 |
+
merged.update(d)
|
| 414 |
+
inference_args.prompt_data_dict = merged
|
| 415 |
+
save_prompt_results(inference_args.prompt_data_dict, inference_args.save_path_gen)
|
| 416 |
+
|
| 417 |
+
elif GLOBAL_RANK == 0:
|
| 418 |
+
save_prompt_results(inference_args.prompt_data_dict, inference_args.save_path_gen)
|
| 419 |
+
|
| 420 |
+
if dist.is_initialized():
|
| 421 |
+
dist.destroy_process_group()
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
if __name__ == "__main__":
|
| 425 |
+
main()
|
benchmarks/image_gen/GEdit/sample_GEdit.sh
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
| 4 |
+
source "$SCRIPT_DIR/../../sample_env.sh"
|
| 5 |
+
|
| 6 |
+
# ========================= 推理参数配置 =========================
|
| 7 |
+
TASK_NAME="image_edit"
|
| 8 |
+
NUM_GPUS=8
|
| 9 |
+
|
| 10 |
+
VALIDATION_NUM_TIMESTEPS=50
|
| 11 |
+
VALIDATION_TIMESTEP_SHIFT=3.5
|
| 12 |
+
EVALUATION_SEED=42
|
| 13 |
+
CFG_TEXT_SCALE=4.0
|
| 14 |
+
CFG_INTERVAL_START=0.4
|
| 15 |
+
CFG_INTERVAL_END=1.0
|
| 16 |
+
USE_KVCACHE=true
|
| 17 |
+
|
| 18 |
+
MODEL_PATH="downloads/Lance_3B"
|
| 19 |
+
VAL_DATASET_CONFIG_FILE="benchmarks/image_gen/GEdit/GEdit_en.json"
|
| 20 |
+
|
| 21 |
+
# ========================= 自动生成路径 =========================
|
| 22 |
+
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
|
| 23 |
+
KVCACHE_TAG=""
|
| 24 |
+
if [ "$USE_KVCACHE" = "true" ]; then
|
| 25 |
+
KVCACHE_TAG="kvcache_"
|
| 26 |
+
fi
|
| 27 |
+
SAVE_PATH_GEN="results/GEdit_ts${VALIDATION_NUM_TIMESTEPS}_tss${VALIDATION_TIMESTEP_SHIFT}_seed${EVALUATION_SEED}_cfg${CFG_TEXT_SCALE}_${KVCACHE_TAG}${TIMESTAMP}"
|
| 28 |
+
|
| 29 |
+
if [ -z "$MODEL_PATH" ]; then
|
| 30 |
+
echo "错误: 请在脚本顶部配置区手动设置 MODEL_PATH"
|
| 31 |
+
exit 1
|
| 32 |
+
fi
|
| 33 |
+
|
| 34 |
+
# ============================== 环境与分布式配置 ==============================
|
| 35 |
+
lance_setup_common_env
|
| 36 |
+
lance_setup_distributed_env "$NUM_GPUS"
|
| 37 |
+
lance_setup_shard_env 1
|
| 38 |
+
|
| 39 |
+
# ========================= 显示任务配置 =========================
|
| 40 |
+
echo "================================================"
|
| 41 |
+
echo "GEdit 图像编辑评估"
|
| 42 |
+
echo "================================================"
|
| 43 |
+
echo "GPU数量: ${NUM_GPUS}"
|
| 44 |
+
echo "保存路径: ${SAVE_PATH_GEN}"
|
| 45 |
+
echo "模型路径: ${MODEL_PATH}"
|
| 46 |
+
if [ -n "$VAL_DATASET_CONFIG_FILE" ]; then
|
| 47 |
+
echo "数据路径: ${VAL_DATASET_CONFIG_FILE}"
|
| 48 |
+
fi
|
| 49 |
+
echo ""
|
| 50 |
+
echo "关键参数:"
|
| 51 |
+
echo " - validation_num_timesteps: ${VALIDATION_NUM_TIMESTEPS}"
|
| 52 |
+
echo " - validation_timestep_shift: ${VALIDATION_TIMESTEP_SHIFT}"
|
| 53 |
+
echo " - evaluation_seed: ${EVALUATION_SEED}"
|
| 54 |
+
echo " - cfg_text_scale: ${CFG_TEXT_SCALE}"
|
| 55 |
+
echo " - cfg_interval: [${CFG_INTERVAL_START}, ${CFG_INTERVAL_END}]"
|
| 56 |
+
echo " - use_KVcache: ${USE_KVCACHE}"
|
| 57 |
+
echo "================================================"
|
| 58 |
+
echo ""
|
| 59 |
+
|
| 60 |
+
# ============================== 执行推理 ==============================
|
| 61 |
+
# 注意:请直接修改本脚本顶部的“推理参数配置”区
|
| 62 |
+
accelerate launch \
|
| 63 |
+
--num_machines $NUM_MACHINES \
|
| 64 |
+
--num_processes $TOTAL_RANK \
|
| 65 |
+
--machine_rank $MACHINE_RANK \
|
| 66 |
+
--main_process_ip $MAIN_PROCESS_IP \
|
| 67 |
+
--main_process_port $MAIN_PROCESS_PORT \
|
| 68 |
+
--mixed_precision bf16 \
|
| 69 |
+
benchmarks/image_gen/GEdit/sample_GEdit.py \
|
| 70 |
+
--model_path "$MODEL_PATH" \
|
| 71 |
+
--val_dataset_config_file "$VAL_DATASET_CONFIG_FILE" \
|
| 72 |
+
--vit_type qwen_2_5_vl_original \
|
| 73 |
+
--llm_qk_norm true \
|
| 74 |
+
--llm_qk_norm_und true \
|
| 75 |
+
--llm_qk_norm_gen true \
|
| 76 |
+
--tie_word_embeddings false \
|
| 77 |
+
--validation_num_timesteps $VALIDATION_NUM_TIMESTEPS \
|
| 78 |
+
--validation_timestep_shift $VALIDATION_TIMESTEP_SHIFT \
|
| 79 |
+
--copy_init_moe true \
|
| 80 |
+
--use_flex true \
|
| 81 |
+
--max_num_frames 1 \
|
| 82 |
+
--max_latent_size 64 \
|
| 83 |
+
--latent_patch_size 1 1 1 \
|
| 84 |
+
--num_replicate $NUM_REPLICATE \
|
| 85 |
+
--num_shard $NUM_SHARD \
|
| 86 |
+
--visual_und true \
|
| 87 |
+
--visual_gen true \
|
| 88 |
+
--vae_model_type wan \
|
| 89 |
+
--apply_qwen_2_5_vl_pos_emb true \
|
| 90 |
+
--apply_chat_template false \
|
| 91 |
+
--cfg_type 0 \
|
| 92 |
+
--validation_data_seed $EVALUATION_SEED \
|
| 93 |
+
--validation_max_samples 100000 \
|
| 94 |
+
--task $TASK_NAME \
|
| 95 |
+
--save_path_gen $SAVE_PATH_GEN \
|
| 96 |
+
--resolution image_768res \
|
| 97 |
+
--text_template true \
|
| 98 |
+
--sample_num_per_prompt 1 \
|
| 99 |
+
--cfg_text_scale $CFG_TEXT_SCALE \
|
| 100 |
+
--cfg_interval $CFG_INTERVAL_START $CFG_INTERVAL_END \
|
| 101 |
+
--use_KVcache $USE_KVCACHE
|
| 102 |
+
|
| 103 |
+
echo ""
|
| 104 |
+
echo "================================================"
|
| 105 |
+
echo "完成! 结果: ${SAVE_PATH_GEN}"
|
| 106 |
+
echo "================================================"
|
benchmarks/image_gen/GenEVAL/GenEVAL.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
benchmarks/image_gen/GenEVAL/README.md
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[Chinese Version](./README_zh.md)
|
| 2 |
+
|
| 3 |
+
# GenEVAL Image Generation Evaluation
|
| 4 |
+
|
| 5 |
+
Benchmark evaluation scripts for GenEVAL based on the Lance model.
|
| 6 |
+
|
| 7 |
+
## Files
|
| 8 |
+
|
| 9 |
+
- `sample_GenEVAL.py` - Python inference script
|
| 10 |
+
- `sample_GenEVAL.sh` - Launch script (recommended)
|
| 11 |
+
- `GenEVAL.jsonl` - Evaluation dataset
|
| 12 |
+
|
| 13 |
+
## Quick Start
|
| 14 |
+
|
| 15 |
+
### Basic Usage
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
bash benchmarks/image_gen/GenEVAL/sample_GenEVAL.sh
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
Before running, edit the "Inference Parameters" section at the top of `benchmarks/image_gen/GenEVAL/sample_GenEVAL.sh`.
|
| 22 |
+
|
| 23 |
+
## Parameters
|
| 24 |
+
|
| 25 |
+
| Parameter | Default | Description |
|
| 26 |
+
|------|--------|------|
|
| 27 |
+
| `TASK_NAME` | `t2i` | Task type. GenEVAL is fixed to image generation. |
|
| 28 |
+
| `VALIDATION_NUM_TIMESTEPS` | 50 | Number of inference steps. |
|
| 29 |
+
| `VALIDATION_TIMESTEP_SHIFT` | 3.5 | Timestep shift. |
|
| 30 |
+
| `EVALUATION_SEED` | 42 | Random seed. |
|
| 31 |
+
| `CFG_TEXT_SCALE` | 4.0 | CFG scale. |
|
| 32 |
+
| `CFG_INTERVAL_START` | 0.4 | Start of the CFG interval. |
|
| 33 |
+
| `CFG_INTERVAL_END` | 1.0 | End of the CFG interval. |
|
| 34 |
+
| `SAMPLE_NUM_PER_PROMPT` | 4 | Number of images generated per case. GenEVAL defaults to 4 images. |
|
| 35 |
+
| `USE_KVCACHE` | `true` | Whether to enable KV cache. |
|
| 36 |
+
| `NUM_GPUS` | 8 | Number of GPUs. |
|
| 37 |
+
| `VIDEO_HEIGHT`/`VIDEO_WIDTH` | 768 | Image resolution. |
|
| 38 |
+
| `MODEL_PATH` | `downloads/Lance_3B` | Path to the Lance checkpoint. |
|
| 39 |
+
| `VAL_DATASET_CONFIG_FILE` | `benchmarks/image_gen/GenEVAL/GenEVAL.jsonl` | Path to the evaluation data. |
|
| 40 |
+
|
| 41 |
+
## How To Modify
|
| 42 |
+
|
| 43 |
+
- Edit the "Inference Parameters" section at the top of `benchmarks/image_gen/GenEVAL/sample_GenEVAL.sh`.
|
| 44 |
+
- After updating the parameters, run `bash benchmarks/image_gen/GenEVAL/sample_GenEVAL.sh` directly.
|
| 45 |
+
- `SAVE_PATH_GEN` is generated automatically from the script parameters and does not need to be set manually.
|
| 46 |
+
|
| 47 |
+
## Output Format
|
| 48 |
+
|
| 49 |
+
Results are saved in a structure like this:
|
| 50 |
+
|
| 51 |
+
```
|
| 52 |
+
results/GenEVAL_ts50_tss3.5_seed42_cfg4.0_kvcache_20260507_120000/
|
| 53 |
+
├── 00000/
|
| 54 |
+
│ ├── metadata.jsonl
|
| 55 |
+
│ ├── grid.png
|
| 56 |
+
│ └── samples/
|
| 57 |
+
│ ├── 0.png
|
| 58 |
+
│ ├── 1.png
|
| 59 |
+
│ ├── 2.png
|
| 60 |
+
│ └── 3.png
|
| 61 |
+
├── 00001/
|
| 62 |
+
│ ├── metadata.jsonl
|
| 63 |
+
│ ├── grid.png
|
| 64 |
+
│ └── samples/
|
| 65 |
+
│ ...
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
Each case generates 4 images by default (`sample_num_per_prompt=4`).
|
| 69 |
+
|
| 70 |
+
## Notes
|
| 71 |
+
|
| 72 |
+
- If you need to switch the model, dataset, or resolution, edit the script configuration at the top directly.
|
| 73 |
+
- The ViT path is resolved automatically by the code and usually does not need to be configured separately.
|
benchmarks/image_gen/GenEVAL/README_zh.md
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[English Version](./README.md)
|
| 2 |
+
|
| 3 |
+
# GenEVAL 图像生成评估
|
| 4 |
+
|
| 5 |
+
基于 Lance 模型的 GenEVAL 评估基准测试脚本。
|
| 6 |
+
|
| 7 |
+
## 文件说明
|
| 8 |
+
|
| 9 |
+
- `sample_GenEVAL.py` - 推理 Python 脚本
|
| 10 |
+
- `sample_GenEVAL.sh` - 启动脚本(推荐使用)
|
| 11 |
+
- `GenEVAL.jsonl` - 评估数据集
|
| 12 |
+
|
| 13 |
+
## 快速开始
|
| 14 |
+
|
| 15 |
+
### 基本用法
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
bash benchmarks/image_gen/GenEVAL/sample_GenEVAL.sh
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
运行前请直接修改 `benchmarks/image_gen/GenEVAL/sample_GenEVAL.sh` 顶部的“推理参数配置”区。
|
| 22 |
+
|
| 23 |
+
## 参数说明
|
| 24 |
+
|
| 25 |
+
| 参数 | 默认值 | 说明 |
|
| 26 |
+
|------|--------|------|
|
| 27 |
+
| `TASK_NAME` | `t2i` | 任务类型,GenEVAL 固定为图像生成 |
|
| 28 |
+
| `VALIDATION_NUM_TIMESTEPS` | 50 | 推理步数 |
|
| 29 |
+
| `VALIDATION_TIMESTEP_SHIFT` | 3.5 | Timestep shift |
|
| 30 |
+
| `EVALUATION_SEED` | 42 | 随机种子 |
|
| 31 |
+
| `CFG_TEXT_SCALE` | 4.0 | CFG scale |
|
| 32 |
+
| `CFG_INTERVAL_START` | 0.4 | CFG 区间起点 |
|
| 33 |
+
| `CFG_INTERVAL_END` | 1.0 | CFG 区间终点 |
|
| 34 |
+
| `SAMPLE_NUM_PER_PROMPT` | 4 | 每个 case 生成的图像数量(GenEVAL 默认为 4 张图) |
|
| 35 |
+
| `USE_KVCACHE` | `true` | 是否启用 KV cache |
|
| 36 |
+
| `NUM_GPUS` | 8 | GPU 数量 |
|
| 37 |
+
| `VIDEO_HEIGHT`/`VIDEO_WIDTH` | 768 | 图像分辨率 |
|
| 38 |
+
| `MODEL_PATH` | `downloads/Lance_3B` | Lance checkpoint 路径 |
|
| 39 |
+
| `VAL_DATASET_CONFIG_FILE` | `benchmarks/image_gen/GenEVAL/GenEVAL.jsonl` | 评估数据路径 |
|
| 40 |
+
|
| 41 |
+
## 修改方式
|
| 42 |
+
|
| 43 |
+
- 请手动编辑 `benchmarks/image_gen/GenEVAL/sample_GenEVAL.sh` 顶部的“推理参数配置”区。
|
| 44 |
+
- 修改完成后,直接运行 `bash benchmarks/image_gen/GenEVAL/sample_GenEVAL.sh`。
|
| 45 |
+
- `SAVE_PATH_GEN` 由脚本根据顶部参数自动生成,不需要手动设置。
|
| 46 |
+
|
| 47 |
+
## 保存格式
|
| 48 |
+
|
| 49 |
+
结果会按照以下结构保存:
|
| 50 |
+
|
| 51 |
+
```
|
| 52 |
+
results/GenEVAL_ts50_tss3.5_seed42_cfg4.0_kvcache_20260507_120000/
|
| 53 |
+
├── 00000/
|
| 54 |
+
│ ├── metadata.jsonl
|
| 55 |
+
│ ├── grid.png
|
| 56 |
+
│ └── samples/
|
| 57 |
+
│ ├── 0.png
|
| 58 |
+
│ ├── 1.png
|
| 59 |
+
│ ├── 2.png
|
| 60 |
+
│ └── 3.png
|
| 61 |
+
├── 00001/
|
| 62 |
+
│ ├── metadata.jsonl
|
| 63 |
+
│ ├── grid.png
|
| 64 |
+
│ └── samples/
|
| 65 |
+
│ ...
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
每个案例生成 4 张图像(`sample_num_per_prompt=4`)。
|
| 69 |
+
|
| 70 |
+
## 注意事项
|
| 71 |
+
|
| 72 |
+
- 如果需要切换模型、数据集或分辨率,请直接修改脚本顶部配置。
|
| 73 |
+
- ViT 路径默认由代码内部自动解析,无需单独配置。
|
benchmarks/image_gen/GenEVAL/sample_GenEVAL.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# coding: utf-8
|
| 15 |
+
|
| 16 |
+
import warnings
|
| 17 |
+
warnings.filterwarnings("ignore", message=".*pkg_resources is deprecated.*", category=UserWarning)
|
| 18 |
+
warnings.filterwarnings("ignore", category=FutureWarning, module="diffusers.models.transformers.transformer_2d")
|
| 19 |
+
import os
|
| 20 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
| 21 |
+
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
| 22 |
+
|
| 23 |
+
import os.path as osp
|
| 24 |
+
from copy import deepcopy
|
| 25 |
+
import json
|
| 26 |
+
from typing import Tuple, cast, Optional
|
| 27 |
+
import torch
|
| 28 |
+
import torch.distributed as dist
|
| 29 |
+
from torch.utils.data import DataLoader
|
| 30 |
+
from transformers import HfArgumentParser, set_seed
|
| 31 |
+
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig
|
| 32 |
+
from safetensors.torch import load_file
|
| 33 |
+
from PIL import Image
|
| 34 |
+
from torchvision.utils import make_grid
|
| 35 |
+
import numpy as np
|
| 36 |
+
from tqdm import trange
|
| 37 |
+
|
| 38 |
+
from data.dataset_base import DataConfig, simple_custom_collate
|
| 39 |
+
from data.data_utils import add_special_tokens
|
| 40 |
+
from modeling.vae.wan.model import WanVideoVAE
|
| 41 |
+
from modeling.lance import LanceConfig, Lance, Qwen2ForCausalLM
|
| 42 |
+
from modeling.qwen2 import Qwen2Tokenizer
|
| 43 |
+
from modeling.qwen2.modeling_qwen2 import Qwen2Config
|
| 44 |
+
from modeling.vit.qwen2_5_vl_vit import Qwen2_5_VisionTransformerPretrainedModel
|
| 45 |
+
from common.utils.misc import tuple_mul, AutoEncoderParams
|
| 46 |
+
from common.val.utils import make_padded_latent
|
| 47 |
+
from data.datasets_custom import ValidationDataset
|
| 48 |
+
from config.config_factory import ModelArguments, DataArguments, EvaluationArguments, get_model_path
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def init_from_model_path_if_needed(model: Qwen2ForCausalLM, model_args: ModelArguments):
|
| 52 |
+
# 统一从 model_path 加载训练好的 Lance checkpoint。
|
| 53 |
+
path_dir = model_args.model_path
|
| 54 |
+
ema_path = osp.join(path_dir, "ema.safetensors")
|
| 55 |
+
model_path = osp.join(path_dir, "model.safetensors")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
model_path_ft = None
|
| 60 |
+
if osp.exists(model_path):
|
| 61 |
+
model_path_ft = model_path
|
| 62 |
+
elif osp.exists(ema_path):
|
| 63 |
+
model_path_ft = ema_path
|
| 64 |
+
|
| 65 |
+
if model_path_ft:
|
| 66 |
+
model_state_dict = load_file(model_path_ft, device="cpu")
|
| 67 |
+
else:
|
| 68 |
+
raise FileNotFoundError(
|
| 69 |
+
f"Fine-tuning failed: No valid checkpoint ('ema.safetensors' or 'model.safetensors') found in {path_dir}"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# NOTE: position embeds are fixed sinusoidal embeddings, so we can just pop it off,
|
| 73 |
+
# which makes it easier to adapt to different resolutions.
|
| 74 |
+
if 'latent_pos_embed.pos_embed' in model_state_dict:
|
| 75 |
+
model_state_dict.pop('latent_pos_embed.pos_embed')
|
| 76 |
+
|
| 77 |
+
model.load_state_dict(model_state_dict, strict=False)
|
| 78 |
+
|
| 79 |
+
clean_memory(model_state_dict)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def clean_memory(*objects):
|
| 83 |
+
"""清理内存并释放 GPU 缓存"""
|
| 84 |
+
for obj in objects:
|
| 85 |
+
del obj
|
| 86 |
+
import gc
|
| 87 |
+
gc.collect()
|
| 88 |
+
if torch.cuda.is_available():
|
| 89 |
+
torch.cuda.empty_cache()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def decode_video_tensor_for_geneval(v_list):
|
| 93 |
+
"""
|
| 94 |
+
专门为 GenEVAL 解码视频张量,保持原有的保存格式
|
| 95 |
+
"""
|
| 96 |
+
N_target = len(v_list)
|
| 97 |
+
if N_target != 1:
|
| 98 |
+
from einops import rearrange
|
| 99 |
+
padded_videos_latent = [v.permute(1, 0, 2, 3) for v in v_list]
|
| 100 |
+
v_tc_hw = rearrange(padded_videos_latent, "n t c h w -> t c h (n w)")
|
| 101 |
+
else:
|
| 102 |
+
v_tc_hw = v_list[0].permute(1, 0, 2, 3)
|
| 103 |
+
|
| 104 |
+
v_tc_hw = v_tc_hw.float().clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round().clamp(0, 255).to(torch.uint8)
|
| 105 |
+
return v_tc_hw
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def resolve_geneval_paths(
|
| 109 |
+
model_args: ModelArguments,
|
| 110 |
+
data_args: DataArguments,
|
| 111 |
+
) -> None:
|
| 112 |
+
if not model_args.model_path:
|
| 113 |
+
raise ValueError("GenEVAL requires --model_path to be provided explicitly.")
|
| 114 |
+
|
| 115 |
+
if not model_args.vit_path:
|
| 116 |
+
model_args.vit_path = get_model_path("vit.qwen2_5_vl")
|
| 117 |
+
|
| 118 |
+
if not data_args.val_dataset_config_file:
|
| 119 |
+
data_args.val_dataset_config_file = get_model_path("geneval.data")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def build_runtime_dataset_config(
|
| 123 |
+
model_args: ModelArguments,
|
| 124 |
+
inference_args: EvaluationArguments,
|
| 125 |
+
vae_config: Optional[AutoEncoderParams],
|
| 126 |
+
) -> DataConfig:
|
| 127 |
+
"""
|
| 128 |
+
当前推理链不再依赖 dataset_config_file,运行期 DataConfig 由显式参数拼装。
|
| 129 |
+
"""
|
| 130 |
+
dataset_config = DataConfig()
|
| 131 |
+
|
| 132 |
+
dataset_config.num_frames = inference_args.num_frames
|
| 133 |
+
dataset_config.H = inference_args.video_height
|
| 134 |
+
dataset_config.W = inference_args.video_width
|
| 135 |
+
dataset_config.task = inference_args.task
|
| 136 |
+
dataset_config.resolution = inference_args.resolution
|
| 137 |
+
dataset_config.text_template = inference_args.text_template
|
| 138 |
+
dataset_config.max_duration = inference_args.max_duration
|
| 139 |
+
dataset_config.system_prompt_type = inference_args.system_prompt_type
|
| 140 |
+
|
| 141 |
+
if inference_args.visual_und:
|
| 142 |
+
dataset_config.vit_patch_size = model_args.vit_patch_size
|
| 143 |
+
dataset_config.vit_patch_size_temporal = model_args.vit_patch_size_temporal
|
| 144 |
+
dataset_config.vit_max_num_patch_per_side = model_args.vit_max_num_patch_per_side
|
| 145 |
+
|
| 146 |
+
if inference_args.visual_gen and vae_config:
|
| 147 |
+
assert len(model_args.latent_patch_size) == 3, "len(latent_patch_size) must be 3"
|
| 148 |
+
dataset_config.latent_patch_size = model_args.latent_patch_size
|
| 149 |
+
dataset_config.vae_downsample = tuple_mul(
|
| 150 |
+
model_args.latent_patch_size,
|
| 151 |
+
(vae_config.downsample_temporal, vae_config.downsample_spatial, vae_config.downsample_spatial),
|
| 152 |
+
)
|
| 153 |
+
dataset_config.max_latent_size = model_args.max_latent_size
|
| 154 |
+
dataset_config.max_num_frames = model_args.max_num_frames
|
| 155 |
+
|
| 156 |
+
dataset_config.text_cond_dropout_prob = model_args.text_cond_dropout_prob
|
| 157 |
+
dataset_config.vae_cond_dropout_prob = model_args.vae_cond_dropout_prob
|
| 158 |
+
dataset_config.vit_cond_dropout_prob = model_args.vit_cond_dropout_prob
|
| 159 |
+
|
| 160 |
+
return dataset_config
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def validate_on_fixed_batch(
|
| 164 |
+
fsdp_model: Lance,
|
| 165 |
+
vae_model: Optional[WanVideoVAE],
|
| 166 |
+
val_data_cpu: dict,
|
| 167 |
+
model_args: ModelArguments,
|
| 168 |
+
inference_args: EvaluationArguments,
|
| 169 |
+
new_token_ids,
|
| 170 |
+
image_token_id: int,
|
| 171 |
+
device: int,
|
| 172 |
+
save_source_video: bool = False,
|
| 173 |
+
save_path_gen: str = "",
|
| 174 |
+
sample_num_per_prompt: int = 1,
|
| 175 |
+
):
|
| 176 |
+
"""
|
| 177 |
+
验证逻辑,保持与原文件相同的保存格式
|
| 178 |
+
"""
|
| 179 |
+
# 检查是否初始化了分布式环境
|
| 180 |
+
if dist.is_initialized():
|
| 181 |
+
is_rank0 = (dist.get_rank() == 0)
|
| 182 |
+
else:
|
| 183 |
+
is_rank0 = True
|
| 184 |
+
|
| 185 |
+
val_data = val_data_cpu.cuda(device).to_dict()
|
| 186 |
+
|
| 187 |
+
with torch.no_grad(), torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
| 188 |
+
# 计算 padded_latent
|
| 189 |
+
if "padded_videos" in val_data.keys():
|
| 190 |
+
val_data["padded_latent"] = make_padded_latent(val_data["padded_videos"], val_data["vae_data_mode"], vae_model)
|
| 191 |
+
|
| 192 |
+
# 先根据val_data["index"]生成一个新的文件夹
|
| 193 |
+
index_save = val_data["index"]
|
| 194 |
+
index_save = f"{index_save:05d}"
|
| 195 |
+
os.makedirs(os.path.join(save_path_gen, index_save), exist_ok=True)
|
| 196 |
+
os.makedirs(os.path.join(save_path_gen, index_save, "samples"), exist_ok=True)
|
| 197 |
+
|
| 198 |
+
# 保存metadata.jsonl
|
| 199 |
+
metadata = val_data["additional_info"]
|
| 200 |
+
with open(os.path.join(save_path_gen, index_save, "metadata.jsonl"), 'w') as f:
|
| 201 |
+
f.write(json.dumps(metadata, ensure_ascii=False) + "\n")
|
| 202 |
+
|
| 203 |
+
# -------------------- GEN 分支 --------------------
|
| 204 |
+
tensor_list_for_grid = []
|
| 205 |
+
loop_iterator = trange(sample_num_per_prompt) if is_rank0 else range(sample_num_per_prompt)
|
| 206 |
+
|
| 207 |
+
for sample_num_per_prompt_index in loop_iterator:
|
| 208 |
+
# 采样生成
|
| 209 |
+
params = {
|
| 210 |
+
"val_packed_text_ids": val_data["packed_text_ids"],
|
| 211 |
+
"val_packed_text_indexes": val_data["packed_text_indexes"],
|
| 212 |
+
"val_sample_lens": val_data["sample_lens"],
|
| 213 |
+
"val_packed_position_ids": val_data["packed_position_ids"],
|
| 214 |
+
"val_split_lens": val_data["split_lens"],
|
| 215 |
+
"val_attn_modes": val_data["attn_modes"],
|
| 216 |
+
"val_sample_N_target": val_data["sample_N_target"],
|
| 217 |
+
"val_packed_vae_token_indexes": val_data["packed_vae_token_indexes"],
|
| 218 |
+
"timestep_shift": inference_args.validation_timestep_shift,
|
| 219 |
+
"num_timesteps": inference_args.validation_num_timesteps,
|
| 220 |
+
"val_mse_loss_indexes": val_data.get("mse_loss_indexes", None),
|
| 221 |
+
"val_padded_latent": val_data["padded_latent"],
|
| 222 |
+
"video_sizes": val_data["video_sizes"],
|
| 223 |
+
"cfg_text_scale": model_args.cfg_text_scale,
|
| 224 |
+
"cfg_interval": inference_args.cfg_interval,
|
| 225 |
+
"cfg_renorm_min": inference_args.cfg_renorm_min,
|
| 226 |
+
"cfg_renorm_type": inference_args.cfg_renorm_type,
|
| 227 |
+
"device": device,
|
| 228 |
+
"dtype": torch.bfloat16,
|
| 229 |
+
"new_token_ids": new_token_ids,
|
| 230 |
+
"max_samples": inference_args.validation_max_samples,
|
| 231 |
+
"validation_noise_seed": inference_args.validation_noise_seed + sample_num_per_prompt_index,
|
| 232 |
+
"apply_chat_template": inference_args.apply_chat_template,
|
| 233 |
+
"apply_qwen_2_5_vl_pos_emb": inference_args.apply_qwen_2_5_vl_pos_emb,
|
| 234 |
+
"image_token_id": image_token_id,
|
| 235 |
+
"val_packed_vit_token_indexes": val_data.get("packed_vit_token_indexes", None),
|
| 236 |
+
"val_packed_vit_tokens": val_data.get("packed_vit_tokens", None),
|
| 237 |
+
"vit_video_grid_thw": val_data.get("vit_video_grid_thw", None),
|
| 238 |
+
"vae_video_grid_thw": val_data["vae_video_grid_thw"],
|
| 239 |
+
"video_grid_thw": val_data.get("video_grid_thw", None),
|
| 240 |
+
"caption": val_data.get("caption", None),
|
| 241 |
+
"sample_task": val_data["sample_task"],
|
| 242 |
+
"sample_modality": val_data["sample_modality"],
|
| 243 |
+
"cfg_type": inference_args.cfg_type,
|
| 244 |
+
"cfg_uncond_token_id": inference_args.cfg_uncond_token_id,
|
| 245 |
+
"index": val_data["index"],
|
| 246 |
+
"val_padded_videos": val_data["padded_videos"] if save_source_video else None,
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
if inference_args.use_KVcache:
|
| 250 |
+
denoise_latent, _, _, _ = fsdp_model.validation_gen_KVcache(**params)
|
| 251 |
+
else:
|
| 252 |
+
denoise_latent, _, _, _ = fsdp_model.validation_gen(**params)
|
| 253 |
+
|
| 254 |
+
# 解码 + 保存
|
| 255 |
+
for latent in denoise_latent:
|
| 256 |
+
v_list = [vae_model.vae_decode([latent_])[0] for latent_ in latent]
|
| 257 |
+
|
| 258 |
+
# 保持与原文件相同的保存格式
|
| 259 |
+
v_thwc = decode_video_tensor_for_geneval(v_list)
|
| 260 |
+
|
| 261 |
+
# 直接取第0帧
|
| 262 |
+
if v_thwc.shape[0] == 1:
|
| 263 |
+
tensor_list_for_grid.append(v_thwc.squeeze(0).cpu())
|
| 264 |
+
|
| 265 |
+
# 保存单张图像
|
| 266 |
+
save_name = f"{save_path_gen}/{index_save}/samples/{sample_num_per_prompt_index}.png"
|
| 267 |
+
Image.fromarray((v_thwc.squeeze(0).permute(1, 2, 0).cpu().numpy()).astype('uint8')).save(save_name)
|
| 268 |
+
else:
|
| 269 |
+
raise NotImplementedError("需要保存图像")
|
| 270 |
+
|
| 271 |
+
# 保存 grid 图
|
| 272 |
+
save_name = f"{save_path_gen}/{index_save}/grid.png"
|
| 273 |
+
grid_tensor = make_grid(tensor_list_for_grid, nrow=int(np.sqrt(sample_num_per_prompt)), padding=0, pad_value=255)
|
| 274 |
+
grid_numpy = grid_tensor.permute(1, 2, 0).numpy()
|
| 275 |
+
Image.fromarray(grid_numpy).save(save_name)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def main():
|
| 279 |
+
# ========================= Env setup ==============================
|
| 280 |
+
assert torch.cuda.is_available()
|
| 281 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
| 282 |
+
dist.init_process_group("nccl")
|
| 283 |
+
GLOBAL_RANK = dist.get_rank()
|
| 284 |
+
WORLD_SIZE = dist.get_world_size()
|
| 285 |
+
else:
|
| 286 |
+
GLOBAL_RANK = 0
|
| 287 |
+
WORLD_SIZE = 1
|
| 288 |
+
|
| 289 |
+
LOCAL_RANK = GLOBAL_RANK % torch.cuda.device_count()
|
| 290 |
+
DEVICE = LOCAL_RANK
|
| 291 |
+
torch.cuda.set_device(DEVICE)
|
| 292 |
+
|
| 293 |
+
# ========================= Args and logger setup ==============================
|
| 294 |
+
parser = HfArgumentParser((ModelArguments, DataArguments, EvaluationArguments))
|
| 295 |
+
model_args, data_args, inference_args = cast(Tuple[ModelArguments, DataArguments, EvaluationArguments], parser.parse_args_into_dataclasses())
|
| 296 |
+
|
| 297 |
+
# ========================= GenEVAL 路径解析 ==============================
|
| 298 |
+
resolve_geneval_paths(model_args, data_args)
|
| 299 |
+
|
| 300 |
+
# NOTE validation_noise_seed 与 validation_data_seed 相同
|
| 301 |
+
inference_args.validation_noise_seed = inference_args.evaluation_seed
|
| 302 |
+
inference_args.validation_data_seed = inference_args.evaluation_seed
|
| 303 |
+
# Set seed:
|
| 304 |
+
seed = inference_args.global_seed * WORLD_SIZE + GLOBAL_RANK
|
| 305 |
+
set_seed(seed)
|
| 306 |
+
log_rank0 = print if GLOBAL_RANK == 0 else (lambda *_: None)
|
| 307 |
+
|
| 308 |
+
# ========================= LLM model setup ==============================
|
| 309 |
+
llm_config: Qwen2Config = Qwen2Config.from_json_file(osp.join(model_args.model_path, "llm_config.json"))
|
| 310 |
+
|
| 311 |
+
llm_config.layer_module = model_args.layer_module
|
| 312 |
+
llm_config.qk_norm = model_args.llm_qk_norm
|
| 313 |
+
llm_config.qk_norm_und = model_args.llm_qk_norm_und
|
| 314 |
+
llm_config.qk_norm_gen = model_args.llm_qk_norm_gen
|
| 315 |
+
|
| 316 |
+
llm_config.tie_word_embeddings = model_args.tie_word_embeddings
|
| 317 |
+
llm_config.freeze_und = inference_args.freeze_und
|
| 318 |
+
llm_config.apply_qwen_2_5_vl_pos_emb = inference_args.apply_qwen_2_5_vl_pos_emb
|
| 319 |
+
|
| 320 |
+
language_model: Qwen2ForCausalLM = Qwen2ForCausalLM(llm_config)
|
| 321 |
+
|
| 322 |
+
if inference_args.visual_und:
|
| 323 |
+
if model_args.vit_type in ("qwen2_5_vl", "qwen_2_5_vl_original"):
|
| 324 |
+
vit_config = Qwen2_5_VLVisionConfig.from_pretrained(model_args.vit_path)
|
| 325 |
+
vit_model = Qwen2_5_VisionTransformerPretrainedModel(vit_config)
|
| 326 |
+
vit_weights = load_file(osp.join(model_args.vit_path, "vit.safetensors"))
|
| 327 |
+
vit_model.load_state_dict(vit_weights, strict=True)
|
| 328 |
+
else:
|
| 329 |
+
raise ValueError(f"Unsupported vit_type: {model_args.vit_type}")
|
| 330 |
+
|
| 331 |
+
clean_memory(vit_weights)
|
| 332 |
+
|
| 333 |
+
if inference_args.visual_gen:
|
| 334 |
+
vae_model = WanVideoVAE()
|
| 335 |
+
vae_config: AutoEncoderParams = deepcopy(vae_model.vae_config)
|
| 336 |
+
else:
|
| 337 |
+
vae_model = None
|
| 338 |
+
vae_config = None
|
| 339 |
+
|
| 340 |
+
# Lance的配置
|
| 341 |
+
config = LanceConfig(
|
| 342 |
+
visual_gen=inference_args.visual_gen,
|
| 343 |
+
visual_und=inference_args.visual_und,
|
| 344 |
+
llm_config=llm_config,
|
| 345 |
+
vit_config=vit_config if inference_args.visual_und else None,
|
| 346 |
+
vae_config=vae_config if inference_args.visual_gen else None,
|
| 347 |
+
latent_patch_size=model_args.latent_patch_size,
|
| 348 |
+
max_num_frames=model_args.max_num_frames,
|
| 349 |
+
max_latent_size=model_args.max_latent_size,
|
| 350 |
+
vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side,
|
| 351 |
+
connector_act=model_args.connector_act,
|
| 352 |
+
interpolate_pos=model_args.interpolate_pos,
|
| 353 |
+
timestep_shift=inference_args.timestep_shift,
|
| 354 |
+
)
|
| 355 |
+
model: Lance = Lance(
|
| 356 |
+
language_model=language_model,
|
| 357 |
+
vit_model=vit_model if inference_args.visual_und else None,
|
| 358 |
+
vit_type=model_args.vit_type,
|
| 359 |
+
config=config,
|
| 360 |
+
training_args=inference_args,
|
| 361 |
+
)
|
| 362 |
+
model = model.to(DEVICE)
|
| 363 |
+
|
| 364 |
+
# Setup tokenizer for model:
|
| 365 |
+
tokenizer: Qwen2Tokenizer = Qwen2Tokenizer.from_pretrained(model_args.model_path)
|
| 366 |
+
|
| 367 |
+
tokenizer, new_token_ids, num_new_tokens = add_special_tokens(tokenizer)
|
| 368 |
+
|
| 369 |
+
# 在加载ckpt前,初始化moe
|
| 370 |
+
if inference_args.copy_init_moe:
|
| 371 |
+
language_model.init_moe()
|
| 372 |
+
|
| 373 |
+
init_from_model_path_if_needed(model, model_args)
|
| 374 |
+
|
| 375 |
+
# 现在再 resize
|
| 376 |
+
if num_new_tokens > 0:
|
| 377 |
+
model.language_model.resize_token_embeddings(len(tokenizer))
|
| 378 |
+
model.config.llm_config.vocab_size = len(tokenizer)
|
| 379 |
+
model.language_model.config.vocab_size = len(tokenizer)
|
| 380 |
+
|
| 381 |
+
if model_args.vit_type.lower() == "qwen2_5_vl":
|
| 382 |
+
from common.model.hacks import hack_qwen2_5_vl_config
|
| 383 |
+
language_model = hack_qwen2_5_vl_config(language_model)
|
| 384 |
+
|
| 385 |
+
image_token_id = language_model.config.video_token_id
|
| 386 |
+
new_token_ids.update({"image_token_id": image_token_id})
|
| 387 |
+
model.update_tokenizer(tokenizer=tokenizer)
|
| 388 |
+
|
| 389 |
+
if model_args.tie_word_embeddings:
|
| 390 |
+
model.language_model.untie_lm_head()
|
| 391 |
+
model.language_model.copy_new_token_rows_to_lm_head(num_new_tokens)
|
| 392 |
+
|
| 393 |
+
model_args.tie_word_embeddings = False
|
| 394 |
+
llm_config.tie_word_embeddings = False
|
| 395 |
+
else:
|
| 396 |
+
assert model.language_model.get_input_embeddings().weight.data.data_ptr() != model.language_model.get_output_embeddings().weight.data.data_ptr(), 'tie_world_embeddings 冲突'
|
| 397 |
+
|
| 398 |
+
model = model.to(device=DEVICE, dtype=torch.bfloat16)
|
| 399 |
+
model.eval()
|
| 400 |
+
# Some VAE wrappers (e.g. `WanVideoVAE`) are plain helper objects rather
|
| 401 |
+
# than `nn.Module`s, and their internal model is already switched to eval.
|
| 402 |
+
if vae_model is not None and hasattr(vae_model, "eval"):
|
| 403 |
+
vae_model.eval()
|
| 404 |
+
|
| 405 |
+
dataset_config = build_runtime_dataset_config(
|
| 406 |
+
model_args=model_args,
|
| 407 |
+
inference_args=inference_args,
|
| 408 |
+
vae_config=vae_config,
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
# 创建数据集
|
| 412 |
+
val_dataset = ValidationDataset(
|
| 413 |
+
jsonl_path= data_args.val_dataset_config_file,
|
| 414 |
+
tokenizer=tokenizer,
|
| 415 |
+
data_args=data_args,
|
| 416 |
+
model_args=model_args,
|
| 417 |
+
training_args=inference_args,
|
| 418 |
+
new_token_ids=new_token_ids,
|
| 419 |
+
dataset_config=dataset_config,
|
| 420 |
+
local_rank=GLOBAL_RANK,
|
| 421 |
+
world_size=WORLD_SIZE,
|
| 422 |
+
)
|
| 423 |
+
val_loader = DataLoader(
|
| 424 |
+
val_dataset,
|
| 425 |
+
batch_size=1,
|
| 426 |
+
num_workers=0,
|
| 427 |
+
pin_memory=True,
|
| 428 |
+
collate_fn=simple_custom_collate,
|
| 429 |
+
drop_last=True,
|
| 430 |
+
prefetch_factor=None,
|
| 431 |
+
persistent_workers=False,
|
| 432 |
+
multiprocessing_context=None,
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
val_loader_iter = iter(val_loader)
|
| 436 |
+
|
| 437 |
+
if not os.path.exists(inference_args.save_path_gen):
|
| 438 |
+
os.makedirs(inference_args.save_path_gen, exist_ok=True)
|
| 439 |
+
|
| 440 |
+
# 主循环
|
| 441 |
+
for _ in trange(len(val_loader), desc="Validating", unit="batch", leave=True, ncols=80, disable=(GLOBAL_RANK != 0)):
|
| 442 |
+
val_data_cpu = next(val_loader_iter)
|
| 443 |
+
|
| 444 |
+
validate_on_fixed_batch(
|
| 445 |
+
fsdp_model=model,
|
| 446 |
+
vae_model=vae_model,
|
| 447 |
+
val_data_cpu=val_data_cpu,
|
| 448 |
+
model_args=model_args,
|
| 449 |
+
inference_args=inference_args,
|
| 450 |
+
new_token_ids=new_token_ids,
|
| 451 |
+
image_token_id=image_token_id,
|
| 452 |
+
device=DEVICE,
|
| 453 |
+
save_source_video=False,
|
| 454 |
+
save_path_gen=inference_args.save_path_gen,
|
| 455 |
+
sample_num_per_prompt=inference_args.sample_num_per_prompt,
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
if dist.is_initialized():
|
| 459 |
+
dist.destroy_process_group()
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
if __name__ == "__main__":
|
| 463 |
+
main()
|
benchmarks/image_gen/GenEVAL/sample_GenEVAL.sh
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
| 4 |
+
source "$SCRIPT_DIR/../../sample_env.sh"
|
| 5 |
+
|
| 6 |
+
# ========================= 推理参数配置 =========================
|
| 7 |
+
TASK_NAME="t2i"
|
| 8 |
+
NUM_GPUS=8
|
| 9 |
+
|
| 10 |
+
VALIDATION_NUM_TIMESTEPS=50
|
| 11 |
+
VALIDATION_TIMESTEP_SHIFT=3.5
|
| 12 |
+
EVALUATION_SEED=42
|
| 13 |
+
CFG_TEXT_SCALE=4.0
|
| 14 |
+
CFG_INTERVAL_START=0.4
|
| 15 |
+
CFG_INTERVAL_END=1.0
|
| 16 |
+
SAMPLE_NUM_PER_PROMPT=4
|
| 17 |
+
USE_KVCACHE=true
|
| 18 |
+
|
| 19 |
+
VIDEO_HEIGHT=768
|
| 20 |
+
VIDEO_WIDTH=768
|
| 21 |
+
|
| 22 |
+
MODEL_PATH="downloads/Lance_3B"
|
| 23 |
+
VAL_DATASET_CONFIG_FILE="benchmarks/image_gen/GenEVAL/GenEVAL.jsonl"
|
| 24 |
+
|
| 25 |
+
# ========================= 自动生成路径 =========================
|
| 26 |
+
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
|
| 27 |
+
KVCACHE_TAG=""
|
| 28 |
+
if [ "$USE_KVCACHE" = "true" ]; then
|
| 29 |
+
KVCACHE_TAG="kvcache_"
|
| 30 |
+
fi
|
| 31 |
+
SAVE_PATH_GEN="results/GenEVAL_ts${VALIDATION_NUM_TIMESTEPS}_tss${VALIDATION_TIMESTEP_SHIFT}_seed${EVALUATION_SEED}_cfg${CFG_TEXT_SCALE}_${KVCACHE_TAG}${TIMESTAMP}"
|
| 32 |
+
|
| 33 |
+
if [ -z "$MODEL_PATH" ]; then
|
| 34 |
+
echo "错误: 请在脚本顶部配置区手动设置 MODEL_PATH"
|
| 35 |
+
exit 1
|
| 36 |
+
fi
|
| 37 |
+
|
| 38 |
+
# ============================== 环境与分布式配置 ==============================
|
| 39 |
+
lance_setup_common_env
|
| 40 |
+
lance_setup_distributed_env "$NUM_GPUS"
|
| 41 |
+
lance_setup_shard_env 1
|
| 42 |
+
|
| 43 |
+
# ========================= 显示任务配置 =========================
|
| 44 |
+
echo "================================================"
|
| 45 |
+
echo "GenEVAL T2I 推理"
|
| 46 |
+
echo "================================================"
|
| 47 |
+
echo "GPU数量: ${NUM_GPUS}"
|
| 48 |
+
echo "保存路径: ${SAVE_PATH_GEN}"
|
| 49 |
+
echo "分辨率: ${VIDEO_HEIGHT}x${VIDEO_WIDTH}"
|
| 50 |
+
echo "模型路径: ${MODEL_PATH}"
|
| 51 |
+
if [ -n "$VAL_DATASET_CONFIG_FILE" ]; then
|
| 52 |
+
echo "数据路径: ${VAL_DATASET_CONFIG_FILE}"
|
| 53 |
+
fi
|
| 54 |
+
echo ""
|
| 55 |
+
echo "关键参数:"
|
| 56 |
+
echo " - validation_num_timesteps: ${VALIDATION_NUM_TIMESTEPS}"
|
| 57 |
+
echo " - validation_timestep_shift: ${VALIDATION_TIMESTEP_SHIFT}"
|
| 58 |
+
echo " - evaluation_seed: ${EVALUATION_SEED}"
|
| 59 |
+
echo " - cfg_text_scale: ${CFG_TEXT_SCALE}"
|
| 60 |
+
echo " - cfg_interval: [${CFG_INTERVAL_START}, ${CFG_INTERVAL_END}]"
|
| 61 |
+
echo " - sample_num_per_prompt: ${SAMPLE_NUM_PER_PROMPT}"
|
| 62 |
+
echo " - use_KVcache: ${USE_KVCACHE}"
|
| 63 |
+
echo "================================================"
|
| 64 |
+
echo ""
|
| 65 |
+
|
| 66 |
+
# ============================== 执行推理 ==============================
|
| 67 |
+
# 注意:请直接修改本脚本顶部的“推理参数配置”区
|
| 68 |
+
accelerate launch \
|
| 69 |
+
--num_machines $NUM_MACHINES \
|
| 70 |
+
--num_processes $TOTAL_RANK \
|
| 71 |
+
--machine_rank $MACHINE_RANK \
|
| 72 |
+
--main_process_ip $MAIN_PROCESS_IP \
|
| 73 |
+
--main_process_port $MAIN_PROCESS_PORT \
|
| 74 |
+
--mixed_precision bf16 \
|
| 75 |
+
benchmarks/image_gen/GenEVAL/sample_GenEVAL.py \
|
| 76 |
+
--model_path "$MODEL_PATH" \
|
| 77 |
+
--val_dataset_config_file "$VAL_DATASET_CONFIG_FILE" \
|
| 78 |
+
--vit_type qwen_2_5_vl_original \
|
| 79 |
+
--llm_qk_norm true \
|
| 80 |
+
--llm_qk_norm_und true \
|
| 81 |
+
--llm_qk_norm_gen true \
|
| 82 |
+
--tie_word_embeddings false \
|
| 83 |
+
--validation_num_timesteps $VALIDATION_NUM_TIMESTEPS \
|
| 84 |
+
--validation_timestep_shift $VALIDATION_TIMESTEP_SHIFT \
|
| 85 |
+
--copy_init_moe true \
|
| 86 |
+
--max_num_frames 1 \
|
| 87 |
+
--max_latent_size 64 \
|
| 88 |
+
--latent_patch_size 1 1 1 \
|
| 89 |
+
--visual_und true \
|
| 90 |
+
--visual_gen true \
|
| 91 |
+
--vae_model_type wan \
|
| 92 |
+
--apply_qwen_2_5_vl_pos_emb true \
|
| 93 |
+
--apply_chat_template false \
|
| 94 |
+
--cfg_type 0 \
|
| 95 |
+
--validation_data_seed $EVALUATION_SEED \
|
| 96 |
+
--video_height $VIDEO_HEIGHT \
|
| 97 |
+
--video_width $VIDEO_WIDTH \
|
| 98 |
+
--task $TASK_NAME \
|
| 99 |
+
--save_path_gen $SAVE_PATH_GEN \
|
| 100 |
+
--resolution image_768res \
|
| 101 |
+
--text_template true \
|
| 102 |
+
--sample_num_per_prompt $SAMPLE_NUM_PER_PROMPT \
|
| 103 |
+
--cfg_text_scale $CFG_TEXT_SCALE \
|
| 104 |
+
--cfg_interval $CFG_INTERVAL_START $CFG_INTERVAL_END \
|
| 105 |
+
--use_KVcache $USE_KVCACHE
|
| 106 |
+
|
| 107 |
+
echo ""
|
| 108 |
+
echo "================================================"
|
| 109 |
+
echo "完成! 结果: ${SAVE_PATH_GEN}"
|
| 110 |
+
echo "================================================"
|
benchmarks/sample_env.sh
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
find_available_port() {
|
| 4 |
+
local start_port="${1:-6666}"
|
| 5 |
+
local end_port="${2:-8888}"
|
| 6 |
+
|
| 7 |
+
python3 - "$start_port" "$end_port" <<'PY'
|
| 8 |
+
import socket
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
start_port = int(sys.argv[1])
|
| 12 |
+
end_port = int(sys.argv[2])
|
| 13 |
+
|
| 14 |
+
for port in range(start_port, end_port):
|
| 15 |
+
try:
|
| 16 |
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
| 17 |
+
sock.bind(("", port))
|
| 18 |
+
sock.close()
|
| 19 |
+
print(port)
|
| 20 |
+
raise SystemExit(0)
|
| 21 |
+
except OSError:
|
| 22 |
+
continue
|
| 23 |
+
|
| 24 |
+
print(start_port)
|
| 25 |
+
PY
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
lance_setup_common_env() {
|
| 30 |
+
export EXP_HW_20250819="${EXP_HW_20250819:-False}"
|
| 31 |
+
echo "EXP_HW_20250819: $EXP_HW_20250819"
|
| 32 |
+
|
| 33 |
+
export POSITION_EMBEDDING_3D_VERSION="${POSITION_EMBEDDING_3D_VERSION:-v2}"
|
| 34 |
+
echo "(shell) POSITION_EMBEDDING_3D_VERSION: $POSITION_EMBEDDING_3D_VERSION"
|
| 35 |
+
|
| 36 |
+
# Default to async CUDA execution for benchmark/inference throughput.
|
| 37 |
+
# Override with CUDA_LAUNCH_BLOCKING=1 only when debugging kernel failures.
|
| 38 |
+
export CUDA_LAUNCH_BLOCKING="${CUDA_LAUNCH_BLOCKING:-0}"
|
| 39 |
+
export NCCL_DEBUG="${NCCL_DEBUG:-VERSION}"
|
| 40 |
+
export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC="${TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC:-900}"
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
lance_setup_distributed_env() {
|
| 45 |
+
local num_gpus="${1:-1}"
|
| 46 |
+
local default_main_process_port
|
| 47 |
+
local has_explicit_main_process_port=0
|
| 48 |
+
|
| 49 |
+
NUM_GPUS="$num_gpus"
|
| 50 |
+
|
| 51 |
+
if [ -n "$MAIN_PROCESS_PORT" ]; then
|
| 52 |
+
has_explicit_main_process_port=1
|
| 53 |
+
fi
|
| 54 |
+
|
| 55 |
+
if [ -n "${ARNOLD_WORKER_NUM:-}" ]; then
|
| 56 |
+
echo "使用平台分布式环境"
|
| 57 |
+
NUM_MACHINES="${NUM_MACHINES:-$ARNOLD_WORKER_NUM}"
|
| 58 |
+
MACHINE_RANK="${MACHINE_RANK:-${ARNOLD_ID:-0}}"
|
| 59 |
+
MAIN_PROCESS_IP="${MAIN_PROCESS_IP:-${ARNOLD_WORKER_0_HOST:-127.0.0.1}}"
|
| 60 |
+
default_main_process_port="${ARNOLD_WORKER_0_PORT:-6666}"
|
| 61 |
+
|
| 62 |
+
if [ "$has_explicit_main_process_port" -eq 1 ]; then
|
| 63 |
+
:
|
| 64 |
+
elif [ "${NUM_MACHINES}" = "1" ]; then
|
| 65 |
+
MAIN_PROCESS_PORT="$(find_available_port "$default_main_process_port" "$((default_main_process_port + 500))")"
|
| 66 |
+
else
|
| 67 |
+
MAIN_PROCESS_PORT="$default_main_process_port"
|
| 68 |
+
echo "多机任务使用平台 rendezvous 端口: $MAIN_PROCESS_PORT"
|
| 69 |
+
fi
|
| 70 |
+
else
|
| 71 |
+
echo "使用本地或显式配置的分布式环境"
|
| 72 |
+
NUM_MACHINES="${NUM_MACHINES:-1}"
|
| 73 |
+
MACHINE_RANK="${MACHINE_RANK:-0}"
|
| 74 |
+
MAIN_PROCESS_IP="${MAIN_PROCESS_IP:-127.0.0.1}"
|
| 75 |
+
default_main_process_port=6666
|
| 76 |
+
|
| 77 |
+
if [ "$has_explicit_main_process_port" -eq 1 ]; then
|
| 78 |
+
:
|
| 79 |
+
else
|
| 80 |
+
MAIN_PROCESS_PORT="$(find_available_port "$default_main_process_port" "$((default_main_process_port + 500))")"
|
| 81 |
+
fi
|
| 82 |
+
fi
|
| 83 |
+
|
| 84 |
+
TOTAL_RANK=$((NUM_MACHINES * NUM_GPUS))
|
| 85 |
+
|
| 86 |
+
export NUM_GPUS NUM_MACHINES MACHINE_RANK MAIN_PROCESS_IP MAIN_PROCESS_PORT TOTAL_RANK
|
| 87 |
+
|
| 88 |
+
echo "NUM_MACHINES: $NUM_MACHINES"
|
| 89 |
+
echo "NUM_GPUS: $NUM_GPUS"
|
| 90 |
+
echo "TOTAL_RANK: $TOTAL_RANK"
|
| 91 |
+
echo "MACHINE_RANK: $MACHINE_RANK"
|
| 92 |
+
echo "MAIN_PROCESS_IP: $MAIN_PROCESS_IP"
|
| 93 |
+
echo "MAIN_PROCESS_PORT: $MAIN_PROCESS_PORT"
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
lance_setup_shard_env() {
|
| 98 |
+
local num_shard="${1:-1}"
|
| 99 |
+
|
| 100 |
+
NUM_SHARD="$num_shard"
|
| 101 |
+
NUM_REPLICATE=$((TOTAL_RANK / NUM_SHARD))
|
| 102 |
+
|
| 103 |
+
export NUM_SHARD NUM_REPLICATE
|
| 104 |
+
|
| 105 |
+
echo "NUM_REPLICATE: $NUM_REPLICATE"
|
| 106 |
+
echo "NUM_SHARD: $NUM_SHARD"
|
| 107 |
+
}
|
benchmarks/video_gen/Vbench/README.md
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[Chinese Version](./README_zh.md)
|
| 2 |
+
|
| 3 |
+
# VBench Video Generation Evaluation
|
| 4 |
+
|
| 5 |
+
Benchmark evaluation scripts for VBench based on the Lance model.
|
| 6 |
+
|
| 7 |
+
## Files
|
| 8 |
+
|
| 9 |
+
- `sample_vbench.py` - Python inference script
|
| 10 |
+
- `sample_vbench.sh` - Launch script (recommended)
|
| 11 |
+
- `Vbench_recaption.jsonl` - Evaluation dataset
|
| 12 |
+
|
| 13 |
+
## Quick Start
|
| 14 |
+
|
| 15 |
+
### Basic Usage
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
bash sample_vbench.sh
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
Before running, edit the "Inference Parameters" section at the top of `benchmarks/video_gen/Vbench/sample_vbench.sh`.
|
| 22 |
+
|
| 23 |
+
## Parameters
|
| 24 |
+
|
| 25 |
+
| Parameter | Default | Description |
|
| 26 |
+
|------|--------|------|
|
| 27 |
+
| `TASK_NAME` | `t2v` | Task type. VBench is fixed to video generation. |
|
| 28 |
+
| `VALIDATION_NUM_TIMESTEPS` | 50 | Number of inference steps. |
|
| 29 |
+
| `VALIDATION_TIMESTEP_SHIFT` | 3.5 | Timestep shift. |
|
| 30 |
+
| `EVALUATION_SEED` | 42 | Random seed. |
|
| 31 |
+
| `CFG_TEXT_SCALE` | 4.0 | CFG scale. |
|
| 32 |
+
| `CFG_INTERVAL_START` | 0.4 | Start of the CFG interval. |
|
| 33 |
+
| `CFG_INTERVAL_END` | 1.0 | End of the CFG interval. |
|
| 34 |
+
| `SAMPLE_NUM_PER_PROMPT` | 5 | Number of videos generated for each regular prompt. |
|
| 35 |
+
| `USE_KVCACHE` | `true` | Whether to enable KV cache. |
|
| 36 |
+
| `NUM_GPUS` | 8 | Number of GPUs. |
|
| 37 |
+
| `VIDEO_HEIGHT`/`VIDEO_WIDTH` | 480 | Video resolution. |
|
| 38 |
+
| `NUM_FRAMES` | 50 | Number of output video frames. |
|
| 39 |
+
| `MAX_NUM_FRAMES` | 121 | Maximum number of frames per sample. |
|
| 40 |
+
| `MAX_LATENT_SIZE` | 64 | Maximum latent size. |
|
| 41 |
+
| `RESOLUTION` | `video_480p` | Dataset resolution tag. |
|
| 42 |
+
| `MODEL_PATH` | `downloads/Lance_3B_Video` | Path to the Lance checkpoint. |
|
| 43 |
+
| `VAL_DATASET_CONFIG_FILE` | `benchmarks/video_gen/Vbench/Vbench_recaption.jsonl` | Path to the evaluation data. |
|
| 44 |
+
| `CONFIG_JSON_PATH` | `""` | Optional training configuration JSON. |
|
| 45 |
+
|
| 46 |
+
## How To Modify
|
| 47 |
+
|
| 48 |
+
- Edit the "Inference Parameters" section at the top of `benchmarks/video_gen/Vbench/sample_vbench.sh`.
|
| 49 |
+
- After updating the parameters, run `bash benchmarks/video_gen/Vbench/sample_vbench.sh` directly.
|
| 50 |
+
- `SAVE_PATH_GEN` is generated automatically from the script parameters and does not need to be set manually.
|
| 51 |
+
|
| 52 |
+
## Output Format
|
| 53 |
+
|
| 54 |
+
Results are saved in a structure like this:
|
| 55 |
+
|
| 56 |
+
```
|
| 57 |
+
results/Vbench_ts50_tss3.5_seed42_cfg4.0_kvcache_20260507_120000/
|
| 58 |
+
├── In a still frame, a stop sign-0.mp4
|
| 59 |
+
├── In a still frame, a stop sign-1.mp4
|
| 60 |
+
├── a toilet, frozen in time-0.mp4
|
| 61 |
+
├── ...
|
| 62 |
+
├── prompt.json
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
Each prompt generates `SAMPLE_NUM_PER_PROMPT` videos by default, named as `original-prompt-sample-index.mp4`. A `prompt.json` file is also written to record the generated text.
|
| 66 |
+
If `temporal_flickering_prompts.json` exists in the repository, the corresponding prompts automatically use a larger sample count. If the file does not exist, the script directly uses `SAMPLE_NUM_PER_PROMPT`.
|
| 67 |
+
|
| 68 |
+
## Notes
|
| 69 |
+
|
| 70 |
+
- If you need to switch the model, dataset, frame count, or resolution, edit the script configuration at the top directly.
|
| 71 |
+
- The ViT path is resolved automatically by the code and usually does not need to be configured separately.
|
| 72 |
+
- `CONFIG_JSON_PATH` is only passed through as an optional training configuration JSON and does not override the other explicit script parameters.
|
benchmarks/video_gen/Vbench/README_zh.md
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[English Version](./README.md)
|
| 2 |
+
|
| 3 |
+
# VBench 视频生成评估
|
| 4 |
+
|
| 5 |
+
基于 Lance 模型的 VBench 评估基准测试脚本。
|
| 6 |
+
|
| 7 |
+
## 文件说明
|
| 8 |
+
|
| 9 |
+
- `sample_vbench.py` - 推理 Python 脚本
|
| 10 |
+
- `sample_vbench.sh` - 启动脚本(推荐使用)
|
| 11 |
+
- `Vbench_recaption.jsonl` - 评估数据集
|
| 12 |
+
|
| 13 |
+
## 快速开始
|
| 14 |
+
|
| 15 |
+
### 基本用法
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
bash sample_vbench.sh
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
运行前请直接修改 `benchmarks/video_gen/Vbench/sample_vbench.sh` 顶部的“推理参数配置”区。
|
| 22 |
+
|
| 23 |
+
## 参数说明
|
| 24 |
+
|
| 25 |
+
| 参数 | 默认值 | 说明 |
|
| 26 |
+
|------|--------|------|
|
| 27 |
+
| `TASK_NAME` | `t2v` | 任务类型,VBench 固定为视频生成 |
|
| 28 |
+
| `VALIDATION_NUM_TIMESTEPS` | 50 | 推理步数 |
|
| 29 |
+
| `VALIDATION_TIMESTEP_SHIFT` | 3.5 | Timestep shift |
|
| 30 |
+
| `EVALUATION_SEED` | 42 | 随机种子 |
|
| 31 |
+
| `CFG_TEXT_SCALE` | 4.0 | CFG scale |
|
| 32 |
+
| `CFG_INTERVAL_START` | 0.4 | CFG 区间起点 |
|
| 33 |
+
| `CFG_INTERVAL_END` | 1.0 | CFG 区间终点 |
|
| 34 |
+
| `SAMPLE_NUM_PER_PROMPT` | 5 | 每个普通 prompt 生成的视频数量 |
|
| 35 |
+
| `USE_KVCACHE` | `true` | 是否启用 KV cache |
|
| 36 |
+
| `NUM_GPUS` | 8 | GPU 数量 |
|
| 37 |
+
| `VIDEO_HEIGHT`/`VIDEO_WIDTH` | 480 | 视频分辨率 |
|
| 38 |
+
| `NUM_FRAMES` | 50 | 输出视频帧数 |
|
| 39 |
+
| `MAX_NUM_FRAMES` | 121 | 单个样本最大帧数 |
|
| 40 |
+
| `MAX_LATENT_SIZE` | 64 | latent size 上限 |
|
| 41 |
+
| `RESOLUTION` | `video_480p` | 数据集分辨率标签 |
|
| 42 |
+
| `MODEL_PATH` | `downloads/Lance_3B_Video` | Lance checkpoint 路径 |
|
| 43 |
+
| `VAL_DATASET_CONFIG_FILE` | `benchmarks/video_gen/Vbench/Vbench_recaption.jsonl` | 评估数据路径 |
|
| 44 |
+
| `CONFIG_JSON_PATH` | `""` | 可选训练配置 JSON |
|
| 45 |
+
|
| 46 |
+
## 修改方式
|
| 47 |
+
|
| 48 |
+
- 请手动编辑 `benchmarks/video_gen/Vbench/sample_vbench.sh` 顶部的“推理参数配置”区。
|
| 49 |
+
- 修改完成后,直接运行 `bash benchmarks/video_gen/Vbench/sample_vbench.sh`。
|
| 50 |
+
- `SAVE_PATH_GEN` 由脚本根据顶部参数自动生成,不需要手动设置。
|
| 51 |
+
|
| 52 |
+
## 保存格式
|
| 53 |
+
|
| 54 |
+
结果会按照以下结构保存:
|
| 55 |
+
|
| 56 |
+
```
|
| 57 |
+
results/Vbench_ts50_tss3.5_seed42_cfg4.0_kvcache_20260507_120000/
|
| 58 |
+
├── In a still frame, a stop sign-0.mp4
|
| 59 |
+
├── In a still frame, a stop sign-1.mp4
|
| 60 |
+
├── a toilet, frozen in time-0.mp4
|
| 61 |
+
├── ...
|
| 62 |
+
├── prompt.json
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
每个 prompt 默认生成 `SAMPLE_NUM_PER_PROMPT` 个视频,并按 `原始 prompt-采样序号.mp4` 命名;同时会额外写出 `prompt.json` 记录生成文本。
|
| 66 |
+
如果仓库中存在 `temporal_flickering_prompts.json`,对应 prompt 会自动提升采样数;当前文件不存在时,脚本会直接使用 `SAMPLE_NUM_PER_PROMPT`。
|
| 67 |
+
|
| 68 |
+
## 注意事项
|
| 69 |
+
|
| 70 |
+
- 如果需要切换模型、数据集、帧数或分辨率,请直接修改脚本顶部配置。
|
| 71 |
+
- ViT 路径默认由代码内部自动解析,无需单独配置。
|
| 72 |
+
- `CONFIG_JSON_PATH` 仅作为可选训练配置 JSON 传入,不会替代脚本顶部其它显式参数。
|
benchmarks/video_gen/Vbench/Vbench_recaption.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
benchmarks/video_gen/Vbench/sample_vbench.py
ADDED
|
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# coding: utf-8
|
| 15 |
+
|
| 16 |
+
import warnings
|
| 17 |
+
warnings.filterwarnings("ignore", message=".*pkg_resources is deprecated.*", category=UserWarning)
|
| 18 |
+
warnings.filterwarnings("ignore", category=FutureWarning, module="diffusers.models.transformers.transformer_2d")
|
| 19 |
+
import os
|
| 20 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
| 21 |
+
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
| 22 |
+
|
| 23 |
+
import json
|
| 24 |
+
import os.path as osp
|
| 25 |
+
from copy import deepcopy
|
| 26 |
+
from dataclasses import asdict, fields
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from typing import Optional, Tuple, cast
|
| 29 |
+
|
| 30 |
+
import imageio
|
| 31 |
+
import torch
|
| 32 |
+
import torch.distributed as dist
|
| 33 |
+
from safetensors.torch import load_file
|
| 34 |
+
from torch.utils.data import DataLoader
|
| 35 |
+
from tqdm import trange
|
| 36 |
+
from transformers import HfArgumentParser, set_seed
|
| 37 |
+
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig
|
| 38 |
+
|
| 39 |
+
from config.config_factory import (
|
| 40 |
+
DataArguments,
|
| 41 |
+
EvaluationArguments,
|
| 42 |
+
ModelArguments,
|
| 43 |
+
TrainingArguments,
|
| 44 |
+
get_model_path,
|
| 45 |
+
)
|
| 46 |
+
from common.model.hacks import hack_qwen2_5_vl_config
|
| 47 |
+
from common.utils.misc import AutoEncoderParams, tuple_mul
|
| 48 |
+
from common.val.utils import decode_video_tensor, make_padded_latent
|
| 49 |
+
from data.dataset_base import DataConfig, simple_custom_collate
|
| 50 |
+
from data.data_utils import add_special_tokens
|
| 51 |
+
from data.datasets_custom import ValidationDataset
|
| 52 |
+
from modeling.lance import Lance, LanceConfig, Qwen2ForCausalLM
|
| 53 |
+
from modeling.qwen2 import Qwen2Tokenizer
|
| 54 |
+
from modeling.qwen2.modeling_qwen2 import Qwen2Config
|
| 55 |
+
from modeling.vae.wan.model import WanVideoVAE
|
| 56 |
+
from modeling.vit.qwen2_5_vl_vit import Qwen2_5_VisionTransformerPretrainedModel
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
PROMPT_JSON_FILENAME = "prompt.json"
|
| 60 |
+
TEMPORAL_FLICKERING_SAMPLE_NUM = 25
|
| 61 |
+
DEFAULT_VBENCH_DATA = "benchmarks/video_gen/Vbench/Vbench_recaption.jsonl"
|
| 62 |
+
TEMPORAL_FLICKERING_PROMPT_FILE = (
|
| 63 |
+
Path(__file__).resolve().parent / "temporal_flickering_prompts.json"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def load_temporal_flickering_prompts() -> set[str]:
|
| 68 |
+
if not TEMPORAL_FLICKERING_PROMPT_FILE.exists():
|
| 69 |
+
warnings.warn(
|
| 70 |
+
f"Temporal flickering prompt file not found: {TEMPORAL_FLICKERING_PROMPT_FILE}. "
|
| 71 |
+
"Falling back to an empty prompt set.",
|
| 72 |
+
stacklevel=2,
|
| 73 |
+
)
|
| 74 |
+
return set()
|
| 75 |
+
|
| 76 |
+
with TEMPORAL_FLICKERING_PROMPT_FILE.open("r", encoding="utf-8") as f:
|
| 77 |
+
data = json.load(f)
|
| 78 |
+
|
| 79 |
+
return set(data)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
PROMPT_WITH_TEMPORAL_FLICKERING = load_temporal_flickering_prompts()
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def clean_memory(*objects):
|
| 86 |
+
for obj in objects:
|
| 87 |
+
del obj
|
| 88 |
+
import gc
|
| 89 |
+
|
| 90 |
+
gc.collect()
|
| 91 |
+
if torch.cuda.is_available():
|
| 92 |
+
torch.cuda.empty_cache()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def init_from_model_path_if_needed(
|
| 96 |
+
model: Qwen2ForCausalLM,
|
| 97 |
+
model_args: ModelArguments,
|
| 98 |
+
):
|
| 99 |
+
path_dir = model_args.model_path
|
| 100 |
+
ema_path = osp.join(path_dir, "ema.safetensors")
|
| 101 |
+
model_path = osp.join(path_dir, "model.safetensors")
|
| 102 |
+
|
| 103 |
+
model_path_ft = None
|
| 104 |
+
if osp.exists(model_path):
|
| 105 |
+
model_path_ft = model_path
|
| 106 |
+
elif osp.exists(ema_path):
|
| 107 |
+
model_path_ft = ema_path
|
| 108 |
+
|
| 109 |
+
if model_path_ft:
|
| 110 |
+
model_state_dict = load_file(model_path_ft, device="cpu")
|
| 111 |
+
else:
|
| 112 |
+
raise FileNotFoundError(
|
| 113 |
+
f"Fine-tuning failed: No valid checkpoint ('ema.safetensors' or 'model.safetensors') found in {path_dir}"
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
if "latent_pos_embed.pos_embed" in model_state_dict:
|
| 117 |
+
model_state_dict.pop("latent_pos_embed.pos_embed")
|
| 118 |
+
|
| 119 |
+
model.load_state_dict(model_state_dict, strict=False)
|
| 120 |
+
clean_memory(model_state_dict)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def resolve_vbench_paths(
|
| 124 |
+
model_args: ModelArguments,
|
| 125 |
+
data_args: DataArguments,
|
| 126 |
+
) -> None:
|
| 127 |
+
if not model_args.model_path:
|
| 128 |
+
raise ValueError("VBench requires --model_path to be provided explicitly.")
|
| 129 |
+
|
| 130 |
+
if not getattr(model_args, "llm_path", ""):
|
| 131 |
+
model_args.llm_path = model_args.model_path
|
| 132 |
+
|
| 133 |
+
if not model_args.vit_path:
|
| 134 |
+
model_args.vit_path = get_model_path("vit.qwen2_5_vl")
|
| 135 |
+
|
| 136 |
+
if not data_args.val_dataset_config_file:
|
| 137 |
+
data_args.val_dataset_config_file = DEFAULT_VBENCH_DATA
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def build_runtime_dataset_config(
|
| 141 |
+
model_args: ModelArguments,
|
| 142 |
+
training_args: TrainingArguments,
|
| 143 |
+
inference_args: EvaluationArguments,
|
| 144 |
+
vae_config: Optional[AutoEncoderParams],
|
| 145 |
+
) -> DataConfig:
|
| 146 |
+
dataset_config = DataConfig()
|
| 147 |
+
|
| 148 |
+
dataset_config.num_frames = inference_args.num_frames
|
| 149 |
+
dataset_config.H = inference_args.video_height
|
| 150 |
+
dataset_config.W = inference_args.video_width
|
| 151 |
+
dataset_config.task = inference_args.task
|
| 152 |
+
dataset_config.resolution = inference_args.resolution
|
| 153 |
+
dataset_config.text_template = inference_args.text_template
|
| 154 |
+
dataset_config.max_duration = inference_args.max_duration
|
| 155 |
+
dataset_config.system_prompt_type = inference_args.system_prompt_type
|
| 156 |
+
|
| 157 |
+
if training_args.visual_und:
|
| 158 |
+
dataset_config.vit_patch_size = model_args.vit_patch_size
|
| 159 |
+
dataset_config.vit_patch_size_temporal = model_args.vit_patch_size_temporal
|
| 160 |
+
dataset_config.vit_max_num_patch_per_side = model_args.vit_max_num_patch_per_side
|
| 161 |
+
|
| 162 |
+
if training_args.visual_gen and vae_config:
|
| 163 |
+
assert len(model_args.latent_patch_size) == 3, "len(latent_patch_size) must be 3"
|
| 164 |
+
dataset_config.latent_patch_size = model_args.latent_patch_size
|
| 165 |
+
dataset_config.vae_downsample = tuple_mul(
|
| 166 |
+
model_args.latent_patch_size,
|
| 167 |
+
(vae_config.downsample_temporal, vae_config.downsample_spatial, vae_config.downsample_spatial),
|
| 168 |
+
)
|
| 169 |
+
dataset_config.max_latent_size = model_args.max_latent_size
|
| 170 |
+
dataset_config.max_num_frames = model_args.max_num_frames
|
| 171 |
+
|
| 172 |
+
dataset_config.text_cond_dropout_prob = model_args.text_cond_dropout_prob
|
| 173 |
+
dataset_config.vae_cond_dropout_prob = model_args.vae_cond_dropout_prob
|
| 174 |
+
dataset_config.vit_cond_dropout_prob = model_args.vit_cond_dropout_prob
|
| 175 |
+
|
| 176 |
+
return dataset_config
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def save_prompt_results(prompt_data_dict, save_path_gen: str):
|
| 180 |
+
prompt_json_path = os.path.join(save_path_gen, PROMPT_JSON_FILENAME)
|
| 181 |
+
with open(prompt_json_path, "w", encoding="utf-8") as f:
|
| 182 |
+
json.dump(prompt_data_dict, f, ensure_ascii=False, indent=2)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def safe_instantiate(cls, cfg: dict, name: str):
|
| 186 |
+
valid_keys = {f.name for f in fields(cls)}
|
| 187 |
+
valid, invalid = {}, {}
|
| 188 |
+
for k, v in cfg.items():
|
| 189 |
+
if k in valid_keys:
|
| 190 |
+
valid[k] = v
|
| 191 |
+
else:
|
| 192 |
+
invalid[k] = v
|
| 193 |
+
|
| 194 |
+
if invalid:
|
| 195 |
+
print(f"[WARN] {name} 过滤无效参数: {invalid}")
|
| 196 |
+
return cls(**valid)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def is_valid_value(value):
|
| 200 |
+
return value is not None
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def merge_args(original_args, override_args):
|
| 204 |
+
merged_dict = asdict(original_args)
|
| 205 |
+
override_dict = asdict(override_args)
|
| 206 |
+
|
| 207 |
+
for key, value in override_dict.items():
|
| 208 |
+
if is_valid_value(value):
|
| 209 |
+
merged_dict[key] = value
|
| 210 |
+
|
| 211 |
+
return original_args.__class__(**merged_dict)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def apply_config_json_overrides(
|
| 215 |
+
model_args: ModelArguments,
|
| 216 |
+
data_args: DataArguments,
|
| 217 |
+
inference_args: EvaluationArguments,
|
| 218 |
+
):
|
| 219 |
+
if not inference_args.config_json_path or not inference_args.config_json_path.endswith(".json"):
|
| 220 |
+
return model_args, data_args, inference_args
|
| 221 |
+
|
| 222 |
+
model_path_original = model_args.model_path
|
| 223 |
+
val_dataset_config_file_original = data_args.val_dataset_config_file
|
| 224 |
+
|
| 225 |
+
with open(inference_args.config_json_path, "r", encoding="utf-8") as f:
|
| 226 |
+
config = json.load(f)
|
| 227 |
+
|
| 228 |
+
if "model_args" in config:
|
| 229 |
+
model_args = merge_args(
|
| 230 |
+
model_args,
|
| 231 |
+
safe_instantiate(ModelArguments, config["model_args"], "ModelArguments"),
|
| 232 |
+
)
|
| 233 |
+
if "data_args" in config:
|
| 234 |
+
data_args = merge_args(
|
| 235 |
+
data_args,
|
| 236 |
+
safe_instantiate(DataArguments, config["data_args"], "DataArguments"),
|
| 237 |
+
)
|
| 238 |
+
if "training_args" in config:
|
| 239 |
+
inference_args = merge_args(
|
| 240 |
+
inference_args,
|
| 241 |
+
safe_instantiate(EvaluationArguments, config["training_args"], "EvaluationArguments"),
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
model_args.model_path = model_path_original
|
| 245 |
+
if getattr(model_args, "llm_path", "") == "":
|
| 246 |
+
model_args.llm_path = model_path_original
|
| 247 |
+
data_args.val_dataset_config_file = val_dataset_config_file_original
|
| 248 |
+
return model_args, data_args, inference_args
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def get_sample_num_per_prompt(
|
| 252 |
+
inference_args: EvaluationArguments,
|
| 253 |
+
prompt: str,
|
| 254 |
+
) -> int:
|
| 255 |
+
if prompt in PROMPT_WITH_TEMPORAL_FLICKERING:
|
| 256 |
+
if inference_args.quick_debug:
|
| 257 |
+
return min(inference_args.sample_num_per_prompt, 5)
|
| 258 |
+
return max(inference_args.sample_num_per_prompt, TEMPORAL_FLICKERING_SAMPLE_NUM)
|
| 259 |
+
return inference_args.sample_num_per_prompt
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def validate_on_fixed_batch(
|
| 263 |
+
fsdp_model: Lance,
|
| 264 |
+
vae_model: Optional[WanVideoVAE],
|
| 265 |
+
val_data_cpu: dict,
|
| 266 |
+
training_args: TrainingArguments,
|
| 267 |
+
model_args: ModelArguments,
|
| 268 |
+
inference_args: EvaluationArguments,
|
| 269 |
+
new_token_ids,
|
| 270 |
+
image_token_id: int,
|
| 271 |
+
device: int,
|
| 272 |
+
save_path_gen: str,
|
| 273 |
+
):
|
| 274 |
+
is_rank0 = not dist.is_initialized() or dist.get_rank() == 0
|
| 275 |
+
val_data = val_data_cpu.cuda(device).to_dict()
|
| 276 |
+
|
| 277 |
+
with torch.no_grad(), torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
| 278 |
+
if "padded_videos" in val_data:
|
| 279 |
+
val_data["padded_latent"] = make_padded_latent(
|
| 280 |
+
val_data["padded_videos"],
|
| 281 |
+
val_data["vae_data_mode"],
|
| 282 |
+
vae_model,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
prompt = val_data.get("original_prompt_en") or val_data.get("caption")
|
| 286 |
+
if not prompt:
|
| 287 |
+
raise ValueError("VBench sample requires `original_prompt_en` or `caption` in dataset.")
|
| 288 |
+
|
| 289 |
+
sample_num_per_prompt = get_sample_num_per_prompt(inference_args, prompt)
|
| 290 |
+
loop_iterator = trange(sample_num_per_prompt, disable=(not is_rank0), leave=False, desc="Sampling")
|
| 291 |
+
|
| 292 |
+
for sample_idx in loop_iterator:
|
| 293 |
+
save_name = f"{save_path_gen}/{prompt}-{sample_idx}.mp4"
|
| 294 |
+
if os.path.exists(save_name):
|
| 295 |
+
continue
|
| 296 |
+
|
| 297 |
+
params = {
|
| 298 |
+
"val_packed_text_ids": val_data["packed_text_ids"],
|
| 299 |
+
"val_packed_text_indexes": val_data["packed_text_indexes"],
|
| 300 |
+
"val_sample_lens": val_data["sample_lens"],
|
| 301 |
+
"val_packed_position_ids": val_data["packed_position_ids"],
|
| 302 |
+
"val_split_lens": val_data["split_lens"],
|
| 303 |
+
"val_attn_modes": val_data["attn_modes"],
|
| 304 |
+
"val_sample_N_target": val_data["sample_N_target"],
|
| 305 |
+
"val_packed_vae_token_indexes": val_data["packed_vae_token_indexes"],
|
| 306 |
+
"timestep_shift": training_args.validation_timestep_shift,
|
| 307 |
+
"num_timesteps": training_args.validation_num_timesteps,
|
| 308 |
+
"val_mse_loss_indexes": val_data.get("mse_loss_indexes", None),
|
| 309 |
+
"val_padded_latent": val_data["padded_latent"],
|
| 310 |
+
"video_sizes": val_data["video_sizes"],
|
| 311 |
+
"cfg_text_scale": model_args.cfg_text_scale,
|
| 312 |
+
"cfg_interval": training_args.cfg_interval,
|
| 313 |
+
"cfg_renorm_min": training_args.cfg_renorm_min,
|
| 314 |
+
"cfg_renorm_type": training_args.cfg_renorm_type,
|
| 315 |
+
"device": device,
|
| 316 |
+
"dtype": torch.bfloat16,
|
| 317 |
+
"new_token_ids": new_token_ids,
|
| 318 |
+
"max_samples": training_args.validation_max_samples,
|
| 319 |
+
"validation_noise_seed": training_args.validation_noise_seed + sample_idx,
|
| 320 |
+
"apply_chat_template": training_args.apply_chat_template,
|
| 321 |
+
"apply_qwen_2_5_vl_pos_emb": training_args.apply_qwen_2_5_vl_pos_emb,
|
| 322 |
+
"image_token_id": image_token_id,
|
| 323 |
+
"val_packed_vit_token_indexes": val_data.get("packed_vit_token_indexes", None),
|
| 324 |
+
"val_packed_vit_tokens": val_data.get("packed_vit_tokens", None),
|
| 325 |
+
"vit_video_grid_thw": val_data.get("vit_video_grid_thw", None),
|
| 326 |
+
"vae_video_grid_thw": val_data["vae_video_grid_thw"],
|
| 327 |
+
"video_grid_thw": val_data.get("video_grid_thw", None),
|
| 328 |
+
"caption": val_data.get("caption", None),
|
| 329 |
+
"sample_task": val_data["sample_task"],
|
| 330 |
+
"sample_modality": val_data["sample_modality"],
|
| 331 |
+
"cfg_type": training_args.cfg_type,
|
| 332 |
+
"cfg_uncond_token_id": training_args.cfg_uncond_token_id,
|
| 333 |
+
"index": val_data["index"],
|
| 334 |
+
"val_padded_videos": None,
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
if inference_args.use_KVcache:
|
| 338 |
+
denoise_latent, captions, _, _ = fsdp_model.validation_gen_KVcache(**params)
|
| 339 |
+
else:
|
| 340 |
+
denoise_latent, captions, _, _ = fsdp_model.validation_gen(**params)
|
| 341 |
+
|
| 342 |
+
for i_val, latent in enumerate(denoise_latent):
|
| 343 |
+
v_list = [vae_model.vae_decode([latent_])[0] for latent_ in latent]
|
| 344 |
+
v_thwc = decode_video_tensor(v_list)
|
| 345 |
+
imageio.mimsave(
|
| 346 |
+
save_name,
|
| 347 |
+
v_thwc,
|
| 348 |
+
fps=inference_args.validation_video_saving_fps,
|
| 349 |
+
format="mp4",
|
| 350 |
+
)
|
| 351 |
+
inference_args.prompt_data_dict[os.path.basename(save_name)] = captions[i_val]
|
| 352 |
+
clean_memory(v_list, v_thwc)
|
| 353 |
+
|
| 354 |
+
clean_memory(denoise_latent, captions)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def main():
|
| 358 |
+
assert torch.cuda.is_available()
|
| 359 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
| 360 |
+
dist.init_process_group("nccl")
|
| 361 |
+
global_rank = dist.get_rank()
|
| 362 |
+
world_size = dist.get_world_size()
|
| 363 |
+
else:
|
| 364 |
+
global_rank = 0
|
| 365 |
+
world_size = 1
|
| 366 |
+
|
| 367 |
+
local_rank = global_rank % torch.cuda.device_count()
|
| 368 |
+
device = local_rank
|
| 369 |
+
torch.cuda.set_device(device)
|
| 370 |
+
|
| 371 |
+
parser = HfArgumentParser((ModelArguments, DataArguments, EvaluationArguments))
|
| 372 |
+
model_args, data_args, inference_args = cast(
|
| 373 |
+
Tuple[ModelArguments, DataArguments, EvaluationArguments],
|
| 374 |
+
parser.parse_args_into_dataclasses(),
|
| 375 |
+
)
|
| 376 |
+
training_args = inference_args
|
| 377 |
+
|
| 378 |
+
model_args, data_args, inference_args = apply_config_json_overrides(
|
| 379 |
+
model_args,
|
| 380 |
+
data_args,
|
| 381 |
+
inference_args,
|
| 382 |
+
)
|
| 383 |
+
training_args = inference_args
|
| 384 |
+
resolve_vbench_paths(model_args, data_args)
|
| 385 |
+
|
| 386 |
+
training_args.validation_noise_seed = inference_args.evaluation_seed
|
| 387 |
+
training_args.validation_data_seed = inference_args.evaluation_seed
|
| 388 |
+
|
| 389 |
+
seed = training_args.global_seed * world_size + global_rank
|
| 390 |
+
set_seed(seed)
|
| 391 |
+
log_rank0 = print if global_rank == 0 else (lambda *_: None)
|
| 392 |
+
|
| 393 |
+
llm_config: Qwen2Config = Qwen2Config.from_json_file(osp.join(model_args.model_path, "llm_config.json"))
|
| 394 |
+
|
| 395 |
+
llm_config.layer_module = model_args.layer_module
|
| 396 |
+
llm_config.qk_norm = model_args.llm_qk_norm
|
| 397 |
+
llm_config.qk_norm_und = model_args.llm_qk_norm_und
|
| 398 |
+
llm_config.qk_norm_gen = model_args.llm_qk_norm_gen
|
| 399 |
+
llm_config.tie_word_embeddings = model_args.tie_word_embeddings
|
| 400 |
+
llm_config.freeze_und = training_args.freeze_und
|
| 401 |
+
llm_config.apply_qwen_2_5_vl_pos_emb = training_args.apply_qwen_2_5_vl_pos_emb
|
| 402 |
+
|
| 403 |
+
language_model: Qwen2ForCausalLM = Qwen2ForCausalLM(llm_config)
|
| 404 |
+
|
| 405 |
+
if training_args.visual_und:
|
| 406 |
+
if model_args.vit_type in ("qwen2_5_vl", "qwen_2_5_vl_original"):
|
| 407 |
+
vit_config = Qwen2_5_VLVisionConfig.from_pretrained(model_args.vit_path)
|
| 408 |
+
vit_model = Qwen2_5_VisionTransformerPretrainedModel(vit_config)
|
| 409 |
+
vit_weights = load_file(osp.join(model_args.vit_path, "vit.safetensors"))
|
| 410 |
+
vit_model.load_state_dict(vit_weights, strict=True)
|
| 411 |
+
else:
|
| 412 |
+
raise ValueError(f"Unsupported vit_type: {model_args.vit_type}")
|
| 413 |
+
clean_memory(vit_weights)
|
| 414 |
+
|
| 415 |
+
if training_args.visual_gen:
|
| 416 |
+
vae_model = WanVideoVAE()
|
| 417 |
+
vae_config: Optional[AutoEncoderParams] = deepcopy(vae_model.vae_config)
|
| 418 |
+
else:
|
| 419 |
+
vae_model = None
|
| 420 |
+
vae_config = None
|
| 421 |
+
|
| 422 |
+
config = LanceConfig(
|
| 423 |
+
visual_gen=training_args.visual_gen,
|
| 424 |
+
visual_und=training_args.visual_und,
|
| 425 |
+
llm_config=llm_config,
|
| 426 |
+
vit_config=vit_config if training_args.visual_und else None,
|
| 427 |
+
vae_config=vae_config if training_args.visual_gen else None,
|
| 428 |
+
latent_patch_size=model_args.latent_patch_size,
|
| 429 |
+
max_num_frames=model_args.max_num_frames,
|
| 430 |
+
max_latent_size=model_args.max_latent_size,
|
| 431 |
+
vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side,
|
| 432 |
+
connector_act=model_args.connector_act,
|
| 433 |
+
interpolate_pos=model_args.interpolate_pos,
|
| 434 |
+
timestep_shift=training_args.timestep_shift,
|
| 435 |
+
)
|
| 436 |
+
model: Lance = Lance(
|
| 437 |
+
language_model=language_model,
|
| 438 |
+
vit_model=vit_model if training_args.visual_und else None,
|
| 439 |
+
vit_type=model_args.vit_type,
|
| 440 |
+
config=config,
|
| 441 |
+
training_args=training_args,
|
| 442 |
+
)
|
| 443 |
+
model = model.to(device)
|
| 444 |
+
|
| 445 |
+
tokenizer: Qwen2Tokenizer = Qwen2Tokenizer.from_pretrained(model_args.model_path)
|
| 446 |
+
|
| 447 |
+
tokenizer, new_token_ids, num_new_tokens = add_special_tokens(tokenizer)
|
| 448 |
+
|
| 449 |
+
if training_args.copy_init_moe:
|
| 450 |
+
language_model.init_moe()
|
| 451 |
+
|
| 452 |
+
init_from_model_path_if_needed(model, model_args)
|
| 453 |
+
|
| 454 |
+
if num_new_tokens > 0:
|
| 455 |
+
model.language_model.resize_token_embeddings(len(tokenizer))
|
| 456 |
+
model.config.llm_config.vocab_size = len(tokenizer)
|
| 457 |
+
model.language_model.config.vocab_size = len(tokenizer)
|
| 458 |
+
|
| 459 |
+
if model_args.vit_type.lower() == "qwen2_5_vl":
|
| 460 |
+
language_model = hack_qwen2_5_vl_config(language_model)
|
| 461 |
+
|
| 462 |
+
image_token_id = language_model.config.video_token_id
|
| 463 |
+
new_token_ids.update({"image_token_id": image_token_id})
|
| 464 |
+
model.update_tokenizer(tokenizer=tokenizer)
|
| 465 |
+
|
| 466 |
+
if model_args.tie_word_embeddings:
|
| 467 |
+
model.language_model.untie_lm_head()
|
| 468 |
+
model.language_model.copy_new_token_rows_to_lm_head(num_new_tokens)
|
| 469 |
+
model_args.tie_word_embeddings = False
|
| 470 |
+
llm_config.tie_word_embeddings = False
|
| 471 |
+
else:
|
| 472 |
+
assert (
|
| 473 |
+
model.language_model.get_input_embeddings().weight.data.data_ptr()
|
| 474 |
+
!= model.language_model.get_output_embeddings().weight.data.data_ptr()
|
| 475 |
+
), "tie_world_embeddings 冲突"
|
| 476 |
+
|
| 477 |
+
model = model.to(device=device, dtype=torch.bfloat16)
|
| 478 |
+
model.eval()
|
| 479 |
+
if vae_model is not None and hasattr(vae_model, "eval"):
|
| 480 |
+
vae_model.eval()
|
| 481 |
+
|
| 482 |
+
dataset_config = build_runtime_dataset_config(
|
| 483 |
+
model_args=model_args,
|
| 484 |
+
training_args=training_args,
|
| 485 |
+
inference_args=inference_args,
|
| 486 |
+
vae_config=vae_config,
|
| 487 |
+
)
|
| 488 |
+
val_dataset = ValidationDataset(
|
| 489 |
+
jsonl_path=data_args.val_dataset_config_file,
|
| 490 |
+
tokenizer=tokenizer,
|
| 491 |
+
data_args=data_args,
|
| 492 |
+
model_args=model_args,
|
| 493 |
+
training_args=training_args,
|
| 494 |
+
new_token_ids=new_token_ids,
|
| 495 |
+
dataset_config=dataset_config,
|
| 496 |
+
local_rank=global_rank,
|
| 497 |
+
world_size=world_size,
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
val_loader = DataLoader(
|
| 501 |
+
val_dataset,
|
| 502 |
+
batch_size=1,
|
| 503 |
+
num_workers=0,
|
| 504 |
+
pin_memory=True,
|
| 505 |
+
collate_fn=simple_custom_collate,
|
| 506 |
+
drop_last=True,
|
| 507 |
+
prefetch_factor=None,
|
| 508 |
+
persistent_workers=False,
|
| 509 |
+
multiprocessing_context=None,
|
| 510 |
+
)
|
| 511 |
+
val_loader_iter = iter(val_loader)
|
| 512 |
+
|
| 513 |
+
if not hasattr(inference_args, "prompt_data_dict"):
|
| 514 |
+
inference_args.prompt_data_dict = {}
|
| 515 |
+
|
| 516 |
+
os.makedirs(inference_args.save_path_gen, exist_ok=True)
|
| 517 |
+
|
| 518 |
+
for _ in trange(
|
| 519 |
+
len(val_loader),
|
| 520 |
+
desc="Validating",
|
| 521 |
+
unit="batch",
|
| 522 |
+
leave=True,
|
| 523 |
+
ncols=80,
|
| 524 |
+
disable=(global_rank != 0),
|
| 525 |
+
):
|
| 526 |
+
val_data_cpu = next(val_loader_iter)
|
| 527 |
+
validate_on_fixed_batch(
|
| 528 |
+
fsdp_model=model,
|
| 529 |
+
vae_model=vae_model,
|
| 530 |
+
val_data_cpu=val_data_cpu,
|
| 531 |
+
training_args=training_args,
|
| 532 |
+
model_args=model_args,
|
| 533 |
+
inference_args=inference_args,
|
| 534 |
+
new_token_ids=new_token_ids,
|
| 535 |
+
image_token_id=image_token_id,
|
| 536 |
+
device=device,
|
| 537 |
+
save_path_gen=inference_args.save_path_gen,
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
if dist.is_initialized():
|
| 541 |
+
dist.barrier()
|
| 542 |
+
gathered = [None for _ in range(dist.get_world_size())]
|
| 543 |
+
dist.all_gather_object(gathered, inference_args.prompt_data_dict)
|
| 544 |
+
|
| 545 |
+
if global_rank == 0:
|
| 546 |
+
merged = {}
|
| 547 |
+
for d in gathered:
|
| 548 |
+
merged.update(d)
|
| 549 |
+
inference_args.prompt_data_dict = merged
|
| 550 |
+
save_prompt_results(inference_args.prompt_data_dict, inference_args.save_path_gen)
|
| 551 |
+
elif global_rank == 0:
|
| 552 |
+
save_prompt_results(inference_args.prompt_data_dict, inference_args.save_path_gen)
|
| 553 |
+
|
| 554 |
+
if dist.is_initialized():
|
| 555 |
+
dist.destroy_process_group()
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
if __name__ == "__main__":
|
| 559 |
+
main()
|
benchmarks/video_gen/Vbench/sample_vbench.sh
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
| 4 |
+
source "$SCRIPT_DIR/../../sample_env.sh"
|
| 5 |
+
|
| 6 |
+
# ========================= 推理参数配置 =========================
|
| 7 |
+
TASK_NAME="t2v"
|
| 8 |
+
NUM_GPUS=8
|
| 9 |
+
|
| 10 |
+
VALIDATION_NUM_TIMESTEPS=30 # 30 # 50 # 10 # 30 # 50
|
| 11 |
+
VALIDATION_TIMESTEP_SHIFT=3.0 # 3.5
|
| 12 |
+
EVALUATION_SEED=42
|
| 13 |
+
CFG_TEXT_SCALE=4.0
|
| 14 |
+
CFG_INTERVAL_START=0.4
|
| 15 |
+
CFG_INTERVAL_END=1.0
|
| 16 |
+
SAMPLE_NUM_PER_PROMPT=5
|
| 17 |
+
USE_KVCACHE=true
|
| 18 |
+
|
| 19 |
+
VIDEO_HEIGHT=480
|
| 20 |
+
VIDEO_WIDTH=848
|
| 21 |
+
NUM_FRAMES=50
|
| 22 |
+
MAX_NUM_FRAMES=121
|
| 23 |
+
MAX_LATENT_SIZE=64
|
| 24 |
+
RESOLUTION="video_480p"
|
| 25 |
+
|
| 26 |
+
MODEL_PATH="downloads/Lance_3B_Video"
|
| 27 |
+
VAL_DATASET_CONFIG_FILE="benchmarks/video_gen/Vbench/Vbench_recaption.jsonl"
|
| 28 |
+
|
| 29 |
+
# ========================= 自动生成路径 =========================
|
| 30 |
+
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
|
| 31 |
+
KVCACHE_TAG=""
|
| 32 |
+
if [ "$USE_KVCACHE" = "true" ]; then
|
| 33 |
+
KVCACHE_TAG="kvcache_"
|
| 34 |
+
fi
|
| 35 |
+
SAVE_PATH_GEN="results/Vbench_ts${VALIDATION_NUM_TIMESTEPS}_tss${VALIDATION_TIMESTEP_SHIFT}_seed${EVALUATION_SEED}_cfg${CFG_TEXT_SCALE}_${KVCACHE_TAG}${TIMESTAMP}"
|
| 36 |
+
|
| 37 |
+
if [ -z "$MODEL_PATH" ]; then
|
| 38 |
+
echo "错误: 请在脚本顶部配置区手动设置 MODEL_PATH"
|
| 39 |
+
exit 1
|
| 40 |
+
fi
|
| 41 |
+
|
| 42 |
+
# ============================== 环境与分布式配置 ==============================
|
| 43 |
+
lance_setup_common_env
|
| 44 |
+
lance_setup_distributed_env "$NUM_GPUS"
|
| 45 |
+
lance_setup_shard_env 1
|
| 46 |
+
|
| 47 |
+
# ========================= 显示任务配置 =========================
|
| 48 |
+
echo "================================================"
|
| 49 |
+
echo "VBench T2V 推理"
|
| 50 |
+
echo "================================================"
|
| 51 |
+
echo "GPU数量: ${NUM_GPUS}"
|
| 52 |
+
echo "保存路径: ${SAVE_PATH_GEN}"
|
| 53 |
+
echo "分辨率: ${VIDEO_HEIGHT}x${VIDEO_WIDTH}"
|
| 54 |
+
echo "输出帧数: ${NUM_FRAMES}"
|
| 55 |
+
echo "最大帧数: ${MAX_NUM_FRAMES}"
|
| 56 |
+
echo "模型路径: ${MODEL_PATH}"
|
| 57 |
+
if [ -n "$VAL_DATASET_CONFIG_FILE" ]; then
|
| 58 |
+
echo "数据路径: ${VAL_DATASET_CONFIG_FILE}"
|
| 59 |
+
fi
|
| 60 |
+
if [ -n "$CONFIG_JSON_PATH" ]; then
|
| 61 |
+
echo "配置JSON: ${CONFIG_JSON_PATH}"
|
| 62 |
+
fi
|
| 63 |
+
echo ""
|
| 64 |
+
echo "关键参数:"
|
| 65 |
+
echo " - validation_num_timesteps: ${VALIDATION_NUM_TIMESTEPS}"
|
| 66 |
+
echo " - validation_timestep_shift: ${VALIDATION_TIMESTEP_SHIFT}"
|
| 67 |
+
echo " - evaluation_seed: ${EVALUATION_SEED}"
|
| 68 |
+
echo " - cfg_text_scale: ${CFG_TEXT_SCALE}"
|
| 69 |
+
echo " - cfg_interval: [${CFG_INTERVAL_START}, ${CFG_INTERVAL_END}]"
|
| 70 |
+
echo " - num_frames: ${NUM_FRAMES}"
|
| 71 |
+
echo " - sample_num_per_prompt: ${SAMPLE_NUM_PER_PROMPT}"
|
| 72 |
+
echo " - use_KVcache: ${USE_KVCACHE}"
|
| 73 |
+
echo "================================================"
|
| 74 |
+
echo ""
|
| 75 |
+
|
| 76 |
+
# ============================== 执行推理 ==============================
|
| 77 |
+
# 注意:请直接修改本脚本顶部的“推理参数配置”区
|
| 78 |
+
accelerate launch \
|
| 79 |
+
--num_machines $NUM_MACHINES \
|
| 80 |
+
--num_processes $TOTAL_RANK \
|
| 81 |
+
--machine_rank $MACHINE_RANK \
|
| 82 |
+
--main_process_ip $MAIN_PROCESS_IP \
|
| 83 |
+
--main_process_port $MAIN_PROCESS_PORT \
|
| 84 |
+
--mixed_precision bf16 \
|
| 85 |
+
benchmarks/video_gen/Vbench/sample_vbench.py \
|
| 86 |
+
--model_path "$MODEL_PATH" \
|
| 87 |
+
--val_dataset_config_file "$VAL_DATASET_CONFIG_FILE" \
|
| 88 |
+
--config_json_path "$CONFIG_JSON_PATH" \
|
| 89 |
+
--vit_type qwen_2_5_vl_original \
|
| 90 |
+
--llm_qk_norm true \
|
| 91 |
+
--llm_qk_norm_und true \
|
| 92 |
+
--llm_qk_norm_gen true \
|
| 93 |
+
--tie_word_embeddings false \
|
| 94 |
+
--validation_num_timesteps $VALIDATION_NUM_TIMESTEPS \
|
| 95 |
+
--validation_timestep_shift $VALIDATION_TIMESTEP_SHIFT \
|
| 96 |
+
--copy_init_moe true \
|
| 97 |
+
--use_flex true \
|
| 98 |
+
--max_num_frames $MAX_NUM_FRAMES \
|
| 99 |
+
--max_latent_size $MAX_LATENT_SIZE \
|
| 100 |
+
--latent_patch_size 1 1 1 \
|
| 101 |
+
--num_replicate $NUM_REPLICATE \
|
| 102 |
+
--num_shard $NUM_SHARD \
|
| 103 |
+
--visual_und true \
|
| 104 |
+
--visual_gen true \
|
| 105 |
+
--vae_model_type wan \
|
| 106 |
+
--apply_qwen_2_5_vl_pos_emb true \
|
| 107 |
+
--apply_chat_template false \
|
| 108 |
+
--cfg_type 0 \
|
| 109 |
+
--validation_video_saving_fps 12 \
|
| 110 |
+
--validation_log_type direct \
|
| 111 |
+
--video_height $VIDEO_HEIGHT \
|
| 112 |
+
--video_width $VIDEO_WIDTH \
|
| 113 |
+
--num_frames $NUM_FRAMES \
|
| 114 |
+
--task $TASK_NAME \
|
| 115 |
+
--save_path_gen $SAVE_PATH_GEN \
|
| 116 |
+
--resolution $RESOLUTION \
|
| 117 |
+
--evaluation_seed $EVALUATION_SEED \
|
| 118 |
+
--text_template true \
|
| 119 |
+
--sample_num_per_prompt $SAMPLE_NUM_PER_PROMPT \
|
| 120 |
+
--cfg_text_scale $CFG_TEXT_SCALE \
|
| 121 |
+
--cfg_interval $CFG_INTERVAL_START $CFG_INTERVAL_END \
|
| 122 |
+
--use_KVcache $USE_KVCACHE
|
| 123 |
+
|
| 124 |
+
echo ""
|
| 125 |
+
echo "================================================"
|
| 126 |
+
echo "完成! 结果: ${SAVE_PATH_GEN}"
|
| 127 |
+
echo "================================================"
|
benchmarks/video_gen/Vbench/temporal_flickering_prompts.json
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
"In a still frame, a stop sign",
|
| 3 |
+
"a toilet, frozen in time",
|
| 4 |
+
"a laptop, frozen in time",
|
| 5 |
+
"A tranquil tableau of alley",
|
| 6 |
+
"A tranquil tableau of bar",
|
| 7 |
+
"A tranquil tableau of barn",
|
| 8 |
+
"A tranquil tableau of bathroom",
|
| 9 |
+
"A tranquil tableau of bedroom",
|
| 10 |
+
"A tranquil tableau of cliff",
|
| 11 |
+
"In a still frame, courtyard",
|
| 12 |
+
"In a still frame, gas station",
|
| 13 |
+
"A tranquil tableau of house",
|
| 14 |
+
"indoor gymnasium, frozen in time",
|
| 15 |
+
"A tranquil tableau of indoor library",
|
| 16 |
+
"A tranquil tableau of kitchen",
|
| 17 |
+
"A tranquil tableau of palace",
|
| 18 |
+
"In a still frame, parking lot",
|
| 19 |
+
"In a still frame, phone booth",
|
| 20 |
+
"A tranquil tableau of restaurant",
|
| 21 |
+
"A tranquil tableau of tower",
|
| 22 |
+
"A tranquil tableau of a bowl",
|
| 23 |
+
"A tranquil tableau of an apple",
|
| 24 |
+
"A tranquil tableau of a bench",
|
| 25 |
+
"A tranquil tableau of a bed",
|
| 26 |
+
"A tranquil tableau of a chair",
|
| 27 |
+
"A tranquil tableau of a cup",
|
| 28 |
+
"A tranquil tableau of a dining table",
|
| 29 |
+
"In a still frame, a pear",
|
| 30 |
+
"A tranquil tableau of a bunch of grapes",
|
| 31 |
+
"A tranquil tableau of a bowl on the kitchen counter",
|
| 32 |
+
"A tranquil tableau of a beautiful, handcrafted ceramic bowl",
|
| 33 |
+
"A tranquil tableau of an antique bowl",
|
| 34 |
+
"A tranquil tableau of an exquisite mahogany dining table",
|
| 35 |
+
"A tranquil tableau of a wooden bench in the park",
|
| 36 |
+
"A tranquil tableau of a beautiful wrought-iron bench surrounded by blooming flowers",
|
| 37 |
+
"In a still frame, a park bench with a view of the lake",
|
| 38 |
+
"A tranquil tableau of a vintage rocking chair was placed on the porch",
|
| 39 |
+
"A tranquil tableau of the jail cell was small and dimly lit, with cold, steel bars",
|
| 40 |
+
"A tranquil tableau of the phone booth was tucked away in a quiet alley",
|
| 41 |
+
"a dilapidated phone booth stood as a relic of a bygone era on the sidewalk, frozen in time",
|
| 42 |
+
"A tranquil tableau of the old red barn stood weathered and iconic against the backdrop of the countryside",
|
| 43 |
+
"A tranquil tableau of a picturesque barn was painted a warm shade of red and nestled in a picturesque meadow",
|
| 44 |
+
"In a still frame, within the desolate desert, an oasis unfolded, characterized by the stoic presence of palm trees and a motionless, glassy pool of water",
|
| 45 |
+
"In a still frame, the Parthenon's majestic Doric columns stand in serene solitude atop the Acropolis, framed by the tranquil Athenian landscape",
|
| 46 |
+
"In a still frame, the Temple of Hephaestus, with its timeless Doric grace, stands stoically against the backdrop of a quiet Athens",
|
| 47 |
+
"In a still frame, the ornate Victorian streetlamp stands solemnly, adorned with intricate ironwork and stained glass panels",
|
| 48 |
+
"A tranquil tableau of the Stonehenge presented itself as an enigmatic puzzle, each colossal stone meticulously placed against the backdrop of tranquility",
|
| 49 |
+
"In a still frame, in the vast desert, an oasis nestled among dunes, featuring tall palm trees and an air of serenity",
|
| 50 |
+
"static view on a desert scene with an oasis, palm trees, and a clear, calm pool of water",
|
| 51 |
+
"A tranquil tableau of an ornate Victorian streetlamp standing on a cobblestone street corner, illuminating the empty night",
|
| 52 |
+
"A tranquil tableau of a tranquil lakeside cabin nestled among tall pines, its reflection mirrored perfectly in the calm water",
|
| 53 |
+
"In a still frame, a vintage gas lantern, adorned with intricate details, gracing a historic cobblestone square",
|
| 54 |
+
"In a still frame, a tranquil Japanese tea ceremony room, with tatami mats, a delicate tea set, and a bonsai tree in the corner",
|
| 55 |
+
"A tranquil tableau of the Parthenon stands resolute in its classical elegance, a timeless symbol of Athens' cultural legacy",
|
| 56 |
+
"A tranquil tableau of in the heart of Plaka, the neoclassical architecture of the old city harmonizes with the ancient ruins",
|
| 57 |
+
"A tranquil tableau of in the desolate beauty of the American Southwest, Chaco Canyon's ancient ruins whispered tales of an enigmatic civilization that once thrived amidst the arid landscapes",
|
| 58 |
+
"A tranquil tableau of at the edge of the Arabian Desert, the ancient city of Petra beckoned with its enigmatic rock-carved façades",
|
| 59 |
+
"In a still frame, amidst the cobblestone streets, an Art Nouveau lamppost stood tall",
|
| 60 |
+
"A tranquil tableau of in the quaint village square, a traditional wrought-iron streetlamp featured delicate filigree patterns and amber-hued glass panels",
|
| 61 |
+
"A tranquil tableau of the lampposts were adorned with Art Deco motifs, their geometric shapes and frosted glass creating a sense of vintage glamour",
|
| 62 |
+
"In a still frame, in the picturesque square, a Gothic-style lamppost adorned with intricate stone carvings added a touch of medieval charm to the setting",
|
| 63 |
+
"In a still frame, in the heart of the old city, a row of ornate lantern-style streetlamps bathed the narrow alleyway in a warm, welcoming light",
|
| 64 |
+
"A tranquil tableau of in the heart of the Utah desert, a massive sandstone arch spanned the horizon",
|
| 65 |
+
"A tranquil tableau of in the Arizona desert, a massive stone bridge arched across a rugged canyon",
|
| 66 |
+
"A tranquil tableau of in the corner of the minimalist tea room, a bonsai tree added a touch of nature's beauty to the otherwise simple and elegant space",
|
| 67 |
+
"In a still frame, amidst the hushed ambiance of the traditional tea room, a meticulously arranged tea set awaited, with porcelain cups, a bamboo whisk",
|
| 68 |
+
"In a still frame, nestled in the Zen garden, a rustic teahouse featured tatami seating and a traditional charcoal brazier",
|
| 69 |
+
"A tranquil tableau of a country estate's library featured elegant wooden shelves",
|
| 70 |
+
"A tranquil tableau of beneath the shade of a solitary oak tree, an old wooden park bench sat patiently",
|
| 71 |
+
"A tranquil tableau of beside a tranquil pond, a weeping willow tree draped its branches gracefully over the water's surface, creating a serene tableau of reflection and calm",
|
| 72 |
+
"A tranquil tableau of in the Zen garden, a perfectly raked gravel path led to a serene rock garden",
|
| 73 |
+
"In a still frame, a tranquil pond was fringed by weeping cherry trees, their blossoms drifting lazily onto the glassy surface",
|
| 74 |
+
"In a still frame, within the historic library's reading room, rows of antique leather chairs and mahogany tables offered a serene haven for literary contemplation",
|
| 75 |
+
"A tranquil tableau of a peaceful orchid garden showcased a variety of delicate blooms",
|
| 76 |
+
"A tranquil tableau of in the serene courtyard, a centuries-old stone well stood as a symbol of a bygone era, its mossy stones bearing witness to the passage of time"
|
| 77 |
+
]
|
common/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# coding: utf-8
|
| 15 |
+
|
| 16 |
+
"""Common utilities package."""
|
common/model/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# coding: utf-8
|
| 15 |
+
|
| 16 |
+
from .hacks import hack_qwen2_5_vl_config
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"hack_qwen2_5_vl_config",
|
| 20 |
+
]
|
common/model/checks.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# coding: utf-8
|
common/model/hacks.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# coding: utf-8
|
| 15 |
+
|
| 16 |
+
def hack_qwen2_5_vl_config(language_model):
|
| 17 |
+
# HACK!!!!!
|
| 18 |
+
language_model.config.image_token_id = 151655
|
| 19 |
+
language_model.config.video_token_id = 151656
|
| 20 |
+
language_model.config.vision_start_token_id = 151652
|
| 21 |
+
language_model.config.vision_end_token_id = 151653
|
| 22 |
+
|
| 23 |
+
language_model.config.vision_config = {
|
| 24 |
+
"depth": 32,
|
| 25 |
+
"hidden_act": "silu",
|
| 26 |
+
"hidden_size": 1280,
|
| 27 |
+
"intermediate_size": 3420,
|
| 28 |
+
"num_heads": 16,
|
| 29 |
+
"in_chans": 3,
|
| 30 |
+
"out_hidden_size": 2048,
|
| 31 |
+
"patch_size": 14,
|
| 32 |
+
"spatial_merge_size": 2,
|
| 33 |
+
"spatial_patch_size": 14,
|
| 34 |
+
"window_size": 112,
|
| 35 |
+
"fullatt_block_indexes": [
|
| 36 |
+
7,
|
| 37 |
+
15,
|
| 38 |
+
23,
|
| 39 |
+
31
|
| 40 |
+
],
|
| 41 |
+
"tokens_per_second": 2,
|
| 42 |
+
"temporal_patch_size": 2
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
language_model.config.rope_scaling = {
|
| 46 |
+
"type": "mrope",
|
| 47 |
+
"mrope_section": [
|
| 48 |
+
16,
|
| 49 |
+
24,
|
| 50 |
+
24
|
| 51 |
+
]
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
return language_model
|
common/utils/__init__.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# coding: utf-8
|
| 15 |
+
|
| 16 |
+
from .distributed import (
|
| 17 |
+
get_global_rank,
|
| 18 |
+
get_local_rank,
|
| 19 |
+
get_world_size,
|
| 20 |
+
is_master,
|
| 21 |
+
get_device,
|
| 22 |
+
barrier_if_distributed,
|
| 23 |
+
)
|
| 24 |
+
from .logging import get_logger
|
| 25 |
+
from .misc import AutoEncoderParams, tuple_mul
|
| 26 |
+
from .tensor_ops import (
|
| 27 |
+
flatten,
|
| 28 |
+
unflatten,
|
| 29 |
+
rearrange,
|
| 30 |
+
repeat,
|
| 31 |
+
pack,
|
| 32 |
+
unpack,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
__all__ = [
|
| 36 |
+
# distributed
|
| 37 |
+
"get_global_rank",
|
| 38 |
+
"get_local_rank",
|
| 39 |
+
"get_world_size",
|
| 40 |
+
"is_master",
|
| 41 |
+
"get_device",
|
| 42 |
+
"barrier_if_distributed",
|
| 43 |
+
# logging
|
| 44 |
+
"get_logger",
|
| 45 |
+
# misc
|
| 46 |
+
"AutoEncoderParams",
|
| 47 |
+
"tuple_mul",
|
| 48 |
+
# tensor_ops
|
| 49 |
+
"flatten",
|
| 50 |
+
"unflatten",
|
| 51 |
+
"rearrange",
|
| 52 |
+
"repeat",
|
| 53 |
+
"pack",
|
| 54 |
+
"unpack",
|
| 55 |
+
]
|
common/utils/distributed.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# coding: utf-8
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import torch
|
| 18 |
+
import torch.distributed as dist
|
| 19 |
+
|
| 20 |
+
def get_global_rank() -> int:
|
| 21 |
+
"""
|
| 22 |
+
Get the global rank, the global index of the GPU.
|
| 23 |
+
"""
|
| 24 |
+
return int(os.environ.get("RANK", "0"))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_local_rank() -> int:
|
| 28 |
+
"""
|
| 29 |
+
Get the local rank, the local index of the GPU.
|
| 30 |
+
"""
|
| 31 |
+
return int(os.environ.get("LOCAL_RANK", "0"))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_world_size() -> int:
|
| 35 |
+
"""
|
| 36 |
+
Get the world size, the total amount of GPUs.
|
| 37 |
+
"""
|
| 38 |
+
return int(os.environ.get("WORLD_SIZE", "1"))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def is_master():
|
| 42 |
+
"""
|
| 43 |
+
Check if the current process is the master process (rank 0).
|
| 44 |
+
"""
|
| 45 |
+
if not dist.is_available() or not dist.is_initialized():
|
| 46 |
+
return True
|
| 47 |
+
return dist.get_rank() == 0
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_device() -> torch.device:
|
| 51 |
+
"""
|
| 52 |
+
Get current rank device.
|
| 53 |
+
"""
|
| 54 |
+
return torch.device("cuda", get_local_rank())
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def barrier_if_distributed(*args, **kwargs):
|
| 58 |
+
"""
|
| 59 |
+
Synchronizes all processes if under distributed context.
|
| 60 |
+
"""
|
| 61 |
+
if dist.is_initialized():
|
| 62 |
+
return dist.barrier(*args, **kwargs)
|
common/utils/logging.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Logging utility functions.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import logging
|
| 20 |
+
import sys
|
| 21 |
+
from typing import Optional
|
| 22 |
+
|
| 23 |
+
from .distributed import get_global_rank, get_local_rank, get_world_size
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
| 27 |
+
logger = logging.getLogger(name)
|
| 28 |
+
logger.setLevel(logging.INFO)
|
| 29 |
+
logger.propagate = False # 修复: 禁用日志传播,防止日志被父级 logger 重复处理
|
| 30 |
+
|
| 31 |
+
if not logger.handlers: # 只看自身,避免祖先影响
|
| 32 |
+
h = logging.StreamHandler(sys.stdout)
|
| 33 |
+
fmt = logging.Formatter(
|
| 34 |
+
"[%(asctime)s] "
|
| 35 |
+
+ (f"[Rank:{get_global_rank()}]" if get_world_size() > 1 else "")
|
| 36 |
+
+ (f"[LocalRank:{get_local_rank()}]" if get_world_size() > 1 else "")
|
| 37 |
+
+ "[%(pathname)s:%(lineno)d][%(threadName).12s][%(name)s][%(levelname).5s] %(message)s"
|
| 38 |
+
)
|
| 39 |
+
h.setFormatter(fmt)
|
| 40 |
+
logger.addHandler(h)
|
| 41 |
+
return logger
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
common/utils/misc.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# coding: utf-8
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class AutoEncoderParams:
|
| 20 |
+
downsample_spatial: int
|
| 21 |
+
downsample_temporal: int
|
| 22 |
+
z_channels: int
|
| 23 |
+
# for flux
|
| 24 |
+
scale_factor: float = 0.3611
|
| 25 |
+
shift_factor: float = 0.1159
|
| 26 |
+
|
| 27 |
+
def tuple_mul(a: tuple, b: tuple) -> tuple:
|
| 28 |
+
"""
|
| 29 |
+
返回两个同长度 tuple 的按位乘积。
|
| 30 |
+
|
| 31 |
+
参数:
|
| 32 |
+
a (tuple of numbers):第一个元组
|
| 33 |
+
b (tuple of numbers):第二个元组,长度需与 a 一致
|
| 34 |
+
|
| 35 |
+
返回:
|
| 36 |
+
tuple:按位相乘后的结果
|
| 37 |
+
"""
|
| 38 |
+
if len(a) != len(b):
|
| 39 |
+
raise ValueError("两个元组长度必须相等")
|
| 40 |
+
return tuple(x * y for x, y in zip(a, b))
|