Add files using upload-large-folder tool
Browse files- .claude/settings.local.json +17 -0
- ICL/.claude/settings.local.json +32 -0
- ICL/DAPO/verl-recipe/.github/workflows/pre-commit.yml +37 -0
- ICL/DAPO/verl-recipe/dapo/config/dapo_megatron_trainer.yaml +28 -0
- ICL/DAPO/verl-recipe/entropy/reward_score/entropy_math/math_normalize.py +192 -0
- ICL/DAPO/verl-recipe/fault_recover/agent_loop/fault_recover_agent_loop.py +137 -0
- ICL/DAPO/verl-recipe/spo/agent_loop/spo_agent_loop.py +155 -0
- ICL/DAPO/verl-recipe/spo/agent_loop/spo_tool_agent_loop.py +414 -0
- ICL/DAPO/verl-recipe/sppo/config/sppo_trainer.yaml +38 -0
- ICL/EVAL_GUIDE.md +47 -0
- ICL/LV/dataset_inspect.tree.txt +456 -0
- ICL/RL_DAPO/__init__.py +1 -0
- ICL/SFT_new/README.md +389 -0
- ICL/SFT_new/convert_and_eval.sh +87 -0
- ICL/SFT_new/ds_zero2.json +37 -0
- ICL/SFT_new/ds_zero3.json +28 -0
- ICL/SFT_new/eval.py +961 -0
- ICL/SFT_new/launch_wrapper.py +13 -0
- ICL/SFT_new/rebuild_and_train.sh +86 -0
- ICL/SFT_new/run_eval.sh +74 -0
- ICL/SFT_new/run_single_node.sh +49 -0
- ICL/SFT_new/submit_northjob.sh +38 -0
- ICL/SFT_new/train.py +659 -0
- ICL/build_embeddings.py +370 -0
- ICL/build_index.py +506 -0
- ICL/build_sft.py +466 -0
- ICL/dataset_inspect.tree.txt +456 -0
- ICL/eval_icl.py +524 -0
- ICL/extract_images.py +231 -0
- ICL/merge_captions.py +70 -0
- ICL/sft_model/epoch3_step1406_fp32/chat_template.json +3 -0
- ICL/sft_model/epoch3_step1406_fp32/config.json +62 -0
- ICL/sft_model/epoch3_step1406_fp32/generation_config.json +14 -0
- ICL/sft_model/epoch3_step1406_fp32/merges.txt +0 -0
- ICL/sft_model/epoch3_step1406_fp32/model.safetensors.index.json +757 -0
- ICL/sft_model/epoch3_step1406_fp32/preprocessor_config.json +21 -0
- ICL/sft_model/epoch3_step1406_fp32/tokenizer.json +0 -0
- ICL/sft_model/epoch3_step1406_fp32/tokenizer_config.json +239 -0
- ICL/sft_model/epoch3_step1406_fp32/video_preprocessor_config.json +21 -0
- ICL/sft_model/epoch3_step1406_fp32/vocab.json +0 -0
- ICL/sft_model/zero_to_fp32.py +760 -0
- RL_dataset/.gitattributes +89 -0
- RL_dataset/.msc +0 -0
- RL_dataset/.mv +1 -0
- RL_dataset/INFOSEEK_DOWNLOAD.md +337 -0
- RL_dataset/README.md +171 -0
- RL_dataset/dataset_infos.json +1 -0
- RL_dataset/download_oven_hf_mirror.sh +189 -0
- RL_dataset/download_scienceqa_hf.sh +135 -0
- download_hf.py +49 -0
.claude/settings.local.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"permissions": {
|
| 3 |
+
"allow": [
|
| 4 |
+
"Bash(find /workspace/xiaobin/ICL/SFT_new/output/emb_cache/ -type f -name *.json)",
|
| 5 |
+
"Bash(find /workspace/xiaobin/ICL/SFT_new/output -type f -name *.json)",
|
| 6 |
+
"Bash(find /workspace/xiaobin/ICL -type f -name *.json)",
|
| 7 |
+
"Bash(find /workspace/xiaobin/ICL/SFT_new/output/emb_cache -name *.json)",
|
| 8 |
+
"Bash(find /workspace/xiaobin -path */medlab/*vllm_thread* -o -path */medlab/*vllm*)",
|
| 9 |
+
"Bash(find /workspace/xiaobin/ICL -path */emb_cache/*.json)",
|
| 10 |
+
"Bash(python -c \"import py_compile; py_compile.compile\\(''build_sft.py'', doraise=True\\); print\\(''OK''\\)\")",
|
| 11 |
+
"Bash(python -c \"import py_compile; py_compile.compile\\(''generate_captions.py'', doraise=True\\); print\\(''OK''\\)\")",
|
| 12 |
+
"Bash(find /workspace/xiaobin -type f -name *.py)",
|
| 13 |
+
"Bash(python:*)",
|
| 14 |
+
"Bash(find /workspace -path */NorthServe/* -maxdepth 3)"
|
| 15 |
+
]
|
| 16 |
+
}
|
| 17 |
+
}
|
ICL/.claude/settings.local.json
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"permissions": {
|
| 3 |
+
"allow": [
|
| 4 |
+
"Bash(python3 -c \"import sys,json; line=sys.stdin.readline\\(\\); d=json.loads\\(line\\); print\\(list\\(d.keys\\(\\)\\)\\); [print\\(f''''{k}: {type\\(d[k]\\).__name__}, len={len\\(str\\(d[k]\\)\\)}''''\\) for k in d.keys\\(\\)]\")",
|
| 5 |
+
"Bash(python3:*)",
|
| 6 |
+
"Bash(find /workspace/xiaobin/ICL -name *sft*.jsonl -o -name output -type d)",
|
| 7 |
+
"Bash(find /workspace/xiaobin/ICL -name *.jsonl)",
|
| 8 |
+
"Bash(wc:*)",
|
| 9 |
+
"Bash(grep -r \"model_path\\\\|model-path\" /workspace/xiaobin/ICL/SFT_new/*.py)",
|
| 10 |
+
"Bash(grep -r Qwen /workspace/xiaobin/ICL/SFT_new/*.py)",
|
| 11 |
+
"Bash(grep -l embedding /workspace/xiaobin/ICL/SFT/*.py)",
|
| 12 |
+
"Bash(du -sh /workspace/xiaobin/dataset/*)",
|
| 13 |
+
"Bash(lscpu)",
|
| 14 |
+
"Bash(/workspace/miniconda3/envs/sft/bin/python -c \"import torch; print\\('torch:', torch.__version__\\); print\\('CXX11_ABI:', torch._C._GLIBCXX_USE_CXX11_ABI\\)\")",
|
| 15 |
+
"Bash(find /workspace/xiaobin/ICL -maxdepth 3 -name *eval* -o -name *inference* -o -name *test*)",
|
| 16 |
+
"Bash(ls /workspace/xiaobin/ICL/sft_model/final/*.py)",
|
| 17 |
+
"Bash(ls /workspace/xiaobin/ICL/sft_model/final/mp_rank*)",
|
| 18 |
+
"Bash(ls /workspace/xiaobin/ICL/sft_model/final/*.json)",
|
| 19 |
+
"Bash(ls /workspace/xiaobin/ICL/sft_model/final/*tag*)",
|
| 20 |
+
"Bash(pip show:*)",
|
| 21 |
+
"Bash(conda run:*)",
|
| 22 |
+
"Read(//workspace/xiaobin/dataset/sft/all/**)",
|
| 23 |
+
"Bash(find /workspace/xiaobin/ICL -type f -name *eval* -o -name *test* -o -name *infer* -o -name *benchmark* -o -name *generate* -o -name *predict*)",
|
| 24 |
+
"Bash(find /workspace/xiaobin/ICL -type f \\\\\\(-name *.jsonl -o -name *.json \\\\\\))",
|
| 25 |
+
"Bash(grep -E \"\\\\.\\(py|sh\\)$\")",
|
| 26 |
+
"Bash(find /workspace/xiaobin/ICL -type f -name *.jsonl)",
|
| 27 |
+
"Read(//workspace/xiaobin/dataset/sft/**)",
|
| 28 |
+
"Read(//workspace/xiaobin/dataset/**)",
|
| 29 |
+
"Bash(python3 -c \"import json; d=json.load\\(open\\(''/workspace/xiaobin/dataset/detail/captioning/coco/train/captions.json''\\)\\); print\\(''keys:'', list\\(d.keys\\(\\)\\)\\); items=d[''items'']; k=list\\(items.keys\\(\\)\\)[0]; print\\(k, ''->'', items[k][:100]\\)\")"
|
| 30 |
+
]
|
| 31 |
+
}
|
| 32 |
+
}
|
ICL/DAPO/verl-recipe/.github/workflows/pre-commit.yml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# c.f. https://github.com/pre-commit/action?tab=readme-ov-file#using-this-action
|
| 2 |
+
name: pre-commit
|
| 3 |
+
|
| 4 |
+
# No need to avoid / cancel lightweight pre-commit jobs
|
| 5 |
+
on:
|
| 6 |
+
schedule:
|
| 7 |
+
- cron: "0 0 * * 0"
|
| 8 |
+
pull_request:
|
| 9 |
+
push:
|
| 10 |
+
branches:
|
| 11 |
+
- main
|
| 12 |
+
- v0.*
|
| 13 |
+
# Allow manual triggering
|
| 14 |
+
workflow_dispatch:
|
| 15 |
+
|
| 16 |
+
# Declare permissions just read content.
|
| 17 |
+
permissions:
|
| 18 |
+
contents: read
|
| 19 |
+
|
| 20 |
+
jobs:
|
| 21 |
+
pre-commit:
|
| 22 |
+
runs-on: ubuntu-latest
|
| 23 |
+
strategy:
|
| 24 |
+
matrix:
|
| 25 |
+
python-version: ["3.12"]
|
| 26 |
+
steps:
|
| 27 |
+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
| 28 |
+
- name: Set up Python ${{ matrix.python-version }}
|
| 29 |
+
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
| 30 |
+
with:
|
| 31 |
+
python-version: ${{ matrix.python-version }}
|
| 32 |
+
- name: Set ruff --output-format=github
|
| 33 |
+
run: |
|
| 34 |
+
sed -i 's/--output-format=full/--output-format=github/' .pre-commit-config.yaml
|
| 35 |
+
git add .pre-commit-config.yaml
|
| 36 |
+
# Check "--all-files" by default
|
| 37 |
+
- uses: pre-commit/action@v3.0.1
|
ICL/DAPO/verl-recipe/dapo/config/dapo_megatron_trainer.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
searchpath:
|
| 3 |
+
- file://verl/trainer/config
|
| 4 |
+
|
| 5 |
+
defaults:
|
| 6 |
+
- ppo_megatron_trainer
|
| 7 |
+
- _self_
|
| 8 |
+
|
| 9 |
+
data:
|
| 10 |
+
gen_batch_size: ${data.train_batch_size}
|
| 11 |
+
|
| 12 |
+
reward_model:
|
| 13 |
+
reward_manager: dapo
|
| 14 |
+
overlong_buffer:
|
| 15 |
+
enable: False # We try to avoid forgetting to set enable
|
| 16 |
+
len: 0
|
| 17 |
+
penalty_factor: 0.0
|
| 18 |
+
log: False
|
| 19 |
+
|
| 20 |
+
algorithm:
|
| 21 |
+
filter_groups:
|
| 22 |
+
_target_: verl.trainer.config.FilterGroupsConfig
|
| 23 |
+
enable: False # We try to avoid forgetting to set enable
|
| 24 |
+
metric: null # acc / score / seq_reward / seq_final_reward / ...
|
| 25 |
+
max_num_gen_batches: 0 # Non-positive values mean no upper limit
|
| 26 |
+
|
| 27 |
+
trainer:
|
| 28 |
+
project_name: verl-dapo
|
ICL/DAPO/verl-recipe/entropy/reward_score/entropy_math/math_normalize.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 PRIME team and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# Copyright (c) 2021 Dan Hendrycks
|
| 16 |
+
#
|
| 17 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 18 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 19 |
+
# in the Software without restriction, including without limitation the rights
|
| 20 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 21 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 22 |
+
# furnished to do so, subject to the following conditions:
|
| 23 |
+
#
|
| 24 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 25 |
+
# copies or substantial portions of the Software.
|
| 26 |
+
#
|
| 27 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 28 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 29 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 30 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 31 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 32 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 33 |
+
# SOFTWARE.
|
| 34 |
+
"""
|
| 35 |
+
This logic is largely copied from the Hendrycks' MATH release (math_equivalence).
|
| 36 |
+
|
| 37 |
+
From: https://github.com/openai/prm800k/blob/main/prm800k/grading/math_normalize.py
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
import re
|
| 41 |
+
from typing import Optional
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def normalize_answer(answer: Optional[str]) -> Optional[str]:
|
| 45 |
+
if answer is None:
|
| 46 |
+
return None
|
| 47 |
+
answer = answer.strip()
|
| 48 |
+
try:
|
| 49 |
+
# Remove enclosing `\text{}`.
|
| 50 |
+
m = re.search(r"^\\text\{(?P<text>.+?)\}$", answer)
|
| 51 |
+
if m is not None:
|
| 52 |
+
answer = m.group("text").strip()
|
| 53 |
+
return _strip_string(answer)
|
| 54 |
+
except Exception:
|
| 55 |
+
return answer
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _fix_fracs(string):
|
| 59 |
+
substrs = string.split("\\frac")
|
| 60 |
+
new_str = substrs[0]
|
| 61 |
+
if len(substrs) > 1:
|
| 62 |
+
substrs = substrs[1:]
|
| 63 |
+
for substr in substrs:
|
| 64 |
+
new_str += "\\frac"
|
| 65 |
+
if substr[0] == "{":
|
| 66 |
+
new_str += substr
|
| 67 |
+
else:
|
| 68 |
+
try:
|
| 69 |
+
assert len(substr) >= 2
|
| 70 |
+
except Exception:
|
| 71 |
+
return string
|
| 72 |
+
a = substr[0]
|
| 73 |
+
b = substr[1]
|
| 74 |
+
if b != "{":
|
| 75 |
+
if len(substr) > 2:
|
| 76 |
+
post_substr = substr[2:]
|
| 77 |
+
new_str += "{" + a + "}{" + b + "}" + post_substr
|
| 78 |
+
else:
|
| 79 |
+
new_str += "{" + a + "}{" + b + "}"
|
| 80 |
+
else:
|
| 81 |
+
if len(substr) > 2:
|
| 82 |
+
post_substr = substr[2:]
|
| 83 |
+
new_str += "{" + a + "}" + b + post_substr
|
| 84 |
+
else:
|
| 85 |
+
new_str += "{" + a + "}" + b
|
| 86 |
+
string = new_str
|
| 87 |
+
return string
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _fix_a_slash_b(string):
|
| 91 |
+
if len(string.split("/")) != 2:
|
| 92 |
+
return string
|
| 93 |
+
a = string.split("/")[0]
|
| 94 |
+
b = string.split("/")[1]
|
| 95 |
+
try:
|
| 96 |
+
a = int(a)
|
| 97 |
+
b = int(b)
|
| 98 |
+
assert string == "{}/{}".format(a, b)
|
| 99 |
+
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
| 100 |
+
return new_string
|
| 101 |
+
except Exception:
|
| 102 |
+
return string
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _remove_right_units(string):
|
| 106 |
+
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
| 107 |
+
if "\\text{ " in string:
|
| 108 |
+
splits = string.split("\\text{ ")
|
| 109 |
+
assert len(splits) == 2
|
| 110 |
+
return splits[0]
|
| 111 |
+
else:
|
| 112 |
+
return string
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _fix_sqrt(string):
|
| 116 |
+
if "\\sqrt" not in string:
|
| 117 |
+
return string
|
| 118 |
+
splits = string.split("\\sqrt")
|
| 119 |
+
new_string = splits[0]
|
| 120 |
+
for split in splits[1:]:
|
| 121 |
+
if split[0] != "{":
|
| 122 |
+
a = split[0]
|
| 123 |
+
new_substr = "\\sqrt{" + a + "}" + split[1:]
|
| 124 |
+
else:
|
| 125 |
+
new_substr = "\\sqrt" + split
|
| 126 |
+
new_string += new_substr
|
| 127 |
+
return new_string
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _strip_string(string):
|
| 131 |
+
# linebreaks
|
| 132 |
+
string = string.replace("\n", "")
|
| 133 |
+
|
| 134 |
+
# remove inverse spaces
|
| 135 |
+
string = string.replace("\\!", "")
|
| 136 |
+
|
| 137 |
+
# replace \\ with \
|
| 138 |
+
string = string.replace("\\\\", "\\")
|
| 139 |
+
|
| 140 |
+
# replace tfrac and dfrac with frac
|
| 141 |
+
string = string.replace("tfrac", "frac")
|
| 142 |
+
string = string.replace("dfrac", "frac")
|
| 143 |
+
|
| 144 |
+
# remove \left and \right
|
| 145 |
+
string = string.replace("\\left", "")
|
| 146 |
+
string = string.replace("\\right", "")
|
| 147 |
+
|
| 148 |
+
# Remove circ (degrees)
|
| 149 |
+
string = string.replace("^{\\circ}", "")
|
| 150 |
+
string = string.replace("^\\circ", "")
|
| 151 |
+
|
| 152 |
+
# remove dollar signs
|
| 153 |
+
string = string.replace("\\$", "")
|
| 154 |
+
|
| 155 |
+
# remove units (on the right)
|
| 156 |
+
string = _remove_right_units(string)
|
| 157 |
+
|
| 158 |
+
# remove percentage
|
| 159 |
+
string = string.replace("\\\\%", "")
|
| 160 |
+
string = string.replace("\\%", "")
|
| 161 |
+
|
| 162 |
+
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
| 163 |
+
string = string.replace(" .", " 0.")
|
| 164 |
+
string = string.replace("{.", "{0.")
|
| 165 |
+
# if empty, return empty string
|
| 166 |
+
if len(string) == 0:
|
| 167 |
+
return string
|
| 168 |
+
if string[0] == ".":
|
| 169 |
+
string = "0" + string
|
| 170 |
+
|
| 171 |
+
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
| 172 |
+
if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2:
|
| 173 |
+
string = string.split("=")[1]
|
| 174 |
+
|
| 175 |
+
# fix sqrt3 --> sqrt{3}
|
| 176 |
+
string = _fix_sqrt(string)
|
| 177 |
+
|
| 178 |
+
# remove spaces
|
| 179 |
+
string = string.replace(" ", "")
|
| 180 |
+
|
| 181 |
+
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1).
|
| 182 |
+
# Also does a/b --> \\frac{a}{b}
|
| 183 |
+
string = _fix_fracs(string)
|
| 184 |
+
|
| 185 |
+
# manually change 0.5 --> \frac{1}{2}
|
| 186 |
+
if string == "0.5":
|
| 187 |
+
string = "\\frac{1}{2}"
|
| 188 |
+
|
| 189 |
+
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
| 190 |
+
string = _fix_a_slash_b(string)
|
| 191 |
+
|
| 192 |
+
return string
|
ICL/DAPO/verl-recipe/fault_recover/agent_loop/fault_recover_agent_loop.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
from typing import Any, Optional
|
| 17 |
+
from uuid import uuid4
|
| 18 |
+
|
| 19 |
+
import ray
|
| 20 |
+
from omegaconf import DictConfig
|
| 21 |
+
|
| 22 |
+
from verl.experimental.agent_loop.agent_loop import AgentLoopManager, AgentLoopWorker, AsyncLLMServerManager
|
| 23 |
+
from verl.single_controller.ray.base import RayResourcePool, RayWorkerGroup
|
| 24 |
+
from verl.utils.rollout_trace import rollout_trace_op
|
| 25 |
+
from verl.workers.rollout.replica import TokenOutput
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__file__)
|
| 28 |
+
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class FaultRecoverAsyncLLMServerManager(AsyncLLMServerManager):
|
| 32 |
+
"""
|
| 33 |
+
A class to manage multiple OpenAI compatible LLM servers. This class provides
|
| 34 |
+
- Load balance: least requests load balancing
|
| 35 |
+
- Sticky session: send multi-turn chat completions to same server for automatic prefix caching
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
@rollout_trace_op
|
| 39 |
+
async def generate(
|
| 40 |
+
self,
|
| 41 |
+
request_id,
|
| 42 |
+
*,
|
| 43 |
+
prompt_ids: list[int],
|
| 44 |
+
sampling_params: dict[str, Any],
|
| 45 |
+
image_data: Optional[list[Any]] = None,
|
| 46 |
+
video_data: Optional[list[Any]] = None,
|
| 47 |
+
global_id: int = None,
|
| 48 |
+
) -> TokenOutput:
|
| 49 |
+
"""Generate tokens from prompt ids.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
request_id (str): request id for sticky session.
|
| 53 |
+
prompt_ids (List[int]): List of prompt token ids.
|
| 54 |
+
sampling_params (Dict[str, Any]): Sampling parameters for the chat completion.
|
| 55 |
+
global_id: Global batch id of req.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
TokenOutput: token output
|
| 59 |
+
"""
|
| 60 |
+
server = self._choose_server(request_id)
|
| 61 |
+
new_request_id = uuid4().hex
|
| 62 |
+
tokens_queue = None
|
| 63 |
+
if global_id is not None:
|
| 64 |
+
from recipe.fault_recover.fault_manager import get_tokens_queue
|
| 65 |
+
|
| 66 |
+
tokens_queue = get_tokens_queue()
|
| 67 |
+
|
| 68 |
+
if tokens_queue is not None:
|
| 69 |
+
await tokens_queue.put.remote((new_request_id, global_id))
|
| 70 |
+
|
| 71 |
+
output = await server.generate.remote(
|
| 72 |
+
request_id=new_request_id, # use new request_id for each turn
|
| 73 |
+
prompt_ids=prompt_ids,
|
| 74 |
+
sampling_params=sampling_params,
|
| 75 |
+
image_data=image_data,
|
| 76 |
+
video_data=video_data,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
if tokens_queue is not None:
|
| 80 |
+
await tokens_queue.put.remote(
|
| 81 |
+
{
|
| 82 |
+
new_request_id: {
|
| 83 |
+
"log_probs": output.log_probs,
|
| 84 |
+
"routed_experts": output.routed_experts,
|
| 85 |
+
"num_preempted": output.num_preempted,
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
return output
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class FaultRecoverAgentLoopWorker(AgentLoopWorker):
|
| 94 |
+
"""Agent loop worker takes a batch of messages and run each message in an agent loop."""
|
| 95 |
+
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
config: DictConfig,
|
| 99 |
+
server_handles: list[ray.actor.ActorHandle],
|
| 100 |
+
reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
|
| 101 |
+
):
|
| 102 |
+
super().__init__(config, server_handles, reward_loop_worker_handles)
|
| 103 |
+
self.server_manager = FaultRecoverAsyncLLMServerManager(config, server_handles)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class FaultRecoverAgentLoopManager(AgentLoopManager):
|
| 107 |
+
"""Agent loop manager that manages a group of agent loop workers."""
|
| 108 |
+
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
config: DictConfig,
|
| 112 |
+
worker_group: RayWorkerGroup = None,
|
| 113 |
+
rollout_resource_pool: RayResourcePool = None,
|
| 114 |
+
reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
|
| 115 |
+
):
|
| 116 |
+
"""Initialize agent loop manager.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
config (DictConfig): trainer config.
|
| 120 |
+
worker_group (RayWorkerGroup): ActorRolloutRef worker group for hybrid mode; None for standalone mode.
|
| 121 |
+
rollout_resource_pool (RayResourcePool): Resource pool for actor rollout (Colocate or Standalone mode).
|
| 122 |
+
reward_loop_worker_handles (List[ray.actor.ActorHandle]): Actor handles for streaming reward computation.
|
| 123 |
+
"""
|
| 124 |
+
self.config = config
|
| 125 |
+
self.worker_group = worker_group
|
| 126 |
+
self.reward_loop_worker_handles = reward_loop_worker_handles
|
| 127 |
+
|
| 128 |
+
# for recipe to change
|
| 129 |
+
if not hasattr(self, "rollout_replica_class"):
|
| 130 |
+
from recipe.fault_recover.vllm_rollout.vllm_async_server import FaultRecovervLLMReplica
|
| 131 |
+
|
| 132 |
+
self.rollout_replica_class = FaultRecovervLLMReplica
|
| 133 |
+
if not hasattr(self, "agent_loop_workers_class"):
|
| 134 |
+
self.agent_loop_workers_class = ray.remote(FaultRecoverAgentLoopWorker)
|
| 135 |
+
|
| 136 |
+
self._initialize_llm_servers(rollout_resource_pool)
|
| 137 |
+
self._init_agent_loop_workers()
|
ICL/DAPO/verl-recipe/spo/agent_loop/spo_agent_loop.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Modifications Copyright 2025 SPO authors
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
SPO Agent Loop - Extends base agent loop with code generation support.
|
| 17 |
+
|
| 18 |
+
This module inherits from verl.experimental.agent_loop and only overrides
|
| 19 |
+
the generate_sequences method to add SPO-specific stop tokens for code generation.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import asyncio
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import ray
|
| 26 |
+
|
| 27 |
+
from verl import DataProto
|
| 28 |
+
|
| 29 |
+
# Re-export all base classes for backward compatibility
|
| 30 |
+
from verl.experimental.agent_loop.agent_loop import AgentLoopManager, get_trajectory_info
|
| 31 |
+
from verl.experimental.agent_loop.agent_loop import (
|
| 32 |
+
AgentLoopWorkerBase as BaseAgentLoopWorkerBase,
|
| 33 |
+
)
|
| 34 |
+
from verl.utils.transferqueue_utils import tqbridge
|
| 35 |
+
|
| 36 |
+
__all__ = [
|
| 37 |
+
"AgentLoopWorkerBase",
|
| 38 |
+
"SPOAgentLoopWorker",
|
| 39 |
+
"SPOAgentLoopManager",
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class AgentLoopWorkerBase(BaseAgentLoopWorkerBase):
|
| 44 |
+
"""SPO-specific agent loop worker with code generation stop tokens.
|
| 45 |
+
|
| 46 |
+
Inherits all functionality from base AgentLoopWorkerBase and only overrides
|
| 47 |
+
the generate_sequences method to add SPO-specific parameters:
|
| 48 |
+
- stop="</code>" for code block termination
|
| 49 |
+
- include_stop_str_in_output=True to include the stop token
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
@tqbridge()
|
| 53 |
+
async def generate_sequences(self, batch: DataProto) -> DataProto:
|
| 54 |
+
"""Generate sequences from agent loop with SPO-specific stop tokens.
|
| 55 |
+
|
| 56 |
+
Override: Adds stop="</code>" and include_stop_str_in_output=True
|
| 57 |
+
to sampling_params for SPO code generation use case.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
batch (DataProto): Input batch.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
DataProto: Output batch.
|
| 64 |
+
- prompts: [bsz, prompt_length], prompt token ids from dataset.
|
| 65 |
+
- responses: [bsz, response_length], output token ids include response tokens
|
| 66 |
+
from LLM generation and observation tokens from tool_calls.
|
| 67 |
+
- response_mask: [bsz, response_length], 1 for LLM generated tokens, 0 for observation/padding tokens.
|
| 68 |
+
- input_ids: [bsz, prompt_length + response_length], whole sequence token ids, including prompt tokens
|
| 69 |
+
and response tokens.
|
| 70 |
+
- attention_mask: [bsz, prompt_length + response_length], 0 for padding tokens, 1 for other tokens.
|
| 71 |
+
- position_ids: [bsz, prompt_length + response_length], incremental position ids.
|
| 72 |
+
|
| 73 |
+
For multi-turn conversations:
|
| 74 |
+
responses: |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->|
|
| 75 |
+
response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0|
|
| 76 |
+
"""
|
| 77 |
+
config = self.config.actor_rollout_ref.rollout
|
| 78 |
+
|
| 79 |
+
# SPO-specific: Add stop tokens for code generation
|
| 80 |
+
sampling_params = dict(
|
| 81 |
+
temperature=config.temperature,
|
| 82 |
+
top_p=config.top_p,
|
| 83 |
+
repetition_penalty=1.0,
|
| 84 |
+
logprobs=config.calculate_log_probs,
|
| 85 |
+
stop="</code>", # SPO-SPECIFIC
|
| 86 |
+
include_stop_str_in_output=True, # SPO-SPECIFIC
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# override sampling params for validation
|
| 90 |
+
if batch.meta_info.get("validate", False):
|
| 91 |
+
sampling_params["top_p"] = config.val_kwargs.top_p
|
| 92 |
+
sampling_params["temperature"] = config.val_kwargs.temperature
|
| 93 |
+
|
| 94 |
+
# by default, we assume it's a single turn agent
|
| 95 |
+
if "agent_name" not in batch.non_tensor_batch:
|
| 96 |
+
default_agent_loop = config.agent.default_agent_loop
|
| 97 |
+
batch.non_tensor_batch["agent_name"] = np.array([default_agent_loop] * len(batch), dtype=object)
|
| 98 |
+
|
| 99 |
+
if "index" in batch.non_tensor_batch:
|
| 100 |
+
index = batch.non_tensor_batch["index"]
|
| 101 |
+
else:
|
| 102 |
+
index = np.arange(len(batch))
|
| 103 |
+
|
| 104 |
+
trajectory_info = await get_trajectory_info(
|
| 105 |
+
batch.meta_info.get("global_steps", -1), index.tolist(), batch.meta_info.get("validate", False)
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
tasks = []
|
| 109 |
+
for i in range(len(batch)):
|
| 110 |
+
kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()}
|
| 111 |
+
tasks.append(asyncio.create_task(self._run_agent_loop(sampling_params, trajectory_info[i], **kwargs)))
|
| 112 |
+
outputs = await asyncio.gather(*tasks)
|
| 113 |
+
|
| 114 |
+
output = self._postprocess(outputs)
|
| 115 |
+
return output
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@ray.remote
|
| 119 |
+
class SPOAgentLoopWorker(AgentLoopWorkerBase):
|
| 120 |
+
"""SPO Agent Loop Worker as a Ray remote actor.
|
| 121 |
+
|
| 122 |
+
This is a Ray remote actor wrapper around AgentLoopWorkerBase,
|
| 123 |
+
enabling distributed execution with SPO-specific stop tokens.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(self, config, server_handles, reward_router_address=None):
|
| 127 |
+
"""Initialize SPO Agent Loop Worker.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
config: trainer config.
|
| 131 |
+
server_handles: OpenAI compatible LLM server actor handles.
|
| 132 |
+
reward_router_address: reward router address.
|
| 133 |
+
"""
|
| 134 |
+
super().__init__(config, server_handles, reward_router_address)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class SPOAgentLoopManager(AgentLoopManager):
|
| 138 |
+
"""SPO-specific Agent Loop Manager that uses SPO's AgentLoopWorker.
|
| 139 |
+
|
| 140 |
+
Inherits all functionality from base AgentLoopManager and only overrides
|
| 141 |
+
the agent_loop_workers_class to use SPOAgentLoopWorker which includes
|
| 142 |
+
code generation stop tokens.
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
def __init__(self, config, worker_group=None, rm_wg=None):
|
| 146 |
+
"""Initialize SPO Agent Loop Manager.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
config: trainer config.
|
| 150 |
+
worker_group: ActorRolloutRef worker group for hybrid mode; None for standalone mode.
|
| 151 |
+
rm_wg: Reward model worker group.
|
| 152 |
+
"""
|
| 153 |
+
# Set SPO-specific worker class before calling parent __init__
|
| 154 |
+
self.agent_loop_workers_class = SPOAgentLoopWorker
|
| 155 |
+
super().__init__(config, worker_group, rm_wg)
|
ICL/DAPO/verl-recipe/spo/agent_loop/spo_tool_agent_loop.py
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Modifications Copyright 2025 SPO authors
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import asyncio
|
| 16 |
+
import copy
|
| 17 |
+
import logging
|
| 18 |
+
import os
|
| 19 |
+
from typing import Any, Optional
|
| 20 |
+
from uuid import uuid4
|
| 21 |
+
|
| 22 |
+
from verl.experimental.agent_loop.agent_loop import (
|
| 23 |
+
AgentLoopBase,
|
| 24 |
+
AgentLoopOutput,
|
| 25 |
+
register,
|
| 26 |
+
)
|
| 27 |
+
from verl.experimental.agent_loop.tool_agent_loop import AgentState
|
| 28 |
+
from verl.interactions.base import BaseInteraction
|
| 29 |
+
from verl.interactions.utils.interaction_registry import (
|
| 30 |
+
initialize_interactions_from_config,
|
| 31 |
+
)
|
| 32 |
+
from verl.tools.schemas import ToolResponse
|
| 33 |
+
from verl.tools.utils.tool_registry import initialize_tools_from_config
|
| 34 |
+
from verl.utils.profiler import simple_timer
|
| 35 |
+
from verl.utils.rollout_trace import rollout_trace_op
|
| 36 |
+
|
| 37 |
+
logger = logging.getLogger(__file__)
|
| 38 |
+
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class AgentData:
|
| 42 |
+
"""Encapsulates all state variables for the agent loop."""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
messages: list[dict[str, Any]],
|
| 47 |
+
image_data: Any,
|
| 48 |
+
metrics: dict[str, Any],
|
| 49 |
+
request_id: str,
|
| 50 |
+
tools_kwargs: dict[str, Any],
|
| 51 |
+
interaction: Optional[BaseInteraction] = None,
|
| 52 |
+
interaction_kwargs: Optional[dict[str, Any]] = None,
|
| 53 |
+
):
|
| 54 |
+
self.messages = messages
|
| 55 |
+
self.image_data = image_data
|
| 56 |
+
self.metrics = metrics
|
| 57 |
+
self.request_id = request_id
|
| 58 |
+
self.tools_kwargs = tools_kwargs
|
| 59 |
+
self.interaction = interaction
|
| 60 |
+
self.interaction_kwargs = interaction_kwargs or {}
|
| 61 |
+
|
| 62 |
+
# State variables
|
| 63 |
+
self.prompt_ids: list[int] = []
|
| 64 |
+
self.response_ids: list[int] = []
|
| 65 |
+
self.response_mask: list[int] = []
|
| 66 |
+
self.response_logprobs: list[float] = []
|
| 67 |
+
self.turn_scores: list[float] = []
|
| 68 |
+
self.tool_rewards: list[float] = []
|
| 69 |
+
self.user_turns = 0
|
| 70 |
+
self.assistant_turns = 0
|
| 71 |
+
|
| 72 |
+
# Temporary state for tool calls
|
| 73 |
+
self.tool_calls: list[str] = [] # Raw Python code strings extracted from <code> tags
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@register("spo_tool_agent")
|
| 77 |
+
class SPOToolAgentLoop(AgentLoopBase):
|
| 78 |
+
@classmethod
|
| 79 |
+
def init_class(cls, config, tokenizer, processor, **kwargs):
|
| 80 |
+
if cls._class_initialized:
|
| 81 |
+
return
|
| 82 |
+
cls._class_initialized = True
|
| 83 |
+
print("Performing class-level ToolAgentLoop initialization")
|
| 84 |
+
|
| 85 |
+
# Initialize tools from config file
|
| 86 |
+
cls.tokenizer = tokenizer
|
| 87 |
+
cls.processor = processor
|
| 88 |
+
cls.max_user_turns = config.actor_rollout_ref.rollout.multi_turn.max_user_turns
|
| 89 |
+
cls.max_assistant_turns = config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns
|
| 90 |
+
cls.max_parallel_calls = config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls
|
| 91 |
+
cls.max_tool_response_length = config.actor_rollout_ref.rollout.multi_turn.max_tool_response_length
|
| 92 |
+
cls.tool_response_truncate_side = config.actor_rollout_ref.rollout.multi_turn.tool_response_truncate_side
|
| 93 |
+
tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path
|
| 94 |
+
tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else []
|
| 95 |
+
cls.tools = {tool.name: tool for tool in tool_list}
|
| 96 |
+
cls.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list]
|
| 97 |
+
print(f"Initialized tools: {cls.tools}")
|
| 98 |
+
|
| 99 |
+
cls.apply_chat_template_kwargs = config.data.get("apply_chat_template_kwargs", {})
|
| 100 |
+
cls.prompt_length = config.actor_rollout_ref.rollout.prompt_length
|
| 101 |
+
cls.response_length = config.actor_rollout_ref.rollout.response_length
|
| 102 |
+
cls.system_prompt = tokenizer.apply_chat_template(
|
| 103 |
+
[{}], add_generation_prompt=False, tokenize=True, **cls.apply_chat_template_kwargs
|
| 104 |
+
)
|
| 105 |
+
# Initialize interactions from config file
|
| 106 |
+
cls.interaction_config_file = config.actor_rollout_ref.rollout.multi_turn.interaction_config_path
|
| 107 |
+
if cls.interaction_config_file:
|
| 108 |
+
cls.interaction_map: dict[str, BaseInteraction] = cls._initialize_interactions(cls.interaction_config_file)
|
| 109 |
+
|
| 110 |
+
@rollout_trace_op
|
| 111 |
+
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
|
| 112 |
+
messages = list(kwargs["raw_prompt"])
|
| 113 |
+
image_data = copy.deepcopy(kwargs.get("multi_modal_data", {}).get("image", None))
|
| 114 |
+
metrics = {}
|
| 115 |
+
request_id = uuid4().hex
|
| 116 |
+
tools_kwargs = kwargs.get("tools_kwargs", {})
|
| 117 |
+
|
| 118 |
+
# Initialize interaction if needed
|
| 119 |
+
interaction = None
|
| 120 |
+
interaction_kwargs = {}
|
| 121 |
+
if self.interaction_config_file:
|
| 122 |
+
interaction_kwargs = kwargs["extra_info"]["interaction_kwargs"]
|
| 123 |
+
if "name" not in interaction_kwargs:
|
| 124 |
+
raise ValueError("'name' key is required in interaction_kwargs")
|
| 125 |
+
interaction_name = interaction_kwargs["name"]
|
| 126 |
+
if interaction_name not in self.interaction_map:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
f"Interaction '{interaction_name}' not found in interaction_map. Available interactions: "
|
| 129 |
+
f"{list(self.interaction_map.keys())}"
|
| 130 |
+
)
|
| 131 |
+
interaction = self.interaction_map[interaction_name]
|
| 132 |
+
await interaction.start_interaction(request_id, **interaction_kwargs)
|
| 133 |
+
# Create AgentData instance to encapsulate all state
|
| 134 |
+
agent_data = AgentData(
|
| 135 |
+
messages=messages,
|
| 136 |
+
image_data=image_data,
|
| 137 |
+
metrics=metrics,
|
| 138 |
+
request_id=request_id,
|
| 139 |
+
tools_kwargs=tools_kwargs,
|
| 140 |
+
interaction=interaction,
|
| 141 |
+
interaction_kwargs=interaction_kwargs,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# State machine loop
|
| 145 |
+
state = AgentState.PENDING
|
| 146 |
+
while state != AgentState.TERMINATED:
|
| 147 |
+
if state == AgentState.PENDING:
|
| 148 |
+
state = await self._handle_pending_state(agent_data, sampling_params)
|
| 149 |
+
elif state == AgentState.GENERATING:
|
| 150 |
+
state = await self._handle_generating_state(agent_data, sampling_params)
|
| 151 |
+
elif state == AgentState.PROCESSING_TOOLS:
|
| 152 |
+
state = await self._handle_processing_tools_state(agent_data)
|
| 153 |
+
elif state == AgentState.INTERACTING:
|
| 154 |
+
state = await self._handle_interacting_state(agent_data)
|
| 155 |
+
else:
|
| 156 |
+
logger.error(f"Invalid state: {state}")
|
| 157 |
+
state = AgentState.TERMINATED
|
| 158 |
+
|
| 159 |
+
# Finalize output
|
| 160 |
+
response_ids = agent_data.prompt_ids[-len(agent_data.response_mask) :]
|
| 161 |
+
prompt_ids = agent_data.prompt_ids[: len(agent_data.prompt_ids) - len(agent_data.response_mask)]
|
| 162 |
+
multi_modal_data = {"image": agent_data.image_data} if agent_data.image_data is not None else {}
|
| 163 |
+
output = AgentLoopOutput(
|
| 164 |
+
prompt_ids=prompt_ids,
|
| 165 |
+
response_ids=response_ids[: self.response_length],
|
| 166 |
+
response_mask=agent_data.response_mask[: self.response_length],
|
| 167 |
+
multi_modal_data=multi_modal_data,
|
| 168 |
+
response_logprobs=agent_data.response_logprobs[: self.response_length]
|
| 169 |
+
if agent_data.response_logprobs
|
| 170 |
+
else None,
|
| 171 |
+
num_turns=agent_data.user_turns + agent_data.assistant_turns + 1,
|
| 172 |
+
metrics=agent_data.metrics,
|
| 173 |
+
extra_fields={},
|
| 174 |
+
)
|
| 175 |
+
output.extra_fields.update({"turn_scores": agent_data.turn_scores, "tool_rewards": agent_data.tool_rewards})
|
| 176 |
+
return output
|
| 177 |
+
|
| 178 |
+
def _extract_code_blocks(self, response_ids: list[int]) -> list[str]:
|
| 179 |
+
"""Extract Python code from <code>...</code> tags in response.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
response_ids: Token IDs from model response
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
List of cleaned Python code strings
|
| 186 |
+
"""
|
| 187 |
+
import re
|
| 188 |
+
|
| 189 |
+
# Decode token IDs to text
|
| 190 |
+
response_text = self.tokenizer.decode(response_ids, skip_special_tokens=False)
|
| 191 |
+
|
| 192 |
+
# Extract all code blocks between <code> and </code> tags
|
| 193 |
+
pattern = r"<code>(.*?)</code>"
|
| 194 |
+
matches = re.findall(pattern, response_text, re.DOTALL)
|
| 195 |
+
|
| 196 |
+
# Clean each code block (remove markdown fences, strip whitespace)
|
| 197 |
+
cleaned_codes = []
|
| 198 |
+
for match in matches:
|
| 199 |
+
# Remove markdown code fences if present
|
| 200 |
+
cleaned = re.sub(r"^```(?:python)?\s*\n?", "", match.strip())
|
| 201 |
+
cleaned = re.sub(r"\n?```\s*$", "", cleaned)
|
| 202 |
+
cleaned_codes.append(cleaned.strip())
|
| 203 |
+
|
| 204 |
+
return cleaned_codes
|
| 205 |
+
|
| 206 |
+
async def _handle_pending_state(self, agent_data: AgentData, sampling_params: dict[str, Any]) -> AgentState:
|
| 207 |
+
"""Handle the pending state: prepare the prompt and start generation."""
|
| 208 |
+
problem = agent_data.messages[0]["content"]
|
| 209 |
+
user_prompt = (
|
| 210 |
+
"Solve the following problem step by step. "
|
| 211 |
+
"You now have the ability to selectively write executable Python code to enhance your reasoning process. "
|
| 212 |
+
"The Python code will be executed by an external sandbox, and the output "
|
| 213 |
+
"(wrapped in `<interpreter>output_str</interpreter>`)"
|
| 214 |
+
" can be returned to aid your reasoning and help you arrive at the final answer. "
|
| 215 |
+
"The Python code should be complete scripts, including necessary imports. "
|
| 216 |
+
"Important: The sandbox is stateless and non-interactive; thus, prior imports, definitions, "
|
| 217 |
+
"and state do not persist between executions and cannot be referenced.\n"
|
| 218 |
+
"Each code snippet is wrapped with `<code>\n```python\ncode snippet\n```\n</code>`.\n"
|
| 219 |
+
)
|
| 220 |
+
user_prompt += "*user question:*\n"
|
| 221 |
+
user_prompt += problem
|
| 222 |
+
messages = [{"role": "user", "content": user_prompt}]
|
| 223 |
+
agent_data.prompt_ids = await self.loop.run_in_executor(
|
| 224 |
+
None,
|
| 225 |
+
lambda: self.tokenizer.apply_chat_template(
|
| 226 |
+
messages, add_generation_prompt=True, tokenize=True, **self.apply_chat_template_kwargs
|
| 227 |
+
),
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
return AgentState.GENERATING
|
| 231 |
+
|
| 232 |
+
async def _handle_generating_state(
|
| 233 |
+
self, agent_data: AgentData, sampling_params: dict[str, Any], ignore_termination: bool = False
|
| 234 |
+
) -> AgentState:
|
| 235 |
+
"""Handle the generating state: generate model response and check for tool calls."""
|
| 236 |
+
add_messages: list[dict[str, Any]] = []
|
| 237 |
+
|
| 238 |
+
with simple_timer("generate_sequences", agent_data.metrics):
|
| 239 |
+
output = await self.server_manager.generate(
|
| 240 |
+
request_id=agent_data.request_id,
|
| 241 |
+
prompt_ids=agent_data.prompt_ids,
|
| 242 |
+
sampling_params=sampling_params,
|
| 243 |
+
image_data=agent_data.image_data,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
agent_data.assistant_turns += 1
|
| 247 |
+
agent_data.response_ids = output.token_ids
|
| 248 |
+
agent_data.prompt_ids += agent_data.response_ids
|
| 249 |
+
agent_data.response_mask += [1] * len(agent_data.response_ids)
|
| 250 |
+
if output.log_probs:
|
| 251 |
+
agent_data.response_logprobs += output.log_probs
|
| 252 |
+
|
| 253 |
+
# Check termination conditions
|
| 254 |
+
if not ignore_termination and len(agent_data.response_mask) >= self.response_length:
|
| 255 |
+
return AgentState.TERMINATED
|
| 256 |
+
if self.max_assistant_turns and agent_data.assistant_turns >= self.max_assistant_turns:
|
| 257 |
+
return AgentState.TERMINATED
|
| 258 |
+
if self.max_user_turns and agent_data.user_turns >= self.max_user_turns:
|
| 259 |
+
return AgentState.TERMINATED
|
| 260 |
+
|
| 261 |
+
# Extract code blocks from <code> tags
|
| 262 |
+
agent_data.tool_calls = self._extract_code_blocks(agent_data.response_ids)
|
| 263 |
+
|
| 264 |
+
# Handle interaction if needed
|
| 265 |
+
if self.interaction_config_file:
|
| 266 |
+
assistant_message = await self.loop.run_in_executor(
|
| 267 |
+
None, lambda: self.tokenizer.decode(agent_data.response_ids, skip_special_tokens=True)
|
| 268 |
+
)
|
| 269 |
+
add_messages.append({"role": "assistant", "content": assistant_message})
|
| 270 |
+
agent_data.messages.extend(add_messages)
|
| 271 |
+
|
| 272 |
+
# Determine next state
|
| 273 |
+
if agent_data.tool_calls:
|
| 274 |
+
return AgentState.PROCESSING_TOOLS
|
| 275 |
+
elif self.interaction_config_file:
|
| 276 |
+
return AgentState.INTERACTING
|
| 277 |
+
else:
|
| 278 |
+
return AgentState.TERMINATED
|
| 279 |
+
|
| 280 |
+
async def _handle_processing_tools_state(self, agent_data: AgentData) -> AgentState:
|
| 281 |
+
"""Handle the processing tools state: execute tool calls and prepare tool responses."""
|
| 282 |
+
tasks = []
|
| 283 |
+
tool_call_names = []
|
| 284 |
+
for tool_call in agent_data.tool_calls[: self.max_parallel_calls]:
|
| 285 |
+
tasks.append(self._call_tool(tool_call, agent_data.tools_kwargs))
|
| 286 |
+
tool_call_names.append("code_interpreter")
|
| 287 |
+
|
| 288 |
+
with simple_timer("tool_calls", agent_data.metrics):
|
| 289 |
+
responses = await asyncio.gather(*tasks)
|
| 290 |
+
|
| 291 |
+
response_ids = await self.loop.run_in_executor(
|
| 292 |
+
None, lambda: self.tokenizer.encode(responses[0].text or "", add_special_tokens=False)
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
if len(agent_data.response_mask) + len(response_ids) >= self.response_length:
|
| 296 |
+
return AgentState.TERMINATED
|
| 297 |
+
# Update prompt_ids and response_mask
|
| 298 |
+
agent_data.prompt_ids += response_ids
|
| 299 |
+
agent_data.response_mask += [0] * len(response_ids)
|
| 300 |
+
if agent_data.response_logprobs:
|
| 301 |
+
agent_data.response_logprobs += [0.0] * len(response_ids)
|
| 302 |
+
agent_data.user_turns += 1
|
| 303 |
+
# Change agent_data.request_id to avoid caching issues
|
| 304 |
+
agent_data.request_id = uuid4().hex
|
| 305 |
+
return AgentState.GENERATING
|
| 306 |
+
|
| 307 |
+
async def _handle_interacting_state(self, agent_data: AgentData) -> AgentState:
|
| 308 |
+
"""Handle the interacting state: get user input from interaction."""
|
| 309 |
+
(
|
| 310 |
+
should_terminate_sequence,
|
| 311 |
+
interaction_responses,
|
| 312 |
+
reward,
|
| 313 |
+
metrics,
|
| 314 |
+
) = await agent_data.interaction.generate_response(
|
| 315 |
+
agent_data.request_id, agent_data.messages, **agent_data.interaction_kwargs
|
| 316 |
+
)
|
| 317 |
+
agent_data.user_turns += 1
|
| 318 |
+
|
| 319 |
+
add_messages: list[dict[str, Any]] = [{"role": "user", "content": interaction_responses}]
|
| 320 |
+
agent_data.messages.extend(add_messages)
|
| 321 |
+
|
| 322 |
+
if reward is not None:
|
| 323 |
+
agent_data.turn_scores.append(reward)
|
| 324 |
+
|
| 325 |
+
# Update prompt with user responses (similar to _handle_processing_tools_state)
|
| 326 |
+
if self.processor is not None:
|
| 327 |
+
raw_user_response = await self.loop.run_in_executor(
|
| 328 |
+
None,
|
| 329 |
+
lambda: self.processor.apply_chat_template(
|
| 330 |
+
add_messages,
|
| 331 |
+
add_generation_prompt=True,
|
| 332 |
+
tokenize=False,
|
| 333 |
+
**self.apply_chat_template_kwargs,
|
| 334 |
+
),
|
| 335 |
+
)
|
| 336 |
+
model_inputs = self.processor(text=[raw_user_response], images=None, return_tensors="pt")
|
| 337 |
+
response_ids = model_inputs.pop("input_ids").squeeze(0).tolist()
|
| 338 |
+
else:
|
| 339 |
+
response_ids = await self.loop.run_in_executor(
|
| 340 |
+
None,
|
| 341 |
+
lambda: self.tokenizer.apply_chat_template(add_messages, add_generation_prompt=True, tokenize=True),
|
| 342 |
+
)
|
| 343 |
+
response_ids = response_ids[len(self.system_prompt) :]
|
| 344 |
+
|
| 345 |
+
# Update prompt_ids and response_mask
|
| 346 |
+
agent_data.prompt_ids += response_ids
|
| 347 |
+
agent_data.response_mask += [0] * len(response_ids)
|
| 348 |
+
if agent_data.response_logprobs:
|
| 349 |
+
agent_data.response_logprobs += [0.0] * len(response_ids)
|
| 350 |
+
|
| 351 |
+
# double check prompt
|
| 352 |
+
# Check termination condition
|
| 353 |
+
if should_terminate_sequence:
|
| 354 |
+
return AgentState.TERMINATED
|
| 355 |
+
else:
|
| 356 |
+
return AgentState.GENERATING
|
| 357 |
+
|
| 358 |
+
async def _call_tool(self, tool_call: str, tools_kwargs: dict[str, Any]) -> tuple[ToolResponse, float, dict]:
|
| 359 |
+
"""Call tool and return tool response."""
|
| 360 |
+
tool, instance_id = None, None
|
| 361 |
+
try:
|
| 362 |
+
tool = self.tools["code_interpreter"]
|
| 363 |
+
instance_id, _ = await tool.create(create_kwargs={})
|
| 364 |
+
|
| 365 |
+
tool_execution_response, _, _ = await tool.execute(instance_id, tool_call)
|
| 366 |
+
except Exception as e:
|
| 367 |
+
logger.warning(f"Error when executing tool: {e}")
|
| 368 |
+
return (
|
| 369 |
+
ToolResponse(
|
| 370 |
+
text=f"Error when executing tool: {e}",
|
| 371 |
+
),
|
| 372 |
+
0.0,
|
| 373 |
+
{},
|
| 374 |
+
)
|
| 375 |
+
finally:
|
| 376 |
+
if tool and instance_id:
|
| 377 |
+
await tool.release(instance_id)
|
| 378 |
+
|
| 379 |
+
tool_response_text = tool_execution_response.text
|
| 380 |
+
if tool_response_text and len(tool_response_text) > self.max_tool_response_length:
|
| 381 |
+
if self.tool_response_truncate_side == "left":
|
| 382 |
+
tool_response_text = tool_response_text[: self.max_tool_response_length] + "...(truncated)"
|
| 383 |
+
elif self.tool_response_truncate_side == "right":
|
| 384 |
+
tool_response_text = "(truncated)..." + tool_response_text[-self.max_tool_response_length :]
|
| 385 |
+
else:
|
| 386 |
+
length = self.max_tool_response_length // 2
|
| 387 |
+
tool_response_text = tool_response_text[:length] + "...(truncated)..." + tool_response_text[-length:]
|
| 388 |
+
|
| 389 |
+
tool_response_text = f"<interpreter>\n{tool_response_text}\n</interpreter>\n\n"
|
| 390 |
+
|
| 391 |
+
# Create ToolResponse from tool execution result
|
| 392 |
+
tool_response_kwargs = {"text": tool_response_text}
|
| 393 |
+
|
| 394 |
+
# Add multimedia data if present
|
| 395 |
+
for attr_name in ["image", "video"]:
|
| 396 |
+
if hasattr(tool_execution_response, attr_name):
|
| 397 |
+
attr_value = getattr(tool_execution_response, attr_name)
|
| 398 |
+
if attr_value is not None:
|
| 399 |
+
tool_response_kwargs[attr_name] = attr_value
|
| 400 |
+
|
| 401 |
+
return ToolResponse(**tool_response_kwargs)
|
| 402 |
+
|
| 403 |
+
@classmethod
|
| 404 |
+
def _initialize_interactions(cls, interaction_config_file):
|
| 405 |
+
"""Initialize interactions from configuration.
|
| 406 |
+
Returns:
|
| 407 |
+
dict[str, BaseInteraction]: A dictionary mapping interaction names to interaction instances.
|
| 408 |
+
"""
|
| 409 |
+
if interaction_config_file is None:
|
| 410 |
+
return {}
|
| 411 |
+
|
| 412 |
+
interaction_map = initialize_interactions_from_config(interaction_config_file)
|
| 413 |
+
logger.info(f"Initialize interactions from configuration: interaction_map: {list(interaction_map.keys())}")
|
| 414 |
+
return interaction_map
|
ICL/DAPO/verl-recipe/sppo/config/sppo_trainer.yaml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# the sppo config will override default ppo_trainer.yaml
|
| 2 |
+
|
| 3 |
+
hydra:
|
| 4 |
+
searchpath:
|
| 5 |
+
- file://verl/trainer/config
|
| 6 |
+
|
| 7 |
+
defaults:
|
| 8 |
+
- ppo_trainer
|
| 9 |
+
- _self_
|
| 10 |
+
|
| 11 |
+
actor_rollout_ref:
|
| 12 |
+
actor:
|
| 13 |
+
_target_: recipe.sppo.config.SPPOActorConfig
|
| 14 |
+
|
| 15 |
+
# sppo_eta is an additional hyperparameter for SPPO, not available in
|
| 16 |
+
# verl core. specifying _target_ with SPPOActorConfig is needed to
|
| 17 |
+
# extend verl ActorConfig with custom fields.
|
| 18 |
+
# additional, it is also possible to use the `extra` field natively supported
|
| 19 |
+
# by all verl core dataclasses, without having to define SPPOActorConfig
|
| 20 |
+
# extra:
|
| 21 |
+
# sppo_eta: 1.0
|
| 22 |
+
sppo_eta: 1.0
|
| 23 |
+
|
| 24 |
+
optim:
|
| 25 |
+
lr_warmup_steps: 15
|
| 26 |
+
rollout:
|
| 27 |
+
name: sglang
|
| 28 |
+
tensor_model_parallel_size: 2
|
| 29 |
+
gpu_memory_utilization: 0.5
|
| 30 |
+
val_kwargs:
|
| 31 |
+
n: 2 # 2 will trigger validation, 1 will bypass
|
| 32 |
+
|
| 33 |
+
algorithm:
|
| 34 |
+
adv_estimator: null
|
| 35 |
+
sppo_eta: 1.0
|
| 36 |
+
|
| 37 |
+
trainer:
|
| 38 |
+
log_val_generations: 0
|
ICL/EVAL_GUIDE.md
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ICL 模型评测步骤
|
| 2 |
+
|
| 3 |
+
## Step 1: 合并 DeepSpeed checkpoint(safetensors 格式)
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
cd /workspace/xiaobin/ICL
|
| 7 |
+
|
| 8 |
+
python3 sft_model/zero_to_fp32.py \
|
| 9 |
+
sft_model \
|
| 10 |
+
sft_model/merged_hf \
|
| 11 |
+
--safe_serialization
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
## Step 2: 复制 tokenizer 和 config(注意不要复制 model.safetensors.index.json)
|
| 15 |
+
|
| 16 |
+
```bash
|
| 17 |
+
cp /workspace/models/Qwen3-VL-8B-Instruct/config.json sft_model/merged_hf/
|
| 18 |
+
cp /workspace/models/Qwen3-VL-8B-Instruct/generation_config.json sft_model/merged_hf/
|
| 19 |
+
cp /workspace/models/Qwen3-VL-8B-Instruct/preprocessor_config.json sft_model/merged_hf/
|
| 20 |
+
cp /workspace/models/Qwen3-VL-8B-Instruct/chat_template.json sft_model/merged_hf/ 2>/dev/null
|
| 21 |
+
cp /workspace/models/Qwen3-VL-8B-Instruct/tokenizer* sft_model/merged_hf/
|
| 22 |
+
cp /workspace/models/Qwen3-VL-8B-Instruct/merges.txt sft_model/merged_hf/
|
| 23 |
+
cp /workspace/models/Qwen3-VL-8B-Instruct/vocab.json sft_model/merged_hf/
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
## Step 3: 跑评测
|
| 27 |
+
|
| 28 |
+
单卡:
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
python3 eval_icl.py \
|
| 32 |
+
--model-path sft_model/merged_hf \
|
| 33 |
+
--all-categories \
|
| 34 |
+
--num-samples 100 \
|
| 35 |
+
--max-rounds 4 \
|
| 36 |
+
--device cuda:0
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
多卡 (8 GPU):
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
torchrun --nproc_per_node=8 eval_icl.py \
|
| 43 |
+
--model-path sft_model/merged_hf \
|
| 44 |
+
--all-categories \
|
| 45 |
+
--num-samples 100 \
|
| 46 |
+
--max-rounds 4
|
| 47 |
+
```
|
ICL/LV/dataset_inspect.tree.txt
ADDED
|
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
M3IT/
|
| 2 |
+
.git/
|
| 3 |
+
data/
|
| 4 |
+
.gitattributes (2.8KB)
|
| 5 |
+
.gitignore (29.0B)
|
| 6 |
+
M3IT.py (54.5KB)
|
| 7 |
+
README.md (18.3KB)
|
| 8 |
+
branches/
|
| 9 |
+
hooks/
|
| 10 |
+
info/
|
| 11 |
+
lfs/
|
| 12 |
+
logs/
|
| 13 |
+
objects/
|
| 14 |
+
refs/
|
| 15 |
+
FETCH_HEAD (110.0B)
|
| 16 |
+
HEAD (21.0B)
|
| 17 |
+
config (339.0B)
|
| 18 |
+
description (73.0B)
|
| 19 |
+
packed-refs (112.0B)
|
| 20 |
+
refs/
|
| 21 |
+
HEAD (189.0B)
|
| 22 |
+
heads/
|
| 23 |
+
remotes/
|
| 24 |
+
main (189.0B)
|
| 25 |
+
heads/
|
| 26 |
+
remotes/
|
| 27 |
+
tags/
|
| 28 |
+
origin/
|
| 29 |
+
HEAD (30.0B)
|
| 30 |
+
main (41.0B)
|
| 31 |
+
info/
|
| 32 |
+
pack/
|
| 33 |
+
pack-ee3e40a1a23ec17affa3b8afb61dc14bdffb229c.idx (38.9KB)
|
| 34 |
+
pack-ee3e40a1a23ec17affa3b8afb61dc14bdffb229c.pack (195.5KB)
|
| 35 |
+
applypatch-msg.sample (478.0B)
|
| 36 |
+
commit-msg.sample (896.0B)
|
| 37 |
+
fsmonitor-watchman.sample (4.5KB)
|
| 38 |
+
post-checkout (280.0B)
|
| 39 |
+
post-commit (276.0B)
|
| 40 |
+
post-merge (274.0B)
|
| 41 |
+
post-update.sample (189.0B)
|
| 42 |
+
pre-applypatch.sample (424.0B)
|
| 43 |
+
pre-commit.sample (1.6KB)
|
| 44 |
+
pre-merge-commit.sample (416.0B)
|
| 45 |
+
pre-push (270.0B)
|
| 46 |
+
pre-push.sample (1.3KB)
|
| 47 |
+
pre-rebase.sample (4.8KB)
|
| 48 |
+
pre-receive.sample (544.0B)
|
| 49 |
+
prepare-commit-msg.sample (1.5KB)
|
| 50 |
+
push-to-checkout.sample (2.7KB)
|
| 51 |
+
update.sample (3.6KB)
|
| 52 |
+
incomplete/
|
| 53 |
+
logs/
|
| 54 |
+
objects/
|
| 55 |
+
tmp/
|
| 56 |
+
0152398d9443f2d300adc9e6099a773c66303d4e2e085812cd502cb36da7a0c73483193049 (0.0B)
|
| 57 |
+
0152398d9443f2d300adc9e6099a773c66303d4e2e085812cd502cb36da7a0c7763208216 (0.0B)
|
| 58 |
+
0152398d9443f2d300adc9e6099a773c66303d4e2e085812cd502cb36da7a0c789921672 (2.5MB)
|
| 59 |
+
0968a4438d46277583968011563e959e130feaee66f51bb2d66dbd7e8c979f8c.part (0.0B)
|
| 60 |
+
1f77f56225e10edca84be06b6e0d796c579cbf1d4884aee46da564438ad1ba9b1484563810 (437.0KB)
|
| 61 |
+
1f77f56225e10edca84be06b6e0d796c579cbf1d4884aee46da564438ad1ba9b3850099655 (326.7KB)
|
| 62 |
+
1f77f56225e10edca84be06b6e0d796c579cbf1d4884aee46da564438ad1ba9b3898577811 (4.1MB)
|
| 63 |
+
220d32d087b6b29d1c5aaa49324d32b32ae1c19f42e9800f40f24d3a695c2a8d1743027097 (0.0B)
|
| 64 |
+
220d32d087b6b29d1c5aaa49324d32b32ae1c19f42e9800f40f24d3a695c2a8d3014727128 (0.0B)
|
| 65 |
+
220d32d087b6b29d1c5aaa49324d32b32ae1c19f42e9800f40f24d3a695c2a8d71894927 (62.6KB)
|
| 66 |
+
24f014bb5bc7b1fa7d9183dd65fd4b43c0c49aafd6af01bb91ae3a0e7e65502b2818819757 (49.3MB)
|
| 67 |
+
3da69649bfbc671710f38c2c2f7c6aaecb8f8544de3446866054bf927257c9332854861486 (158.6KB)
|
| 68 |
+
3da69649bfbc671710f38c2c2f7c6aaecb8f8544de3446866054bf927257c9334214717938 (0.0B)
|
| 69 |
+
3da69649bfbc671710f38c2c2f7c6aaecb8f8544de3446866054bf927257c933593947826 (0.0B)
|
| 70 |
+
45e8c51ed0df8edb1ae51d2012b3f7d6cd9cc84addf41e6f9f9adb0f625d41033126870057 (259.2MB)
|
| 71 |
+
4a80559730d917177e4d13246da0ce23ca318735b29d519d0448bea5579b1a771450117433 (154.4MB)
|
| 72 |
+
4fda2aa4918e5dec847935db6d46e9bebc570a173bd4201c5f48e60a3f73813a1530155941 (1.1MB)
|
| 73 |
+
4fda2aa4918e5dec847935db6d46e9bebc570a173bd4201c5f48e60a3f73813a2738070238 (0.0B)
|
| 74 |
+
4fda2aa4918e5dec847935db6d46e9bebc570a173bd4201c5f48e60a3f73813a2828099128 (0.0B)
|
| 75 |
+
52a445f8a26cd898e64129e7f1d4bfa6d7203311442068684f5344fc73407310.part (0.0B)
|
| 76 |
+
6728a8fb7bad0bad3a2a27669232cb9ae66461c635172f1f7958c80a28e09fa32607733000 (150.2MB)
|
| 77 |
+
6bb6c9f17e77eab7d88e4a4501c38cb31a6cf792fe77e3b75d511b964a5667df2998182268 (91.8MB)
|
| 78 |
+
8cb15647ff6bbac322142fea1a38599c523f73acb3614ddb7d12e6a1975a79dc1986657385 (0.0B)
|
| 79 |
+
8cb15647ff6bbac322142fea1a38599c523f73acb3614ddb7d12e6a1975a79dc2743098052 (0.0B)
|
| 80 |
+
8cb15647ff6bbac322142fea1a38599c523f73acb3614ddb7d12e6a1975a79dc4193739161 (0.0B)
|
| 81 |
+
9919274ad6bc88e37235a4c7245d05e357e404ef3352a90a1ba0594e694893c01114223911 (0.0B)
|
| 82 |
+
9919274ad6bc88e37235a4c7245d05e357e404ef3352a90a1ba0594e694893c03545613611 (0.0B)
|
| 83 |
+
9919274ad6bc88e37235a4c7245d05e357e404ef3352a90a1ba0594e694893c0559090370 (2.8MB)
|
| 84 |
+
9cdf4d1a6972db893c8db1a4f2be0d1ec0362ba22a44542402b336760029c87253830692 (88.0MB)
|
| 85 |
+
b6aed90c79d180c5346994f8e7d0657b3d8a9aab002c057503736b4013a2096b.part (0.0B)
|
| 86 |
+
ba47b9680dc949322877399218d1f210a057249803bc70addfb9528152e4b1662004000729 (218.5MB)
|
| 87 |
+
ca49e0b3f3400f38519a1103b2a567db32c9fa990a7395b1024b94454601479b.part (0.0B)
|
| 88 |
+
d66a5b3267a7935b8ff272bcc166a8f43a8d66fb89c59503d536ac87661a02022501429466 (0.0B)
|
| 89 |
+
d66a5b3267a7935b8ff272bcc166a8f43a8d66fb89c59503d536ac87661a020230475132 (0.0B)
|
| 90 |
+
d66a5b3267a7935b8ff272bcc166a8f43a8d66fb89c59503d536ac87661a0202373225118 (62.5KB)
|
| 91 |
+
e5a3eb3e2d0c47d6f014e294ef7398bf26375920c8d2af80fd65e255396dcc78.part (0.0B)
|
| 92 |
+
f19cacf3a9f9a57abdcafc4a6d242aa9c6fa48188ad0a394b1a2558cb8ab4dc5372340294 (199.2MB)
|
| 93 |
+
20251021T152133.441099492.log (1.4KB)
|
| 94 |
+
01/
|
| 95 |
+
02/
|
| 96 |
+
03/
|
| 97 |
+
05/
|
| 98 |
+
06/
|
| 99 |
+
07/
|
| 100 |
+
09/
|
| 101 |
+
0b/
|
| 102 |
+
0f/
|
| 103 |
+
10/
|
| 104 |
+
12/
|
| 105 |
+
15/
|
| 106 |
+
16/
|
| 107 |
+
19/
|
| 108 |
+
1d/
|
| 109 |
+
1e/
|
| 110 |
+
1f/
|
| 111 |
+
21/
|
| 112 |
+
22/
|
| 113 |
+
23/
|
| 114 |
+
24/
|
| 115 |
+
2a/
|
| 116 |
+
2b/
|
| 117 |
+
2c/
|
| 118 |
+
2d/
|
| 119 |
+
2f/
|
| 120 |
+
30/
|
| 121 |
+
32/
|
| 122 |
+
34/
|
| 123 |
+
37/
|
| 124 |
+
3b/
|
| 125 |
+
3d/
|
| 126 |
+
44/
|
| 127 |
+
45/
|
| 128 |
+
4a/
|
| 129 |
+
4f/
|
| 130 |
+
50/
|
| 131 |
+
52/
|
| 132 |
+
54/
|
| 133 |
+
56/
|
| 134 |
+
58/
|
| 135 |
+
5a/
|
| 136 |
+
5b/
|
| 137 |
+
60/
|
| 138 |
+
61/
|
| 139 |
+
64/
|
| 140 |
+
65/
|
| 141 |
+
67/
|
| 142 |
+
68/
|
| 143 |
+
69/
|
| 144 |
+
6b/
|
| 145 |
+
6d/
|
| 146 |
+
6e/
|
| 147 |
+
70/
|
| 148 |
+
75/
|
| 149 |
+
76/
|
| 150 |
+
7b/
|
| 151 |
+
7c/
|
| 152 |
+
80/
|
| 153 |
+
87/
|
| 154 |
+
88/
|
| 155 |
+
89/
|
| 156 |
+
8b/
|
| 157 |
+
8c/
|
| 158 |
+
90/
|
| 159 |
+
91/
|
| 160 |
+
93/
|
| 161 |
+
99/
|
| 162 |
+
9a/
|
| 163 |
+
9b/
|
| 164 |
+
9c/
|
| 165 |
+
9e/
|
| 166 |
+
9f/
|
| 167 |
+
a0/
|
| 168 |
+
a5/
|
| 169 |
+
a9/
|
| 170 |
+
ac/
|
| 171 |
+
ae/
|
| 172 |
+
b1/
|
| 173 |
+
b3/
|
| 174 |
+
b4/
|
| 175 |
+
b6/
|
| 176 |
+
ba/
|
| 177 |
+
bb/
|
| 178 |
+
bc/
|
| 179 |
+
bd/
|
| 180 |
+
be/
|
| 181 |
+
c0/
|
| 182 |
+
c1/
|
| 183 |
+
c2/
|
| 184 |
+
c4/
|
| 185 |
+
c6/
|
| 186 |
+
c7/
|
| 187 |
+
c8/
|
| 188 |
+
ca/
|
| 189 |
+
cb/
|
| 190 |
+
d6/
|
| 191 |
+
d9/
|
| 192 |
+
dd/
|
| 193 |
+
e2/
|
| 194 |
+
e5/
|
| 195 |
+
e7/
|
| 196 |
+
e8/
|
| 197 |
+
e9/
|
| 198 |
+
ee/
|
| 199 |
+
ef/
|
| 200 |
+
f1/
|
| 201 |
+
f3/
|
| 202 |
+
f4/
|
| 203 |
+
f5/
|
| 204 |
+
f6/
|
| 205 |
+
f7/
|
| 206 |
+
f8/
|
| 207 |
+
f9/
|
| 208 |
+
fc/
|
| 209 |
+
exclude (240.0B)
|
| 210 |
+
captioning/
|
| 211 |
+
classification/
|
| 212 |
+
generation/
|
| 213 |
+
reasoning/
|
| 214 |
+
vqa/
|
| 215 |
+
chinesefoodnet-10/
|
| 216 |
+
coco-goi/
|
| 217 |
+
coco-text/
|
| 218 |
+
imagenet/
|
| 219 |
+
iqa/
|
| 220 |
+
itm/
|
| 221 |
+
mocheg/
|
| 222 |
+
refcoco/
|
| 223 |
+
snli-ve/
|
| 224 |
+
ss/
|
| 225 |
+
vsr/
|
| 226 |
+
winoground/
|
| 227 |
+
.gitattributes (141.0B)
|
| 228 |
+
README.md (211.0B)
|
| 229 |
+
instructions.json (1.4KB)
|
| 230 |
+
labels.json (9.0KB)
|
| 231 |
+
test.jsonl (223.5MB)
|
| 232 |
+
train.jsonl (238.9MB)
|
| 233 |
+
val.jsonl (227.6MB)
|
| 234 |
+
README.md (31.0B)
|
| 235 |
+
esnlive_test.jsonl (743.0MB)
|
| 236 |
+
esnlive_train.jsonl (1000.8MB)
|
| 237 |
+
esnlive_val.jsonl (717.9MB)
|
| 238 |
+
instructions.json (1.9KB)
|
| 239 |
+
test_2023-10-09.jsonl (2.9GB)
|
| 240 |
+
train_2023-10-09.jsonl (3.9GB)
|
| 241 |
+
instructions.json (825.0B)
|
| 242 |
+
mapping.txt (30.9KB)
|
| 243 |
+
test_2023-10-08.jsonl (10.6GB)
|
| 244 |
+
train.jsonl (1.5GB)
|
| 245 |
+
train_2023-10-08.jsonl (5.9GB)
|
| 246 |
+
val.jsonl (2.6GB)
|
| 247 |
+
instructions.json (907.0B)
|
| 248 |
+
test.jsonl (330.4MB)
|
| 249 |
+
test_2023-10-09.jsonl (1.3GB)
|
| 250 |
+
train.jsonl (1.9GB)
|
| 251 |
+
train_2023-10-08.jsonl (7.8GB)
|
| 252 |
+
val.jsonl (330.8MB)
|
| 253 |
+
instructions.json (773.0B)
|
| 254 |
+
test.jsonl (730.0MB)
|
| 255 |
+
test_2023-10-09.jsonl (2.9GB)
|
| 256 |
+
train.jsonl (4.3GB)
|
| 257 |
+
train_2023-10-08.jsonl (17.1GB)
|
| 258 |
+
val.jsonl (730.2MB)
|
| 259 |
+
instructions.json (1.4KB)
|
| 260 |
+
test_2023-10-09.jsonl (553.7MB)
|
| 261 |
+
train_2023-10-09.jsonl (1.9GB)
|
| 262 |
+
vsr_test.jsonl (137.7MB)
|
| 263 |
+
vsr_train.jsonl (483.3MB)
|
| 264 |
+
vsr_val.jsonl (68.8MB)
|
| 265 |
+
instructions.json (774.0B)
|
| 266 |
+
test_2023-10-10.jsonl (7.6GB)
|
| 267 |
+
train.jsonl (8.2GB)
|
| 268 |
+
train_2023-10-08.jsonl (32.8GB)
|
| 269 |
+
val.jsonl (1.9GB)
|
| 270 |
+
instructions.json (733.0B)
|
| 271 |
+
test_2023-10-07.jsonl (279.1MB)
|
| 272 |
+
train.jsonl (2.0GB)
|
| 273 |
+
train_2023-10-06.jsonl (4.1GB)
|
| 274 |
+
val.jsonl (138.9MB)
|
| 275 |
+
instructions.json (2.0KB)
|
| 276 |
+
winoground_test.jsonl (245.5MB)
|
| 277 |
+
instructions.json (1.3KB)
|
| 278 |
+
test.jsonl (122.9MB)
|
| 279 |
+
instructions.json (1.0KB)
|
| 280 |
+
mocheg_test.jsonl (60.3MB)
|
| 281 |
+
mocheg_train.jsonl (631.7MB)
|
| 282 |
+
mocheg_val.jsonl (28.2MB)
|
| 283 |
+
test_2023-10-08.jsonl (242.5MB)
|
| 284 |
+
train_2023-10-08.jsonl (2.5GB)
|
| 285 |
+
instructions.json (1.5KB)
|
| 286 |
+
test.jsonl (701.9MB)
|
| 287 |
+
test_2023-10-08.jsonl (2.7GB)
|
| 288 |
+
train.jsonl (3.9GB)
|
| 289 |
+
train_2023-10-08.jsonl (15.6GB)
|
| 290 |
+
val.jsonl (667.7MB)
|
| 291 |
+
clevr/
|
| 292 |
+
nlvr/
|
| 293 |
+
science_qa/
|
| 294 |
+
vcr/
|
| 295 |
+
visual_mrc/
|
| 296 |
+
instructions.json (2.5KB)
|
| 297 |
+
science_qa_test.jsonl (174.0MB)
|
| 298 |
+
science_qa_train.jsonl (531.3MB)
|
| 299 |
+
science_qa_validation.jsonl (176.4MB)
|
| 300 |
+
instructions.json (976.0B)
|
| 301 |
+
train.jsonl (5.6GB)
|
| 302 |
+
train_2023-10-07.jsonl (11.1GB)
|
| 303 |
+
val.jsonl (379.6MB)
|
| 304 |
+
val_2023-10-07.jsonl (760.4MB)
|
| 305 |
+
instructions.json (911.0B)
|
| 306 |
+
test.jsonl (1.2GB)
|
| 307 |
+
train.jsonl (3.9GB)
|
| 308 |
+
val.jsonl (266.9MB)
|
| 309 |
+
instructions.json (1.3KB)
|
| 310 |
+
test.jsonl (909.3MB)
|
| 311 |
+
train.jsonl (4.3GB)
|
| 312 |
+
val.jsonl (992.9MB)
|
| 313 |
+
instructions.json (1.2KB)
|
| 314 |
+
test.jsonl (489.0MB)
|
| 315 |
+
train.jsonl (7.9GB)
|
| 316 |
+
val.jsonl (533.3MB)
|
| 317 |
+
mmchat/
|
| 318 |
+
multi30k/
|
| 319 |
+
vist/
|
| 320 |
+
visual_dialog/
|
| 321 |
+
instructions.json (818.0B)
|
| 322 |
+
test.jsonl (65.2MB)
|
| 323 |
+
test_2023-10-10.jsonl (262.2MB)
|
| 324 |
+
train.jsonl (3.2GB)
|
| 325 |
+
train_2023-10-09.jsonl (13.0GB)
|
| 326 |
+
val.jsonl (66.0MB)
|
| 327 |
+
instructions.json (1.2KB)
|
| 328 |
+
test.jsonl (610.6MB)
|
| 329 |
+
train.jsonl (4.4GB)
|
| 330 |
+
val.jsonl (301.1MB)
|
| 331 |
+
instructions.json (809.0B)
|
| 332 |
+
test.jsonl (2.3GB)
|
| 333 |
+
train.jsonl (6.2GB)
|
| 334 |
+
train_new.jsonl (6.2GB)
|
| 335 |
+
validation.jsonl (2.0GB)
|
| 336 |
+
instructions.json (1.0KB)
|
| 337 |
+
test.jsonl (14.0GB)
|
| 338 |
+
train.jsonl (15.4GB)
|
| 339 |
+
val.jsonl (13.0GB)
|
| 340 |
+
a-okvqa/
|
| 341 |
+
activitynet-qa/
|
| 342 |
+
docvqa/
|
| 343 |
+
fm-iqa/
|
| 344 |
+
gqa/
|
| 345 |
+
ivqa/
|
| 346 |
+
msrvtt-qa/
|
| 347 |
+
msvd-qa/
|
| 348 |
+
ocr-vqa/
|
| 349 |
+
okvqa/
|
| 350 |
+
shapes/
|
| 351 |
+
st-vqa/
|
| 352 |
+
text-vqa/
|
| 353 |
+
viquae/
|
| 354 |
+
vqav2/
|
| 355 |
+
instruction.json (905.0B)
|
| 356 |
+
train.jsonl (533.5MB)
|
| 357 |
+
train_new.jsonl (533.5MB)
|
| 358 |
+
validation.jsonl (228.3MB)
|
| 359 |
+
instructions.json (1.9KB)
|
| 360 |
+
train.jsonl (1.2GB)
|
| 361 |
+
train_v2.jsonl (1.2GB)
|
| 362 |
+
val.jsonl (77.7MB)
|
| 363 |
+
val_v2.jsonl (78.2MB)
|
| 364 |
+
instruction.json (905.0B)
|
| 365 |
+
test.jsonl (713.3MB)
|
| 366 |
+
train.jsonl (3.3GB)
|
| 367 |
+
validation_new.jsonl (529.5MB)
|
| 368 |
+
instruction.json (772.0B)
|
| 369 |
+
train.jsonl (1.5GB)
|
| 370 |
+
validation.jsonl (260.3MB)
|
| 371 |
+
instruction.json (853.0B)
|
| 372 |
+
test.jsonl (229.4MB)
|
| 373 |
+
train.jsonl (1.4GB)
|
| 374 |
+
README.md (288.0B)
|
| 375 |
+
instructions.json (1.2KB)
|
| 376 |
+
test.jsonl (132.4MB)
|
| 377 |
+
train.jsonl (343.1MB)
|
| 378 |
+
val.jsonl (60.9MB)
|
| 379 |
+
instructions.json (853.0B)
|
| 380 |
+
train.jsonl (1.9GB)
|
| 381 |
+
val.jsonl (1.9GB)
|
| 382 |
+
instructions.json (1.7KB)
|
| 383 |
+
train.jsonl (7.2GB)
|
| 384 |
+
val.jsonl (976.6MB)
|
| 385 |
+
instructions.json (1.5KB)
|
| 386 |
+
test.jsonl (1.4MB)
|
| 387 |
+
test_2023-10-08.jsonl (7.0MB)
|
| 388 |
+
train.large.jsonl (18.3MB)
|
| 389 |
+
train_2023-10-08.jsonl (92.6MB)
|
| 390 |
+
val.jsonl (1.4MB)
|
| 391 |
+
README.md (334.0B)
|
| 392 |
+
instructions.json (1.0KB)
|
| 393 |
+
test.jsonl (500.8MB)
|
| 394 |
+
train.jsonl (1.5GB)
|
| 395 |
+
val.jsonl (485.4MB)
|
| 396 |
+
README.md (434.0B)
|
| 397 |
+
instructions.json (1.0KB)
|
| 398 |
+
test.jsonl (348.1MB)
|
| 399 |
+
train.jsonl (757.5MB)
|
| 400 |
+
val.jsonl (58.0MB)
|
| 401 |
+
.gitattributes (141.0B)
|
| 402 |
+
README.md (332.0B)
|
| 403 |
+
instructions.json (1.4KB)
|
| 404 |
+
test.jsonl (474.7MB)
|
| 405 |
+
train.jsonl (2.1GB)
|
| 406 |
+
val.jsonl (1.1GB)
|
| 407 |
+
instructions.json (1.2KB)
|
| 408 |
+
train.jsonl (594.8MB)
|
| 409 |
+
train_v2.jsonl (596.3MB)
|
| 410 |
+
val.jsonl (334.3MB)
|
| 411 |
+
val_v2.jsonl (335.2MB)
|
| 412 |
+
instructions.json (802.0B)
|
| 413 |
+
para_train.jsonl (10.5GB)
|
| 414 |
+
para_val.jsonl (4.8GB)
|
| 415 |
+
train.jsonl (10.5GB)
|
| 416 |
+
val.jsonl (4.8GB)
|
| 417 |
+
instructions.json (1.2KB)
|
| 418 |
+
test.jsonl (122.5MB)
|
| 419 |
+
test_v2.jsonl (120.9MB)
|
| 420 |
+
train.jsonl (110.1MB)
|
| 421 |
+
train_v2.jsonl (110.2MB)
|
| 422 |
+
validation.jsonl (125.5MB)
|
| 423 |
+
validation_v2.jsonl (125.6MB)
|
| 424 |
+
coco/
|
| 425 |
+
coco-cn/
|
| 426 |
+
flickr8k-cn/
|
| 427 |
+
image_paragraph_captioning/
|
| 428 |
+
msrvtt/
|
| 429 |
+
textcap/
|
| 430 |
+
.gitattributes (141.0B)
|
| 431 |
+
README.md (490.0B)
|
| 432 |
+
instructions.json (1010.0B)
|
| 433 |
+
test.jsonl (117.1MB)
|
| 434 |
+
train.jsonl (231.1MB)
|
| 435 |
+
val.jsonl (116.9MB)
|
| 436 |
+
instructions.json (541.0B)
|
| 437 |
+
test.jsonl (49.4MB)
|
| 438 |
+
train.jsonl (300.0MB)
|
| 439 |
+
val.jsonl (49.9MB)
|
| 440 |
+
instructions.json (790.0B)
|
| 441 |
+
test.jsonl (66.4MB)
|
| 442 |
+
train.jsonl (1.2GB)
|
| 443 |
+
val.jsonl (65.0MB)
|
| 444 |
+
image_paragraph_captioning_test.jsonl (120.7MB)
|
| 445 |
+
image_paragraph_captioning_train.jsonl (701.2MB)
|
| 446 |
+
image_paragraph_captioning_val.jsonl (118.0MB)
|
| 447 |
+
instruction.json (1.4KB)
|
| 448 |
+
README.md (73.0B)
|
| 449 |
+
create_dataset.py (5.5KB)
|
| 450 |
+
instructions.json (882.0B)
|
| 451 |
+
test.jsonl (333.1MB)
|
| 452 |
+
train.jsonl (7.4GB)
|
| 453 |
+
val.jsonl (333.4MB)
|
| 454 |
+
instructions.json (1.1KB)
|
| 455 |
+
train.jsonl (5.7GB)
|
| 456 |
+
val.jsonl (851.3MB)
|
ICL/RL_DAPO/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
ICL/SFT_new/README.md
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Qwen3-VL-8B Single-Step Decision SFT
|
| 2 |
+
|
| 3 |
+
## 项目结构
|
| 4 |
+
|
| 5 |
+
```
|
| 6 |
+
SFT_new/
|
| 7 |
+
├── build_sft.py # 数据构造 (SigLIP2 相似度选 shots, 单步决策格式)
|
| 8 |
+
├── generate_captions.py # VLM 批量 caption 生成 (替代短答案作为检索描述)
|
| 9 |
+
├── train.py # 训练主脚本 (DeepSpeed + Flash Attention 2)
|
| 10 |
+
├── ds_zero2.json # DeepSpeed ZeRO-2 配置 (推荐, 速度快)
|
| 11 |
+
├── ds_zero3.json # DeepSpeed ZeRO-3 配置 (备用, 更省显存)
|
| 12 |
+
├── run_single_node.sh # 单机启动脚本 (debug)
|
| 13 |
+
├── run_multi_node.sh # 多机训练入口 (每个 node 执行)
|
| 14 |
+
├── submit_northjob.sh # northjob 集群提交 (64卡)
|
| 15 |
+
├── launch_wrapper.py # northjob → bash 桥接
|
| 16 |
+
└── README.md # 本文件
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
## 整体 Pipeline
|
| 22 |
+
|
| 23 |
+
```
|
| 24 |
+
原始数据集 (jsonl + 图片)
|
| 25 |
+
│
|
| 26 |
+
▼
|
| 27 |
+
┌─────────────────┐
|
| 28 |
+
│ build_sft.py │ --build-cache ← 只跑一次, GPU
|
| 29 |
+
│ SigLIP2 编码 │ 生成 emb_cache/
|
| 30 |
+
└────────┬────────┘
|
| 31 |
+
│
|
| 32 |
+
▼
|
| 33 |
+
┌──────────────────────┐
|
| 34 |
+
│ generate_captions.py │ VLM API 批量生成 ← 只跑一次, 无需 GPU
|
| 35 |
+
│ 生成 caption_cache/ │ (vLLM 部署的 Qwen3-VL)
|
| 36 |
+
└────────┬─────────────┘
|
| 37 |
+
│
|
| 38 |
+
▼
|
| 39 |
+
┌─────────────────┐
|
| 40 |
+
│ build_sft.py │ 构造 SFT 数据 ← CPU, 可多进程并行
|
| 41 |
+
│ 读取 emb_cache │ 读取 caption_cache
|
| 42 |
+
│ + caption_cache │ 输出 sft.jsonl
|
| 43 |
+
└────────┬────────┘
|
| 44 |
+
│
|
| 45 |
+
▼
|
| 46 |
+
┌─────────────────┐
|
| 47 |
+
│ train.py │ DeepSpeed 训练
|
| 48 |
+
└─────────────────┘
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
---
|
| 52 |
+
|
| 53 |
+
## 1. 配环境
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
# 创建 conda 环境
|
| 57 |
+
conda create -n sft python=3.11 -y
|
| 58 |
+
conda activate sft
|
| 59 |
+
|
| 60 |
+
# PyTorch 2.4 + CUDA 12 (匹配 flash-attn whl)
|
| 61 |
+
pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124
|
| 62 |
+
|
| 63 |
+
# Flash Attention 2 (本地 whl, 先试 TRUE 版, 不行换 FALSE 版)
|
| 64 |
+
pip install /workspace/flash_attn-2.8.3+cu12torch2.4cxx11abiTRUE-cp311-cp311-linux_x86_64.whl
|
| 65 |
+
# 如果报 CXX11 ABI 不匹配:
|
| 66 |
+
# pip install /workspace/flash_attn-2.8.3+cu12torch2.4cxx11abiFALSE-cp311-cp311-linux_x86_64.whl
|
| 67 |
+
|
| 68 |
+
# 核心依赖
|
| 69 |
+
pip install transformers>=4.57.0
|
| 70 |
+
pip install accelerate>=1.13.0
|
| 71 |
+
pip install peft>=0.18.0
|
| 72 |
+
pip install deepspeed>=0.16.0
|
| 73 |
+
pip install qwen-vl-utils
|
| 74 |
+
pip install tqdm pillow
|
| 75 |
+
pip install openai # generate_captions.py 需要
|
| 76 |
+
|
| 77 |
+
# 验证安装
|
| 78 |
+
python -c "
|
| 79 |
+
import torch, transformers, deepspeed, flash_attn, peft
|
| 80 |
+
print(f'torch: {torch.__version__}')
|
| 81 |
+
print(f'transformers: {transformers.__version__}')
|
| 82 |
+
print(f'deepspeed: {deepspeed.__version__}')
|
| 83 |
+
print(f'flash_attn: {flash_attn.__version__}')
|
| 84 |
+
print(f'peft: {peft.__version__}')
|
| 85 |
+
print(f'CUDA: {torch.cuda.is_available()}, {torch.cuda.get_device_name(0)}')
|
| 86 |
+
from transformers import Qwen3VLForConditionalGeneration
|
| 87 |
+
print('Qwen3VL: OK')
|
| 88 |
+
"
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
**注意**: flash-attn whl 是针对 torch 2.4 编译的, 所以 PyTorch 必须装 2.4.x 版本.
|
| 92 |
+
|
| 93 |
+
---
|
| 94 |
+
|
| 95 |
+
## 2. 构造数据
|
| 96 |
+
|
| 97 |
+
### 2.1 构建 SigLIP embedding 缓存 (只跑一次, GPU)
|
| 98 |
+
|
| 99 |
+
```bash
|
| 100 |
+
conda activate sft
|
| 101 |
+
|
| 102 |
+
python /workspace/xiaobin/ICL/SFT_new/build_sft.py \
|
| 103 |
+
--build-cache \
|
| 104 |
+
--data-root /path/to/your/dataset \
|
| 105 |
+
--output-dir /workspace/xiaobin/ICL/SFT_new/output \
|
| 106 |
+
--siglip-model /workspace/models/siglip2-so400m-patch16-naflex \
|
| 107 |
+
--device cuda:0 \
|
| 108 |
+
--categories vqa,captioning,classification,reasoning
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
缓存保存在 `output/emb_cache/` 下, JSON 格式 (float16 base64), 可跨环境复用.
|
| 112 |
+
|
| 113 |
+
### 2.2 生成 VLM Caption (只跑一次, 调 API 无需本地 GPU)
|
| 114 |
+
|
| 115 |
+
**为什么需要这一步**: 很多 VQA 数据集的 answer 是短答案 ("yes", "3", "cab"), 不适合做语义检索的 query 描述. 用 VLM 给每张 pool 图片生成描述性 caption, 作为 `<RET>` 输出的 Description 和 context shot 的 Caption, 质量远好于原始 answer.
|
| 116 |
+
|
| 117 |
+
#### 启动 vLLM 服务 (NorthServe)
|
| 118 |
+
|
| 119 |
+
```bash
|
| 120 |
+
# 启动 Qwen3-VL-8B 推理服务(8 副本,每副本 1 卡)
|
| 121 |
+
HOME=/root /workspace/nex-agi/NorthServe/northserve launch \
|
| 122 |
+
--model-name qwen3vl8b-caption \
|
| 123 |
+
--served-model-name Qwen3-VL-8B-Instruct \
|
| 124 |
+
--namespace bg-agentic-coding \
|
| 125 |
+
--model-path /i_workspace/models/Qwen3-VL-8B-Instruct \
|
| 126 |
+
--volumes "i-xinsiyang-y4zy0sik0a:/i_workspace" \
|
| 127 |
+
--replicas 32 \
|
| 128 |
+
--gpus-per-pod 1 \
|
| 129 |
+
--pods-per-job 1 \
|
| 130 |
+
--profile generation \
|
| 131 |
+
--backend vllm \
|
| 132 |
+
--priority-class-name higher-priority-job \
|
| 133 |
+
--extra-cmds "--trust-remote-code --max-model-len 4096 --max-num-seqs 128" \
|
| 134 |
+
-y
|
| 135 |
+
|
| 136 |
+
# 验证(所有模型共用 http://10.51.6.110/v1,模型名在请求体里指定)
|
| 137 |
+
curl http://10.51.6.110/v1/models
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
#### 生成 caption (emb_cache 对齐版)
|
| 141 |
+
|
| 142 |
+
```bash
|
| 143 |
+
python /workspace/xiaobin/ICL/SFT_new/generate_captions.py \
|
| 144 |
+
--api-base http://10.51.6.110/v1 \
|
| 145 |
+
--model Qwen3-VL-8B-Instruct \
|
| 146 |
+
--emb-cache-dir /workspace/xiaobin/ICL/SFT_new/output/emb_cache \
|
| 147 |
+
--output-dir /workspace/xiaobin/ICL/SFT_new/output/caption_cache \
|
| 148 |
+
--num-workers 128 \
|
| 149 |
+
--prompt "Describe this image in one or two sentences. Focus on the main objects, their attributes, and spatial relationships."
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
#### 生成 caption (全量图片版, 按 split 分开保存)
|
| 153 |
+
|
| 154 |
+
```bash
|
| 155 |
+
# 全量跑 (~200 万张图)
|
| 156 |
+
python /workspace/xiaobin/ICL/SFT_new/generate_captions_all.py \
|
| 157 |
+
--api-base http://10.51.6.110/v1 \
|
| 158 |
+
--model Qwen3-VL-8B-Instruct \
|
| 159 |
+
--num-workers 128
|
| 160 |
+
|
| 161 |
+
# 只跑某个 category
|
| 162 |
+
python /workspace/xiaobin/ICL/SFT_new/generate_captions_all.py \
|
| 163 |
+
--api-base http://10.51.6.110/v1 \
|
| 164 |
+
--model Qwen3-VL-8B-Instruct \
|
| 165 |
+
--categories vqa \
|
| 166 |
+
--num-workers 128
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
输出到 `/workspace/xiaobin/dataset/detail/{category}/{dataset}/{split}/captions.json`
|
| 170 |
+
|
| 171 |
+
#### 停止服务
|
| 172 |
+
|
| 173 |
+
```bash
|
| 174 |
+
HOME=/root /workspace/nex-agi/NorthServe/northserve stop \
|
| 175 |
+
--model-name qwen3vl8b-caption
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
**关键特性**:
|
| 179 |
+
- **断点续传**: 已完成的文件自动跳过, 部分完成的只处理缺失图片
|
| 180 |
+
- **定期存盘**: 每 500 张自动保存 (防崩溃丢数据), `--save-every` 可调
|
| 181 |
+
- **并发请求**: `--num-workers 128`, 8 副本理论上限 1024, 不报错就往大了开
|
| 182 |
+
|
| 183 |
+
### 2.3 构建 SFT 数据集 (CPU, 不需要 GPU, 可多进程并行)
|
| 184 |
+
|
| 185 |
+
```bash
|
| 186 |
+
# 单进程
|
| 187 |
+
python /workspace/xiaobin/ICL/SFT_new/build_sft.py \
|
| 188 |
+
--data-root /path/to/your/dataset \
|
| 189 |
+
--output-dir /workspace/xiaobin/ICL/SFT_new/output \
|
| 190 |
+
--caption-cache-dir /workspace/xiaobin/ICL/SFT_new/output/caption_cache \
|
| 191 |
+
--samples-per-cat 20000 \
|
| 192 |
+
--max-shots 3 \
|
| 193 |
+
--answer-at-weights 3,3,2,1
|
| 194 |
+
|
| 195 |
+
# 多进程并行 (4 shards)
|
| 196 |
+
for i in 0 1 2 3; do
|
| 197 |
+
python /workspace/xiaobin/ICL/SFT_new/build_sft.py \
|
| 198 |
+
--data-root /path/to/your/dataset \
|
| 199 |
+
--output-dir /workspace/xiaobin/ICL/SFT_new/output \
|
| 200 |
+
--caption-cache-dir /workspace/xiaobin/ICL/SFT_new/output/caption_cache \
|
| 201 |
+
--shard-id $i --num-shards 4 &
|
| 202 |
+
done
|
| 203 |
+
wait
|
| 204 |
+
|
| 205 |
+
# 合并
|
| 206 |
+
python /workspace/xiaobin/ICL/SFT_new/build_sft.py \
|
| 207 |
+
--data-root /path/to/your/dataset \
|
| 208 |
+
--output-dir /workspace/xiaobin/ICL/SFT_new/output \
|
| 209 |
+
--merge --shuffle
|
| 210 |
+
```
|
| 211 |
+
|
| 212 |
+
**注意**: `--caption-cache-dir` 不传或目录不存在时行为和之前完全一致(用原始 answer)。正式训练前务必先跑 `generate_captions.py` 生成完整的 caption cache。
|
| 213 |
+
|
| 214 |
+
最终数据: `output/all/sft.jsonl`
|
| 215 |
+
|
| 216 |
+
**生成数据中的描述字段变化**:
|
| 217 |
+
```
|
| 218 |
+
# 之前 (用原始 answer, 短答案质量差)
|
| 219 |
+
{"from": "gpt", "value": "<RET>\nDescription: yes"}
|
| 220 |
+
{"from": "human", "value": "...<image>\nCaption: yes..."}
|
| 221 |
+
|
| 222 |
+
# 现在 (用 VLM 生成的描述, 适合语义检索)
|
| 223 |
+
{"from": "gpt", "value": "<RET>\nDescription: A woman cutting a large white cake in a kitchen."}
|
| 224 |
+
{"from": "human", "value": "...<image>\nCaption: A woman cutting a large white cake in a kitchen...."}
|
| 225 |
+
```
|
| 226 |
+
|
| 227 |
+
---
|
| 228 |
+
|
| 229 |
+
## 3. 训练
|
| 230 |
+
|
| 231 |
+
### 3.1 单机 debug (1 node x 8 H100)
|
| 232 |
+
|
| 233 |
+
```bash
|
| 234 |
+
conda activate sft
|
| 235 |
+
|
| 236 |
+
bash /workspace/xiaobin/ICL/SFT_new/run_single_node.sh \
|
| 237 |
+
/workspace/xiaobin/ICL/SFT_new/output/all/sft.jsonl \
|
| 238 |
+
8
|
| 239 |
+
```
|
| 240 |
+
|
| 241 |
+
可改 GPU 数快速 debug:
|
| 242 |
+
```bash
|
| 243 |
+
# 用 2 卡 debug
|
| 244 |
+
bash /workspace/xiaobin/ICL/SFT_new/run_single_node.sh /path/to/sft.jsonl 2
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
### 3.2 多机训练 (8 nodes x 8 GPUs = 64 H100)
|
| 248 |
+
|
| 249 |
+
**方式 A: northjob 提交 (推荐)**
|
| 250 |
+
|
| 251 |
+
先修改 `submit_northjob.sh` 里的 k8s 参数 (queue/namespace/pvc-name 改成你自己的), 然后:
|
| 252 |
+
|
| 253 |
+
```bash
|
| 254 |
+
bash /workspace/xiaobin/ICL/SFT_new/submit_northjob.sh 64 # 64卡
|
| 255 |
+
bash /workspace/xiaobin/ICL/SFT_new/submit_northjob.sh 32 # 32卡
|
| 256 |
+
```
|
| 257 |
+
|
| 258 |
+
**方式 B: 手动 torchrun (每个 node 上跑)**
|
| 259 |
+
|
| 260 |
+
```bash
|
| 261 |
+
# 在每个 node 上执行, 修改 --node_rank=0/1/2/.../7
|
| 262 |
+
torchrun \
|
| 263 |
+
--nproc_per_node=8 \
|
| 264 |
+
--nnodes=8 \
|
| 265 |
+
--node_rank=${NODE_RANK} \
|
| 266 |
+
--master_addr=${MASTER_ADDR} \
|
| 267 |
+
--master_port=29500 \
|
| 268 |
+
/workspace/xiaobin/ICL/SFT_new/train.py \
|
| 269 |
+
--model-path /workspace/models/Qwen3-VL-8B-Instruct \
|
| 270 |
+
--data-path /workspace/xiaobin/ICL/SFT_new/output/all/sft.jsonl \
|
| 271 |
+
--output-dir /workspace/xiaobin/ICL/SFT_new/output/qwen3vl_sft_64gpu \
|
| 272 |
+
--deepspeed /workspace/xiaobin/ICL/SFT_new/ds_zero2.json \
|
| 273 |
+
--num-epochs 3 \
|
| 274 |
+
--batch-size 1 \
|
| 275 |
+
--gradient-accumulation-steps 2 \
|
| 276 |
+
--learning-rate 2e-5
|
| 277 |
+
```
|
| 278 |
+
|
| 279 |
+
---
|
| 280 |
+
|
| 281 |
+
## 4. 训练策略说明
|
| 282 |
+
|
| 283 |
+
| 配置 | 单机 8 GPU (debug) | 64 GPU (正式) |
|
| 284 |
+
|------|-------------------|---------------|
|
| 285 |
+
| 并行 | DeepSpeed ZeRO-2 | DeepSpeed ZeRO-2 |
|
| 286 |
+
| micro_batch/GPU | 1 | 1 |
|
| 287 |
+
| grad_accum | 8 | 2 |
|
| 288 |
+
| **global_batch** | **64** | **128** |
|
| 289 |
+
| LR | 1e-5 | 2e-5 |
|
| 290 |
+
| Epochs | 3 | 3 |
|
| 291 |
+
| max_length | 4096 | 4096 |
|
| 292 |
+
| 精度 | BF16 | BF16 |
|
| 293 |
+
| Attention | Flash Attention 2 | Flash Attention 2 |
|
| 294 |
+
| Gradient ckpt | yes | yes |
|
| 295 |
+
| 训��方式 | Full fine-tuning | Full fine-tuning |
|
| 296 |
+
|
| 297 |
+
**为什么 ZeRO-2**: 8B 模型 BF16 约 16GB, H100 80GB 绰绰有余, ZeRO-2 比 ZeRO-3 快 30-40%.
|
| 298 |
+
|
| 299 |
+
**为什么 Full FT**: 任务需要学 `<RET>/<ANS>` 新 token + 新决策能力, LoRA 对 embedding 层学习有限. 加 `--use-lora` 可切换.
|
| 300 |
+
|
| 301 |
+
**Loss**: 只在 assistant turn 内容上计算, user turn 全部 mask (-100).
|
| 302 |
+
|
| 303 |
+
---
|
| 304 |
+
|
| 305 |
+
## 5. 关键参数调整
|
| 306 |
+
|
| 307 |
+
```bash
|
| 308 |
+
# 如果显存不够 → 降 max_pixels 或切 ZeRO-3
|
| 309 |
+
--max-pixels $((512*28*28)) # 减少图片分辨率
|
| 310 |
+
--deepspeed ds_zero3.json # 切 ZeRO-3
|
| 311 |
+
|
| 312 |
+
# 如果想用 LoRA (省显存, 快, 但效果可能差一点)
|
| 313 |
+
--use-lora --lora-rank 64 --lora-alpha 128
|
| 314 |
+
|
| 315 |
+
# 调整 n-shot 分布 (answer_at_weights)
|
| 316 |
+
--answer-at-weights 3,3,2,1 # 偏向少 shot (默认)
|
| 317 |
+
--answer-at-weights 1,1,1,1 # 均匀分布
|
| 318 |
+
--answer-at-weights 1,2,3,3 # 偏向多 shot
|
| 319 |
+
```
|
| 320 |
+
|
| 321 |
+
---
|
| 322 |
+
|
| 323 |
+
## 6. 输出目录结构
|
| 324 |
+
|
| 325 |
+
```
|
| 326 |
+
output/
|
| 327 |
+
├── emb_cache/ # SigLIP2 embedding 缓存
|
| 328 |
+
│ ├── vqa_vqav2.json
|
| 329 |
+
│ ├── vqa_okvqa.json
|
| 330 |
+
│ └── ...
|
| 331 |
+
├── caption_cache/ # VLM 生成的 caption 缓存
|
| 332 |
+
│ ├── vqa_vqav2.json
|
| 333 |
+
│ ├── vqa_okvqa.json
|
| 334 |
+
│ └── ...
|
| 335 |
+
├── vqa/
|
| 336 |
+
│ ├── sft.part00.jsonl # 分片
|
| 337 |
+
│ └── sft.jsonl # 合并后
|
| 338 |
+
├── captioning/
|
| 339 |
+
│ └── ...
|
| 340 |
+
├── classification/
|
| 341 |
+
│ └── ...
|
| 342 |
+
├── reasoning/
|
| 343 |
+
│ └── ...
|
| 344 |
+
└── all/
|
| 345 |
+
└── sft.jsonl # 全部合并 + shuffle, 训练用这个
|
| 346 |
+
```
|
| 347 |
+
|
| 348 |
+
---
|
| 349 |
+
|
| 350 |
+
## 7. 快速验证 (小规模测试)
|
| 351 |
+
|
| 352 |
+
```bash
|
| 353 |
+
# Step 1: 建 embedding cache
|
| 354 |
+
python build_sft.py --build-cache --data-root /path/to/data \
|
| 355 |
+
--categories vqa --device cuda:0
|
| 356 |
+
|
| 357 |
+
# Step 2: 生成 VLM caption (先小规模测试)
|
| 358 |
+
python generate_captions.py \
|
| 359 |
+
--api-base http://10.51.6.110/v1 \
|
| 360 |
+
--model Qwen3-VL-8B-Instruct \
|
| 361 |
+
--emb-cache-dir ./output/emb_cache \
|
| 362 |
+
--output-dir ./output/caption_cache \
|
| 363 |
+
--num-workers 128 --save-every 50
|
| 364 |
+
|
| 365 |
+
# Step 3: 检查 caption 质量
|
| 366 |
+
python -c "
|
| 367 |
+
import json
|
| 368 |
+
d = json.load(open('./output/caption_cache/vqa_vqav2.json'))
|
| 369 |
+
for k, v in list(d['items'].items())[:10]:
|
| 370 |
+
print(f'{k}\n → {v}\n')
|
| 371 |
+
"
|
| 372 |
+
|
| 373 |
+
# Step 4: 构造 SFT 数据 (100 条快速测试)
|
| 374 |
+
python build_sft.py --data-root /path/to/data \
|
| 375 |
+
--caption-cache-dir ./output/caption_cache \
|
| 376 |
+
--categories vqa --samples-per-cat 100
|
| 377 |
+
|
| 378 |
+
# Step 5: 检查生成结果
|
| 379 |
+
python -c "
|
| 380 |
+
import json
|
| 381 |
+
with open('./output/vqa/sft.part00.jsonl') as f:
|
| 382 |
+
for i, line in enumerate(f):
|
| 383 |
+
if i >= 5: break
|
| 384 |
+
r = json.loads(line)
|
| 385 |
+
for c in r['conversations']:
|
| 386 |
+
print(f'[{c[\"from\"]}] {c[\"value\"][:120]}')
|
| 387 |
+
print('---')
|
| 388 |
+
"
|
| 389 |
+
```
|
ICL/SFT_new/convert_and_eval.sh
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# =============================================================================
|
| 3 |
+
# DeepSpeed ZeRO checkpoint -> HuggingFace 格式转换 + 跑评测
|
| 4 |
+
#
|
| 5 |
+
# 用法:
|
| 6 |
+
# bash convert_and_eval.sh # 转换 epoch3_step1406,8卡评测
|
| 7 |
+
# bash convert_and_eval.sh final # 转换 final checkpoint
|
| 8 |
+
# bash convert_and_eval.sh epoch2_step937 # 转换指定 checkpoint
|
| 9 |
+
# NUM_GPUS=4 bash convert_and_eval.sh # 4卡评测
|
| 10 |
+
# SKIP_EVAL=1 bash convert_and_eval.sh # 只转换不评测
|
| 11 |
+
# =============================================================================
|
| 12 |
+
set -euo pipefail
|
| 13 |
+
|
| 14 |
+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
| 15 |
+
|
| 16 |
+
# ---- 参数 ----
|
| 17 |
+
CKPT_TAG="${1:-epoch3_step1406}"
|
| 18 |
+
CKPT_DIR="/workspace/xiaobin/ICL/sft_model"
|
| 19 |
+
BASE_MODEL="/workspace/models/Qwen3-VL-8B-Instruct"
|
| 20 |
+
OUTPUT_DIR="${CKPT_DIR}/${CKPT_TAG}_fp32"
|
| 21 |
+
NUM_GPUS="${NUM_GPUS:-8}"
|
| 22 |
+
BATCH_SIZE="${BATCH_SIZE:-32}"
|
| 23 |
+
SKIP_EVAL="${SKIP_EVAL:-0}"
|
| 24 |
+
|
| 25 |
+
echo "============================================"
|
| 26 |
+
echo " Checkpoint: ${CKPT_TAG}"
|
| 27 |
+
echo " Source: ${CKPT_DIR}/${CKPT_TAG}"
|
| 28 |
+
echo " Output: ${OUTPUT_DIR}"
|
| 29 |
+
echo " Base model: ${BASE_MODEL}"
|
| 30 |
+
echo "============================================"
|
| 31 |
+
|
| 32 |
+
# ---- Step 1: 检查源 checkpoint 存在 ----
|
| 33 |
+
if [ ! -d "${CKPT_DIR}/${CKPT_TAG}" ]; then
|
| 34 |
+
echo "[ERROR] Checkpoint not found: ${CKPT_DIR}/${CKPT_TAG}"
|
| 35 |
+
echo "Available checkpoints:"
|
| 36 |
+
ls -d "${CKPT_DIR}"/epoch* "${CKPT_DIR}"/final 2>/dev/null || echo " (none)"
|
| 37 |
+
exit 1
|
| 38 |
+
fi
|
| 39 |
+
|
| 40 |
+
# ---- Step 2: 转换 DeepSpeed ZeRO -> fp32 ----
|
| 41 |
+
if [ -d "${OUTPUT_DIR}" ] && [ "$(ls -A "${OUTPUT_DIR}" 2>/dev/null)" ]; then
|
| 42 |
+
echo "[SKIP] ${OUTPUT_DIR} already exists, skipping conversion."
|
| 43 |
+
echo " Delete it if you want to re-convert."
|
| 44 |
+
else
|
| 45 |
+
echo "[1/3] Converting DeepSpeed ZeRO checkpoint to fp32..."
|
| 46 |
+
mkdir -p "${OUTPUT_DIR}"
|
| 47 |
+
python3 "${CKPT_DIR}/zero_to_fp32.py" \
|
| 48 |
+
"${CKPT_DIR}" \
|
| 49 |
+
"${OUTPUT_DIR}" \
|
| 50 |
+
--tag "${CKPT_TAG}" \
|
| 51 |
+
--safe_serialization
|
| 52 |
+
echo "Done."
|
| 53 |
+
fi
|
| 54 |
+
|
| 55 |
+
# ---- Step 3: 拷贝 config / tokenizer ----
|
| 56 |
+
echo "[2/3] Copying config & tokenizer from base model..."
|
| 57 |
+
FILES_TO_COPY=(
|
| 58 |
+
config.json
|
| 59 |
+
tokenizer.json
|
| 60 |
+
tokenizer_config.json
|
| 61 |
+
generation_config.json
|
| 62 |
+
preprocessor_config.json
|
| 63 |
+
video_preprocessor_config.json
|
| 64 |
+
special_tokens_map.json
|
| 65 |
+
chat_template.json
|
| 66 |
+
merges.txt
|
| 67 |
+
vocab.json
|
| 68 |
+
)
|
| 69 |
+
copied=0
|
| 70 |
+
for f in "${FILES_TO_COPY[@]}"; do
|
| 71 |
+
if [ -f "${BASE_MODEL}/${f}" ] && [ ! -f "${OUTPUT_DIR}/${f}" ]; then
|
| 72 |
+
cp "${BASE_MODEL}/${f}" "${OUTPUT_DIR}/"
|
| 73 |
+
copied=$((copied + 1))
|
| 74 |
+
fi
|
| 75 |
+
done
|
| 76 |
+
echo "Copied ${copied} files. Model ready at: ${OUTPUT_DIR}"
|
| 77 |
+
|
| 78 |
+
# ---- Step 4: 跑评测 ----
|
| 79 |
+
if [ "${SKIP_EVAL}" = "1" ]; then
|
| 80 |
+
echo "[3/3] SKIP_EVAL=1, skipping evaluation."
|
| 81 |
+
echo "To run eval manually:"
|
| 82 |
+
echo " MODEL_PATH=${OUTPUT_DIR} BATCH_SIZE=${BATCH_SIZE} bash ${SCRIPT_DIR}/run_eval.sh ${NUM_GPUS}"
|
| 83 |
+
exit 0
|
| 84 |
+
fi
|
| 85 |
+
|
| 86 |
+
echo "[3/3] Running evaluation (${NUM_GPUS} GPUs, batch_size=${BATCH_SIZE})..."
|
| 87 |
+
MODEL_PATH="${OUTPUT_DIR}" BATCH_SIZE="${BATCH_SIZE}" bash "${SCRIPT_DIR}/run_eval.sh" "${NUM_GPUS}"
|
ICL/SFT_new/ds_zero2.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bf16": {
|
| 3 |
+
"enabled": true
|
| 4 |
+
},
|
| 5 |
+
"zero_optimization": {
|
| 6 |
+
"stage": 2,
|
| 7 |
+
"overlap_comm": true,
|
| 8 |
+
"contiguous_gradients": true,
|
| 9 |
+
"reduce_scatter": true,
|
| 10 |
+
"reduce_bucket_size": 5e8,
|
| 11 |
+
"allgather_bucket_size": 5e8
|
| 12 |
+
},
|
| 13 |
+
"optimizer": {
|
| 14 |
+
"type": "AdamW",
|
| 15 |
+
"params": {
|
| 16 |
+
"lr": 1e-6,
|
| 17 |
+
"betas": [0.9, 0.999],
|
| 18 |
+
"eps": 1e-8,
|
| 19 |
+
"weight_decay": 0.1
|
| 20 |
+
}
|
| 21 |
+
},
|
| 22 |
+
"scheduler": {
|
| 23 |
+
"type": "WarmupDecayLR",
|
| 24 |
+
"params": {
|
| 25 |
+
"warmup_min_lr": 0,
|
| 26 |
+
"warmup_max_lr": 1e-6,
|
| 27 |
+
"warmup_num_steps": 50,
|
| 28 |
+
"total_num_steps": 950
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
"gradient_accumulation_steps": 4,
|
| 32 |
+
"gradient_clipping": 1.0,
|
| 33 |
+
"train_batch_size": 64,
|
| 34 |
+
"train_micro_batch_size_per_gpu": 2,
|
| 35 |
+
"wall_clock_breakdown": false,
|
| 36 |
+
"steps_per_print": 50
|
| 37 |
+
}
|
ICL/SFT_new/ds_zero3.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bf16": {
|
| 3 |
+
"enabled": true
|
| 4 |
+
},
|
| 5 |
+
"zero_optimization": {
|
| 6 |
+
"stage": 3,
|
| 7 |
+
"overlap_comm": true,
|
| 8 |
+
"contiguous_gradients": true,
|
| 9 |
+
"reduce_bucket_size": 5e8,
|
| 10 |
+
"stage3_prefetch_bucket_size": 5e8,
|
| 11 |
+
"stage3_param_persistence_threshold": 1e6,
|
| 12 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
| 13 |
+
},
|
| 14 |
+
"optimizer": {
|
| 15 |
+
"type": "AdamW",
|
| 16 |
+
"params": {
|
| 17 |
+
"lr": 1e-5,
|
| 18 |
+
"betas": [0.9, 0.999],
|
| 19 |
+
"eps": 1e-8,
|
| 20 |
+
"weight_decay": 0.1
|
| 21 |
+
}
|
| 22 |
+
},
|
| 23 |
+
"gradient_accumulation_steps": 4,
|
| 24 |
+
"gradient_clipping": 1.0,
|
| 25 |
+
"train_micro_batch_size_per_gpu": 2,
|
| 26 |
+
"wall_clock_breakdown": false,
|
| 27 |
+
"steps_per_print": 50
|
| 28 |
+
}
|
ICL/SFT_new/eval.py
ADDED
|
@@ -0,0 +1,961 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
ICL 多轮推理评测脚本:模拟 RET/ANS 决策循环,验证 SFT 模型效果。
|
| 5 |
+
|
| 6 |
+
流程:
|
| 7 |
+
1. 从 source index 的 val split 加载原始记录(与训练集无重叠)
|
| 8 |
+
2. 给模型 query_image + question(0-shot)
|
| 9 |
+
3. 模型输出 <RET> → 从预计算 top5 取下一张 shot + caption,追加 context,再问
|
| 10 |
+
4. 模型输出 <ANS> → 提取答案,结束
|
| 11 |
+
5. 最多 max_rounds 轮(防止死循环 RET)
|
| 12 |
+
|
| 13 |
+
多卡策略:
|
| 14 |
+
- 每张 GPU 加载一份模型,按 dataset 粒度分配任务
|
| 15 |
+
- 只有 rank 0 打印进度日志(其他 rank 静默)
|
| 16 |
+
- 最后 rank 0 汇总并写出有序 JSON log
|
| 17 |
+
|
| 18 |
+
用法:
|
| 19 |
+
# 单卡 (debug)
|
| 20 |
+
python3 eval.py \\
|
| 21 |
+
--model-path /workspace/xiaobin/ICL/sft_model/merged_hf \\
|
| 22 |
+
--category vqa --dataset vqav2 --split val \\
|
| 23 |
+
--num-samples 20 --device cuda:0
|
| 24 |
+
|
| 25 |
+
# 多卡
|
| 26 |
+
torchrun --nproc_per_node=8 eval.py \\
|
| 27 |
+
--model-path /workspace/xiaobin/ICL/sft_model/merged_hf \\
|
| 28 |
+
--all-categories --split val --num-samples 200
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
import argparse
|
| 32 |
+
import json
|
| 33 |
+
import math
|
| 34 |
+
import os
|
| 35 |
+
import random
|
| 36 |
+
import re
|
| 37 |
+
import sys
|
| 38 |
+
import time
|
| 39 |
+
from collections import defaultdict
|
| 40 |
+
from pathlib import Path
|
| 41 |
+
from typing import Dict, List, Optional, Tuple
|
| 42 |
+
|
| 43 |
+
import torch
|
| 44 |
+
import torch.distributed as dist
|
| 45 |
+
|
| 46 |
+
# 绕过 transformers 对 torch<2.6 的 torch.load 安全检查 (CVE-2025-32434)
|
| 47 |
+
# 在 import transformers 之前 patch modeling_utils.load_state_dict
|
| 48 |
+
import transformers.utils.import_utils as _tu
|
| 49 |
+
if hasattr(_tu, "check_torch_load_is_safe"):
|
| 50 |
+
_tu.check_torch_load_is_safe = lambda: None
|
| 51 |
+
import transformers.modeling_utils as _mu
|
| 52 |
+
if hasattr(_mu, "check_torch_load_is_safe"):
|
| 53 |
+
_mu.check_torch_load_is_safe = lambda: None
|
| 54 |
+
# 直接 patch load_state_dict 里调用的那个
|
| 55 |
+
_orig_load_state_dict = getattr(_mu, "load_state_dict", None)
|
| 56 |
+
if _orig_load_state_dict is not None:
|
| 57 |
+
import functools
|
| 58 |
+
@functools.wraps(_orig_load_state_dict)
|
| 59 |
+
def _patched_load_state_dict(checkpoint_file, **kwargs):
|
| 60 |
+
# 直接用 torch.load 跳过安全检查
|
| 61 |
+
return torch.load(checkpoint_file, map_location="cpu", weights_only=False)
|
| 62 |
+
_mu.load_state_dict = _patched_load_state_dict
|
| 63 |
+
|
| 64 |
+
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
|
| 65 |
+
from qwen_vl_utils import process_vision_info
|
| 66 |
+
from tqdm import tqdm
|
| 67 |
+
|
| 68 |
+
# ---------------------------------------------------------------------------
|
| 69 |
+
# 默认路径
|
| 70 |
+
# ---------------------------------------------------------------------------
|
| 71 |
+
INDEX_ROOT = "/workspace/xiaobin/dataset/index"
|
| 72 |
+
EMBEDDINGS_DIR = "/workspace/xiaobin/dataset/embeddings"
|
| 73 |
+
CAPTION_CACHE_DIR = "/workspace/xiaobin/dataset/caption_cache"
|
| 74 |
+
|
| 75 |
+
# ---------------------------------------------------------------------------
|
| 76 |
+
# 分布式工具
|
| 77 |
+
# ---------------------------------------------------------------------------
|
| 78 |
+
|
| 79 |
+
def setup_distributed():
|
| 80 |
+
"""初始化分布式环境,返回 (rank, world_size, device)。"""
|
| 81 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
| 82 |
+
rank = int(os.environ["RANK"])
|
| 83 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
| 84 |
+
local_rank = int(os.environ.get("LOCAL_RANK", rank))
|
| 85 |
+
dist.init_process_group("nccl")
|
| 86 |
+
torch.cuda.set_device(local_rank)
|
| 87 |
+
device = f"cuda:{local_rank}"
|
| 88 |
+
else:
|
| 89 |
+
rank, world_size = 0, 1
|
| 90 |
+
device = None
|
| 91 |
+
return rank, world_size, device
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def gather_results(local_results: List[Dict], rank: int, world_size: int) -> List[Dict]:
|
| 95 |
+
"""各 rank 结果汇总到 rank 0。"""
|
| 96 |
+
if world_size == 1:
|
| 97 |
+
return local_results
|
| 98 |
+
|
| 99 |
+
data = json.dumps(local_results, ensure_ascii=False).encode("utf-8")
|
| 100 |
+
size = torch.tensor([len(data)], dtype=torch.long, device=f"cuda:{rank}")
|
| 101 |
+
|
| 102 |
+
size_list = [torch.zeros(1, dtype=torch.long, device=f"cuda:{rank}") for _ in range(world_size)]
|
| 103 |
+
dist.all_gather(size_list, size)
|
| 104 |
+
max_size = max(s.item() for s in size_list)
|
| 105 |
+
|
| 106 |
+
padded = data + b"\x00" * (max_size - len(data))
|
| 107 |
+
tensor = torch.ByteTensor(list(padded)).cuda(rank)
|
| 108 |
+
tensor_list = [torch.zeros(max_size, dtype=torch.uint8, device=f"cuda:{rank}") for _ in range(world_size)]
|
| 109 |
+
dist.all_gather(tensor_list, tensor)
|
| 110 |
+
|
| 111 |
+
if rank == 0:
|
| 112 |
+
all_results = []
|
| 113 |
+
for t, s in zip(tensor_list, size_list):
|
| 114 |
+
raw = bytes(t[: s.item()].cpu().tolist())
|
| 115 |
+
all_results.extend(json.loads(raw.decode("utf-8")))
|
| 116 |
+
return all_results
|
| 117 |
+
return []
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def log(msg: str, rank: int = 0, force: bool = False):
|
| 121 |
+
"""只在 rank 0 或 force=True 时打印。"""
|
| 122 |
+
if rank == 0 or force:
|
| 123 |
+
print(msg, flush=True)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# ---------------------------------------------------------------------------
|
| 127 |
+
# 数据加载
|
| 128 |
+
# ---------------------------------------------------------------------------
|
| 129 |
+
|
| 130 |
+
def load_records(cat: str, ds: str, split: str, limit: int = 0) -> List[Dict]:
|
| 131 |
+
"""从 index root 加载指定 split 的记录。"""
|
| 132 |
+
path = os.path.join(INDEX_ROOT, cat, ds, f"{split}.jsonl")
|
| 133 |
+
if not os.path.exists(path):
|
| 134 |
+
return []
|
| 135 |
+
records = []
|
| 136 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 137 |
+
for line in f:
|
| 138 |
+
line = line.strip()
|
| 139 |
+
if not line:
|
| 140 |
+
continue
|
| 141 |
+
r = json.loads(line)
|
| 142 |
+
if r.get("image") and r.get("answer"):
|
| 143 |
+
records.append(r)
|
| 144 |
+
if limit and len(records) >= limit:
|
| 145 |
+
break
|
| 146 |
+
return records
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def load_top5(cat: str, ds: str) -> Dict[str, List[str]]:
|
| 150 |
+
path = os.path.join(EMBEDDINGS_DIR, f"{cat}_{ds}_top5.json")
|
| 151 |
+
if not os.path.exists(path):
|
| 152 |
+
return {}
|
| 153 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 154 |
+
return json.load(f)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def load_caption_cache(cat: str, ds: str) -> Dict[str, str]:
|
| 158 |
+
path = os.path.join(CAPTION_CACHE_DIR, f"{cat}_{ds}.json")
|
| 159 |
+
if not os.path.exists(path):
|
| 160 |
+
return {}
|
| 161 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 162 |
+
data = json.load(f)
|
| 163 |
+
if isinstance(data, dict) and "items" in data:
|
| 164 |
+
return data["items"]
|
| 165 |
+
return data if isinstance(data, dict) else {}
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def load_instructions(cat: str, ds: str) -> List[str]:
|
| 169 |
+
path = os.path.join(INDEX_ROOT, cat, ds, "instructions.json")
|
| 170 |
+
if not os.path.exists(path):
|
| 171 |
+
return ["Look at the image and answer the question."]
|
| 172 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 173 |
+
data = json.load(f)
|
| 174 |
+
if isinstance(data, list):
|
| 175 |
+
return [str(x).strip() for x in data if str(x).strip()]
|
| 176 |
+
if isinstance(data, dict):
|
| 177 |
+
for key in ("instructions", "instruction", "prompts"):
|
| 178 |
+
v = data.get(key)
|
| 179 |
+
if isinstance(v, list):
|
| 180 |
+
return [str(x).strip() for x in v if str(x).strip()]
|
| 181 |
+
return ["Look at the image and answer the question."]
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def discover_datasets(categories: List[str]) -> List[Tuple[str, str]]:
|
| 185 |
+
results = []
|
| 186 |
+
for cat in sorted(os.listdir(INDEX_ROOT)):
|
| 187 |
+
if categories and cat not in categories:
|
| 188 |
+
continue
|
| 189 |
+
cat_dir = os.path.join(INDEX_ROOT, cat)
|
| 190 |
+
if not os.path.isdir(cat_dir):
|
| 191 |
+
continue
|
| 192 |
+
for ds in sorted(os.listdir(cat_dir)):
|
| 193 |
+
if os.path.isdir(os.path.join(cat_dir, ds)):
|
| 194 |
+
results.append((cat, ds))
|
| 195 |
+
return results
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# ---------------------------------------------------------------------------
|
| 199 |
+
# 模型加载
|
| 200 |
+
# ---------------------------------------------------------------------------
|
| 201 |
+
|
| 202 |
+
def load_model(model_path: str, device: str):
|
| 203 |
+
from transformers import AutoConfig
|
| 204 |
+
|
| 205 |
+
processor = AutoProcessor.from_pretrained(
|
| 206 |
+
model_path,
|
| 207 |
+
trust_remote_code=True,
|
| 208 |
+
min_pixels=256 * 28 * 28,
|
| 209 |
+
max_pixels=1280 * 28 * 28,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# 先添加 special tokens 到 tokenizer,这样 vocab_size 对齐 checkpoint
|
| 213 |
+
special_tokens = ["<RET>", "<ANS>", "</ANS>", "<RETQ>", "</RETQ>"]
|
| 214 |
+
processor.tokenizer.add_tokens(special_tokens, special_tokens=True)
|
| 215 |
+
# batch 推理 decoder-only 模型必须左 padding
|
| 216 |
+
processor.tokenizer.padding_side = "left"
|
| 217 |
+
target_vocab_size = len(processor.tokenizer)
|
| 218 |
+
|
| 219 |
+
# 关键:把 config 的 vocab_size 改成 checkpoint 实际大小,
|
| 220 |
+
# 否则 ignore_mismatched_sizes 会导致 embed_tokens/lm_head 被随机初始化!
|
| 221 |
+
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
| 222 |
+
print(f"[load_model] text_config.vocab_size={config.text_config.vocab_size}, target={target_vocab_size}")
|
| 223 |
+
config.text_config.vocab_size = target_vocab_size
|
| 224 |
+
|
| 225 |
+
model = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 226 |
+
model_path,
|
| 227 |
+
config=config,
|
| 228 |
+
trust_remote_code=True,
|
| 229 |
+
torch_dtype=torch.bfloat16,
|
| 230 |
+
attn_implementation="sdpa",
|
| 231 |
+
device_map=device,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
model.eval()
|
| 235 |
+
|
| 236 |
+
ret_id = processor.tokenizer.convert_tokens_to_ids("<RET>")
|
| 237 |
+
ans_id = processor.tokenizer.convert_tokens_to_ids("<ANS>")
|
| 238 |
+
return model, processor, ret_id, ans_id
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# ---------------------------------------------------------------------------
|
| 242 |
+
# 推理核心
|
| 243 |
+
# ---------------------------------------------------------------------------
|
| 244 |
+
|
| 245 |
+
def build_messages(
|
| 246 |
+
instruction: str,
|
| 247 |
+
query_image: str,
|
| 248 |
+
question: Optional[str],
|
| 249 |
+
shots: List[Dict],
|
| 250 |
+
min_pixels: int = 256 * 28 * 28,
|
| 251 |
+
max_pixels: int = 1280 * 28 * 28,
|
| 252 |
+
) -> List[Dict]:
|
| 253 |
+
"""构建 Qwen3-VL chat messages。"""
|
| 254 |
+
user_content = []
|
| 255 |
+
|
| 256 |
+
if instruction:
|
| 257 |
+
user_content.append({"type": "text", "text": instruction})
|
| 258 |
+
|
| 259 |
+
user_content.append({
|
| 260 |
+
"type": "image",
|
| 261 |
+
"image": f"file://{query_image}",
|
| 262 |
+
"min_pixels": min_pixels,
|
| 263 |
+
"max_pixels": max_pixels,
|
| 264 |
+
})
|
| 265 |
+
|
| 266 |
+
if question:
|
| 267 |
+
user_content.append({"type": "text", "text": f"Question: {question}"})
|
| 268 |
+
|
| 269 |
+
for shot in shots:
|
| 270 |
+
user_content.append({
|
| 271 |
+
"type": "image",
|
| 272 |
+
"image": f"file://{shot['image']}",
|
| 273 |
+
"min_pixels": min_pixels,
|
| 274 |
+
"max_pixels": max_pixels,
|
| 275 |
+
})
|
| 276 |
+
if shot.get("caption"):
|
| 277 |
+
user_content.append({"type": "text", "text": f"Caption: {shot['caption']}"})
|
| 278 |
+
|
| 279 |
+
user_content.append({"type": "text", "text": "Action:"})
|
| 280 |
+
return [{"role": "user", "content": user_content}]
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
@torch.no_grad()
|
| 284 |
+
def generate_action(model, processor, messages: List[Dict], max_new_tokens: int = 256) -> str:
|
| 285 |
+
"""单条推理(fallback 用)。"""
|
| 286 |
+
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 287 |
+
|
| 288 |
+
image_inputs = None
|
| 289 |
+
try:
|
| 290 |
+
image_inputs, _ = process_vision_info(messages)
|
| 291 |
+
except Exception:
|
| 292 |
+
pass
|
| 293 |
+
|
| 294 |
+
inputs = processor(
|
| 295 |
+
text=[text],
|
| 296 |
+
images=image_inputs if image_inputs else None,
|
| 297 |
+
return_tensors="pt",
|
| 298 |
+
padding=False,
|
| 299 |
+
truncation=False,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
device = next(model.parameters()).device
|
| 303 |
+
inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
|
| 304 |
+
|
| 305 |
+
outputs = model.generate(
|
| 306 |
+
**inputs,
|
| 307 |
+
max_new_tokens=max_new_tokens,
|
| 308 |
+
do_sample=False,
|
| 309 |
+
temperature=None,
|
| 310 |
+
top_p=None,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
input_len = inputs["input_ids"].shape[1]
|
| 314 |
+
generated = outputs[0][input_len:]
|
| 315 |
+
return processor.tokenizer.decode(generated, skip_special_tokens=False)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
@torch.no_grad()
|
| 319 |
+
def generate_action_batch(
|
| 320 |
+
model, processor, messages_list: List[List[Dict]],
|
| 321 |
+
max_new_tokens: int = 256, batch_size: int = 4,
|
| 322 |
+
pbar=None,
|
| 323 |
+
) -> List[str]:
|
| 324 |
+
"""批量推理,按 batch_size 分批处理。每个 batch 完成后更新 pbar。"""
|
| 325 |
+
all_results = []
|
| 326 |
+
device = next(model.parameters()).device
|
| 327 |
+
|
| 328 |
+
for start in range(0, len(messages_list), batch_size):
|
| 329 |
+
batch_msgs = messages_list[start : start + batch_size]
|
| 330 |
+
|
| 331 |
+
texts = []
|
| 332 |
+
all_images_nested = [] # 嵌套 list: [[sample0 imgs], [sample1 imgs], ...]
|
| 333 |
+
has_any_image = False
|
| 334 |
+
for msgs in batch_msgs:
|
| 335 |
+
texts.append(processor.apply_chat_template(
|
| 336 |
+
msgs, tokenize=False, add_generation_prompt=True
|
| 337 |
+
))
|
| 338 |
+
try:
|
| 339 |
+
imgs, _ = process_vision_info(msgs)
|
| 340 |
+
if imgs:
|
| 341 |
+
all_images_nested.append(imgs)
|
| 342 |
+
has_any_image = True
|
| 343 |
+
else:
|
| 344 |
+
all_images_nested.append([])
|
| 345 |
+
except Exception:
|
| 346 |
+
all_images_nested.append([])
|
| 347 |
+
|
| 348 |
+
inputs = processor(
|
| 349 |
+
text=texts,
|
| 350 |
+
images=all_images_nested if has_any_image else None,
|
| 351 |
+
return_tensors="pt",
|
| 352 |
+
padding=True,
|
| 353 |
+
truncation=False,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
|
| 357 |
+
|
| 358 |
+
outputs = model.generate(
|
| 359 |
+
**inputs,
|
| 360 |
+
max_new_tokens=max_new_tokens,
|
| 361 |
+
do_sample=False,
|
| 362 |
+
temperature=None,
|
| 363 |
+
top_p=None,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
# 解码每条(左 padding 时,所有样本的 padded 输入长度相同)
|
| 367 |
+
input_len = inputs["input_ids"].shape[1]
|
| 368 |
+
for i in range(len(batch_msgs)):
|
| 369 |
+
generated = outputs[i][input_len:]
|
| 370 |
+
text = processor.tokenizer.decode(generated, skip_special_tokens=False)
|
| 371 |
+
all_results.append(text)
|
| 372 |
+
|
| 373 |
+
# 每个 batch 完成后更新进度条
|
| 374 |
+
if pbar is not None:
|
| 375 |
+
pbar.set_postfix_str(f"batch {start // batch_size + 1}/{math.ceil(len(messages_list) / batch_size)}")
|
| 376 |
+
|
| 377 |
+
return all_results
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def parse_action(text: str) -> Tuple[str, str]:
|
| 381 |
+
"""解析模型输出,返回 (action, content)。"""
|
| 382 |
+
text = text.strip()
|
| 383 |
+
|
| 384 |
+
if text.startswith("<RET>"):
|
| 385 |
+
desc = text[len("<RET>"):].strip()
|
| 386 |
+
if desc.startswith("Description:"):
|
| 387 |
+
desc = desc[len("Description:"):].strip()
|
| 388 |
+
for tok in ["<|im_end|>", "</s>", "<|endoftext|>"]:
|
| 389 |
+
desc = desc.replace(tok, "").strip()
|
| 390 |
+
return "ret", desc
|
| 391 |
+
|
| 392 |
+
if text.startswith("<ANS>"):
|
| 393 |
+
ans = text[len("<ANS>"):]
|
| 394 |
+
end_idx = ans.find("</ANS>")
|
| 395 |
+
if end_idx != -1:
|
| 396 |
+
ans = ans[:end_idx]
|
| 397 |
+
else:
|
| 398 |
+
for tok in ["<|im_end|>", "</s>", "<|endoftext|>"]:
|
| 399 |
+
ans = ans.replace(tok, "").strip()
|
| 400 |
+
return "ans", ans.strip()
|
| 401 |
+
|
| 402 |
+
return "unknown", text
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def run_icl_loop(
|
| 406 |
+
model,
|
| 407 |
+
processor,
|
| 408 |
+
record: Dict,
|
| 409 |
+
instruction: str,
|
| 410 |
+
top5: Dict[str, List[str]],
|
| 411 |
+
caption_cache: Dict[str, str],
|
| 412 |
+
max_rounds: int = 4,
|
| 413 |
+
) -> Dict:
|
| 414 |
+
"""对单条记录跑多轮 RET/ANS 循环(fallback 用)。"""
|
| 415 |
+
query_image = record["image"]
|
| 416 |
+
question = record.get("question", "")
|
| 417 |
+
gt_answer = record.get("answer", "")
|
| 418 |
+
|
| 419 |
+
shots = []
|
| 420 |
+
used_images = {query_image}
|
| 421 |
+
rounds = []
|
| 422 |
+
candidates = top5.get(query_image, [])
|
| 423 |
+
|
| 424 |
+
for round_idx in range(max_rounds):
|
| 425 |
+
messages = build_messages(instruction, query_image, question, shots)
|
| 426 |
+
raw_output = generate_action(model, processor, messages)
|
| 427 |
+
action, content = parse_action(raw_output)
|
| 428 |
+
|
| 429 |
+
rounds.append({
|
| 430 |
+
"round": round_idx,
|
| 431 |
+
"action": action,
|
| 432 |
+
"content": content,
|
| 433 |
+
"raw": raw_output[:300],
|
| 434 |
+
})
|
| 435 |
+
|
| 436 |
+
if action == "ans":
|
| 437 |
+
return {
|
| 438 |
+
"image": query_image,
|
| 439 |
+
"question": question,
|
| 440 |
+
"gt_answer": gt_answer,
|
| 441 |
+
"final_answer": content,
|
| 442 |
+
"num_rounds": round_idx + 1,
|
| 443 |
+
"terminated_by": "ans",
|
| 444 |
+
"rounds": rounds,
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
if action == "ret":
|
| 448 |
+
next_image = None
|
| 449 |
+
for c in candidates:
|
| 450 |
+
if c not in used_images:
|
| 451 |
+
next_image = c
|
| 452 |
+
break
|
| 453 |
+
|
| 454 |
+
if next_image is None:
|
| 455 |
+
return {
|
| 456 |
+
"image": query_image,
|
| 457 |
+
"question": question,
|
| 458 |
+
"gt_answer": gt_answer,
|
| 459 |
+
"final_answer": None,
|
| 460 |
+
"num_rounds": round_idx + 1,
|
| 461 |
+
"terminated_by": "no_more_shots",
|
| 462 |
+
"rounds": rounds,
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
cap = caption_cache.get(next_image, content)
|
| 466 |
+
shots.append({"image": next_image, "caption": cap})
|
| 467 |
+
used_images.add(next_image)
|
| 468 |
+
else:
|
| 469 |
+
return {
|
| 470 |
+
"image": query_image,
|
| 471 |
+
"question": question,
|
| 472 |
+
"gt_answer": gt_answer,
|
| 473 |
+
"final_answer": content,
|
| 474 |
+
"num_rounds": round_idx + 1,
|
| 475 |
+
"terminated_by": "unknown_action",
|
| 476 |
+
"rounds": rounds,
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
return {
|
| 480 |
+
"image": query_image,
|
| 481 |
+
"question": question,
|
| 482 |
+
"gt_answer": gt_answer,
|
| 483 |
+
"final_answer": None,
|
| 484 |
+
"num_rounds": max_rounds,
|
| 485 |
+
"terminated_by": "max_rounds",
|
| 486 |
+
"rounds": rounds,
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def run_icl_batch(
|
| 491 |
+
model, processor,
|
| 492 |
+
records: List[Dict],
|
| 493 |
+
instructions: List[str],
|
| 494 |
+
top5: Dict[str, List[str]],
|
| 495 |
+
caption_cache: Dict[str, str],
|
| 496 |
+
max_rounds: int = 4,
|
| 497 |
+
batch_size: int = 4,
|
| 498 |
+
rank: int = 0,
|
| 499 |
+
ds_label: str = "",
|
| 500 |
+
) -> List[Dict]:
|
| 501 |
+
"""对一批记录做 round-parallel 的批量 ICL 推理。
|
| 502 |
+
|
| 503 |
+
Round 0: 所有样本 batch 推理
|
| 504 |
+
Round 1: RET 的样本加 shot 后 batch 推理
|
| 505 |
+
...直到全部完成或 max_rounds
|
| 506 |
+
"""
|
| 507 |
+
rng = random.Random(42)
|
| 508 |
+
|
| 509 |
+
# 初始化每条样本的状态
|
| 510 |
+
states = []
|
| 511 |
+
for rec in records:
|
| 512 |
+
states.append({
|
| 513 |
+
"record": rec,
|
| 514 |
+
"instruction": rng.choice(instructions),
|
| 515 |
+
"query_image": rec["image"],
|
| 516 |
+
"question": rec.get("question", ""),
|
| 517 |
+
"gt_answer": rec.get("answer", ""),
|
| 518 |
+
"shots": [],
|
| 519 |
+
"used_images": {rec["image"]},
|
| 520 |
+
"candidates": top5.get(rec["image"], []),
|
| 521 |
+
"rounds": [],
|
| 522 |
+
"done": False,
|
| 523 |
+
"result": None,
|
| 524 |
+
})
|
| 525 |
+
|
| 526 |
+
total = len(states)
|
| 527 |
+
pbar = tqdm(total=total, desc=f" {ds_label}", unit="done",
|
| 528 |
+
disable=(rank != 0))
|
| 529 |
+
|
| 530 |
+
for round_idx in range(max_rounds):
|
| 531 |
+
# 收集未完成的样本
|
| 532 |
+
active = [(i, s) for i, s in enumerate(states) if not s["done"]]
|
| 533 |
+
if not active:
|
| 534 |
+
break
|
| 535 |
+
|
| 536 |
+
n_active = len(active)
|
| 537 |
+
pbar.set_postfix(round=round_idx, active=n_active)
|
| 538 |
+
|
| 539 |
+
# 构建 messages
|
| 540 |
+
messages_list = []
|
| 541 |
+
active_indices = []
|
| 542 |
+
for i, s in active:
|
| 543 |
+
msgs = build_messages(
|
| 544 |
+
s["instruction"], s["query_image"], s["question"], s["shots"]
|
| 545 |
+
)
|
| 546 |
+
messages_list.append(msgs)
|
| 547 |
+
active_indices.append(i)
|
| 548 |
+
|
| 549 |
+
# 批量推理
|
| 550 |
+
try:
|
| 551 |
+
raw_outputs = generate_action_batch(
|
| 552 |
+
model, processor, messages_list,
|
| 553 |
+
batch_size=batch_size,
|
| 554 |
+
pbar=pbar,
|
| 555 |
+
)
|
| 556 |
+
except Exception as e:
|
| 557 |
+
# batch 推理失败时 fallback 到逐条
|
| 558 |
+
log(f" [WARN] Batch failed at round {round_idx}, falling back to single: {e}", rank)
|
| 559 |
+
raw_outputs = []
|
| 560 |
+
for msgs in messages_list:
|
| 561 |
+
try:
|
| 562 |
+
raw_outputs.append(generate_action(model, processor, msgs))
|
| 563 |
+
except Exception:
|
| 564 |
+
raw_outputs.append("")
|
| 565 |
+
|
| 566 |
+
# 解析结果、更新状态
|
| 567 |
+
newly_done = 0
|
| 568 |
+
for idx_in_batch, global_idx in enumerate(active_indices):
|
| 569 |
+
s = states[global_idx]
|
| 570 |
+
raw = raw_outputs[idx_in_batch]
|
| 571 |
+
action, content = parse_action(raw)
|
| 572 |
+
|
| 573 |
+
s["rounds"].append({
|
| 574 |
+
"round": round_idx,
|
| 575 |
+
"action": action,
|
| 576 |
+
"content": content,
|
| 577 |
+
"raw": raw[:300],
|
| 578 |
+
})
|
| 579 |
+
|
| 580 |
+
if action == "ans":
|
| 581 |
+
s["done"] = True
|
| 582 |
+
s["result"] = {
|
| 583 |
+
"image": s["query_image"],
|
| 584 |
+
"question": s["question"],
|
| 585 |
+
"gt_answer": s["gt_answer"],
|
| 586 |
+
"final_answer": content,
|
| 587 |
+
"num_rounds": round_idx + 1,
|
| 588 |
+
"terminated_by": "ans",
|
| 589 |
+
"rounds": s["rounds"],
|
| 590 |
+
}
|
| 591 |
+
newly_done += 1
|
| 592 |
+
elif action == "ret":
|
| 593 |
+
next_image = None
|
| 594 |
+
for c in s["candidates"]:
|
| 595 |
+
if c not in s["used_images"]:
|
| 596 |
+
next_image = c
|
| 597 |
+
break
|
| 598 |
+
|
| 599 |
+
if next_image is None:
|
| 600 |
+
s["done"] = True
|
| 601 |
+
s["result"] = {
|
| 602 |
+
"image": s["query_image"],
|
| 603 |
+
"question": s["question"],
|
| 604 |
+
"gt_answer": s["gt_answer"],
|
| 605 |
+
"final_answer": None,
|
| 606 |
+
"num_rounds": round_idx + 1,
|
| 607 |
+
"terminated_by": "no_more_shots",
|
| 608 |
+
"rounds": s["rounds"],
|
| 609 |
+
}
|
| 610 |
+
newly_done += 1
|
| 611 |
+
else:
|
| 612 |
+
cap = caption_cache.get(next_image, content)
|
| 613 |
+
s["shots"].append({"image": next_image, "caption": cap})
|
| 614 |
+
s["used_images"].add(next_image)
|
| 615 |
+
else:
|
| 616 |
+
s["done"] = True
|
| 617 |
+
s["result"] = {
|
| 618 |
+
"image": s["query_image"],
|
| 619 |
+
"question": s["question"],
|
| 620 |
+
"gt_answer": s["gt_answer"],
|
| 621 |
+
"final_answer": content,
|
| 622 |
+
"num_rounds": round_idx + 1,
|
| 623 |
+
"terminated_by": "unknown_action",
|
| 624 |
+
"rounds": s["rounds"],
|
| 625 |
+
}
|
| 626 |
+
newly_done += 1
|
| 627 |
+
|
| 628 |
+
pbar.update(newly_done)
|
| 629 |
+
|
| 630 |
+
n_active = sum(1 for s in states if not s["done"])
|
| 631 |
+
if rank == 0:
|
| 632 |
+
pbar.set_postfix(round=round_idx, active=n_active)
|
| 633 |
+
|
| 634 |
+
# 处理还没完成的(达到 max_rounds)
|
| 635 |
+
for s in states:
|
| 636 |
+
if not s["done"]:
|
| 637 |
+
s["result"] = {
|
| 638 |
+
"image": s["query_image"],
|
| 639 |
+
"question": s["question"],
|
| 640 |
+
"gt_answer": s["gt_answer"],
|
| 641 |
+
"final_answer": None,
|
| 642 |
+
"num_rounds": max_rounds,
|
| 643 |
+
"terminated_by": "max_rounds",
|
| 644 |
+
"rounds": s["rounds"],
|
| 645 |
+
}
|
| 646 |
+
pbar.update(1)
|
| 647 |
+
|
| 648 |
+
pbar.close()
|
| 649 |
+
return [s["result"] for s in states]
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
# ---------------------------------------------------------------------------
|
| 653 |
+
# 答案质量指标
|
| 654 |
+
# ---------------------------------------------------------------------------
|
| 655 |
+
|
| 656 |
+
def normalize_answer(s: str) -> str:
|
| 657 |
+
"""归一化答案用于比较。"""
|
| 658 |
+
s = s.lower().strip()
|
| 659 |
+
# 去标点
|
| 660 |
+
s = re.sub(r"[^\w\s]", "", s)
|
| 661 |
+
# 去多余空格
|
| 662 |
+
s = " ".join(s.split())
|
| 663 |
+
return s
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
def compute_metrics(results: List[Dict]) -> Dict:
|
| 667 |
+
"""计算答案质量指标。"""
|
| 668 |
+
answered = [r for r in results if r.get("final_answer") is not None]
|
| 669 |
+
if not answered:
|
| 670 |
+
return {"exact_match": 0.0, "contains_gt": 0.0, "answer_rate": 0.0}
|
| 671 |
+
|
| 672 |
+
em_count = 0
|
| 673 |
+
contains_count = 0
|
| 674 |
+
|
| 675 |
+
for r in answered:
|
| 676 |
+
pred = normalize_answer(r["final_answer"])
|
| 677 |
+
gt = normalize_answer(r["gt_answer"])
|
| 678 |
+
|
| 679 |
+
if pred == gt:
|
| 680 |
+
em_count += 1
|
| 681 |
+
if gt in pred or pred in gt:
|
| 682 |
+
contains_count += 1
|
| 683 |
+
|
| 684 |
+
n_total = len(results)
|
| 685 |
+
n_answered = len(answered)
|
| 686 |
+
|
| 687 |
+
return {
|
| 688 |
+
"exact_match": em_count / n_answered * 100 if n_answered else 0.0,
|
| 689 |
+
"contains_gt": contains_count / n_answered * 100 if n_answered else 0.0,
|
| 690 |
+
"answer_rate": n_answered / n_total * 100 if n_total else 0.0,
|
| 691 |
+
"shot_distribution": compute_shot_distribution(results),
|
| 692 |
+
"avg_shots": compute_avg_shots(results),
|
| 693 |
+
}
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
def compute_shot_distribution(results: List[Dict]) -> Dict[str, int]:
|
| 697 |
+
"""统计 shot 数量分布。"""
|
| 698 |
+
shot_counts = defaultdict(int)
|
| 699 |
+
for r in results:
|
| 700 |
+
if r.get("terminated_by") == "ans":
|
| 701 |
+
n_shots = r["num_rounds"] - 1
|
| 702 |
+
else:
|
| 703 |
+
n_shots = r["num_rounds"]
|
| 704 |
+
shot_counts[f"{n_shots}-shot"] += 1
|
| 705 |
+
return dict(sorted(shot_counts.items()))
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
def compute_avg_shots(results: List[Dict]) -> float:
|
| 709 |
+
if not results:
|
| 710 |
+
return 0.0
|
| 711 |
+
total = 0
|
| 712 |
+
for r in results:
|
| 713 |
+
if r.get("terminated_by") == "ans":
|
| 714 |
+
total += r["num_rounds"] - 1
|
| 715 |
+
else:
|
| 716 |
+
total += r["num_rounds"]
|
| 717 |
+
return total / len(results)
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
# ---------------------------------------------------------------------------
|
| 721 |
+
# 统计输出
|
| 722 |
+
# ---------------------------------------------------------------------------
|
| 723 |
+
|
| 724 |
+
def print_stats(results: List[Dict], cat: str = "", ds: str = ""):
|
| 725 |
+
prefix = f"[{cat}/{ds}]" if ds else f"[{cat}]" if cat else "[ALL]"
|
| 726 |
+
n = len(results)
|
| 727 |
+
if n == 0:
|
| 728 |
+
print(f"{prefix} 无结果")
|
| 729 |
+
return
|
| 730 |
+
|
| 731 |
+
# 终止原因
|
| 732 |
+
term_counts = defaultdict(int)
|
| 733 |
+
for r in results:
|
| 734 |
+
term_counts[r["terminated_by"]] += 1
|
| 735 |
+
|
| 736 |
+
# 每轮 action 分布
|
| 737 |
+
round_actions = defaultdict(lambda: defaultdict(int))
|
| 738 |
+
for r in results:
|
| 739 |
+
for rd in r["rounds"]:
|
| 740 |
+
round_actions[rd["round"]][rd["action"]] += 1
|
| 741 |
+
|
| 742 |
+
avg_rounds = sum(r["num_rounds"] for r in results) / n
|
| 743 |
+
|
| 744 |
+
# 答案质量
|
| 745 |
+
metrics = compute_metrics(results)
|
| 746 |
+
|
| 747 |
+
print(f"\n{'=' * 64}")
|
| 748 |
+
print(f"{prefix} 共 {n} 条样本")
|
| 749 |
+
print(f" 平均轮次: {avg_rounds:.2f}")
|
| 750 |
+
print(f" 终止原因:")
|
| 751 |
+
for k, v in sorted(term_counts.items()):
|
| 752 |
+
print(f" {k}: {v} ({v / n * 100:.1f}%)")
|
| 753 |
+
|
| 754 |
+
print(f" 每轮 RET/ANS 分布:")
|
| 755 |
+
for rd_idx in sorted(round_actions.keys()):
|
| 756 |
+
actions = round_actions[rd_idx]
|
| 757 |
+
total = sum(actions.values())
|
| 758 |
+
parts = [f"{a}={c}({c / total * 100:.0f}%)" for a, c in sorted(actions.items())]
|
| 759 |
+
print(f" Round {rd_idx}: {' | '.join(parts)} (共 {total} 条)")
|
| 760 |
+
|
| 761 |
+
# Shot 数量统计(num_rounds - 1 = 回答前检索了几个 shot)
|
| 762 |
+
shot_counts = defaultdict(int)
|
| 763 |
+
for r in results:
|
| 764 |
+
if r["terminated_by"] == "ans":
|
| 765 |
+
n_shots = r["num_rounds"] - 1 # RET 次数 = 回答时已有的 shot 数
|
| 766 |
+
else:
|
| 767 |
+
n_shots = r["num_rounds"] # 没回答的,全是 RET
|
| 768 |
+
shot_counts[n_shots] += 1
|
| 769 |
+
|
| 770 |
+
print(f" Shot 数量分布 (回答时已有的 shot 数):")
|
| 771 |
+
for k in sorted(shot_counts.keys()):
|
| 772 |
+
v = shot_counts[k]
|
| 773 |
+
bar = "█" * int(v / n * 40)
|
| 774 |
+
print(f" {k}-shot: {v:4d} ({v / n * 100:5.1f}%) {bar}")
|
| 775 |
+
avg_shots = sum(k * v for k, v in shot_counts.items()) / n
|
| 776 |
+
print(f" 平均 shot 数: {avg_shots:.2f}")
|
| 777 |
+
|
| 778 |
+
answered = [r for r in results if r["final_answer"] is not None]
|
| 779 |
+
print(f" 产出答案: {len(answered)}/{n} ({metrics['answer_rate']:.1f}%)")
|
| 780 |
+
if answered:
|
| 781 |
+
print(f" 答案质量 (仅 ans 样本):")
|
| 782 |
+
print(f" Exact Match: {metrics['exact_match']:.1f}%")
|
| 783 |
+
print(f" Contains GT: {metrics['contains_gt']:.1f}%")
|
| 784 |
+
print(f"{'=' * 64}")
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
# ---------------------------------------------------------------------------
|
| 788 |
+
# Main
|
| 789 |
+
# ---------------------------------------------------------------------------
|
| 790 |
+
|
| 791 |
+
def main():
|
| 792 |
+
parser = argparse.ArgumentParser(description="ICL 多轮推理评测(支持多卡,log 对齐)")
|
| 793 |
+
parser.add_argument("--model-path", required=True, help="合并后的 HF 模型路径")
|
| 794 |
+
parser.add_argument("--category", type=str, default="")
|
| 795 |
+
parser.add_argument("--dataset", type=str, default="")
|
| 796 |
+
parser.add_argument("--split", type=str, default="val",
|
| 797 |
+
help="使用的数据 split(默认 val,与训练集 train 隔离)")
|
| 798 |
+
parser.add_argument("--all-categories", action="store_true")
|
| 799 |
+
parser.add_argument("--num-samples", type=int, default=100,
|
| 800 |
+
help="每个 dataset 采样数")
|
| 801 |
+
parser.add_argument("--max-rounds", type=int, default=4)
|
| 802 |
+
parser.add_argument("--batch-size", type=int, default=4,
|
| 803 |
+
help="每轮 batch 推理的样本数")
|
| 804 |
+
parser.add_argument("--device", type=str, default="cuda:0",
|
| 805 |
+
help="单卡时用的设备")
|
| 806 |
+
parser.add_argument("--output-dir", type=str,
|
| 807 |
+
default="/workspace/xiaobin/ICL/SFT_new/eval_results",
|
| 808 |
+
help="评测结果保存目录")
|
| 809 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 810 |
+
args = parser.parse_args()
|
| 811 |
+
|
| 812 |
+
random.seed(args.seed)
|
| 813 |
+
|
| 814 |
+
# ---- 分布式初始化 ----
|
| 815 |
+
rank, world_size, dist_device = setup_distributed()
|
| 816 |
+
device = dist_device or args.device
|
| 817 |
+
is_main = rank == 0
|
| 818 |
+
|
| 819 |
+
log(f"World size: {world_size}", rank)
|
| 820 |
+
log(f"Model: {args.model_path}", rank)
|
| 821 |
+
log(f"Split: {args.split} (与训练集 train 隔离)", rank)
|
| 822 |
+
|
| 823 |
+
# ---- 加载模型 ----
|
| 824 |
+
model, processor, ret_id, ans_id = load_model(args.model_path, device)
|
| 825 |
+
log(f"Model loaded. <RET>={ret_id}, <ANS>={ans_id}", rank)
|
| 826 |
+
|
| 827 |
+
# ---- 确定 dataset 列表 ----
|
| 828 |
+
if args.all_categories:
|
| 829 |
+
categories = ["vqa", "captioning", "classification", "reasoning"]
|
| 830 |
+
elif args.category:
|
| 831 |
+
categories = [args.category]
|
| 832 |
+
else:
|
| 833 |
+
categories = ["vqa"]
|
| 834 |
+
|
| 835 |
+
if args.dataset:
|
| 836 |
+
ds_list = [(args.category or "vqa", args.dataset)]
|
| 837 |
+
else:
|
| 838 |
+
ds_list = discover_datasets(categories)
|
| 839 |
+
|
| 840 |
+
# ---- 按 rank 分配 dataset ----
|
| 841 |
+
my_ds_list = ds_list[rank::world_size]
|
| 842 |
+
log(f"共 {len(ds_list)} 个 dataset,rank {rank} 分到 {len(my_ds_list)} 个", rank)
|
| 843 |
+
|
| 844 |
+
local_results = []
|
| 845 |
+
t_start = time.time()
|
| 846 |
+
|
| 847 |
+
for ds_idx, (cat, ds) in enumerate(my_ds_list):
|
| 848 |
+
log(f"[{ds_idx + 1}/{len(my_ds_list)}] Evaluating {cat}/{ds} ({args.split})", rank)
|
| 849 |
+
|
| 850 |
+
records = load_records(cat, ds, args.split, limit=args.num_samples * 5)
|
| 851 |
+
if not records:
|
| 852 |
+
log(f" 跳过 {cat}/{ds}:无记录", rank)
|
| 853 |
+
continue
|
| 854 |
+
|
| 855 |
+
top5 = load_top5(cat, ds)
|
| 856 |
+
if not top5:
|
| 857 |
+
log(f" 跳过 {cat}/{ds}:无 top5 embedding", rank)
|
| 858 |
+
continue
|
| 859 |
+
|
| 860 |
+
caption_cache = load_caption_cache(cat, ds)
|
| 861 |
+
instructions = load_instructions(cat, ds)
|
| 862 |
+
|
| 863 |
+
# 过滤:需要 top5 覆盖
|
| 864 |
+
records = [r for r in records if r["image"] in top5]
|
| 865 |
+
if not records:
|
| 866 |
+
log(f" 跳过 {cat}/{ds}:val 图片无 top5 覆盖", rank)
|
| 867 |
+
continue
|
| 868 |
+
|
| 869 |
+
if len(records) > args.num_samples:
|
| 870 |
+
records = random.sample(records, args.num_samples)
|
| 871 |
+
log(f" {cat}/{ds}: {len(records)} 条, batch_size={args.batch_size}", rank)
|
| 872 |
+
|
| 873 |
+
ds_results = run_icl_batch(
|
| 874 |
+
model, processor, records, instructions, top5, caption_cache,
|
| 875 |
+
max_rounds=args.max_rounds,
|
| 876 |
+
batch_size=args.batch_size,
|
| 877 |
+
rank=rank,
|
| 878 |
+
ds_label=f"{cat}/{ds}",
|
| 879 |
+
)
|
| 880 |
+
for r in ds_results:
|
| 881 |
+
r["category"] = cat
|
| 882 |
+
r["dataset"] = ds
|
| 883 |
+
local_results.extend(ds_results)
|
| 884 |
+
|
| 885 |
+
elapsed = time.time() - t_start
|
| 886 |
+
log(f"\nrank {rank} 完成,{len(local_results)} 条,耗时 {elapsed:.1f}s", rank)
|
| 887 |
+
|
| 888 |
+
# ---- 汇总结果 ----
|
| 889 |
+
all_results = gather_results(local_results, rank, world_size)
|
| 890 |
+
|
| 891 |
+
if is_main:
|
| 892 |
+
# 排序:category → dataset → image
|
| 893 |
+
all_results.sort(key=lambda r: (r.get("category", ""), r.get("dataset", ""), r.get("image", "")))
|
| 894 |
+
|
| 895 |
+
# ---- 按 category / dataset 打印统计 ----
|
| 896 |
+
cat_results = defaultdict(list)
|
| 897 |
+
for r in all_results:
|
| 898 |
+
cat_results[r["category"]].append(r)
|
| 899 |
+
|
| 900 |
+
for cat in categories:
|
| 901 |
+
if not cat_results.get(cat):
|
| 902 |
+
continue
|
| 903 |
+
ds_groups = defaultdict(list)
|
| 904 |
+
for r in cat_results[cat]:
|
| 905 |
+
ds_groups[r["dataset"]].append(r)
|
| 906 |
+
for d in sorted(ds_groups):
|
| 907 |
+
print_stats(ds_groups[d], cat, d)
|
| 908 |
+
# category 汇总
|
| 909 |
+
if len(ds_groups) > 1:
|
| 910 |
+
print_stats(cat_results[cat], cat)
|
| 911 |
+
|
| 912 |
+
# 总汇总
|
| 913 |
+
if len(categories) > 1 or not args.dataset:
|
| 914 |
+
print_stats(all_results)
|
| 915 |
+
|
| 916 |
+
# ---- 保存 JSON log ----
|
| 917 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 918 |
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
| 919 |
+
output_path = os.path.join(args.output_dir, f"eval_{args.split}_{timestamp}.json")
|
| 920 |
+
|
| 921 |
+
# 构建 summary
|
| 922 |
+
summary = {
|
| 923 |
+
"model_path": args.model_path,
|
| 924 |
+
"split": args.split,
|
| 925 |
+
"num_samples_per_ds": args.num_samples,
|
| 926 |
+
"max_rounds": args.max_rounds,
|
| 927 |
+
"total_samples": len(all_results),
|
| 928 |
+
"world_size": world_size,
|
| 929 |
+
"elapsed_seconds": elapsed,
|
| 930 |
+
"metrics": {},
|
| 931 |
+
}
|
| 932 |
+
|
| 933 |
+
# 整体 metrics
|
| 934 |
+
summary["metrics"]["overall"] = compute_metrics(all_results)
|
| 935 |
+
|
| 936 |
+
# 按 category metrics
|
| 937 |
+
for cat in categories:
|
| 938 |
+
if cat_results.get(cat):
|
| 939 |
+
summary["metrics"][cat] = compute_metrics(cat_results[cat])
|
| 940 |
+
|
| 941 |
+
output_data = {
|
| 942 |
+
"summary": summary,
|
| 943 |
+
"results": all_results,
|
| 944 |
+
}
|
| 945 |
+
|
| 946 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 947 |
+
json.dump(output_data, f, ensure_ascii=False, indent=2)
|
| 948 |
+
print(f"\n详细结果已保存到: {output_path}")
|
| 949 |
+
|
| 950 |
+
# 也保存一份不带时间戳的 latest
|
| 951 |
+
latest_path = os.path.join(args.output_dir, f"eval_{args.split}_latest.json")
|
| 952 |
+
with open(latest_path, "w", encoding="utf-8") as f:
|
| 953 |
+
json.dump(output_data, f, ensure_ascii=False, indent=2)
|
| 954 |
+
print(f"Latest 链接: {latest_path}")
|
| 955 |
+
|
| 956 |
+
if world_size > 1:
|
| 957 |
+
dist.destroy_process_group()
|
| 958 |
+
|
| 959 |
+
|
| 960 |
+
if __name__ == "__main__":
|
| 961 |
+
main()
|
ICL/SFT_new/launch_wrapper.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Wrapper for northjob: receives torchrun args, launches run_multi_node.sh."""
|
| 3 |
+
import subprocess
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
if __name__ == "__main__":
|
| 8 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 9 |
+
bash_script = os.path.join(script_dir, "run_multi_node.sh")
|
| 10 |
+
args = sys.argv[1:]
|
| 11 |
+
cmd = ["bash", bash_script] + args
|
| 12 |
+
result = subprocess.run(cmd, env=os.environ.copy())
|
| 13 |
+
sys.exit(result.returncode)
|
ICL/SFT_new/rebuild_and_train.sh
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# =============================================================================
|
| 3 |
+
# 一键:重建 SFT 数据 → 提交 16 卡训练任务
|
| 4 |
+
#
|
| 5 |
+
# 1. 用新配比 (answer_at_weights=1,3,3,2 + 去掉中间ANS) 重建数据
|
| 6 |
+
# 2. 通过 northjob 提交 16 GPU 训练
|
| 7 |
+
#
|
| 8 |
+
# Usage:
|
| 9 |
+
# bash rebuild_and_train.sh
|
| 10 |
+
# =============================================================================
|
| 11 |
+
set -euo pipefail
|
| 12 |
+
|
| 13 |
+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
| 14 |
+
ICL_DIR="$(dirname "${SCRIPT_DIR}")"
|
| 15 |
+
PYTHON_BIN="/workspace/miniconda3/envs/sft/bin/python3"
|
| 16 |
+
|
| 17 |
+
BUILD_SCRIPT="${ICL_DIR}/build_sft.py"
|
| 18 |
+
SFT_OUTPUT="/workspace/xiaobin/dataset/sft"
|
| 19 |
+
SFT_DATA="${SFT_OUTPUT}/all/sft.jsonl"
|
| 20 |
+
|
| 21 |
+
echo "============================================"
|
| 22 |
+
echo "Step 1: 重建 SFT 数据集"
|
| 23 |
+
echo " 权重: 5,3,2,1 (多给 0-shot ANS,轨迹式无矛盾)"
|
| 24 |
+
echo " 轨迹式生成:同一输入只出现一种 action"
|
| 25 |
+
echo "============================================"
|
| 26 |
+
|
| 27 |
+
# 备份旧数据
|
| 28 |
+
if [ -f "${SFT_DATA}" ]; then
|
| 29 |
+
BACKUP="${SFT_DATA}.bak.$(date +%Y%m%d_%H%M%S)"
|
| 30 |
+
cp "${SFT_DATA}" "${BACKUP}"
|
| 31 |
+
echo "旧数据已备份: ${BACKUP}"
|
| 32 |
+
fi
|
| 33 |
+
|
| 34 |
+
# 重建数据(4 类,总量 ~6 万条 SFT 样本)
|
| 35 |
+
${PYTHON_BIN} "${BUILD_SCRIPT}" \
|
| 36 |
+
--answer-at-weights "5,3,2,1" \
|
| 37 |
+
--samples-per-cat 7800 \
|
| 38 |
+
--shuffle
|
| 39 |
+
|
| 40 |
+
echo ""
|
| 41 |
+
|
| 42 |
+
# 验证新数据
|
| 43 |
+
echo "============================================"
|
| 44 |
+
echo "Step 2: 验证新数据配比"
|
| 45 |
+
echo "============================================"
|
| 46 |
+
${PYTHON_BIN} -c "
|
| 47 |
+
import json
|
| 48 |
+
ret, ans = 0, 0
|
| 49 |
+
shot_ret, shot_ans = {}, {}
|
| 50 |
+
with open('${SFT_DATA}') as f:
|
| 51 |
+
for line in f:
|
| 52 |
+
r = json.loads(line)
|
| 53 |
+
n = len(r.get('shots', []))
|
| 54 |
+
if r['type'] == 'ret':
|
| 55 |
+
ret += 1
|
| 56 |
+
shot_ret[n] = shot_ret.get(n, 0) + 1
|
| 57 |
+
else:
|
| 58 |
+
ans += 1
|
| 59 |
+
shot_ans[n] = shot_ans.get(n, 0) + 1
|
| 60 |
+
total = ret + ans
|
| 61 |
+
print(f'总样本: {total}')
|
| 62 |
+
print(f'RET: {ret} ({ret/total*100:.1f}%)')
|
| 63 |
+
print(f'ANS: {ans} ({ans/total*100:.1f}%)')
|
| 64 |
+
print(f'RET/ANS 比: {ret/max(ans,1):.2f}:1')
|
| 65 |
+
print()
|
| 66 |
+
print('RET shot 分布:')
|
| 67 |
+
for k in sorted(shot_ret): print(f' {k}-shot: {shot_ret[k]}')
|
| 68 |
+
print('ANS shot 分布:')
|
| 69 |
+
for k in sorted(shot_ans): print(f' {k}-shot: {shot_ans[k]}')
|
| 70 |
+
r0 = shot_ret.get(0, 0); a0 = shot_ans.get(0, 0)
|
| 71 |
+
print(f'\n0-shot: RET={r0}({r0/(r0+a0)*100:.1f}%) ANS={a0}({a0/(r0+a0)*100:.1f}%)')
|
| 72 |
+
"
|
| 73 |
+
|
| 74 |
+
echo ""
|
| 75 |
+
echo "============================================"
|
| 76 |
+
echo "Step 3: 提交 16 卡训练任务"
|
| 77 |
+
echo "============================================"
|
| 78 |
+
|
| 79 |
+
bash "${SCRIPT_DIR}/submit_northjob.sh" 16
|
| 80 |
+
|
| 81 |
+
echo ""
|
| 82 |
+
echo "============================================"
|
| 83 |
+
echo "全部完成!"
|
| 84 |
+
echo " 数据: ${SFT_DATA}"
|
| 85 |
+
echo " 任务: 16 GPU via northjob"
|
| 86 |
+
echo "============================================"
|
ICL/SFT_new/run_eval.sh
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# =============================================================================
|
| 3 |
+
# ICL 评测启动脚本
|
| 4 |
+
#
|
| 5 |
+
# 默认:四类任务 (vqa, captioning, classification, reasoning) 各 500 条
|
| 6 |
+
#
|
| 7 |
+
# 用法:
|
| 8 |
+
# bash run_eval.sh # 单卡,四类各 500 条
|
| 9 |
+
# bash run_eval.sh 8 # 8 卡,四类各 500 条
|
| 10 |
+
# bash run_eval.sh 1 vqa vqav2 20 # 单卡,指定 dataset,20 条
|
| 11 |
+
#
|
| 12 |
+
# 环境变量:
|
| 13 |
+
# MODEL_PATH=... bash run_eval.sh # 指定模型路径
|
| 14 |
+
# BATCH_SIZE=8 bash run_eval.sh # 调大 batch
|
| 15 |
+
# =============================================================================
|
| 16 |
+
set -euo pipefail
|
| 17 |
+
|
| 18 |
+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
| 19 |
+
|
| 20 |
+
# ---- 默认参数 ----
|
| 21 |
+
NUM_GPUS="${1:-1}"
|
| 22 |
+
CATEGORY="${2:-}"
|
| 23 |
+
DATASET="${3:-}"
|
| 24 |
+
NUM_SAMPLES="${4:-500}"
|
| 25 |
+
BATCH_SIZE="${BATCH_SIZE:-4}"
|
| 26 |
+
SPLIT="val"
|
| 27 |
+
MODEL_PATH="${MODEL_PATH:-/workspace/xiaobin/ICL/sft_model/epoch3_step1406_fp32}"
|
| 28 |
+
OUTPUT_DIR="${SCRIPT_DIR}/eval_results"
|
| 29 |
+
|
| 30 |
+
echo "============================================"
|
| 31 |
+
echo "ICL Evaluation"
|
| 32 |
+
echo " GPUs: ${NUM_GPUS}"
|
| 33 |
+
echo " Model: ${MODEL_PATH}"
|
| 34 |
+
echo " Split: ${SPLIT}"
|
| 35 |
+
echo " Batch size: ${BATCH_SIZE}"
|
| 36 |
+
echo " Samples/ds: ${NUM_SAMPLES}"
|
| 37 |
+
echo " Category: ${CATEGORY:-all (vqa,captioning,classification,reasoning)}"
|
| 38 |
+
echo " Dataset: ${DATASET:-all}"
|
| 39 |
+
echo " Output: ${OUTPUT_DIR}"
|
| 40 |
+
echo "============================================"
|
| 41 |
+
|
| 42 |
+
# ---- 构建参数 ----
|
| 43 |
+
EXTRA_ARGS=""
|
| 44 |
+
if [ -n "${CATEGORY}" ] && [ -n "${DATASET}" ]; then
|
| 45 |
+
EXTRA_ARGS="--category ${CATEGORY} --dataset ${DATASET}"
|
| 46 |
+
elif [ -n "${CATEGORY}" ]; then
|
| 47 |
+
EXTRA_ARGS="--category ${CATEGORY}"
|
| 48 |
+
else
|
| 49 |
+
EXTRA_ARGS="--all-categories"
|
| 50 |
+
fi
|
| 51 |
+
|
| 52 |
+
if [ "${NUM_GPUS}" -eq 1 ]; then
|
| 53 |
+
python3 "${SCRIPT_DIR}/eval.py" \
|
| 54 |
+
--model-path "${MODEL_PATH}" \
|
| 55 |
+
--split "${SPLIT}" \
|
| 56 |
+
--num-samples "${NUM_SAMPLES}" \
|
| 57 |
+
--batch-size "${BATCH_SIZE}" \
|
| 58 |
+
--max-rounds 4 \
|
| 59 |
+
--output-dir "${OUTPUT_DIR}" \
|
| 60 |
+
--device cuda:0 \
|
| 61 |
+
${EXTRA_ARGS}
|
| 62 |
+
else
|
| 63 |
+
torchrun \
|
| 64 |
+
--nproc_per_node="${NUM_GPUS}" \
|
| 65 |
+
--master_port=29501 \
|
| 66 |
+
"${SCRIPT_DIR}/eval.py" \
|
| 67 |
+
--model-path "${MODEL_PATH}" \
|
| 68 |
+
--split "${SPLIT}" \
|
| 69 |
+
--num-samples "${NUM_SAMPLES}" \
|
| 70 |
+
--batch-size "${BATCH_SIZE}" \
|
| 71 |
+
--max-rounds 4 \
|
| 72 |
+
--output-dir "${OUTPUT_DIR}" \
|
| 73 |
+
${EXTRA_ARGS}
|
| 74 |
+
fi
|
ICL/SFT_new/run_single_node.sh
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# =============================================================================
|
| 3 |
+
# Single-node training (1 machine, 8x H100)
|
| 4 |
+
# For debugging and quick iteration
|
| 5 |
+
#
|
| 6 |
+
# Usage:
|
| 7 |
+
# bash run_single_node.sh <data.jsonl> [num_gpus]
|
| 8 |
+
# bash run_single_node.sh /path/to/sft.jsonl 8
|
| 9 |
+
# bash run_single_node.sh /path/to/sft.jsonl 2 # quick debug
|
| 10 |
+
# =============================================================================
|
| 11 |
+
set -euo pipefail
|
| 12 |
+
|
| 13 |
+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
| 14 |
+
|
| 15 |
+
# ---- Config ----
|
| 16 |
+
MODEL_PATH="/workspace/models/Qwen3-VL-8B-Instruct"
|
| 17 |
+
DATA_PATH="${1:?Usage: $0 <data.jsonl> [num_gpus]}"
|
| 18 |
+
NUM_GPUS="${2:-8}"
|
| 19 |
+
OUTPUT_DIR="/workspace/xiaobin/ICL/sft_model"
|
| 20 |
+
|
| 21 |
+
# ---- Env ----
|
| 22 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 23 |
+
export NCCL_P2P_DISABLE=0
|
| 24 |
+
export NCCL_IB_DISABLE=0
|
| 25 |
+
|
| 26 |
+
# ---- Launch ----
|
| 27 |
+
echo "============================================"
|
| 28 |
+
echo "Single-node SFT: ${NUM_GPUS} GPUs"
|
| 29 |
+
echo "Model: ${MODEL_PATH}"
|
| 30 |
+
echo "Data: ${DATA_PATH}"
|
| 31 |
+
echo "Output: ${OUTPUT_DIR}"
|
| 32 |
+
echo "============================================"
|
| 33 |
+
|
| 34 |
+
torchrun \
|
| 35 |
+
--nproc_per_node=${NUM_GPUS} \
|
| 36 |
+
--master_port=29500 \
|
| 37 |
+
${SCRIPT_DIR}/train.py \
|
| 38 |
+
--model-path ${MODEL_PATH} \
|
| 39 |
+
--data-path ${DATA_PATH} \
|
| 40 |
+
--output-dir ${OUTPUT_DIR} \
|
| 41 |
+
--deepspeed ${SCRIPT_DIR}/ds_zero2.json \
|
| 42 |
+
--num-epochs 3 \
|
| 43 |
+
--batch-size 2 \
|
| 44 |
+
--gradient-accumulation-steps 4 \
|
| 45 |
+
--learning-rate 1e-6 \
|
| 46 |
+
--max-length 32768 \
|
| 47 |
+
--gradient-checkpointing \
|
| 48 |
+
--log-interval 10 \
|
| 49 |
+
--save-interval 500
|
ICL/SFT_new/submit_northjob.sh
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# =============================================================================
|
| 3 |
+
# Submit multi-node job via northjob (16 GPUs = 2 nodes × 8 H100)
|
| 4 |
+
#
|
| 5 |
+
# Usage:
|
| 6 |
+
# bash submit_northjob.sh [num_gpus]
|
| 7 |
+
# bash submit_northjob.sh 16 # 2 nodes
|
| 8 |
+
# bash submit_northjob.sh 8 # 1 node (debug)
|
| 9 |
+
# =============================================================================
|
| 10 |
+
set -euo pipefail
|
| 11 |
+
|
| 12 |
+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
| 13 |
+
GPU_NUMS="${1:-16}"
|
| 14 |
+
GPU_PER_NODE=8
|
| 15 |
+
NNODES=$((GPU_NUMS / GPU_PER_NODE))
|
| 16 |
+
|
| 17 |
+
JOB_NAME="qwen3vl-sft-${GPU_NUMS}gpu"
|
| 18 |
+
WORK_DIR="${SCRIPT_DIR}"
|
| 19 |
+
TRAIN_SCRIPT="${SCRIPT_DIR}/launch_wrapper.py"
|
| 20 |
+
|
| 21 |
+
echo "Submitting: ${JOB_NAME} (${NNODES} nodes × ${GPU_PER_NODE} GPUs)"
|
| 22 |
+
|
| 23 |
+
/workspace/miniconda3/envs/sft/bin/northjob \
|
| 24 |
+
create \
|
| 25 |
+
--job-type train \
|
| 26 |
+
--nproc-per-node ${GPU_PER_NODE} \
|
| 27 |
+
--gpu-per-node ${GPU_PER_NODE} \
|
| 28 |
+
--nnodes ${NNODES} \
|
| 29 |
+
--k8s-priority 3 \
|
| 30 |
+
--k8s-queue bg-agentic-coding \
|
| 31 |
+
--k8s-namespace bg-agentic-coding \
|
| 32 |
+
--k8s-pvc-name i-xinsiyang-y4zy0sik0a \
|
| 33 |
+
--k8s-pvc-mount-path /workspace \
|
| 34 |
+
--k8s-no-reclaim \
|
| 35 |
+
--k8s-images harbor.local.clusters/bp/megatron-bplm:25.03_fp8.ibgda.qwen3.next.fix_triton.fix_te.hf457.qwen3_vl \
|
| 36 |
+
--job-name ${JOB_NAME} \
|
| 37 |
+
--workspace ${WORK_DIR} \
|
| 38 |
+
${TRAIN_SCRIPT} ${GPU_PER_NODE}
|
ICL/SFT_new/train.py
ADDED
|
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Qwen3-VL-8B SFT Training Script (single-step RET/ANS decision).
|
| 5 |
+
|
| 6 |
+
Supports:
|
| 7 |
+
- Full fine-tuning or LoRA
|
| 8 |
+
- DeepSpeed ZeRO-2/3
|
| 9 |
+
- Multi-image conversations
|
| 10 |
+
- Loss masking on user turns only
|
| 11 |
+
- Flash Attention 2 on H100
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import json
|
| 16 |
+
import logging
|
| 17 |
+
import math
|
| 18 |
+
import os
|
| 19 |
+
import sys
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Dict, List, Optional, Sequence
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.distributed as dist
|
| 25 |
+
from torch.utils.data import Dataset, DataLoader
|
| 26 |
+
|
| 27 |
+
from transformers import (
|
| 28 |
+
AutoProcessor,
|
| 29 |
+
Qwen3VLForConditionalGeneration,
|
| 30 |
+
get_cosine_schedule_with_warmup,
|
| 31 |
+
)
|
| 32 |
+
from peft import LoraConfig, get_peft_model, TaskType
|
| 33 |
+
from qwen_vl_utils import process_vision_info
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
import deepspeed
|
| 37 |
+
HAS_DEEPSPEED = True
|
| 38 |
+
except ImportError:
|
| 39 |
+
HAS_DEEPSPEED = False
|
| 40 |
+
|
| 41 |
+
logging.basicConfig(
|
| 42 |
+
format="%(asctime)s [%(levelname)s] %(message)s",
|
| 43 |
+
level=logging.INFO,
|
| 44 |
+
)
|
| 45 |
+
logger = logging.getLogger(__name__)
|
| 46 |
+
|
| 47 |
+
# Special token IDs (Qwen3-VL)
|
| 48 |
+
IM_START_ID = 151644
|
| 49 |
+
IM_END_ID = 151645
|
| 50 |
+
IGNORE_INDEX = -100
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# ============================================================================
|
| 54 |
+
# Dataset
|
| 55 |
+
# ============================================================================
|
| 56 |
+
|
| 57 |
+
class SFTDataset(Dataset):
|
| 58 |
+
"""Load single-step SFT JSONL (轻量引用格式).
|
| 59 |
+
|
| 60 |
+
支持两种格式:
|
| 61 |
+
|
| 62 |
+
格式 A (新,轻量引用):
|
| 63 |
+
{
|
| 64 |
+
"type": "ret" | "ans",
|
| 65 |
+
"query_image": "/path/to/query.jpg",
|
| 66 |
+
"question": "What color?",
|
| 67 |
+
"answer": "black",
|
| 68 |
+
"instruction": "Answer the question...",
|
| 69 |
+
"shots": [{"image": "/path/shot.jpg", "caption": "A cat..."}],
|
| 70 |
+
"next_description": "A dog..." // 仅 ret 类型
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
格式 B (旧,conversations):
|
| 74 |
+
{
|
| 75 |
+
"images": ["path1.jpg", ...],
|
| 76 |
+
"conversations": [
|
| 77 |
+
{"from": "human", "value": "...<image>..."},
|
| 78 |
+
{"from": "gpt", "value": "<ANS>answer</ANS>"}
|
| 79 |
+
]
|
| 80 |
+
}
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(self, data_path: str, processor, max_length: int = 4096,
|
| 84 |
+
min_pixels: int = 256 * 28 * 28,
|
| 85 |
+
max_pixels: int = 1280 * 28 * 28):
|
| 86 |
+
self.processor = processor
|
| 87 |
+
self.max_length = max_length
|
| 88 |
+
self.min_pixels = min_pixels
|
| 89 |
+
self.max_pixels = max_pixels
|
| 90 |
+
self.records = []
|
| 91 |
+
|
| 92 |
+
logger.info(f"Loading data from {data_path}")
|
| 93 |
+
with open(data_path, "r", encoding="utf-8") as f:
|
| 94 |
+
for line in f:
|
| 95 |
+
line = line.strip()
|
| 96 |
+
if not line:
|
| 97 |
+
continue
|
| 98 |
+
try:
|
| 99 |
+
self.records.append(json.loads(line))
|
| 100 |
+
except Exception:
|
| 101 |
+
continue
|
| 102 |
+
logger.info(f"Loaded {len(self.records)} samples")
|
| 103 |
+
|
| 104 |
+
def __len__(self):
|
| 105 |
+
return len(self.records)
|
| 106 |
+
|
| 107 |
+
# ---- 新格式: 从引用字段动态构建 messages ----
|
| 108 |
+
|
| 109 |
+
def _build_messages_v2(self, record: Dict) -> List[Dict]:
|
| 110 |
+
"""从轻量引用格式构建 Qwen3-VL chat messages."""
|
| 111 |
+
user_content = []
|
| 112 |
+
|
| 113 |
+
# 1. instruction
|
| 114 |
+
inst = record.get("instruction", "")
|
| 115 |
+
if inst:
|
| 116 |
+
user_content.append({"type": "text", "text": inst})
|
| 117 |
+
|
| 118 |
+
# 2. query image
|
| 119 |
+
user_content.append({
|
| 120 |
+
"type": "image",
|
| 121 |
+
"image": f"file://{record['query_image']}",
|
| 122 |
+
"min_pixels": self.min_pixels,
|
| 123 |
+
"max_pixels": self.max_pixels,
|
| 124 |
+
})
|
| 125 |
+
|
| 126 |
+
# 3. question (可能为空,如 captioning 类)
|
| 127 |
+
question = record.get("question", "")
|
| 128 |
+
if question:
|
| 129 |
+
user_content.append({"type": "text", "text": f"Question: {question}"})
|
| 130 |
+
|
| 131 |
+
# 4. context shots (image + caption)
|
| 132 |
+
for shot in record.get("shots", []):
|
| 133 |
+
user_content.append({
|
| 134 |
+
"type": "image",
|
| 135 |
+
"image": f"file://{shot['image']}",
|
| 136 |
+
"min_pixels": self.min_pixels,
|
| 137 |
+
"max_pixels": self.max_pixels,
|
| 138 |
+
})
|
| 139 |
+
cap = shot.get("caption", "")
|
| 140 |
+
if cap:
|
| 141 |
+
user_content.append({"type": "text", "text": f"Caption: {cap}"})
|
| 142 |
+
|
| 143 |
+
# 5. Action prompt
|
| 144 |
+
user_content.append({"type": "text", "text": "Action:"})
|
| 145 |
+
|
| 146 |
+
# 6. assistant response
|
| 147 |
+
if record["type"] == "ret":
|
| 148 |
+
desc = record.get("next_description", "")
|
| 149 |
+
assistant_text = f"<RET>\nDescription: {desc}"
|
| 150 |
+
else:
|
| 151 |
+
assistant_text = f"<ANS>{record['answer']}</ANS>"
|
| 152 |
+
|
| 153 |
+
messages = [
|
| 154 |
+
{"role": "user", "content": user_content},
|
| 155 |
+
{"role": "assistant", "content": [{"type": "text", "text": assistant_text}]},
|
| 156 |
+
]
|
| 157 |
+
return messages
|
| 158 |
+
|
| 159 |
+
# ---- 旧格式: conversations + <image> 占位符 ----
|
| 160 |
+
|
| 161 |
+
def _build_messages_v1(self, record: Dict) -> List[Dict]:
|
| 162 |
+
"""Convert conversations format → Qwen3-VL chat messages."""
|
| 163 |
+
convs = record["conversations"]
|
| 164 |
+
image_paths = record.get("images", [])
|
| 165 |
+
messages = []
|
| 166 |
+
|
| 167 |
+
for turn in convs:
|
| 168 |
+
role = "user" if turn["from"] == "human" else "assistant"
|
| 169 |
+
text = turn["value"]
|
| 170 |
+
|
| 171 |
+
if role == "user":
|
| 172 |
+
content = []
|
| 173 |
+
parts = text.split("<image>")
|
| 174 |
+
img_idx = 0
|
| 175 |
+
for i, part in enumerate(parts):
|
| 176 |
+
if i > 0 and img_idx < len(image_paths):
|
| 177 |
+
content.append({
|
| 178 |
+
"type": "image",
|
| 179 |
+
"image": f"file://{image_paths[img_idx]}",
|
| 180 |
+
"min_pixels": self.min_pixels,
|
| 181 |
+
"max_pixels": self.max_pixels,
|
| 182 |
+
})
|
| 183 |
+
img_idx += 1
|
| 184 |
+
if part.strip():
|
| 185 |
+
content.append({"type": "text", "text": part.strip()})
|
| 186 |
+
messages.append({"role": role, "content": content})
|
| 187 |
+
else:
|
| 188 |
+
messages.append({
|
| 189 |
+
"role": role,
|
| 190 |
+
"content": [{"type": "text", "text": text}],
|
| 191 |
+
})
|
| 192 |
+
|
| 193 |
+
return messages
|
| 194 |
+
|
| 195 |
+
def __getitem__(self, idx):
|
| 196 |
+
record = self.records[idx]
|
| 197 |
+
|
| 198 |
+
# 自动检测格式
|
| 199 |
+
if "type" in record and "query_image" in record:
|
| 200 |
+
messages = self._build_messages_v2(record)
|
| 201 |
+
else:
|
| 202 |
+
messages = self._build_messages_v1(record)
|
| 203 |
+
|
| 204 |
+
# Apply chat template (no generation prompt for training)
|
| 205 |
+
text = self.processor.apply_chat_template(
|
| 206 |
+
messages, tokenize=False, add_generation_prompt=False
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# Process images
|
| 210 |
+
image_inputs = None
|
| 211 |
+
try:
|
| 212 |
+
image_inputs, _ = process_vision_info(messages)
|
| 213 |
+
except Exception:
|
| 214 |
+
pass
|
| 215 |
+
|
| 216 |
+
# Tokenize — 不截断,避免图片 token 不匹配
|
| 217 |
+
inputs = self.processor(
|
| 218 |
+
text=[text],
|
| 219 |
+
images=image_inputs if image_inputs else None,
|
| 220 |
+
return_tensors="pt",
|
| 221 |
+
padding=False,
|
| 222 |
+
truncation=False,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# Squeeze batch dim
|
| 226 |
+
input_ids = inputs["input_ids"].squeeze(0)
|
| 227 |
+
attention_mask = inputs["attention_mask"].squeeze(0)
|
| 228 |
+
|
| 229 |
+
# 超长时截断文本部分(保留前 max_length 个 token)
|
| 230 |
+
if input_ids.shape[0] > self.max_length:
|
| 231 |
+
input_ids = input_ids[:self.max_length]
|
| 232 |
+
attention_mask = attention_mask[:self.max_length]
|
| 233 |
+
|
| 234 |
+
# Build labels: mask user turns, keep assistant turns
|
| 235 |
+
labels = self._build_labels(input_ids)
|
| 236 |
+
|
| 237 |
+
result = {
|
| 238 |
+
"input_ids": input_ids,
|
| 239 |
+
"attention_mask": attention_mask,
|
| 240 |
+
"labels": labels,
|
| 241 |
+
}
|
| 242 |
+
# Pass through pixel values if present
|
| 243 |
+
if "pixel_values" in inputs:
|
| 244 |
+
result["pixel_values"] = inputs["pixel_values"].squeeze(0) \
|
| 245 |
+
if inputs["pixel_values"].dim() > 3 else inputs["pixel_values"]
|
| 246 |
+
if "image_grid_thw" in inputs:
|
| 247 |
+
result["image_grid_thw"] = inputs["image_grid_thw"]
|
| 248 |
+
|
| 249 |
+
return result
|
| 250 |
+
|
| 251 |
+
def _build_labels(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 252 |
+
"""Mask everything except assistant responses.
|
| 253 |
+
|
| 254 |
+
Strategy: find <|im_start|>assistant ... <|im_end|> spans,
|
| 255 |
+
only compute loss on tokens after 'assistant\n' until <|im_end|>.
|
| 256 |
+
"""
|
| 257 |
+
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
| 258 |
+
ids = input_ids.tolist()
|
| 259 |
+
|
| 260 |
+
assist_tokens = self.processor.tokenizer.encode(
|
| 261 |
+
"assistant\n", add_special_tokens=False
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
i = 0
|
| 265 |
+
while i < len(ids):
|
| 266 |
+
if ids[i] == IM_START_ID:
|
| 267 |
+
start = i + 1
|
| 268 |
+
end = start + len(assist_tokens)
|
| 269 |
+
if end <= len(ids) and ids[start:end] == assist_tokens:
|
| 270 |
+
content_start = end
|
| 271 |
+
j = content_start
|
| 272 |
+
while j < len(ids) and ids[j] != IM_END_ID:
|
| 273 |
+
j += 1
|
| 274 |
+
labels[content_start:j + 1] = input_ids[content_start:j + 1]
|
| 275 |
+
i = j + 1
|
| 276 |
+
continue
|
| 277 |
+
i += 1
|
| 278 |
+
|
| 279 |
+
return labels
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# ============================================================================
|
| 283 |
+
# Collator
|
| 284 |
+
# ============================================================================
|
| 285 |
+
|
| 286 |
+
class SFTCollator:
|
| 287 |
+
"""Pad variable-length samples into a batch."""
|
| 288 |
+
|
| 289 |
+
def __init__(self, pad_token_id: int, max_length: int = 4096):
|
| 290 |
+
self.pad_token_id = pad_token_id
|
| 291 |
+
self.max_length = max_length
|
| 292 |
+
|
| 293 |
+
def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
|
| 294 |
+
max_len = min(
|
| 295 |
+
max(f["input_ids"].size(0) for f in features),
|
| 296 |
+
self.max_length,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
batch_input_ids = []
|
| 300 |
+
batch_attention_mask = []
|
| 301 |
+
batch_labels = []
|
| 302 |
+
batch_pixel_values = []
|
| 303 |
+
batch_image_grid_thw = []
|
| 304 |
+
|
| 305 |
+
for f in features:
|
| 306 |
+
ids = f["input_ids"][:max_len]
|
| 307 |
+
mask = f["attention_mask"][:max_len]
|
| 308 |
+
lab = f["labels"][:max_len]
|
| 309 |
+
pad_len = max_len - ids.size(0)
|
| 310 |
+
|
| 311 |
+
if pad_len > 0:
|
| 312 |
+
ids = torch.cat([ids, torch.full((pad_len,), self.pad_token_id, dtype=ids.dtype)])
|
| 313 |
+
mask = torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)])
|
| 314 |
+
lab = torch.cat([lab, torch.full((pad_len,), IGNORE_INDEX, dtype=lab.dtype)])
|
| 315 |
+
|
| 316 |
+
batch_input_ids.append(ids)
|
| 317 |
+
batch_attention_mask.append(mask)
|
| 318 |
+
batch_labels.append(lab)
|
| 319 |
+
|
| 320 |
+
if "pixel_values" in f:
|
| 321 |
+
batch_pixel_values.append(f["pixel_values"])
|
| 322 |
+
if "image_grid_thw" in f:
|
| 323 |
+
batch_image_grid_thw.append(f["image_grid_thw"])
|
| 324 |
+
|
| 325 |
+
result = {
|
| 326 |
+
"input_ids": torch.stack(batch_input_ids),
|
| 327 |
+
"attention_mask": torch.stack(batch_attention_mask),
|
| 328 |
+
"labels": torch.stack(batch_labels),
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
if batch_pixel_values:
|
| 332 |
+
result["pixel_values"] = torch.cat(batch_pixel_values, dim=0)
|
| 333 |
+
if batch_image_grid_thw:
|
| 334 |
+
result["image_grid_thw"] = torch.cat(batch_image_grid_thw, dim=0)
|
| 335 |
+
|
| 336 |
+
return result
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
# ============================================================================
|
| 340 |
+
# Training
|
| 341 |
+
# ============================================================================
|
| 342 |
+
|
| 343 |
+
def train(args):
|
| 344 |
+
# ---- Distributed setup ----
|
| 345 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 346 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 347 |
+
rank = int(os.environ.get("RANK", 0))
|
| 348 |
+
|
| 349 |
+
if world_size > 1 and not dist.is_initialized():
|
| 350 |
+
dist.init_process_group("nccl")
|
| 351 |
+
|
| 352 |
+
torch.cuda.set_device(local_rank)
|
| 353 |
+
device = torch.device(f"cuda:{local_rank}")
|
| 354 |
+
is_main = rank == 0
|
| 355 |
+
|
| 356 |
+
if is_main:
|
| 357 |
+
logger.info(f"World size: {world_size}, Local rank: {local_rank}")
|
| 358 |
+
logger.info(f"Args: {vars(args)}")
|
| 359 |
+
|
| 360 |
+
# ---- Load processor & model ----
|
| 361 |
+
processor = AutoProcessor.from_pretrained(
|
| 362 |
+
args.model_path, trust_remote_code=True,
|
| 363 |
+
min_pixels=args.min_pixels, max_pixels=args.max_pixels,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
model_kwargs = {
|
| 367 |
+
"trust_remote_code": True,
|
| 368 |
+
"torch_dtype": torch.bfloat16,
|
| 369 |
+
"attn_implementation": "flash_attention_2",
|
| 370 |
+
}
|
| 371 |
+
if not (HAS_DEEPSPEED and args.deepspeed):
|
| 372 |
+
model_kwargs["device_map"] = {"": device}
|
| 373 |
+
|
| 374 |
+
model = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 375 |
+
args.model_path, **model_kwargs,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
# Add special tokens
|
| 379 |
+
special_tokens = ["<RET>", "<ANS>", "</ANS>", "<RETQ>", "</RETQ>"]
|
| 380 |
+
num_added = processor.tokenizer.add_tokens(special_tokens, special_tokens=True)
|
| 381 |
+
if num_added > 0:
|
| 382 |
+
model.resize_token_embeddings(len(processor.tokenizer))
|
| 383 |
+
if is_main:
|
| 384 |
+
logger.info(f"Added {num_added} special tokens, vocab → {len(processor.tokenizer)}")
|
| 385 |
+
|
| 386 |
+
# ---- LoRA (optional) ----
|
| 387 |
+
if args.use_lora:
|
| 388 |
+
lora_config = LoraConfig(
|
| 389 |
+
task_type=TaskType.CAUSAL_LM,
|
| 390 |
+
r=args.lora_rank,
|
| 391 |
+
lora_alpha=args.lora_alpha,
|
| 392 |
+
lora_dropout=args.lora_dropout,
|
| 393 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 394 |
+
"gate_proj", "up_proj", "down_proj"],
|
| 395 |
+
)
|
| 396 |
+
model = get_peft_model(model, lora_config)
|
| 397 |
+
if is_main:
|
| 398 |
+
model.print_trainable_parameters()
|
| 399 |
+
else:
|
| 400 |
+
if args.gradient_checkpointing:
|
| 401 |
+
model.gradient_checkpointing_enable(
|
| 402 |
+
gradient_checkpointing_kwargs={"use_reentrant": False}
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
# ---- Dataset ----
|
| 406 |
+
train_dataset = SFTDataset(
|
| 407 |
+
args.data_path, processor, args.max_length,
|
| 408 |
+
args.min_pixels, args.max_pixels,
|
| 409 |
+
)
|
| 410 |
+
collator = SFTCollator(processor.tokenizer.pad_token_id, args.max_length)
|
| 411 |
+
|
| 412 |
+
# ---- DeepSpeed or vanilla DDP ----
|
| 413 |
+
if HAS_DEEPSPEED and args.deepspeed:
|
| 414 |
+
# Load DS config and dynamically set scheduler params
|
| 415 |
+
import copy
|
| 416 |
+
with open(args.deepspeed, "r") as _f:
|
| 417 |
+
ds_config = json.load(_f)
|
| 418 |
+
|
| 419 |
+
# Explicitly set all batch-size params (avoid "auto" which some DS versions don't support)
|
| 420 |
+
micro_bs = ds_config.get("train_micro_batch_size_per_gpu", args.batch_size)
|
| 421 |
+
grad_accum_cfg = ds_config.get("gradient_accumulation_steps", args.gradient_accumulation_steps)
|
| 422 |
+
ds_config["train_micro_batch_size_per_gpu"] = micro_bs
|
| 423 |
+
ds_config["gradient_accumulation_steps"] = grad_accum_cfg
|
| 424 |
+
ds_config["train_batch_size"] = micro_bs * grad_accum_cfg * world_size
|
| 425 |
+
|
| 426 |
+
# Override LR from CLI args
|
| 427 |
+
if "optimizer" in ds_config and "params" in ds_config["optimizer"]:
|
| 428 |
+
ds_config["optimizer"]["params"]["lr"] = args.learning_rate
|
| 429 |
+
|
| 430 |
+
if is_main:
|
| 431 |
+
logger.info(f"DeepSpeed config: micro_bs={micro_bs}, grad_accum={grad_accum_cfg}, "
|
| 432 |
+
f"world_size={world_size}, train_batch_size={ds_config['train_batch_size']}")
|
| 433 |
+
|
| 434 |
+
model_engine, optimizer, train_loader, _ = deepspeed.initialize(
|
| 435 |
+
model=model,
|
| 436 |
+
model_parameters=[p for p in model.parameters() if p.requires_grad],
|
| 437 |
+
training_data=train_dataset,
|
| 438 |
+
collate_fn=collator,
|
| 439 |
+
config=ds_config,
|
| 440 |
+
)
|
| 441 |
+
# total_steps = optimizer steps (micro-batch steps per epoch / grad_accum * num_epochs)
|
| 442 |
+
grad_accum = model_engine.gradient_accumulation_steps()
|
| 443 |
+
steps_per_epoch = len(train_loader) // grad_accum
|
| 444 |
+
total_steps = steps_per_epoch * args.num_epochs
|
| 445 |
+
warmup_steps = int(total_steps * args.warmup_ratio)
|
| 446 |
+
|
| 447 |
+
# Replace DS scheduler with cosine schedule
|
| 448 |
+
# Note: model_engine.optimizer is DeepSpeedZeroOptimizer (not a torch.optim.Optimizer),
|
| 449 |
+
# so we must use the underlying torch optimizer for LambdaLR.
|
| 450 |
+
base_optimizer = model_engine.optimizer.optimizer # unwrap to torch AdamW
|
| 451 |
+
ds_scheduler = get_cosine_schedule_with_warmup(
|
| 452 |
+
base_optimizer,
|
| 453 |
+
num_warmup_steps=warmup_steps,
|
| 454 |
+
num_training_steps=total_steps,
|
| 455 |
+
)
|
| 456 |
+
model_engine.lr_scheduler = ds_scheduler
|
| 457 |
+
scheduler = None
|
| 458 |
+
else:
|
| 459 |
+
# Vanilla DDP
|
| 460 |
+
if world_size > 1:
|
| 461 |
+
model = torch.nn.parallel.DistributedDataParallel(
|
| 462 |
+
model, device_ids=[local_rank],
|
| 463 |
+
find_unused_parameters=False,
|
| 464 |
+
)
|
| 465 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
| 466 |
+
train_dataset, num_replicas=world_size, rank=rank, shuffle=True,
|
| 467 |
+
) if world_size > 1 else None
|
| 468 |
+
|
| 469 |
+
train_loader = DataLoader(
|
| 470 |
+
train_dataset, batch_size=args.batch_size,
|
| 471 |
+
sampler=sampler, shuffle=(sampler is None),
|
| 472 |
+
collate_fn=collator, num_workers=args.num_workers,
|
| 473 |
+
pin_memory=True, drop_last=True,
|
| 474 |
+
)
|
| 475 |
+
optimizer = torch.optim.AdamW(
|
| 476 |
+
[p for p in model.parameters() if p.requires_grad],
|
| 477 |
+
lr=args.learning_rate, weight_decay=args.weight_decay,
|
| 478 |
+
betas=(0.9, 0.999),
|
| 479 |
+
)
|
| 480 |
+
total_steps = (len(train_loader) * args.num_epochs) // args.gradient_accumulation_steps
|
| 481 |
+
warmup_steps = int(total_steps * args.warmup_ratio)
|
| 482 |
+
scheduler = get_cosine_schedule_with_warmup(
|
| 483 |
+
optimizer, warmup_steps, total_steps,
|
| 484 |
+
)
|
| 485 |
+
model_engine = None
|
| 486 |
+
|
| 487 |
+
if is_main:
|
| 488 |
+
logger.info(f"Dataset: {len(train_dataset)} samples")
|
| 489 |
+
logger.info(f"Total steps: {total_steps}, Warmup: {warmup_steps}")
|
| 490 |
+
|
| 491 |
+
# ---- Training loop ----
|
| 492 |
+
optimizer_step = 0
|
| 493 |
+
running_loss = 0.0
|
| 494 |
+
running_count = 0
|
| 495 |
+
accum_loss = 0.0 # accumulate loss across micro-batches within one grad accum cycle
|
| 496 |
+
|
| 497 |
+
for epoch in range(args.num_epochs):
|
| 498 |
+
if hasattr(train_loader, "sampler") and hasattr(train_loader.sampler, "set_epoch"):
|
| 499 |
+
train_loader.sampler.set_epoch(epoch)
|
| 500 |
+
|
| 501 |
+
model.train() if model_engine is None else model_engine.train()
|
| 502 |
+
|
| 503 |
+
for step, batch in enumerate(train_loader):
|
| 504 |
+
# Move batch to GPU
|
| 505 |
+
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
|
| 506 |
+
for k, v in batch.items()}
|
| 507 |
+
|
| 508 |
+
# Forward
|
| 509 |
+
if model_engine:
|
| 510 |
+
outputs = model_engine(**batch)
|
| 511 |
+
loss = outputs.loss
|
| 512 |
+
model_engine.backward(loss)
|
| 513 |
+
model_engine.step()
|
| 514 |
+
|
| 515 |
+
# Accumulate loss across micro-batches
|
| 516 |
+
accum_loss += loss.item()
|
| 517 |
+
|
| 518 |
+
# Log/save only on optimizer step boundaries
|
| 519 |
+
if model_engine.is_gradient_accumulation_boundary():
|
| 520 |
+
grad_accum = model_engine.gradient_accumulation_steps()
|
| 521 |
+
optimizer_step += 1
|
| 522 |
+
cur_loss = accum_loss / grad_accum # average over micro-batches
|
| 523 |
+
accum_loss = 0.0
|
| 524 |
+
|
| 525 |
+
running_loss += cur_loss
|
| 526 |
+
running_count += 1
|
| 527 |
+
avg_loss = running_loss / running_count
|
| 528 |
+
|
| 529 |
+
if is_main and optimizer_step % args.log_interval == 0:
|
| 530 |
+
lr_now = ds_scheduler.get_last_lr()[0]
|
| 531 |
+
logger.info(
|
| 532 |
+
f"Epoch {epoch+1}/{args.num_epochs} "
|
| 533 |
+
f"Step {optimizer_step}/{total_steps} "
|
| 534 |
+
f"Loss {cur_loss:.4f} "
|
| 535 |
+
f"AvgLoss {avg_loss:.4f} "
|
| 536 |
+
f"LR {lr_now:.2e}"
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
# Save checkpoint
|
| 540 |
+
if optimizer_step > 0 and optimizer_step % args.save_interval == 0:
|
| 541 |
+
_save_checkpoint(args, model, model_engine, processor, epoch, optimizer_step, is_main)
|
| 542 |
+
|
| 543 |
+
else:
|
| 544 |
+
outputs = model(**batch)
|
| 545 |
+
loss = outputs.loss / args.gradient_accumulation_steps
|
| 546 |
+
loss.backward()
|
| 547 |
+
accum_loss += loss.item() * args.gradient_accumulation_steps
|
| 548 |
+
|
| 549 |
+
if (step + 1) % args.gradient_accumulation_steps == 0:
|
| 550 |
+
torch.nn.utils.clip_grad_norm_(
|
| 551 |
+
model.parameters(), args.max_grad_norm
|
| 552 |
+
)
|
| 553 |
+
optimizer.step()
|
| 554 |
+
scheduler.step()
|
| 555 |
+
optimizer.zero_grad()
|
| 556 |
+
optimizer_step += 1
|
| 557 |
+
|
| 558 |
+
cur_loss = accum_loss / args.gradient_accumulation_steps
|
| 559 |
+
accum_loss = 0.0
|
| 560 |
+
|
| 561 |
+
running_loss += cur_loss
|
| 562 |
+
running_count += 1
|
| 563 |
+
avg_loss = running_loss / running_count
|
| 564 |
+
|
| 565 |
+
if is_main and optimizer_step % args.log_interval == 0:
|
| 566 |
+
lr_now = scheduler.get_last_lr()[0]
|
| 567 |
+
logger.info(
|
| 568 |
+
f"Epoch {epoch+1}/{args.num_epochs} "
|
| 569 |
+
f"Step {optimizer_step}/{total_steps} "
|
| 570 |
+
f"Loss {cur_loss:.4f} "
|
| 571 |
+
f"AvgLoss {avg_loss:.4f} "
|
| 572 |
+
f"LR {lr_now:.2e}"
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
# End of epoch save
|
| 576 |
+
if model_engine:
|
| 577 |
+
_save_checkpoint(args, model, model_engine, processor, epoch, optimizer_step, is_main)
|
| 578 |
+
else:
|
| 579 |
+
_save_checkpoint(args, model, model_engine, processor, epoch, optimizer_step, is_main)
|
| 580 |
+
|
| 581 |
+
# Final save
|
| 582 |
+
if model_engine:
|
| 583 |
+
_save_checkpoint(args, model, model_engine, processor, args.num_epochs, optimizer_step, is_main, final=True)
|
| 584 |
+
else:
|
| 585 |
+
_save_checkpoint(args, model, model_engine, processor, args.num_epochs, optimizer_step, is_main, final=True)
|
| 586 |
+
|
| 587 |
+
if is_main:
|
| 588 |
+
logger.info("Training complete!")
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
def _save_checkpoint(args, model, model_engine, processor, epoch, step, is_main, final=False):
|
| 592 |
+
tag = "final" if final else f"epoch{epoch+1}_step{step}"
|
| 593 |
+
save_dir = Path(args.output_dir) / tag
|
| 594 |
+
|
| 595 |
+
if model_engine and HAS_DEEPSPEED:
|
| 596 |
+
# DeepSpeed save_checkpoint must be called by ALL ranks
|
| 597 |
+
model_engine.save_checkpoint(str(args.output_dir), tag=tag)
|
| 598 |
+
elif is_main:
|
| 599 |
+
unwrapped = model.module if hasattr(model, "module") else model
|
| 600 |
+
if args.use_lora:
|
| 601 |
+
unwrapped.save_pretrained(str(save_dir))
|
| 602 |
+
else:
|
| 603 |
+
unwrapped.save_pretrained(str(save_dir))
|
| 604 |
+
processor.save_pretrained(str(save_dir))
|
| 605 |
+
|
| 606 |
+
if is_main:
|
| 607 |
+
logger.info(f"Saved checkpoint → {save_dir}")
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
# ============================================================================
|
| 611 |
+
# Main
|
| 612 |
+
# ============================================================================
|
| 613 |
+
|
| 614 |
+
def parse_args():
|
| 615 |
+
p = argparse.ArgumentParser(description="Qwen3-VL-8B SFT")
|
| 616 |
+
|
| 617 |
+
# Model
|
| 618 |
+
p.add_argument("--model-path", default="/workspace/models/Qwen3-VL-8B-Instruct")
|
| 619 |
+
p.add_argument("--output-dir", default="/workspace/xiaobin/ICL/SFT_new/output/qwen3vl_sft")
|
| 620 |
+
|
| 621 |
+
# Data
|
| 622 |
+
p.add_argument("--data-path", required=True, help="Path to sft.jsonl")
|
| 623 |
+
p.add_argument("--max-length", type=int, default=4096)
|
| 624 |
+
p.add_argument("--min-pixels", type=int, default=256 * 28 * 28)
|
| 625 |
+
p.add_argument("--max-pixels", type=int, default=1280 * 28 * 28)
|
| 626 |
+
|
| 627 |
+
# Training
|
| 628 |
+
p.add_argument("--num-epochs", type=int, default=3)
|
| 629 |
+
p.add_argument("--batch-size", type=int, default=1,
|
| 630 |
+
help="Per-GPU micro batch size")
|
| 631 |
+
p.add_argument("--gradient-accumulation-steps", type=int, default=4)
|
| 632 |
+
p.add_argument("--learning-rate", type=float, default=1e-5)
|
| 633 |
+
p.add_argument("--weight-decay", type=float, default=0.1)
|
| 634 |
+
p.add_argument("--warmup-ratio", type=float, default=0.05)
|
| 635 |
+
p.add_argument("--max-grad-norm", type=float, default=1.0)
|
| 636 |
+
p.add_argument("--gradient-checkpointing", action="store_true", default=True)
|
| 637 |
+
p.add_argument("--num-workers", type=int, default=4)
|
| 638 |
+
|
| 639 |
+
# LoRA
|
| 640 |
+
p.add_argument("--use-lora", action="store_true", default=False)
|
| 641 |
+
p.add_argument("--lora-rank", type=int, default=64)
|
| 642 |
+
p.add_argument("--lora-alpha", type=int, default=128)
|
| 643 |
+
p.add_argument("--lora-dropout", type=float, default=0.05)
|
| 644 |
+
|
| 645 |
+
# Logging
|
| 646 |
+
p.add_argument("--log-interval", type=int, default=10)
|
| 647 |
+
p.add_argument("--save-interval", type=int, default=500)
|
| 648 |
+
|
| 649 |
+
# DeepSpeed
|
| 650 |
+
p.add_argument("--deepspeed", type=str, default=None,
|
| 651 |
+
help="Path to DeepSpeed config JSON")
|
| 652 |
+
p.add_argument("--local_rank", type=int, default=-1) # torchrun sets this
|
| 653 |
+
|
| 654 |
+
return p.parse_args()
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
if __name__ == "__main__":
|
| 658 |
+
args = parse_args()
|
| 659 |
+
train(args)
|
ICL/build_embeddings.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
预算 SigLIP2 embeddings + Top5 相似图片映射(8卡 DataParallel)。
|
| 4 |
+
|
| 5 |
+
用法:
|
| 6 |
+
python3 build_embeddings.py # 8卡,全部
|
| 7 |
+
python3 build_embeddings.py --datasets vqa/shapes # 测试
|
| 8 |
+
python3 build_embeddings.py --force # 强制重建
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import json
|
| 13 |
+
import os
|
| 14 |
+
import numpy as np
|
| 15 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 16 |
+
from typing import Dict, List, Tuple
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
except ImportError:
|
| 21 |
+
def tqdm(x, **kw):
|
| 22 |
+
return x
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
import cv2
|
| 27 |
+
import numpy as np
|
| 28 |
+
from PIL import Image
|
| 29 |
+
from transformers import AutoModel, AutoProcessor
|
| 30 |
+
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
IMAGES_ROOT = "/workspace/xiaobin/dataset/images"
|
| 33 |
+
CAPTION_CACHE_DIR = "/workspace/xiaobin/dataset/caption_cache"
|
| 34 |
+
EMBEDDINGS_DIR = "/workspace/xiaobin/dataset/embeddings"
|
| 35 |
+
DEFAULT_MODEL = "/workspace/models/siglip2-so400m-patch14-384"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
# DataParallel wrappers
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
class SigLIPImageModule(nn.Module):
|
| 42 |
+
def __init__(self, model):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.model = model
|
| 45 |
+
|
| 46 |
+
def forward(self, **kwargs):
|
| 47 |
+
out = self.model.get_image_features(**kwargs)
|
| 48 |
+
feat = out.pooler_output if hasattr(out, "pooler_output") else out
|
| 49 |
+
return feat / feat.norm(dim=-1, keepdim=True)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class SigLIPTextModule(nn.Module):
|
| 53 |
+
def __init__(self, model):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.model = model
|
| 56 |
+
|
| 57 |
+
def forward(self, **kwargs):
|
| 58 |
+
out = self.model.get_text_features(**kwargs)
|
| 59 |
+
feat = out.pooler_output if hasattr(out, "pooler_output") else out
|
| 60 |
+
return feat / feat.norm(dim=-1, keepdim=True)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ---------------------------------------------------------------------------
|
| 64 |
+
# Encoder: 单进程, 多线程读图, 小batch快速跑
|
| 65 |
+
# ---------------------------------------------------------------------------
|
| 66 |
+
class SigLIPEncoder:
|
| 67 |
+
def __init__(self, model_path: str, gpu_ids: List[int],
|
| 68 |
+
batch_size_per_gpu: int = 64, num_threads: int = 16):
|
| 69 |
+
self.gpu_ids = gpu_ids
|
| 70 |
+
self.n_gpus = len(gpu_ids)
|
| 71 |
+
self.batch_size = batch_size_per_gpu * self.n_gpus
|
| 72 |
+
self.num_threads = num_threads
|
| 73 |
+
self.primary = torch.device(f"cuda:{gpu_ids[0]}")
|
| 74 |
+
|
| 75 |
+
print(f" GPU: {gpu_ids} ({self.n_gpus} 张)")
|
| 76 |
+
print(f" batch: {batch_size_per_gpu}/卡 × {self.n_gpus}卡 = {self.batch_size}")
|
| 77 |
+
|
| 78 |
+
self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
| 79 |
+
base_model = AutoModel.from_pretrained(
|
| 80 |
+
model_path, dtype=torch.bfloat16, trust_remote_code=True
|
| 81 |
+
).to(self.primary).eval()
|
| 82 |
+
|
| 83 |
+
self.img_module = nn.DataParallel(
|
| 84 |
+
SigLIPImageModule(base_model), device_ids=gpu_ids)
|
| 85 |
+
self.txt_module = nn.DataParallel(
|
| 86 |
+
SigLIPTextModule(base_model), device_ids=gpu_ids)
|
| 87 |
+
|
| 88 |
+
@staticmethod
|
| 89 |
+
def _load_and_preprocess(path):
|
| 90 |
+
"""读图 + OpenCV resize + normalize → numpy (3, 384, 384) float32"""
|
| 91 |
+
try:
|
| 92 |
+
img = cv2.imread(path)
|
| 93 |
+
if img is None:
|
| 94 |
+
return (path, None)
|
| 95 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 96 |
+
img = cv2.resize(img, (384, 384))
|
| 97 |
+
img = img.astype(np.float32) / 255.0
|
| 98 |
+
img = (img - 0.5) / 0.5
|
| 99 |
+
img = np.transpose(img, (2, 0, 1)) # (3, 384, 384)
|
| 100 |
+
return (path, img)
|
| 101 |
+
except Exception:
|
| 102 |
+
return (path, None)
|
| 103 |
+
|
| 104 |
+
def encode_images(self, paths: List[str]) -> Tuple[List[str], np.ndarray]:
|
| 105 |
+
all_embs = []
|
| 106 |
+
valid_paths = []
|
| 107 |
+
n = len(paths)
|
| 108 |
+
pbar = tqdm(total=n, desc=" encode-img", unit="张", dynamic_ncols=True)
|
| 109 |
+
|
| 110 |
+
thread_pool = ThreadPoolExecutor(max_workers=self.num_threads)
|
| 111 |
+
|
| 112 |
+
batches = [paths[s:s + self.batch_size]
|
| 113 |
+
for s in range(0, n, self.batch_size)]
|
| 114 |
+
|
| 115 |
+
# 预提交第一批
|
| 116 |
+
if batches:
|
| 117 |
+
next_future = list(thread_pool.map(self._load_and_preprocess, batches[0]))
|
| 118 |
+
else:
|
| 119 |
+
next_future = []
|
| 120 |
+
|
| 121 |
+
for i, batch_paths in enumerate(batches):
|
| 122 |
+
loaded = next_future
|
| 123 |
+
|
| 124 |
+
# 提前提交下一批 IO + 预处理
|
| 125 |
+
if i + 1 < len(batches):
|
| 126 |
+
next_futures_list = [thread_pool.submit(self._load_and_preprocess, p)
|
| 127 |
+
for p in batches[i + 1]]
|
| 128 |
+
else:
|
| 129 |
+
next_futures_list = None
|
| 130 |
+
|
| 131 |
+
batch_valid = []
|
| 132 |
+
batch_arrays = []
|
| 133 |
+
for p, arr in loaded:
|
| 134 |
+
if arr is not None:
|
| 135 |
+
batch_valid.append(p)
|
| 136 |
+
batch_arrays.append(arr)
|
| 137 |
+
|
| 138 |
+
if not batch_arrays:
|
| 139 |
+
pbar.update(len(batch_paths))
|
| 140 |
+
if next_futures_list:
|
| 141 |
+
next_future = [f.result() for f in next_futures_list]
|
| 142 |
+
continue
|
| 143 |
+
|
| 144 |
+
# numpy stack → torch → GPU
|
| 145 |
+
pixel_values = torch.from_numpy(np.stack(batch_arrays)).to(
|
| 146 |
+
dtype=torch.bfloat16, device=self.primary)
|
| 147 |
+
|
| 148 |
+
with torch.inference_mode():
|
| 149 |
+
feat = self.img_module(pixel_values=pixel_values)
|
| 150 |
+
all_embs.append(feat.cpu().float().numpy())
|
| 151 |
+
valid_paths.extend(batch_valid)
|
| 152 |
+
|
| 153 |
+
pbar.update(len(batch_paths))
|
| 154 |
+
|
| 155 |
+
if next_futures_list:
|
| 156 |
+
next_future = [f.result() for f in next_futures_list]
|
| 157 |
+
|
| 158 |
+
thread_pool.shutdown(wait=False)
|
| 159 |
+
pbar.close()
|
| 160 |
+
if not all_embs:
|
| 161 |
+
return [], np.empty((0, 0), dtype=np.float16)
|
| 162 |
+
return valid_paths, np.concatenate(all_embs, axis=0).astype(np.float16)
|
| 163 |
+
|
| 164 |
+
def encode_texts(self, texts: List[str]) -> np.ndarray:
|
| 165 |
+
all_embs = []
|
| 166 |
+
n = len(texts)
|
| 167 |
+
pbar = tqdm(total=n, desc=" encode-txt", unit="条", dynamic_ncols=True)
|
| 168 |
+
|
| 169 |
+
for start in range(0, n, self.batch_size):
|
| 170 |
+
batch = texts[start:start + self.batch_size]
|
| 171 |
+
inp = self.processor(text=batch, return_tensors="pt",
|
| 172 |
+
padding="max_length", truncation=True,
|
| 173 |
+
max_length=64)
|
| 174 |
+
keys = {k: v.to(self.primary) for k, v in inp.items()
|
| 175 |
+
if k in ("input_ids", "attention_mask", "position_ids")}
|
| 176 |
+
with torch.inference_mode():
|
| 177 |
+
feat = self.txt_module(**keys)
|
| 178 |
+
all_embs.append(feat.cpu().float().numpy())
|
| 179 |
+
pbar.update(len(batch))
|
| 180 |
+
|
| 181 |
+
pbar.close()
|
| 182 |
+
if not all_embs:
|
| 183 |
+
return np.empty((0, 0), dtype=np.float16)
|
| 184 |
+
return np.concatenate(all_embs, axis=0).astype(np.float16)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# ---------------------------------------------------------------------------
|
| 188 |
+
# Top-K(GPU)
|
| 189 |
+
# ---------------------------------------------------------------------------
|
| 190 |
+
def compute_top_k(caption_embs, image_embs, image_paths, k=5,
|
| 191 |
+
chunk_size=5000, device="cuda:0"):
|
| 192 |
+
n = len(image_paths)
|
| 193 |
+
img_gpu = torch.from_numpy(image_embs.astype(np.float32)).to(device)
|
| 194 |
+
top_k_map = {}
|
| 195 |
+
|
| 196 |
+
for start in tqdm(range(0, n, chunk_size), desc=" compute-top5", unit="chunk"):
|
| 197 |
+
end = min(start + chunk_size, n)
|
| 198 |
+
cap = torch.from_numpy(caption_embs[start:end].astype(np.float32)).to(device)
|
| 199 |
+
sim = cap @ img_gpu.T
|
| 200 |
+
idx_range = torch.arange(end - start, device=sim.device)
|
| 201 |
+
sim[idx_range, torch.arange(start, end, device=sim.device)] = -1.0
|
| 202 |
+
_, top_idx = sim.topk(k, dim=1)
|
| 203 |
+
top_idx_cpu = top_idx.cpu().numpy()
|
| 204 |
+
for i in range(end - start):
|
| 205 |
+
top_k_map[image_paths[start + i]] = [
|
| 206 |
+
image_paths[j] for j in top_idx_cpu[i]]
|
| 207 |
+
return top_k_map
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# ---------------------------------------------------------------------------
|
| 211 |
+
# 数据集工具
|
| 212 |
+
# ---------------------------------------------------------------------------
|
| 213 |
+
def discover_datasets(categories=None, specific=None):
|
| 214 |
+
if specific:
|
| 215 |
+
return [(s.split("/")[0], s.split("/")[1]) for s in specific if "/" in s]
|
| 216 |
+
result = []
|
| 217 |
+
for cat in sorted(os.listdir(IMAGES_ROOT)):
|
| 218 |
+
d = os.path.join(IMAGES_ROOT, cat)
|
| 219 |
+
if not os.path.isdir(d):
|
| 220 |
+
continue
|
| 221 |
+
if categories and cat not in categories:
|
| 222 |
+
continue
|
| 223 |
+
for ds in sorted(os.listdir(d)):
|
| 224 |
+
if os.path.isdir(os.path.join(d, ds)):
|
| 225 |
+
result.append((cat, ds))
|
| 226 |
+
return result
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def load_captions(cat, ds):
|
| 230 |
+
p = os.path.join(CAPTION_CACHE_DIR, f"{cat}_{ds}.json")
|
| 231 |
+
if not os.path.exists(p):
|
| 232 |
+
return {}
|
| 233 |
+
try:
|
| 234 |
+
with open(p) as f:
|
| 235 |
+
return json.load(f).get("items", {})
|
| 236 |
+
except Exception:
|
| 237 |
+
return {}
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def collect_images(cat, ds):
|
| 241 |
+
base = os.path.join(IMAGES_ROOT, cat, ds)
|
| 242 |
+
paths = []
|
| 243 |
+
for split in ("train", "val", "test", "other"):
|
| 244 |
+
d = os.path.join(base, split)
|
| 245 |
+
if not os.path.isdir(d):
|
| 246 |
+
continue
|
| 247 |
+
for fn in sorted(os.listdir(d)):
|
| 248 |
+
fp = os.path.join(d, fn)
|
| 249 |
+
if os.path.isfile(fp):
|
| 250 |
+
paths.append(fp)
|
| 251 |
+
return paths
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# ---------------------------------------------------------------------------
|
| 255 |
+
# 处理单个数据集(含断点续传)
|
| 256 |
+
# ---------------------------------------------------------------------------
|
| 257 |
+
def process_dataset(cat, ds, encoder, top_k, force):
|
| 258 |
+
tag = f"{cat}_{ds}"
|
| 259 |
+
npz_path = os.path.join(EMBEDDINGS_DIR, f"{tag}.npz")
|
| 260 |
+
top5_path = os.path.join(EMBEDDINGS_DIR, f"{tag}_top{top_k}.json")
|
| 261 |
+
|
| 262 |
+
# 断点1:全部完成
|
| 263 |
+
if not force and os.path.exists(npz_path) and os.path.exists(top5_path):
|
| 264 |
+
try:
|
| 265 |
+
data = np.load(npz_path, allow_pickle=True)
|
| 266 |
+
n_emb = len(data["image_paths"])
|
| 267 |
+
with open(top5_path) as f:
|
| 268 |
+
n_top = len(json.load(f))
|
| 269 |
+
if n_emb == n_top and n_emb > 0:
|
| 270 |
+
print(f" [SKIP] {tag} ({n_emb} 张)")
|
| 271 |
+
return True
|
| 272 |
+
except Exception:
|
| 273 |
+
pass
|
| 274 |
+
|
| 275 |
+
# 断点2:有 embeddings 缺 top5
|
| 276 |
+
if not force and os.path.exists(npz_path) and not os.path.exists(top5_path):
|
| 277 |
+
try:
|
| 278 |
+
data = np.load(npz_path, allow_pickle=True)
|
| 279 |
+
sp = list(data["image_paths"])
|
| 280 |
+
si, sc = data["image_embs"], data["caption_embs"]
|
| 281 |
+
if len(sp) > 0 and si.shape[0] == len(sp):
|
| 282 |
+
print(f" [RESUME] {tag} 只算 top{top_k} ({len(sp)} 张)")
|
| 283 |
+
m = compute_top_k(sc, si, sp, k=top_k, device=str(encoder.primary))
|
| 284 |
+
with open(top5_path, 'w') as f:
|
| 285 |
+
json.dump(m, f, ensure_ascii=False)
|
| 286 |
+
print(f" top{top_k}: {os.path.getsize(top5_path)/1048576:.1f}MB")
|
| 287 |
+
return True
|
| 288 |
+
except Exception:
|
| 289 |
+
pass
|
| 290 |
+
|
| 291 |
+
# 从头
|
| 292 |
+
all_paths = collect_images(cat, ds)
|
| 293 |
+
if not all_paths:
|
| 294 |
+
print(f" [SKIP] {tag} 无图片")
|
| 295 |
+
return False
|
| 296 |
+
|
| 297 |
+
captions = load_captions(cat, ds)
|
| 298 |
+
if not captions:
|
| 299 |
+
print(f" [WARN] {tag} 无 caption,跳过")
|
| 300 |
+
return False
|
| 301 |
+
|
| 302 |
+
paths_with_cap = [p for p in all_paths if p in captions]
|
| 303 |
+
if not paths_with_cap:
|
| 304 |
+
print(f" [WARN] {tag} 无交集,跳过")
|
| 305 |
+
return False
|
| 306 |
+
|
| 307 |
+
print(f"\n {tag}: {len(paths_with_cap)} 张图")
|
| 308 |
+
|
| 309 |
+
valid_paths, image_embs = encoder.encode_images(paths_with_cap)
|
| 310 |
+
if not valid_paths:
|
| 311 |
+
print(f" [ERROR] {tag} 编码失败")
|
| 312 |
+
return False
|
| 313 |
+
|
| 314 |
+
caption_embs = encoder.encode_texts([captions[p] for p in valid_paths])
|
| 315 |
+
|
| 316 |
+
os.makedirs(EMBEDDINGS_DIR, exist_ok=True)
|
| 317 |
+
np.savez_compressed(npz_path, image_paths=np.array(valid_paths),
|
| 318 |
+
image_embs=image_embs, caption_embs=caption_embs)
|
| 319 |
+
print(f" embeddings: {os.path.getsize(npz_path)/1048576:.1f}MB")
|
| 320 |
+
|
| 321 |
+
m = compute_top_k(caption_embs, image_embs, valid_paths,
|
| 322 |
+
k=top_k, device=str(encoder.primary))
|
| 323 |
+
with open(top5_path, 'w') as f:
|
| 324 |
+
json.dump(m, f, ensure_ascii=False)
|
| 325 |
+
print(f" top{top_k}: {os.path.getsize(top5_path)/1048576:.1f}MB")
|
| 326 |
+
return True
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
# ---------------------------------------------------------------------------
|
| 330 |
+
def main():
|
| 331 |
+
parser = argparse.ArgumentParser()
|
| 332 |
+
parser.add_argument("--model-path", default=DEFAULT_MODEL)
|
| 333 |
+
parser.add_argument("--gpus", default="")
|
| 334 |
+
parser.add_argument("--batch-size-per-gpu", type=int, default=256,
|
| 335 |
+
help="每卡batch(预处理不再是瓶颈,可以开大)")
|
| 336 |
+
parser.add_argument("--num-threads", type=int, default=16,
|
| 337 |
+
help="图片IO线程数")
|
| 338 |
+
parser.add_argument("--top-k", type=int, default=5)
|
| 339 |
+
parser.add_argument("--categories", default="")
|
| 340 |
+
parser.add_argument("--datasets", default="")
|
| 341 |
+
parser.add_argument("--force", action="store_true")
|
| 342 |
+
args = parser.parse_args()
|
| 343 |
+
|
| 344 |
+
gpu_ids = ([int(x) for x in args.gpus.split(",") if x.strip()]
|
| 345 |
+
or list(range(torch.cuda.device_count())))
|
| 346 |
+
total_batch = args.batch_size_per_gpu * len(gpu_ids)
|
| 347 |
+
print(f"GPU: {gpu_ids} ({len(gpu_ids)} 张), batch: {total_batch}")
|
| 348 |
+
|
| 349 |
+
cats = [c.strip() for c in args.categories.split(",") if c.strip()] or None
|
| 350 |
+
specific = [d.strip() for d in args.datasets.split(",") if d.strip()] or None
|
| 351 |
+
datasets = discover_datasets(categories=cats, specific=specific)
|
| 352 |
+
print(f"共 {len(datasets)} 个数据集\n")
|
| 353 |
+
|
| 354 |
+
encoder = SigLIPEncoder(args.model_path, gpu_ids,
|
| 355 |
+
args.batch_size_per_gpu, args.num_threads)
|
| 356 |
+
|
| 357 |
+
ok, fail = 0, 0
|
| 358 |
+
pbar = tqdm(datasets, desc="总进度", unit="ds", dynamic_ncols=True)
|
| 359 |
+
for i, (cat, ds) in enumerate(pbar, 1):
|
| 360 |
+
pbar.set_postfix(current=f"{cat}/{ds}", ok=ok, fail=fail)
|
| 361 |
+
if process_dataset(cat, ds, encoder, args.top_k, args.force):
|
| 362 |
+
ok += 1
|
| 363 |
+
else:
|
| 364 |
+
fail += 1
|
| 365 |
+
pbar.close()
|
| 366 |
+
print(f"\n完成: {ok} 成功, {fail} 失败/跳过 → {EMBEDDINGS_DIR}")
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
if __name__ == "__main__":
|
| 370 |
+
main()
|
ICL/build_index.py
ADDED
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
生成索引 JSONL:将原始 base64 JSONL 的文本字段 + 提取后的图片路径 + VLM描述 对应起来。
|
| 4 |
+
|
| 5 |
+
输入:
|
| 6 |
+
/workspace/xiaobin/dataset/data/{cat}/{ds}/{split}.jsonl (原始,含base64)
|
| 7 |
+
/workspace/xiaobin/dataset/images/{cat}/{ds}/{split}/ (已提取的图片)
|
| 8 |
+
/workspace/xiaobin/dataset/detail/{cat}/{ds}/{split}/captions.json (VLM描述)
|
| 9 |
+
|
| 10 |
+
输出:
|
| 11 |
+
/workspace/xiaobin/dataset/index/{cat}/{ds}/{split}.jsonl (轻量索引)
|
| 12 |
+
|
| 13 |
+
每条记录格式:
|
| 14 |
+
{
|
| 15 |
+
"image": "/workspace/xiaobin/dataset/images/vqa/shapes/test/00000000.jpg",
|
| 16 |
+
"images": ["/path/..."], # 多图时(video_str/images字段)
|
| 17 |
+
"question": "...",
|
| 18 |
+
"answer": "...",
|
| 19 |
+
"description": "A cat sitting...", # 来自 detail/captions.json
|
| 20 |
+
"meta": {...}, # 原始meta(如有)
|
| 21 |
+
"id": "...", # 原始id/img_id
|
| 22 |
+
"category": "vqa",
|
| 23 |
+
"dataset": "shapes",
|
| 24 |
+
"split": "test"
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
用法:
|
| 28 |
+
python3 build_index.py # 全部(已完成的自动跳过)
|
| 29 |
+
python3 build_index.py vqa/shapes # 某个数据集
|
| 30 |
+
python3 build_index.py --force # 全部强制重建
|
| 31 |
+
python3 build_index.py --force vqa/shapes # 某个数据集强制重建
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
import os
|
| 35 |
+
import sys
|
| 36 |
+
import json
|
| 37 |
+
import glob
|
| 38 |
+
import re
|
| 39 |
+
from tqdm import tqdm
|
| 40 |
+
|
| 41 |
+
DATA_ROOT = "/workspace/xiaobin/dataset/data"
|
| 42 |
+
IMAGES_ROOT = "/workspace/xiaobin/dataset/images"
|
| 43 |
+
DETAIL_ROOT = "/workspace/xiaobin/dataset/detail"
|
| 44 |
+
INDEX_ROOT = "/workspace/xiaobin/dataset/index"
|
| 45 |
+
|
| 46 |
+
# 图片base64字段(用于判断"这行有图",和extract_images.py一致)
|
| 47 |
+
ALL_IMAGE_FIELDS = [
|
| 48 |
+
"image", "image_str", "image_base64_str", "img_str",
|
| 49 |
+
"base64", "image_base64", "image_base_url",
|
| 50 |
+
"video_str", "images",
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
# 文本字段提取
|
| 54 |
+
QUESTION_FIELDS = ["question", "text", "query", "prompt", "input", "inputs", "user_prompt"]
|
| 55 |
+
ANSWER_FIELDS = ["answer", "output", "outputs", "label", "target", "caption", "paraphrased_answer", "original_answer"]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def classify_split(filename):
|
| 59 |
+
fn = filename.lower()
|
| 60 |
+
if "train" in fn:
|
| 61 |
+
return "train"
|
| 62 |
+
elif "test" in fn:
|
| 63 |
+
return "test"
|
| 64 |
+
elif "val" in fn:
|
| 65 |
+
return "val"
|
| 66 |
+
else:
|
| 67 |
+
return "other"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def has_image(record):
|
| 71 |
+
"""判断这条记录是否有图(和 extract_images.py 逻辑一致)"""
|
| 72 |
+
for field in ALL_IMAGE_FIELDS:
|
| 73 |
+
if field not in record or not record[field]:
|
| 74 |
+
continue
|
| 75 |
+
val = record[field]
|
| 76 |
+
if isinstance(val, str) and len(val) > 100:
|
| 77 |
+
return True
|
| 78 |
+
elif isinstance(val, list):
|
| 79 |
+
if any(isinstance(item, str) and len(item) > 100 for item in val):
|
| 80 |
+
return True
|
| 81 |
+
return False
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def is_multi_image(record):
|
| 85 |
+
"""判断是否是多图记录(video_str/images 列表字段)"""
|
| 86 |
+
for field in ("video_str", "images"):
|
| 87 |
+
if field in record and isinstance(record[field], list):
|
| 88 |
+
items = [x for x in record[field] if isinstance(x, str) and len(x) > 100]
|
| 89 |
+
if len(items) > 1:
|
| 90 |
+
return True
|
| 91 |
+
# image_str/image_base64 也可能是list
|
| 92 |
+
for field in ("image_str", "image_base64"):
|
| 93 |
+
val = record.get(field)
|
| 94 |
+
if isinstance(val, list):
|
| 95 |
+
items = [x for x in val if isinstance(x, str) and len(x) > 100]
|
| 96 |
+
if len(items) > 1:
|
| 97 |
+
return True
|
| 98 |
+
return False
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def count_images_in_record(record):
|
| 102 |
+
"""统计这条记录里有几张图"""
|
| 103 |
+
for field in ALL_IMAGE_FIELDS:
|
| 104 |
+
if field not in record or not record[field]:
|
| 105 |
+
continue
|
| 106 |
+
val = record[field]
|
| 107 |
+
if isinstance(val, str) and len(val) > 100:
|
| 108 |
+
return 1
|
| 109 |
+
elif isinstance(val, list):
|
| 110 |
+
return len([x for x in val if isinstance(x, str) and len(x) > 100])
|
| 111 |
+
return 0
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def extract_text(record, fields):
|
| 115 |
+
"""从记录中提取文本字段"""
|
| 116 |
+
for k in fields:
|
| 117 |
+
v = record.get(k)
|
| 118 |
+
if isinstance(v, str) and v.strip():
|
| 119 |
+
return v.strip()
|
| 120 |
+
# 尝试 answers 列表
|
| 121 |
+
if "answers" in record:
|
| 122 |
+
v = record["answers"]
|
| 123 |
+
if isinstance(v, list):
|
| 124 |
+
for a in v:
|
| 125 |
+
if isinstance(a, str) and a.strip():
|
| 126 |
+
return a.strip()
|
| 127 |
+
return None
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def extract_id(record):
|
| 131 |
+
"""提取记录ID"""
|
| 132 |
+
for k in ("id", "image_id", "img_id"):
|
| 133 |
+
v = record.get(k)
|
| 134 |
+
if v is not None:
|
| 135 |
+
return str(v)
|
| 136 |
+
meta = record.get("meta")
|
| 137 |
+
if isinstance(meta, dict):
|
| 138 |
+
for k in ("img_id", "id", "image_id"):
|
| 139 |
+
v = meta.get(k)
|
| 140 |
+
if v is not None:
|
| 141 |
+
return str(v)
|
| 142 |
+
return None
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def extract_meta(record):
|
| 146 |
+
"""提取meta信息(去掉base64等大字段)"""
|
| 147 |
+
meta = record.get("meta")
|
| 148 |
+
if not isinstance(meta, dict):
|
| 149 |
+
return None
|
| 150 |
+
out = {}
|
| 151 |
+
for k, v in meta.items():
|
| 152 |
+
# 跳过所有图片/base64相关字段
|
| 153 |
+
if any(x in k.lower() for x in ("image", "img", "base64", "video")):
|
| 154 |
+
continue
|
| 155 |
+
# 跳过大字符串
|
| 156 |
+
if isinstance(v, str) and len(v) > 500:
|
| 157 |
+
continue
|
| 158 |
+
# 跳过含大字符串的列表
|
| 159 |
+
if isinstance(v, list) and v and isinstance(v[0], str) and len(v[0]) > 200:
|
| 160 |
+
continue
|
| 161 |
+
out[k] = v
|
| 162 |
+
return out if out else None
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def load_detail(category, dataset, split):
|
| 166 |
+
"""加载 VLM description 缓存"""
|
| 167 |
+
path = os.path.join(DETAIL_ROOT, category, dataset, split, "captions.json")
|
| 168 |
+
if not os.path.exists(path):
|
| 169 |
+
return {}
|
| 170 |
+
try:
|
| 171 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 172 |
+
data = json.load(f)
|
| 173 |
+
items = data.get("items", {})
|
| 174 |
+
if isinstance(items, dict):
|
| 175 |
+
return items
|
| 176 |
+
except Exception:
|
| 177 |
+
pass
|
| 178 |
+
return {}
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def count_lines(filepath):
|
| 182 |
+
count = 0
|
| 183 |
+
with open(filepath, 'rb') as f:
|
| 184 |
+
buf_size = 8 * 1024 * 1024
|
| 185 |
+
buf = f.raw.read(buf_size)
|
| 186 |
+
while buf:
|
| 187 |
+
count += buf.count(b'\n')
|
| 188 |
+
buf = f.raw.read(buf_size)
|
| 189 |
+
return count
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def process_one(jsonl_path, file_idx, total_files):
|
| 193 |
+
"""处理单个原始 JSONL,生成索引 JSONL"""
|
| 194 |
+
rel_path = os.path.relpath(jsonl_path, DATA_ROOT)
|
| 195 |
+
parts = rel_path.split(os.sep)
|
| 196 |
+
if len(parts) < 3:
|
| 197 |
+
return 0
|
| 198 |
+
|
| 199 |
+
category, dataset, filename = parts[0], parts[1], parts[2]
|
| 200 |
+
split = classify_split(filename)
|
| 201 |
+
|
| 202 |
+
# 图片目录
|
| 203 |
+
img_dir = os.path.join(IMAGES_ROOT, category, dataset, split)
|
| 204 |
+
if not os.path.isdir(img_dir):
|
| 205 |
+
print(f" [SKIP] 无图片目录: {img_dir}")
|
| 206 |
+
return 0
|
| 207 |
+
|
| 208 |
+
# 图片文件列表(按编号排序)
|
| 209 |
+
img_files = sorted([f for f in os.listdir(img_dir) if os.path.isfile(os.path.join(img_dir, f))])
|
| 210 |
+
if not img_files:
|
| 211 |
+
print(f" [SKIP] 图片目录为空: {img_dir}")
|
| 212 |
+
return 0
|
| 213 |
+
|
| 214 |
+
# VLM描述
|
| 215 |
+
detail = load_detail(category, dataset, split)
|
| 216 |
+
|
| 217 |
+
# 输出索引文件
|
| 218 |
+
out_path = os.path.join(INDEX_ROOT, category, dataset, f"{split}.jsonl")
|
| 219 |
+
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
| 220 |
+
|
| 221 |
+
total_lines = count_lines(jsonl_path)
|
| 222 |
+
file_size_mb = os.path.getsize(jsonl_path) / (1024 * 1024)
|
| 223 |
+
desc = f"[{file_idx}/{total_files}] {category}/{dataset}/{split} ({file_size_mb:.0f}MB)"
|
| 224 |
+
|
| 225 |
+
img_idx = 0 # 图片文件游标
|
| 226 |
+
written = 0
|
| 227 |
+
skipped = 0
|
| 228 |
+
|
| 229 |
+
with open(jsonl_path, 'r', encoding='utf-8') as fin, \
|
| 230 |
+
open(out_path, 'w', encoding='utf-8') as fout:
|
| 231 |
+
|
| 232 |
+
pbar = tqdm(fin, total=total_lines, desc=desc, unit="行",
|
| 233 |
+
dynamic_ncols=True, miniters=100)
|
| 234 |
+
|
| 235 |
+
for line in pbar:
|
| 236 |
+
line = line.strip()
|
| 237 |
+
if not line:
|
| 238 |
+
continue
|
| 239 |
+
try:
|
| 240 |
+
record = json.loads(line)
|
| 241 |
+
except json.JSONDecodeError:
|
| 242 |
+
continue
|
| 243 |
+
|
| 244 |
+
if not has_image(record):
|
| 245 |
+
skipped += 1
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
n_imgs = count_images_in_record(record)
|
| 249 |
+
if img_idx + n_imgs > len(img_files):
|
| 250 |
+
# 图片不够了,可能extract时有错误
|
| 251 |
+
skipped += 1
|
| 252 |
+
continue
|
| 253 |
+
|
| 254 |
+
# 收集这条记录对应的图片路径
|
| 255 |
+
if n_imgs == 1:
|
| 256 |
+
img_path = os.path.join(img_dir, img_files[img_idx])
|
| 257 |
+
img_paths = [img_path]
|
| 258 |
+
else:
|
| 259 |
+
img_paths = [os.path.join(img_dir, img_files[img_idx + i])
|
| 260 |
+
for i in range(n_imgs)]
|
| 261 |
+
img_path = img_paths[0]
|
| 262 |
+
|
| 263 |
+
# 获取 VLM 描述
|
| 264 |
+
desc_text = detail.get(img_path, "")
|
| 265 |
+
# 多图时尝试获取每张的描述
|
| 266 |
+
if n_imgs > 1:
|
| 267 |
+
descs = [detail.get(p, "") for p in img_paths]
|
| 268 |
+
else:
|
| 269 |
+
descs = None
|
| 270 |
+
|
| 271 |
+
# 构建索引记录
|
| 272 |
+
idx_record = {
|
| 273 |
+
"image": img_path,
|
| 274 |
+
"question": extract_text(record, QUESTION_FIELDS),
|
| 275 |
+
"answer": extract_text(record, ANSWER_FIELDS),
|
| 276 |
+
"description": desc_text,
|
| 277 |
+
"category": category,
|
| 278 |
+
"dataset": dataset,
|
| 279 |
+
"split": split,
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
# 多图
|
| 283 |
+
if n_imgs > 1:
|
| 284 |
+
idx_record["images"] = img_paths
|
| 285 |
+
idx_record["descriptions"] = descs
|
| 286 |
+
|
| 287 |
+
# ID
|
| 288 |
+
rid = extract_id(record)
|
| 289 |
+
if rid:
|
| 290 |
+
idx_record["id"] = rid
|
| 291 |
+
|
| 292 |
+
# meta
|
| 293 |
+
meta = extract_meta(record)
|
| 294 |
+
if meta:
|
| 295 |
+
idx_record["meta"] = meta
|
| 296 |
+
|
| 297 |
+
# instructions(如有)
|
| 298 |
+
insts = record.get("instructions")
|
| 299 |
+
if isinstance(insts, list) and insts:
|
| 300 |
+
idx_record["instructions"] = insts
|
| 301 |
+
|
| 302 |
+
fout.write(json.dumps(idx_record, ensure_ascii=False) + "\n")
|
| 303 |
+
written += 1
|
| 304 |
+
img_idx += n_imgs
|
| 305 |
+
|
| 306 |
+
pbar.set_postfix(written=written, imgs=img_idx, skip=skipped, refresh=False)
|
| 307 |
+
|
| 308 |
+
pbar.close()
|
| 309 |
+
|
| 310 |
+
print(f" -> {written} 条, 用了 {img_idx} 张图, 跳过 {skipped} 行")
|
| 311 |
+
if img_idx != len(img_files):
|
| 312 |
+
print(f" [WARN] 图片游标 {img_idx} != 图片总数 {len(img_files)}")
|
| 313 |
+
return written
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def find_all_jsonl_files():
|
| 317 |
+
all_files = []
|
| 318 |
+
for jsonl_path in sorted(glob.glob(os.path.join(DATA_ROOT, "*/*/*.jsonl"))):
|
| 319 |
+
filename = os.path.basename(jsonl_path)
|
| 320 |
+
if re.search(r'_\d{4}-\d{2}-\d{2}\.jsonl$', filename):
|
| 321 |
+
continue
|
| 322 |
+
if '_v2.jsonl' in filename or '_new.jsonl' in filename:
|
| 323 |
+
continue
|
| 324 |
+
if filename.startswith('para_'):
|
| 325 |
+
continue
|
| 326 |
+
all_files.append(jsonl_path)
|
| 327 |
+
return all_files
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def group_by_split(files):
|
| 331 |
+
"""将多个JSONL文件按 (category/dataset/split) 分组,
|
| 332 |
+
同一split的多个文件按顺序合并处理(因为extract_images是按这个顺序提取的)"""
|
| 333 |
+
from collections import OrderedDict
|
| 334 |
+
groups = OrderedDict()
|
| 335 |
+
for f in files:
|
| 336 |
+
rel = os.path.relpath(f, DATA_ROOT)
|
| 337 |
+
parts = rel.split(os.sep)
|
| 338 |
+
if len(parts) < 3:
|
| 339 |
+
continue
|
| 340 |
+
cat, ds, fn = parts[0], parts[1], parts[2]
|
| 341 |
+
split = classify_split(fn)
|
| 342 |
+
key = (cat, ds, split)
|
| 343 |
+
groups.setdefault(key, []).append(f)
|
| 344 |
+
return groups
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def process_group(jsonl_files, category, dataset, split, group_idx, total_groups,
|
| 348 |
+
force=False):
|
| 349 |
+
"""处理同一个 split 的一组 JSONL 文件(可能有多个)"""
|
| 350 |
+
out_path = os.path.join(INDEX_ROOT, category, dataset, f"{split}.jsonl")
|
| 351 |
+
|
| 352 |
+
# 断点续传:对比索引条数和图片数,一致才跳过
|
| 353 |
+
if not force and os.path.exists(out_path) and os.path.getsize(out_path) > 0:
|
| 354 |
+
existing = sum(1 for _ in open(out_path, 'r', encoding='utf-8'))
|
| 355 |
+
img_dir = os.path.join(IMAGES_ROOT, category, dataset, split)
|
| 356 |
+
if os.path.isdir(img_dir):
|
| 357 |
+
img_count = len([f for f in os.listdir(img_dir) if os.path.isfile(os.path.join(img_dir, f))])
|
| 358 |
+
if existing == img_count:
|
| 359 |
+
print(f" [SKIP] {category}/{dataset}/{split} 索引完整 ({existing}/{img_count})")
|
| 360 |
+
return existing
|
| 361 |
+
else:
|
| 362 |
+
print(f" [REDO] {category}/{dataset}/{split} 索引不完整 ({existing}/{img_count}), 重建")
|
| 363 |
+
|
| 364 |
+
img_dir = os.path.join(IMAGES_ROOT, category, dataset, split)
|
| 365 |
+
if not os.path.isdir(img_dir):
|
| 366 |
+
print(f" [SKIP] 无图片目录: {img_dir}")
|
| 367 |
+
return 0
|
| 368 |
+
|
| 369 |
+
img_files = sorted([f for f in os.listdir(img_dir) if os.path.isfile(os.path.join(img_dir, f))])
|
| 370 |
+
if not img_files:
|
| 371 |
+
print(f" [SKIP] 图片目录为空: {img_dir}")
|
| 372 |
+
return 0
|
| 373 |
+
|
| 374 |
+
detail = load_detail(category, dataset, split)
|
| 375 |
+
|
| 376 |
+
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
| 377 |
+
|
| 378 |
+
img_idx = 0 # 图片游标,跨文件累加
|
| 379 |
+
written = 0
|
| 380 |
+
|
| 381 |
+
with open(out_path, 'w', encoding='utf-8') as fout:
|
| 382 |
+
for fi, jsonl_path in enumerate(jsonl_files):
|
| 383 |
+
total_lines = count_lines(jsonl_path)
|
| 384 |
+
file_size_mb = os.path.getsize(jsonl_path) / (1024 * 1024)
|
| 385 |
+
fn = os.path.basename(jsonl_path)
|
| 386 |
+
if len(jsonl_files) > 1:
|
| 387 |
+
desc = f"[{group_idx}/{total_groups}] {category}/{dataset}/{split} ({fn}, {file_size_mb:.0f}MB)"
|
| 388 |
+
else:
|
| 389 |
+
desc = f"[{group_idx}/{total_groups}] {category}/{dataset}/{split} ({file_size_mb:.0f}MB)"
|
| 390 |
+
|
| 391 |
+
skipped = 0
|
| 392 |
+
with open(jsonl_path, 'r', encoding='utf-8') as fin:
|
| 393 |
+
pbar = tqdm(fin, total=total_lines, desc=desc, unit="行",
|
| 394 |
+
dynamic_ncols=True, miniters=100)
|
| 395 |
+
for line in pbar:
|
| 396 |
+
line = line.strip()
|
| 397 |
+
if not line:
|
| 398 |
+
continue
|
| 399 |
+
try:
|
| 400 |
+
record = json.loads(line)
|
| 401 |
+
except json.JSONDecodeError:
|
| 402 |
+
continue
|
| 403 |
+
|
| 404 |
+
if not has_image(record):
|
| 405 |
+
skipped += 1
|
| 406 |
+
continue
|
| 407 |
+
|
| 408 |
+
n_imgs = count_images_in_record(record)
|
| 409 |
+
if img_idx + n_imgs > len(img_files):
|
| 410 |
+
skipped += 1
|
| 411 |
+
continue
|
| 412 |
+
|
| 413 |
+
if n_imgs == 1:
|
| 414 |
+
img_path = os.path.join(img_dir, img_files[img_idx])
|
| 415 |
+
img_paths = [img_path]
|
| 416 |
+
else:
|
| 417 |
+
img_paths = [os.path.join(img_dir, img_files[img_idx + i])
|
| 418 |
+
for i in range(n_imgs)]
|
| 419 |
+
img_path = img_paths[0]
|
| 420 |
+
|
| 421 |
+
desc_text = detail.get(img_path, "")
|
| 422 |
+
|
| 423 |
+
idx_record = {
|
| 424 |
+
"image": img_path,
|
| 425 |
+
"question": extract_text(record, QUESTION_FIELDS),
|
| 426 |
+
"answer": extract_text(record, ANSWER_FIELDS),
|
| 427 |
+
"description": desc_text,
|
| 428 |
+
"category": category,
|
| 429 |
+
"dataset": dataset,
|
| 430 |
+
"split": split,
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
if n_imgs > 1:
|
| 434 |
+
idx_record["images"] = img_paths
|
| 435 |
+
idx_record["descriptions"] = [detail.get(p, "") for p in img_paths]
|
| 436 |
+
|
| 437 |
+
rid = extract_id(record)
|
| 438 |
+
if rid:
|
| 439 |
+
idx_record["id"] = rid
|
| 440 |
+
|
| 441 |
+
meta = extract_meta(record)
|
| 442 |
+
if meta:
|
| 443 |
+
idx_record["meta"] = meta
|
| 444 |
+
|
| 445 |
+
insts = record.get("instructions")
|
| 446 |
+
if isinstance(insts, list) and insts:
|
| 447 |
+
idx_record["instructions"] = insts
|
| 448 |
+
|
| 449 |
+
fout.write(json.dumps(idx_record, ensure_ascii=False) + "\n")
|
| 450 |
+
written += 1
|
| 451 |
+
img_idx += n_imgs
|
| 452 |
+
|
| 453 |
+
pbar.set_postfix(written=written, imgs=img_idx, skip=skipped, refresh=False)
|
| 454 |
+
pbar.close()
|
| 455 |
+
|
| 456 |
+
print(f" -> {written} 条, 用了 {img_idx}/{len(img_files)} 张图")
|
| 457 |
+
if img_idx != len(img_files):
|
| 458 |
+
print(f" [WARN] 图片游标 {img_idx} != 图片总数 {len(img_files)}")
|
| 459 |
+
return written
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def main():
|
| 463 |
+
print("=" * 60)
|
| 464 |
+
print("生成索引 JSONL (图片路径 + 文本 + VLM描述)")
|
| 465 |
+
print(f"原始数据: {DATA_ROOT}")
|
| 466 |
+
print(f"图片目录: {IMAGES_ROOT}")
|
| 467 |
+
print(f"描述缓存: {DETAIL_ROOT}")
|
| 468 |
+
print(f"输出索引: {INDEX_ROOT}")
|
| 469 |
+
print("=" * 60)
|
| 470 |
+
|
| 471 |
+
force = "--force" in sys.argv
|
| 472 |
+
args = [a for a in sys.argv[1:] if a != "--force"]
|
| 473 |
+
|
| 474 |
+
if args:
|
| 475 |
+
target = args[0]
|
| 476 |
+
if os.path.isfile(target):
|
| 477 |
+
files = [target]
|
| 478 |
+
else:
|
| 479 |
+
files = sorted(glob.glob(os.path.join(DATA_ROOT, target, "*.jsonl")))
|
| 480 |
+
files = [f for f in files
|
| 481 |
+
if not re.search(r'_\d{4}-\d{2}-\d{2}\.jsonl$', os.path.basename(f))
|
| 482 |
+
and '_v2.jsonl' not in os.path.basename(f)
|
| 483 |
+
and '_new.jsonl' not in os.path.basename(f)
|
| 484 |
+
and not os.path.basename(f).startswith('para_')]
|
| 485 |
+
else:
|
| 486 |
+
files = find_all_jsonl_files()
|
| 487 |
+
|
| 488 |
+
groups = group_by_split(files)
|
| 489 |
+
print(f"\n共 {len(groups)} 个 split 组 ({len(files)} 个文件):")
|
| 490 |
+
for (cat, ds, split), flist in groups.items():
|
| 491 |
+
for f in flist:
|
| 492 |
+
size_mb = os.path.getsize(f) / (1024 * 1024)
|
| 493 |
+
print(f" {cat}/{ds}/{split}: {os.path.basename(f):40s} {size_mb:>10.1f} MB")
|
| 494 |
+
|
| 495 |
+
total = 0
|
| 496 |
+
for i, ((cat, ds, split), flist) in enumerate(groups.items(), 1):
|
| 497 |
+
n = process_group(flist, cat, ds, split, i, len(groups), force=force)
|
| 498 |
+
total += n
|
| 499 |
+
|
| 500 |
+
print(f"\n{'=' * 60}")
|
| 501 |
+
print(f"全部完成!共生成 {total} 条索引记录")
|
| 502 |
+
print(f"保存在: {INDEX_ROOT}")
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
if __name__ == "__main__":
|
| 506 |
+
main()
|
ICL/build_sft.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
构建单步决策 SFT 数据集(轻量版,只存引用路径)。
|
| 4 |
+
|
| 5 |
+
输入:
|
| 6 |
+
/workspace/xiaobin/dataset/index/{cat}/{ds}/{split}.jsonl (索引)
|
| 7 |
+
/workspace/xiaobin/dataset/embeddings/{cat}_{ds}_top5.json (预计算相似图)
|
| 8 |
+
/workspace/xiaobin/dataset/caption_cache/{cat}_{ds}.json (VLM描述)
|
| 9 |
+
/workspace/xiaobin/dataset/index/{cat}/{ds}/instructions.json
|
| 10 |
+
|
| 11 |
+
输出:
|
| 12 |
+
/workspace/xiaobin/dataset/sft/{cat}/sft.part{shard}.jsonl
|
| 13 |
+
/workspace/xiaobin/dataset/sft/all/sft.jsonl (合并后)
|
| 14 |
+
|
| 15 |
+
每条记录格式(不含conversation,由train.py动态构建):
|
| 16 |
+
{
|
| 17 |
+
"type": "ret" | "ans",
|
| 18 |
+
"query_image": "/path/to/query.jpg",
|
| 19 |
+
"question": "...",
|
| 20 |
+
"answer": "...",
|
| 21 |
+
"instruction": "...",
|
| 22 |
+
"shots": [{"image": "...", "caption": "..."}],
|
| 23 |
+
"next_description": "...", # 仅 ret 类型
|
| 24 |
+
"category": "vqa",
|
| 25 |
+
"dataset": "vqav2"
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
用法:
|
| 29 |
+
python3 build_sft.py # 全部
|
| 30 |
+
python3 build_sft.py --categories vqa # 单类
|
| 31 |
+
python3 build_sft.py --shard-id 0 --num-shards 4 # 分片
|
| 32 |
+
python3 build_sft.py --merge --shuffle # 合并
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
import argparse
|
| 36 |
+
import json
|
| 37 |
+
import os
|
| 38 |
+
import random
|
| 39 |
+
from pathlib import Path
|
| 40 |
+
from typing import Dict, List, Optional, Tuple
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
from tqdm import tqdm
|
| 44 |
+
except ImportError:
|
| 45 |
+
def tqdm(x, **kw):
|
| 46 |
+
return x
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# 默认路径
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
INDEX_ROOT = "/workspace/xiaobin/dataset/index"
|
| 52 |
+
EMBEDDINGS_DIR = "/workspace/xiaobin/dataset/embeddings"
|
| 53 |
+
CAPTION_CACHE_DIR = "/workspace/xiaobin/dataset/caption_cache"
|
| 54 |
+
OUTPUT_DIR = "/workspace/xiaobin/dataset/sft"
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ---------------------------------------------------------------------------
|
| 58 |
+
# 数据加载
|
| 59 |
+
# ---------------------------------------------------------------------------
|
| 60 |
+
def discover_datasets(index_root: str, categories: List[str]) -> List[Tuple[str, str]]:
|
| 61 |
+
"""发现所有 (category, dataset) 对。"""
|
| 62 |
+
result = []
|
| 63 |
+
for cat in sorted(os.listdir(index_root)):
|
| 64 |
+
if categories and cat not in categories:
|
| 65 |
+
continue
|
| 66 |
+
cat_dir = os.path.join(index_root, cat)
|
| 67 |
+
if not os.path.isdir(cat_dir):
|
| 68 |
+
continue
|
| 69 |
+
for ds in sorted(os.listdir(cat_dir)):
|
| 70 |
+
if os.path.isdir(os.path.join(cat_dir, ds)):
|
| 71 |
+
result.append((cat, ds))
|
| 72 |
+
return result
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def load_index(index_root: str, cat: str, ds: str, split: str) -> List[Dict]:
|
| 76 |
+
"""加载索引 JSONL。"""
|
| 77 |
+
path = os.path.join(index_root, cat, ds, f"{split}.jsonl")
|
| 78 |
+
if not os.path.exists(path):
|
| 79 |
+
return []
|
| 80 |
+
records = []
|
| 81 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 82 |
+
for line in f:
|
| 83 |
+
line = line.strip()
|
| 84 |
+
if not line:
|
| 85 |
+
continue
|
| 86 |
+
try:
|
| 87 |
+
r = json.loads(line)
|
| 88 |
+
# 必须有 image + (question 或 answer)
|
| 89 |
+
if r.get("image"):
|
| 90 |
+
records.append(r)
|
| 91 |
+
except Exception:
|
| 92 |
+
continue
|
| 93 |
+
return records
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def load_top5(embeddings_dir: str, cat: str, ds: str, k: int = 5) -> Dict[str, List[str]]:
|
| 97 |
+
"""加载预计算的 top-k 相似图映射。"""
|
| 98 |
+
path = os.path.join(embeddings_dir, f"{cat}_{ds}_top{k}.json")
|
| 99 |
+
if not os.path.exists(path):
|
| 100 |
+
return {}
|
| 101 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 102 |
+
return json.load(f)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def load_captions(caption_cache_dir: str, cat: str, ds: str) -> Dict[str, str]:
|
| 106 |
+
"""加载 caption 缓存: {image_path: description}。"""
|
| 107 |
+
path = os.path.join(caption_cache_dir, f"{cat}_{ds}.json")
|
| 108 |
+
if not os.path.exists(path):
|
| 109 |
+
return {}
|
| 110 |
+
try:
|
| 111 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 112 |
+
data = json.load(f)
|
| 113 |
+
items = data.get("items", {})
|
| 114 |
+
return items if isinstance(items, dict) else {}
|
| 115 |
+
except Exception:
|
| 116 |
+
return {}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def load_instructions(index_root: str, cat: str, ds: str) -> List[str]:
|
| 120 |
+
"""加载 instruction 模板。"""
|
| 121 |
+
path = os.path.join(index_root, cat, ds, "instructions.json")
|
| 122 |
+
if not os.path.exists(path):
|
| 123 |
+
return []
|
| 124 |
+
try:
|
| 125 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 126 |
+
data = json.load(f)
|
| 127 |
+
if isinstance(data, list):
|
| 128 |
+
return [str(x).strip() for x in data if str(x).strip()]
|
| 129 |
+
if isinstance(data, dict):
|
| 130 |
+
for key in ("instructions", "instruction", "prompts"):
|
| 131 |
+
v = data.get(key)
|
| 132 |
+
if isinstance(v, list):
|
| 133 |
+
return [str(x).strip() for x in v if str(x).strip()]
|
| 134 |
+
return []
|
| 135 |
+
except Exception:
|
| 136 |
+
return []
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# ---------------------------------------------------------------------------
|
| 140 |
+
# 样本生成
|
| 141 |
+
# ---------------------------------------------------------------------------
|
| 142 |
+
def generate_samples(
|
| 143 |
+
records: List[Dict],
|
| 144 |
+
top5_map: Dict[str, List[str]],
|
| 145 |
+
caption_map: Dict[str, str],
|
| 146 |
+
instructions: List[str],
|
| 147 |
+
cat: str, ds: str,
|
| 148 |
+
rng: random.Random,
|
| 149 |
+
max_shots: int = 3,
|
| 150 |
+
answer_at_weights: List[float] = None,
|
| 151 |
+
target_count: int = 0,
|
| 152 |
+
) -> List[Dict]:
|
| 153 |
+
"""为一个数据集生成 SFT 样本。
|
| 154 |
+
|
| 155 |
+
target_count=0 表示全量(遍历每条记录),>0 表示随机抽样到目标数。
|
| 156 |
+
"""
|
| 157 |
+
if answer_at_weights is None:
|
| 158 |
+
answer_at_weights = [1, 3, 3, 2]
|
| 159 |
+
|
| 160 |
+
# 过滤:需要有 answer + top5;question 可为空(captioning 类)
|
| 161 |
+
valid = [r for r in records
|
| 162 |
+
if r.get("answer") and r.get("image") in top5_map]
|
| 163 |
+
if not valid:
|
| 164 |
+
return []
|
| 165 |
+
|
| 166 |
+
answer_at_values = list(range(len(answer_at_weights)))
|
| 167 |
+
default_inst = "Please answer the question based on the image."
|
| 168 |
+
samples = []
|
| 169 |
+
|
| 170 |
+
# 决定遍历源:全量遍历 or 随机抽样
|
| 171 |
+
if target_count > 0:
|
| 172 |
+
# 随机抽样模式
|
| 173 |
+
source = [rng.choice(valid) for _ in range(target_count * 5)]
|
| 174 |
+
else:
|
| 175 |
+
# 全量模式:遍历所有记录
|
| 176 |
+
source = valid
|
| 177 |
+
|
| 178 |
+
for q in source:
|
| 179 |
+
q_img = q["image"]
|
| 180 |
+
q_question = q.get("question") or ""
|
| 181 |
+
q_answer = q["answer"]
|
| 182 |
+
|
| 183 |
+
inst = rng.choice(instructions) if instructions else default_inst
|
| 184 |
+
|
| 185 |
+
answer_at = rng.choices(answer_at_values, weights=answer_at_weights, k=1)[0]
|
| 186 |
+
answer_at = min(answer_at, max_shots)
|
| 187 |
+
|
| 188 |
+
top5 = top5_map.get(q_img, [])
|
| 189 |
+
if answer_at > 0 and not top5:
|
| 190 |
+
continue
|
| 191 |
+
|
| 192 |
+
# 降级处理
|
| 193 |
+
if answer_at > len(top5):
|
| 194 |
+
answer_at = len(top5)
|
| 195 |
+
|
| 196 |
+
# 从 top5 里随机选 answer_at 个
|
| 197 |
+
chosen = rng.sample(top5, answer_at) if answer_at > 0 else []
|
| 198 |
+
|
| 199 |
+
shots = []
|
| 200 |
+
for img_path in chosen:
|
| 201 |
+
cap = caption_map.get(img_path, "")
|
| 202 |
+
shots.append({"image": img_path, "caption": cap})
|
| 203 |
+
|
| 204 |
+
remaining = [p for p in top5 if p not in chosen]
|
| 205 |
+
|
| 206 |
+
# ---- 轨迹式生成:每条记录只有一条一致的 RET→...→ANS 轨迹 ----
|
| 207 |
+
# answer_at=0: 直接 ANS(0-shot)
|
| 208 |
+
# answer_at=2: RET(0-shot) → RET(1-shot) → ANS(2-shot)
|
| 209 |
+
# 不在同一个 (image, question, n-shot) 下同时生成 RET 和 ANS,避免矛盾信号
|
| 210 |
+
for n in range(answer_at):
|
| 211 |
+
if n < len(chosen):
|
| 212 |
+
next_desc = caption_map.get(chosen[n], "")
|
| 213 |
+
elif remaining:
|
| 214 |
+
next_desc = caption_map.get(rng.choice(remaining), "")
|
| 215 |
+
else:
|
| 216 |
+
break
|
| 217 |
+
|
| 218 |
+
# RET 样本:在 n-shot 时决定继续检索
|
| 219 |
+
samples.append({
|
| 220 |
+
"type": "ret",
|
| 221 |
+
"query_image": q_img,
|
| 222 |
+
"question": q_question,
|
| 223 |
+
"answer": q_answer,
|
| 224 |
+
"instruction": inst,
|
| 225 |
+
"shots": shots[:n],
|
| 226 |
+
"next_description": next_desc,
|
| 227 |
+
"category": cat,
|
| 228 |
+
"dataset": ds,
|
| 229 |
+
})
|
| 230 |
+
|
| 231 |
+
# ANS 样本:在 answer_at shot 时回答
|
| 232 |
+
samples.append({
|
| 233 |
+
"type": "ans",
|
| 234 |
+
"query_image": q_img,
|
| 235 |
+
"question": q_question,
|
| 236 |
+
"answer": q_answer,
|
| 237 |
+
"instruction": inst,
|
| 238 |
+
"shots": shots[:answer_at],
|
| 239 |
+
"category": cat,
|
| 240 |
+
"dataset": ds,
|
| 241 |
+
})
|
| 242 |
+
|
| 243 |
+
if target_count > 0 and len(samples) >= target_count:
|
| 244 |
+
break
|
| 245 |
+
|
| 246 |
+
if target_count > 0:
|
| 247 |
+
samples = samples[:target_count]
|
| 248 |
+
|
| 249 |
+
return samples
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# ---------------------------------------------------------------------------
|
| 253 |
+
# 文件工具
|
| 254 |
+
# ---------------------------------------------------------------------------
|
| 255 |
+
def write_jsonl(path: str, records: List[Dict]):
|
| 256 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 257 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 258 |
+
for r in records:
|
| 259 |
+
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def concat_and_shuffle(output_dir: str, categories: List[str], shuffle: bool, seed: int):
|
| 263 |
+
"""合并各 category 的分片,生成最终数据集。"""
|
| 264 |
+
rng = random.Random(seed)
|
| 265 |
+
|
| 266 |
+
for cat in categories:
|
| 267 |
+
cat_dir = os.path.join(output_dir, cat)
|
| 268 |
+
if not os.path.isdir(cat_dir):
|
| 269 |
+
continue
|
| 270 |
+
parts = sorted(Path(cat_dir).glob("sft.part*.jsonl"))
|
| 271 |
+
if not parts:
|
| 272 |
+
continue
|
| 273 |
+
out_path = os.path.join(cat_dir, "sft.jsonl")
|
| 274 |
+
lines = []
|
| 275 |
+
for p in parts:
|
| 276 |
+
with open(p, "r", encoding="utf-8") as f:
|
| 277 |
+
lines.extend(line for line in f if line.strip())
|
| 278 |
+
if shuffle:
|
| 279 |
+
rng.shuffle(lines)
|
| 280 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 281 |
+
f.writelines(lines)
|
| 282 |
+
print(f" [OK] {cat}: {len(lines)} 条")
|
| 283 |
+
|
| 284 |
+
# 合并所有 category
|
| 285 |
+
all_lines = []
|
| 286 |
+
for cat in categories:
|
| 287 |
+
cat_file = os.path.join(output_dir, cat, "sft.jsonl")
|
| 288 |
+
if os.path.exists(cat_file):
|
| 289 |
+
with open(cat_file, "r", encoding="utf-8") as f:
|
| 290 |
+
all_lines.extend(line for line in f if line.strip())
|
| 291 |
+
if all_lines:
|
| 292 |
+
if shuffle:
|
| 293 |
+
rng.shuffle(all_lines)
|
| 294 |
+
all_dir = os.path.join(output_dir, "all")
|
| 295 |
+
os.makedirs(all_dir, exist_ok=True)
|
| 296 |
+
all_path = os.path.join(all_dir, "sft.jsonl")
|
| 297 |
+
with open(all_path, "w", encoding="utf-8") as f:
|
| 298 |
+
f.writelines(all_lines)
|
| 299 |
+
print(f" [OK] all: {len(all_lines)} 条 → {all_path}")
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
# ---------------------------------------------------------------------------
|
| 303 |
+
# Main
|
| 304 |
+
# ---------------------------------------------------------------------------
|
| 305 |
+
def main():
|
| 306 |
+
parser = argparse.ArgumentParser(description="构建单步决策 SFT 数据集")
|
| 307 |
+
|
| 308 |
+
# 路径
|
| 309 |
+
parser.add_argument("--index-root", default=INDEX_ROOT)
|
| 310 |
+
parser.add_argument("--embeddings-dir", default=EMBEDDINGS_DIR)
|
| 311 |
+
parser.add_argument("--caption-cache-dir", default=CAPTION_CACHE_DIR)
|
| 312 |
+
parser.add_argument("--output-dir", default=OUTPUT_DIR)
|
| 313 |
+
|
| 314 |
+
# 数据集选择
|
| 315 |
+
parser.add_argument("--categories", default="vqa,captioning,classification,reasoning")
|
| 316 |
+
parser.add_argument("--split", default="train", help="query 来自哪个 split")
|
| 317 |
+
parser.add_argument("--top-k", type=int, default=5)
|
| 318 |
+
|
| 319 |
+
# 样本参数
|
| 320 |
+
parser.add_argument("--samples-per-cat", type=int, default=0,
|
| 321 |
+
help="每类目标数,0=全量遍历所有记录")
|
| 322 |
+
parser.add_argument("--samples-per-ds", type=int, default=0,
|
| 323 |
+
help="每个数据集最多取多少条原始记录(0=不限)")
|
| 324 |
+
parser.add_argument("--max-shots", type=int, default=3)
|
| 325 |
+
parser.add_argument("--answer-at-weights", default="1,3,3,2",
|
| 326 |
+
help="0/1/2/3-shot 的权重(默认 1,3,3,2,鼓励多轮 RET)")
|
| 327 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 328 |
+
|
| 329 |
+
# 分片
|
| 330 |
+
parser.add_argument("--shard-id", type=int, default=0)
|
| 331 |
+
parser.add_argument("--num-shards", type=int, default=1)
|
| 332 |
+
|
| 333 |
+
# 模式
|
| 334 |
+
parser.add_argument("--merge", action="store_true", help="合并分片")
|
| 335 |
+
parser.add_argument("--shuffle", action="store_true", help="合并时 shuffle")
|
| 336 |
+
|
| 337 |
+
args = parser.parse_args()
|
| 338 |
+
categories = [c.strip() for c in args.categories.split(",") if c.strip()]
|
| 339 |
+
|
| 340 |
+
# ---- 合并模式 ----
|
| 341 |
+
if args.merge:
|
| 342 |
+
print("合并分片...")
|
| 343 |
+
concat_and_shuffle(args.output_dir, categories, args.shuffle, args.seed)
|
| 344 |
+
return
|
| 345 |
+
|
| 346 |
+
# ---- 构建模式 ----
|
| 347 |
+
aw = [float(x) for x in args.answer_at_weights.split(",") if x.strip()]
|
| 348 |
+
rng = random.Random(args.seed + args.shard_id * 1000003)
|
| 349 |
+
|
| 350 |
+
datasets = discover_datasets(args.index_root, categories)
|
| 351 |
+
print(f"共 {len(datasets)} 个数据集")
|
| 352 |
+
|
| 353 |
+
# 按 category 分组
|
| 354 |
+
cat_datasets: Dict[str, List[Tuple[str, str]]] = {}
|
| 355 |
+
for cat, ds in datasets:
|
| 356 |
+
cat_datasets.setdefault(cat, []).append((cat, ds))
|
| 357 |
+
|
| 358 |
+
for cat in categories:
|
| 359 |
+
ds_list = cat_datasets.get(cat, [])
|
| 360 |
+
if not ds_list:
|
| 361 |
+
print(f"[SKIP] {cat}: 无数据集")
|
| 362 |
+
continue
|
| 363 |
+
|
| 364 |
+
# 加载数据
|
| 365 |
+
ds_data = []
|
| 366 |
+
for c, d in ds_list:
|
| 367 |
+
records = load_index(args.index_root, c, d, args.split)
|
| 368 |
+
top5 = load_top5(args.embeddings_dir, c, d, args.top_k)
|
| 369 |
+
captions = load_captions(args.caption_cache_dir, c, d)
|
| 370 |
+
insts = load_instructions(args.index_root, c, d)
|
| 371 |
+
if not records or not top5:
|
| 372 |
+
print(f" [SKIP] {c}/{d}: records={len(records)} top5={len(top5)}")
|
| 373 |
+
continue
|
| 374 |
+
# 预检:有多少条记录同时有 answer + top5 覆盖
|
| 375 |
+
n_valid = sum(1 for r in records
|
| 376 |
+
if r.get("answer") and r.get("image") in top5)
|
| 377 |
+
if n_valid == 0:
|
| 378 |
+
print(f" [SKIP] {c}/{d}: {len(records)} 条但无 answer+top5 覆盖")
|
| 379 |
+
continue
|
| 380 |
+
|
| 381 |
+
ds_data.append({
|
| 382 |
+
"cat": c, "ds": d,
|
| 383 |
+
"records": records, "top5": top5,
|
| 384 |
+
"captions": captions, "instructions": insts,
|
| 385 |
+
})
|
| 386 |
+
# 统计 caption 覆盖率
|
| 387 |
+
n_cap = sum(1 for r in records if r.get("image") in captions)
|
| 388 |
+
n_top5 = sum(1 for r in records if r.get("image") in top5)
|
| 389 |
+
print(f" [OK] {c}/{d}: {len(records)} 条, "
|
| 390 |
+
f"valid={n_valid}, top5覆盖={n_top5}, caption覆盖={n_cap}, "
|
| 391 |
+
f"instructions={len(insts)}")
|
| 392 |
+
|
| 393 |
+
if not ds_data:
|
| 394 |
+
print(f"[WARN] {cat}: 无可用数据集")
|
| 395 |
+
continue
|
| 396 |
+
|
| 397 |
+
all_samples = []
|
| 398 |
+
|
| 399 |
+
# 计算每个数据集该抽多少条原始记录
|
| 400 |
+
n_ds = len(ds_data)
|
| 401 |
+
if args.samples_per_cat > 0:
|
| 402 |
+
# 目标: 每类 samples_per_cat 条 SFT 样本
|
| 403 |
+
# 保守估计每条记录生成 ~1.5 条样本(captioning等可能更少)
|
| 404 |
+
# 多抽一些,最后按 samples_per_cat 截断
|
| 405 |
+
records_per_ds = max(200, int(args.samples_per_cat / 1.0 / n_ds))
|
| 406 |
+
elif args.samples_per_ds > 0:
|
| 407 |
+
records_per_ds = args.samples_per_ds
|
| 408 |
+
else:
|
| 409 |
+
records_per_ds = 0 # 全量
|
| 410 |
+
|
| 411 |
+
print(f" {cat}: {n_ds} 个数据集, 每个抽 {records_per_ds} 条记录" if records_per_ds > 0
|
| 412 |
+
else f" {cat}: {n_ds} 个数据集, 全量")
|
| 413 |
+
|
| 414 |
+
for d in tqdm(ds_data, desc=f"{cat} shard{args.shard_id}"):
|
| 415 |
+
recs = d["records"]
|
| 416 |
+
|
| 417 |
+
# 抽样
|
| 418 |
+
if records_per_ds > 0 and len(recs) > records_per_ds:
|
| 419 |
+
recs = rng.sample(recs, records_per_ds)
|
| 420 |
+
|
| 421 |
+
samples = generate_samples(
|
| 422 |
+
records=recs,
|
| 423 |
+
top5_map=d["top5"],
|
| 424 |
+
caption_map=d["captions"],
|
| 425 |
+
instructions=d["instructions"],
|
| 426 |
+
cat=d["cat"], ds=d["ds"],
|
| 427 |
+
rng=rng,
|
| 428 |
+
max_shots=args.max_shots,
|
| 429 |
+
answer_at_weights=aw,
|
| 430 |
+
target_count=0, # 遍历抽出的所有记录
|
| 431 |
+
)
|
| 432 |
+
all_samples.extend(samples)
|
| 433 |
+
|
| 434 |
+
# 截断到目标数(仅 samples-per-cat>0 时)
|
| 435 |
+
if args.samples_per_cat > 0 and len(all_samples) > args.samples_per_cat:
|
| 436 |
+
rng.shuffle(all_samples)
|
| 437 |
+
all_samples = all_samples[:args.samples_per_cat]
|
| 438 |
+
|
| 439 |
+
# shuffle 保证混合
|
| 440 |
+
rng.shuffle(all_samples)
|
| 441 |
+
|
| 442 |
+
# 写出
|
| 443 |
+
out_path = os.path.join(args.output_dir, cat, f"sft.part{args.shard_id:02d}.jsonl")
|
| 444 |
+
write_jsonl(out_path, all_samples)
|
| 445 |
+
|
| 446 |
+
# 统计
|
| 447 |
+
n_ret = sum(1 for r in all_samples if r["type"] == "ret")
|
| 448 |
+
n_ans = sum(1 for r in all_samples if r["type"] == "ans")
|
| 449 |
+
n_dist = {}
|
| 450 |
+
for r in all_samples:
|
| 451 |
+
nc = len(r.get("shots", []))
|
| 452 |
+
n_dist[nc] = n_dist.get(nc, 0) + 1
|
| 453 |
+
print(f"[OK] {cat} shard{args.shard_id}: {len(all_samples)} 条 "
|
| 454 |
+
f"(ret={n_ret} ans={n_ans}) shot分布={dict(sorted(n_dist.items()))}")
|
| 455 |
+
print(f" → {out_path}")
|
| 456 |
+
|
| 457 |
+
# 单 shard 时自动合并 + shuffle
|
| 458 |
+
if args.num_shards == 1:
|
| 459 |
+
print("\n自动合并所有 category...")
|
| 460 |
+
concat_and_shuffle(args.output_dir, categories, shuffle=True, seed=args.seed)
|
| 461 |
+
|
| 462 |
+
print(f"\n完成!输出: {args.output_dir}")
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
if __name__ == "__main__":
|
| 466 |
+
main()
|
ICL/dataset_inspect.tree.txt
ADDED
|
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
M3IT/
|
| 2 |
+
.git/
|
| 3 |
+
data/
|
| 4 |
+
.gitattributes (2.8KB)
|
| 5 |
+
.gitignore (29.0B)
|
| 6 |
+
M3IT.py (54.5KB)
|
| 7 |
+
README.md (18.3KB)
|
| 8 |
+
branches/
|
| 9 |
+
hooks/
|
| 10 |
+
info/
|
| 11 |
+
lfs/
|
| 12 |
+
logs/
|
| 13 |
+
objects/
|
| 14 |
+
refs/
|
| 15 |
+
FETCH_HEAD (110.0B)
|
| 16 |
+
HEAD (21.0B)
|
| 17 |
+
config (339.0B)
|
| 18 |
+
description (73.0B)
|
| 19 |
+
packed-refs (112.0B)
|
| 20 |
+
refs/
|
| 21 |
+
HEAD (189.0B)
|
| 22 |
+
heads/
|
| 23 |
+
remotes/
|
| 24 |
+
main (189.0B)
|
| 25 |
+
heads/
|
| 26 |
+
remotes/
|
| 27 |
+
tags/
|
| 28 |
+
origin/
|
| 29 |
+
HEAD (30.0B)
|
| 30 |
+
main (41.0B)
|
| 31 |
+
info/
|
| 32 |
+
pack/
|
| 33 |
+
pack-ee3e40a1a23ec17affa3b8afb61dc14bdffb229c.idx (38.9KB)
|
| 34 |
+
pack-ee3e40a1a23ec17affa3b8afb61dc14bdffb229c.pack (195.5KB)
|
| 35 |
+
applypatch-msg.sample (478.0B)
|
| 36 |
+
commit-msg.sample (896.0B)
|
| 37 |
+
fsmonitor-watchman.sample (4.5KB)
|
| 38 |
+
post-checkout (280.0B)
|
| 39 |
+
post-commit (276.0B)
|
| 40 |
+
post-merge (274.0B)
|
| 41 |
+
post-update.sample (189.0B)
|
| 42 |
+
pre-applypatch.sample (424.0B)
|
| 43 |
+
pre-commit.sample (1.6KB)
|
| 44 |
+
pre-merge-commit.sample (416.0B)
|
| 45 |
+
pre-push (270.0B)
|
| 46 |
+
pre-push.sample (1.3KB)
|
| 47 |
+
pre-rebase.sample (4.8KB)
|
| 48 |
+
pre-receive.sample (544.0B)
|
| 49 |
+
prepare-commit-msg.sample (1.5KB)
|
| 50 |
+
push-to-checkout.sample (2.7KB)
|
| 51 |
+
update.sample (3.6KB)
|
| 52 |
+
incomplete/
|
| 53 |
+
logs/
|
| 54 |
+
objects/
|
| 55 |
+
tmp/
|
| 56 |
+
0152398d9443f2d300adc9e6099a773c66303d4e2e085812cd502cb36da7a0c73483193049 (0.0B)
|
| 57 |
+
0152398d9443f2d300adc9e6099a773c66303d4e2e085812cd502cb36da7a0c7763208216 (0.0B)
|
| 58 |
+
0152398d9443f2d300adc9e6099a773c66303d4e2e085812cd502cb36da7a0c789921672 (2.5MB)
|
| 59 |
+
0968a4438d46277583968011563e959e130feaee66f51bb2d66dbd7e8c979f8c.part (0.0B)
|
| 60 |
+
1f77f56225e10edca84be06b6e0d796c579cbf1d4884aee46da564438ad1ba9b1484563810 (437.0KB)
|
| 61 |
+
1f77f56225e10edca84be06b6e0d796c579cbf1d4884aee46da564438ad1ba9b3850099655 (326.7KB)
|
| 62 |
+
1f77f56225e10edca84be06b6e0d796c579cbf1d4884aee46da564438ad1ba9b3898577811 (4.1MB)
|
| 63 |
+
220d32d087b6b29d1c5aaa49324d32b32ae1c19f42e9800f40f24d3a695c2a8d1743027097 (0.0B)
|
| 64 |
+
220d32d087b6b29d1c5aaa49324d32b32ae1c19f42e9800f40f24d3a695c2a8d3014727128 (0.0B)
|
| 65 |
+
220d32d087b6b29d1c5aaa49324d32b32ae1c19f42e9800f40f24d3a695c2a8d71894927 (62.6KB)
|
| 66 |
+
24f014bb5bc7b1fa7d9183dd65fd4b43c0c49aafd6af01bb91ae3a0e7e65502b2818819757 (49.3MB)
|
| 67 |
+
3da69649bfbc671710f38c2c2f7c6aaecb8f8544de3446866054bf927257c9332854861486 (158.6KB)
|
| 68 |
+
3da69649bfbc671710f38c2c2f7c6aaecb8f8544de3446866054bf927257c9334214717938 (0.0B)
|
| 69 |
+
3da69649bfbc671710f38c2c2f7c6aaecb8f8544de3446866054bf927257c933593947826 (0.0B)
|
| 70 |
+
45e8c51ed0df8edb1ae51d2012b3f7d6cd9cc84addf41e6f9f9adb0f625d41033126870057 (259.2MB)
|
| 71 |
+
4a80559730d917177e4d13246da0ce23ca318735b29d519d0448bea5579b1a771450117433 (154.4MB)
|
| 72 |
+
4fda2aa4918e5dec847935db6d46e9bebc570a173bd4201c5f48e60a3f73813a1530155941 (1.1MB)
|
| 73 |
+
4fda2aa4918e5dec847935db6d46e9bebc570a173bd4201c5f48e60a3f73813a2738070238 (0.0B)
|
| 74 |
+
4fda2aa4918e5dec847935db6d46e9bebc570a173bd4201c5f48e60a3f73813a2828099128 (0.0B)
|
| 75 |
+
52a445f8a26cd898e64129e7f1d4bfa6d7203311442068684f5344fc73407310.part (0.0B)
|
| 76 |
+
6728a8fb7bad0bad3a2a27669232cb9ae66461c635172f1f7958c80a28e09fa32607733000 (150.2MB)
|
| 77 |
+
6bb6c9f17e77eab7d88e4a4501c38cb31a6cf792fe77e3b75d511b964a5667df2998182268 (91.8MB)
|
| 78 |
+
8cb15647ff6bbac322142fea1a38599c523f73acb3614ddb7d12e6a1975a79dc1986657385 (0.0B)
|
| 79 |
+
8cb15647ff6bbac322142fea1a38599c523f73acb3614ddb7d12e6a1975a79dc2743098052 (0.0B)
|
| 80 |
+
8cb15647ff6bbac322142fea1a38599c523f73acb3614ddb7d12e6a1975a79dc4193739161 (0.0B)
|
| 81 |
+
9919274ad6bc88e37235a4c7245d05e357e404ef3352a90a1ba0594e694893c01114223911 (0.0B)
|
| 82 |
+
9919274ad6bc88e37235a4c7245d05e357e404ef3352a90a1ba0594e694893c03545613611 (0.0B)
|
| 83 |
+
9919274ad6bc88e37235a4c7245d05e357e404ef3352a90a1ba0594e694893c0559090370 (2.8MB)
|
| 84 |
+
9cdf4d1a6972db893c8db1a4f2be0d1ec0362ba22a44542402b336760029c87253830692 (88.0MB)
|
| 85 |
+
b6aed90c79d180c5346994f8e7d0657b3d8a9aab002c057503736b4013a2096b.part (0.0B)
|
| 86 |
+
ba47b9680dc949322877399218d1f210a057249803bc70addfb9528152e4b1662004000729 (218.5MB)
|
| 87 |
+
ca49e0b3f3400f38519a1103b2a567db32c9fa990a7395b1024b94454601479b.part (0.0B)
|
| 88 |
+
d66a5b3267a7935b8ff272bcc166a8f43a8d66fb89c59503d536ac87661a02022501429466 (0.0B)
|
| 89 |
+
d66a5b3267a7935b8ff272bcc166a8f43a8d66fb89c59503d536ac87661a020230475132 (0.0B)
|
| 90 |
+
d66a5b3267a7935b8ff272bcc166a8f43a8d66fb89c59503d536ac87661a0202373225118 (62.5KB)
|
| 91 |
+
e5a3eb3e2d0c47d6f014e294ef7398bf26375920c8d2af80fd65e255396dcc78.part (0.0B)
|
| 92 |
+
f19cacf3a9f9a57abdcafc4a6d242aa9c6fa48188ad0a394b1a2558cb8ab4dc5372340294 (199.2MB)
|
| 93 |
+
20251021T152133.441099492.log (1.4KB)
|
| 94 |
+
01/
|
| 95 |
+
02/
|
| 96 |
+
03/
|
| 97 |
+
05/
|
| 98 |
+
06/
|
| 99 |
+
07/
|
| 100 |
+
09/
|
| 101 |
+
0b/
|
| 102 |
+
0f/
|
| 103 |
+
10/
|
| 104 |
+
12/
|
| 105 |
+
15/
|
| 106 |
+
16/
|
| 107 |
+
19/
|
| 108 |
+
1d/
|
| 109 |
+
1e/
|
| 110 |
+
1f/
|
| 111 |
+
21/
|
| 112 |
+
22/
|
| 113 |
+
23/
|
| 114 |
+
24/
|
| 115 |
+
2a/
|
| 116 |
+
2b/
|
| 117 |
+
2c/
|
| 118 |
+
2d/
|
| 119 |
+
2f/
|
| 120 |
+
30/
|
| 121 |
+
32/
|
| 122 |
+
34/
|
| 123 |
+
37/
|
| 124 |
+
3b/
|
| 125 |
+
3d/
|
| 126 |
+
44/
|
| 127 |
+
45/
|
| 128 |
+
4a/
|
| 129 |
+
4f/
|
| 130 |
+
50/
|
| 131 |
+
52/
|
| 132 |
+
54/
|
| 133 |
+
56/
|
| 134 |
+
58/
|
| 135 |
+
5a/
|
| 136 |
+
5b/
|
| 137 |
+
60/
|
| 138 |
+
61/
|
| 139 |
+
64/
|
| 140 |
+
65/
|
| 141 |
+
67/
|
| 142 |
+
68/
|
| 143 |
+
69/
|
| 144 |
+
6b/
|
| 145 |
+
6d/
|
| 146 |
+
6e/
|
| 147 |
+
70/
|
| 148 |
+
75/
|
| 149 |
+
76/
|
| 150 |
+
7b/
|
| 151 |
+
7c/
|
| 152 |
+
80/
|
| 153 |
+
87/
|
| 154 |
+
88/
|
| 155 |
+
89/
|
| 156 |
+
8b/
|
| 157 |
+
8c/
|
| 158 |
+
90/
|
| 159 |
+
91/
|
| 160 |
+
93/
|
| 161 |
+
99/
|
| 162 |
+
9a/
|
| 163 |
+
9b/
|
| 164 |
+
9c/
|
| 165 |
+
9e/
|
| 166 |
+
9f/
|
| 167 |
+
a0/
|
| 168 |
+
a5/
|
| 169 |
+
a9/
|
| 170 |
+
ac/
|
| 171 |
+
ae/
|
| 172 |
+
b1/
|
| 173 |
+
b3/
|
| 174 |
+
b4/
|
| 175 |
+
b6/
|
| 176 |
+
ba/
|
| 177 |
+
bb/
|
| 178 |
+
bc/
|
| 179 |
+
bd/
|
| 180 |
+
be/
|
| 181 |
+
c0/
|
| 182 |
+
c1/
|
| 183 |
+
c2/
|
| 184 |
+
c4/
|
| 185 |
+
c6/
|
| 186 |
+
c7/
|
| 187 |
+
c8/
|
| 188 |
+
ca/
|
| 189 |
+
cb/
|
| 190 |
+
d6/
|
| 191 |
+
d9/
|
| 192 |
+
dd/
|
| 193 |
+
e2/
|
| 194 |
+
e5/
|
| 195 |
+
e7/
|
| 196 |
+
e8/
|
| 197 |
+
e9/
|
| 198 |
+
ee/
|
| 199 |
+
ef/
|
| 200 |
+
f1/
|
| 201 |
+
f3/
|
| 202 |
+
f4/
|
| 203 |
+
f5/
|
| 204 |
+
f6/
|
| 205 |
+
f7/
|
| 206 |
+
f8/
|
| 207 |
+
f9/
|
| 208 |
+
fc/
|
| 209 |
+
exclude (240.0B)
|
| 210 |
+
captioning/
|
| 211 |
+
classification/
|
| 212 |
+
generation/
|
| 213 |
+
reasoning/
|
| 214 |
+
vqa/
|
| 215 |
+
chinesefoodnet-10/
|
| 216 |
+
coco-goi/
|
| 217 |
+
coco-text/
|
| 218 |
+
imagenet/
|
| 219 |
+
iqa/
|
| 220 |
+
itm/
|
| 221 |
+
mocheg/
|
| 222 |
+
refcoco/
|
| 223 |
+
snli-ve/
|
| 224 |
+
ss/
|
| 225 |
+
vsr/
|
| 226 |
+
winoground/
|
| 227 |
+
.gitattributes (141.0B)
|
| 228 |
+
README.md (211.0B)
|
| 229 |
+
instructions.json (1.4KB)
|
| 230 |
+
labels.json (9.0KB)
|
| 231 |
+
test.jsonl (223.5MB)
|
| 232 |
+
train.jsonl (238.9MB)
|
| 233 |
+
val.jsonl (227.6MB)
|
| 234 |
+
README.md (31.0B)
|
| 235 |
+
esnlive_test.jsonl (743.0MB)
|
| 236 |
+
esnlive_train.jsonl (1000.8MB)
|
| 237 |
+
esnlive_val.jsonl (717.9MB)
|
| 238 |
+
instructions.json (1.9KB)
|
| 239 |
+
test_2023-10-09.jsonl (2.9GB)
|
| 240 |
+
train_2023-10-09.jsonl (3.9GB)
|
| 241 |
+
instructions.json (825.0B)
|
| 242 |
+
mapping.txt (30.9KB)
|
| 243 |
+
test_2023-10-08.jsonl (10.6GB)
|
| 244 |
+
train.jsonl (1.5GB)
|
| 245 |
+
train_2023-10-08.jsonl (5.9GB)
|
| 246 |
+
val.jsonl (2.6GB)
|
| 247 |
+
instructions.json (907.0B)
|
| 248 |
+
test.jsonl (330.4MB)
|
| 249 |
+
test_2023-10-09.jsonl (1.3GB)
|
| 250 |
+
train.jsonl (1.9GB)
|
| 251 |
+
train_2023-10-08.jsonl (7.8GB)
|
| 252 |
+
val.jsonl (330.8MB)
|
| 253 |
+
instructions.json (773.0B)
|
| 254 |
+
test.jsonl (730.0MB)
|
| 255 |
+
test_2023-10-09.jsonl (2.9GB)
|
| 256 |
+
train.jsonl (4.3GB)
|
| 257 |
+
train_2023-10-08.jsonl (17.1GB)
|
| 258 |
+
val.jsonl (730.2MB)
|
| 259 |
+
instructions.json (1.4KB)
|
| 260 |
+
test_2023-10-09.jsonl (553.7MB)
|
| 261 |
+
train_2023-10-09.jsonl (1.9GB)
|
| 262 |
+
vsr_test.jsonl (137.7MB)
|
| 263 |
+
vsr_train.jsonl (483.3MB)
|
| 264 |
+
vsr_val.jsonl (68.8MB)
|
| 265 |
+
instructions.json (774.0B)
|
| 266 |
+
test_2023-10-10.jsonl (7.6GB)
|
| 267 |
+
train.jsonl (8.2GB)
|
| 268 |
+
train_2023-10-08.jsonl (32.8GB)
|
| 269 |
+
val.jsonl (1.9GB)
|
| 270 |
+
instructions.json (733.0B)
|
| 271 |
+
test_2023-10-07.jsonl (279.1MB)
|
| 272 |
+
train.jsonl (2.0GB)
|
| 273 |
+
train_2023-10-06.jsonl (4.1GB)
|
| 274 |
+
val.jsonl (138.9MB)
|
| 275 |
+
instructions.json (2.0KB)
|
| 276 |
+
winoground_test.jsonl (245.5MB)
|
| 277 |
+
instructions.json (1.3KB)
|
| 278 |
+
test.jsonl (122.9MB)
|
| 279 |
+
instructions.json (1.0KB)
|
| 280 |
+
mocheg_test.jsonl (60.3MB)
|
| 281 |
+
mocheg_train.jsonl (631.7MB)
|
| 282 |
+
mocheg_val.jsonl (28.2MB)
|
| 283 |
+
test_2023-10-08.jsonl (242.5MB)
|
| 284 |
+
train_2023-10-08.jsonl (2.5GB)
|
| 285 |
+
instructions.json (1.5KB)
|
| 286 |
+
test.jsonl (701.9MB)
|
| 287 |
+
test_2023-10-08.jsonl (2.7GB)
|
| 288 |
+
train.jsonl (3.9GB)
|
| 289 |
+
train_2023-10-08.jsonl (15.6GB)
|
| 290 |
+
val.jsonl (667.7MB)
|
| 291 |
+
clevr/
|
| 292 |
+
nlvr/
|
| 293 |
+
science_qa/
|
| 294 |
+
vcr/
|
| 295 |
+
visual_mrc/
|
| 296 |
+
instructions.json (2.5KB)
|
| 297 |
+
science_qa_test.jsonl (174.0MB)
|
| 298 |
+
science_qa_train.jsonl (531.3MB)
|
| 299 |
+
science_qa_validation.jsonl (176.4MB)
|
| 300 |
+
instructions.json (976.0B)
|
| 301 |
+
train.jsonl (5.6GB)
|
| 302 |
+
train_2023-10-07.jsonl (11.1GB)
|
| 303 |
+
val.jsonl (379.6MB)
|
| 304 |
+
val_2023-10-07.jsonl (760.4MB)
|
| 305 |
+
instructions.json (911.0B)
|
| 306 |
+
test.jsonl (1.2GB)
|
| 307 |
+
train.jsonl (3.9GB)
|
| 308 |
+
val.jsonl (266.9MB)
|
| 309 |
+
instructions.json (1.3KB)
|
| 310 |
+
test.jsonl (909.3MB)
|
| 311 |
+
train.jsonl (4.3GB)
|
| 312 |
+
val.jsonl (992.9MB)
|
| 313 |
+
instructions.json (1.2KB)
|
| 314 |
+
test.jsonl (489.0MB)
|
| 315 |
+
train.jsonl (7.9GB)
|
| 316 |
+
val.jsonl (533.3MB)
|
| 317 |
+
mmchat/
|
| 318 |
+
multi30k/
|
| 319 |
+
vist/
|
| 320 |
+
visual_dialog/
|
| 321 |
+
instructions.json (818.0B)
|
| 322 |
+
test.jsonl (65.2MB)
|
| 323 |
+
test_2023-10-10.jsonl (262.2MB)
|
| 324 |
+
train.jsonl (3.2GB)
|
| 325 |
+
train_2023-10-09.jsonl (13.0GB)
|
| 326 |
+
val.jsonl (66.0MB)
|
| 327 |
+
instructions.json (1.2KB)
|
| 328 |
+
test.jsonl (610.6MB)
|
| 329 |
+
train.jsonl (4.4GB)
|
| 330 |
+
val.jsonl (301.1MB)
|
| 331 |
+
instructions.json (809.0B)
|
| 332 |
+
test.jsonl (2.3GB)
|
| 333 |
+
train.jsonl (6.2GB)
|
| 334 |
+
train_new.jsonl (6.2GB)
|
| 335 |
+
validation.jsonl (2.0GB)
|
| 336 |
+
instructions.json (1.0KB)
|
| 337 |
+
test.jsonl (14.0GB)
|
| 338 |
+
train.jsonl (15.4GB)
|
| 339 |
+
val.jsonl (13.0GB)
|
| 340 |
+
a-okvqa/
|
| 341 |
+
activitynet-qa/
|
| 342 |
+
docvqa/
|
| 343 |
+
fm-iqa/
|
| 344 |
+
gqa/
|
| 345 |
+
ivqa/
|
| 346 |
+
msrvtt-qa/
|
| 347 |
+
msvd-qa/
|
| 348 |
+
ocr-vqa/
|
| 349 |
+
okvqa/
|
| 350 |
+
shapes/
|
| 351 |
+
st-vqa/
|
| 352 |
+
text-vqa/
|
| 353 |
+
viquae/
|
| 354 |
+
vqav2/
|
| 355 |
+
instruction.json (905.0B)
|
| 356 |
+
train.jsonl (533.5MB)
|
| 357 |
+
train_new.jsonl (533.5MB)
|
| 358 |
+
validation.jsonl (228.3MB)
|
| 359 |
+
instructions.json (1.9KB)
|
| 360 |
+
train.jsonl (1.2GB)
|
| 361 |
+
train_v2.jsonl (1.2GB)
|
| 362 |
+
val.jsonl (77.7MB)
|
| 363 |
+
val_v2.jsonl (78.2MB)
|
| 364 |
+
instruction.json (905.0B)
|
| 365 |
+
test.jsonl (713.3MB)
|
| 366 |
+
train.jsonl (3.3GB)
|
| 367 |
+
validation_new.jsonl (529.5MB)
|
| 368 |
+
instruction.json (772.0B)
|
| 369 |
+
train.jsonl (1.5GB)
|
| 370 |
+
validation.jsonl (260.3MB)
|
| 371 |
+
instruction.json (853.0B)
|
| 372 |
+
test.jsonl (229.4MB)
|
| 373 |
+
train.jsonl (1.4GB)
|
| 374 |
+
README.md (288.0B)
|
| 375 |
+
instructions.json (1.2KB)
|
| 376 |
+
test.jsonl (132.4MB)
|
| 377 |
+
train.jsonl (343.1MB)
|
| 378 |
+
val.jsonl (60.9MB)
|
| 379 |
+
instructions.json (853.0B)
|
| 380 |
+
train.jsonl (1.9GB)
|
| 381 |
+
val.jsonl (1.9GB)
|
| 382 |
+
instructions.json (1.7KB)
|
| 383 |
+
train.jsonl (7.2GB)
|
| 384 |
+
val.jsonl (976.6MB)
|
| 385 |
+
instructions.json (1.5KB)
|
| 386 |
+
test.jsonl (1.4MB)
|
| 387 |
+
test_2023-10-08.jsonl (7.0MB)
|
| 388 |
+
train.large.jsonl (18.3MB)
|
| 389 |
+
train_2023-10-08.jsonl (92.6MB)
|
| 390 |
+
val.jsonl (1.4MB)
|
| 391 |
+
README.md (334.0B)
|
| 392 |
+
instructions.json (1.0KB)
|
| 393 |
+
test.jsonl (500.8MB)
|
| 394 |
+
train.jsonl (1.5GB)
|
| 395 |
+
val.jsonl (485.4MB)
|
| 396 |
+
README.md (434.0B)
|
| 397 |
+
instructions.json (1.0KB)
|
| 398 |
+
test.jsonl (348.1MB)
|
| 399 |
+
train.jsonl (757.5MB)
|
| 400 |
+
val.jsonl (58.0MB)
|
| 401 |
+
.gitattributes (141.0B)
|
| 402 |
+
README.md (332.0B)
|
| 403 |
+
instructions.json (1.4KB)
|
| 404 |
+
test.jsonl (474.7MB)
|
| 405 |
+
train.jsonl (2.1GB)
|
| 406 |
+
val.jsonl (1.1GB)
|
| 407 |
+
instructions.json (1.2KB)
|
| 408 |
+
train.jsonl (594.8MB)
|
| 409 |
+
train_v2.jsonl (596.3MB)
|
| 410 |
+
val.jsonl (334.3MB)
|
| 411 |
+
val_v2.jsonl (335.2MB)
|
| 412 |
+
instructions.json (802.0B)
|
| 413 |
+
para_train.jsonl (10.5GB)
|
| 414 |
+
para_val.jsonl (4.8GB)
|
| 415 |
+
train.jsonl (10.5GB)
|
| 416 |
+
val.jsonl (4.8GB)
|
| 417 |
+
instructions.json (1.2KB)
|
| 418 |
+
test.jsonl (122.5MB)
|
| 419 |
+
test_v2.jsonl (120.9MB)
|
| 420 |
+
train.jsonl (110.1MB)
|
| 421 |
+
train_v2.jsonl (110.2MB)
|
| 422 |
+
validation.jsonl (125.5MB)
|
| 423 |
+
validation_v2.jsonl (125.6MB)
|
| 424 |
+
coco/
|
| 425 |
+
coco-cn/
|
| 426 |
+
flickr8k-cn/
|
| 427 |
+
image_paragraph_captioning/
|
| 428 |
+
msrvtt/
|
| 429 |
+
textcap/
|
| 430 |
+
.gitattributes (141.0B)
|
| 431 |
+
README.md (490.0B)
|
| 432 |
+
instructions.json (1010.0B)
|
| 433 |
+
test.jsonl (117.1MB)
|
| 434 |
+
train.jsonl (231.1MB)
|
| 435 |
+
val.jsonl (116.9MB)
|
| 436 |
+
instructions.json (541.0B)
|
| 437 |
+
test.jsonl (49.4MB)
|
| 438 |
+
train.jsonl (300.0MB)
|
| 439 |
+
val.jsonl (49.9MB)
|
| 440 |
+
instructions.json (790.0B)
|
| 441 |
+
test.jsonl (66.4MB)
|
| 442 |
+
train.jsonl (1.2GB)
|
| 443 |
+
val.jsonl (65.0MB)
|
| 444 |
+
image_paragraph_captioning_test.jsonl (120.7MB)
|
| 445 |
+
image_paragraph_captioning_train.jsonl (701.2MB)
|
| 446 |
+
image_paragraph_captioning_val.jsonl (118.0MB)
|
| 447 |
+
instruction.json (1.4KB)
|
| 448 |
+
README.md (73.0B)
|
| 449 |
+
create_dataset.py (5.5KB)
|
| 450 |
+
instructions.json (882.0B)
|
| 451 |
+
test.jsonl (333.1MB)
|
| 452 |
+
train.jsonl (7.4GB)
|
| 453 |
+
val.jsonl (333.4MB)
|
| 454 |
+
instructions.json (1.1KB)
|
| 455 |
+
train.jsonl (5.7GB)
|
| 456 |
+
val.jsonl (851.3MB)
|
ICL/eval_icl.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
ICL 推理评测脚本:模拟多轮 RET/ANS 决策循环。支持多卡并行。
|
| 4 |
+
|
| 5 |
+
流程:
|
| 6 |
+
1. 给模型 query_image + question(0-shot)
|
| 7 |
+
2. 模型输出 <RET> → 用预计算 top5 检索下一张图+caption,追加到 context,再问
|
| 8 |
+
3. 模型输出 <ANS> → 结束,提取答案
|
| 9 |
+
4. 最多 max_rounds 轮(防止一直 RET)
|
| 10 |
+
|
| 11 |
+
多卡策略:
|
| 12 |
+
每张 GPU 加载一份模型,按 dataset 粒度分配任务,最后 rank 0 汇总。
|
| 13 |
+
|
| 14 |
+
用法:
|
| 15 |
+
# 单卡
|
| 16 |
+
python3 eval_icl.py \
|
| 17 |
+
--model-path /workspace/xiaobin/ICL/sft_model/merged_hf \
|
| 18 |
+
--category vqa --dataset vqav2 --split val \
|
| 19 |
+
--num-samples 200 --max-rounds 4 --device cuda:0
|
| 20 |
+
|
| 21 |
+
# 多卡(8 GPU)
|
| 22 |
+
torchrun --nproc_per_node=8 eval_icl.py \
|
| 23 |
+
--model-path /workspace/xiaobin/ICL/sft_model/merged_hf \
|
| 24 |
+
--all-categories --num-samples 100 --max-rounds 4
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
import argparse
|
| 28 |
+
import json
|
| 29 |
+
import os
|
| 30 |
+
import random
|
| 31 |
+
import sys
|
| 32 |
+
import time
|
| 33 |
+
from collections import defaultdict
|
| 34 |
+
from typing import Dict, List, Optional, Tuple
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
import torch.distributed as dist
|
| 38 |
+
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
|
| 39 |
+
from qwen_vl_utils import process_vision_info
|
| 40 |
+
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
# 默认路径
|
| 43 |
+
# ---------------------------------------------------------------------------
|
| 44 |
+
INDEX_ROOT = "/workspace/xiaobin/dataset/index"
|
| 45 |
+
EMBEDDINGS_DIR = "/workspace/xiaobin/dataset/embeddings"
|
| 46 |
+
CAPTION_CACHE_DIR = "/workspace/xiaobin/dataset/caption_cache"
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# 分布式工具
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
|
| 52 |
+
def setup_distributed():
|
| 53 |
+
"""初始化分布式环境,返回 (rank, world_size, device)。
|
| 54 |
+
单卡时 rank=0, world_size=1。"""
|
| 55 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
| 56 |
+
rank = int(os.environ["RANK"])
|
| 57 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
| 58 |
+
local_rank = int(os.environ.get("LOCAL_RANK", rank))
|
| 59 |
+
dist.init_process_group("nccl")
|
| 60 |
+
torch.cuda.set_device(local_rank)
|
| 61 |
+
device = f"cuda:{local_rank}"
|
| 62 |
+
else:
|
| 63 |
+
rank, world_size = 0, 1
|
| 64 |
+
device = None # 由 args.device 决定
|
| 65 |
+
return rank, world_size, device
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def gather_results(local_results: List[Dict], rank: int, world_size: int) -> List[Dict]:
|
| 69 |
+
"""把各 rank 的结果汇总到 rank 0。"""
|
| 70 |
+
if world_size == 1:
|
| 71 |
+
return local_results
|
| 72 |
+
|
| 73 |
+
# 序列化 → bytes → tensor
|
| 74 |
+
data = json.dumps(local_results, ensure_ascii=False).encode("utf-8")
|
| 75 |
+
size = torch.tensor([len(data)], dtype=torch.long, device=f"cuda:{rank}")
|
| 76 |
+
|
| 77 |
+
# 收集各 rank 的大小
|
| 78 |
+
size_list = [torch.zeros(1, dtype=torch.long, device=f"cuda:{rank}") for _ in range(world_size)]
|
| 79 |
+
dist.all_gather(size_list, size)
|
| 80 |
+
max_size = max(s.item() for s in size_list)
|
| 81 |
+
|
| 82 |
+
# pad 到相同长度
|
| 83 |
+
padded = data + b"\x00" * (max_size - len(data))
|
| 84 |
+
tensor = torch.ByteTensor(list(padded)).cuda(rank)
|
| 85 |
+
|
| 86 |
+
tensor_list = [torch.zeros(max_size, dtype=torch.uint8, device=f"cuda:{rank}") for _ in range(world_size)]
|
| 87 |
+
dist.all_gather(tensor_list, tensor)
|
| 88 |
+
|
| 89 |
+
if rank == 0:
|
| 90 |
+
all_results = []
|
| 91 |
+
for i, (t, s) in enumerate(zip(tensor_list, size_list)):
|
| 92 |
+
raw = bytes(t[:s.item()].cpu().tolist())
|
| 93 |
+
all_results.extend(json.loads(raw.decode("utf-8")))
|
| 94 |
+
return all_results
|
| 95 |
+
return []
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# ---------------------------------------------------------------------------
|
| 99 |
+
# 数据加载
|
| 100 |
+
# ---------------------------------------------------------------------------
|
| 101 |
+
|
| 102 |
+
def load_records(cat: str, ds: str, split: str, limit: int = 0) -> List[Dict]:
|
| 103 |
+
path = os.path.join(INDEX_ROOT, cat, ds, f"{split}.jsonl")
|
| 104 |
+
if not os.path.exists(path):
|
| 105 |
+
return []
|
| 106 |
+
records = []
|
| 107 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 108 |
+
for line in f:
|
| 109 |
+
line = line.strip()
|
| 110 |
+
if not line:
|
| 111 |
+
continue
|
| 112 |
+
r = json.loads(line)
|
| 113 |
+
if r.get("image") and r.get("answer"):
|
| 114 |
+
records.append(r)
|
| 115 |
+
if limit and len(records) >= limit:
|
| 116 |
+
break
|
| 117 |
+
return records
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def load_top5(cat: str, ds: str) -> Dict[str, List[str]]:
|
| 121 |
+
path = os.path.join(EMBEDDINGS_DIR, f"{cat}_{ds}_top5.json")
|
| 122 |
+
if not os.path.exists(path):
|
| 123 |
+
return {}
|
| 124 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 125 |
+
return json.load(f)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def load_caption_cache(cat: str, ds: str) -> Dict[str, str]:
|
| 129 |
+
path = os.path.join(CAPTION_CACHE_DIR, f"{cat}_{ds}.json")
|
| 130 |
+
if not os.path.exists(path):
|
| 131 |
+
return {}
|
| 132 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 133 |
+
return json.load(f)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def load_instructions(cat: str, ds: str) -> List[str]:
|
| 137 |
+
path = os.path.join(INDEX_ROOT, cat, ds, "instructions.json")
|
| 138 |
+
if not os.path.exists(path):
|
| 139 |
+
return ["Look at the image and answer the question."]
|
| 140 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 141 |
+
return json.load(f)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def discover_datasets(categories: List[str]) -> List[Tuple[str, str]]:
|
| 145 |
+
results = []
|
| 146 |
+
for cat in sorted(os.listdir(INDEX_ROOT)):
|
| 147 |
+
if categories and cat not in categories:
|
| 148 |
+
continue
|
| 149 |
+
cat_dir = os.path.join(INDEX_ROOT, cat)
|
| 150 |
+
if not os.path.isdir(cat_dir):
|
| 151 |
+
continue
|
| 152 |
+
for ds in sorted(os.listdir(cat_dir)):
|
| 153 |
+
ds_dir = os.path.join(cat_dir, ds)
|
| 154 |
+
if os.path.isdir(ds_dir):
|
| 155 |
+
results.append((cat, ds))
|
| 156 |
+
return results
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# ---------------------------------------------------------------------------
|
| 160 |
+
# 模型加载
|
| 161 |
+
# ---------------------------------------------------------------------------
|
| 162 |
+
|
| 163 |
+
def load_model(model_path: str, device: str):
|
| 164 |
+
print(f"[{device}] Loading model from {model_path} ...")
|
| 165 |
+
processor = AutoProcessor.from_pretrained(
|
| 166 |
+
model_path, trust_remote_code=True,
|
| 167 |
+
min_pixels=256 * 28 * 28,
|
| 168 |
+
max_pixels=1280 * 28 * 28,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
model = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 172 |
+
model_path,
|
| 173 |
+
trust_remote_code=True,
|
| 174 |
+
torch_dtype=torch.bfloat16,
|
| 175 |
+
attn_implementation="flash_attention_2",
|
| 176 |
+
device_map=device,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
special_tokens = ["<RET>", "<ANS>", "</ANS>", "<RETQ>", "</RETQ>"]
|
| 180 |
+
num_added = processor.tokenizer.add_tokens(special_tokens, special_tokens=True)
|
| 181 |
+
if num_added > 0:
|
| 182 |
+
model.resize_token_embeddings(len(processor.tokenizer))
|
| 183 |
+
|
| 184 |
+
model.eval()
|
| 185 |
+
|
| 186 |
+
ret_id = processor.tokenizer.convert_tokens_to_ids("<RET>")
|
| 187 |
+
ans_id = processor.tokenizer.convert_tokens_to_ids("<ANS>")
|
| 188 |
+
print(f"[{device}] Ready. <RET>={ret_id}, <ANS>={ans_id}")
|
| 189 |
+
|
| 190 |
+
return model, processor
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# ---------------------------------------------------------------------------
|
| 194 |
+
# 推理核心
|
| 195 |
+
# ---------------------------------------------------------------------------
|
| 196 |
+
|
| 197 |
+
def build_messages(
|
| 198 |
+
instruction: str,
|
| 199 |
+
query_image: str,
|
| 200 |
+
question: Optional[str],
|
| 201 |
+
shots: List[Dict],
|
| 202 |
+
min_pixels: int = 256 * 28 * 28,
|
| 203 |
+
max_pixels: int = 1280 * 28 * 28,
|
| 204 |
+
) -> List[Dict]:
|
| 205 |
+
user_content = []
|
| 206 |
+
|
| 207 |
+
if instruction:
|
| 208 |
+
user_content.append({"type": "text", "text": instruction})
|
| 209 |
+
|
| 210 |
+
user_content.append({
|
| 211 |
+
"type": "image",
|
| 212 |
+
"image": f"file://{query_image}",
|
| 213 |
+
"min_pixels": min_pixels, "max_pixels": max_pixels,
|
| 214 |
+
})
|
| 215 |
+
|
| 216 |
+
if question:
|
| 217 |
+
user_content.append({"type": "text", "text": f"Question: {question}"})
|
| 218 |
+
|
| 219 |
+
for shot in shots:
|
| 220 |
+
user_content.append({
|
| 221 |
+
"type": "image",
|
| 222 |
+
"image": f"file://{shot['image']}",
|
| 223 |
+
"min_pixels": min_pixels, "max_pixels": max_pixels,
|
| 224 |
+
})
|
| 225 |
+
if shot.get("caption"):
|
| 226 |
+
user_content.append({"type": "text", "text": f"Caption: {shot['caption']}"})
|
| 227 |
+
|
| 228 |
+
user_content.append({"type": "text", "text": "Action:"})
|
| 229 |
+
return [{"role": "user", "content": user_content}]
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
@torch.no_grad()
|
| 233 |
+
def generate_action(model, processor, messages: List[Dict], max_new_tokens: int = 256) -> str:
|
| 234 |
+
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 235 |
+
|
| 236 |
+
image_inputs = None
|
| 237 |
+
try:
|
| 238 |
+
image_inputs, _ = process_vision_info(messages)
|
| 239 |
+
except Exception:
|
| 240 |
+
pass
|
| 241 |
+
|
| 242 |
+
inputs = processor(
|
| 243 |
+
text=[text],
|
| 244 |
+
images=image_inputs if image_inputs else None,
|
| 245 |
+
return_tensors="pt",
|
| 246 |
+
padding=False,
|
| 247 |
+
truncation=False,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
device = next(model.parameters()).device
|
| 251 |
+
inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
|
| 252 |
+
|
| 253 |
+
outputs = model.generate(
|
| 254 |
+
**inputs,
|
| 255 |
+
max_new_tokens=max_new_tokens,
|
| 256 |
+
do_sample=False,
|
| 257 |
+
temperature=None,
|
| 258 |
+
top_p=None,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
input_len = inputs["input_ids"].shape[1]
|
| 262 |
+
generated = outputs[0][input_len:]
|
| 263 |
+
return processor.tokenizer.decode(generated, skip_special_tokens=False)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def parse_action(text: str) -> Tuple[str, str]:
|
| 267 |
+
text = text.strip()
|
| 268 |
+
|
| 269 |
+
if text.startswith("<RET>"):
|
| 270 |
+
desc = text[len("<RET>"):].strip()
|
| 271 |
+
if desc.startswith("Description:"):
|
| 272 |
+
desc = desc[len("Description:"):].strip()
|
| 273 |
+
for tok in ["<|im_end|>", "</s>", "<|endoftext|>"]:
|
| 274 |
+
desc = desc.replace(tok, "").strip()
|
| 275 |
+
return "ret", desc
|
| 276 |
+
|
| 277 |
+
if text.startswith("<ANS>"):
|
| 278 |
+
ans = text[len("<ANS>"):]
|
| 279 |
+
end_idx = ans.find("</ANS>")
|
| 280 |
+
if end_idx != -1:
|
| 281 |
+
ans = ans[:end_idx]
|
| 282 |
+
else:
|
| 283 |
+
for tok in ["<|im_end|>", "</s>", "<|endoftext|>"]:
|
| 284 |
+
ans = ans.replace(tok, "").strip()
|
| 285 |
+
return "ans", ans.strip()
|
| 286 |
+
|
| 287 |
+
return "unknown", text
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def run_icl_loop(
|
| 291 |
+
model, processor,
|
| 292 |
+
record: Dict,
|
| 293 |
+
instruction: str,
|
| 294 |
+
top5: Dict[str, List[str]],
|
| 295 |
+
caption_cache: Dict[str, str],
|
| 296 |
+
max_rounds: int = 4,
|
| 297 |
+
) -> Dict:
|
| 298 |
+
query_image = record["image"]
|
| 299 |
+
question = record.get("question", "")
|
| 300 |
+
gt_answer = record.get("answer", "")
|
| 301 |
+
|
| 302 |
+
shots = []
|
| 303 |
+
used_images = {query_image}
|
| 304 |
+
rounds = []
|
| 305 |
+
candidates = top5.get(query_image, [])
|
| 306 |
+
|
| 307 |
+
for round_idx in range(max_rounds):
|
| 308 |
+
messages = build_messages(instruction, query_image, question, shots)
|
| 309 |
+
raw_output = generate_action(model, processor, messages)
|
| 310 |
+
action, content = parse_action(raw_output)
|
| 311 |
+
|
| 312 |
+
rounds.append({
|
| 313 |
+
"round": round_idx,
|
| 314 |
+
"action": action,
|
| 315 |
+
"content": content,
|
| 316 |
+
"raw": raw_output[:200],
|
| 317 |
+
})
|
| 318 |
+
|
| 319 |
+
if action == "ans":
|
| 320 |
+
return {
|
| 321 |
+
"image": query_image, "question": question,
|
| 322 |
+
"gt_answer": gt_answer, "rounds": rounds,
|
| 323 |
+
"final_answer": content, "num_rounds": round_idx + 1,
|
| 324 |
+
"terminated_by": "ans",
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
if action == "ret":
|
| 328 |
+
next_image = None
|
| 329 |
+
for c in candidates:
|
| 330 |
+
if c not in used_images:
|
| 331 |
+
next_image = c
|
| 332 |
+
break
|
| 333 |
+
|
| 334 |
+
if next_image is None:
|
| 335 |
+
return {
|
| 336 |
+
"image": query_image, "question": question,
|
| 337 |
+
"gt_answer": gt_answer, "rounds": rounds,
|
| 338 |
+
"final_answer": None, "num_rounds": round_idx + 1,
|
| 339 |
+
"terminated_by": "no_more_shots",
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
cap = caption_cache.get(next_image, content)
|
| 343 |
+
shots.append({"image": next_image, "caption": cap})
|
| 344 |
+
used_images.add(next_image)
|
| 345 |
+
else:
|
| 346 |
+
return {
|
| 347 |
+
"image": query_image, "question": question,
|
| 348 |
+
"gt_answer": gt_answer, "rounds": rounds,
|
| 349 |
+
"final_answer": content, "num_rounds": round_idx + 1,
|
| 350 |
+
"terminated_by": "unknown_action",
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
return {
|
| 354 |
+
"image": query_image, "question": question,
|
| 355 |
+
"gt_answer": gt_answer, "rounds": rounds,
|
| 356 |
+
"final_answer": None, "num_rounds": max_rounds,
|
| 357 |
+
"terminated_by": "max_rounds",
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
# ---------------------------------------------------------------------------
|
| 362 |
+
# 统计
|
| 363 |
+
# ---------------------------------------------------------------------------
|
| 364 |
+
|
| 365 |
+
def print_stats(results: List[Dict], cat: str = "", ds: str = ""):
|
| 366 |
+
prefix = f"[{cat}/{ds}]" if ds else f"[{cat}]" if cat else "[ALL]"
|
| 367 |
+
n = len(results)
|
| 368 |
+
if n == 0:
|
| 369 |
+
print(f"{prefix} 无结果")
|
| 370 |
+
return
|
| 371 |
+
|
| 372 |
+
term_counts = defaultdict(int)
|
| 373 |
+
for r in results:
|
| 374 |
+
term_counts[r["terminated_by"]] += 1
|
| 375 |
+
|
| 376 |
+
round_actions = defaultdict(lambda: defaultdict(int))
|
| 377 |
+
for r in results:
|
| 378 |
+
for rd in r["rounds"]:
|
| 379 |
+
round_actions[rd["round"]][rd["action"]] += 1
|
| 380 |
+
|
| 381 |
+
avg_rounds = sum(r["num_rounds"] for r in results) / n
|
| 382 |
+
|
| 383 |
+
print(f"\n{'='*60}")
|
| 384 |
+
print(f"{prefix} 共 {n} 条样本")
|
| 385 |
+
print(f" 平均轮次: {avg_rounds:.2f}")
|
| 386 |
+
print(f" 终止原因:")
|
| 387 |
+
for k, v in sorted(term_counts.items()):
|
| 388 |
+
print(f" {k}: {v} ({v/n*100:.1f}%)")
|
| 389 |
+
|
| 390 |
+
print(f" 每轮 RET/ANS 分布:")
|
| 391 |
+
for rd_idx in sorted(round_actions.keys()):
|
| 392 |
+
actions = round_actions[rd_idx]
|
| 393 |
+
total = sum(actions.values())
|
| 394 |
+
parts = [f"{a}={c}({c/total*100:.0f}%)" for a, c in sorted(actions.items())]
|
| 395 |
+
print(f" Round {rd_idx}: {' | '.join(parts)} (共{total}条)")
|
| 396 |
+
|
| 397 |
+
answered = [r for r in results if r["final_answer"] is not None]
|
| 398 |
+
print(f" 产出答案: {len(answered)}/{n} ({len(answered)/n*100:.1f}%)")
|
| 399 |
+
print(f"{'='*60}")
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
# ---------------------------------------------------------------------------
|
| 403 |
+
# Main
|
| 404 |
+
# ---------------------------------------------------------------------------
|
| 405 |
+
|
| 406 |
+
def main():
|
| 407 |
+
parser = argparse.ArgumentParser(description="ICL 多轮推理评测(支持多卡)")
|
| 408 |
+
parser.add_argument("--model-path", required=True, help="合并后的 HF 模型路径")
|
| 409 |
+
parser.add_argument("--category", type=str, default="")
|
| 410 |
+
parser.add_argument("--dataset", type=str, default="")
|
| 411 |
+
parser.add_argument("--split", type=str, default="val")
|
| 412 |
+
parser.add_argument("--all-categories", action="store_true")
|
| 413 |
+
parser.add_argument("--num-samples", type=int, default=100, help="每个 dataset 采样数")
|
| 414 |
+
parser.add_argument("--max-rounds", type=int, default=4)
|
| 415 |
+
parser.add_argument("--device", type=str, default="cuda:0", help="单卡时用的设备")
|
| 416 |
+
parser.add_argument("--output", type=str, default="")
|
| 417 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 418 |
+
args = parser.parse_args()
|
| 419 |
+
|
| 420 |
+
random.seed(args.seed)
|
| 421 |
+
|
| 422 |
+
# 分布式初始化
|
| 423 |
+
rank, world_size, dist_device = setup_distributed()
|
| 424 |
+
device = dist_device or args.device
|
| 425 |
+
is_main = (rank == 0)
|
| 426 |
+
|
| 427 |
+
if is_main:
|
| 428 |
+
print(f"World size: {world_size}")
|
| 429 |
+
|
| 430 |
+
# 加载模型(每张卡一份)
|
| 431 |
+
model, processor = load_model(args.model_path, device)
|
| 432 |
+
|
| 433 |
+
# 确定 dataset 列表
|
| 434 |
+
if args.all_categories:
|
| 435 |
+
categories = ["vqa", "captioning", "classification", "reasoning"]
|
| 436 |
+
elif args.category:
|
| 437 |
+
categories = [args.category]
|
| 438 |
+
else:
|
| 439 |
+
categories = ["vqa"]
|
| 440 |
+
|
| 441 |
+
if args.dataset:
|
| 442 |
+
ds_list = [(args.category or "vqa", args.dataset)]
|
| 443 |
+
else:
|
| 444 |
+
ds_list = discover_datasets(categories)
|
| 445 |
+
|
| 446 |
+
# ---- 按 rank 分配 dataset ----
|
| 447 |
+
my_ds_list = ds_list[rank::world_size]
|
| 448 |
+
if is_main:
|
| 449 |
+
print(f"共 {len(ds_list)} 个 dataset,每卡约 {len(my_ds_list)} 个")
|
| 450 |
+
|
| 451 |
+
local_results = []
|
| 452 |
+
|
| 453 |
+
for cat, ds in my_ds_list:
|
| 454 |
+
print(f"[rank {rank}] Evaluating {cat}/{ds} ({args.split})")
|
| 455 |
+
|
| 456 |
+
records = load_records(cat, ds, args.split, limit=args.num_samples * 5)
|
| 457 |
+
if not records:
|
| 458 |
+
print(f" [rank {rank}] 跳过 {cat}/{ds}:无记录")
|
| 459 |
+
continue
|
| 460 |
+
|
| 461 |
+
top5 = load_top5(cat, ds)
|
| 462 |
+
if not top5:
|
| 463 |
+
print(f" [rank {rank}] 跳过 {cat}/{ds}:无 top5")
|
| 464 |
+
continue
|
| 465 |
+
|
| 466 |
+
caption_cache = load_caption_cache(cat, ds)
|
| 467 |
+
instructions = load_instructions(cat, ds)
|
| 468 |
+
|
| 469 |
+
records = [r for r in records if r["image"] in top5]
|
| 470 |
+
if not records:
|
| 471 |
+
print(f" [rank {rank}] 跳过 {cat}/{ds}:无 top5 覆盖")
|
| 472 |
+
continue
|
| 473 |
+
|
| 474 |
+
if len(records) > args.num_samples:
|
| 475 |
+
records = random.sample(records, args.num_samples)
|
| 476 |
+
print(f" [rank {rank}] {cat}/{ds}: {len(records)} 条")
|
| 477 |
+
|
| 478 |
+
for i, rec in enumerate(records):
|
| 479 |
+
inst = random.choice(instructions)
|
| 480 |
+
result = run_icl_loop(
|
| 481 |
+
model, processor, rec, inst, top5, caption_cache,
|
| 482 |
+
max_rounds=args.max_rounds,
|
| 483 |
+
)
|
| 484 |
+
result["category"] = cat
|
| 485 |
+
result["dataset"] = ds
|
| 486 |
+
local_results.append(result)
|
| 487 |
+
|
| 488 |
+
if (i + 1) % 10 == 0 or (i + 1) == len(records):
|
| 489 |
+
action_seq = " → ".join(rd["action"].upper() for rd in result["rounds"])
|
| 490 |
+
print(f" [rank {rank}] [{i+1}/{len(records)}] {action_seq} | "
|
| 491 |
+
f"{result['terminated_by']}")
|
| 492 |
+
|
| 493 |
+
# ---- 汇总结果 ----
|
| 494 |
+
all_results = gather_results(local_results, rank, world_size)
|
| 495 |
+
|
| 496 |
+
if is_main:
|
| 497 |
+
# 按 category 统计
|
| 498 |
+
cat_results = defaultdict(list)
|
| 499 |
+
for r in all_results:
|
| 500 |
+
cat_results[r["category"]].append(r)
|
| 501 |
+
|
| 502 |
+
for cat in categories:
|
| 503 |
+
if cat_results[cat]:
|
| 504 |
+
# 按 dataset 子统计
|
| 505 |
+
ds_groups = defaultdict(list)
|
| 506 |
+
for r in cat_results[cat]:
|
| 507 |
+
ds_groups[r["dataset"]].append(r)
|
| 508 |
+
for d in sorted(ds_groups):
|
| 509 |
+
print_stats(ds_groups[d], cat, d)
|
| 510 |
+
print_stats(cat_results[cat], cat)
|
| 511 |
+
|
| 512 |
+
print_stats(all_results)
|
| 513 |
+
|
| 514 |
+
output_path = args.output or f"/workspace/xiaobin/ICL/eval_results_{args.split}.json"
|
| 515 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 516 |
+
json.dump(all_results, f, ensure_ascii=False, indent=2)
|
| 517 |
+
print(f"\n详细结果已保存到: {output_path}")
|
| 518 |
+
|
| 519 |
+
if world_size > 1:
|
| 520 |
+
dist.destroy_process_group()
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
if __name__ == "__main__":
|
| 524 |
+
main()
|
ICL/extract_images.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
从 /workspace/xiaobin/dataset/data 下所有 JSONL 文件中提取 base64 编码的图片,
|
| 4 |
+
保存到 /workspace/xiaobin/dataset/images/{category}/{dataset}/{split}/ 目录。
|
| 5 |
+
|
| 6 |
+
split 由文件名推断:含 train -> train, 含 test -> test, 含 val/validation -> val
|
| 7 |
+
|
| 8 |
+
图片字段名自动检测,支持:
|
| 9 |
+
image_str, image_base64_str, img_str, base64, image_base64, image_base_url,
|
| 10 |
+
video_str (list), images (list)
|
| 11 |
+
|
| 12 |
+
依赖:无需额外安装(tqdm 已有)
|
| 13 |
+
|
| 14 |
+
用法:
|
| 15 |
+
python3 extract_images.py # 处理全部
|
| 16 |
+
python3 extract_images.py vqa/shapes # 只处理某个数据集
|
| 17 |
+
python3 extract_images.py /path/to/some.jsonl # 只处理某个文件
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
import json
|
| 23 |
+
import base64
|
| 24 |
+
import glob
|
| 25 |
+
import re
|
| 26 |
+
from tqdm import tqdm
|
| 27 |
+
|
| 28 |
+
DATA_ROOT = "/workspace/xiaobin/dataset/data"
|
| 29 |
+
OUTPUT_ROOT = "/workspace/xiaobin/dataset/images"
|
| 30 |
+
|
| 31 |
+
# 所有可能的图片字段名(优先级顺序)
|
| 32 |
+
# 注意:有些字段在不同数据集中可能是 str 也可能是 list,统一处理
|
| 33 |
+
ALL_IMAGE_FIELDS = [
|
| 34 |
+
"image", # captioning/coco
|
| 35 |
+
"image_str", # 多个数据集(str 或 list)
|
| 36 |
+
"image_base64_str", # snli-ve, multi30k, vcr, visual_mrc
|
| 37 |
+
"img_str", # gqa, ocr-vqa, st-vqa, text-vqa, viquae, vqav2
|
| 38 |
+
"base64", # fm-iqa
|
| 39 |
+
"image_base64", # coco-cn, mmchat(str 或 list,如 chinesefoodnet-10)
|
| 40 |
+
"image_base_url", # textcap
|
| 41 |
+
"video_str", # msrvtt, ss, activitynet-qa, ivqa, msrvtt-qa, msvd-qa (list)
|
| 42 |
+
"images", # vist (list)
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def detect_extension(data_bytes):
|
| 47 |
+
"""根据文件头判断图片格式"""
|
| 48 |
+
if data_bytes[:2] == b'\xff\xd8':
|
| 49 |
+
return ".jpg"
|
| 50 |
+
elif data_bytes[:8] == b'\x89PNG\r\n\x1a\n':
|
| 51 |
+
return ".png"
|
| 52 |
+
elif data_bytes[:4] == b'GIF8':
|
| 53 |
+
return ".gif"
|
| 54 |
+
elif data_bytes[:4] == b'RIFF' and data_bytes[8:12] == b'WEBP':
|
| 55 |
+
return ".webp"
|
| 56 |
+
else:
|
| 57 |
+
return ".jpg"
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def classify_split(filename):
|
| 61 |
+
"""从文件名推断 split 类型"""
|
| 62 |
+
fn = filename.lower()
|
| 63 |
+
if "train" in fn:
|
| 64 |
+
return "train"
|
| 65 |
+
elif "test" in fn:
|
| 66 |
+
return "test"
|
| 67 |
+
elif "val" in fn:
|
| 68 |
+
return "val"
|
| 69 |
+
else:
|
| 70 |
+
return "other"
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def extract_images_from_record(record):
|
| 74 |
+
"""从一条 JSONL 记录中提取图片 base64 字符串列表"""
|
| 75 |
+
for field in ALL_IMAGE_FIELDS:
|
| 76 |
+
if field not in record or not record[field]:
|
| 77 |
+
continue
|
| 78 |
+
val = record[field]
|
| 79 |
+
if isinstance(val, str) and len(val) > 100:
|
| 80 |
+
return [val]
|
| 81 |
+
elif isinstance(val, list):
|
| 82 |
+
return [item for item in val if isinstance(item, str) and len(item) > 100]
|
| 83 |
+
return []
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def count_lines(filepath):
|
| 87 |
+
"""快速统计文件行数(用于 tqdm total)"""
|
| 88 |
+
count = 0
|
| 89 |
+
with open(filepath, 'rb') as f:
|
| 90 |
+
# 用 buffer 读取,比逐行快很多
|
| 91 |
+
buf_size = 1024 * 1024 * 8 # 8MB
|
| 92 |
+
buf = f.raw.read(buf_size)
|
| 93 |
+
while buf:
|
| 94 |
+
count += buf.count(b'\n')
|
| 95 |
+
buf = f.raw.read(buf_size)
|
| 96 |
+
return count
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def process_jsonl_file(jsonl_path, file_idx, total_files):
|
| 100 |
+
"""处理单个 JSONL 文件,提取图片并保存"""
|
| 101 |
+
rel_path = os.path.relpath(jsonl_path, DATA_ROOT)
|
| 102 |
+
parts = rel_path.split(os.sep)
|
| 103 |
+
if len(parts) < 3:
|
| 104 |
+
print(f" [SKIP] 路径层级不够: {rel_path}")
|
| 105 |
+
return 0
|
| 106 |
+
|
| 107 |
+
category = parts[0]
|
| 108 |
+
dataset = parts[1]
|
| 109 |
+
filename = parts[2]
|
| 110 |
+
split = classify_split(filename)
|
| 111 |
+
|
| 112 |
+
out_dir = os.path.join(OUTPUT_ROOT, category, dataset, split)
|
| 113 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 114 |
+
|
| 115 |
+
# 断点续传:统计已有图片数
|
| 116 |
+
existing_count = len([f for f in os.listdir(out_dir) if os.path.isfile(os.path.join(out_dir, f))])
|
| 117 |
+
|
| 118 |
+
# 快速统计总行数
|
| 119 |
+
file_size_mb = os.path.getsize(jsonl_path) / (1024 * 1024)
|
| 120 |
+
total_lines = count_lines(jsonl_path)
|
| 121 |
+
|
| 122 |
+
count = 0
|
| 123 |
+
skipped = 0
|
| 124 |
+
errors = 0
|
| 125 |
+
|
| 126 |
+
desc = f"[{file_idx}/{total_files}] {category}/{dataset}/{split} ({file_size_mb:.0f}MB)"
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
with open(jsonl_path, 'r', encoding='utf-8') as f:
|
| 130 |
+
pbar = tqdm(f, total=total_lines, desc=desc, unit="行",
|
| 131 |
+
dynamic_ncols=True, miniters=50)
|
| 132 |
+
for line in pbar:
|
| 133 |
+
line = line.strip()
|
| 134 |
+
if not line:
|
| 135 |
+
continue
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
record = json.loads(line)
|
| 139 |
+
except json.JSONDecodeError:
|
| 140 |
+
errors += 1
|
| 141 |
+
continue
|
| 142 |
+
|
| 143 |
+
b64_list = extract_images_from_record(record)
|
| 144 |
+
if not b64_list:
|
| 145 |
+
skipped += 1
|
| 146 |
+
continue
|
| 147 |
+
|
| 148 |
+
for img_idx, b64_str in enumerate(b64_list):
|
| 149 |
+
global_idx = existing_count + count
|
| 150 |
+
try:
|
| 151 |
+
img_bytes = base64.b64decode(b64_str)
|
| 152 |
+
ext = detect_extension(img_bytes)
|
| 153 |
+
if len(b64_list) > 1:
|
| 154 |
+
img_name = f"{global_idx:08d}_f{img_idx:03d}{ext}"
|
| 155 |
+
else:
|
| 156 |
+
img_name = f"{global_idx:08d}{ext}"
|
| 157 |
+
with open(os.path.join(out_dir, img_name), 'wb') as img_f:
|
| 158 |
+
img_f.write(img_bytes)
|
| 159 |
+
count += 1
|
| 160 |
+
except Exception as e:
|
| 161 |
+
errors += 1
|
| 162 |
+
if errors <= 3:
|
| 163 |
+
tqdm.write(f" [ERROR] {e}")
|
| 164 |
+
|
| 165 |
+
# 更新后缀信息
|
| 166 |
+
pbar.set_postfix(imgs=count, skip=skipped, err=errors, refresh=False)
|
| 167 |
+
pbar.close()
|
| 168 |
+
|
| 169 |
+
except Exception as e:
|
| 170 |
+
print(f" [FATAL] {e}")
|
| 171 |
+
|
| 172 |
+
print(f" -> 完成: {count} 张图片, 跳过 {skipped} 行(无图), 错误 {errors}")
|
| 173 |
+
return count
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def find_all_jsonl_files():
|
| 177 |
+
"""查找所有需要处理的 JSONL 文件"""
|
| 178 |
+
all_files = []
|
| 179 |
+
for jsonl_path in sorted(glob.glob(os.path.join(DATA_ROOT, "*/*/*.jsonl"))):
|
| 180 |
+
filename = os.path.basename(jsonl_path)
|
| 181 |
+
if re.search(r'_\d{4}-\d{2}-\d{2}\.jsonl$', filename):
|
| 182 |
+
continue
|
| 183 |
+
if '_v2.jsonl' in filename or '_new.jsonl' in filename:
|
| 184 |
+
continue
|
| 185 |
+
if filename.startswith('para_'):
|
| 186 |
+
continue
|
| 187 |
+
all_files.append(jsonl_path)
|
| 188 |
+
return all_files
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def main():
|
| 192 |
+
print("=" * 60)
|
| 193 |
+
print("JSONL 图片提取工具")
|
| 194 |
+
print(f"数据源: {DATA_ROOT}")
|
| 195 |
+
print(f"输出到: {OUTPUT_ROOT}")
|
| 196 |
+
print("=" * 60)
|
| 197 |
+
|
| 198 |
+
if len(sys.argv) > 1:
|
| 199 |
+
target = sys.argv[1]
|
| 200 |
+
if os.path.isfile(target):
|
| 201 |
+
files = [target]
|
| 202 |
+
else:
|
| 203 |
+
files = sorted(glob.glob(os.path.join(DATA_ROOT, target, "*.jsonl")))
|
| 204 |
+
files = [f for f in files
|
| 205 |
+
if not re.search(r'_\d{4}-\d{2}-\d{2}\.jsonl$', os.path.basename(f))
|
| 206 |
+
and '_v2.jsonl' not in os.path.basename(f)
|
| 207 |
+
and '_new.jsonl' not in os.path.basename(f)
|
| 208 |
+
and not os.path.basename(f).startswith('para_')]
|
| 209 |
+
else:
|
| 210 |
+
files = find_all_jsonl_files()
|
| 211 |
+
|
| 212 |
+
print(f"\n共 {len(files)} 个 JSONL 文件:")
|
| 213 |
+
total_size = 0
|
| 214 |
+
for f in files:
|
| 215 |
+
size_mb = os.path.getsize(f) / (1024 * 1024)
|
| 216 |
+
total_size += size_mb
|
| 217 |
+
print(f" {os.path.relpath(f, DATA_ROOT):50s} {size_mb:>10.1f} MB")
|
| 218 |
+
print(f" {'合计':50s} {total_size/1024:>10.1f} GB")
|
| 219 |
+
|
| 220 |
+
total_images = 0
|
| 221 |
+
for i, jsonl_path in enumerate(files, 1):
|
| 222 |
+
n = process_jsonl_file(jsonl_path, i, len(files))
|
| 223 |
+
total_images += n
|
| 224 |
+
|
| 225 |
+
print(f"\n{'=' * 60}")
|
| 226 |
+
print(f"全部完成!共提取 {total_images} 张图片")
|
| 227 |
+
print(f"保存在: {OUTPUT_ROOT}")
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
if __name__ == "__main__":
|
| 231 |
+
main()
|
ICL/merge_captions.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
把 detail/{cat}/{ds}/{split}/captions.json 合并成 build_sft.py 需要的格式:
|
| 4 |
+
caption_cache/{cat}_{ds}.json = {"items": {img_path: caption, ...}}
|
| 5 |
+
|
| 6 |
+
这样 build_sft.py --caption-cache-dir caption_cache 就能直接复用。
|
| 7 |
+
|
| 8 |
+
用法:
|
| 9 |
+
python3 merge_captions.py
|
| 10 |
+
python3 merge_captions.py --force # 强制重建
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
import json
|
| 16 |
+
import glob
|
| 17 |
+
|
| 18 |
+
DETAIL_ROOT = "/workspace/xiaobin/dataset/detail"
|
| 19 |
+
CAPTION_CACHE_DIR = "/workspace/xiaobin/dataset/caption_cache"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def main():
|
| 23 |
+
force = "--force" in sys.argv
|
| 24 |
+
os.makedirs(CAPTION_CACHE_DIR, exist_ok=True)
|
| 25 |
+
|
| 26 |
+
# 找所有 dataset 目录 (cat/ds)
|
| 27 |
+
datasets = set()
|
| 28 |
+
for captions_file in glob.glob(os.path.join(DETAIL_ROOT, "*/*/*/captions.json")):
|
| 29 |
+
rel = os.path.relpath(captions_file, DETAIL_ROOT)
|
| 30 |
+
parts = rel.split(os.sep) # cat/ds/split/captions.json
|
| 31 |
+
datasets.add((parts[0], parts[1]))
|
| 32 |
+
|
| 33 |
+
print(f"共 {len(datasets)} 个数据集")
|
| 34 |
+
|
| 35 |
+
for cat, ds in sorted(datasets):
|
| 36 |
+
out_name = f"{cat}_{ds}.json"
|
| 37 |
+
out_path = os.path.join(CAPTION_CACHE_DIR, out_name)
|
| 38 |
+
|
| 39 |
+
if not force and os.path.exists(out_path) and os.path.getsize(out_path) > 0:
|
| 40 |
+
print(f" [SKIP] {out_name}")
|
| 41 |
+
continue
|
| 42 |
+
|
| 43 |
+
merged = {}
|
| 44 |
+
for split in ("train", "val", "test"):
|
| 45 |
+
src = os.path.join(DETAIL_ROOT, cat, ds, split, "captions.json")
|
| 46 |
+
if not os.path.exists(src):
|
| 47 |
+
continue
|
| 48 |
+
try:
|
| 49 |
+
with open(src, 'r', encoding='utf-8') as f:
|
| 50 |
+
data = json.load(f)
|
| 51 |
+
items = data.get("items", {})
|
| 52 |
+
if isinstance(items, dict):
|
| 53 |
+
merged.update(items)
|
| 54 |
+
except Exception as e:
|
| 55 |
+
print(f" [WARN] {src}: {e}")
|
| 56 |
+
|
| 57 |
+
if not merged:
|
| 58 |
+
print(f" [EMPTY] {cat}/{ds}")
|
| 59 |
+
continue
|
| 60 |
+
|
| 61 |
+
with open(out_path, 'w', encoding='utf-8') as f:
|
| 62 |
+
json.dump({"items": merged}, f, ensure_ascii=False)
|
| 63 |
+
|
| 64 |
+
print(f" [OK] {out_name}: {len(merged)} 条")
|
| 65 |
+
|
| 66 |
+
print(f"\n完成! 输出: {CAPTION_CACHE_DIR}")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
main()
|
ICL/sft_model/epoch3_step1406_fp32/chat_template.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {%- if messages[0].content is string %}\n {{- messages[0].content }}\n {%- else %}\n {%- for content in messages[0].content %}\n {%- if 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].content is string %}\n {{- messages[0].content }}\n {%- else %}\n {%- for content in messages[0].content %}\n {%- if 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set image_count = namespace(value=0) %}\n{%- set video_count = namespace(value=0) %}\n{%- for message in messages %}\n {%- if message.role == \"user\" %}\n {{- '<|im_start|>' + message.role + '\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content in message.content %}\n {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}\n {%- set image_count.value = image_count.value + 1 %}\n {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}\n <|vision_start|><|image_pad|><|vision_end|>\n {%- elif content.type == 'video' or 'video' in content %}\n {%- set video_count.value = video_count.value + 1 %}\n {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}\n <|vision_start|><|video_pad|><|vision_end|>\n {%- elif 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content_item in message.content %}\n {%- if 'text' in content_item %}\n {{- content_item.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and message.content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content in message.content %}\n {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}\n {%- set image_count.value = image_count.value + 1 %}\n {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}\n <|vision_start|><|image_pad|><|vision_end|>\n {%- elif content.type == 'video' or 'video' in content %}\n {%- set video_count.value = video_count.value + 1 %}\n {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}\n <|vision_start|><|video_pad|><|vision_end|>\n {%- elif 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n"
|
| 3 |
+
}
|
ICL/sft_model/epoch3_step1406_fp32/config.json
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Qwen3VLForConditionalGeneration"
|
| 4 |
+
],
|
| 5 |
+
"image_token_id": 151655,
|
| 6 |
+
"model_type": "qwen3_vl",
|
| 7 |
+
"text_config": {
|
| 8 |
+
"attention_bias": false,
|
| 9 |
+
"attention_dropout": 0.0,
|
| 10 |
+
"bos_token_id": 151643,
|
| 11 |
+
"dtype": "bfloat16",
|
| 12 |
+
"eos_token_id": 151645,
|
| 13 |
+
"head_dim": 128,
|
| 14 |
+
"hidden_act": "silu",
|
| 15 |
+
"hidden_size": 4096,
|
| 16 |
+
"initializer_range": 0.02,
|
| 17 |
+
"intermediate_size": 12288,
|
| 18 |
+
"max_position_embeddings": 262144,
|
| 19 |
+
"model_type": "qwen3_vl_text",
|
| 20 |
+
"num_attention_heads": 32,
|
| 21 |
+
"num_hidden_layers": 36,
|
| 22 |
+
"num_key_value_heads": 8,
|
| 23 |
+
"rms_norm_eps": 1e-06,
|
| 24 |
+
"rope_scaling": {
|
| 25 |
+
"mrope_interleaved": true,
|
| 26 |
+
"mrope_section": [
|
| 27 |
+
24,
|
| 28 |
+
20,
|
| 29 |
+
20
|
| 30 |
+
],
|
| 31 |
+
"rope_type": "default"
|
| 32 |
+
},
|
| 33 |
+
"rope_theta": 5000000,
|
| 34 |
+
"use_cache": true,
|
| 35 |
+
"vocab_size": 151936
|
| 36 |
+
},
|
| 37 |
+
"tie_word_embeddings": false,
|
| 38 |
+
"transformers_version": "4.57.0.dev0",
|
| 39 |
+
"video_token_id": 151656,
|
| 40 |
+
"vision_config": {
|
| 41 |
+
"deepstack_visual_indexes": [
|
| 42 |
+
8,
|
| 43 |
+
16,
|
| 44 |
+
24
|
| 45 |
+
],
|
| 46 |
+
"depth": 27,
|
| 47 |
+
"hidden_act": "gelu_pytorch_tanh",
|
| 48 |
+
"hidden_size": 1152,
|
| 49 |
+
"in_channels": 3,
|
| 50 |
+
"initializer_range": 0.02,
|
| 51 |
+
"intermediate_size": 4304,
|
| 52 |
+
"model_type": "qwen3_vl",
|
| 53 |
+
"num_heads": 16,
|
| 54 |
+
"num_position_embeddings": 2304,
|
| 55 |
+
"out_hidden_size": 4096,
|
| 56 |
+
"patch_size": 16,
|
| 57 |
+
"spatial_merge_size": 2,
|
| 58 |
+
"temporal_patch_size": 2
|
| 59 |
+
},
|
| 60 |
+
"vision_end_token_id": 151653,
|
| 61 |
+
"vision_start_token_id": 151652
|
| 62 |
+
}
|
ICL/sft_model/epoch3_step1406_fp32/generation_config.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 151643,
|
| 3 |
+
"pad_token_id": 151643,
|
| 4 |
+
"do_sample": true,
|
| 5 |
+
"eos_token_id": [
|
| 6 |
+
151645,
|
| 7 |
+
151643
|
| 8 |
+
],
|
| 9 |
+
"top_k": 20,
|
| 10 |
+
"top_p": 0.8,
|
| 11 |
+
"repetition_penalty": 1.0,
|
| 12 |
+
"temperature": 0.7,
|
| 13 |
+
"transformers_version": "4.56.0"
|
| 14 |
+
}
|
ICL/sft_model/epoch3_step1406_fp32/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ICL/sft_model/epoch3_step1406_fp32/model.safetensors.index.json
ADDED
|
@@ -0,0 +1,757 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_size": 35059909568
|
| 4 |
+
},
|
| 5 |
+
"weight_map": {
|
| 6 |
+
"lm_head.weight": "model-00008-of-00008.safetensors",
|
| 7 |
+
"model.language_model.embed_tokens.weight": "model-00001-of-00008.safetensors",
|
| 8 |
+
"model.language_model.layers.0.input_layernorm.weight": "model-00002-of-00008.safetensors",
|
| 9 |
+
"model.language_model.layers.0.mlp.down_proj.weight": "model-00002-of-00008.safetensors",
|
| 10 |
+
"model.language_model.layers.0.mlp.gate_proj.weight": "model-00002-of-00008.safetensors",
|
| 11 |
+
"model.language_model.layers.0.mlp.up_proj.weight": "model-00002-of-00008.safetensors",
|
| 12 |
+
"model.language_model.layers.0.post_attention_layernorm.weight": "model-00002-of-00008.safetensors",
|
| 13 |
+
"model.language_model.layers.0.self_attn.k_norm.weight": "model-00001-of-00008.safetensors",
|
| 14 |
+
"model.language_model.layers.0.self_attn.k_proj.weight": "model-00001-of-00008.safetensors",
|
| 15 |
+
"model.language_model.layers.0.self_attn.o_proj.weight": "model-00001-of-00008.safetensors",
|
| 16 |
+
"model.language_model.layers.0.self_attn.q_norm.weight": "model-00001-of-00008.safetensors",
|
| 17 |
+
"model.language_model.layers.0.self_attn.q_proj.weight": "model-00001-of-00008.safetensors",
|
| 18 |
+
"model.language_model.layers.0.self_attn.v_proj.weight": "model-00001-of-00008.safetensors",
|
| 19 |
+
"model.language_model.layers.1.input_layernorm.weight": "model-00002-of-00008.safetensors",
|
| 20 |
+
"model.language_model.layers.1.mlp.down_proj.weight": "model-00002-of-00008.safetensors",
|
| 21 |
+
"model.language_model.layers.1.mlp.gate_proj.weight": "model-00002-of-00008.safetensors",
|
| 22 |
+
"model.language_model.layers.1.mlp.up_proj.weight": "model-00002-of-00008.safetensors",
|
| 23 |
+
"model.language_model.layers.1.post_attention_layernorm.weight": "model-00002-of-00008.safetensors",
|
| 24 |
+
"model.language_model.layers.1.self_attn.k_norm.weight": "model-00002-of-00008.safetensors",
|
| 25 |
+
"model.language_model.layers.1.self_attn.k_proj.weight": "model-00002-of-00008.safetensors",
|
| 26 |
+
"model.language_model.layers.1.self_attn.o_proj.weight": "model-00002-of-00008.safetensors",
|
| 27 |
+
"model.language_model.layers.1.self_attn.q_norm.weight": "model-00002-of-00008.safetensors",
|
| 28 |
+
"model.language_model.layers.1.self_attn.q_proj.weight": "model-00002-of-00008.safetensors",
|
| 29 |
+
"model.language_model.layers.1.self_attn.v_proj.weight": "model-00002-of-00008.safetensors",
|
| 30 |
+
"model.language_model.layers.10.input_layernorm.weight": "model-00003-of-00008.safetensors",
|
| 31 |
+
"model.language_model.layers.10.mlp.down_proj.weight": "model-00003-of-00008.safetensors",
|
| 32 |
+
"model.language_model.layers.10.mlp.gate_proj.weight": "model-00003-of-00008.safetensors",
|
| 33 |
+
"model.language_model.layers.10.mlp.up_proj.weight": "model-00003-of-00008.safetensors",
|
| 34 |
+
"model.language_model.layers.10.post_attention_layernorm.weight": "model-00003-of-00008.safetensors",
|
| 35 |
+
"model.language_model.layers.10.self_attn.k_norm.weight": "model-00003-of-00008.safetensors",
|
| 36 |
+
"model.language_model.layers.10.self_attn.k_proj.weight": "model-00003-of-00008.safetensors",
|
| 37 |
+
"model.language_model.layers.10.self_attn.o_proj.weight": "model-00003-of-00008.safetensors",
|
| 38 |
+
"model.language_model.layers.10.self_attn.q_norm.weight": "model-00003-of-00008.safetensors",
|
| 39 |
+
"model.language_model.layers.10.self_attn.q_proj.weight": "model-00003-of-00008.safetensors",
|
| 40 |
+
"model.language_model.layers.10.self_attn.v_proj.weight": "model-00003-of-00008.safetensors",
|
| 41 |
+
"model.language_model.layers.11.input_layernorm.weight": "model-00003-of-00008.safetensors",
|
| 42 |
+
"model.language_model.layers.11.mlp.down_proj.weight": "model-00003-of-00008.safetensors",
|
| 43 |
+
"model.language_model.layers.11.mlp.gate_proj.weight": "model-00003-of-00008.safetensors",
|
| 44 |
+
"model.language_model.layers.11.mlp.up_proj.weight": "model-00003-of-00008.safetensors",
|
| 45 |
+
"model.language_model.layers.11.post_attention_layernorm.weight": "model-00003-of-00008.safetensors",
|
| 46 |
+
"model.language_model.layers.11.self_attn.k_norm.weight": "model-00003-of-00008.safetensors",
|
| 47 |
+
"model.language_model.layers.11.self_attn.k_proj.weight": "model-00003-of-00008.safetensors",
|
| 48 |
+
"model.language_model.layers.11.self_attn.o_proj.weight": "model-00003-of-00008.safetensors",
|
| 49 |
+
"model.language_model.layers.11.self_attn.q_norm.weight": "model-00003-of-00008.safetensors",
|
| 50 |
+
"model.language_model.layers.11.self_attn.q_proj.weight": "model-00003-of-00008.safetensors",
|
| 51 |
+
"model.language_model.layers.11.self_attn.v_proj.weight": "model-00003-of-00008.safetensors",
|
| 52 |
+
"model.language_model.layers.12.input_layernorm.weight": "model-00004-of-00008.safetensors",
|
| 53 |
+
"model.language_model.layers.12.mlp.down_proj.weight": "model-00004-of-00008.safetensors",
|
| 54 |
+
"model.language_model.layers.12.mlp.gate_proj.weight": "model-00003-of-00008.safetensors",
|
| 55 |
+
"model.language_model.layers.12.mlp.up_proj.weight": "model-00003-of-00008.safetensors",
|
| 56 |
+
"model.language_model.layers.12.post_attention_layernorm.weight": "model-00004-of-00008.safetensors",
|
| 57 |
+
"model.language_model.layers.12.self_attn.k_norm.weight": "model-00003-of-00008.safetensors",
|
| 58 |
+
"model.language_model.layers.12.self_attn.k_proj.weight": "model-00003-of-00008.safetensors",
|
| 59 |
+
"model.language_model.layers.12.self_attn.o_proj.weight": "model-00003-of-00008.safetensors",
|
| 60 |
+
"model.language_model.layers.12.self_attn.q_norm.weight": "model-00003-of-00008.safetensors",
|
| 61 |
+
"model.language_model.layers.12.self_attn.q_proj.weight": "model-00003-of-00008.safetensors",
|
| 62 |
+
"model.language_model.layers.12.self_attn.v_proj.weight": "model-00003-of-00008.safetensors",
|
| 63 |
+
"model.language_model.layers.13.input_layernorm.weight": "model-00004-of-00008.safetensors",
|
| 64 |
+
"model.language_model.layers.13.mlp.down_proj.weight": "model-00004-of-00008.safetensors",
|
| 65 |
+
"model.language_model.layers.13.mlp.gate_proj.weight": "model-00004-of-00008.safetensors",
|
| 66 |
+
"model.language_model.layers.13.mlp.up_proj.weight": "model-00004-of-00008.safetensors",
|
| 67 |
+
"model.language_model.layers.13.post_attention_layernorm.weight": "model-00004-of-00008.safetensors",
|
| 68 |
+
"model.language_model.layers.13.self_attn.k_norm.weight": "model-00004-of-00008.safetensors",
|
| 69 |
+
"model.language_model.layers.13.self_attn.k_proj.weight": "model-00004-of-00008.safetensors",
|
| 70 |
+
"model.language_model.layers.13.self_attn.o_proj.weight": "model-00004-of-00008.safetensors",
|
| 71 |
+
"model.language_model.layers.13.self_attn.q_norm.weight": "model-00004-of-00008.safetensors",
|
| 72 |
+
"model.language_model.layers.13.self_attn.q_proj.weight": "model-00004-of-00008.safetensors",
|
| 73 |
+
"model.language_model.layers.13.self_attn.v_proj.weight": "model-00004-of-00008.safetensors",
|
| 74 |
+
"model.language_model.layers.14.input_layernorm.weight": "model-00004-of-00008.safetensors",
|
| 75 |
+
"model.language_model.layers.14.mlp.down_proj.weight": "model-00004-of-00008.safetensors",
|
| 76 |
+
"model.language_model.layers.14.mlp.gate_proj.weight": "model-00004-of-00008.safetensors",
|
| 77 |
+
"model.language_model.layers.14.mlp.up_proj.weight": "model-00004-of-00008.safetensors",
|
| 78 |
+
"model.language_model.layers.14.post_attention_layernorm.weight": "model-00004-of-00008.safetensors",
|
| 79 |
+
"model.language_model.layers.14.self_attn.k_norm.weight": "model-00004-of-00008.safetensors",
|
| 80 |
+
"model.language_model.layers.14.self_attn.k_proj.weight": "model-00004-of-00008.safetensors",
|
| 81 |
+
"model.language_model.layers.14.self_attn.o_proj.weight": "model-00004-of-00008.safetensors",
|
| 82 |
+
"model.language_model.layers.14.self_attn.q_norm.weight": "model-00004-of-00008.safetensors",
|
| 83 |
+
"model.language_model.layers.14.self_attn.q_proj.weight": "model-00004-of-00008.safetensors",
|
| 84 |
+
"model.language_model.layers.14.self_attn.v_proj.weight": "model-00004-of-00008.safetensors",
|
| 85 |
+
"model.language_model.layers.15.input_layernorm.weight": "model-00004-of-00008.safetensors",
|
| 86 |
+
"model.language_model.layers.15.mlp.down_proj.weight": "model-00004-of-00008.safetensors",
|
| 87 |
+
"model.language_model.layers.15.mlp.gate_proj.weight": "model-00004-of-00008.safetensors",
|
| 88 |
+
"model.language_model.layers.15.mlp.up_proj.weight": "model-00004-of-00008.safetensors",
|
| 89 |
+
"model.language_model.layers.15.post_attention_layernorm.weight": "model-00004-of-00008.safetensors",
|
| 90 |
+
"model.language_model.layers.15.self_attn.k_norm.weight": "model-00004-of-00008.safetensors",
|
| 91 |
+
"model.language_model.layers.15.self_attn.k_proj.weight": "model-00004-of-00008.safetensors",
|
| 92 |
+
"model.language_model.layers.15.self_attn.o_proj.weight": "model-00004-of-00008.safetensors",
|
| 93 |
+
"model.language_model.layers.15.self_attn.q_norm.weight": "model-00004-of-00008.safetensors",
|
| 94 |
+
"model.language_model.layers.15.self_attn.q_proj.weight": "model-00004-of-00008.safetensors",
|
| 95 |
+
"model.language_model.layers.15.self_attn.v_proj.weight": "model-00004-of-00008.safetensors",
|
| 96 |
+
"model.language_model.layers.16.input_layernorm.weight": "model-00004-of-00008.safetensors",
|
| 97 |
+
"model.language_model.layers.16.mlp.down_proj.weight": "model-00004-of-00008.safetensors",
|
| 98 |
+
"model.language_model.layers.16.mlp.gate_proj.weight": "model-00004-of-00008.safetensors",
|
| 99 |
+
"model.language_model.layers.16.mlp.up_proj.weight": "model-00004-of-00008.safetensors",
|
| 100 |
+
"model.language_model.layers.16.post_attention_layernorm.weight": "model-00004-of-00008.safetensors",
|
| 101 |
+
"model.language_model.layers.16.self_attn.k_norm.weight": "model-00004-of-00008.safetensors",
|
| 102 |
+
"model.language_model.layers.16.self_attn.k_proj.weight": "model-00004-of-00008.safetensors",
|
| 103 |
+
"model.language_model.layers.16.self_attn.o_proj.weight": "model-00004-of-00008.safetensors",
|
| 104 |
+
"model.language_model.layers.16.self_attn.q_norm.weight": "model-00004-of-00008.safetensors",
|
| 105 |
+
"model.language_model.layers.16.self_attn.q_proj.weight": "model-00004-of-00008.safetensors",
|
| 106 |
+
"model.language_model.layers.16.self_attn.v_proj.weight": "model-00004-of-00008.safetensors",
|
| 107 |
+
"model.language_model.layers.17.input_layernorm.weight": "model-00004-of-00008.safetensors",
|
| 108 |
+
"model.language_model.layers.17.mlp.down_proj.weight": "model-00004-of-00008.safetensors",
|
| 109 |
+
"model.language_model.layers.17.mlp.gate_proj.weight": "model-00004-of-00008.safetensors",
|
| 110 |
+
"model.language_model.layers.17.mlp.up_proj.weight": "model-00004-of-00008.safetensors",
|
| 111 |
+
"model.language_model.layers.17.post_attention_layernorm.weight": "model-00004-of-00008.safetensors",
|
| 112 |
+
"model.language_model.layers.17.self_attn.k_norm.weight": "model-00004-of-00008.safetensors",
|
| 113 |
+
"model.language_model.layers.17.self_attn.k_proj.weight": "model-00004-of-00008.safetensors",
|
| 114 |
+
"model.language_model.layers.17.self_attn.o_proj.weight": "model-00004-of-00008.safetensors",
|
| 115 |
+
"model.language_model.layers.17.self_attn.q_norm.weight": "model-00004-of-00008.safetensors",
|
| 116 |
+
"model.language_model.layers.17.self_attn.q_proj.weight": "model-00004-of-00008.safetensors",
|
| 117 |
+
"model.language_model.layers.17.self_attn.v_proj.weight": "model-00004-of-00008.safetensors",
|
| 118 |
+
"model.language_model.layers.18.input_layernorm.weight": "model-00004-of-00008.safetensors",
|
| 119 |
+
"model.language_model.layers.18.mlp.down_proj.weight": "model-00004-of-00008.safetensors",
|
| 120 |
+
"model.language_model.layers.18.mlp.gate_proj.weight": "model-00004-of-00008.safetensors",
|
| 121 |
+
"model.language_model.layers.18.mlp.up_proj.weight": "model-00004-of-00008.safetensors",
|
| 122 |
+
"model.language_model.layers.18.post_attention_layernorm.weight": "model-00004-of-00008.safetensors",
|
| 123 |
+
"model.language_model.layers.18.self_attn.k_norm.weight": "model-00004-of-00008.safetensors",
|
| 124 |
+
"model.language_model.layers.18.self_attn.k_proj.weight": "model-00004-of-00008.safetensors",
|
| 125 |
+
"model.language_model.layers.18.self_attn.o_proj.weight": "model-00004-of-00008.safetensors",
|
| 126 |
+
"model.language_model.layers.18.self_attn.q_norm.weight": "model-00004-of-00008.safetensors",
|
| 127 |
+
"model.language_model.layers.18.self_attn.q_proj.weight": "model-00004-of-00008.safetensors",
|
| 128 |
+
"model.language_model.layers.18.self_attn.v_proj.weight": "model-00004-of-00008.safetensors",
|
| 129 |
+
"model.language_model.layers.19.input_layernorm.weight": "model-00005-of-00008.safetensors",
|
| 130 |
+
"model.language_model.layers.19.mlp.down_proj.weight": "model-00005-of-00008.safetensors",
|
| 131 |
+
"model.language_model.layers.19.mlp.gate_proj.weight": "model-00005-of-00008.safetensors",
|
| 132 |
+
"model.language_model.layers.19.mlp.up_proj.weight": "model-00005-of-00008.safetensors",
|
| 133 |
+
"model.language_model.layers.19.post_attention_layernorm.weight": "model-00005-of-00008.safetensors",
|
| 134 |
+
"model.language_model.layers.19.self_attn.k_norm.weight": "model-00004-of-00008.safetensors",
|
| 135 |
+
"model.language_model.layers.19.self_attn.k_proj.weight": "model-00004-of-00008.safetensors",
|
| 136 |
+
"model.language_model.layers.19.self_attn.o_proj.weight": "model-00004-of-00008.safetensors",
|
| 137 |
+
"model.language_model.layers.19.self_attn.q_norm.weight": "model-00004-of-00008.safetensors",
|
| 138 |
+
"model.language_model.layers.19.self_attn.q_proj.weight": "model-00004-of-00008.safetensors",
|
| 139 |
+
"model.language_model.layers.19.self_attn.v_proj.weight": "model-00004-of-00008.safetensors",
|
| 140 |
+
"model.language_model.layers.2.input_layernorm.weight": "model-00002-of-00008.safetensors",
|
| 141 |
+
"model.language_model.layers.2.mlp.down_proj.weight": "model-00002-of-00008.safetensors",
|
| 142 |
+
"model.language_model.layers.2.mlp.gate_proj.weight": "model-00002-of-00008.safetensors",
|
| 143 |
+
"model.language_model.layers.2.mlp.up_proj.weight": "model-00002-of-00008.safetensors",
|
| 144 |
+
"model.language_model.layers.2.post_attention_layernorm.weight": "model-00002-of-00008.safetensors",
|
| 145 |
+
"model.language_model.layers.2.self_attn.k_norm.weight": "model-00002-of-00008.safetensors",
|
| 146 |
+
"model.language_model.layers.2.self_attn.k_proj.weight": "model-00002-of-00008.safetensors",
|
| 147 |
+
"model.language_model.layers.2.self_attn.o_proj.weight": "model-00002-of-00008.safetensors",
|
| 148 |
+
"model.language_model.layers.2.self_attn.q_norm.weight": "model-00002-of-00008.safetensors",
|
| 149 |
+
"model.language_model.layers.2.self_attn.q_proj.weight": "model-00002-of-00008.safetensors",
|
| 150 |
+
"model.language_model.layers.2.self_attn.v_proj.weight": "model-00002-of-00008.safetensors",
|
| 151 |
+
"model.language_model.layers.20.input_layernorm.weight": "model-00005-of-00008.safetensors",
|
| 152 |
+
"model.language_model.layers.20.mlp.down_proj.weight": "model-00005-of-00008.safetensors",
|
| 153 |
+
"model.language_model.layers.20.mlp.gate_proj.weight": "model-00005-of-00008.safetensors",
|
| 154 |
+
"model.language_model.layers.20.mlp.up_proj.weight": "model-00005-of-00008.safetensors",
|
| 155 |
+
"model.language_model.layers.20.post_attention_layernorm.weight": "model-00005-of-00008.safetensors",
|
| 156 |
+
"model.language_model.layers.20.self_attn.k_norm.weight": "model-00005-of-00008.safetensors",
|
| 157 |
+
"model.language_model.layers.20.self_attn.k_proj.weight": "model-00005-of-00008.safetensors",
|
| 158 |
+
"model.language_model.layers.20.self_attn.o_proj.weight": "model-00005-of-00008.safetensors",
|
| 159 |
+
"model.language_model.layers.20.self_attn.q_norm.weight": "model-00005-of-00008.safetensors",
|
| 160 |
+
"model.language_model.layers.20.self_attn.q_proj.weight": "model-00005-of-00008.safetensors",
|
| 161 |
+
"model.language_model.layers.20.self_attn.v_proj.weight": "model-00005-of-00008.safetensors",
|
| 162 |
+
"model.language_model.layers.21.input_layernorm.weight": "model-00005-of-00008.safetensors",
|
| 163 |
+
"model.language_model.layers.21.mlp.down_proj.weight": "model-00005-of-00008.safetensors",
|
| 164 |
+
"model.language_model.layers.21.mlp.gate_proj.weight": "model-00005-of-00008.safetensors",
|
| 165 |
+
"model.language_model.layers.21.mlp.up_proj.weight": "model-00005-of-00008.safetensors",
|
| 166 |
+
"model.language_model.layers.21.post_attention_layernorm.weight": "model-00005-of-00008.safetensors",
|
| 167 |
+
"model.language_model.layers.21.self_attn.k_norm.weight": "model-00005-of-00008.safetensors",
|
| 168 |
+
"model.language_model.layers.21.self_attn.k_proj.weight": "model-00005-of-00008.safetensors",
|
| 169 |
+
"model.language_model.layers.21.self_attn.o_proj.weight": "model-00005-of-00008.safetensors",
|
| 170 |
+
"model.language_model.layers.21.self_attn.q_norm.weight": "model-00005-of-00008.safetensors",
|
| 171 |
+
"model.language_model.layers.21.self_attn.q_proj.weight": "model-00005-of-00008.safetensors",
|
| 172 |
+
"model.language_model.layers.21.self_attn.v_proj.weight": "model-00005-of-00008.safetensors",
|
| 173 |
+
"model.language_model.layers.22.input_layernorm.weight": "model-00005-of-00008.safetensors",
|
| 174 |
+
"model.language_model.layers.22.mlp.down_proj.weight": "model-00005-of-00008.safetensors",
|
| 175 |
+
"model.language_model.layers.22.mlp.gate_proj.weight": "model-00005-of-00008.safetensors",
|
| 176 |
+
"model.language_model.layers.22.mlp.up_proj.weight": "model-00005-of-00008.safetensors",
|
| 177 |
+
"model.language_model.layers.22.post_attention_layernorm.weight": "model-00005-of-00008.safetensors",
|
| 178 |
+
"model.language_model.layers.22.self_attn.k_norm.weight": "model-00005-of-00008.safetensors",
|
| 179 |
+
"model.language_model.layers.22.self_attn.k_proj.weight": "model-00005-of-00008.safetensors",
|
| 180 |
+
"model.language_model.layers.22.self_attn.o_proj.weight": "model-00005-of-00008.safetensors",
|
| 181 |
+
"model.language_model.layers.22.self_attn.q_norm.weight": "model-00005-of-00008.safetensors",
|
| 182 |
+
"model.language_model.layers.22.self_attn.q_proj.weight": "model-00005-of-00008.safetensors",
|
| 183 |
+
"model.language_model.layers.22.self_attn.v_proj.weight": "model-00005-of-00008.safetensors",
|
| 184 |
+
"model.language_model.layers.23.input_layernorm.weight": "model-00005-of-00008.safetensors",
|
| 185 |
+
"model.language_model.layers.23.mlp.down_proj.weight": "model-00005-of-00008.safetensors",
|
| 186 |
+
"model.language_model.layers.23.mlp.gate_proj.weight": "model-00005-of-00008.safetensors",
|
| 187 |
+
"model.language_model.layers.23.mlp.up_proj.weight": "model-00005-of-00008.safetensors",
|
| 188 |
+
"model.language_model.layers.23.post_attention_layernorm.weight": "model-00005-of-00008.safetensors",
|
| 189 |
+
"model.language_model.layers.23.self_attn.k_norm.weight": "model-00005-of-00008.safetensors",
|
| 190 |
+
"model.language_model.layers.23.self_attn.k_proj.weight": "model-00005-of-00008.safetensors",
|
| 191 |
+
"model.language_model.layers.23.self_attn.o_proj.weight": "model-00005-of-00008.safetensors",
|
| 192 |
+
"model.language_model.layers.23.self_attn.q_norm.weight": "model-00005-of-00008.safetensors",
|
| 193 |
+
"model.language_model.layers.23.self_attn.q_proj.weight": "model-00005-of-00008.safetensors",
|
| 194 |
+
"model.language_model.layers.23.self_attn.v_proj.weight": "model-00005-of-00008.safetensors",
|
| 195 |
+
"model.language_model.layers.24.input_layernorm.weight": "model-00005-of-00008.safetensors",
|
| 196 |
+
"model.language_model.layers.24.mlp.down_proj.weight": "model-00005-of-00008.safetensors",
|
| 197 |
+
"model.language_model.layers.24.mlp.gate_proj.weight": "model-00005-of-00008.safetensors",
|
| 198 |
+
"model.language_model.layers.24.mlp.up_proj.weight": "model-00005-of-00008.safetensors",
|
| 199 |
+
"model.language_model.layers.24.post_attention_layernorm.weight": "model-00005-of-00008.safetensors",
|
| 200 |
+
"model.language_model.layers.24.self_attn.k_norm.weight": "model-00005-of-00008.safetensors",
|
| 201 |
+
"model.language_model.layers.24.self_attn.k_proj.weight": "model-00005-of-00008.safetensors",
|
| 202 |
+
"model.language_model.layers.24.self_attn.o_proj.weight": "model-00005-of-00008.safetensors",
|
| 203 |
+
"model.language_model.layers.24.self_attn.q_norm.weight": "model-00005-of-00008.safetensors",
|
| 204 |
+
"model.language_model.layers.24.self_attn.q_proj.weight": "model-00005-of-00008.safetensors",
|
| 205 |
+
"model.language_model.layers.24.self_attn.v_proj.weight": "model-00005-of-00008.safetensors",
|
| 206 |
+
"model.language_model.layers.25.input_layernorm.weight": "model-00006-of-00008.safetensors",
|
| 207 |
+
"model.language_model.layers.25.mlp.down_proj.weight": "model-00006-of-00008.safetensors",
|
| 208 |
+
"model.language_model.layers.25.mlp.gate_proj.weight": "model-00005-of-00008.safetensors",
|
| 209 |
+
"model.language_model.layers.25.mlp.up_proj.weight": "model-00006-of-00008.safetensors",
|
| 210 |
+
"model.language_model.layers.25.post_attention_layernorm.weight": "model-00006-of-00008.safetensors",
|
| 211 |
+
"model.language_model.layers.25.self_attn.k_norm.weight": "model-00005-of-00008.safetensors",
|
| 212 |
+
"model.language_model.layers.25.self_attn.k_proj.weight": "model-00005-of-00008.safetensors",
|
| 213 |
+
"model.language_model.layers.25.self_attn.o_proj.weight": "model-00005-of-00008.safetensors",
|
| 214 |
+
"model.language_model.layers.25.self_attn.q_norm.weight": "model-00005-of-00008.safetensors",
|
| 215 |
+
"model.language_model.layers.25.self_attn.q_proj.weight": "model-00005-of-00008.safetensors",
|
| 216 |
+
"model.language_model.layers.25.self_attn.v_proj.weight": "model-00005-of-00008.safetensors",
|
| 217 |
+
"model.language_model.layers.26.input_layernorm.weight": "model-00006-of-00008.safetensors",
|
| 218 |
+
"model.language_model.layers.26.mlp.down_proj.weight": "model-00006-of-00008.safetensors",
|
| 219 |
+
"model.language_model.layers.26.mlp.gate_proj.weight": "model-00006-of-00008.safetensors",
|
| 220 |
+
"model.language_model.layers.26.mlp.up_proj.weight": "model-00006-of-00008.safetensors",
|
| 221 |
+
"model.language_model.layers.26.post_attention_layernorm.weight": "model-00006-of-00008.safetensors",
|
| 222 |
+
"model.language_model.layers.26.self_attn.k_norm.weight": "model-00006-of-00008.safetensors",
|
| 223 |
+
"model.language_model.layers.26.self_attn.k_proj.weight": "model-00006-of-00008.safetensors",
|
| 224 |
+
"model.language_model.layers.26.self_attn.o_proj.weight": "model-00006-of-00008.safetensors",
|
| 225 |
+
"model.language_model.layers.26.self_attn.q_norm.weight": "model-00006-of-00008.safetensors",
|
| 226 |
+
"model.language_model.layers.26.self_attn.q_proj.weight": "model-00006-of-00008.safetensors",
|
| 227 |
+
"model.language_model.layers.26.self_attn.v_proj.weight": "model-00006-of-00008.safetensors",
|
| 228 |
+
"model.language_model.layers.27.input_layernorm.weight": "model-00006-of-00008.safetensors",
|
| 229 |
+
"model.language_model.layers.27.mlp.down_proj.weight": "model-00006-of-00008.safetensors",
|
| 230 |
+
"model.language_model.layers.27.mlp.gate_proj.weight": "model-00006-of-00008.safetensors",
|
| 231 |
+
"model.language_model.layers.27.mlp.up_proj.weight": "model-00006-of-00008.safetensors",
|
| 232 |
+
"model.language_model.layers.27.post_attention_layernorm.weight": "model-00006-of-00008.safetensors",
|
| 233 |
+
"model.language_model.layers.27.self_attn.k_norm.weight": "model-00006-of-00008.safetensors",
|
| 234 |
+
"model.language_model.layers.27.self_attn.k_proj.weight": "model-00006-of-00008.safetensors",
|
| 235 |
+
"model.language_model.layers.27.self_attn.o_proj.weight": "model-00006-of-00008.safetensors",
|
| 236 |
+
"model.language_model.layers.27.self_attn.q_norm.weight": "model-00006-of-00008.safetensors",
|
| 237 |
+
"model.language_model.layers.27.self_attn.q_proj.weight": "model-00006-of-00008.safetensors",
|
| 238 |
+
"model.language_model.layers.27.self_attn.v_proj.weight": "model-00006-of-00008.safetensors",
|
| 239 |
+
"model.language_model.layers.28.input_layernorm.weight": "model-00006-of-00008.safetensors",
|
| 240 |
+
"model.language_model.layers.28.mlp.down_proj.weight": "model-00006-of-00008.safetensors",
|
| 241 |
+
"model.language_model.layers.28.mlp.gate_proj.weight": "model-00006-of-00008.safetensors",
|
| 242 |
+
"model.language_model.layers.28.mlp.up_proj.weight": "model-00006-of-00008.safetensors",
|
| 243 |
+
"model.language_model.layers.28.post_attention_layernorm.weight": "model-00006-of-00008.safetensors",
|
| 244 |
+
"model.language_model.layers.28.self_attn.k_norm.weight": "model-00006-of-00008.safetensors",
|
| 245 |
+
"model.language_model.layers.28.self_attn.k_proj.weight": "model-00006-of-00008.safetensors",
|
| 246 |
+
"model.language_model.layers.28.self_attn.o_proj.weight": "model-00006-of-00008.safetensors",
|
| 247 |
+
"model.language_model.layers.28.self_attn.q_norm.weight": "model-00006-of-00008.safetensors",
|
| 248 |
+
"model.language_model.layers.28.self_attn.q_proj.weight": "model-00006-of-00008.safetensors",
|
| 249 |
+
"model.language_model.layers.28.self_attn.v_proj.weight": "model-00006-of-00008.safetensors",
|
| 250 |
+
"model.language_model.layers.29.input_layernorm.weight": "model-00006-of-00008.safetensors",
|
| 251 |
+
"model.language_model.layers.29.mlp.down_proj.weight": "model-00006-of-00008.safetensors",
|
| 252 |
+
"model.language_model.layers.29.mlp.gate_proj.weight": "model-00006-of-00008.safetensors",
|
| 253 |
+
"model.language_model.layers.29.mlp.up_proj.weight": "model-00006-of-00008.safetensors",
|
| 254 |
+
"model.language_model.layers.29.post_attention_layernorm.weight": "model-00006-of-00008.safetensors",
|
| 255 |
+
"model.language_model.layers.29.self_attn.k_norm.weight": "model-00006-of-00008.safetensors",
|
| 256 |
+
"model.language_model.layers.29.self_attn.k_proj.weight": "model-00006-of-00008.safetensors",
|
| 257 |
+
"model.language_model.layers.29.self_attn.o_proj.weight": "model-00006-of-00008.safetensors",
|
| 258 |
+
"model.language_model.layers.29.self_attn.q_norm.weight": "model-00006-of-00008.safetensors",
|
| 259 |
+
"model.language_model.layers.29.self_attn.q_proj.weight": "model-00006-of-00008.safetensors",
|
| 260 |
+
"model.language_model.layers.29.self_attn.v_proj.weight": "model-00006-of-00008.safetensors",
|
| 261 |
+
"model.language_model.layers.3.input_layernorm.weight": "model-00002-of-00008.safetensors",
|
| 262 |
+
"model.language_model.layers.3.mlp.down_proj.weight": "model-00002-of-00008.safetensors",
|
| 263 |
+
"model.language_model.layers.3.mlp.gate_proj.weight": "model-00002-of-00008.safetensors",
|
| 264 |
+
"model.language_model.layers.3.mlp.up_proj.weight": "model-00002-of-00008.safetensors",
|
| 265 |
+
"model.language_model.layers.3.post_attention_layernorm.weight": "model-00002-of-00008.safetensors",
|
| 266 |
+
"model.language_model.layers.3.self_attn.k_norm.weight": "model-00002-of-00008.safetensors",
|
| 267 |
+
"model.language_model.layers.3.self_attn.k_proj.weight": "model-00002-of-00008.safetensors",
|
| 268 |
+
"model.language_model.layers.3.self_attn.o_proj.weight": "model-00002-of-00008.safetensors",
|
| 269 |
+
"model.language_model.layers.3.self_attn.q_norm.weight": "model-00002-of-00008.safetensors",
|
| 270 |
+
"model.language_model.layers.3.self_attn.q_proj.weight": "model-00002-of-00008.safetensors",
|
| 271 |
+
"model.language_model.layers.3.self_attn.v_proj.weight": "model-00002-of-00008.safetensors",
|
| 272 |
+
"model.language_model.layers.30.input_layernorm.weight": "model-00006-of-00008.safetensors",
|
| 273 |
+
"model.language_model.layers.30.mlp.down_proj.weight": "model-00006-of-00008.safetensors",
|
| 274 |
+
"model.language_model.layers.30.mlp.gate_proj.weight": "model-00006-of-00008.safetensors",
|
| 275 |
+
"model.language_model.layers.30.mlp.up_proj.weight": "model-00006-of-00008.safetensors",
|
| 276 |
+
"model.language_model.layers.30.post_attention_layernorm.weight": "model-00006-of-00008.safetensors",
|
| 277 |
+
"model.language_model.layers.30.self_attn.k_norm.weight": "model-00006-of-00008.safetensors",
|
| 278 |
+
"model.language_model.layers.30.self_attn.k_proj.weight": "model-00006-of-00008.safetensors",
|
| 279 |
+
"model.language_model.layers.30.self_attn.o_proj.weight": "model-00006-of-00008.safetensors",
|
| 280 |
+
"model.language_model.layers.30.self_attn.q_norm.weight": "model-00006-of-00008.safetensors",
|
| 281 |
+
"model.language_model.layers.30.self_attn.q_proj.weight": "model-00006-of-00008.safetensors",
|
| 282 |
+
"model.language_model.layers.30.self_attn.v_proj.weight": "model-00006-of-00008.safetensors",
|
| 283 |
+
"model.language_model.layers.31.input_layernorm.weight": "model-00007-of-00008.safetensors",
|
| 284 |
+
"model.language_model.layers.31.mlp.down_proj.weight": "model-00007-of-00008.safetensors",
|
| 285 |
+
"model.language_model.layers.31.mlp.gate_proj.weight": "model-00006-of-00008.safetensors",
|
| 286 |
+
"model.language_model.layers.31.mlp.up_proj.weight": "model-00006-of-00008.safetensors",
|
| 287 |
+
"model.language_model.layers.31.post_attention_layernorm.weight": "model-00007-of-00008.safetensors",
|
| 288 |
+
"model.language_model.layers.31.self_attn.k_norm.weight": "model-00006-of-00008.safetensors",
|
| 289 |
+
"model.language_model.layers.31.self_attn.k_proj.weight": "model-00006-of-00008.safetensors",
|
| 290 |
+
"model.language_model.layers.31.self_attn.o_proj.weight": "model-00006-of-00008.safetensors",
|
| 291 |
+
"model.language_model.layers.31.self_attn.q_norm.weight": "model-00006-of-00008.safetensors",
|
| 292 |
+
"model.language_model.layers.31.self_attn.q_proj.weight": "model-00006-of-00008.safetensors",
|
| 293 |
+
"model.language_model.layers.31.self_attn.v_proj.weight": "model-00006-of-00008.safetensors",
|
| 294 |
+
"model.language_model.layers.32.input_layernorm.weight": "model-00007-of-00008.safetensors",
|
| 295 |
+
"model.language_model.layers.32.mlp.down_proj.weight": "model-00007-of-00008.safetensors",
|
| 296 |
+
"model.language_model.layers.32.mlp.gate_proj.weight": "model-00007-of-00008.safetensors",
|
| 297 |
+
"model.language_model.layers.32.mlp.up_proj.weight": "model-00007-of-00008.safetensors",
|
| 298 |
+
"model.language_model.layers.32.post_attention_layernorm.weight": "model-00007-of-00008.safetensors",
|
| 299 |
+
"model.language_model.layers.32.self_attn.k_norm.weight": "model-00007-of-00008.safetensors",
|
| 300 |
+
"model.language_model.layers.32.self_attn.k_proj.weight": "model-00007-of-00008.safetensors",
|
| 301 |
+
"model.language_model.layers.32.self_attn.o_proj.weight": "model-00007-of-00008.safetensors",
|
| 302 |
+
"model.language_model.layers.32.self_attn.q_norm.weight": "model-00007-of-00008.safetensors",
|
| 303 |
+
"model.language_model.layers.32.self_attn.q_proj.weight": "model-00007-of-00008.safetensors",
|
| 304 |
+
"model.language_model.layers.32.self_attn.v_proj.weight": "model-00007-of-00008.safetensors",
|
| 305 |
+
"model.language_model.layers.33.input_layernorm.weight": "model-00007-of-00008.safetensors",
|
| 306 |
+
"model.language_model.layers.33.mlp.down_proj.weight": "model-00007-of-00008.safetensors",
|
| 307 |
+
"model.language_model.layers.33.mlp.gate_proj.weight": "model-00007-of-00008.safetensors",
|
| 308 |
+
"model.language_model.layers.33.mlp.up_proj.weight": "model-00007-of-00008.safetensors",
|
| 309 |
+
"model.language_model.layers.33.post_attention_layernorm.weight": "model-00007-of-00008.safetensors",
|
| 310 |
+
"model.language_model.layers.33.self_attn.k_norm.weight": "model-00007-of-00008.safetensors",
|
| 311 |
+
"model.language_model.layers.33.self_attn.k_proj.weight": "model-00007-of-00008.safetensors",
|
| 312 |
+
"model.language_model.layers.33.self_attn.o_proj.weight": "model-00007-of-00008.safetensors",
|
| 313 |
+
"model.language_model.layers.33.self_attn.q_norm.weight": "model-00007-of-00008.safetensors",
|
| 314 |
+
"model.language_model.layers.33.self_attn.q_proj.weight": "model-00007-of-00008.safetensors",
|
| 315 |
+
"model.language_model.layers.33.self_attn.v_proj.weight": "model-00007-of-00008.safetensors",
|
| 316 |
+
"model.language_model.layers.34.input_layernorm.weight": "model-00007-of-00008.safetensors",
|
| 317 |
+
"model.language_model.layers.34.mlp.down_proj.weight": "model-00007-of-00008.safetensors",
|
| 318 |
+
"model.language_model.layers.34.mlp.gate_proj.weight": "model-00007-of-00008.safetensors",
|
| 319 |
+
"model.language_model.layers.34.mlp.up_proj.weight": "model-00007-of-00008.safetensors",
|
| 320 |
+
"model.language_model.layers.34.post_attention_layernorm.weight": "model-00007-of-00008.safetensors",
|
| 321 |
+
"model.language_model.layers.34.self_attn.k_norm.weight": "model-00007-of-00008.safetensors",
|
| 322 |
+
"model.language_model.layers.34.self_attn.k_proj.weight": "model-00007-of-00008.safetensors",
|
| 323 |
+
"model.language_model.layers.34.self_attn.o_proj.weight": "model-00007-of-00008.safetensors",
|
| 324 |
+
"model.language_model.layers.34.self_attn.q_norm.weight": "model-00007-of-00008.safetensors",
|
| 325 |
+
"model.language_model.layers.34.self_attn.q_proj.weight": "model-00007-of-00008.safetensors",
|
| 326 |
+
"model.language_model.layers.34.self_attn.v_proj.weight": "model-00007-of-00008.safetensors",
|
| 327 |
+
"model.language_model.layers.35.input_layernorm.weight": "model-00007-of-00008.safetensors",
|
| 328 |
+
"model.language_model.layers.35.mlp.down_proj.weight": "model-00007-of-00008.safetensors",
|
| 329 |
+
"model.language_model.layers.35.mlp.gate_proj.weight": "model-00007-of-00008.safetensors",
|
| 330 |
+
"model.language_model.layers.35.mlp.up_proj.weight": "model-00007-of-00008.safetensors",
|
| 331 |
+
"model.language_model.layers.35.post_attention_layernorm.weight": "model-00007-of-00008.safetensors",
|
| 332 |
+
"model.language_model.layers.35.self_attn.k_norm.weight": "model-00007-of-00008.safetensors",
|
| 333 |
+
"model.language_model.layers.35.self_attn.k_proj.weight": "model-00007-of-00008.safetensors",
|
| 334 |
+
"model.language_model.layers.35.self_attn.o_proj.weight": "model-00007-of-00008.safetensors",
|
| 335 |
+
"model.language_model.layers.35.self_attn.q_norm.weight": "model-00007-of-00008.safetensors",
|
| 336 |
+
"model.language_model.layers.35.self_attn.q_proj.weight": "model-00007-of-00008.safetensors",
|
| 337 |
+
"model.language_model.layers.35.self_attn.v_proj.weight": "model-00007-of-00008.safetensors",
|
| 338 |
+
"model.language_model.layers.4.input_layernorm.weight": "model-00002-of-00008.safetensors",
|
| 339 |
+
"model.language_model.layers.4.mlp.down_proj.weight": "model-00002-of-00008.safetensors",
|
| 340 |
+
"model.language_model.layers.4.mlp.gate_proj.weight": "model-00002-of-00008.safetensors",
|
| 341 |
+
"model.language_model.layers.4.mlp.up_proj.weight": "model-00002-of-00008.safetensors",
|
| 342 |
+
"model.language_model.layers.4.post_attention_layernorm.weight": "model-00002-of-00008.safetensors",
|
| 343 |
+
"model.language_model.layers.4.self_attn.k_norm.weight": "model-00002-of-00008.safetensors",
|
| 344 |
+
"model.language_model.layers.4.self_attn.k_proj.weight": "model-00002-of-00008.safetensors",
|
| 345 |
+
"model.language_model.layers.4.self_attn.o_proj.weight": "model-00002-of-00008.safetensors",
|
| 346 |
+
"model.language_model.layers.4.self_attn.q_norm.weight": "model-00002-of-00008.safetensors",
|
| 347 |
+
"model.language_model.layers.4.self_attn.q_proj.weight": "model-00002-of-00008.safetensors",
|
| 348 |
+
"model.language_model.layers.4.self_attn.v_proj.weight": "model-00002-of-00008.safetensors",
|
| 349 |
+
"model.language_model.layers.5.input_layernorm.weight": "model-00002-of-00008.safetensors",
|
| 350 |
+
"model.language_model.layers.5.mlp.down_proj.weight": "model-00002-of-00008.safetensors",
|
| 351 |
+
"model.language_model.layers.5.mlp.gate_proj.weight": "model-00002-of-00008.safetensors",
|
| 352 |
+
"model.language_model.layers.5.mlp.up_proj.weight": "model-00002-of-00008.safetensors",
|
| 353 |
+
"model.language_model.layers.5.post_attention_layernorm.weight": "model-00002-of-00008.safetensors",
|
| 354 |
+
"model.language_model.layers.5.self_attn.k_norm.weight": "model-00002-of-00008.safetensors",
|
| 355 |
+
"model.language_model.layers.5.self_attn.k_proj.weight": "model-00002-of-00008.safetensors",
|
| 356 |
+
"model.language_model.layers.5.self_attn.o_proj.weight": "model-00002-of-00008.safetensors",
|
| 357 |
+
"model.language_model.layers.5.self_attn.q_norm.weight": "model-00002-of-00008.safetensors",
|
| 358 |
+
"model.language_model.layers.5.self_attn.q_proj.weight": "model-00002-of-00008.safetensors",
|
| 359 |
+
"model.language_model.layers.5.self_attn.v_proj.weight": "model-00002-of-00008.safetensors",
|
| 360 |
+
"model.language_model.layers.6.input_layernorm.weight": "model-00003-of-00008.safetensors",
|
| 361 |
+
"model.language_model.layers.6.mlp.down_proj.weight": "model-00003-of-00008.safetensors",
|
| 362 |
+
"model.language_model.layers.6.mlp.gate_proj.weight": "model-00002-of-00008.safetensors",
|
| 363 |
+
"model.language_model.layers.6.mlp.up_proj.weight": "model-00003-of-00008.safetensors",
|
| 364 |
+
"model.language_model.layers.6.post_attention_layernorm.weight": "model-00003-of-00008.safetensors",
|
| 365 |
+
"model.language_model.layers.6.self_attn.k_norm.weight": "model-00002-of-00008.safetensors",
|
| 366 |
+
"model.language_model.layers.6.self_attn.k_proj.weight": "model-00002-of-00008.safetensors",
|
| 367 |
+
"model.language_model.layers.6.self_attn.o_proj.weight": "model-00002-of-00008.safetensors",
|
| 368 |
+
"model.language_model.layers.6.self_attn.q_norm.weight": "model-00002-of-00008.safetensors",
|
| 369 |
+
"model.language_model.layers.6.self_attn.q_proj.weight": "model-00002-of-00008.safetensors",
|
| 370 |
+
"model.language_model.layers.6.self_attn.v_proj.weight": "model-00002-of-00008.safetensors",
|
| 371 |
+
"model.language_model.layers.7.input_layernorm.weight": "model-00003-of-00008.safetensors",
|
| 372 |
+
"model.language_model.layers.7.mlp.down_proj.weight": "model-00003-of-00008.safetensors",
|
| 373 |
+
"model.language_model.layers.7.mlp.gate_proj.weight": "model-00003-of-00008.safetensors",
|
| 374 |
+
"model.language_model.layers.7.mlp.up_proj.weight": "model-00003-of-00008.safetensors",
|
| 375 |
+
"model.language_model.layers.7.post_attention_layernorm.weight": "model-00003-of-00008.safetensors",
|
| 376 |
+
"model.language_model.layers.7.self_attn.k_norm.weight": "model-00003-of-00008.safetensors",
|
| 377 |
+
"model.language_model.layers.7.self_attn.k_proj.weight": "model-00003-of-00008.safetensors",
|
| 378 |
+
"model.language_model.layers.7.self_attn.o_proj.weight": "model-00003-of-00008.safetensors",
|
| 379 |
+
"model.language_model.layers.7.self_attn.q_norm.weight": "model-00003-of-00008.safetensors",
|
| 380 |
+
"model.language_model.layers.7.self_attn.q_proj.weight": "model-00003-of-00008.safetensors",
|
| 381 |
+
"model.language_model.layers.7.self_attn.v_proj.weight": "model-00003-of-00008.safetensors",
|
| 382 |
+
"model.language_model.layers.8.input_layernorm.weight": "model-00003-of-00008.safetensors",
|
| 383 |
+
"model.language_model.layers.8.mlp.down_proj.weight": "model-00003-of-00008.safetensors",
|
| 384 |
+
"model.language_model.layers.8.mlp.gate_proj.weight": "model-00003-of-00008.safetensors",
|
| 385 |
+
"model.language_model.layers.8.mlp.up_proj.weight": "model-00003-of-00008.safetensors",
|
| 386 |
+
"model.language_model.layers.8.post_attention_layernorm.weight": "model-00003-of-00008.safetensors",
|
| 387 |
+
"model.language_model.layers.8.self_attn.k_norm.weight": "model-00003-of-00008.safetensors",
|
| 388 |
+
"model.language_model.layers.8.self_attn.k_proj.weight": "model-00003-of-00008.safetensors",
|
| 389 |
+
"model.language_model.layers.8.self_attn.o_proj.weight": "model-00003-of-00008.safetensors",
|
| 390 |
+
"model.language_model.layers.8.self_attn.q_norm.weight": "model-00003-of-00008.safetensors",
|
| 391 |
+
"model.language_model.layers.8.self_attn.q_proj.weight": "model-00003-of-00008.safetensors",
|
| 392 |
+
"model.language_model.layers.8.self_attn.v_proj.weight": "model-00003-of-00008.safetensors",
|
| 393 |
+
"model.language_model.layers.9.input_layernorm.weight": "model-00003-of-00008.safetensors",
|
| 394 |
+
"model.language_model.layers.9.mlp.down_proj.weight": "model-00003-of-00008.safetensors",
|
| 395 |
+
"model.language_model.layers.9.mlp.gate_proj.weight": "model-00003-of-00008.safetensors",
|
| 396 |
+
"model.language_model.layers.9.mlp.up_proj.weight": "model-00003-of-00008.safetensors",
|
| 397 |
+
"model.language_model.layers.9.post_attention_layernorm.weight": "model-00003-of-00008.safetensors",
|
| 398 |
+
"model.language_model.layers.9.self_attn.k_norm.weight": "model-00003-of-00008.safetensors",
|
| 399 |
+
"model.language_model.layers.9.self_attn.k_proj.weight": "model-00003-of-00008.safetensors",
|
| 400 |
+
"model.language_model.layers.9.self_attn.o_proj.weight": "model-00003-of-00008.safetensors",
|
| 401 |
+
"model.language_model.layers.9.self_attn.q_norm.weight": "model-00003-of-00008.safetensors",
|
| 402 |
+
"model.language_model.layers.9.self_attn.q_proj.weight": "model-00003-of-00008.safetensors",
|
| 403 |
+
"model.language_model.layers.9.self_attn.v_proj.weight": "model-00003-of-00008.safetensors",
|
| 404 |
+
"model.language_model.norm.weight": "model-00007-of-00008.safetensors",
|
| 405 |
+
"model.visual.blocks.0.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 406 |
+
"model.visual.blocks.0.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 407 |
+
"model.visual.blocks.0.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 408 |
+
"model.visual.blocks.0.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 409 |
+
"model.visual.blocks.0.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 410 |
+
"model.visual.blocks.0.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 411 |
+
"model.visual.blocks.0.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 412 |
+
"model.visual.blocks.0.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 413 |
+
"model.visual.blocks.0.norm1.bias": "model-00001-of-00008.safetensors",
|
| 414 |
+
"model.visual.blocks.0.norm1.weight": "model-00001-of-00008.safetensors",
|
| 415 |
+
"model.visual.blocks.0.norm2.bias": "model-00001-of-00008.safetensors",
|
| 416 |
+
"model.visual.blocks.0.norm2.weight": "model-00001-of-00008.safetensors",
|
| 417 |
+
"model.visual.blocks.1.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 418 |
+
"model.visual.blocks.1.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 419 |
+
"model.visual.blocks.1.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 420 |
+
"model.visual.blocks.1.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 421 |
+
"model.visual.blocks.1.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 422 |
+
"model.visual.blocks.1.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 423 |
+
"model.visual.blocks.1.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 424 |
+
"model.visual.blocks.1.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 425 |
+
"model.visual.blocks.1.norm1.bias": "model-00001-of-00008.safetensors",
|
| 426 |
+
"model.visual.blocks.1.norm1.weight": "model-00001-of-00008.safetensors",
|
| 427 |
+
"model.visual.blocks.1.norm2.bias": "model-00001-of-00008.safetensors",
|
| 428 |
+
"model.visual.blocks.1.norm2.weight": "model-00001-of-00008.safetensors",
|
| 429 |
+
"model.visual.blocks.10.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 430 |
+
"model.visual.blocks.10.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 431 |
+
"model.visual.blocks.10.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 432 |
+
"model.visual.blocks.10.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 433 |
+
"model.visual.blocks.10.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 434 |
+
"model.visual.blocks.10.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 435 |
+
"model.visual.blocks.10.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 436 |
+
"model.visual.blocks.10.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 437 |
+
"model.visual.blocks.10.norm1.bias": "model-00001-of-00008.safetensors",
|
| 438 |
+
"model.visual.blocks.10.norm1.weight": "model-00001-of-00008.safetensors",
|
| 439 |
+
"model.visual.blocks.10.norm2.bias": "model-00001-of-00008.safetensors",
|
| 440 |
+
"model.visual.blocks.10.norm2.weight": "model-00001-of-00008.safetensors",
|
| 441 |
+
"model.visual.blocks.11.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 442 |
+
"model.visual.blocks.11.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 443 |
+
"model.visual.blocks.11.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 444 |
+
"model.visual.blocks.11.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 445 |
+
"model.visual.blocks.11.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 446 |
+
"model.visual.blocks.11.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 447 |
+
"model.visual.blocks.11.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 448 |
+
"model.visual.blocks.11.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 449 |
+
"model.visual.blocks.11.norm1.bias": "model-00001-of-00008.safetensors",
|
| 450 |
+
"model.visual.blocks.11.norm1.weight": "model-00001-of-00008.safetensors",
|
| 451 |
+
"model.visual.blocks.11.norm2.bias": "model-00001-of-00008.safetensors",
|
| 452 |
+
"model.visual.blocks.11.norm2.weight": "model-00001-of-00008.safetensors",
|
| 453 |
+
"model.visual.blocks.12.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 454 |
+
"model.visual.blocks.12.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 455 |
+
"model.visual.blocks.12.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 456 |
+
"model.visual.blocks.12.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 457 |
+
"model.visual.blocks.12.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 458 |
+
"model.visual.blocks.12.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 459 |
+
"model.visual.blocks.12.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 460 |
+
"model.visual.blocks.12.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 461 |
+
"model.visual.blocks.12.norm1.bias": "model-00001-of-00008.safetensors",
|
| 462 |
+
"model.visual.blocks.12.norm1.weight": "model-00001-of-00008.safetensors",
|
| 463 |
+
"model.visual.blocks.12.norm2.bias": "model-00001-of-00008.safetensors",
|
| 464 |
+
"model.visual.blocks.12.norm2.weight": "model-00001-of-00008.safetensors",
|
| 465 |
+
"model.visual.blocks.13.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 466 |
+
"model.visual.blocks.13.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 467 |
+
"model.visual.blocks.13.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 468 |
+
"model.visual.blocks.13.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 469 |
+
"model.visual.blocks.13.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 470 |
+
"model.visual.blocks.13.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 471 |
+
"model.visual.blocks.13.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 472 |
+
"model.visual.blocks.13.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 473 |
+
"model.visual.blocks.13.norm1.bias": "model-00001-of-00008.safetensors",
|
| 474 |
+
"model.visual.blocks.13.norm1.weight": "model-00001-of-00008.safetensors",
|
| 475 |
+
"model.visual.blocks.13.norm2.bias": "model-00001-of-00008.safetensors",
|
| 476 |
+
"model.visual.blocks.13.norm2.weight": "model-00001-of-00008.safetensors",
|
| 477 |
+
"model.visual.blocks.14.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 478 |
+
"model.visual.blocks.14.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 479 |
+
"model.visual.blocks.14.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 480 |
+
"model.visual.blocks.14.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 481 |
+
"model.visual.blocks.14.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 482 |
+
"model.visual.blocks.14.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 483 |
+
"model.visual.blocks.14.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 484 |
+
"model.visual.blocks.14.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 485 |
+
"model.visual.blocks.14.norm1.bias": "model-00001-of-00008.safetensors",
|
| 486 |
+
"model.visual.blocks.14.norm1.weight": "model-00001-of-00008.safetensors",
|
| 487 |
+
"model.visual.blocks.14.norm2.bias": "model-00001-of-00008.safetensors",
|
| 488 |
+
"model.visual.blocks.14.norm2.weight": "model-00001-of-00008.safetensors",
|
| 489 |
+
"model.visual.blocks.15.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 490 |
+
"model.visual.blocks.15.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 491 |
+
"model.visual.blocks.15.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 492 |
+
"model.visual.blocks.15.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 493 |
+
"model.visual.blocks.15.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 494 |
+
"model.visual.blocks.15.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 495 |
+
"model.visual.blocks.15.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 496 |
+
"model.visual.blocks.15.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 497 |
+
"model.visual.blocks.15.norm1.bias": "model-00001-of-00008.safetensors",
|
| 498 |
+
"model.visual.blocks.15.norm1.weight": "model-00001-of-00008.safetensors",
|
| 499 |
+
"model.visual.blocks.15.norm2.bias": "model-00001-of-00008.safetensors",
|
| 500 |
+
"model.visual.blocks.15.norm2.weight": "model-00001-of-00008.safetensors",
|
| 501 |
+
"model.visual.blocks.16.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 502 |
+
"model.visual.blocks.16.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 503 |
+
"model.visual.blocks.16.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 504 |
+
"model.visual.blocks.16.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 505 |
+
"model.visual.blocks.16.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 506 |
+
"model.visual.blocks.16.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 507 |
+
"model.visual.blocks.16.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 508 |
+
"model.visual.blocks.16.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 509 |
+
"model.visual.blocks.16.norm1.bias": "model-00001-of-00008.safetensors",
|
| 510 |
+
"model.visual.blocks.16.norm1.weight": "model-00001-of-00008.safetensors",
|
| 511 |
+
"model.visual.blocks.16.norm2.bias": "model-00001-of-00008.safetensors",
|
| 512 |
+
"model.visual.blocks.16.norm2.weight": "model-00001-of-00008.safetensors",
|
| 513 |
+
"model.visual.blocks.17.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 514 |
+
"model.visual.blocks.17.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 515 |
+
"model.visual.blocks.17.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 516 |
+
"model.visual.blocks.17.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 517 |
+
"model.visual.blocks.17.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 518 |
+
"model.visual.blocks.17.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 519 |
+
"model.visual.blocks.17.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 520 |
+
"model.visual.blocks.17.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 521 |
+
"model.visual.blocks.17.norm1.bias": "model-00001-of-00008.safetensors",
|
| 522 |
+
"model.visual.blocks.17.norm1.weight": "model-00001-of-00008.safetensors",
|
| 523 |
+
"model.visual.blocks.17.norm2.bias": "model-00001-of-00008.safetensors",
|
| 524 |
+
"model.visual.blocks.17.norm2.weight": "model-00001-of-00008.safetensors",
|
| 525 |
+
"model.visual.blocks.18.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 526 |
+
"model.visual.blocks.18.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 527 |
+
"model.visual.blocks.18.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 528 |
+
"model.visual.blocks.18.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 529 |
+
"model.visual.blocks.18.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 530 |
+
"model.visual.blocks.18.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 531 |
+
"model.visual.blocks.18.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 532 |
+
"model.visual.blocks.18.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 533 |
+
"model.visual.blocks.18.norm1.bias": "model-00001-of-00008.safetensors",
|
| 534 |
+
"model.visual.blocks.18.norm1.weight": "model-00001-of-00008.safetensors",
|
| 535 |
+
"model.visual.blocks.18.norm2.bias": "model-00001-of-00008.safetensors",
|
| 536 |
+
"model.visual.blocks.18.norm2.weight": "model-00001-of-00008.safetensors",
|
| 537 |
+
"model.visual.blocks.19.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 538 |
+
"model.visual.blocks.19.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 539 |
+
"model.visual.blocks.19.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 540 |
+
"model.visual.blocks.19.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 541 |
+
"model.visual.blocks.19.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 542 |
+
"model.visual.blocks.19.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 543 |
+
"model.visual.blocks.19.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 544 |
+
"model.visual.blocks.19.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 545 |
+
"model.visual.blocks.19.norm1.bias": "model-00001-of-00008.safetensors",
|
| 546 |
+
"model.visual.blocks.19.norm1.weight": "model-00001-of-00008.safetensors",
|
| 547 |
+
"model.visual.blocks.19.norm2.bias": "model-00001-of-00008.safetensors",
|
| 548 |
+
"model.visual.blocks.19.norm2.weight": "model-00001-of-00008.safetensors",
|
| 549 |
+
"model.visual.blocks.2.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 550 |
+
"model.visual.blocks.2.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 551 |
+
"model.visual.blocks.2.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 552 |
+
"model.visual.blocks.2.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 553 |
+
"model.visual.blocks.2.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 554 |
+
"model.visual.blocks.2.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 555 |
+
"model.visual.blocks.2.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 556 |
+
"model.visual.blocks.2.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 557 |
+
"model.visual.blocks.2.norm1.bias": "model-00001-of-00008.safetensors",
|
| 558 |
+
"model.visual.blocks.2.norm1.weight": "model-00001-of-00008.safetensors",
|
| 559 |
+
"model.visual.blocks.2.norm2.bias": "model-00001-of-00008.safetensors",
|
| 560 |
+
"model.visual.blocks.2.norm2.weight": "model-00001-of-00008.safetensors",
|
| 561 |
+
"model.visual.blocks.20.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 562 |
+
"model.visual.blocks.20.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 563 |
+
"model.visual.blocks.20.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 564 |
+
"model.visual.blocks.20.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 565 |
+
"model.visual.blocks.20.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 566 |
+
"model.visual.blocks.20.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 567 |
+
"model.visual.blocks.20.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 568 |
+
"model.visual.blocks.20.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 569 |
+
"model.visual.blocks.20.norm1.bias": "model-00001-of-00008.safetensors",
|
| 570 |
+
"model.visual.blocks.20.norm1.weight": "model-00001-of-00008.safetensors",
|
| 571 |
+
"model.visual.blocks.20.norm2.bias": "model-00001-of-00008.safetensors",
|
| 572 |
+
"model.visual.blocks.20.norm2.weight": "model-00001-of-00008.safetensors",
|
| 573 |
+
"model.visual.blocks.21.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 574 |
+
"model.visual.blocks.21.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 575 |
+
"model.visual.blocks.21.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 576 |
+
"model.visual.blocks.21.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 577 |
+
"model.visual.blocks.21.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 578 |
+
"model.visual.blocks.21.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 579 |
+
"model.visual.blocks.21.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 580 |
+
"model.visual.blocks.21.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 581 |
+
"model.visual.blocks.21.norm1.bias": "model-00001-of-00008.safetensors",
|
| 582 |
+
"model.visual.blocks.21.norm1.weight": "model-00001-of-00008.safetensors",
|
| 583 |
+
"model.visual.blocks.21.norm2.bias": "model-00001-of-00008.safetensors",
|
| 584 |
+
"model.visual.blocks.21.norm2.weight": "model-00001-of-00008.safetensors",
|
| 585 |
+
"model.visual.blocks.22.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 586 |
+
"model.visual.blocks.22.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 587 |
+
"model.visual.blocks.22.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 588 |
+
"model.visual.blocks.22.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 589 |
+
"model.visual.blocks.22.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 590 |
+
"model.visual.blocks.22.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 591 |
+
"model.visual.blocks.22.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 592 |
+
"model.visual.blocks.22.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 593 |
+
"model.visual.blocks.22.norm1.bias": "model-00001-of-00008.safetensors",
|
| 594 |
+
"model.visual.blocks.22.norm1.weight": "model-00001-of-00008.safetensors",
|
| 595 |
+
"model.visual.blocks.22.norm2.bias": "model-00001-of-00008.safetensors",
|
| 596 |
+
"model.visual.blocks.22.norm2.weight": "model-00001-of-00008.safetensors",
|
| 597 |
+
"model.visual.blocks.23.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 598 |
+
"model.visual.blocks.23.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 599 |
+
"model.visual.blocks.23.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 600 |
+
"model.visual.blocks.23.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 601 |
+
"model.visual.blocks.23.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 602 |
+
"model.visual.blocks.23.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 603 |
+
"model.visual.blocks.23.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 604 |
+
"model.visual.blocks.23.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 605 |
+
"model.visual.blocks.23.norm1.bias": "model-00001-of-00008.safetensors",
|
| 606 |
+
"model.visual.blocks.23.norm1.weight": "model-00001-of-00008.safetensors",
|
| 607 |
+
"model.visual.blocks.23.norm2.bias": "model-00001-of-00008.safetensors",
|
| 608 |
+
"model.visual.blocks.23.norm2.weight": "model-00001-of-00008.safetensors",
|
| 609 |
+
"model.visual.blocks.24.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 610 |
+
"model.visual.blocks.24.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 611 |
+
"model.visual.blocks.24.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 612 |
+
"model.visual.blocks.24.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 613 |
+
"model.visual.blocks.24.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 614 |
+
"model.visual.blocks.24.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 615 |
+
"model.visual.blocks.24.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 616 |
+
"model.visual.blocks.24.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 617 |
+
"model.visual.blocks.24.norm1.bias": "model-00001-of-00008.safetensors",
|
| 618 |
+
"model.visual.blocks.24.norm1.weight": "model-00001-of-00008.safetensors",
|
| 619 |
+
"model.visual.blocks.24.norm2.bias": "model-00001-of-00008.safetensors",
|
| 620 |
+
"model.visual.blocks.24.norm2.weight": "model-00001-of-00008.safetensors",
|
| 621 |
+
"model.visual.blocks.25.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 622 |
+
"model.visual.blocks.25.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 623 |
+
"model.visual.blocks.25.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 624 |
+
"model.visual.blocks.25.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 625 |
+
"model.visual.blocks.25.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 626 |
+
"model.visual.blocks.25.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 627 |
+
"model.visual.blocks.25.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 628 |
+
"model.visual.blocks.25.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 629 |
+
"model.visual.blocks.25.norm1.bias": "model-00001-of-00008.safetensors",
|
| 630 |
+
"model.visual.blocks.25.norm1.weight": "model-00001-of-00008.safetensors",
|
| 631 |
+
"model.visual.blocks.25.norm2.bias": "model-00001-of-00008.safetensors",
|
| 632 |
+
"model.visual.blocks.25.norm2.weight": "model-00001-of-00008.safetensors",
|
| 633 |
+
"model.visual.blocks.26.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 634 |
+
"model.visual.blocks.26.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 635 |
+
"model.visual.blocks.26.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 636 |
+
"model.visual.blocks.26.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 637 |
+
"model.visual.blocks.26.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 638 |
+
"model.visual.blocks.26.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 639 |
+
"model.visual.blocks.26.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 640 |
+
"model.visual.blocks.26.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 641 |
+
"model.visual.blocks.26.norm1.bias": "model-00001-of-00008.safetensors",
|
| 642 |
+
"model.visual.blocks.26.norm1.weight": "model-00001-of-00008.safetensors",
|
| 643 |
+
"model.visual.blocks.26.norm2.bias": "model-00001-of-00008.safetensors",
|
| 644 |
+
"model.visual.blocks.26.norm2.weight": "model-00001-of-00008.safetensors",
|
| 645 |
+
"model.visual.blocks.3.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 646 |
+
"model.visual.blocks.3.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 647 |
+
"model.visual.blocks.3.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 648 |
+
"model.visual.blocks.3.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 649 |
+
"model.visual.blocks.3.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 650 |
+
"model.visual.blocks.3.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 651 |
+
"model.visual.blocks.3.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 652 |
+
"model.visual.blocks.3.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 653 |
+
"model.visual.blocks.3.norm1.bias": "model-00001-of-00008.safetensors",
|
| 654 |
+
"model.visual.blocks.3.norm1.weight": "model-00001-of-00008.safetensors",
|
| 655 |
+
"model.visual.blocks.3.norm2.bias": "model-00001-of-00008.safetensors",
|
| 656 |
+
"model.visual.blocks.3.norm2.weight": "model-00001-of-00008.safetensors",
|
| 657 |
+
"model.visual.blocks.4.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 658 |
+
"model.visual.blocks.4.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 659 |
+
"model.visual.blocks.4.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 660 |
+
"model.visual.blocks.4.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 661 |
+
"model.visual.blocks.4.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 662 |
+
"model.visual.blocks.4.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 663 |
+
"model.visual.blocks.4.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 664 |
+
"model.visual.blocks.4.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 665 |
+
"model.visual.blocks.4.norm1.bias": "model-00001-of-00008.safetensors",
|
| 666 |
+
"model.visual.blocks.4.norm1.weight": "model-00001-of-00008.safetensors",
|
| 667 |
+
"model.visual.blocks.4.norm2.bias": "model-00001-of-00008.safetensors",
|
| 668 |
+
"model.visual.blocks.4.norm2.weight": "model-00001-of-00008.safetensors",
|
| 669 |
+
"model.visual.blocks.5.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 670 |
+
"model.visual.blocks.5.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 671 |
+
"model.visual.blocks.5.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 672 |
+
"model.visual.blocks.5.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 673 |
+
"model.visual.blocks.5.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 674 |
+
"model.visual.blocks.5.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 675 |
+
"model.visual.blocks.5.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 676 |
+
"model.visual.blocks.5.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 677 |
+
"model.visual.blocks.5.norm1.bias": "model-00001-of-00008.safetensors",
|
| 678 |
+
"model.visual.blocks.5.norm1.weight": "model-00001-of-00008.safetensors",
|
| 679 |
+
"model.visual.blocks.5.norm2.bias": "model-00001-of-00008.safetensors",
|
| 680 |
+
"model.visual.blocks.5.norm2.weight": "model-00001-of-00008.safetensors",
|
| 681 |
+
"model.visual.blocks.6.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 682 |
+
"model.visual.blocks.6.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 683 |
+
"model.visual.blocks.6.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 684 |
+
"model.visual.blocks.6.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 685 |
+
"model.visual.blocks.6.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 686 |
+
"model.visual.blocks.6.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 687 |
+
"model.visual.blocks.6.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 688 |
+
"model.visual.blocks.6.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 689 |
+
"model.visual.blocks.6.norm1.bias": "model-00001-of-00008.safetensors",
|
| 690 |
+
"model.visual.blocks.6.norm1.weight": "model-00001-of-00008.safetensors",
|
| 691 |
+
"model.visual.blocks.6.norm2.bias": "model-00001-of-00008.safetensors",
|
| 692 |
+
"model.visual.blocks.6.norm2.weight": "model-00001-of-00008.safetensors",
|
| 693 |
+
"model.visual.blocks.7.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 694 |
+
"model.visual.blocks.7.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 695 |
+
"model.visual.blocks.7.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 696 |
+
"model.visual.blocks.7.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 697 |
+
"model.visual.blocks.7.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 698 |
+
"model.visual.blocks.7.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 699 |
+
"model.visual.blocks.7.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 700 |
+
"model.visual.blocks.7.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 701 |
+
"model.visual.blocks.7.norm1.bias": "model-00001-of-00008.safetensors",
|
| 702 |
+
"model.visual.blocks.7.norm1.weight": "model-00001-of-00008.safetensors",
|
| 703 |
+
"model.visual.blocks.7.norm2.bias": "model-00001-of-00008.safetensors",
|
| 704 |
+
"model.visual.blocks.7.norm2.weight": "model-00001-of-00008.safetensors",
|
| 705 |
+
"model.visual.blocks.8.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 706 |
+
"model.visual.blocks.8.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 707 |
+
"model.visual.blocks.8.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 708 |
+
"model.visual.blocks.8.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 709 |
+
"model.visual.blocks.8.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 710 |
+
"model.visual.blocks.8.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 711 |
+
"model.visual.blocks.8.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 712 |
+
"model.visual.blocks.8.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 713 |
+
"model.visual.blocks.8.norm1.bias": "model-00001-of-00008.safetensors",
|
| 714 |
+
"model.visual.blocks.8.norm1.weight": "model-00001-of-00008.safetensors",
|
| 715 |
+
"model.visual.blocks.8.norm2.bias": "model-00001-of-00008.safetensors",
|
| 716 |
+
"model.visual.blocks.8.norm2.weight": "model-00001-of-00008.safetensors",
|
| 717 |
+
"model.visual.blocks.9.attn.proj.bias": "model-00001-of-00008.safetensors",
|
| 718 |
+
"model.visual.blocks.9.attn.proj.weight": "model-00001-of-00008.safetensors",
|
| 719 |
+
"model.visual.blocks.9.attn.qkv.bias": "model-00001-of-00008.safetensors",
|
| 720 |
+
"model.visual.blocks.9.attn.qkv.weight": "model-00001-of-00008.safetensors",
|
| 721 |
+
"model.visual.blocks.9.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 722 |
+
"model.visual.blocks.9.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 723 |
+
"model.visual.blocks.9.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 724 |
+
"model.visual.blocks.9.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 725 |
+
"model.visual.blocks.9.norm1.bias": "model-00001-of-00008.safetensors",
|
| 726 |
+
"model.visual.blocks.9.norm1.weight": "model-00001-of-00008.safetensors",
|
| 727 |
+
"model.visual.blocks.9.norm2.bias": "model-00001-of-00008.safetensors",
|
| 728 |
+
"model.visual.blocks.9.norm2.weight": "model-00001-of-00008.safetensors",
|
| 729 |
+
"model.visual.deepstack_merger_list.0.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 730 |
+
"model.visual.deepstack_merger_list.0.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 731 |
+
"model.visual.deepstack_merger_list.0.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 732 |
+
"model.visual.deepstack_merger_list.0.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 733 |
+
"model.visual.deepstack_merger_list.0.norm.bias": "model-00001-of-00008.safetensors",
|
| 734 |
+
"model.visual.deepstack_merger_list.0.norm.weight": "model-00001-of-00008.safetensors",
|
| 735 |
+
"model.visual.deepstack_merger_list.1.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 736 |
+
"model.visual.deepstack_merger_list.1.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 737 |
+
"model.visual.deepstack_merger_list.1.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 738 |
+
"model.visual.deepstack_merger_list.1.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 739 |
+
"model.visual.deepstack_merger_list.1.norm.bias": "model-00001-of-00008.safetensors",
|
| 740 |
+
"model.visual.deepstack_merger_list.1.norm.weight": "model-00001-of-00008.safetensors",
|
| 741 |
+
"model.visual.deepstack_merger_list.2.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 742 |
+
"model.visual.deepstack_merger_list.2.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 743 |
+
"model.visual.deepstack_merger_list.2.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 744 |
+
"model.visual.deepstack_merger_list.2.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 745 |
+
"model.visual.deepstack_merger_list.2.norm.bias": "model-00001-of-00008.safetensors",
|
| 746 |
+
"model.visual.deepstack_merger_list.2.norm.weight": "model-00001-of-00008.safetensors",
|
| 747 |
+
"model.visual.merger.linear_fc1.bias": "model-00001-of-00008.safetensors",
|
| 748 |
+
"model.visual.merger.linear_fc1.weight": "model-00001-of-00008.safetensors",
|
| 749 |
+
"model.visual.merger.linear_fc2.bias": "model-00001-of-00008.safetensors",
|
| 750 |
+
"model.visual.merger.linear_fc2.weight": "model-00001-of-00008.safetensors",
|
| 751 |
+
"model.visual.merger.norm.bias": "model-00001-of-00008.safetensors",
|
| 752 |
+
"model.visual.merger.norm.weight": "model-00001-of-00008.safetensors",
|
| 753 |
+
"model.visual.patch_embed.proj.bias": "model-00001-of-00008.safetensors",
|
| 754 |
+
"model.visual.patch_embed.proj.weight": "model-00001-of-00008.safetensors",
|
| 755 |
+
"model.visual.pos_embed.weight": "model-00001-of-00008.safetensors"
|
| 756 |
+
}
|
| 757 |
+
}
|
ICL/sft_model/epoch3_step1406_fp32/preprocessor_config.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"size": {
|
| 3 |
+
"longest_edge": 16777216,
|
| 4 |
+
"shortest_edge": 65536
|
| 5 |
+
},
|
| 6 |
+
"patch_size": 16,
|
| 7 |
+
"temporal_patch_size": 2,
|
| 8 |
+
"merge_size": 2,
|
| 9 |
+
"image_mean": [
|
| 10 |
+
0.5,
|
| 11 |
+
0.5,
|
| 12 |
+
0.5
|
| 13 |
+
],
|
| 14 |
+
"image_std": [
|
| 15 |
+
0.5,
|
| 16 |
+
0.5,
|
| 17 |
+
0.5
|
| 18 |
+
],
|
| 19 |
+
"processor_class": "Qwen3VLProcessor",
|
| 20 |
+
"image_processor_type": "Qwen2VLImageProcessorFast"
|
| 21 |
+
}
|
ICL/sft_model/epoch3_step1406_fp32/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ICL/sft_model/epoch3_step1406_fp32/tokenizer_config.json
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_prefix_space": false,
|
| 4 |
+
"added_tokens_decoder": {
|
| 5 |
+
"151643": {
|
| 6 |
+
"content": "<|endoftext|>",
|
| 7 |
+
"lstrip": false,
|
| 8 |
+
"normalized": false,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false,
|
| 11 |
+
"special": true
|
| 12 |
+
},
|
| 13 |
+
"151644": {
|
| 14 |
+
"content": "<|im_start|>",
|
| 15 |
+
"lstrip": false,
|
| 16 |
+
"normalized": false,
|
| 17 |
+
"rstrip": false,
|
| 18 |
+
"single_word": false,
|
| 19 |
+
"special": true
|
| 20 |
+
},
|
| 21 |
+
"151645": {
|
| 22 |
+
"content": "<|im_end|>",
|
| 23 |
+
"lstrip": false,
|
| 24 |
+
"normalized": false,
|
| 25 |
+
"rstrip": false,
|
| 26 |
+
"single_word": false,
|
| 27 |
+
"special": true
|
| 28 |
+
},
|
| 29 |
+
"151646": {
|
| 30 |
+
"content": "<|object_ref_start|>",
|
| 31 |
+
"lstrip": false,
|
| 32 |
+
"normalized": false,
|
| 33 |
+
"rstrip": false,
|
| 34 |
+
"single_word": false,
|
| 35 |
+
"special": true
|
| 36 |
+
},
|
| 37 |
+
"151647": {
|
| 38 |
+
"content": "<|object_ref_end|>",
|
| 39 |
+
"lstrip": false,
|
| 40 |
+
"normalized": false,
|
| 41 |
+
"rstrip": false,
|
| 42 |
+
"single_word": false,
|
| 43 |
+
"special": true
|
| 44 |
+
},
|
| 45 |
+
"151648": {
|
| 46 |
+
"content": "<|box_start|>",
|
| 47 |
+
"lstrip": false,
|
| 48 |
+
"normalized": false,
|
| 49 |
+
"rstrip": false,
|
| 50 |
+
"single_word": false,
|
| 51 |
+
"special": true
|
| 52 |
+
},
|
| 53 |
+
"151649": {
|
| 54 |
+
"content": "<|box_end|>",
|
| 55 |
+
"lstrip": false,
|
| 56 |
+
"normalized": false,
|
| 57 |
+
"rstrip": false,
|
| 58 |
+
"single_word": false,
|
| 59 |
+
"special": true
|
| 60 |
+
},
|
| 61 |
+
"151650": {
|
| 62 |
+
"content": "<|quad_start|>",
|
| 63 |
+
"lstrip": false,
|
| 64 |
+
"normalized": false,
|
| 65 |
+
"rstrip": false,
|
| 66 |
+
"single_word": false,
|
| 67 |
+
"special": true
|
| 68 |
+
},
|
| 69 |
+
"151651": {
|
| 70 |
+
"content": "<|quad_end|>",
|
| 71 |
+
"lstrip": false,
|
| 72 |
+
"normalized": false,
|
| 73 |
+
"rstrip": false,
|
| 74 |
+
"single_word": false,
|
| 75 |
+
"special": true
|
| 76 |
+
},
|
| 77 |
+
"151652": {
|
| 78 |
+
"content": "<|vision_start|>",
|
| 79 |
+
"lstrip": false,
|
| 80 |
+
"normalized": false,
|
| 81 |
+
"rstrip": false,
|
| 82 |
+
"single_word": false,
|
| 83 |
+
"special": true
|
| 84 |
+
},
|
| 85 |
+
"151653": {
|
| 86 |
+
"content": "<|vision_end|>",
|
| 87 |
+
"lstrip": false,
|
| 88 |
+
"normalized": false,
|
| 89 |
+
"rstrip": false,
|
| 90 |
+
"single_word": false,
|
| 91 |
+
"special": true
|
| 92 |
+
},
|
| 93 |
+
"151654": {
|
| 94 |
+
"content": "<|vision_pad|>",
|
| 95 |
+
"lstrip": false,
|
| 96 |
+
"normalized": false,
|
| 97 |
+
"rstrip": false,
|
| 98 |
+
"single_word": false,
|
| 99 |
+
"special": true
|
| 100 |
+
},
|
| 101 |
+
"151655": {
|
| 102 |
+
"content": "<|image_pad|>",
|
| 103 |
+
"lstrip": false,
|
| 104 |
+
"normalized": false,
|
| 105 |
+
"rstrip": false,
|
| 106 |
+
"single_word": false,
|
| 107 |
+
"special": true
|
| 108 |
+
},
|
| 109 |
+
"151656": {
|
| 110 |
+
"content": "<|video_pad|>",
|
| 111 |
+
"lstrip": false,
|
| 112 |
+
"normalized": false,
|
| 113 |
+
"rstrip": false,
|
| 114 |
+
"single_word": false,
|
| 115 |
+
"special": true
|
| 116 |
+
},
|
| 117 |
+
"151657": {
|
| 118 |
+
"content": "<tool_call>",
|
| 119 |
+
"lstrip": false,
|
| 120 |
+
"normalized": false,
|
| 121 |
+
"rstrip": false,
|
| 122 |
+
"single_word": false,
|
| 123 |
+
"special": false
|
| 124 |
+
},
|
| 125 |
+
"151658": {
|
| 126 |
+
"content": "</tool_call>",
|
| 127 |
+
"lstrip": false,
|
| 128 |
+
"normalized": false,
|
| 129 |
+
"rstrip": false,
|
| 130 |
+
"single_word": false,
|
| 131 |
+
"special": false
|
| 132 |
+
},
|
| 133 |
+
"151659": {
|
| 134 |
+
"content": "<|fim_prefix|>",
|
| 135 |
+
"lstrip": false,
|
| 136 |
+
"normalized": false,
|
| 137 |
+
"rstrip": false,
|
| 138 |
+
"single_word": false,
|
| 139 |
+
"special": false
|
| 140 |
+
},
|
| 141 |
+
"151660": {
|
| 142 |
+
"content": "<|fim_middle|>",
|
| 143 |
+
"lstrip": false,
|
| 144 |
+
"normalized": false,
|
| 145 |
+
"rstrip": false,
|
| 146 |
+
"single_word": false,
|
| 147 |
+
"special": false
|
| 148 |
+
},
|
| 149 |
+
"151661": {
|
| 150 |
+
"content": "<|fim_suffix|>",
|
| 151 |
+
"lstrip": false,
|
| 152 |
+
"normalized": false,
|
| 153 |
+
"rstrip": false,
|
| 154 |
+
"single_word": false,
|
| 155 |
+
"special": false
|
| 156 |
+
},
|
| 157 |
+
"151662": {
|
| 158 |
+
"content": "<|fim_pad|>",
|
| 159 |
+
"lstrip": false,
|
| 160 |
+
"normalized": false,
|
| 161 |
+
"rstrip": false,
|
| 162 |
+
"single_word": false,
|
| 163 |
+
"special": false
|
| 164 |
+
},
|
| 165 |
+
"151663": {
|
| 166 |
+
"content": "<|repo_name|>",
|
| 167 |
+
"lstrip": false,
|
| 168 |
+
"normalized": false,
|
| 169 |
+
"rstrip": false,
|
| 170 |
+
"single_word": false,
|
| 171 |
+
"special": false
|
| 172 |
+
},
|
| 173 |
+
"151664": {
|
| 174 |
+
"content": "<|file_sep|>",
|
| 175 |
+
"lstrip": false,
|
| 176 |
+
"normalized": false,
|
| 177 |
+
"rstrip": false,
|
| 178 |
+
"single_word": false,
|
| 179 |
+
"special": false
|
| 180 |
+
},
|
| 181 |
+
"151665": {
|
| 182 |
+
"content": "<tool_response>",
|
| 183 |
+
"lstrip": false,
|
| 184 |
+
"normalized": false,
|
| 185 |
+
"rstrip": false,
|
| 186 |
+
"single_word": false,
|
| 187 |
+
"special": false
|
| 188 |
+
},
|
| 189 |
+
"151666": {
|
| 190 |
+
"content": "</tool_response>",
|
| 191 |
+
"lstrip": false,
|
| 192 |
+
"normalized": false,
|
| 193 |
+
"rstrip": false,
|
| 194 |
+
"single_word": false,
|
| 195 |
+
"special": false
|
| 196 |
+
},
|
| 197 |
+
"151667": {
|
| 198 |
+
"content": "<think>",
|
| 199 |
+
"lstrip": false,
|
| 200 |
+
"normalized": false,
|
| 201 |
+
"rstrip": false,
|
| 202 |
+
"single_word": false,
|
| 203 |
+
"special": false
|
| 204 |
+
},
|
| 205 |
+
"151668": {
|
| 206 |
+
"content": "</think>",
|
| 207 |
+
"lstrip": false,
|
| 208 |
+
"normalized": false,
|
| 209 |
+
"rstrip": false,
|
| 210 |
+
"single_word": false,
|
| 211 |
+
"special": false
|
| 212 |
+
}
|
| 213 |
+
},
|
| 214 |
+
"additional_special_tokens": [
|
| 215 |
+
"<|im_start|>",
|
| 216 |
+
"<|im_end|>",
|
| 217 |
+
"<|object_ref_start|>",
|
| 218 |
+
"<|object_ref_end|>",
|
| 219 |
+
"<|box_start|>",
|
| 220 |
+
"<|box_end|>",
|
| 221 |
+
"<|quad_start|>",
|
| 222 |
+
"<|quad_end|>",
|
| 223 |
+
"<|vision_start|>",
|
| 224 |
+
"<|vision_end|>",
|
| 225 |
+
"<|vision_pad|>",
|
| 226 |
+
"<|image_pad|>",
|
| 227 |
+
"<|video_pad|>"
|
| 228 |
+
],
|
| 229 |
+
"bos_token": null,
|
| 230 |
+
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {%- if messages[0].content is string %}\n {{- messages[0].content }}\n {%- else %}\n {%- for content in messages[0].content %}\n {%- if 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].content is string %}\n {{- messages[0].content }}\n {%- else %}\n {%- for content in messages[0].content %}\n {%- if 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set image_count = namespace(value=0) %}\n{%- set video_count = namespace(value=0) %}\n{%- for message in messages %}\n {%- if message.role == \"user\" %}\n {{- '<|im_start|>' + message.role + '\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content in message.content %}\n {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}\n {%- set image_count.value = image_count.value + 1 %}\n {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}\n <|vision_start|><|image_pad|><|vision_end|>\n {%- elif content.type == 'video' or 'video' in content %}\n {%- set video_count.value = video_count.value + 1 %}\n {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}\n <|vision_start|><|video_pad|><|vision_end|>\n {%- elif 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content_item in message.content %}\n {%- if 'text' in content_item %}\n {{- content_item.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and message.content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content in message.content %}\n {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}\n {%- set image_count.value = image_count.value + 1 %}\n {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}\n <|vision_start|><|image_pad|><|vision_end|>\n {%- elif content.type == 'video' or 'video' in content %}\n {%- set video_count.value = video_count.value + 1 %}\n {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}\n <|vision_start|><|video_pad|><|vision_end|>\n {%- elif 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
|
| 231 |
+
"clean_up_tokenization_spaces": false,
|
| 232 |
+
"eos_token": "<|im_end|>",
|
| 233 |
+
"errors": "replace",
|
| 234 |
+
"model_max_length": 262144,
|
| 235 |
+
"pad_token": "<|endoftext|>",
|
| 236 |
+
"split_special_tokens": false,
|
| 237 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 238 |
+
"unk_token": null
|
| 239 |
+
}
|
ICL/sft_model/epoch3_step1406_fp32/video_preprocessor_config.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"size": {
|
| 3 |
+
"longest_edge": 25165824,
|
| 4 |
+
"shortest_edge": 4096
|
| 5 |
+
},
|
| 6 |
+
"patch_size": 16,
|
| 7 |
+
"temporal_patch_size": 2,
|
| 8 |
+
"merge_size": 2,
|
| 9 |
+
"image_mean": [
|
| 10 |
+
0.5,
|
| 11 |
+
0.5,
|
| 12 |
+
0.5
|
| 13 |
+
],
|
| 14 |
+
"image_std": [
|
| 15 |
+
0.5,
|
| 16 |
+
0.5,
|
| 17 |
+
0.5
|
| 18 |
+
],
|
| 19 |
+
"processor_class": "Qwen3VLProcessor",
|
| 20 |
+
"video_processor_type": "Qwen3VLVideoProcessor"
|
| 21 |
+
}
|
ICL/sft_model/epoch3_step1406_fp32/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ICL/sft_model/zero_to_fp32.py
ADDED
|
@@ -0,0 +1,760 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright (c) Microsoft Corporation.
|
| 4 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 5 |
+
|
| 6 |
+
# DeepSpeed Team
|
| 7 |
+
|
| 8 |
+
# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
|
| 9 |
+
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
|
| 10 |
+
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
|
| 11 |
+
# application.
|
| 12 |
+
#
|
| 13 |
+
# example:
|
| 14 |
+
# python zero_to_fp32.py . output_dir/
|
| 15 |
+
# or
|
| 16 |
+
# python zero_to_fp32.py . output_dir/ --safe_serialization
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import torch
|
| 20 |
+
import glob
|
| 21 |
+
import math
|
| 22 |
+
import os
|
| 23 |
+
import re
|
| 24 |
+
import gc
|
| 25 |
+
import json
|
| 26 |
+
import numpy as np
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
+
from collections import OrderedDict
|
| 29 |
+
from dataclasses import dataclass
|
| 30 |
+
|
| 31 |
+
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
|
| 32 |
+
# DeepSpeed data structures it has to be available in the current python environment.
|
| 33 |
+
from deepspeed.utils import logger
|
| 34 |
+
from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
|
| 35 |
+
FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
|
| 36 |
+
FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class zero_model_state:
|
| 41 |
+
buffers: dict()
|
| 42 |
+
param_shapes: dict()
|
| 43 |
+
shared_params: list
|
| 44 |
+
ds_version: int
|
| 45 |
+
frozen_param_shapes: dict()
|
| 46 |
+
frozen_param_fragments: dict()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
debug = 0
|
| 50 |
+
|
| 51 |
+
# load to cpu
|
| 52 |
+
device = torch.device('cpu')
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def atoi(text):
|
| 56 |
+
return int(text) if text.isdigit() else text
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def natural_keys(text):
|
| 60 |
+
'''
|
| 61 |
+
alist.sort(key=natural_keys) sorts in human order
|
| 62 |
+
http://nedbatchelder.com/blog/200712/human_sorting.html
|
| 63 |
+
(See Toothy's implementation in the comments)
|
| 64 |
+
'''
|
| 65 |
+
return [atoi(c) for c in re.split(r'(\d+)', text)]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_model_state_file(checkpoint_dir, zero_stage):
|
| 69 |
+
if not os.path.isdir(checkpoint_dir):
|
| 70 |
+
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
|
| 71 |
+
|
| 72 |
+
# there should be only one file
|
| 73 |
+
if zero_stage <= 2:
|
| 74 |
+
file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
|
| 75 |
+
elif zero_stage == 3:
|
| 76 |
+
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
|
| 77 |
+
|
| 78 |
+
if not os.path.exists(file):
|
| 79 |
+
raise FileNotFoundError(f"can't find model states file at '{file}'")
|
| 80 |
+
|
| 81 |
+
return file
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_checkpoint_files(checkpoint_dir, glob_pattern):
|
| 85 |
+
# XXX: need to test that this simple glob rule works for multi-node setup too
|
| 86 |
+
ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
|
| 87 |
+
|
| 88 |
+
if len(ckpt_files) == 0:
|
| 89 |
+
raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
|
| 90 |
+
|
| 91 |
+
return ckpt_files
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_optim_files(checkpoint_dir):
|
| 95 |
+
return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_model_state_files(checkpoint_dir):
|
| 99 |
+
return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def parse_model_states(files):
|
| 103 |
+
zero_model_states = []
|
| 104 |
+
for file in files:
|
| 105 |
+
state_dict = torch.load(file, map_location=device, weights_only=False)
|
| 106 |
+
|
| 107 |
+
if BUFFER_NAMES not in state_dict:
|
| 108 |
+
raise ValueError(f"{file} is not a model state checkpoint")
|
| 109 |
+
buffer_names = state_dict[BUFFER_NAMES]
|
| 110 |
+
if debug:
|
| 111 |
+
print("Found buffers:", buffer_names)
|
| 112 |
+
|
| 113 |
+
# recover just the buffers while restoring them to fp32 if they were saved in fp16
|
| 114 |
+
buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
|
| 115 |
+
param_shapes = state_dict[PARAM_SHAPES]
|
| 116 |
+
|
| 117 |
+
# collect parameters that are included in param_shapes
|
| 118 |
+
param_names = []
|
| 119 |
+
for s in param_shapes:
|
| 120 |
+
for name in s.keys():
|
| 121 |
+
param_names.append(name)
|
| 122 |
+
|
| 123 |
+
# update with frozen parameters
|
| 124 |
+
frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
|
| 125 |
+
if frozen_param_shapes is not None:
|
| 126 |
+
if debug:
|
| 127 |
+
print(f"Found frozen_param_shapes: {frozen_param_shapes}")
|
| 128 |
+
param_names += list(frozen_param_shapes.keys())
|
| 129 |
+
|
| 130 |
+
# handle shared params
|
| 131 |
+
shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
|
| 132 |
+
|
| 133 |
+
ds_version = state_dict.get(DS_VERSION, None)
|
| 134 |
+
|
| 135 |
+
frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
|
| 136 |
+
|
| 137 |
+
z_model_state = zero_model_state(buffers=buffers,
|
| 138 |
+
param_shapes=param_shapes,
|
| 139 |
+
shared_params=shared_params,
|
| 140 |
+
ds_version=ds_version,
|
| 141 |
+
frozen_param_shapes=frozen_param_shapes,
|
| 142 |
+
frozen_param_fragments=frozen_param_fragments)
|
| 143 |
+
zero_model_states.append(z_model_state)
|
| 144 |
+
|
| 145 |
+
return zero_model_states
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def parse_optim_states(files, ds_checkpoint_dir):
|
| 149 |
+
total_files = len(files)
|
| 150 |
+
state_dicts = []
|
| 151 |
+
for f in tqdm(files, desc='Loading checkpoint shards'):
|
| 152 |
+
state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
|
| 153 |
+
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
|
| 154 |
+
# and also handle the case where it was already removed by another helper script
|
| 155 |
+
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
|
| 156 |
+
state_dicts.append(state_dict)
|
| 157 |
+
|
| 158 |
+
if ZERO_STAGE not in state_dicts[0][OPTIMIZER_STATE_DICT]:
|
| 159 |
+
raise ValueError(f"{files[0]} is not a zero checkpoint")
|
| 160 |
+
zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
|
| 161 |
+
world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
|
| 162 |
+
|
| 163 |
+
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert
|
| 164 |
+
# parameters can be different from data parallelism for non-expert parameters. So we can just
|
| 165 |
+
# use the max of the partition_count to get the dp world_size.
|
| 166 |
+
|
| 167 |
+
if type(world_size) is list:
|
| 168 |
+
world_size = max(world_size)
|
| 169 |
+
|
| 170 |
+
if world_size != total_files:
|
| 171 |
+
raise ValueError(
|
| 172 |
+
f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
|
| 173 |
+
"Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# the groups are named differently in each stage
|
| 177 |
+
if zero_stage <= 2:
|
| 178 |
+
fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
|
| 179 |
+
elif zero_stage == 3:
|
| 180 |
+
fp32_groups_key = FP32_FLAT_GROUPS
|
| 181 |
+
else:
|
| 182 |
+
raise ValueError(f"unknown zero stage {zero_stage}")
|
| 183 |
+
|
| 184 |
+
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
|
| 185 |
+
return zero_stage, world_size, fp32_flat_groups
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
|
| 189 |
+
"""
|
| 190 |
+
Returns fp32 state_dict reconstructed from ds checkpoint
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
- ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
|
| 194 |
+
|
| 195 |
+
"""
|
| 196 |
+
print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
|
| 197 |
+
|
| 198 |
+
optim_files = get_optim_files(ds_checkpoint_dir)
|
| 199 |
+
zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
|
| 200 |
+
print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
|
| 201 |
+
|
| 202 |
+
model_files = get_model_state_files(ds_checkpoint_dir)
|
| 203 |
+
|
| 204 |
+
zero_model_states = parse_model_states(model_files)
|
| 205 |
+
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
|
| 206 |
+
|
| 207 |
+
if zero_stage <= 2:
|
| 208 |
+
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
| 209 |
+
exclude_frozen_parameters)
|
| 210 |
+
elif zero_stage == 3:
|
| 211 |
+
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
| 212 |
+
exclude_frozen_parameters)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _zero2_merge_frozen_params(state_dict, zero_model_states):
|
| 216 |
+
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
| 217 |
+
return
|
| 218 |
+
|
| 219 |
+
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
| 220 |
+
frozen_param_fragments = zero_model_states[0].frozen_param_fragments
|
| 221 |
+
|
| 222 |
+
if debug:
|
| 223 |
+
num_elem = sum(s.numel() for s in frozen_param_shapes.values())
|
| 224 |
+
print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
|
| 225 |
+
|
| 226 |
+
wanted_params = len(frozen_param_shapes)
|
| 227 |
+
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
|
| 228 |
+
avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
|
| 229 |
+
print(f'Frozen params: Have {avail_numel} numels to process.')
|
| 230 |
+
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
|
| 231 |
+
|
| 232 |
+
total_params = 0
|
| 233 |
+
total_numel = 0
|
| 234 |
+
for name, shape in frozen_param_shapes.items():
|
| 235 |
+
total_params += 1
|
| 236 |
+
unpartitioned_numel = shape.numel()
|
| 237 |
+
total_numel += unpartitioned_numel
|
| 238 |
+
|
| 239 |
+
state_dict[name] = frozen_param_fragments[name]
|
| 240 |
+
|
| 241 |
+
if debug:
|
| 242 |
+
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
|
| 243 |
+
|
| 244 |
+
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def _has_callable(obj, fn):
|
| 248 |
+
attr = getattr(obj, fn, None)
|
| 249 |
+
return callable(attr)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
|
| 253 |
+
param_shapes = zero_model_states[0].param_shapes
|
| 254 |
+
|
| 255 |
+
# Reconstruction protocol:
|
| 256 |
+
#
|
| 257 |
+
# XXX: document this
|
| 258 |
+
|
| 259 |
+
if debug:
|
| 260 |
+
for i in range(world_size):
|
| 261 |
+
for j in range(len(fp32_flat_groups[0])):
|
| 262 |
+
print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
|
| 263 |
+
|
| 264 |
+
# XXX: memory usage doubles here (zero2)
|
| 265 |
+
num_param_groups = len(fp32_flat_groups[0])
|
| 266 |
+
merged_single_partition_of_fp32_groups = []
|
| 267 |
+
for i in range(num_param_groups):
|
| 268 |
+
merged_partitions = [sd[i] for sd in fp32_flat_groups]
|
| 269 |
+
full_single_fp32_vector = torch.cat(merged_partitions, 0)
|
| 270 |
+
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
|
| 271 |
+
avail_numel = sum(
|
| 272 |
+
[full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
|
| 273 |
+
|
| 274 |
+
if debug:
|
| 275 |
+
wanted_params = sum([len(shapes) for shapes in param_shapes])
|
| 276 |
+
wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
|
| 277 |
+
# not asserting if there is a mismatch due to possible padding
|
| 278 |
+
print(f"Have {avail_numel} numels to process.")
|
| 279 |
+
print(f"Need {wanted_numel} numels in {wanted_params} params.")
|
| 280 |
+
|
| 281 |
+
# params
|
| 282 |
+
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
|
| 283 |
+
# out-of-core computing solution
|
| 284 |
+
total_numel = 0
|
| 285 |
+
total_params = 0
|
| 286 |
+
for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
|
| 287 |
+
offset = 0
|
| 288 |
+
avail_numel = full_single_fp32_vector.numel()
|
| 289 |
+
for name, shape in shapes.items():
|
| 290 |
+
|
| 291 |
+
unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
|
| 292 |
+
total_numel += unpartitioned_numel
|
| 293 |
+
total_params += 1
|
| 294 |
+
|
| 295 |
+
if debug:
|
| 296 |
+
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
|
| 297 |
+
state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
|
| 298 |
+
offset += unpartitioned_numel
|
| 299 |
+
|
| 300 |
+
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
|
| 301 |
+
# avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
|
| 302 |
+
# paddings performed in the code it's almost impossible to predict the exact numbers w/o the
|
| 303 |
+
# live optimizer object, so we are checking that the numbers are within the right range
|
| 304 |
+
align_to = 2 * world_size
|
| 305 |
+
|
| 306 |
+
def zero2_align(x):
|
| 307 |
+
return align_to * math.ceil(x / align_to)
|
| 308 |
+
|
| 309 |
+
if debug:
|
| 310 |
+
print(f"original offset={offset}, avail_numel={avail_numel}")
|
| 311 |
+
|
| 312 |
+
offset = zero2_align(offset)
|
| 313 |
+
avail_numel = zero2_align(avail_numel)
|
| 314 |
+
|
| 315 |
+
if debug:
|
| 316 |
+
print(f"aligned offset={offset}, avail_numel={avail_numel}")
|
| 317 |
+
|
| 318 |
+
# Sanity check
|
| 319 |
+
if offset != avail_numel:
|
| 320 |
+
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
| 321 |
+
|
| 322 |
+
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
| 326 |
+
exclude_frozen_parameters):
|
| 327 |
+
state_dict = OrderedDict()
|
| 328 |
+
|
| 329 |
+
# buffers
|
| 330 |
+
buffers = zero_model_states[0].buffers
|
| 331 |
+
state_dict.update(buffers)
|
| 332 |
+
if debug:
|
| 333 |
+
print(f"added {len(buffers)} buffers")
|
| 334 |
+
|
| 335 |
+
if not exclude_frozen_parameters:
|
| 336 |
+
_zero2_merge_frozen_params(state_dict, zero_model_states)
|
| 337 |
+
|
| 338 |
+
_zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
| 339 |
+
|
| 340 |
+
# recover shared parameters
|
| 341 |
+
for pair in zero_model_states[0].shared_params:
|
| 342 |
+
if pair[1] in state_dict:
|
| 343 |
+
state_dict[pair[0]] = state_dict[pair[1]]
|
| 344 |
+
|
| 345 |
+
return state_dict
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def zero3_partitioned_param_info(unpartitioned_numel, world_size):
|
| 349 |
+
remainder = unpartitioned_numel % world_size
|
| 350 |
+
padding_numel = (world_size - remainder) if remainder else 0
|
| 351 |
+
partitioned_numel = math.ceil(unpartitioned_numel / world_size)
|
| 352 |
+
return partitioned_numel, padding_numel
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
|
| 356 |
+
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
| 357 |
+
return
|
| 358 |
+
|
| 359 |
+
if debug:
|
| 360 |
+
for i in range(world_size):
|
| 361 |
+
num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
|
| 362 |
+
print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
|
| 363 |
+
|
| 364 |
+
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
| 365 |
+
wanted_params = len(frozen_param_shapes)
|
| 366 |
+
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
|
| 367 |
+
avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
|
| 368 |
+
print(f'Frozen params: Have {avail_numel} numels to process.')
|
| 369 |
+
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
|
| 370 |
+
|
| 371 |
+
total_params = 0
|
| 372 |
+
total_numel = 0
|
| 373 |
+
for name, shape in zero_model_states[0].frozen_param_shapes.items():
|
| 374 |
+
total_params += 1
|
| 375 |
+
unpartitioned_numel = shape.numel()
|
| 376 |
+
total_numel += unpartitioned_numel
|
| 377 |
+
|
| 378 |
+
param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
|
| 379 |
+
state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
|
| 380 |
+
|
| 381 |
+
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
| 382 |
+
|
| 383 |
+
if debug:
|
| 384 |
+
print(
|
| 385 |
+
f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class GatheredTensor:
|
| 392 |
+
"""
|
| 393 |
+
A pseudo tensor that collects partitioned weights.
|
| 394 |
+
It is more memory efficient when there are multiple groups.
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
|
| 398 |
+
self.flat_groups = flat_groups
|
| 399 |
+
self.flat_groups_offset = flat_groups_offset
|
| 400 |
+
self.offset = offset
|
| 401 |
+
self.partitioned_numel = partitioned_numel
|
| 402 |
+
self.shape = shape
|
| 403 |
+
self.dtype = self.flat_groups[0][0].dtype
|
| 404 |
+
|
| 405 |
+
def contiguous(self):
|
| 406 |
+
"""
|
| 407 |
+
Merge partitioned weights from flat_groups into a single tensor.
|
| 408 |
+
"""
|
| 409 |
+
end_idx = self.offset + self.partitioned_numel
|
| 410 |
+
world_size = len(self.flat_groups)
|
| 411 |
+
pad_flat_param_chunks = []
|
| 412 |
+
|
| 413 |
+
for rank_i in range(world_size):
|
| 414 |
+
# for each rank, we need to collect weights from related group/groups
|
| 415 |
+
flat_groups_at_rank_i = self.flat_groups[rank_i]
|
| 416 |
+
start_group_id = None
|
| 417 |
+
end_group_id = None
|
| 418 |
+
for group_id in range(len(self.flat_groups_offset)):
|
| 419 |
+
if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
|
| 420 |
+
start_group_id = group_id
|
| 421 |
+
if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
|
| 422 |
+
end_group_id = group_id
|
| 423 |
+
break
|
| 424 |
+
# collect weights from related group/groups
|
| 425 |
+
for group_id in range(start_group_id, end_group_id + 1):
|
| 426 |
+
flat_tensor = flat_groups_at_rank_i[group_id]
|
| 427 |
+
start_offset = self.offset - self.flat_groups_offset[group_id]
|
| 428 |
+
end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
|
| 429 |
+
pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
|
| 430 |
+
|
| 431 |
+
# collect weights from all ranks
|
| 432 |
+
pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
|
| 433 |
+
param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
|
| 434 |
+
return param
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
|
| 438 |
+
param_shapes = zero_model_states[0].param_shapes
|
| 439 |
+
avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
|
| 440 |
+
|
| 441 |
+
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
|
| 442 |
+
# param, re-consolidating each param, while dealing with padding if any
|
| 443 |
+
|
| 444 |
+
# merge list of dicts, preserving order
|
| 445 |
+
param_shapes = {k: v for d in param_shapes for k, v in d.items()}
|
| 446 |
+
|
| 447 |
+
if debug:
|
| 448 |
+
for i in range(world_size):
|
| 449 |
+
print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
|
| 450 |
+
|
| 451 |
+
wanted_params = len(param_shapes)
|
| 452 |
+
wanted_numel = sum(shape.numel() for shape in param_shapes.values())
|
| 453 |
+
# not asserting if there is a mismatch due to possible padding
|
| 454 |
+
avail_numel = fp32_flat_groups[0].numel() * world_size
|
| 455 |
+
print(f"Trainable params: Have {avail_numel} numels to process.")
|
| 456 |
+
print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
|
| 457 |
+
|
| 458 |
+
# params
|
| 459 |
+
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
|
| 460 |
+
# out-of-core computing solution
|
| 461 |
+
offset = 0
|
| 462 |
+
total_numel = 0
|
| 463 |
+
total_params = 0
|
| 464 |
+
flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
|
| 465 |
+
for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
|
| 466 |
+
unpartitioned_numel = shape.numel()
|
| 467 |
+
total_numel += unpartitioned_numel
|
| 468 |
+
total_params += 1
|
| 469 |
+
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
| 470 |
+
|
| 471 |
+
if debug:
|
| 472 |
+
print(
|
| 473 |
+
f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
# memory efficient tensor
|
| 477 |
+
tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
|
| 478 |
+
state_dict[name] = tensor
|
| 479 |
+
offset += partitioned_numel
|
| 480 |
+
|
| 481 |
+
offset *= world_size
|
| 482 |
+
|
| 483 |
+
# Sanity check
|
| 484 |
+
if offset != avail_numel:
|
| 485 |
+
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
| 486 |
+
|
| 487 |
+
print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
| 491 |
+
exclude_frozen_parameters):
|
| 492 |
+
state_dict = OrderedDict()
|
| 493 |
+
|
| 494 |
+
# buffers
|
| 495 |
+
buffers = zero_model_states[0].buffers
|
| 496 |
+
state_dict.update(buffers)
|
| 497 |
+
if debug:
|
| 498 |
+
print(f"added {len(buffers)} buffers")
|
| 499 |
+
|
| 500 |
+
if not exclude_frozen_parameters:
|
| 501 |
+
_zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
|
| 502 |
+
|
| 503 |
+
_zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
| 504 |
+
|
| 505 |
+
# recover shared parameters
|
| 506 |
+
for pair in zero_model_states[0].shared_params:
|
| 507 |
+
if pair[1] in state_dict:
|
| 508 |
+
state_dict[pair[0]] = state_dict[pair[1]]
|
| 509 |
+
|
| 510 |
+
return state_dict
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def to_torch_tensor(state_dict, return_empty_tensor=False):
|
| 514 |
+
"""
|
| 515 |
+
Convert state_dict of GatheredTensor to torch tensor
|
| 516 |
+
"""
|
| 517 |
+
torch_state_dict = {}
|
| 518 |
+
converted_tensors = {}
|
| 519 |
+
for name, tensor in state_dict.items():
|
| 520 |
+
tensor_id = id(tensor)
|
| 521 |
+
if tensor_id in converted_tensors: # shared tensors
|
| 522 |
+
shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
|
| 523 |
+
torch_state_dict[name] = shared_tensor
|
| 524 |
+
else:
|
| 525 |
+
converted_tensors[tensor_id] = name
|
| 526 |
+
if return_empty_tensor:
|
| 527 |
+
torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
|
| 528 |
+
else:
|
| 529 |
+
torch_state_dict[name] = tensor.contiguous()
|
| 530 |
+
return torch_state_dict
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
|
| 534 |
+
tag=None,
|
| 535 |
+
exclude_frozen_parameters=False,
|
| 536 |
+
lazy_mode=False):
|
| 537 |
+
"""
|
| 538 |
+
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
|
| 539 |
+
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
|
| 540 |
+
via a model hub.
|
| 541 |
+
|
| 542 |
+
Args:
|
| 543 |
+
- ``checkpoint_dir``: path to the desired checkpoint folder
|
| 544 |
+
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
|
| 545 |
+
- ``exclude_frozen_parameters``: exclude frozen parameters
|
| 546 |
+
- ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
|
| 547 |
+
Convert the pesduo tensor to torch tensor by ``.contiguous()``
|
| 548 |
+
|
| 549 |
+
Returns:
|
| 550 |
+
- pytorch ``state_dict``
|
| 551 |
+
|
| 552 |
+
A typical usage might be ::
|
| 553 |
+
|
| 554 |
+
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
| 555 |
+
# do the training and checkpoint saving
|
| 556 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
|
| 557 |
+
model = model.cpu() # move to cpu
|
| 558 |
+
model.load_state_dict(state_dict)
|
| 559 |
+
# submit to model hub or save the model to share with others
|
| 560 |
+
|
| 561 |
+
In this example the ``model`` will no longer be usable in the deepspeed context of the same
|
| 562 |
+
application. i.e. you will need to re-initialize the deepspeed engine, since
|
| 563 |
+
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
|
| 564 |
+
|
| 565 |
+
If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
|
| 566 |
+
|
| 567 |
+
Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
|
| 568 |
+
You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
|
| 569 |
+
the checkpoint. Or you can load state_dict in lazy mode ::
|
| 570 |
+
|
| 571 |
+
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
| 572 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
|
| 573 |
+
for name, lazy_tensor in state_dict.item():
|
| 574 |
+
tensor = lazy_tensor.contiguous() # to cpu
|
| 575 |
+
print(name, tensor)
|
| 576 |
+
# del tensor to release memory if it no longer in use
|
| 577 |
+
"""
|
| 578 |
+
if tag is None:
|
| 579 |
+
latest_path = os.path.join(checkpoint_dir, 'latest')
|
| 580 |
+
if os.path.isfile(latest_path):
|
| 581 |
+
with open(latest_path, 'r') as fd:
|
| 582 |
+
tag = fd.read().strip()
|
| 583 |
+
else:
|
| 584 |
+
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
|
| 585 |
+
|
| 586 |
+
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
|
| 587 |
+
|
| 588 |
+
if not os.path.isdir(ds_checkpoint_dir):
|
| 589 |
+
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
|
| 590 |
+
|
| 591 |
+
state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
|
| 592 |
+
if lazy_mode:
|
| 593 |
+
return state_dict
|
| 594 |
+
else:
|
| 595 |
+
return to_torch_tensor(state_dict)
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
|
| 599 |
+
output_dir,
|
| 600 |
+
max_shard_size="5GB",
|
| 601 |
+
safe_serialization=False,
|
| 602 |
+
tag=None,
|
| 603 |
+
exclude_frozen_parameters=False):
|
| 604 |
+
"""
|
| 605 |
+
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
|
| 606 |
+
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
|
| 607 |
+
|
| 608 |
+
Args:
|
| 609 |
+
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
| 610 |
+
- ``output_dir``: directory to the pytorch fp32 state_dict output files
|
| 611 |
+
- ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
|
| 612 |
+
- ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
| 613 |
+
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
| 614 |
+
- ``exclude_frozen_parameters``: exclude frozen parameters
|
| 615 |
+
"""
|
| 616 |
+
|
| 617 |
+
# Dependency pre-check
|
| 618 |
+
if safe_serialization:
|
| 619 |
+
try:
|
| 620 |
+
from safetensors.torch import save_file
|
| 621 |
+
except ImportError:
|
| 622 |
+
print('If you want to use `safe_serialization`, please `pip install safetensors`')
|
| 623 |
+
raise
|
| 624 |
+
if max_shard_size is not None:
|
| 625 |
+
try:
|
| 626 |
+
from huggingface_hub import split_torch_state_dict_into_shards
|
| 627 |
+
except ImportError:
|
| 628 |
+
print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
|
| 629 |
+
raise
|
| 630 |
+
|
| 631 |
+
# Convert zero checkpoint to state_dict
|
| 632 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
|
| 633 |
+
tag,
|
| 634 |
+
exclude_frozen_parameters,
|
| 635 |
+
lazy_mode=True)
|
| 636 |
+
|
| 637 |
+
# Shard the model if it is too big.
|
| 638 |
+
weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
|
| 639 |
+
if max_shard_size is not None:
|
| 640 |
+
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
|
| 641 |
+
# an memory-efficient approach for sharding
|
| 642 |
+
empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
|
| 643 |
+
state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
|
| 644 |
+
filename_pattern=filename_pattern,
|
| 645 |
+
max_shard_size=max_shard_size)
|
| 646 |
+
else:
|
| 647 |
+
from collections import namedtuple
|
| 648 |
+
StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
|
| 649 |
+
state_dict_split = StateDictSplit(is_sharded=False,
|
| 650 |
+
filename_to_tensors={weights_name: list(state_dict.keys())})
|
| 651 |
+
|
| 652 |
+
# Save the model by shard
|
| 653 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 654 |
+
filename_to_tensors = state_dict_split.filename_to_tensors.items()
|
| 655 |
+
for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
|
| 656 |
+
shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
|
| 657 |
+
shard_state_dict = to_torch_tensor(shard_state_dict)
|
| 658 |
+
output_path = os.path.join(output_dir, shard_file)
|
| 659 |
+
if safe_serialization:
|
| 660 |
+
save_file(shard_state_dict, output_path, metadata={"format": "pt"})
|
| 661 |
+
else:
|
| 662 |
+
torch.save(shard_state_dict, output_path)
|
| 663 |
+
# release the memory of current shard
|
| 664 |
+
for tensor_name in list(shard_state_dict.keys()):
|
| 665 |
+
del state_dict[tensor_name]
|
| 666 |
+
del shard_state_dict[tensor_name]
|
| 667 |
+
del shard_state_dict
|
| 668 |
+
gc.collect()
|
| 669 |
+
|
| 670 |
+
# Save index if sharded
|
| 671 |
+
if state_dict_split.is_sharded:
|
| 672 |
+
index = {
|
| 673 |
+
"metadata": state_dict_split.metadata,
|
| 674 |
+
"weight_map": state_dict_split.tensor_to_filename,
|
| 675 |
+
}
|
| 676 |
+
save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
|
| 677 |
+
save_index_file = os.path.join(output_dir, save_index_file)
|
| 678 |
+
with open(save_index_file, "w", encoding="utf-8") as f:
|
| 679 |
+
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
| 680 |
+
f.write(content)
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
|
| 684 |
+
"""
|
| 685 |
+
1. Put the provided model to cpu
|
| 686 |
+
2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
|
| 687 |
+
3. Load it into the provided model
|
| 688 |
+
|
| 689 |
+
Args:
|
| 690 |
+
- ``model``: the model object to update
|
| 691 |
+
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
| 692 |
+
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
| 693 |
+
|
| 694 |
+
Returns:
|
| 695 |
+
- ``model`: modified model
|
| 696 |
+
|
| 697 |
+
Make sure you have plenty of CPU memory available before you call this function. If you don't
|
| 698 |
+
have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
|
| 699 |
+
conveniently placed for you in the checkpoint folder.
|
| 700 |
+
|
| 701 |
+
A typical usage might be ::
|
| 702 |
+
|
| 703 |
+
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
|
| 704 |
+
model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
|
| 705 |
+
# submit to model hub or save the model to share with others
|
| 706 |
+
|
| 707 |
+
Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
|
| 708 |
+
of the same application. i.e. you will need to re-initialize the deepspeed engine, since
|
| 709 |
+
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
|
| 710 |
+
|
| 711 |
+
"""
|
| 712 |
+
logger.info("Extracting fp32 weights")
|
| 713 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
|
| 714 |
+
|
| 715 |
+
logger.info("Overwriting model with fp32 weights")
|
| 716 |
+
model = model.cpu()
|
| 717 |
+
model.load_state_dict(state_dict, strict=False)
|
| 718 |
+
|
| 719 |
+
return model
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
if __name__ == "__main__":
|
| 723 |
+
parser = argparse.ArgumentParser()
|
| 724 |
+
parser.add_argument("checkpoint_dir",
|
| 725 |
+
type=str,
|
| 726 |
+
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
|
| 727 |
+
parser.add_argument("output_dir",
|
| 728 |
+
type=str,
|
| 729 |
+
help="directory to the pytorch fp32 state_dict output files"
|
| 730 |
+
"(e.g. path/checkpoint-12-output/)")
|
| 731 |
+
parser.add_argument(
|
| 732 |
+
"--max_shard_size",
|
| 733 |
+
type=str,
|
| 734 |
+
default="5GB",
|
| 735 |
+
help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
|
| 736 |
+
"lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
|
| 737 |
+
"We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
|
| 738 |
+
"without CPU OOM issues.")
|
| 739 |
+
parser.add_argument(
|
| 740 |
+
"--safe_serialization",
|
| 741 |
+
default=False,
|
| 742 |
+
action='store_true',
|
| 743 |
+
help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
|
| 744 |
+
parser.add_argument("-t",
|
| 745 |
+
"--tag",
|
| 746 |
+
type=str,
|
| 747 |
+
default=None,
|
| 748 |
+
help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
|
| 749 |
+
parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
|
| 750 |
+
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
|
| 751 |
+
args = parser.parse_args()
|
| 752 |
+
|
| 753 |
+
debug = args.debug
|
| 754 |
+
|
| 755 |
+
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
|
| 756 |
+
args.output_dir,
|
| 757 |
+
max_shard_size=args.max_shard_size,
|
| 758 |
+
safe_serialization=args.safe_serialization,
|
| 759 |
+
tag=args.tag,
|
| 760 |
+
exclude_frozen_parameters=args.exclude_frozen_parameters)
|
RL_dataset/.gitattributes
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.mat filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.hdf5 filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*.tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.db* filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.ark* filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
**/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
**/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
**/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
*.bmp filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
*.webp filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
*.wma filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
*.aac filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
*.ogg filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
*.m4a filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
*.m3u8 filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
*.amr filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
*.audio filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
*.avi filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
*.flv filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
*.mpg filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
*.asf filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
*.mov filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
*.mpeg filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
*.3gp filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
*.wmv filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
*.rmvb filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
*.rm filter=lfs diff=lfs merge=lfs -text
|
| 69 |
+
*.ts filter=lfs diff=lfs merge=lfs -text
|
| 70 |
+
*.mkv filter=lfs diff=lfs merge=lfs -text
|
| 71 |
+
*.flash filter=lfs diff=lfs merge=lfs -text
|
| 72 |
+
*.vob filter=lfs diff=lfs merge=lfs -text
|
| 73 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
| 74 |
+
*.ost filter=lfs diff=lfs merge=lfs -text
|
| 75 |
+
*.pst filter=lfs diff=lfs merge=lfs -text
|
| 76 |
+
*.doc filter=lfs diff=lfs merge=lfs -text
|
| 77 |
+
*.docx filter=lfs diff=lfs merge=lfs -text
|
| 78 |
+
*.txt filter=lfs diff=lfs merge=lfs -text
|
| 79 |
+
*.ppt filter=lfs diff=lfs merge=lfs -text
|
| 80 |
+
*.pptx filter=lfs diff=lfs merge=lfs -text
|
| 81 |
+
*.xls filter=lfs diff=lfs merge=lfs -text
|
| 82 |
+
*.xlsx filter=lfs diff=lfs merge=lfs -text
|
| 83 |
+
*.vsd filter=lfs diff=lfs merge=lfs -text
|
| 84 |
+
*.vsdx filter=lfs diff=lfs merge=lfs -text
|
| 85 |
+
*.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 86 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
| 87 |
+
dataset_infos.json ignore
|
| 88 |
+
*.csv filter=lfs diff=lfs merge=lfs -text
|
| 89 |
+
*.tsv filter=lfs diff=lfs merge=lfs -text
|
RL_dataset/.msc
ADDED
|
Binary file (546 Bytes). View file
|
|
|
RL_dataset/.mv
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
master
|
RL_dataset/INFOSEEK_DOWNLOAD.md
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# InfoSeek Data Download
|
| 2 |
+
|
| 3 |
+
This document collects ready-to-run scripts for downloading the InfoSeek dataset into:
|
| 4 |
+
|
| 5 |
+
`/workspace/xiaobin/RL_dataset/data`
|
| 6 |
+
|
| 7 |
+
It covers:
|
| 8 |
+
|
| 9 |
+
- InfoSeek annotations
|
| 10 |
+
- InfoSeek KB mapping files
|
| 11 |
+
- InfoSeek human set
|
| 12 |
+
- Wiki6M text files
|
| 13 |
+
- OVEN image snapshot on Hugging Face
|
| 14 |
+
- OVEN original-source image download workflow
|
| 15 |
+
|
| 16 |
+
InfoSeek images are derived from OVEN, so image download is handled through the OVEN release pipeline.
|
| 17 |
+
|
| 18 |
+
## 1. Recommended Directory Layout
|
| 19 |
+
|
| 20 |
+
```bash
|
| 21 |
+
mkdir -p /workspace/xiaobin/RL_dataset/data/infoseek
|
| 22 |
+
mkdir -p /workspace/xiaobin/RL_dataset/data/oven_hf
|
| 23 |
+
mkdir -p /workspace/xiaobin/RL_dataset/data/oven_source
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
Suggested usage:
|
| 27 |
+
|
| 28 |
+
- `/workspace/xiaobin/RL_dataset/data/infoseek`: InfoSeek jsonl files
|
| 29 |
+
- `/workspace/xiaobin/RL_dataset/data/oven_hf`: Hugging Face image snapshot files
|
| 30 |
+
- `/workspace/xiaobin/RL_dataset/data/oven_source`: upstream OVEN repo for original-source image download
|
| 31 |
+
|
| 32 |
+
## 2. Proxy Workaround
|
| 33 |
+
|
| 34 |
+
If your shell is configured with an invalid local proxy such as `127.0.0.1:7890`, use one of these patterns.
|
| 35 |
+
|
| 36 |
+
Temporarily disable proxy for a single command:
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
env -u http_proxy -u https_proxy -u HTTP_PROXY -u HTTPS_PROXY wget -c URL
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
Or disable proxy for the current shell session:
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
## 3. Download All InfoSeek Text Data With `wget`
|
| 49 |
+
|
| 50 |
+
This is the simplest full download for the released InfoSeek jsonl files.
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
#!/usr/bin/env bash
|
| 54 |
+
set -euo pipefail
|
| 55 |
+
|
| 56 |
+
TARGET_DIR="/workspace/xiaobin/RL_dataset/data/infoseek"
|
| 57 |
+
mkdir -p "${TARGET_DIR}"
|
| 58 |
+
cd "${TARGET_DIR}"
|
| 59 |
+
|
| 60 |
+
wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_train.jsonl
|
| 61 |
+
wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_val.jsonl
|
| 62 |
+
wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_test.jsonl
|
| 63 |
+
wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_train_withkb.jsonl
|
| 64 |
+
wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_val_withkb.jsonl
|
| 65 |
+
wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_human.jsonl
|
| 66 |
+
wget -c http://storage.googleapis.com/gresearch/open-vision-language/Wiki6M_ver_1_0.jsonl.gz
|
| 67 |
+
wget -c http://storage.googleapis.com/gresearch/open-vision-language/Wiki6M_ver_1_0_title_only.jsonl
|
| 68 |
+
|
| 69 |
+
ls -lh "${TARGET_DIR}"
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
## 4. Download All InfoSeek Text Data With `curl`
|
| 73 |
+
|
| 74 |
+
Use this if `wget` is not available.
|
| 75 |
+
|
| 76 |
+
```bash
|
| 77 |
+
#!/usr/bin/env bash
|
| 78 |
+
set -euo pipefail
|
| 79 |
+
|
| 80 |
+
TARGET_DIR="/workspace/xiaobin/RL_dataset/data/infoseek"
|
| 81 |
+
mkdir -p "${TARGET_DIR}"
|
| 82 |
+
cd "${TARGET_DIR}"
|
| 83 |
+
|
| 84 |
+
curl -L -O http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_train.jsonl
|
| 85 |
+
curl -L -O http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_val.jsonl
|
| 86 |
+
curl -L -O http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_test.jsonl
|
| 87 |
+
curl -L -O http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_train_withkb.jsonl
|
| 88 |
+
curl -L -O http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_val_withkb.jsonl
|
| 89 |
+
curl -L -O http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_human.jsonl
|
| 90 |
+
curl -L -O http://storage.googleapis.com/gresearch/open-vision-language/Wiki6M_ver_1_0.jsonl.gz
|
| 91 |
+
curl -L -O http://storage.googleapis.com/gresearch/open-vision-language/Wiki6M_ver_1_0_title_only.jsonl
|
| 92 |
+
|
| 93 |
+
ls -lh "${TARGET_DIR}"
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
## 5. Download Only Core InfoSeek Splits
|
| 97 |
+
|
| 98 |
+
If you only need the standard train/val/test annotations:
|
| 99 |
+
|
| 100 |
+
```bash
|
| 101 |
+
#!/usr/bin/env bash
|
| 102 |
+
set -euo pipefail
|
| 103 |
+
|
| 104 |
+
TARGET_DIR="/workspace/xiaobin/RL_dataset/data/infoseek"
|
| 105 |
+
mkdir -p "${TARGET_DIR}"
|
| 106 |
+
cd "${TARGET_DIR}"
|
| 107 |
+
|
| 108 |
+
wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_train.jsonl
|
| 109 |
+
wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_val.jsonl
|
| 110 |
+
wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_test.jsonl
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
## 6. Download Only KB Mapping Files
|
| 114 |
+
|
| 115 |
+
```bash
|
| 116 |
+
#!/usr/bin/env bash
|
| 117 |
+
set -euo pipefail
|
| 118 |
+
|
| 119 |
+
TARGET_DIR="/workspace/xiaobin/RL_dataset/data/infoseek"
|
| 120 |
+
mkdir -p "${TARGET_DIR}"
|
| 121 |
+
cd "${TARGET_DIR}"
|
| 122 |
+
|
| 123 |
+
wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_train_withkb.jsonl
|
| 124 |
+
wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_val_withkb.jsonl
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
## 7. Download Only Human Eval Set
|
| 128 |
+
|
| 129 |
+
```bash
|
| 130 |
+
#!/usr/bin/env bash
|
| 131 |
+
set -euo pipefail
|
| 132 |
+
|
| 133 |
+
TARGET_DIR="/workspace/xiaobin/RL_dataset/data/infoseek"
|
| 134 |
+
mkdir -p "${TARGET_DIR}"
|
| 135 |
+
cd "${TARGET_DIR}"
|
| 136 |
+
|
| 137 |
+
wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_human.jsonl
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
## 8. Download Only Wiki6M Files
|
| 141 |
+
|
| 142 |
+
```bash
|
| 143 |
+
#!/usr/bin/env bash
|
| 144 |
+
set -euo pipefail
|
| 145 |
+
|
| 146 |
+
TARGET_DIR="/workspace/xiaobin/RL_dataset/data/infoseek"
|
| 147 |
+
mkdir -p "${TARGET_DIR}"
|
| 148 |
+
cd "${TARGET_DIR}"
|
| 149 |
+
|
| 150 |
+
wget -c http://storage.googleapis.com/gresearch/open-vision-language/Wiki6M_ver_1_0.jsonl.gz
|
| 151 |
+
wget -c http://storage.googleapis.com/gresearch/open-vision-language/Wiki6M_ver_1_0_title_only.jsonl
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
Optional decompression:
|
| 155 |
+
|
| 156 |
+
```bash
|
| 157 |
+
gunzip -k /workspace/xiaobin/RL_dataset/data/infoseek/Wiki6M_ver_1_0.jsonl.gz
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
## 9. Download OVEN Image Snapshot From Hugging Face
|
| 161 |
+
|
| 162 |
+
Upstream OVEN now points image snapshot downloads to the gated dataset `ychenNLP/oven` on Hugging Face. Before downloading:
|
| 163 |
+
|
| 164 |
+
1. Open `https://huggingface.co/datasets/ychenNLP/oven`
|
| 165 |
+
2. Accept the dataset access conditions
|
| 166 |
+
3. Log in with the Hugging Face CLI
|
| 167 |
+
|
| 168 |
+
Install the CLI if needed:
|
| 169 |
+
|
| 170 |
+
```bash
|
| 171 |
+
python -m pip install -U "huggingface_hub[cli]"
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
Login:
|
| 175 |
+
|
| 176 |
+
```bash
|
| 177 |
+
hf auth login
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
Download the image snapshot and mapping file into `/workspace/xiaobin/RL_dataset/data/oven_hf`:
|
| 181 |
+
|
| 182 |
+
```bash
|
| 183 |
+
#!/usr/bin/env bash
|
| 184 |
+
set -euo pipefail
|
| 185 |
+
|
| 186 |
+
TARGET_DIR="/workspace/xiaobin/RL_dataset/data/oven_hf"
|
| 187 |
+
mkdir -p "${TARGET_DIR}"
|
| 188 |
+
|
| 189 |
+
hf download ychenNLP/oven \
|
| 190 |
+
--repo-type dataset \
|
| 191 |
+
--local-dir "${TARGET_DIR}" \
|
| 192 |
+
--include "shard*.tar" \
|
| 193 |
+
--include "all_wikipedia_images.tar" \
|
| 194 |
+
--include "ovenid2impath.csv"
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
Extract the snapshot tar files:
|
| 198 |
+
|
| 199 |
+
```bash
|
| 200 |
+
#!/usr/bin/env bash
|
| 201 |
+
set -euo pipefail
|
| 202 |
+
|
| 203 |
+
HF_DIR="/workspace/xiaobin/RL_dataset/data/oven_hf"
|
| 204 |
+
IMG_DIR="/workspace/xiaobin/RL_dataset/data/infoseek/images"
|
| 205 |
+
mkdir -p "${IMG_DIR}"
|
| 206 |
+
|
| 207 |
+
for f in "${HF_DIR}"/shard*.tar; do
|
| 208 |
+
tar -xf "${f}" -C "${IMG_DIR}"
|
| 209 |
+
done
|
| 210 |
+
|
| 211 |
+
tar -xf "${HF_DIR}/all_wikipedia_images.tar" -C "${IMG_DIR}"
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
Notes:
|
| 215 |
+
|
| 216 |
+
- Hugging Face file listing shows `shard01.tar` to `shard08.tar` plus `all_wikipedia_images.tar`
|
| 217 |
+
- The compressed download is very large, roughly 293 GB based on the published file sizes
|
| 218 |
+
- You need additional free space for extraction
|
| 219 |
+
|
| 220 |
+
## 10. Download OVEN Images From Original Sources
|
| 221 |
+
|
| 222 |
+
This follows the upstream `oven_eval/image_downloads` workflow.
|
| 223 |
+
|
| 224 |
+
### 10.1 Clone the Upstream Repo
|
| 225 |
+
|
| 226 |
+
```bash
|
| 227 |
+
git clone https://github.com/edchengg/oven_eval /workspace/xiaobin/RL_dataset/data/oven_source/oven_eval
|
| 228 |
+
```
|
| 229 |
+
|
| 230 |
+
### 10.2 Run All Source Download Scripts
|
| 231 |
+
|
| 232 |
+
The upstream image download directory contains these scripts:
|
| 233 |
+
|
| 234 |
+
- `download_aircraft.sh`
|
| 235 |
+
- `download_car196.sh`
|
| 236 |
+
- `download_coco.sh`
|
| 237 |
+
- `download_food101.sh`
|
| 238 |
+
- `download_gldv2.sh`
|
| 239 |
+
- `download_imagenet.sh`
|
| 240 |
+
- `download_inat.sh`
|
| 241 |
+
- `download_oxfordflower.sh`
|
| 242 |
+
- `download_sports100.sh`
|
| 243 |
+
- `download_sun397.sh`
|
| 244 |
+
- `download_textvqa.sh`
|
| 245 |
+
- `download_v7w.sh`
|
| 246 |
+
- `download_vg.sh`
|
| 247 |
+
|
| 248 |
+
Run them one by one:
|
| 249 |
+
|
| 250 |
+
```bash
|
| 251 |
+
#!/usr/bin/env bash
|
| 252 |
+
set -euo pipefail
|
| 253 |
+
|
| 254 |
+
cd /workspace/xiaobin/RL_dataset/data/oven_source/oven_eval/image_downloads
|
| 255 |
+
|
| 256 |
+
bash download_aircraft.sh
|
| 257 |
+
bash download_car196.sh
|
| 258 |
+
bash download_coco.sh
|
| 259 |
+
bash download_food101.sh
|
| 260 |
+
bash download_gldv2.sh
|
| 261 |
+
bash download_imagenet.sh
|
| 262 |
+
bash download_inat.sh
|
| 263 |
+
bash download_oxfordflower.sh
|
| 264 |
+
bash download_sports100.sh
|
| 265 |
+
bash download_sun397.sh
|
| 266 |
+
bash download_textvqa.sh
|
| 267 |
+
bash download_v7w.sh
|
| 268 |
+
bash download_vg.sh
|
| 269 |
+
```
|
| 270 |
+
|
| 271 |
+
Or run them in a loop:
|
| 272 |
+
|
| 273 |
+
```bash
|
| 274 |
+
#!/usr/bin/env bash
|
| 275 |
+
set -euo pipefail
|
| 276 |
+
|
| 277 |
+
cd /workspace/xiaobin/RL_dataset/data/oven_source/oven_eval/image_downloads
|
| 278 |
+
|
| 279 |
+
for script in download_*.sh; do
|
| 280 |
+
bash "${script}"
|
| 281 |
+
done
|
| 282 |
+
```
|
| 283 |
+
|
| 284 |
+
### 10.3 Download `ovenid2impath.csv`
|
| 285 |
+
|
| 286 |
+
You need `ovenid2impath.csv` for the merge step. The current recommended source is the Hugging Face dataset:
|
| 287 |
+
|
| 288 |
+
```bash
|
| 289 |
+
#!/usr/bin/env bash
|
| 290 |
+
set -euo pipefail
|
| 291 |
+
|
| 292 |
+
TARGET_DIR="/workspace/xiaobin/RL_dataset/data/oven_hf"
|
| 293 |
+
mkdir -p "${TARGET_DIR}"
|
| 294 |
+
|
| 295 |
+
hf download ychenNLP/oven \
|
| 296 |
+
--repo-type dataset \
|
| 297 |
+
--local-dir "${TARGET_DIR}" \
|
| 298 |
+
--include "ovenid2impath.csv"
|
| 299 |
+
```
|
| 300 |
+
|
| 301 |
+
### 10.4 Merge Into the Final OVEN Image Layout
|
| 302 |
+
|
| 303 |
+
Run the upstream merge script after all downloads finish:
|
| 304 |
+
|
| 305 |
+
```bash
|
| 306 |
+
cd /workspace/xiaobin/RL_dataset/data/oven_source/oven_eval/image_downloads
|
| 307 |
+
python merge_oven_images.py
|
| 308 |
+
```
|
| 309 |
+
|
| 310 |
+
The upstream documentation states that `merge_oven_images.py` should be run after all image download scripts complete and after `ovenid2impath.csv` is available.
|
| 311 |
+
|
| 312 |
+
## 11. Verify the Downloaded Files
|
| 313 |
+
|
| 314 |
+
Check text files:
|
| 315 |
+
|
| 316 |
+
```bash
|
| 317 |
+
ls -lh /workspace/xiaobin/RL_dataset/data/infoseek
|
| 318 |
+
```
|
| 319 |
+
|
| 320 |
+
Check Hugging Face snapshot files:
|
| 321 |
+
|
| 322 |
+
```bash
|
| 323 |
+
ls -lh /workspace/xiaobin/RL_dataset/data/oven_hf
|
| 324 |
+
```
|
| 325 |
+
|
| 326 |
+
Check extracted images:
|
| 327 |
+
|
| 328 |
+
```bash
|
| 329 |
+
find /workspace/xiaobin/RL_dataset/data/infoseek/images -type f | wc -l
|
| 330 |
+
```
|
| 331 |
+
|
| 332 |
+
## 12. Upstream References
|
| 333 |
+
|
| 334 |
+
- InfoSeek release page: `https://github.com/open-vision-language/infoseek`
|
| 335 |
+
- OVEN image download page: `https://github.com/edchengg/oven_eval/tree/main/image_downloads`
|
| 336 |
+
- Hugging Face OVEN dataset: `https://huggingface.co/datasets/ychenNLP/oven`
|
| 337 |
+
- Hugging Face CLI download docs: `https://huggingface.co/docs/huggingface_hub/guides/cli`
|
RL_dataset/README.md
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
task_categories:
|
| 4 |
+
- question-answering
|
| 5 |
+
tags:
|
| 6 |
+
- deep-research
|
| 7 |
+
- hierarchical-reasoning
|
| 8 |
+
- multi-hop-qa
|
| 9 |
+
- synthetic-data
|
| 10 |
+
- data-synthesis
|
| 11 |
+
language:
|
| 12 |
+
- en
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# InfoSeek: Open Data Synthesis For Deep Research
|
| 16 |
+
|
| 17 |
+
[Paper](https://huggingface.co/papers/2509.00375) | [Code](https://github.com/VectorSpaceLab/InfoSeek)
|
| 18 |
+
|
| 19 |
+
## Dataset Information
|
| 20 |
+
|
| 21 |
+
* **`data/InfoSeek.jsonl`**
|
| 22 |
+
Contains the full research tree structures of *InfoSeek*. Each sample starts from a root node with a research question, its corresponding entity, and process information for sub-questions (stored in `root`). Also expands into intermediate tree structure during each step of construction (stored in `all_tree_list`). Totally 52K samples.
|
| 23 |
+
|
| 24 |
+
* **`data/InfoSeekQA.jsonl`**
|
| 25 |
+
A collection of QA pairs derived from *InfoSeek*. Each entry corresponds to the final question (`sample['root']['question']`) and its answer entity (`sample['root']['entity']`) in `InfoSeek.jsonl`.
|
| 26 |
+
|
| 27 |
+
* **`data/InfoSeek-Hard-18K.jsonl`**
|
| 28 |
+
A challenging subset of *InfoSeek* (18K samples), which is better to conduct end-to-end RL, identified using an LLM with a dedicated prompt for complex deep research.
|
| 29 |
+
|
| 30 |
+
* **`data/Trajectory-RFT-17K.jsonl`**
|
| 31 |
+
Contains 17K reasoning trajectories generated through the workflow described in our paper. These can be used as training data for supervised fine-tuning (SFT).
|
| 32 |
+
|
| 33 |
+
## Abstract
|
| 34 |
+
Large language models (LLMs) are increasingly expected to go beyond simple factual queries toward Deep Research-tasks that require decomposing questions into sub-problems, coordinating multi-step reasoning, and synthesizing evidence from diverse sources. We formalize Deep Research tasks with verifiable answers as Hierarchical Constraint Satisfaction Problems (HCSPs), which are fundamentally different from single-constraint, multi-hop, or flat CSP formulations. However, existing benchmarks (e.g., Natural Questions, HotpotQA) fail to capture this complexity, while recent synthetic datasets often introduce shortcut reasoning, knowledge leakage, or lack sufficient structural depth. To address this gap, we introduce InfoSeek, a scalable framework for synthesizing complex Deep Research tasks. InfoSeek uses a dual-agent system to recursively build a Research Tree from large-scale webpages, blurring intermediate nodes into valid sub-problems, and converting these trees into natural language questions that require traversing the full hierarchy. It also enables rapid scaling, yielding over 50K training examples, a curated test set, and reasoning trajectories generated via reject sampling. Experiments show that models trained on InfoSeek consistently outperform strong baselines. On a challenging benchmark BrowseComp-Plus, 3B LLMs optimized with InfoSeek surpass much larger 32B models and lightweight commercial APIs (e.g., Gemini2.5-Flash), while achieving performance comparable to stronger APIs (e.g., Gemini2.5-Pro). By preserving meta-information such as intermediate steps and retrieval labels, InfoSeek further supports advanced optimization strategies, including compound reward design and trajectory-level exploration.
|
| 35 |
+
|
| 36 |
+
## 🔆 Overview
|
| 37 |
+
We propose **InfoSeek**, a scalable data synthesis framework for constructing structurally complex Deep Research tasks. InfoSeek designs a dual-agent system to recursively build a *Research Tree* by mining entities and relations from large-scale text, and blurring itermediate vertices to ensure they form valid sub-problems. The agent then transform these trees into natural language questions whose solutions require traversing the entire hierarchy. Using InfoSeek pipeline, we construct a high-quality, complexity-controllable, and intrinsically verifiable dataset.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
### Example 1:
|
| 41 |
+
**Question:** What is a species of bird that was named by a person employed under his father between 1818 and 1824, whose wife was a British artist, and which has three subspecies and body length is generally no more than 6 inches?
|
| 42 |
+
|
| 43 |
+
**Answer:** Russet sparrow
|
| 44 |
+
|
| 45 |
+
<details>
|
| 46 |
+
<summary>Tree Structure</summary>
|
| 47 |
+
|
| 48 |
+
```
|
| 49 |
+
{
|
| 50 |
+
"root": {
|
| 51 |
+
"id": "A",
|
| 52 |
+
"entity": "Russet sparrow",
|
| 53 |
+
"question": "What is a species of bird that was named by a person employed under his father between 1818 and 1824, whose wife was a British artist, and which has three subspecies and body length is generally no more than 6 inches?",
|
| 54 |
+
"claims": [
|
| 55 |
+
{ "target_id": "B", "claim": "A was named by B" },
|
| 56 |
+
{ "target_id": "C", "claim": "A has three subspecies" },
|
| 57 |
+
{ "target_id": "D", "claim": "A's body length is generally no more than 6 inches" }
|
| 58 |
+
],\
|
| 59 |
+
"children": [
|
| 60 |
+
{
|
| 61 |
+
"id": "B",
|
| 62 |
+
"entity": "John Gould",
|
| 63 |
+
"claims": [
|
| 64 |
+
{ "target_id": "E", "claim": "B was employed by his father between 1818 and 1824" },
|
| 65 |
+
{ "target_id": "F", "claim": "B's wife was F" }
|
| 66 |
+
],\
|
| 67 |
+
"children": [
|
| 68 |
+
{ "id": "E", "entity": "None", "claims": [], "children": [] },
|
| 69 |
+
{ "id": "F", "entity": "Elizabeth Gould", "claims": [], "children": [] }
|
| 70 |
+
]
|
| 71 |
+
},\
|
| 72 |
+
{ "id": "C", "entity": "None", "claims": [], "children": [] },
|
| 73 |
+
{ "id": "D", "entity": "None", "claims": [], "children": [] }
|
| 74 |
+
]
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
```
|
| 80 |
+
(A: Russet sparrow)
|
| 81 |
+
│
|
| 82 |
+
│
|
| 83 |
+
│── [claim] "was named by" ──> (B: John Gould)
|
| 84 |
+
│ │
|
| 85 |
+
│ │
|
| 86 |
+
│ │── [claim] "was employed by his father (1818-1824)"
|
| 87 |
+
│ │
|
| 88 |
+
│ │
|
| 89 |
+
│ │── [claim] "wife was" ──> (F: Elizabeth Gould)
|
| 90 |
+
│
|
| 91 |
+
│
|
| 92 |
+
│── [claim] "has three subspecies"
|
| 93 |
+
│
|
| 94 |
+
│
|
| 95 |
+
│── [claim] "body length is generally no more than 6 inches"
|
| 96 |
+
```
|
| 97 |
+
</details>
|
| 98 |
+
|
| 99 |
+
### Example 2:
|
| 100 |
+
|
| 101 |
+
**Question:** What is a women's football team whose first goals in the 2. Bundesliga were scored by a player born in Korogocho, who was discovered and developed by the Mathare Youth Sports Association?
|
| 102 |
+
|
| 103 |
+
**Answer:** SV Werder Bremen (women)
|
| 104 |
+
|
| 105 |
+
<details>
|
| 106 |
+
<summary>Tree Structure</summary>
|
| 107 |
+
|
| 108 |
+
```
|
| 109 |
+
{
|
| 110 |
+
"root": {
|
| 111 |
+
"id": "A",
|
| 112 |
+
"entity": "SV Werder Bremen (women)",
|
| 113 |
+
"question": "What is a women's football team whose first goals in the 2. Bundesliga were scored by a player born in Korogocho, who was discovered and developed by the Mathare Youth Sports Association?",
|
| 114 |
+
"claims": [
|
| 115 |
+
{ "target_id": "B", "claim": "A's first goals in the 2. Bundesliga were scored by B" }
|
| 116 |
+
],\
|
| 117 |
+
"children": [
|
| 118 |
+
{
|
| 119 |
+
"id": "B",
|
| 120 |
+
"entity": "Doreen Nabwire",
|
| 121 |
+
"claims": [
|
| 122 |
+
{ "target_id": "C", "claim": "B was discovered and developed by C" },
|
| 123 |
+
{ "target_id": "D", "claim": "B was born in D" }
|
| 124 |
+
],\
|
| 125 |
+
"children": [
|
| 126 |
+
{ "id": "C", "entity": "Mathare Youth Sports Association", "claims": [], "children": [] },
|
| 127 |
+
{ "id": "D", "entity": "Korogocho", "claims": [], "children": [] }
|
| 128 |
+
]
|
| 129 |
+
}
|
| 130 |
+
]
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
```
|
| 136 |
+
(A: SV Werder Bremen (women))
|
| 137 |
+
│
|
| 138 |
+
│
|
| 139 |
+
│── [claim] "first goals scored by" ──> (B: Doreen Nabwire)
|
| 140 |
+
│
|
| 141 |
+
│
|
| 142 |
+
│── [claim] "discovered and developed by" ──> (C:Mathare Youth Sports Association)
|
| 143 |
+
│
|
| 144 |
+
│
|
| 145 |
+
│── [claim] "was born in" ──> (D: Korogocho)
|
| 146 |
+
```
|
| 147 |
+
</details>
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
## 📊 Performance
|
| 151 |
+
Model trained on InfoSeek and our framework shows strong performances on traditional multi-hop benchmarks:
|
| 152 |
+
|
| 153 |
+
<img src="https://github.com/VectorSpaceLab/InfoSeek/raw/main/assets/results.png" width="800">
|
| 154 |
+
|
| 155 |
+
Our 3B model shows competitive results on [BrowseComp-Plus](https://github.com/texttron/BrowseComp-Plus):
|
| 156 |
+
|
| 157 |
+
<img src="https://github.com/VectorSpaceLab/InfoSeek/raw/main/assets/browsecomp_plus.png" width="800">
|
| 158 |
+
|
| 159 |
+
## ❤️ Citing Us
|
| 160 |
+
If you find this repository or our work useful, please consider giving a star ⭐ and or citing our work, which would be greatly appreciated:
|
| 161 |
+
```bibtex
|
| 162 |
+
@misc{xia2025opendatasynthesisdeep,
|
| 163 |
+
title={Open Data Synthesis For Deep Research},
|
| 164 |
+
author={Ziyi Xia and Kun Luo and Hongjin Qian and Zheng Liu},
|
| 165 |
+
year={2025},\
|
| 166 |
+
eprint={2509.00375},
|
| 167 |
+
archivePrefix={arXiv},
|
| 168 |
+
primaryClass={cs.CL},\
|
| 169 |
+
url={https://arxiv.org/abs/2509.00375},
|
| 170 |
+
}
|
| 171 |
+
```
|
RL_dataset/dataset_infos.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"default": {"features": {"root": {"_type": "Value"}, "all_tree_list": {"_type": "Value"}, "vertices": {"_type": "Value"}}, "splits": {"train": {"name": "train", "dataset_name": "InfoSeek"}}}}
|
RL_dataset/download_oven_hf_mirror.sh
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
MODE="${1:-all}"
|
| 5 |
+
|
| 6 |
+
REPO_ID="ychenNLP/oven"
|
| 7 |
+
TARGET_DIR="/workspace/xiaobin/RL_dataset/data"
|
| 8 |
+
CACHE_DIR="${TARGET_DIR}/.hf_cache"
|
| 9 |
+
ASSETS_DIR="${TARGET_DIR}/.hf_assets"
|
| 10 |
+
DEFAULT_ENDPOINT="https://hf-mirror.com"
|
| 11 |
+
MIRROR_URL="${HF_ENDPOINT:-${HF_ENDPOINT_OVERRIDE:-${DEFAULT_ENDPOINT}}}"
|
| 12 |
+
HARDCODED_TOKEN="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"
|
| 13 |
+
META_FILES=(
|
| 14 |
+
"download_infoseek_jsonl.sh"
|
| 15 |
+
"download_oven_jsonl.sh"
|
| 16 |
+
"ovenid2impath.csv"
|
| 17 |
+
)
|
| 18 |
+
IMAGE_FILES=(
|
| 19 |
+
"shard01.tar"
|
| 20 |
+
"shard02.tar"
|
| 21 |
+
"shard03.tar"
|
| 22 |
+
"shard04.tar"
|
| 23 |
+
"shard05.tar"
|
| 24 |
+
"shard06.tar"
|
| 25 |
+
"shard07.tar"
|
| 26 |
+
"shard08.tar"
|
| 27 |
+
"all_wikipedia_images.tar"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
unset http_proxy
|
| 31 |
+
unset https_proxy
|
| 32 |
+
unset HTTP_PROXY
|
| 33 |
+
unset HTTPS_PROXY
|
| 34 |
+
unset all_proxy
|
| 35 |
+
unset ALL_PROXY
|
| 36 |
+
|
| 37 |
+
export HF_ENDPOINT="${MIRROR_URL}"
|
| 38 |
+
export HF_HUB_CACHE="${CACHE_DIR}"
|
| 39 |
+
export HF_ASSETS_CACHE="${ASSETS_DIR}"
|
| 40 |
+
|
| 41 |
+
mkdir -p "${TARGET_DIR}" "${CACHE_DIR}" "${ASSETS_DIR}"
|
| 42 |
+
|
| 43 |
+
if command -v hf >/dev/null 2>&1; then
|
| 44 |
+
HF_BIN=(hf download)
|
| 45 |
+
elif command -v huggingface-cli >/dev/null 2>&1; then
|
| 46 |
+
HF_BIN=(huggingface-cli download)
|
| 47 |
+
else
|
| 48 |
+
echo "Missing Hugging Face CLI. Install it with:" >&2
|
| 49 |
+
echo " python -m pip install -U \"huggingface_hub[cli]\"" >&2
|
| 50 |
+
exit 1
|
| 51 |
+
fi
|
| 52 |
+
|
| 53 |
+
TOKEN_ARGS=()
|
| 54 |
+
if [[ -n "${HF_TOKEN:-}" ]]; then
|
| 55 |
+
TOKEN_ARGS=(--token "${HF_TOKEN}")
|
| 56 |
+
elif [[ -n "${HARDCODED_TOKEN}" ]]; then
|
| 57 |
+
TOKEN_ARGS=(--token "${HARDCODED_TOKEN}")
|
| 58 |
+
fi
|
| 59 |
+
|
| 60 |
+
print_help() {
|
| 61 |
+
cat <<'EOF'
|
| 62 |
+
Usage:
|
| 63 |
+
bash download_oven_hf_mirror.sh [meta|images|all]
|
| 64 |
+
|
| 65 |
+
Modes:
|
| 66 |
+
meta Download metadata files only:
|
| 67 |
+
- download_infoseek_jsonl.sh
|
| 68 |
+
- download_oven_jsonl.sh
|
| 69 |
+
- ovenid2impath.csv
|
| 70 |
+
images Download image tar files only:
|
| 71 |
+
- shard01.tar ... shard08.tar
|
| 72 |
+
- all_wikipedia_images.tar
|
| 73 |
+
all Download both metadata and image tar files
|
| 74 |
+
|
| 75 |
+
Behavior:
|
| 76 |
+
- unsets proxy variables before downloading
|
| 77 |
+
- uses the mirror endpoint: https://hf-mirror.com
|
| 78 |
+
- endpoint can be overridden:
|
| 79 |
+
HF_ENDPOINT=https://huggingface.co bash download_oven_hf_mirror.sh meta
|
| 80 |
+
- stores downloaded files in: /workspace/xiaobin/RL_dataset/data
|
| 81 |
+
- stores Hugging Face cache in: /workspace/xiaobin/RL_dataset/data/.hf_cache
|
| 82 |
+
|
| 83 |
+
Notes:
|
| 84 |
+
- The dataset is gated. First accept access at:
|
| 85 |
+
https://huggingface.co/datasets/ychenNLP/oven
|
| 86 |
+
- The script contains a hardcoded token by default.
|
| 87 |
+
- If needed, export your token before running to override it:
|
| 88 |
+
export HF_TOKEN=hf_xxx
|
| 89 |
+
EOF
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
if [[ "${MODE}" == "-h" || "${MODE}" == "--help" || "${MODE}" == "help" ]]; then
|
| 93 |
+
print_help
|
| 94 |
+
exit 0
|
| 95 |
+
fi
|
| 96 |
+
|
| 97 |
+
require_auth() {
|
| 98 |
+
if [[ -n "${HF_TOKEN:-}" ]]; then
|
| 99 |
+
return 0
|
| 100 |
+
fi
|
| 101 |
+
|
| 102 |
+
if hf auth whoami >/dev/null 2>&1; then
|
| 103 |
+
return 0
|
| 104 |
+
fi
|
| 105 |
+
|
| 106 |
+
echo "No Hugging Face authentication detected." >&2
|
| 107 |
+
echo "Do this first:" >&2
|
| 108 |
+
echo " 1. Open https://huggingface.co/datasets/ychenNLP/oven and accept access." >&2
|
| 109 |
+
echo " 2. Run: hf auth login" >&2
|
| 110 |
+
echo " or: export HF_TOKEN=hf_xxx" >&2
|
| 111 |
+
exit 2
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
run_download() {
|
| 115 |
+
if ! "$@"; then
|
| 116 |
+
echo >&2
|
| 117 |
+
echo "Download failed." >&2
|
| 118 |
+
echo "Check these items:" >&2
|
| 119 |
+
echo " - access was approved for https://huggingface.co/datasets/ychenNLP/oven" >&2
|
| 120 |
+
echo " - HF_TOKEN is valid, or 'hf auth login' succeeded" >&2
|
| 121 |
+
echo " - the mirror endpoint is reachable: ${MIRROR_URL}" >&2
|
| 122 |
+
exit 1
|
| 123 |
+
fi
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
verify_files() {
|
| 127 |
+
local missing=0
|
| 128 |
+
local file
|
| 129 |
+
|
| 130 |
+
for file in "$@"; do
|
| 131 |
+
if [[ ! -f "${TARGET_DIR}/${file}" ]]; then
|
| 132 |
+
echo "Missing expected file: ${TARGET_DIR}/${file}" >&2
|
| 133 |
+
missing=1
|
| 134 |
+
fi
|
| 135 |
+
done
|
| 136 |
+
|
| 137 |
+
if [[ "${missing}" -ne 0 ]]; then
|
| 138 |
+
echo >&2
|
| 139 |
+
echo "Download did not complete successfully." >&2
|
| 140 |
+
echo "This usually means one of these:" >&2
|
| 141 |
+
echo " - the mirror endpoint could not be reached" >&2
|
| 142 |
+
echo " - access to the gated dataset was not approved" >&2
|
| 143 |
+
echo " - authentication was missing or invalid" >&2
|
| 144 |
+
exit 1
|
| 145 |
+
fi
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
download_meta() {
|
| 149 |
+
run_download "${HF_BIN[@]}" "${REPO_ID}" \
|
| 150 |
+
--repo-type dataset \
|
| 151 |
+
--local-dir "${TARGET_DIR}" \
|
| 152 |
+
--include "download_infoseek_jsonl.sh" \
|
| 153 |
+
--include "download_oven_jsonl.sh" \
|
| 154 |
+
--include "ovenid2impath.csv" \
|
| 155 |
+
"${TOKEN_ARGS[@]}"
|
| 156 |
+
verify_files "${META_FILES[@]}"
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
download_images() {
|
| 160 |
+
run_download "${HF_BIN[@]}" "${REPO_ID}" \
|
| 161 |
+
--repo-type dataset \
|
| 162 |
+
--local-dir "${TARGET_DIR}" \
|
| 163 |
+
--include "all_wikipedia_images.tar" \
|
| 164 |
+
--include "shard*.tar" \
|
| 165 |
+
"${TOKEN_ARGS[@]}"
|
| 166 |
+
verify_files "${IMAGE_FILES[@]}"
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
require_auth
|
| 170 |
+
|
| 171 |
+
case "${MODE}" in
|
| 172 |
+
meta)
|
| 173 |
+
download_meta
|
| 174 |
+
;;
|
| 175 |
+
images)
|
| 176 |
+
download_images
|
| 177 |
+
;;
|
| 178 |
+
all)
|
| 179 |
+
download_meta
|
| 180 |
+
download_images
|
| 181 |
+
;;
|
| 182 |
+
*)
|
| 183 |
+
echo "Unknown mode: ${MODE}" >&2
|
| 184 |
+
print_help >&2
|
| 185 |
+
exit 1
|
| 186 |
+
;;
|
| 187 |
+
esac
|
| 188 |
+
|
| 189 |
+
echo "Download completed. Files are under: ${TARGET_DIR}"
|
RL_dataset/download_scienceqa_hf.sh
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
MODE="${1:-all}"
|
| 5 |
+
|
| 6 |
+
REPO_ID="derek-thomas/ScienceQA"
|
| 7 |
+
ROOT_DIR="/workspace/xiaobin/RL_dataset/data/ScienceQA"
|
| 8 |
+
HF_DIR="${ROOT_DIR}/hf"
|
| 9 |
+
IMG_DIR="${ROOT_DIR}/images"
|
| 10 |
+
CACHE_DIR="${ROOT_DIR}/.hf_cache"
|
| 11 |
+
DEFAULT_ENDPOINT="https://hf-mirror.com"
|
| 12 |
+
HF_ENDPOINT_VALUE="${HF_ENDPOINT:-${HF_ENDPOINT_OVERRIDE:-${DEFAULT_ENDPOINT}}}"
|
| 13 |
+
|
| 14 |
+
unset http_proxy
|
| 15 |
+
unset https_proxy
|
| 16 |
+
unset HTTP_PROXY
|
| 17 |
+
unset HTTPS_PROXY
|
| 18 |
+
unset all_proxy
|
| 19 |
+
unset ALL_PROXY
|
| 20 |
+
|
| 21 |
+
export HF_ENDPOINT="${HF_ENDPOINT_VALUE}"
|
| 22 |
+
|
| 23 |
+
mkdir -p "${HF_DIR}" "${IMG_DIR}" "${CACHE_DIR}"
|
| 24 |
+
|
| 25 |
+
if command -v hf >/dev/null 2>&1; then
|
| 26 |
+
HF_BIN=(hf download)
|
| 27 |
+
elif command -v huggingface-cli >/dev/null 2>&1; then
|
| 28 |
+
HF_BIN=(huggingface-cli download)
|
| 29 |
+
else
|
| 30 |
+
echo "Missing Hugging Face CLI. Install it with:" >&2
|
| 31 |
+
echo " python -m pip install -U \"huggingface_hub[cli]\"" >&2
|
| 32 |
+
exit 1
|
| 33 |
+
fi
|
| 34 |
+
|
| 35 |
+
print_help() {
|
| 36 |
+
cat <<'EOF'
|
| 37 |
+
Usage:
|
| 38 |
+
bash download_scienceqa_hf.sh [parquet|images|all]
|
| 39 |
+
|
| 40 |
+
Modes:
|
| 41 |
+
parquet Download the public Hugging Face parquet files only
|
| 42 |
+
images Download the original ScienceQA image zip files only
|
| 43 |
+
all Download both parquet files and images
|
| 44 |
+
|
| 45 |
+
Output layout:
|
| 46 |
+
/workspace/xiaobin/RL_dataset/data/ScienceQA/hf
|
| 47 |
+
/workspace/xiaobin/RL_dataset/data/ScienceQA/images
|
| 48 |
+
|
| 49 |
+
Notes:
|
| 50 |
+
- This dataset is public and should not require an HF token.
|
| 51 |
+
- Image URLs are adapted from:
|
| 52 |
+
/workspace/xiaobin/RL_dataset/ScienceQA/tools/download.sh
|
| 53 |
+
- Proxies are unset before download.
|
| 54 |
+
- Default HF endpoint: https://hf-mirror.com
|
| 55 |
+
- To override and use the official endpoint:
|
| 56 |
+
HF_ENDPOINT=https://huggingface.co bash download_scienceqa_hf.sh parquet
|
| 57 |
+
EOF
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
if [[ "${MODE}" == "-h" || "${MODE}" == "--help" || "${MODE}" == "help" ]]; then
|
| 61 |
+
print_help
|
| 62 |
+
exit 0
|
| 63 |
+
fi
|
| 64 |
+
|
| 65 |
+
verify_glob() {
|
| 66 |
+
local pattern="$1"
|
| 67 |
+
|
| 68 |
+
if ! compgen -G "${pattern}" >/dev/null; then
|
| 69 |
+
echo "Missing expected file matching: ${pattern}" >&2
|
| 70 |
+
exit 1
|
| 71 |
+
fi
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
download_parquet() {
|
| 75 |
+
"${HF_BIN[@]}" "${REPO_ID}" \
|
| 76 |
+
--repo-type dataset \
|
| 77 |
+
--cache-dir "${CACHE_DIR}" \
|
| 78 |
+
--local-dir "${HF_DIR}" \
|
| 79 |
+
--include "data/*.parquet" \
|
| 80 |
+
--include "README.md" \
|
| 81 |
+
--include "ScienceQA.py"
|
| 82 |
+
|
| 83 |
+
verify_glob "${HF_DIR}/data/train-*.parquet"
|
| 84 |
+
verify_glob "${HF_DIR}/data/validation-*.parquet"
|
| 85 |
+
verify_glob "${HF_DIR}/data/test-*.parquet"
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
download_one_split() {
|
| 89 |
+
local split="$1"
|
| 90 |
+
local zip_path="${IMG_DIR}/${split}.zip"
|
| 91 |
+
local split_dir="${IMG_DIR}/${split}"
|
| 92 |
+
local url="https://scienceqa.s3.us-west-1.amazonaws.com/images/${split}.zip"
|
| 93 |
+
|
| 94 |
+
if [[ -d "${split_dir}" ]]; then
|
| 95 |
+
echo "Image split already exists: ${split_dir}"
|
| 96 |
+
return 0
|
| 97 |
+
fi
|
| 98 |
+
|
| 99 |
+
wget -c -O "${zip_path}" "${url}"
|
| 100 |
+
unzip -q -o "${zip_path}" -d "${IMG_DIR}"
|
| 101 |
+
rm -f "${zip_path}"
|
| 102 |
+
|
| 103 |
+
if [[ ! -d "${split_dir}" ]]; then
|
| 104 |
+
echo "Failed to extract image split: ${split}" >&2
|
| 105 |
+
exit 1
|
| 106 |
+
fi
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
download_images() {
|
| 110 |
+
download_one_split train
|
| 111 |
+
download_one_split val
|
| 112 |
+
download_one_split test
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
case "${MODE}" in
|
| 116 |
+
parquet)
|
| 117 |
+
download_parquet
|
| 118 |
+
;;
|
| 119 |
+
images)
|
| 120 |
+
download_images
|
| 121 |
+
;;
|
| 122 |
+
all)
|
| 123 |
+
download_parquet
|
| 124 |
+
download_images
|
| 125 |
+
;;
|
| 126 |
+
*)
|
| 127 |
+
echo "Unknown mode: ${MODE}" >&2
|
| 128 |
+
print_help >&2
|
| 129 |
+
exit 1
|
| 130 |
+
;;
|
| 131 |
+
esac
|
| 132 |
+
|
| 133 |
+
echo "Download completed."
|
| 134 |
+
echo "Parquet dir: ${HF_DIR}"
|
| 135 |
+
echo "Image dir: ${IMG_DIR}"
|
download_hf.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Hugging Face 断点续传下载脚本
|
| 4 |
+
镜像站: hf-mirror.com
|
| 5 |
+
目标: MMInstruction/M3IT
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
# 设置国内镜像站
|
| 12 |
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 13 |
+
|
| 14 |
+
from huggingface_hub import snapshot_download
|
| 15 |
+
from huggingface_hub import hf_hub_download
|
| 16 |
+
import huggingface_hub
|
| 17 |
+
|
| 18 |
+
REPO_ID = "MMInstruction/M3IT"
|
| 19 |
+
LOCAL_DIR = "/workspace/xiaobin/dataset"
|
| 20 |
+
REPO_TYPE = "dataset" # M3IT 是数据集
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def download():
|
| 24 |
+
print(f"镜像站: {os.environ['HF_ENDPOINT']}")
|
| 25 |
+
print(f"下载仓库: {REPO_ID}")
|
| 26 |
+
print(f"保存目录: {LOCAL_DIR}")
|
| 27 |
+
print("-" * 50)
|
| 28 |
+
|
| 29 |
+
os.makedirs(LOCAL_DIR, exist_ok=True)
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
snapshot_download(
|
| 33 |
+
repo_id=REPO_ID,
|
| 34 |
+
repo_type=REPO_TYPE,
|
| 35 |
+
local_dir=LOCAL_DIR,
|
| 36 |
+
local_dir_use_symlinks=False, # 直接复制文件,不用软链接
|
| 37 |
+
resume_download=True, # 断点续传
|
| 38 |
+
ignore_patterns=["*.gitattributes"],
|
| 39 |
+
)
|
| 40 |
+
print("\n下载完成!")
|
| 41 |
+
|
| 42 |
+
except Exception as e:
|
| 43 |
+
print(f"\n出错: {e}")
|
| 44 |
+
print("提示: 如果是模型仓库,请将 REPO_TYPE 改为 'model' 后重试")
|
| 45 |
+
sys.exit(1)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
download()
|