Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- SpecForge-ext/.pre-commit-config.yaml +53 -0
- SpecForge-ext/convert_mtbench.py +22 -0
- SpecForge-ext/download_datasets.py +64 -0
- SpecForge-ext/download_mtbench.sh +23 -0
- SpecForge-ext/download_mtbench_data.py +51 -0
- SpecForge-ext/mtbench_sample.json +26 -0
- SpecForge-ext/pyproject.toml +44 -0
- SpecForge-ext/requirements.txt +0 -0
- SpecForge-ext/setup.py +33 -0
- SpecForge-ext/test_accept_length.md +300 -0
- SpecForge/.editorconfig +25 -0
- SpecForge/.isort.cfg +3 -0
- SpecForge/.pre-commit-config.yaml +53 -0
- SpecForge/LICENSE +21 -0
- SpecForge/MANIFEST.in +2 -0
- SpecForge/README.md +70 -0
- SpecForge/pyproject.toml +47 -0
- SpecForge/requirements-rocm.txt +20 -0
- SpecForge/version.txt +1 -0
- idea1/.editorconfig +25 -0
- idea1/.isort.cfg +3 -0
- idea1/.pre-commit-config.yaml +53 -0
- idea1/LICENSE +21 -0
- idea1/requirements-rocm.txt +20 -0
- idea1/version.txt +1 -0
- qwen3-8b_dflash_regen/.gitattributes +36 -0
- syxin/backup.log +0 -0
- syxin/dflash_lora_changelog.md +232 -0
- syxin/eval_accepted_length.md +217 -0
- syxin/eval_dflash_b16_baseline.py +354 -0
- syxin/eval_dflash_lora_inject.py +627 -0
- syxin/idea.md +23 -0
- syxin/launch_train.sh +37 -0
- syxin/launch_train_wrapper.py +21 -0
- syxin/list.md +12 -0
- syxin/merge_lora.py +66 -0
- syxin/oom_fix_progress.md +42 -0
- syxin/requirements.txt +0 -0
- syxin/run_bench.sh +68 -0
- syxin/run_bench_dflash.sh +71 -0
- syxin/run_bench_dflash_b16_baseline.sh +60 -0
- syxin/run_qwen3_8b_sft_32gpu.sh +31 -0
- syxin/run_train_dflash_direct_inject.sh +56 -0
- syxin/run_train_dflash_lora_inject.sh +71 -0
- syxin/run_train_multinode.sh +67 -0
- syxin/run_train_qwen3_8b_sft_32gpu.sh +66 -0
- syxin/server.log +186 -0
- syxin/start_server.sh +42 -0
- syxin/start_server_dflash.sh +54 -0
- syxin/step1.md +139 -0
SpecForge-ext/.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
default_stages: [pre-commit, pre-push, manual]
|
| 2 |
+
|
| 3 |
+
repos:
|
| 4 |
+
- repo: https://github.com/PyCQA/autoflake
|
| 5 |
+
rev: v2.3.1
|
| 6 |
+
hooks:
|
| 7 |
+
- id: autoflake
|
| 8 |
+
args: [--remove-all-unused-imports, --in-place]
|
| 9 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 10 |
+
rev: v5.0.0
|
| 11 |
+
hooks:
|
| 12 |
+
- id: check-symlinks
|
| 13 |
+
- id: destroyed-symlinks
|
| 14 |
+
- id: trailing-whitespace
|
| 15 |
+
- id: end-of-file-fixer
|
| 16 |
+
- id: check-yaml
|
| 17 |
+
args: [--allow-multiple-documents]
|
| 18 |
+
- id: check-toml
|
| 19 |
+
- id: check-ast
|
| 20 |
+
- id: check-added-large-files
|
| 21 |
+
- id: check-merge-conflict
|
| 22 |
+
- id: check-shebang-scripts-are-executable
|
| 23 |
+
- id: detect-private-key
|
| 24 |
+
- id: debug-statements
|
| 25 |
+
- id: no-commit-to-branch
|
| 26 |
+
- repo: https://github.com/PyCQA/isort
|
| 27 |
+
rev: 5.13.2
|
| 28 |
+
hooks:
|
| 29 |
+
- id: isort
|
| 30 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 31 |
+
rev: v0.11.10
|
| 32 |
+
hooks:
|
| 33 |
+
- id: ruff
|
| 34 |
+
args: [--select=F401, --fixable=F401]
|
| 35 |
+
files: ^(benchmark/|docs/|examples/)
|
| 36 |
+
exclude: \.ipynb$
|
| 37 |
+
- repo: https://github.com/psf/black
|
| 38 |
+
rev: 24.10.0
|
| 39 |
+
hooks:
|
| 40 |
+
- id: black-jupyter
|
| 41 |
+
- repo: https://github.com/pre-commit/mirrors-clang-format
|
| 42 |
+
rev: v18.1.8
|
| 43 |
+
hooks:
|
| 44 |
+
- id: clang-format
|
| 45 |
+
types_or: [c++, cuda]
|
| 46 |
+
args: [--style=file, --verbose]
|
| 47 |
+
- repo: https://github.com/kynan/nbstripout
|
| 48 |
+
rev: 0.8.1
|
| 49 |
+
hooks:
|
| 50 |
+
- id: nbstripout
|
| 51 |
+
args:
|
| 52 |
+
- '--keep-output'
|
| 53 |
+
- '--extra-keys=metadata.kernelspec metadata.language_info.version'
|
SpecForge-ext/convert_mtbench.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# 读取 JSON 文件并转换为 JSONL
|
| 6 |
+
input_file = "/workspace/hanrui/SpecForge-ext/mtbench_sample.json"
|
| 7 |
+
with open(input_file, 'r') as f:
|
| 8 |
+
data = json.load(f)
|
| 9 |
+
|
| 10 |
+
# 保存为 jsonl
|
| 11 |
+
cache_dir = os.path.expanduser("~/.cache/sglang")
|
| 12 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 13 |
+
output_file = os.path.join(cache_dir, "mtbench.jsonl")
|
| 14 |
+
|
| 15 |
+
with open(output_file, 'w') as f:
|
| 16 |
+
for item in data:
|
| 17 |
+
f.write(json.dumps(item) + '\n')
|
| 18 |
+
|
| 19 |
+
print(f"Converted {len(data)} questions")
|
| 20 |
+
print(f"Saved to {output_file}")
|
| 21 |
+
print(f"\nFirst question:")
|
| 22 |
+
print(json.dumps(data[0], indent=2))
|
SpecForge-ext/download_datasets.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
下载 GSM8K 和 HumanEval 数据集到本地
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import requests
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
|
| 10 |
+
DATA_DIR = "/workspace/hanrui/datasets"
|
| 11 |
+
os.makedirs(DATA_DIR, exist_ok=True)
|
| 12 |
+
|
| 13 |
+
print("=" * 60)
|
| 14 |
+
print("下载 GSM8K 数据集")
|
| 15 |
+
print("=" * 60)
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
# 下载 GSM8K
|
| 19 |
+
gsm8k_dir = os.path.join(DATA_DIR, "gsm8k")
|
| 20 |
+
os.makedirs(gsm8k_dir, exist_ok=True)
|
| 21 |
+
|
| 22 |
+
print("Loading GSM8K from HuggingFace...")
|
| 23 |
+
dataset = load_dataset("gsm8k", "main", split="test")
|
| 24 |
+
|
| 25 |
+
# 保存为 jsonl
|
| 26 |
+
output_file = os.path.join(gsm8k_dir, "test.jsonl")
|
| 27 |
+
with open(output_file, 'w') as f:
|
| 28 |
+
for item in dataset:
|
| 29 |
+
f.write(json.dumps(item) + '\n')
|
| 30 |
+
|
| 31 |
+
print(f"✓ GSM8K saved to {output_file}")
|
| 32 |
+
print(f" Total samples: {len(dataset)}")
|
| 33 |
+
|
| 34 |
+
except Exception as e:
|
| 35 |
+
print(f"✗ GSM8K download failed: {e}")
|
| 36 |
+
|
| 37 |
+
print("\n" + "=" * 60)
|
| 38 |
+
print("下载 HumanEval 数据集")
|
| 39 |
+
print("=" * 60)
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
# 下载 HumanEval
|
| 43 |
+
humaneval_dir = os.path.join(DATA_DIR, "humaneval")
|
| 44 |
+
os.makedirs(humaneval_dir, exist_ok=True)
|
| 45 |
+
|
| 46 |
+
print("Loading HumanEval from HuggingFace...")
|
| 47 |
+
dataset = load_dataset("openai_humaneval", split="test")
|
| 48 |
+
|
| 49 |
+
# 保存为 jsonl
|
| 50 |
+
output_file = os.path.join(humaneval_dir, "test.jsonl")
|
| 51 |
+
with open(output_file, 'w') as f:
|
| 52 |
+
for item in dataset:
|
| 53 |
+
f.write(json.dumps(item) + '\n')
|
| 54 |
+
|
| 55 |
+
print(f"✓ HumanEval saved to {output_file}")
|
| 56 |
+
print(f" Total samples: {len(dataset)}")
|
| 57 |
+
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"✗ HumanEval download failed: {e}")
|
| 60 |
+
|
| 61 |
+
print("\n" + "=" * 60)
|
| 62 |
+
print("下载完成")
|
| 63 |
+
print("=" * 60)
|
| 64 |
+
print(f"数据保存在: {DATA_DIR}")
|
SpecForge-ext/download_mtbench.sh
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# 下载 mtbench 数据文件
|
| 4 |
+
# 如果无法访问 GitHub,需要手动下载或使用镜像
|
| 5 |
+
|
| 6 |
+
CACHE_DIR="$HOME/.cache/sglang"
|
| 7 |
+
mkdir -p "$CACHE_DIR"
|
| 8 |
+
|
| 9 |
+
echo "Downloading mtbench data..."
|
| 10 |
+
|
| 11 |
+
# 方法1:尝试使用代理下载
|
| 12 |
+
https_proxy=http://10.1.2.1:7890 http_proxy=http://10.1.2.1:7890 \
|
| 13 |
+
curl -L "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl" \
|
| 14 |
+
-o "$CACHE_DIR/mtbench.jsonl"
|
| 15 |
+
|
| 16 |
+
if [ $? -eq 0 ]; then
|
| 17 |
+
echo "Downloaded to $CACHE_DIR/mtbench.jsonl"
|
| 18 |
+
ls -lh "$CACHE_DIR/mtbench.jsonl"
|
| 19 |
+
else
|
| 20 |
+
echo "Download failed. Please manually download the file from:"
|
| 21 |
+
echo "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl"
|
| 22 |
+
echo "And save it to: $CACHE_DIR/mtbench.jsonl"
|
| 23 |
+
fi
|
SpecForge-ext/download_mtbench_data.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
下载并转换 MT-Bench 数据到本地目录
|
| 4 |
+
"""
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import requests
|
| 8 |
+
|
| 9 |
+
# 目标目录
|
| 10 |
+
DATA_DIR = "/workspace/hanrui/datasets/mtbench"
|
| 11 |
+
os.makedirs(DATA_DIR, exist_ok=True)
|
| 12 |
+
|
| 13 |
+
# 下载 MT-Bench 问题数据
|
| 14 |
+
url = "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl"
|
| 15 |
+
output_file = os.path.join(DATA_DIR, "question.jsonl")
|
| 16 |
+
|
| 17 |
+
print(f"Downloading MT-Bench questions from {url}")
|
| 18 |
+
print(f"Saving to {output_file}")
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
# 使用代理下载
|
| 22 |
+
proxies = {
|
| 23 |
+
'http': 'http://10.1.2.1:7890',
|
| 24 |
+
'https': 'http://10.1.2.1:7890',
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
response = requests.get(url, proxies=proxies, timeout=30)
|
| 28 |
+
response.raise_for_status()
|
| 29 |
+
|
| 30 |
+
with open(output_file, 'wb') as f:
|
| 31 |
+
f.write(response.content)
|
| 32 |
+
|
| 33 |
+
print(f"✓ Downloaded successfully")
|
| 34 |
+
|
| 35 |
+
# 验证数据
|
| 36 |
+
with open(output_file, 'r') as f:
|
| 37 |
+
lines = f.readlines()
|
| 38 |
+
|
| 39 |
+
print(f"✓ Total questions: {len(lines)}")
|
| 40 |
+
|
| 41 |
+
# 显示第一个问题
|
| 42 |
+
first_question = json.loads(lines[0])
|
| 43 |
+
print(f"\nFirst question:")
|
| 44 |
+
print(json.dumps(first_question, indent=2))
|
| 45 |
+
|
| 46 |
+
except Exception as e:
|
| 47 |
+
print(f"✗ Download failed: {e}")
|
| 48 |
+
print(f"\nPlease manually download from:")
|
| 49 |
+
print(f" {url}")
|
| 50 |
+
print(f"And save to:")
|
| 51 |
+
print(f" {output_file}")
|
SpecForge-ext/mtbench_sample.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"question_id": 1,
|
| 4 |
+
"category": "writing",
|
| 5 |
+
"turns": [
|
| 6 |
+
"Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.",
|
| 7 |
+
"Rewrite your previous response. Start every sentence with the letter A."
|
| 8 |
+
]
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"question_id": 2,
|
| 12 |
+
"category": "roleplay",
|
| 13 |
+
"turns": [
|
| 14 |
+
"Imagine you are writing a blog post comparing two popular smartphone models. Develop an outline for the blog post, including key points and subheadings to effectively compare and contrast the features, performance, and user experience of the two models. Please answer in fewer than 200 words.",
|
| 15 |
+
"Take your previous response and rephrase it as a limerick."
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"question_id": 3,
|
| 20 |
+
"category": "reasoning",
|
| 21 |
+
"turns": [
|
| 22 |
+
"Describe a vivid and unique character, using strong imagery and creative language. Please answer in fewer than two paragraphs.",
|
| 23 |
+
"Revise your previous response and incorporate an allusion to a famous work of literature or historical event in each sentence."
|
| 24 |
+
]
|
| 25 |
+
}
|
| 26 |
+
]
|
SpecForge-ext/pyproject.toml
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "specforge"
|
| 7 |
+
dynamic = ["version", "description"]
|
| 8 |
+
readme = "README.md"
|
| 9 |
+
requires-python = ">=3.11"
|
| 10 |
+
dependencies = [
|
| 11 |
+
"pre-commit",
|
| 12 |
+
"torch==2.9.1",
|
| 13 |
+
"torchaudio==2.9.1",
|
| 14 |
+
"torchvision==0.24.1",
|
| 15 |
+
"transformers==4.57.1",
|
| 16 |
+
"qwen-vl-utils==0.0.11",
|
| 17 |
+
"datasets",
|
| 18 |
+
"setuptools",
|
| 19 |
+
"tqdm",
|
| 20 |
+
"wandb",
|
| 21 |
+
"psutil",
|
| 22 |
+
"numpy",
|
| 23 |
+
"accelerate",
|
| 24 |
+
"pydantic",
|
| 25 |
+
"sglang==0.5.6",
|
| 26 |
+
"openai-harmony",
|
| 27 |
+
"ninja",
|
| 28 |
+
"packaging",
|
| 29 |
+
"yunchang",
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
[tool.setuptools]
|
| 33 |
+
packages = ["specforge"]
|
| 34 |
+
|
| 35 |
+
[project.optional-dependencies]
|
| 36 |
+
dev = [
|
| 37 |
+
"pre-commit",
|
| 38 |
+
"unittest"
|
| 39 |
+
]
|
| 40 |
+
fa = ["flash-attn"]
|
| 41 |
+
|
| 42 |
+
[tool.setuptools.dynamic]
|
| 43 |
+
version = {file = "version.txt"}
|
| 44 |
+
description = {file = "README.md"}
|
SpecForge-ext/requirements.txt
ADDED
|
File without changes
|
SpecForge-ext/setup.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tomllib
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from setuptools import find_packages, setup
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def read_readme():
|
| 8 |
+
with open("README.md", "r") as f:
|
| 9 |
+
return f.read()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def read_version():
|
| 13 |
+
with open("version.txt", "r") as f:
|
| 14 |
+
return f.read().strip()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def read_dependencies():
|
| 18 |
+
pyproject_path = Path(__file__).parent / "pyproject.toml"
|
| 19 |
+
with open(pyproject_path, "rb") as f:
|
| 20 |
+
pyproject = tomllib.load(f)
|
| 21 |
+
return pyproject.get("project", {}).get("dependencies", [])
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
setup(
|
| 25 |
+
name="specforge",
|
| 26 |
+
packages=find_packages(exclude=["configs", "scripts", "tests"]),
|
| 27 |
+
version=read_version(),
|
| 28 |
+
install_requires=read_dependencies(),
|
| 29 |
+
long_description=read_readme(),
|
| 30 |
+
long_description_content_type="text/markdown",
|
| 31 |
+
author="SGLang Team",
|
| 32 |
+
url="https://github.com/sgl-project/SpecForge",
|
| 33 |
+
)
|
SpecForge-ext/test_accept_length.md
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Accept Length 测试指南
|
| 2 |
+
|
| 3 |
+
## 0. 准备工作
|
| 4 |
+
|
| 5 |
+
### 创建目录
|
| 6 |
+
```bash
|
| 7 |
+
cd /workspace/hanrui/SpecForge-ext
|
| 8 |
+
mkdir -p logs results
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
### 下载数据集(首次运行)
|
| 12 |
+
```bash
|
| 13 |
+
cd /workspace/hanrui/SpecForge-ext
|
| 14 |
+
python download_datasets.py
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
数据保存位置:
|
| 18 |
+
- MT-Bench: `/workspace/hanrui/datasets/mtbench/question.jsonl`
|
| 19 |
+
- GSM8K: `/workspace/hanrui/datasets/gsm8k/test.jsonl`
|
| 20 |
+
- HumanEval: `/workspace/hanrui/datasets/humaneval/test.jsonl`
|
| 21 |
+
|
| 22 |
+
---
|
| 23 |
+
|
| 24 |
+
## 1. 测试 Baseline 模型
|
| 25 |
+
|
| 26 |
+
### 启动服务器(终端1)
|
| 27 |
+
```bash
|
| 28 |
+
cd /workspace/hanrui/SpecForge-ext
|
| 29 |
+
|
| 30 |
+
# 设置环境变量
|
| 31 |
+
export NO_PROXY="localhost,127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16"
|
| 32 |
+
export no_proxy="localhost,127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16"
|
| 33 |
+
|
| 34 |
+
# 启动 baseline 服务器
|
| 35 |
+
python3 -m sglang.launch_server \
|
| 36 |
+
--model /workspace/Qwen3-8B \
|
| 37 |
+
--speculative-algorithm EAGLE3 \
|
| 38 |
+
--speculative-draft-model-path /workspace/qwen3_8b_eagle3 \
|
| 39 |
+
--speculative-num-steps 3 \
|
| 40 |
+
--speculative-eagle-topk 1 \
|
| 41 |
+
--speculative-num-draft-tokens 4 \
|
| 42 |
+
--mem-fraction-static 0.75 \
|
| 43 |
+
--cuda-graph-max-bs 1 \
|
| 44 |
+
--tp 1 \
|
| 45 |
+
--trust-remote-code \
|
| 46 |
+
--host 0.0.0.0 \
|
| 47 |
+
--port 30000 \
|
| 48 |
+
--dtype bfloat16 \
|
| 49 |
+
--skip-server-warmup
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
等待看到 `Application startup complete` 后,继续下一步。
|
| 53 |
+
|
| 54 |
+
### 运行三个 Benchmark(终端2)
|
| 55 |
+
```bash
|
| 56 |
+
cd /workspace/hanrui/SpecForge-ext
|
| 57 |
+
conda activate /workspace/Hanrui/
|
| 58 |
+
|
| 59 |
+
# 设置环境变量
|
| 60 |
+
export NO_PROXY="localhost,127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16"
|
| 61 |
+
export no_proxy="localhost,127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16"
|
| 62 |
+
|
| 63 |
+
# 1. MT-Bench
|
| 64 |
+
echo "=== Running MT-Bench (Baseline) ==="
|
| 65 |
+
python benchmarks/bench_eagle3.py \
|
| 66 |
+
--model-path /workspace/Qwen3-8B \
|
| 67 |
+
--host 10.1.1.31 \
|
| 68 |
+
--port 30000 \
|
| 69 |
+
--config-list 1,3,1,4 \
|
| 70 |
+
--benchmark-list mtbench:80 \
|
| 71 |
+
--dtype bfloat16 \
|
| 72 |
+
--skip-launch-server \
|
| 73 |
+
--name baseline_mtbench \
|
| 74 |
+
--output-dir ./results \
|
| 75 |
+
2>&1 | tee logs/baseline_mtbench_$(date +%Y%m%d_%H%M%S).log
|
| 76 |
+
|
| 77 |
+
# 2. GSM8K
|
| 78 |
+
echo "=== Running GSM8K (Baseline) ==="
|
| 79 |
+
python benchmarks/bench_eagle3.py \
|
| 80 |
+
--model-path /workspace/Qwen3-8B \
|
| 81 |
+
--host 10.1.1.31 \
|
| 82 |
+
--port 30000 \
|
| 83 |
+
--config-list 1,3,1,4 \
|
| 84 |
+
--benchmark-list gsm8k:100 \
|
| 85 |
+
--dtype bfloat16 \
|
| 86 |
+
--skip-launch-server \
|
| 87 |
+
--name baseline_gsm8k \
|
| 88 |
+
--output-dir ./results \
|
| 89 |
+
2>&1 | tee logs/baseline_gsm8k_$(date +%Y%m%d_%H%M%S).log
|
| 90 |
+
|
| 91 |
+
# 3. HumanEval
|
| 92 |
+
echo "=== Running HumanEval (Baseline) ==="
|
| 93 |
+
python benchmarks/bench_eagle3.py \
|
| 94 |
+
--model-path /workspace/Qwen3-8B \
|
| 95 |
+
--host 10.1.1.31 \
|
| 96 |
+
--port 30000 \
|
| 97 |
+
--config-list 1,3,1,4 \
|
| 98 |
+
--benchmark-list humaneval:164 \
|
| 99 |
+
--dtype bfloat16 \
|
| 100 |
+
--skip-launch-server \
|
| 101 |
+
--name baseline_humaneval \
|
| 102 |
+
--output-dir ./results \
|
| 103 |
+
2>&1 | tee logs/baseline_humaneval_$(date +%Y%m%d_%H%M%S).log
|
| 104 |
+
|
| 105 |
+
echo "=== Baseline 测试完成 ==="
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
---
|
| 109 |
+
|
| 110 |
+
## 2. 测试训练后的模型
|
| 111 |
+
|
| 112 |
+
### 停止 Baseline 服务器并启动训练后的服务器(终端1)
|
| 113 |
+
```bash
|
| 114 |
+
cd /workspace/hanrui/SpecForge-ext
|
| 115 |
+
|
| 116 |
+
# 停止旧服务器
|
| 117 |
+
pkill -f "sglang.launch_server"
|
| 118 |
+
sleep 5
|
| 119 |
+
|
| 120 |
+
# 设置环境变量
|
| 121 |
+
export NO_PROXY="localhost,127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16"
|
| 122 |
+
export no_proxy="localhost,127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16"
|
| 123 |
+
|
| 124 |
+
# 启动训练后的服务器
|
| 125 |
+
python3 -m sglang.launch_server \
|
| 126 |
+
--model /workspace/Qwen3-8B \
|
| 127 |
+
--speculative-algorithm EAGLE3 \
|
| 128 |
+
--speculative-draft-model-path /workspace/hanrui/SpecForge-ext/outputs/qwen3-8b-qwen3eagle-5layer/epoch_9_step_12310 \
|
| 129 |
+
--speculative-num-steps 3 \
|
| 130 |
+
--speculative-eagle-topk 1 \
|
| 131 |
+
--speculative-num-draft-tokens 4 \
|
| 132 |
+
--mem-fraction-static 0.75 \
|
| 133 |
+
--cuda-graph-max-bs 1 \
|
| 134 |
+
--tp 1 \
|
| 135 |
+
--trust-remote-code \
|
| 136 |
+
--host 0.0.0.0 \
|
| 137 |
+
--port 30000 \
|
| 138 |
+
--dtype bfloat16 \
|
| 139 |
+
--skip-server-warmup
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
等待看到 `Application startup complete` 后,继续下一步。
|
| 143 |
+
|
| 144 |
+
### 运行三个 Benchmark(终端2)
|
| 145 |
+
```bash
|
| 146 |
+
cd /workspace/hanrui/SpecForge-ext
|
| 147 |
+
conda activate /workspace/Hanrui/
|
| 148 |
+
|
| 149 |
+
# 设置环境变量
|
| 150 |
+
export NO_PROXY="localhost,127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16"
|
| 151 |
+
export no_proxy="localhost,127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16"
|
| 152 |
+
|
| 153 |
+
# 1. MT-Bench
|
| 154 |
+
echo "=== Running MT-Bench (Trained) ==="
|
| 155 |
+
python benchmarks/bench_eagle3.py \
|
| 156 |
+
--model-path /workspace/Qwen3-8B \
|
| 157 |
+
--host 10.1.1.31 \
|
| 158 |
+
--port 30000 \
|
| 159 |
+
--config-list 1,3,1,4 \
|
| 160 |
+
--benchmark-list mtbench:80 \
|
| 161 |
+
--dtype bfloat16 \
|
| 162 |
+
--skip-launch-server \
|
| 163 |
+
--name trained_mtbench \
|
| 164 |
+
--output-dir ./results \
|
| 165 |
+
2>&1 | tee logs/trained_mtbench_$(date +%Y%m%d_%H%M%S).log
|
| 166 |
+
|
| 167 |
+
# 2. GSM8K
|
| 168 |
+
echo "=== Running GSM8K (Trained) ==="
|
| 169 |
+
python benchmarks/bench_eagle3.py \
|
| 170 |
+
--model-path /workspace/Qwen3-8B \
|
| 171 |
+
--host 10.1.1.31 \
|
| 172 |
+
--port 30000 \
|
| 173 |
+
--config-list 1,3,1,4 \
|
| 174 |
+
--benchmark-list gsm8k:100 \
|
| 175 |
+
--dtype bfloat16 \
|
| 176 |
+
--skip-launch-server \
|
| 177 |
+
--name trained_gsm8k \
|
| 178 |
+
--output-dir ./results \
|
| 179 |
+
2>&1 | tee logs/trained_gsm8k_$(date +%Y%m%d_%H%M%S).log
|
| 180 |
+
|
| 181 |
+
# 3. HumanEval
|
| 182 |
+
echo "=== Running HumanEval (Trained) ==="
|
| 183 |
+
python benchmarks/bench_eagle3.py \
|
| 184 |
+
--model-path /workspace/Qwen3-8B \
|
| 185 |
+
--host 10.1.1.31 \
|
| 186 |
+
--port 30000 \
|
| 187 |
+
--config-list 1,3,1,4 \
|
| 188 |
+
--benchmark-list humaneval:164 \
|
| 189 |
+
--dtype bfloat16 \
|
| 190 |
+
--skip-launch-server \
|
| 191 |
+
--name trained_humaneval \
|
| 192 |
+
--output-dir ./results \
|
| 193 |
+
2>&1 | tee logs/trained_humaneval_$(date +%Y%m%d_%H%M%S).log
|
| 194 |
+
|
| 195 |
+
echo "=== Trained 测试完成 ==="
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
---
|
| 199 |
+
|
| 200 |
+
## 3. 查看结果
|
| 201 |
+
|
| 202 |
+
### 日志文件位置
|
| 203 |
+
所有日志保存在:`/workspace/hanrui/SpecForge-ext/logs/`
|
| 204 |
+
- `baseline_mtbench_*.log`
|
| 205 |
+
- `baseline_gsm8k_*.log`
|
| 206 |
+
- `baseline_humaneval_*.log`
|
| 207 |
+
- `trained_mtbench_*.log`
|
| 208 |
+
- `trained_gsm8k_*.log`
|
| 209 |
+
- `trained_humaneval_*.log`
|
| 210 |
+
|
| 211 |
+
所有结果保存在:`/workspace/hanrui/SpecForge-ext/results/`
|
| 212 |
+
- `baseline_mtbench_*.jsonl`
|
| 213 |
+
- `baseline_gsm8k_*.jsonl`
|
| 214 |
+
- `baseline_humaneval_*.jsonl`
|
| 215 |
+
- `trained_mtbench_*.jsonl`
|
| 216 |
+
- `trained_gsm8k_*.jsonl`
|
| 217 |
+
- `trained_humaneval_*.jsonl`
|
| 218 |
+
|
| 219 |
+
### 生成对比报告
|
| 220 |
+
```bash
|
| 221 |
+
cd /workspace/hanrui/SpecForge-ext
|
| 222 |
+
|
| 223 |
+
python3 << 'EOF'
|
| 224 |
+
import json
|
| 225 |
+
import glob
|
| 226 |
+
|
| 227 |
+
print("=" * 80)
|
| 228 |
+
print("Accept Length 对比报告")
|
| 229 |
+
print("=" * 80)
|
| 230 |
+
|
| 231 |
+
datasets = ['mtbench', 'gsm8k', 'humaneval']
|
| 232 |
+
|
| 233 |
+
for dataset in datasets:
|
| 234 |
+
print(f"\n{'=' * 80}")
|
| 235 |
+
print(f"{dataset.upper()} 结果对比")
|
| 236 |
+
print('=' * 80)
|
| 237 |
+
|
| 238 |
+
baseline_files = sorted(glob.glob(f'results/baseline_{dataset}_*.jsonl'))
|
| 239 |
+
trained_files = sorted(glob.glob(f'results/trained_{dataset}_*.jsonl'))
|
| 240 |
+
|
| 241 |
+
if not baseline_files or not trained_files:
|
| 242 |
+
print(f" 未找到 {dataset} 的结果文件")
|
| 243 |
+
continue
|
| 244 |
+
|
| 245 |
+
with open(baseline_files[-1], 'r') as f:
|
| 246 |
+
baseline = json.load(f)
|
| 247 |
+
|
| 248 |
+
with open(trained_files[-1], 'r') as f:
|
| 249 |
+
trained = json.load(f)
|
| 250 |
+
|
| 251 |
+
baseline_metrics = baseline[dataset][0]['metrics'][0]
|
| 252 |
+
trained_metrics = trained[dataset][0]['metrics'][0]
|
| 253 |
+
|
| 254 |
+
print(f"\nBaseline:")
|
| 255 |
+
print(f" Accept Length: {baseline_metrics['accept_length']:.4f}")
|
| 256 |
+
print(f" Output Throughput: {baseline_metrics['output_throughput']:.2f} tokens/s")
|
| 257 |
+
if 'accuracy' in baseline_metrics and baseline_metrics['accuracy'] is not None:
|
| 258 |
+
print(f" Accuracy: {baseline_metrics['accuracy']:.2%}")
|
| 259 |
+
|
| 260 |
+
print(f"\nTrained:")
|
| 261 |
+
print(f" Accept Length: {trained_metrics['accept_length']:.4f}")
|
| 262 |
+
print(f" Output Throughput: {trained_metrics['output_throughput']:.2f} tokens/s")
|
| 263 |
+
if 'accuracy' in trained_metrics and trained_metrics['accuracy'] is not None:
|
| 264 |
+
print(f" Accuracy: {trained_metrics['accuracy']:.2%}")
|
| 265 |
+
|
| 266 |
+
accept_diff = trained_metrics['accept_length'] - baseline_metrics['accept_length']
|
| 267 |
+
accept_pct = (accept_diff / baseline_metrics['accept_length']) * 100
|
| 268 |
+
|
| 269 |
+
throughput_diff = trained_metrics['output_throughput'] - baseline_metrics['output_throughput']
|
| 270 |
+
throughput_pct = (throughput_diff / baseline_metrics['output_throughput']) * 100
|
| 271 |
+
|
| 272 |
+
print(f"\n差异:")
|
| 273 |
+
print(f" Accept Length: {accept_diff:+.4f} ({accept_pct:+.2f}%)")
|
| 274 |
+
print(f" Throughput: {throughput_diff:+.2f} tokens/s ({throughput_pct:+.2f}%)")
|
| 275 |
+
|
| 276 |
+
if 'accuracy' in baseline_metrics and baseline_metrics['accuracy'] is not None:
|
| 277 |
+
acc_diff = trained_metrics['accuracy'] - baseline_metrics['accuracy']
|
| 278 |
+
acc_pct = acc_diff * 100
|
| 279 |
+
print(f" Accuracy: {acc_pct:+.2f} percentage points")
|
| 280 |
+
|
| 281 |
+
print("\n" + "=" * 80)
|
| 282 |
+
EOF
|
| 283 |
+
```
|
| 284 |
+
|
| 285 |
+
---
|
| 286 |
+
|
| 287 |
+
## 4. 快速查看单个结果
|
| 288 |
+
```bash
|
| 289 |
+
cd /workspace/hanrui/SpecForge-ext
|
| 290 |
+
|
| 291 |
+
# 查看 baseline 的 accept_length
|
| 292 |
+
cat results/baseline_mtbench_*.jsonl | jq '.mtbench[0].metrics[0].accept_length'
|
| 293 |
+
cat results/baseline_gsm8k_*.jsonl | jq '.gsm8k[0].metrics[0].accept_length'
|
| 294 |
+
cat results/baseline_humaneval_*.jsonl | jq '.humaneval[0].metrics[0].accept_length'
|
| 295 |
+
|
| 296 |
+
# 查看 trained 的 accept_length
|
| 297 |
+
cat results/trained_mtbench_*.jsonl | jq '.mtbench[0].metrics[0].accept_length'
|
| 298 |
+
cat results/trained_gsm8k_*.jsonl | jq '.gsm8k[0].metrics[0].accept_length'
|
| 299 |
+
cat results/trained_humaneval_*.jsonl | jq '.humaneval[0].metrics[0].accept_length'
|
| 300 |
+
```
|
SpecForge/.editorconfig
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://editorconfig.org/
|
| 2 |
+
|
| 3 |
+
root = true
|
| 4 |
+
|
| 5 |
+
[*]
|
| 6 |
+
charset = utf-8
|
| 7 |
+
end_of_line = lf
|
| 8 |
+
indent_style = space
|
| 9 |
+
indent_size = 4
|
| 10 |
+
trim_trailing_whitespace = true
|
| 11 |
+
insert_final_newline = true
|
| 12 |
+
|
| 13 |
+
[*.{json,yaml,yml}]
|
| 14 |
+
indent_size = 2
|
| 15 |
+
|
| 16 |
+
[*.md]
|
| 17 |
+
indent_size = 2
|
| 18 |
+
x-soft-wrap-text = true
|
| 19 |
+
|
| 20 |
+
[*.rst]
|
| 21 |
+
indent_size = 4
|
| 22 |
+
x-soft-wrap-text = true
|
| 23 |
+
|
| 24 |
+
[Makefile]
|
| 25 |
+
indent_style = tab
|
SpecForge/.isort.cfg
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[settings]
|
| 2 |
+
profile=black
|
| 3 |
+
known_first_party=sgl-eagle
|
SpecForge/.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
default_stages: [pre-commit, pre-push, manual]
|
| 2 |
+
|
| 3 |
+
repos:
|
| 4 |
+
- repo: https://github.com/PyCQA/autoflake
|
| 5 |
+
rev: v2.3.1
|
| 6 |
+
hooks:
|
| 7 |
+
- id: autoflake
|
| 8 |
+
args: [--remove-all-unused-imports, --in-place]
|
| 9 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 10 |
+
rev: v5.0.0
|
| 11 |
+
hooks:
|
| 12 |
+
- id: check-symlinks
|
| 13 |
+
- id: destroyed-symlinks
|
| 14 |
+
- id: trailing-whitespace
|
| 15 |
+
- id: end-of-file-fixer
|
| 16 |
+
- id: check-yaml
|
| 17 |
+
args: [--allow-multiple-documents]
|
| 18 |
+
- id: check-toml
|
| 19 |
+
- id: check-ast
|
| 20 |
+
- id: check-added-large-files
|
| 21 |
+
- id: check-merge-conflict
|
| 22 |
+
- id: check-shebang-scripts-are-executable
|
| 23 |
+
- id: detect-private-key
|
| 24 |
+
- id: debug-statements
|
| 25 |
+
- id: no-commit-to-branch
|
| 26 |
+
- repo: https://github.com/PyCQA/isort
|
| 27 |
+
rev: 5.13.2
|
| 28 |
+
hooks:
|
| 29 |
+
- id: isort
|
| 30 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 31 |
+
rev: v0.11.10
|
| 32 |
+
hooks:
|
| 33 |
+
- id: ruff
|
| 34 |
+
args: [--select=F401, --fixable=F401]
|
| 35 |
+
files: ^(benchmark/|docs/|examples/)
|
| 36 |
+
exclude: \.ipynb$
|
| 37 |
+
- repo: https://github.com/psf/black
|
| 38 |
+
rev: 24.10.0
|
| 39 |
+
hooks:
|
| 40 |
+
- id: black-jupyter
|
| 41 |
+
- repo: https://github.com/pre-commit/mirrors-clang-format
|
| 42 |
+
rev: v18.1.8
|
| 43 |
+
hooks:
|
| 44 |
+
- id: clang-format
|
| 45 |
+
types_or: [c++, cuda]
|
| 46 |
+
args: [--style=file, --verbose]
|
| 47 |
+
- repo: https://github.com/kynan/nbstripout
|
| 48 |
+
rev: 0.8.1
|
| 49 |
+
hooks:
|
| 50 |
+
- id: nbstripout
|
| 51 |
+
args:
|
| 52 |
+
- '--keep-output'
|
| 53 |
+
- '--extra-keys=metadata.kernelspec metadata.language_info.version'
|
SpecForge/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 sgl-project
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
SpecForge/MANIFEST.in
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
include requirements.txt
|
| 2 |
+
include version.txt
|
SpecForge/README.md
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center" id="sglangtop">
|
| 2 |
+
<img src="./assets/logo.png" alt="logo" width="400" margin="10px"></img>
|
| 3 |
+
|
| 4 |
+
[](https://docs.sglang.ai/SpecForge/)
|
| 5 |
+
[](https://huggingface.co/collections/lmsys/specbundle)
|
| 6 |
+
[](https://deepwiki.com/sgl-project/SpecForge)
|
| 7 |
+
|
| 8 |
+
[](https://lmsys.org/blog/2025-07-25-spec-forge/)
|
| 9 |
+
[](https://sgl-fru7574.slack.com/archives/C09784E3EN6)
|
| 10 |
+
[](./LICENSE)
|
| 11 |
+
|
| 12 |
+
</div>
|
| 13 |
+
|
| 14 |
+
## 📍 Overview
|
| 15 |
+
|
| 16 |
+
SpecForge is an ecosystem project developed by the SGLang team. It is a framework for training speculative decoding models so that you can smoothly port them over to the SGLang serving framework to speed up your inference.
|
| 17 |
+
|
| 18 |
+
We have seen many open-source projects for speculative decoding, but most of them are not well-maintained or not directly compatible with SGLang. We prepared this project because we wish that the open-source community can enjoy a speculative decoding framework that is
|
| 19 |
+
- regularly maintained by the SpecForge team: the code is runnable out-of-the-box
|
| 20 |
+
- directly compatible with SGLang: there is no additional efforts for porting to SGLang
|
| 21 |
+
- provide performant training capabilities: we provided online/offline/tensor-parallel/FSDP to suit your needs
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
Check out [**our documentation**](https://docs.sglang.ai/SpecForge/) to get started.
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
## 🚀 Accelerate with SpecBundle
|
| 28 |
+
|
| 29 |
+
SpecBundle is a collection of production-grade speculative decoding models that are released by the SpecForge team and our industry partners. They provide higher acceptance rate compared to the existing open-source checkpoints over a wide range of domains. Together with SGLang, you can experience up to 4x speedup for inference. Check out our resources below:
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
| Item | Link |
|
| 33 |
+
| --- | --- |
|
| 34 |
+
| 📝 Documentation | [Link](https://docs.sglang.io/SpecForge/community_resources/specbundle.html) |
|
| 35 |
+
| 📊 Performance Dashboard | [Link](https://docs.sglang.io/SpecForge/SpecBundle/index.html) |
|
| 36 |
+
| 🤗 Hugging Face Collection | [Link](https://huggingface.co/collections/lmsys/specbundle) |
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
## 🎉 News
|
| 40 |
+
|
| 41 |
+
- [2025-12] 🎉 Released SpecBundle (phase 1) and SpecForge v0.2. Check out our blog at [LMSYS.org](https://lmsys.org/blog/2025-12-23-spec-bundle-phase-1/)
|
| 42 |
+
- [2025-12] 🔔 Released the roadmap for 2026 Q1.
|
| 43 |
+
- [2025-08] 🔔 SpecForge is listed as a [flagship project](https://lmsys.org/about/) in LMSYS. Congratulations to the SpecForge team!
|
| 44 |
+
- [2025-08] 🔥 SpecForge powered the Eagle3 draft model for GPT-OSS. Check out the blog at [LMSYS.org](https://lmsys.org/blog/2025-08-27-gpt-oss/)
|
| 45 |
+
- [2025-07] 🔥 SpecForge is released together with Llama4-Eagle3 checkpoints. Check out our blog at [LMSYS.org](https://lmsys.org/blog/2025-07-25-spec-forge/)
|
| 46 |
+
|
| 47 |
+
## ✨ Acknowledgements
|
| 48 |
+
|
| 49 |
+
<img src="./assets/acknowledgements.png" alt="acknowledgements"></img>
|
| 50 |
+
|
| 51 |
+
We would like to express our sincere gratitude to the official EAGLE team, especially Hongyang Zhang and Yuhui Li, for their invaluable contributions and support. Our thanks also go to the NVIDIA team—particularly Avery H and Izzy Putterman—and to the Google team, especially Ying Wang, for their insightful discussions and generous assistance throughout the project.
|
| 52 |
+
|
| 53 |
+
We are especially grateful to Meituan for their strong backing and meaningful contributions, which played a vital role in driving this project forward.
|
| 54 |
+
|
| 55 |
+
This project has also been inspired by many outstanding open-source projects from the LLM community, including [EAGLE](https://github.com/SafeAILab/EAGLE), [BaldEagle](https://github.com/NickL77/BaldEagle), and [TensorRT-Model-Optimizer](https://github.com/NVIDIA/TensorRT-Model-Optimizer) and others. Their contributions and shared knowledge have greatly benefited our work.
|
| 56 |
+
|
| 57 |
+
## 💡 Special Thanks to Voltage Park
|
| 58 |
+
|
| 59 |
+
We would like to extend our sincere thanks to [Voltage Park](https://www.voltagepark.com/), our official infrastructure partner. As part of a formal collaboration with the SGLang team, Voltage Park provided critical GPU resources that empowered us to train and evaluate large-scale speculative decoding models efficiently and reliably. This partnership was instrumental in making SpecForge possible. We deeply appreciate Voltage Park’s mission to make cutting-edge AI infrastructure more accessible, and we look forward to continued collaboration as we push the boundaries of open-source LLM serving and optimization.
|
| 60 |
+
|
| 61 |
+
## 📃 Citation
|
| 62 |
+
|
| 63 |
+
```bibtex
|
| 64 |
+
@misc{specforge2025,
|
| 65 |
+
title={SpecForge: Train speculative decoding models effortlessly},
|
| 66 |
+
author={Shenggui Li, Yikai Zhu, Chao Wang, Fan Yin, Shuai Shi, Yubo Wang, Yi Zhang, Yingyi Huang, Haoshuai Zheng, Yineng Zhang},
|
| 67 |
+
year={2025},
|
| 68 |
+
publisher={GitHub},
|
| 69 |
+
howpublished={\url{https://github.com/sgl-project/specforge}},
|
| 70 |
+
}
|
SpecForge/pyproject.toml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "specforge"
|
| 7 |
+
dynamic = ["version"]
|
| 8 |
+
readme = "README.md"
|
| 9 |
+
requires-python = ">=3.11"
|
| 10 |
+
description = "SpecForge: Speculative Decoding Training Framework"
|
| 11 |
+
authors = [{name = "SGLang Team"}]
|
| 12 |
+
urls = {Homepage = "https://github.com/sgl-project/SpecForge"}
|
| 13 |
+
dependencies = [
|
| 14 |
+
"pre-commit",
|
| 15 |
+
"torch==2.9.1",
|
| 16 |
+
"torchaudio==2.9.1",
|
| 17 |
+
"torchvision==0.24.1",
|
| 18 |
+
"transformers==4.57.1",
|
| 19 |
+
"qwen-vl-utils==0.0.11",
|
| 20 |
+
"datasets",
|
| 21 |
+
"setuptools",
|
| 22 |
+
"tqdm",
|
| 23 |
+
"wandb",
|
| 24 |
+
"psutil",
|
| 25 |
+
"numpy",
|
| 26 |
+
"accelerate",
|
| 27 |
+
"pydantic",
|
| 28 |
+
"sglang==0.5.9",
|
| 29 |
+
"openai-harmony",
|
| 30 |
+
"ninja",
|
| 31 |
+
"packaging",
|
| 32 |
+
"yunchang",
|
| 33 |
+
"tensorboard",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
[tool.setuptools.packages.find]
|
| 37 |
+
exclude = ["configs*", "scripts*", "tests*"]
|
| 38 |
+
|
| 39 |
+
[project.optional-dependencies]
|
| 40 |
+
dev = [
|
| 41 |
+
"pre-commit",
|
| 42 |
+
"unittest"
|
| 43 |
+
]
|
| 44 |
+
fa = ["flash-attn"]
|
| 45 |
+
|
| 46 |
+
[tool.setuptools.dynamic]
|
| 47 |
+
version = {file = "version.txt"}
|
SpecForge/requirements-rocm.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use the PyTorch ROCm wheel index (choose the stream that matches your system)
|
| 2 |
+
--extra-index-url https://download.pytorch.org/whl/rocm6.3
|
| 3 |
+
|
| 4 |
+
pre-commit
|
| 5 |
+
torch==2.8.0+rocm6.3
|
| 6 |
+
torchaudio==2.8.0+rocm6.3
|
| 7 |
+
torchvision==0.23.0+rocm6.3
|
| 8 |
+
transformers==4.57.1
|
| 9 |
+
qwen-vl-utils==0.0.11
|
| 10 |
+
datasets
|
| 11 |
+
setuptools
|
| 12 |
+
tqdm
|
| 13 |
+
wandb
|
| 14 |
+
psutil
|
| 15 |
+
numpy
|
| 16 |
+
accelerate
|
| 17 |
+
pydantic
|
| 18 |
+
sglang[all]==0.5.4
|
| 19 |
+
openai-harmony
|
| 20 |
+
tensorboard
|
SpecForge/version.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0.2.0
|
idea1/.editorconfig
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://editorconfig.org/
|
| 2 |
+
|
| 3 |
+
root = true
|
| 4 |
+
|
| 5 |
+
[*]
|
| 6 |
+
charset = utf-8
|
| 7 |
+
end_of_line = lf
|
| 8 |
+
indent_style = space
|
| 9 |
+
indent_size = 4
|
| 10 |
+
trim_trailing_whitespace = true
|
| 11 |
+
insert_final_newline = true
|
| 12 |
+
|
| 13 |
+
[*.{json,yaml,yml}]
|
| 14 |
+
indent_size = 2
|
| 15 |
+
|
| 16 |
+
[*.md]
|
| 17 |
+
indent_size = 2
|
| 18 |
+
x-soft-wrap-text = true
|
| 19 |
+
|
| 20 |
+
[*.rst]
|
| 21 |
+
indent_size = 4
|
| 22 |
+
x-soft-wrap-text = true
|
| 23 |
+
|
| 24 |
+
[Makefile]
|
| 25 |
+
indent_style = tab
|
idea1/.isort.cfg
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[settings]
|
| 2 |
+
profile=black
|
| 3 |
+
known_first_party=sgl-eagle
|
idea1/.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
default_stages: [pre-commit, pre-push, manual]
|
| 2 |
+
|
| 3 |
+
repos:
|
| 4 |
+
- repo: https://github.com/PyCQA/autoflake
|
| 5 |
+
rev: v2.3.1
|
| 6 |
+
hooks:
|
| 7 |
+
- id: autoflake
|
| 8 |
+
args: [--remove-all-unused-imports, --in-place]
|
| 9 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 10 |
+
rev: v5.0.0
|
| 11 |
+
hooks:
|
| 12 |
+
- id: check-symlinks
|
| 13 |
+
- id: destroyed-symlinks
|
| 14 |
+
- id: trailing-whitespace
|
| 15 |
+
- id: end-of-file-fixer
|
| 16 |
+
- id: check-yaml
|
| 17 |
+
args: [--allow-multiple-documents]
|
| 18 |
+
- id: check-toml
|
| 19 |
+
- id: check-ast
|
| 20 |
+
- id: check-added-large-files
|
| 21 |
+
- id: check-merge-conflict
|
| 22 |
+
- id: check-shebang-scripts-are-executable
|
| 23 |
+
- id: detect-private-key
|
| 24 |
+
- id: debug-statements
|
| 25 |
+
- id: no-commit-to-branch
|
| 26 |
+
- repo: https://github.com/PyCQA/isort
|
| 27 |
+
rev: 5.13.2
|
| 28 |
+
hooks:
|
| 29 |
+
- id: isort
|
| 30 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 31 |
+
rev: v0.11.10
|
| 32 |
+
hooks:
|
| 33 |
+
- id: ruff
|
| 34 |
+
args: [--select=F401, --fixable=F401]
|
| 35 |
+
files: ^(benchmark/|docs/|examples/)
|
| 36 |
+
exclude: \.ipynb$
|
| 37 |
+
- repo: https://github.com/psf/black
|
| 38 |
+
rev: 24.10.0
|
| 39 |
+
hooks:
|
| 40 |
+
- id: black-jupyter
|
| 41 |
+
- repo: https://github.com/pre-commit/mirrors-clang-format
|
| 42 |
+
rev: v18.1.8
|
| 43 |
+
hooks:
|
| 44 |
+
- id: clang-format
|
| 45 |
+
types_or: [c++, cuda]
|
| 46 |
+
args: [--style=file, --verbose]
|
| 47 |
+
- repo: https://github.com/kynan/nbstripout
|
| 48 |
+
rev: 0.8.1
|
| 49 |
+
hooks:
|
| 50 |
+
- id: nbstripout
|
| 51 |
+
args:
|
| 52 |
+
- '--keep-output'
|
| 53 |
+
- '--extra-keys=metadata.kernelspec metadata.language_info.version'
|
idea1/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 sgl-project
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
idea1/requirements-rocm.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use the PyTorch ROCm wheel index (choose the stream that matches your system)
|
| 2 |
+
--extra-index-url https://download.pytorch.org/whl/rocm6.3
|
| 3 |
+
|
| 4 |
+
pre-commit
|
| 5 |
+
torch==2.8.0+rocm6.3
|
| 6 |
+
torchaudio==2.8.0+rocm6.3
|
| 7 |
+
torchvision==0.23.0+rocm6.3
|
| 8 |
+
transformers==4.57.1
|
| 9 |
+
qwen-vl-utils==0.0.11
|
| 10 |
+
datasets
|
| 11 |
+
setuptools
|
| 12 |
+
tqdm
|
| 13 |
+
wandb
|
| 14 |
+
psutil
|
| 15 |
+
numpy
|
| 16 |
+
accelerate
|
| 17 |
+
pydantic
|
| 18 |
+
sglang[all]==0.5.4
|
| 19 |
+
openai-harmony
|
| 20 |
+
tensorboard
|
idea1/version.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0.2.0
|
qwen3-8b_dflash_regen/.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
sharegpt_train_regenerated.jsonl filter=lfs diff=lfs merge=lfs -text
|
syxin/backup.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
syxin/dflash_lora_changelog.md
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DFlash LoRA 全部改动记录
|
| 2 |
+
|
| 3 |
+
## 概述
|
| 4 |
+
|
| 5 |
+
为了让 Qwen3-8B DFlash LoRA 训练在 2×H100 上跑通(解决 OOM),共新增/修改了 **5 个文件,1084 行代码**。改动分为两大阶段:基础搭建 + OOM 修复。
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## 新增文件清单
|
| 10 |
+
|
| 11 |
+
| 文件 | 行数 | 用途 |
|
| 12 |
+
|------|------|------|
|
| 13 |
+
| `specforge/core/dflash_lora.py` | 453 | 训练 wrapper(OnlineDFlashLoRAModel) |
|
| 14 |
+
| `specforge/modeling/draft/dflash_lora.py` | 141 | LoRA draft 模型(DFlashLoRADraftModel) |
|
| 15 |
+
| `scripts/train_dflash_lora.py` | 449 | 训练入口脚本 |
|
| 16 |
+
| `scripts/run_train_dflash_lora.sh` | 31 | 启动 shell 脚本 |
|
| 17 |
+
| `configs/qwen3-8b-dflash-lora.json` | 10 | LoRA 配置文件 |
|
| 18 |
+
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
## Step 1 完成过程
|
| 22 |
+
|
| 23 |
+
### 1.1 分析现有代码
|
| 24 |
+
|
| 25 |
+
首先分析了非 LoRA 版 `train_dflash.py` 的完整流程:
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
input_ids → target_model.generate_dflash_data() → hidden_states
|
| 29 |
+
→ OnlineDFlashModel.forward():
|
| 30 |
+
1. 截断到 block 边界
|
| 31 |
+
2. prepare_noise_input(): anchor 保留,其余 → MASK
|
| 32 |
+
3. embed_tokens(noise_input_ids) → noise_embedding
|
| 33 |
+
4. 构建 DFlash attention mask
|
| 34 |
+
5. draft_model(noise_embedding, target_hidden, mask)
|
| 35 |
+
6. lm_head(hidden) → logits → CE loss
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
非 LoRA 版使用独立的小型 draft model + 冻结 target model 提取 hidden states。
|
| 39 |
+
|
| 40 |
+
### 1.2 确定 LoRA 版设计差异
|
| 41 |
+
|
| 42 |
+
| 方面 | 非 LoRA 版 (`train_dflash.py`) | LoRA 版 (`train_dflash_lora.py`) |
|
| 43 |
+
|------|------|------|
|
| 44 |
+
| Draft model | 自定义小模型 (1-10 层) | Qwen3-8B + PEFT LoRA |
|
| 45 |
+
| Target model | 冻结大模型提取 hidden states | 无需 — 模型用自身表征 |
|
| 46 |
+
| Attention | 自定义 Qwen3DFlashAttention,KV = [ctx, noise] concat | 标准 HF attention + DFlash mask |
|
| 47 |
+
| KV 结构 | Q_LEN = noise_len, KV_LEN = 2×noise_len | Q_LEN = KV_LEN = seq_len |
|
| 48 |
+
| 可训练参数 | 全部 draft model 参数 | 仅 LoRA (q/k/v/o_proj) |
|
| 49 |
+
|
| 50 |
+
### 1.3 新建 LoRA 版三个核心文件
|
| 51 |
+
|
| 52 |
+
#### `specforge/modeling/draft/dflash_lora.py` — DFlashLoRADraftModel
|
| 53 |
+
|
| 54 |
+
- `from_pretrained()`: 加载 Qwen3-8B,注入 PEFT LoRA,支持 `attn_implementation` 参数
|
| 55 |
+
- `forward()`: 标准 HF forward,支持 `output_hidden_states` 参数(chunked loss 需要)
|
| 56 |
+
- `get_lm_head()`: 穿透 PEFT 层级获取 lm_head 引用
|
| 57 |
+
- `gradient_checkpointing_enable()`: 代理到底层模型
|
| 58 |
+
- `save_pretrained()`: 仅保存 LoRA adapter 权重
|
| 59 |
+
|
| 60 |
+
#### `specforge/core/dflash_lora.py` — OnlineDFlashLoRAModel
|
| 61 |
+
|
| 62 |
+
- `prepare_noise_input()`: context 部分保持不变,block 部分只保留 anchor,其余替换为 MASK
|
| 63 |
+
- `build_dflash_full_attn_mask_fast()`: 向量化构建 4D additive mask `[bsz, 1, seq, seq]`
|
| 64 |
+
- `_compute_loss_weights()`: context + anchor 权重为 0,非 anchor 权重为 1(或 decay)
|
| 65 |
+
- `_full_lm_loss()`: 标准 CE loss 路径
|
| 66 |
+
- `_compute_accuracy()`: block-wise acceptance rate(累积正确预测长度 / block 非 anchor 长度)
|
| 67 |
+
- `forward()`: 完整训练 forward pass
|
| 68 |
+
|
| 69 |
+
LoRA 版 mask 规则:
|
| 70 |
+
- context token i → 因果注意力 (j ≤ i)
|
| 71 |
+
- block token i (属于 block b) → 所有 context + 同 block 内双向注意力
|
| 72 |
+
|
| 73 |
+
#### `scripts/train_dflash_lora.py` — 训练脚本
|
| 74 |
+
|
| 75 |
+
- 参数解析:model/lora/dataset/training/output/distributed/tracker 7 组参数
|
| 76 |
+
- `build_model()`: 加载模型 + 注入 LoRA + 包装 OnlineDFlashLoRAModel
|
| 77 |
+
- `build_dataloader()`: 复用 `build_eagle3_dataset` 和 `prepare_dp_dataloaders`
|
| 78 |
+
- FSDP 包装 + BF16Optimizer
|
| 79 |
+
- 训练循环:forward → backward → accumulation → optimizer step
|
| 80 |
+
- checkpoint 保存/恢复
|
| 81 |
+
|
| 82 |
+
---
|
| 83 |
+
|
| 84 |
+
## OOM 修复改动(4 项)
|
| 85 |
+
|
| 86 |
+
### 改动 1: FSDP FULL_SHARD (ZeRO-3)
|
| 87 |
+
|
| 88 |
+
**问题**: `SHARD_GRAD_OP` (ZeRO-2) 每卡持有完整 Qwen3-8B 参数 (~16GB bf16)
|
| 89 |
+
|
| 90 |
+
**修复**: `train_dflash_lora.py:362`
|
| 91 |
+
```python
|
| 92 |
+
# 之前
|
| 93 |
+
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP
|
| 94 |
+
# 之后
|
| 95 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
**效果**: 参数跨卡分片,每卡省 ~8-12GB
|
| 99 |
+
|
| 100 |
+
### 改动 2: batch_size=1 + accumulation_steps=8
|
| 101 |
+
|
| 102 |
+
**问题**: `batch_size=2` 时峰值显存过高
|
| 103 |
+
|
| 104 |
+
**修复**: `run_train_dflash_lora.sh`
|
| 105 |
+
```bash
|
| 106 |
+
--batch-size 1 \
|
| 107 |
+
--accumulation-steps 8 \
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
**效果**: 等效 global batch size 不变,峰值显存减半
|
| 111 |
+
|
| 112 |
+
### 改动 3: flex_attention + BlockMask 替换 4D additive mask
|
| 113 |
+
|
| 114 |
+
**问题**: SDPA 不支持 4D additive mask → fallback 到 math backend → 每层 materialize 完整 `[bsz, 32heads, 2048, 2048]` attention scores
|
| 115 |
+
|
| 116 |
+
**修复**: 从非 LoRA 版 `dflash.py` 移植 `_get_or_create_block_mask()` 方法,适配 LoRA 场景
|
| 117 |
+
|
| 118 |
+
涉及文件:
|
| 119 |
+
|
| 120 |
+
1. **`specforge/core/dflash_lora.py`**
|
| 121 |
+
- `__init__()`: 添加 `attention_backend` 参数(默认 `"flex_attention"`),BlockMask 缓存字段
|
| 122 |
+
- 新增 `_get_or_create_block_mask()`: 用 `create_block_mask()` 构建零显存的 BlockMask
|
| 123 |
+
- `forward()`: 根据 `attention_backend` 选择 BlockMask 或 additive mask
|
| 124 |
+
|
| 125 |
+
2. **`specforge/modeling/draft/dflash_lora.py`**
|
| 126 |
+
- `from_pretrained()`: 当 backend 为 flex_attention 时,传 `attn_implementation="flex_attention"` 给 HuggingFace
|
| 127 |
+
|
| 128 |
+
3. **`scripts/train_dflash_lora.py`**
|
| 129 |
+
- `parse_args()`: `--attention-backend` 参数 (`flex_attention` | `additive`)
|
| 130 |
+
- `build_model()`: 根据 backend 选择 `attn_implementation`
|
| 131 |
+
|
| 132 |
+
BlockMask mask function(LoRA 版):
|
| 133 |
+
```python
|
| 134 |
+
def dflash_lora_mask_fn(b, h, q_idx, kv_idx):
|
| 135 |
+
# Context query: 标准因果
|
| 136 |
+
is_q_ctx = q_idx < context_len
|
| 137 |
+
ctx_visible = is_q_ctx & (kv_idx <= q_idx)
|
| 138 |
+
|
| 139 |
+
# Block query: 全部 context + 同 block 双向
|
| 140 |
+
is_q_block = q_idx >= context_len
|
| 141 |
+
is_k_ctx = kv_idx < context_len
|
| 142 |
+
q_block_id = (q_idx - context_len) // block_size
|
| 143 |
+
k_block_id = (kv_idx - context_len) // block_size
|
| 144 |
+
block_attend_ctx = is_q_block & is_k_ctx
|
| 145 |
+
block_attend_same = is_q_block & (~is_k_ctx) & (q_block_id == k_block_id)
|
| 146 |
+
|
| 147 |
+
return ctx_visible | (block_attend_ctx | block_attend_same)
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
**验证**: 手动逐元素对比 BlockMask 和 additive mask 输出,三组测试 (context_len=4/0, seq=12/16/64) pattern 完全一致。
|
| 151 |
+
|
| 152 |
+
**效果**: 不再 fallback 到 SDPA math backend,省去 `[bsz, heads, seq, seq]` attention scores 显存
|
| 153 |
+
|
| 154 |
+
### 改动 4: chunked cross-entropy loss
|
| 155 |
+
|
| 156 |
+
**问题**: `[bsz, 2048, 151936]` bf16 logits ≈ 1.18GB,加梯度 ~2.4GB+
|
| 157 |
+
|
| 158 |
+
**修复**: 从非 LoRA 版 `dflash.py:419-478` 移植 chunked loss
|
| 159 |
+
|
| 160 |
+
涉及文件:
|
| 161 |
+
|
| 162 |
+
1. **`specforge/core/dflash_lora.py`**
|
| 163 |
+
- `__init__()`: 添加 `lm_head_chunk_size` 参数(默认 0 = 不启用)
|
| 164 |
+
- 新增 `_chunked_lm_loss()`: 分 chunk 过 lm_head + CE loss + gradient checkpointing
|
| 165 |
+
- 提取 `_full_lm_loss()`: 原始非 chunked 路径
|
| 166 |
+
- `forward()`: `lm_head_chunk_size > 0` 时走 chunked 路径
|
| 167 |
+
|
| 168 |
+
2. **`specforge/modeling/draft/dflash_lora.py`**
|
| 169 |
+
- `forward()`: 新增 `output_hidden_states` 参数,True 时返回 last hidden state 而非 logits
|
| 170 |
+
- `get_lm_head()`: 穿透 PEFT 层级返回 `base_model.lm_head` 引用
|
| 171 |
+
|
| 172 |
+
3. **`scripts/train_dflash_lora.py`**
|
| 173 |
+
- `parse_args()`: `--lm-head-chunk-size` 参数(默认 0,推荐 256)
|
| 174 |
+
- `build_model()`: 传递到 OnlineDFlashLoRAModel
|
| 175 |
+
|
| 176 |
+
Chunked loss 核心逻辑:
|
| 177 |
+
```python
|
| 178 |
+
# 分 chunk 计算,每 chunk 用 gradient checkpointing(backward 时重算 logits,不存储)
|
| 179 |
+
for start in range(0, effective_len, chunk_size):
|
| 180 |
+
end = min(start + chunk_size, effective_len)
|
| 181 |
+
chunk_loss, chunk_weight = grad_checkpoint(
|
| 182 |
+
_chunk_ce, # lm_head + CE
|
| 183 |
+
hidden[:, start:end, :], # 只取当前 chunk
|
| 184 |
+
input_ids[:, start:end],
|
| 185 |
+
combined_mask[:, start:end],
|
| 186 |
+
use_reentrant=False,
|
| 187 |
+
)
|
| 188 |
+
total_loss += chunk_loss
|
| 189 |
+
total_weight += chunk_weight
|
| 190 |
+
loss = total_loss / total_weight
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
**效果**: logits 峰值显存从 `O(seq_len × vocab_size)` 降至 `O(chunk_size × vocab_size)`,256 chunk → ~150MB vs 1.18GB
|
| 194 |
+
|
| 195 |
+
---
|
| 196 |
+
|
| 197 |
+
## 当前训练命令
|
| 198 |
+
|
| 199 |
+
```bash
|
| 200 |
+
bash run_train_dflash_lora.sh 2 # 2 = GPU 数量
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
对应完整参数:
|
| 204 |
+
```bash
|
| 205 |
+
torchrun --nproc_per_node 2 scripts/train_dflash_lora.py \
|
| 206 |
+
--model-path /workspace/Qwen3-8B \
|
| 207 |
+
--train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
|
| 208 |
+
--output-dir outputs/qwen3-8b-dflash-lora \
|
| 209 |
+
--lora-config configs/qwen3-8b-dflash-lora.json \
|
| 210 |
+
--block-size 16 \
|
| 211 |
+
--max-length 2048 \
|
| 212 |
+
--batch-size 1 \
|
| 213 |
+
--num-epochs 3 \
|
| 214 |
+
--learning-rate 2e-4 \
|
| 215 |
+
--accumulation-steps 8 \
|
| 216 |
+
--loss-decay-gamma 7 \
|
| 217 |
+
--attention-backend flex_attention \
|
| 218 |
+
--lm-head-chunk-size 256 \
|
| 219 |
+
--gradient-checkpointing \
|
| 220 |
+
--chat-template qwen \
|
| 221 |
+
--log-interval 50 \
|
| 222 |
+
--save-interval 500
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
---
|
| 226 |
+
|
| 227 |
+
## 待验证
|
| 228 |
+
|
| 229 |
+
- [ ] 跑 `bash run_train_dflash_lora.sh 2` 确认不再 OOM
|
| 230 |
+
- [ ] 确认无 SDPA math fallback warning
|
| 231 |
+
- [ ] 观察 GPU 显存峰值
|
| 232 |
+
- [ ] 确认 loss 下降和 accuracy 上升趋势正常
|
syxin/eval_accepted_length.md
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DFlash-LoRA-Inject 评测:Accepted Length & Accuracy
|
| 2 |
+
|
| 3 |
+
## 为什么不能用 sglang 在线评测?
|
| 4 |
+
|
| 5 |
+
DFlash-LoRA-Inject 的推理需要**逐层注入 target 模型的 hidden states** 到 draft 模型中,
|
| 6 |
+
这是 LoRA-Inject 训练时的核心机制。但 sglang 不支持这种推理模式:
|
| 7 |
+
|
| 8 |
+
| sglang 算法 | 问题 |
|
| 9 |
+
|---|---|
|
| 10 |
+
| `STANDALONE` | 把 draft 当独立自回归模型跑,**完全忽略 layer injection**。merged 模型 ≈ 原始 Qwen3-8B,accept_length 恒 ≈ 4.7,跟 LoRA 训没训没关系 |
|
| 11 |
+
| `DFLASH` | 期望 DFlash-b16 架构(5 层 + fc + hidden_norm),跟 LoRA-Inject(36 层全模型)结构不匹配 |
|
| 12 |
+
|
| 13 |
+
因此必须**离线评测**:加载 target + draft 两个模型,手动实现带 layer injection 的 speculative decoding 循环。
|
| 14 |
+
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
## 基本信息
|
| 18 |
+
|
| 19 |
+
| 项目 | 路径 / 值 |
|
| 20 |
+
|---|---|
|
| 21 |
+
| conda 环境 | `spec` |
|
| 22 |
+
| 基座模型(target) | `/workspace/models/Qwen3-8B` |
|
| 23 |
+
| 训练输出(最终 ckpt) | `.../outputs/qwen3-8b-dflash-lora-inject/epoch_3_step_1400` |
|
| 24 |
+
| 合并后 draft 模型 | `.../outputs/qwen3-8b-dflash-lora-inject-merged` |
|
| 25 |
+
| 评测脚本 | `/workspace/hanrui/syxin_old/eval_dflash_lora_inject.py` |
|
| 26 |
+
| 本地数据集 | `/workspace/hanrui/datasets/{humaneval,mtbench,gsm8k}` |
|
| 27 |
+
| 结果输出目录 | `/workspace/hanrui/syxin_old/Specforge/benchmarks/results/` |
|
| 28 |
+
| GPU | 8 × H100 80GB(单卡即可,需 ~32GB 加载两个 8B 模型) |
|
| 29 |
+
|
| 30 |
+
---
|
| 31 |
+
|
| 32 |
+
## Step 1:合并 LoRA 权重
|
| 33 |
+
|
| 34 |
+
LoRA-Inject 训练只保存 adapter 权重,评测时需要完整模型。
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
conda activate spec
|
| 38 |
+
|
| 39 |
+
python3 -c "
|
| 40 |
+
from peft import PeftModel
|
| 41 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 42 |
+
import torch, os
|
| 43 |
+
|
| 44 |
+
BASE = '/workspace/models/Qwen3-8B'
|
| 45 |
+
ADAPTER = '/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject/epoch_3_step_1400'
|
| 46 |
+
MERGED = '/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject-merged'
|
| 47 |
+
|
| 48 |
+
if os.path.exists(MERGED):
|
| 49 |
+
print(f'[skip] Merged model already exists: {MERGED}')
|
| 50 |
+
else:
|
| 51 |
+
print('[1/4] Loading base model to CPU ...')
|
| 52 |
+
model = AutoModelForCausalLM.from_pretrained(BASE, torch_dtype=torch.bfloat16, device_map='cpu')
|
| 53 |
+
print('[2/4] Loading LoRA adapter ...')
|
| 54 |
+
model = PeftModel.from_pretrained(model, ADAPTER)
|
| 55 |
+
print('[3/4] Merging weights ...')
|
| 56 |
+
model = model.merge_and_unload()
|
| 57 |
+
print('[4/4] Saving merged model ...')
|
| 58 |
+
os.makedirs(MERGED, exist_ok=True)
|
| 59 |
+
model.save_pretrained(MERGED, safe_serialization=True)
|
| 60 |
+
AutoTokenizer.from_pretrained(BASE).save_pretrained(MERGED)
|
| 61 |
+
print(f'Done. Merged model saved to: {MERGED}')
|
| 62 |
+
"
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
> 耗时约 3–5 分钟,CPU 内存占用 ≈ 16 GB。已存在则自动跳过。
|
| 66 |
+
|
| 67 |
+
---
|
| 68 |
+
|
| 69 |
+
## Step 2:离线评测 accepted length
|
| 70 |
+
|
| 71 |
+
**不需要启动 sglang server**,直接跑:
|
| 72 |
+
|
| 73 |
+
### 全部 Bench(推荐)
|
| 74 |
+
|
| 75 |
+
```bash
|
| 76 |
+
bash /workspace/hanrui/syxin_old/run_bench_dflash.sh
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
### 单独跑 / 快速测试
|
| 80 |
+
|
| 81 |
+
```bash
|
| 82 |
+
# 只跑 HumanEval
|
| 83 |
+
bash /workspace/hanrui/syxin_old/run_bench_dflash.sh humaneval
|
| 84 |
+
|
| 85 |
+
# 快速测试(每个 bench 20 条)
|
| 86 |
+
bash /workspace/hanrui/syxin_old/run_bench_dflash.sh --quick
|
| 87 |
+
|
| 88 |
+
# 指定 checkpoint
|
| 89 |
+
bash /workspace/hanrui/syxin_old/run_bench_dflash.sh --ckpt epoch_0_step_1000
|
| 90 |
+
|
| 91 |
+
# 组合
|
| 92 |
+
bash /workspace/hanrui/syxin_old/run_bench_dflash.sh humaneval gsm8k --quick
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
### 或者直接调 Python
|
| 96 |
+
|
| 97 |
+
```bash
|
| 98 |
+
conda activate spec
|
| 99 |
+
|
| 100 |
+
python3 /workspace/hanrui/syxin_old/eval_dflash_lora_inject.py \
|
| 101 |
+
--benchmarks humaneval mtbench gsm8k \
|
| 102 |
+
--block-size 16 \
|
| 103 |
+
--max-new-tokens 512 \
|
| 104 |
+
--temperature 0.0
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
---
|
| 108 |
+
|
| 109 |
+
## 结果文件说明
|
| 110 |
+
|
| 111 |
+
结果保存在 `results/` 下,文件名示例:
|
| 112 |
+
```
|
| 113 |
+
dflash_lora_inject_offline_epoch_3_step_1400_20260314_150000.json
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
```json
|
| 117 |
+
{
|
| 118 |
+
"model": "dflash-lora-inject/epoch_3_step_1400",
|
| 119 |
+
"block_size": 16,
|
| 120 |
+
"humaneval": {
|
| 121 |
+
"avg_accept_length": 3.42,
|
| 122 |
+
"total_tokens": 28500,
|
| 123 |
+
"latency": 120.5,
|
| 124 |
+
"throughput": 236.5,
|
| 125 |
+
"num_samples": 164,
|
| 126 |
+
"num_verify_rounds": 8320
|
| 127 |
+
},
|
| 128 |
+
"mtbench": { ... },
|
| 129 |
+
"gsm8k": { ... }
|
| 130 |
+
}
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
| 字段 | 含义 |
|
| 134 |
+
|---|---|
|
| 135 |
+
| `avg_accept_length` | **核心指标**:平均每次 verify 接受的 token 数(含 injection)。越高越好,`1.0` = draft 完全无效 |
|
| 136 |
+
| `total_tokens` | 总生成 token 数 |
|
| 137 |
+
| `throughput` | tokens/s(离线评测,不含 batching 优化) |
|
| 138 |
+
| `num_verify_rounds` | 总验证轮数 |
|
| 139 |
+
|
| 140 |
+
---
|
| 141 |
+
|
| 142 |
+
## 对比 baseline
|
| 143 |
+
|
| 144 |
+
对比未经 LoRA 训练的原始 Qwen3-8B 当 draft 的 accept_length:
|
| 145 |
+
|
| 146 |
+
```bash
|
| 147 |
+
python3 /workspace/hanrui/syxin_old/eval_dflash_lora_inject.py \
|
| 148 |
+
--merged-path /workspace/models/Qwen3-8B \
|
| 149 |
+
--benchmarks humaneval mtbench gsm8k \
|
| 150 |
+
--num-samples 50
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
> 这会用原始 Qwen3-8B 同时当 target 和 draft(带 injection),
|
| 154 |
+
> 对比 LoRA 训练前后 accept_length 是否有提升。
|
| 155 |
+
|
| 156 |
+
---
|
| 157 |
+
|
| 158 |
+
## 如何测其他 checkpoint
|
| 159 |
+
|
| 160 |
+
```bash
|
| 161 |
+
# 方法 1:直接加载 adapter(自动 merge,不保存)
|
| 162 |
+
python3 /workspace/hanrui/syxin_old/eval_dflash_lora_inject.py \
|
| 163 |
+
--ckpt epoch_0_step_1000 \
|
| 164 |
+
--benchmarks humaneval --num-samples 50
|
| 165 |
+
|
| 166 |
+
# 方法 2:预先 merge 到不同目录
|
| 167 |
+
python3 -c "
|
| 168 |
+
from peft import PeftModel
|
| 169 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 170 |
+
import torch, os
|
| 171 |
+
BASE = '/workspace/models/Qwen3-8B'
|
| 172 |
+
ADAPTER = '/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject/epoch_0_step_1000'
|
| 173 |
+
MERGED = '/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject-merged-epoch_0_step_1000'
|
| 174 |
+
model = AutoModelForCausalLM.from_pretrained(BASE, torch_dtype=torch.bfloat16, device_map='cpu')
|
| 175 |
+
model = PeftModel.from_pretrained(model, ADAPTER).merge_and_unload()
|
| 176 |
+
os.makedirs(MERGED, exist_ok=True)
|
| 177 |
+
model.save_pretrained(MERGED, safe_serialization=True)
|
| 178 |
+
AutoTokenizer.from_pretrained(BASE).save_pretrained(MERGED)
|
| 179 |
+
"
|
| 180 |
+
|
| 181 |
+
python3 /workspace/hanrui/syxin_old/eval_dflash_lora_inject.py \
|
| 182 |
+
--merged-path .../qwen3-8b-dflash-lora-inject-merged-epoch_0_step_1000 \
|
| 183 |
+
--benchmarks humaneval --num-samples 50
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
可用 checkpoint:`epoch_0_step_500` / `epoch_0_step_1000` / `epoch_0_step_1400` / `epoch_2_step_34500` / `epoch_2_step_35000` / `epoch_3_step_1400`
|
| 187 |
+
|
| 188 |
+
---
|
| 189 |
+
|
| 190 |
+
## 常见问题
|
| 191 |
+
|
| 192 |
+
### Q1:accept_length 和 STANDALONE 模式下差不多(都 ≈ 4.7)
|
| 193 |
+
|
| 194 |
+
这说明 layer injection 没有真正起作用。检查:
|
| 195 |
+
- 评测脚本确实用的是 `eval_dflash_lora_inject.py`(离线),不是 sglang bench
|
| 196 |
+
- merged 模型确实是 LoRA-Inject 版本(不是原始 Qwen3-8B)
|
| 197 |
+
|
| 198 |
+
### Q2:OOM(单卡放不下两个 8B 模型)
|
| 199 |
+
|
| 200 |
+
两个 bf16 的 Qwen3-8B ≈ 32GB,单卡 H100 80GB 够用。如果 OOM:
|
| 201 |
+
- 检查是否有其他进程占用显存
|
| 202 |
+
- 减小 `--max-new-tokens`(试 256)
|
| 203 |
+
- 减小 `--num-samples`
|
| 204 |
+
|
| 205 |
+
### Q3:数据集下载失败(无外网)
|
| 206 |
+
|
| 207 |
+
评测脚本优先读本地文件:
|
| 208 |
+
|
| 209 |
+
| bench | 本地文件 |
|
| 210 |
+
|---|---|
|
| 211 |
+
| GSM8K | `/workspace/hanrui/datasets/gsm8k/test.jsonl` |
|
| 212 |
+
| MT-Bench | `/workspace/hanrui/datasets/mtbench/question.jsonl` |
|
| 213 |
+
| HumanEval | `/workspace/hanrui/datasets/humaneval/test.jsonl` |
|
| 214 |
+
|
| 215 |
+
---
|
| 216 |
+
|
| 217 |
+
*基座:`/workspace/models/Qwen3-8B` | 最终 ckpt:`epoch_3_step_1400` | block_size:16*
|
syxin/eval_dflash_b16_baseline.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Offline evaluation for DFlash-b16 baseline: measure accepted length.
|
| 4 |
+
8 GPUs parallel, each GPU loads target + draft independently.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
# 8 GPUs
|
| 8 |
+
torchrun --nproc_per_node 8 eval_dflash_b16_baseline.py
|
| 9 |
+
|
| 10 |
+
# quick test
|
| 11 |
+
torchrun --nproc_per_node 8 eval_dflash_b16_baseline.py --num-samples 20
|
| 12 |
+
|
| 13 |
+
# single GPU
|
| 14 |
+
python3 eval_dflash_b16_baseline.py --benchmarks humaneval
|
| 15 |
+
"""
|
| 16 |
+
import argparse
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import sys
|
| 20 |
+
import time
|
| 21 |
+
from typing import List, Optional, Tuple
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
import torch.distributed as dist
|
| 26 |
+
from tqdm import tqdm
|
| 27 |
+
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, DynamicCache
|
| 28 |
+
|
| 29 |
+
# Add DFlash model path so we can import utils
|
| 30 |
+
sys.path.insert(0, "/workspace/models/Qwen3-8B-DFlash-b16")
|
| 31 |
+
from utils import extract_context_feature, sample
|
| 32 |
+
|
| 33 |
+
# ──────────────────────────────────────────────────────────────────
|
| 34 |
+
BASE_MODEL = "/workspace/models/Qwen3-8B"
|
| 35 |
+
DRAFT_MODEL = "/workspace/models/Qwen3-8B-DFlash-b16"
|
| 36 |
+
RESULT_DIR = "/workspace/hanrui/syxin_old/Specforge/benchmarks/results"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ──────────────────────────────────────────────────────────────────
|
| 40 |
+
# Distributed helpers
|
| 41 |
+
# ──────────────────────────────────────────────────────────────────
|
| 42 |
+
def is_distributed():
|
| 43 |
+
return dist.is_available() and dist.is_initialized()
|
| 44 |
+
|
| 45 |
+
def get_rank():
|
| 46 |
+
return dist.get_rank() if is_distributed() else 0
|
| 47 |
+
|
| 48 |
+
def get_world_size():
|
| 49 |
+
return dist.get_world_size() if is_distributed() else 1
|
| 50 |
+
|
| 51 |
+
def is_main():
|
| 52 |
+
return get_rank() == 0
|
| 53 |
+
|
| 54 |
+
def print_rank0(*args, **kwargs):
|
| 55 |
+
if is_main():
|
| 56 |
+
print(*args, **kwargs)
|
| 57 |
+
|
| 58 |
+
def split_list(lst, rank, world_size):
|
| 59 |
+
return [x for i, x in enumerate(lst) if i % world_size == rank]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ──────────────────────────────────────────────────────────────────
|
| 63 |
+
# Prompts
|
| 64 |
+
# ──────────────────────────────────────────────────────────────────
|
| 65 |
+
def load_prompts(bench_name: str, num_samples: Optional[int] = None) -> List[str]:
|
| 66 |
+
local_paths = {
|
| 67 |
+
"humaneval": "/workspace/hanrui/datasets/humaneval/test.jsonl",
|
| 68 |
+
"mtbench": "/workspace/hanrui/datasets/mtbench/question.jsonl",
|
| 69 |
+
"gsm8k": "/workspace/hanrui/datasets/gsm8k/test.jsonl",
|
| 70 |
+
}
|
| 71 |
+
prompts = []
|
| 72 |
+
path = local_paths.get(bench_name)
|
| 73 |
+
if path and os.path.exists(path):
|
| 74 |
+
with open(path) as f:
|
| 75 |
+
for line in f:
|
| 76 |
+
item = json.loads(line)
|
| 77 |
+
if bench_name == "humaneval":
|
| 78 |
+
p = f"Write a solution to the following problem and make sure that it passes the tests:\n```python\n{item['prompt']}\n```"
|
| 79 |
+
elif bench_name == "mtbench":
|
| 80 |
+
p = item.get("turns", [item.get("prompt", "")])[0]
|
| 81 |
+
elif bench_name == "gsm8k":
|
| 82 |
+
p = item["question"] + "\nPlease reason step by step, and put your final answer within \\boxed{}."
|
| 83 |
+
else:
|
| 84 |
+
p = str(item)
|
| 85 |
+
prompts.append(p)
|
| 86 |
+
else:
|
| 87 |
+
from datasets import load_dataset
|
| 88 |
+
if bench_name == "humaneval":
|
| 89 |
+
ds = load_dataset("openai/openai_humaneval", split="test")
|
| 90 |
+
prompts = [f"Write a solution to the following problem and make sure that it passes the tests:\n```python\n{x['prompt']}\n```" for x in ds]
|
| 91 |
+
elif bench_name == "mtbench":
|
| 92 |
+
ds = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train")
|
| 93 |
+
prompts = [x["prompt"][0] for x in ds]
|
| 94 |
+
elif bench_name == "gsm8k":
|
| 95 |
+
ds = load_dataset("openai/gsm8k", "main", split="test")
|
| 96 |
+
prompts = [x["question"] + "\nPlease reason step by step, and put your final answer within \\boxed{}." for x in ds]
|
| 97 |
+
if num_samples is not None:
|
| 98 |
+
prompts = prompts[:num_samples]
|
| 99 |
+
return prompts
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ──────────────────────────────────────────────────────────────────
|
| 103 |
+
# spec_generate with acceptance_lengths returned
|
| 104 |
+
# (Same logic as DFlashDraftModel.spec_generate but returns accept lens)
|
| 105 |
+
# ──────────────────────────────────────────────────────────────────
|
| 106 |
+
@torch.inference_mode()
|
| 107 |
+
def spec_generate_b16(
|
| 108 |
+
draft_model,
|
| 109 |
+
target_model: nn.Module,
|
| 110 |
+
input_ids: torch.LongTensor,
|
| 111 |
+
max_new_tokens: int = 512,
|
| 112 |
+
temperature: float = 0.0,
|
| 113 |
+
stop_token_ids: Optional[List[int]] = None,
|
| 114 |
+
) -> Tuple[torch.Tensor, List[int]]:
|
| 115 |
+
"""Same as DFlashDraftModel.spec_generate but also returns acceptance_lengths."""
|
| 116 |
+
draft_model.eval()
|
| 117 |
+
device = target_model.device if hasattr(target_model, 'device') else input_ids.device
|
| 118 |
+
num_input_tokens = input_ids.shape[1]
|
| 119 |
+
max_length = num_input_tokens + max_new_tokens
|
| 120 |
+
block_size = draft_model.block_size
|
| 121 |
+
mask_token_id = draft_model.mask_token_id
|
| 122 |
+
|
| 123 |
+
output_ids = torch.full(
|
| 124 |
+
(1, max_length + block_size), mask_token_id,
|
| 125 |
+
dtype=torch.long, device=device,
|
| 126 |
+
)
|
| 127 |
+
position_ids = torch.arange(output_ids.shape[1], device=device).unsqueeze(0)
|
| 128 |
+
|
| 129 |
+
past_key_values_target = DynamicCache()
|
| 130 |
+
past_key_values_draft = DynamicCache()
|
| 131 |
+
|
| 132 |
+
# Prefill
|
| 133 |
+
output = target_model(
|
| 134 |
+
input_ids,
|
| 135 |
+
position_ids=position_ids[:, :num_input_tokens],
|
| 136 |
+
past_key_values=past_key_values_target,
|
| 137 |
+
use_cache=True,
|
| 138 |
+
logits_to_keep=1,
|
| 139 |
+
output_hidden_states=True,
|
| 140 |
+
)
|
| 141 |
+
output_ids[:, :num_input_tokens] = input_ids
|
| 142 |
+
output_ids[:, num_input_tokens:num_input_tokens + 1] = sample(output.logits, temperature)
|
| 143 |
+
target_hidden = extract_context_feature(output.hidden_states, draft_model.target_layer_ids)
|
| 144 |
+
|
| 145 |
+
# Decode
|
| 146 |
+
acceptance_lengths = []
|
| 147 |
+
start = num_input_tokens
|
| 148 |
+
while start < max_length:
|
| 149 |
+
block_output_ids = output_ids[:, start:start + block_size].clone()
|
| 150 |
+
block_position_ids = position_ids[:, start:start + block_size]
|
| 151 |
+
noise_embedding = target_model.model.embed_tokens(block_output_ids)
|
| 152 |
+
|
| 153 |
+
draft_logits = target_model.lm_head(
|
| 154 |
+
draft_model(
|
| 155 |
+
target_hidden=target_hidden,
|
| 156 |
+
noise_embedding=noise_embedding,
|
| 157 |
+
position_ids=position_ids[:, past_key_values_draft.get_seq_length():start + block_size],
|
| 158 |
+
past_key_values=past_key_values_draft,
|
| 159 |
+
use_cache=True,
|
| 160 |
+
is_causal=False,
|
| 161 |
+
)[:, -block_size + 1:, :]
|
| 162 |
+
)
|
| 163 |
+
past_key_values_draft.crop(start)
|
| 164 |
+
block_output_ids[:, 1:] = sample(draft_logits)
|
| 165 |
+
|
| 166 |
+
output = target_model(
|
| 167 |
+
block_output_ids,
|
| 168 |
+
position_ids=block_position_ids,
|
| 169 |
+
past_key_values=past_key_values_target,
|
| 170 |
+
use_cache=True,
|
| 171 |
+
output_hidden_states=True,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
posterior = sample(output.logits, temperature)
|
| 175 |
+
acceptance_length = (
|
| 176 |
+
(block_output_ids[:, 1:] == posterior[:, :-1])
|
| 177 |
+
.cumprod(dim=1).sum(dim=1)[0].item()
|
| 178 |
+
)
|
| 179 |
+
output_ids[:, start:start + int(acceptance_length) + 1] = block_output_ids[:, :int(acceptance_length) + 1]
|
| 180 |
+
output_ids[:, start + int(acceptance_length) + 1] = posterior[:, int(acceptance_length)]
|
| 181 |
+
start += int(acceptance_length) + 1
|
| 182 |
+
past_key_values_target.crop(start)
|
| 183 |
+
target_hidden = extract_context_feature(
|
| 184 |
+
output.hidden_states, draft_model.target_layer_ids
|
| 185 |
+
)[:, :int(acceptance_length) + 1, :]
|
| 186 |
+
acceptance_lengths.append(int(acceptance_length) + 1)
|
| 187 |
+
|
| 188 |
+
if stop_token_ids is not None and any(
|
| 189 |
+
sid in output_ids[:, num_input_tokens:start] for sid in stop_token_ids
|
| 190 |
+
):
|
| 191 |
+
break
|
| 192 |
+
|
| 193 |
+
output_ids = output_ids[:, :max_length]
|
| 194 |
+
output_ids = output_ids[:, output_ids[0] != mask_token_id]
|
| 195 |
+
if stop_token_ids is not None:
|
| 196 |
+
stop_t = torch.tensor(stop_token_ids, device=output_ids.device)
|
| 197 |
+
stop_idx = torch.isin(output_ids[0][num_input_tokens:], stop_t).nonzero(as_tuple=True)[0]
|
| 198 |
+
if stop_idx.numel() > 0:
|
| 199 |
+
output_ids = output_ids[:, :num_input_tokens + stop_idx[0] + 1]
|
| 200 |
+
|
| 201 |
+
return output_ids, acceptance_lengths
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# ──────────────────────────────────────────────────────────────────
|
| 205 |
+
def parse_args():
|
| 206 |
+
p = argparse.ArgumentParser()
|
| 207 |
+
p.add_argument("--base-model", default=BASE_MODEL)
|
| 208 |
+
p.add_argument("--draft-model", default=DRAFT_MODEL)
|
| 209 |
+
p.add_argument("--max-new-tokens", type=int, default=512)
|
| 210 |
+
p.add_argument("--temperature", type=float, default=0.0)
|
| 211 |
+
p.add_argument("--benchmarks", nargs="+", default=["humaneval", "mtbench", "gsm8k"])
|
| 212 |
+
p.add_argument("--num-samples", type=int, default=None)
|
| 213 |
+
p.add_argument("--output-dir", default=RESULT_DIR)
|
| 214 |
+
return p.parse_args()
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def main():
|
| 218 |
+
args = parse_args()
|
| 219 |
+
|
| 220 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 221 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 222 |
+
|
| 223 |
+
if world_size > 1:
|
| 224 |
+
dist.init_process_group(backend="nccl")
|
| 225 |
+
torch.cuda.set_device(local_rank)
|
| 226 |
+
|
| 227 |
+
device = f"cuda:{local_rank}"
|
| 228 |
+
rank = get_rank()
|
| 229 |
+
|
| 230 |
+
print_rank0(f"Running DFlash-b16 baseline on {world_size} GPU(s)")
|
| 231 |
+
|
| 232 |
+
# ── Load models ──
|
| 233 |
+
print_rank0(f"Loading target: {args.base_model}")
|
| 234 |
+
target_model = AutoModelForCausalLM.from_pretrained(
|
| 235 |
+
args.base_model,
|
| 236 |
+
torch_dtype=torch.bfloat16,
|
| 237 |
+
device_map=device,
|
| 238 |
+
trust_remote_code=True,
|
| 239 |
+
)
|
| 240 |
+
target_model.eval()
|
| 241 |
+
|
| 242 |
+
print_rank0(f"Loading DFlash-b16 draft: {args.draft_model}")
|
| 243 |
+
draft_model = AutoModel.from_pretrained(
|
| 244 |
+
args.draft_model,
|
| 245 |
+
torch_dtype=torch.bfloat16,
|
| 246 |
+
trust_remote_code=True,
|
| 247 |
+
).to(device)
|
| 248 |
+
draft_model.eval()
|
| 249 |
+
|
| 250 |
+
tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
|
| 251 |
+
stop_token_ids = [tokenizer.eos_token_id]
|
| 252 |
+
|
| 253 |
+
print_rank0(f"DFlash-b16: block_size={draft_model.block_size}, "
|
| 254 |
+
f"target_layer_ids={draft_model.target_layer_ids}, "
|
| 255 |
+
f"num_layers={len(draft_model.layers)}")
|
| 256 |
+
|
| 257 |
+
# ── Run benchmarks ──
|
| 258 |
+
results = {"model": "Qwen3-8B-DFlash-b16", "type": "baseline",
|
| 259 |
+
"block_size": draft_model.block_size}
|
| 260 |
+
|
| 261 |
+
for bench_name in args.benchmarks:
|
| 262 |
+
print_rank0(f"\n{'='*60}")
|
| 263 |
+
print_rank0(f"Benchmark: {bench_name} ({world_size} GPUs)")
|
| 264 |
+
print_rank0(f"{'='*60}")
|
| 265 |
+
|
| 266 |
+
all_prompts = load_prompts(bench_name, args.num_samples)
|
| 267 |
+
my_prompts = split_list(all_prompts, rank, world_size)
|
| 268 |
+
print_rank0(f"Total {len(all_prompts)} prompts, ~{len(my_prompts)} per GPU")
|
| 269 |
+
|
| 270 |
+
local_accept_lengths = []
|
| 271 |
+
local_tokens = 0
|
| 272 |
+
t0 = time.time()
|
| 273 |
+
|
| 274 |
+
iterator = tqdm(my_prompts, desc=f"[GPU{rank}] {bench_name}", unit="sample",
|
| 275 |
+
disable=(rank != 0))
|
| 276 |
+
for prompt in iterator:
|
| 277 |
+
messages = [{"role": "user", "content": prompt}]
|
| 278 |
+
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 279 |
+
input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
|
| 280 |
+
|
| 281 |
+
output_ids, accept_lens = spec_generate_b16(
|
| 282 |
+
draft_model=draft_model,
|
| 283 |
+
target_model=target_model,
|
| 284 |
+
input_ids=input_ids,
|
| 285 |
+
max_new_tokens=args.max_new_tokens,
|
| 286 |
+
temperature=args.temperature,
|
| 287 |
+
stop_token_ids=stop_token_ids,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
local_accept_lengths.extend(accept_lens)
|
| 291 |
+
num_gen = output_ids.shape[1] - input_ids.shape[1]
|
| 292 |
+
local_tokens += num_gen
|
| 293 |
+
|
| 294 |
+
if rank == 0 and len(local_accept_lengths) > 0:
|
| 295 |
+
avg = sum(local_accept_lengths) / len(local_accept_lengths)
|
| 296 |
+
iterator.set_postfix(accept_len=f"{avg:.2f}", tokens=local_tokens, gen=num_gen)
|
| 297 |
+
|
| 298 |
+
elapsed = time.time() - t0
|
| 299 |
+
|
| 300 |
+
# ── Gather ──
|
| 301 |
+
if world_size > 1:
|
| 302 |
+
local_sum = torch.tensor(sum(local_accept_lengths), dtype=torch.float64, device=device)
|
| 303 |
+
local_count = torch.tensor(len(local_accept_lengths), dtype=torch.long, device=device)
|
| 304 |
+
local_tok = torch.tensor(local_tokens, dtype=torch.long, device=device)
|
| 305 |
+
dist.all_reduce(local_sum, op=dist.ReduceOp.SUM)
|
| 306 |
+
dist.all_reduce(local_count, op=dist.ReduceOp.SUM)
|
| 307 |
+
dist.all_reduce(local_tok, op=dist.ReduceOp.SUM)
|
| 308 |
+
total_accept_sum = local_sum.item()
|
| 309 |
+
total_count = local_count.item()
|
| 310 |
+
total_tokens = local_tok.item()
|
| 311 |
+
else:
|
| 312 |
+
total_accept_sum = sum(local_accept_lengths)
|
| 313 |
+
total_count = len(local_accept_lengths)
|
| 314 |
+
total_tokens = local_tokens
|
| 315 |
+
|
| 316 |
+
avg_accept_length = total_accept_sum / max(total_count, 1)
|
| 317 |
+
throughput = total_tokens / elapsed if elapsed > 0 else 0
|
| 318 |
+
|
| 319 |
+
print_rank0(f"\n{bench_name} Results:")
|
| 320 |
+
print_rank0(f" Avg Accept Length: {avg_accept_length:.3f}")
|
| 321 |
+
print_rank0(f" Total tokens: {total_tokens}")
|
| 322 |
+
print_rank0(f" Latency: {elapsed:.1f}s")
|
| 323 |
+
print_rank0(f" Throughput: {throughput:.1f} tok/s (aggregate {world_size} GPUs)")
|
| 324 |
+
print_rank0(f" Num verify rounds: {total_count}")
|
| 325 |
+
print_rank0(f" Num samples: {len(all_prompts)}")
|
| 326 |
+
|
| 327 |
+
results[bench_name] = {
|
| 328 |
+
"avg_accept_length": avg_accept_length,
|
| 329 |
+
"total_tokens": total_tokens,
|
| 330 |
+
"latency": elapsed,
|
| 331 |
+
"throughput": throughput,
|
| 332 |
+
"num_samples": len(all_prompts),
|
| 333 |
+
"num_verify_rounds": total_count,
|
| 334 |
+
"num_gpus": world_size,
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
# ── Save ──
|
| 338 |
+
if is_main():
|
| 339 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 340 |
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
| 341 |
+
result_file = os.path.join(
|
| 342 |
+
args.output_dir,
|
| 343 |
+
f"dflash_b16_baseline_offline_{timestamp}.json",
|
| 344 |
+
)
|
| 345 |
+
with open(result_file, "w") as f:
|
| 346 |
+
json.dump(results, f, indent=2)
|
| 347 |
+
print(f"\nResults saved to: {result_file}")
|
| 348 |
+
|
| 349 |
+
if world_size > 1:
|
| 350 |
+
dist.destroy_process_group()
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
if __name__ == "__main__":
|
| 354 |
+
main()
|
syxin/eval_dflash_lora_inject.py
ADDED
|
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Offline evaluation for DFlash-LoRA-Inject: measure accepted length & speedup.
|
| 4 |
+
Aligned with official DFlash benchmark.py methodology.
|
| 5 |
+
|
| 6 |
+
Unlike DFlash-b16 which uses a small 5-layer draft model with fc/hidden_norm,
|
| 7 |
+
LoRA-Inject uses a full Qwen3-8B with LoRA adapters that receives target hidden
|
| 8 |
+
states via layer-by-layer injection.
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
conda activate spec
|
| 12 |
+
|
| 13 |
+
# 8 GPU parallel (default, all 10 benchmarks)
|
| 14 |
+
torchrun --nproc_per_node 8 eval_dflash_lora_inject.py
|
| 15 |
+
|
| 16 |
+
# single GPU
|
| 17 |
+
python3 eval_dflash_lora_inject.py
|
| 18 |
+
|
| 19 |
+
# specific checkpoint / benchmark
|
| 20 |
+
torchrun --nproc_per_node 8 eval_dflash_lora_inject.py --ckpt epoch_0_step_1000 --datasets humaneval
|
| 21 |
+
|
| 22 |
+
# quick test
|
| 23 |
+
torchrun --nproc_per_node 8 eval_dflash_lora_inject.py --max-samples 20
|
| 24 |
+
"""
|
| 25 |
+
import argparse
|
| 26 |
+
import json
|
| 27 |
+
import os
|
| 28 |
+
import random
|
| 29 |
+
import sys
|
| 30 |
+
import time
|
| 31 |
+
import warnings
|
| 32 |
+
from itertools import chain
|
| 33 |
+
from types import SimpleNamespace
|
| 34 |
+
from typing import List, Optional, Tuple
|
| 35 |
+
|
| 36 |
+
import numpy as np
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn as nn
|
| 39 |
+
import torch.distributed as dist
|
| 40 |
+
from peft import PeftModel
|
| 41 |
+
from tqdm import tqdm
|
| 42 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
|
| 43 |
+
|
| 44 |
+
# Import official dataset loader
|
| 45 |
+
sys.path.insert(0, "/workspace/hanrui/dflash")
|
| 46 |
+
from model.utils import load_and_process_dataset
|
| 47 |
+
|
| 48 |
+
# ──────────────────────────────────────────────────────────────────
|
| 49 |
+
# Config defaults
|
| 50 |
+
# ──────────────────────────────────────────────────────────────────
|
| 51 |
+
BASE_MODEL = "/workspace/models/Qwen3-8B"
|
| 52 |
+
ADAPTER_ROOT = "/workspace/hanrui/syxin/Specforge/outputs/qwen3-8b-dflash-lora-inject"
|
| 53 |
+
DEFAULT_CKPT = "epoch_3_step_1400"
|
| 54 |
+
MASK_TOKEN_ID = 151669 # Qwen3 <|mask|>
|
| 55 |
+
BLOCK_SIZE = 16
|
| 56 |
+
RESULT_DIR = "/workspace/hanrui/syxin/Specforge/benchmarks/results"
|
| 57 |
+
|
| 58 |
+
# Official benchmark tasks (from run_benchmark.sh)
|
| 59 |
+
OFFICIAL_TASKS = {
|
| 60 |
+
"gsm8k": 128,
|
| 61 |
+
"math500": 128,
|
| 62 |
+
"aime24": 30,
|
| 63 |
+
"aime25": 30,
|
| 64 |
+
"humaneval": 164,
|
| 65 |
+
"mbpp": 128,
|
| 66 |
+
"livecodebench": 128,
|
| 67 |
+
"swe-bench": 128,
|
| 68 |
+
"mt-bench": 80,
|
| 69 |
+
"alpaca": 128,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# ──────────────────────────────────────────────────────────────────
|
| 74 |
+
# CUDA-synchronised timer (matches official benchmark.py)
|
| 75 |
+
# ──────────────────────────────────────────────────────────────────
|
| 76 |
+
def cuda_time() -> float:
|
| 77 |
+
torch.cuda.synchronize()
|
| 78 |
+
return time.perf_counter()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def has_flash_attn() -> bool:
|
| 82 |
+
try:
|
| 83 |
+
import flash_attn # noqa: F401
|
| 84 |
+
return True
|
| 85 |
+
except ImportError:
|
| 86 |
+
print("[WARN] flash_attn not installed, falling back to sdpa.")
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# ──────────────────────────────────────────────────────────────────
|
| 91 |
+
# Distributed helpers (mirrors official distributed.py)
|
| 92 |
+
# ──────────────────────────────────────────────────────────────────
|
| 93 |
+
def dist_init():
|
| 94 |
+
if "RANK" not in os.environ:
|
| 95 |
+
warnings.warn("RANK not set. Skipping distributed init.")
|
| 96 |
+
return
|
| 97 |
+
dist.init_process_group(backend="nccl", init_method="env://")
|
| 98 |
+
|
| 99 |
+
def dist_rank():
|
| 100 |
+
return int(os.environ.get("RANK", 0))
|
| 101 |
+
|
| 102 |
+
def dist_size():
|
| 103 |
+
return int(os.environ.get("WORLD_SIZE", 1))
|
| 104 |
+
|
| 105 |
+
def dist_local_rank():
|
| 106 |
+
return int(os.environ.get("LOCAL_RANK", 0))
|
| 107 |
+
|
| 108 |
+
def dist_is_main():
|
| 109 |
+
return dist_rank() == 0
|
| 110 |
+
|
| 111 |
+
def dist_gather(obj, dst=0):
|
| 112 |
+
if not dist.is_initialized():
|
| 113 |
+
return [obj]
|
| 114 |
+
if dist_is_main():
|
| 115 |
+
objs = [None for _ in range(dist_size())]
|
| 116 |
+
dist.gather_object(obj, objs, dst=dst)
|
| 117 |
+
return objs
|
| 118 |
+
else:
|
| 119 |
+
dist.gather_object(obj, dst=dst)
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
def print_rank0(*args, **kwargs):
|
| 123 |
+
if dist_is_main():
|
| 124 |
+
print(*args, **kwargs)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ──────────────────────────────────────────────────────────────────
|
| 128 |
+
# Sampling (matches official model/utils.py::sample)
|
| 129 |
+
# ──────────────────────────────────────────────────────────────────
|
| 130 |
+
def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor:
|
| 131 |
+
if temperature < 1e-5:
|
| 132 |
+
return torch.argmax(logits, dim=-1)
|
| 133 |
+
bsz, seq_len, vocab_size = logits.shape
|
| 134 |
+
logits = logits.view(-1, vocab_size)
|
| 135 |
+
logits = logits / temperature
|
| 136 |
+
probs = torch.softmax(logits, dim=-1)
|
| 137 |
+
return torch.multinomial(probs, num_samples=1).view(bsz, seq_len)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# ──────────────────────────────────────────────────────────────────
|
| 141 |
+
# Build DFlash attention mask (vectorized, no Python loops)
|
| 142 |
+
# ──────────────────────────────────────────────────────────────────
|
| 143 |
+
def build_dflash_mask(ctx_len: int, block_size: int, device, dtype=torch.bfloat16):
|
| 144 |
+
"""
|
| 145 |
+
Build DFlash attention mask for [context | block] sequence.
|
| 146 |
+
- Context part: standard causal
|
| 147 |
+
- Block part: each token sees all context + all tokens in same block (bidirectional)
|
| 148 |
+
"""
|
| 149 |
+
full_len = ctx_len + block_size
|
| 150 |
+
neg_inf = torch.finfo(dtype).min
|
| 151 |
+
|
| 152 |
+
mask = torch.full((1, 1, full_len, full_len), neg_inf, device=device, dtype=dtype)
|
| 153 |
+
|
| 154 |
+
if ctx_len > 0:
|
| 155 |
+
ctx_rows = torch.arange(ctx_len, device=device)
|
| 156 |
+
ctx_cols = torch.arange(ctx_len, device=device)
|
| 157 |
+
causal = ctx_cols.unsqueeze(0) <= ctx_rows.unsqueeze(1)
|
| 158 |
+
mask[0, 0, :ctx_len, :ctx_len].masked_fill_(causal, 0)
|
| 159 |
+
|
| 160 |
+
if ctx_len > 0:
|
| 161 |
+
mask[0, 0, ctx_len:, :ctx_len] = 0
|
| 162 |
+
mask[0, 0, ctx_len:, ctx_len:] = 0
|
| 163 |
+
|
| 164 |
+
return mask
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# ──────────────────────────────────────────────────────────────────
|
| 168 |
+
# Pure autoregressive generation (target model only, no draft)
|
| 169 |
+
# Used for AR baseline timing -- avoids inflating AR time with draft overhead.
|
| 170 |
+
# ──────────────────────────────────────────────────────────────────
|
| 171 |
+
@torch.inference_mode()
|
| 172 |
+
def ar_generate(
|
| 173 |
+
target_model: nn.Module,
|
| 174 |
+
input_ids: torch.LongTensor,
|
| 175 |
+
max_new_tokens: int = 2048,
|
| 176 |
+
mask_token_id: int = MASK_TOKEN_ID,
|
| 177 |
+
temperature: float = 0.0,
|
| 178 |
+
stop_token_ids: Optional[List[int]] = None,
|
| 179 |
+
) -> SimpleNamespace:
|
| 180 |
+
"""
|
| 181 |
+
Pure autoregressive generation using only the target model.
|
| 182 |
+
Mirrors official benchmark.py with block_size=1 (no draft model involved).
|
| 183 |
+
Returns SimpleNamespace matching official dflash_generate output format.
|
| 184 |
+
"""
|
| 185 |
+
device = input_ids.device
|
| 186 |
+
num_input_tokens = input_ids.shape[1]
|
| 187 |
+
max_length = num_input_tokens + max_new_tokens
|
| 188 |
+
|
| 189 |
+
output_ids = torch.full(
|
| 190 |
+
(1, max_length + 1), mask_token_id,
|
| 191 |
+
dtype=torch.long, device=device,
|
| 192 |
+
)
|
| 193 |
+
output_ids[:, :num_input_tokens] = input_ids
|
| 194 |
+
position_ids = torch.arange(output_ids.shape[1], device=device).unsqueeze(0)
|
| 195 |
+
past_key_values = DynamicCache()
|
| 196 |
+
|
| 197 |
+
# Prefill
|
| 198 |
+
prefill_start = cuda_time()
|
| 199 |
+
output = target_model(
|
| 200 |
+
input_ids,
|
| 201 |
+
position_ids=position_ids[:, :num_input_tokens],
|
| 202 |
+
past_key_values=past_key_values,
|
| 203 |
+
use_cache=True,
|
| 204 |
+
logits_to_keep=1,
|
| 205 |
+
output_hidden_states=False,
|
| 206 |
+
)
|
| 207 |
+
first_token = sample(output.logits, temperature)
|
| 208 |
+
output_ids[:, num_input_tokens:num_input_tokens + 1] = first_token
|
| 209 |
+
time_to_first_token = cuda_time() - prefill_start
|
| 210 |
+
|
| 211 |
+
# Decode (autoregressive, one token at a time)
|
| 212 |
+
decode_start = cuda_time()
|
| 213 |
+
start = num_input_tokens
|
| 214 |
+
|
| 215 |
+
while start < max_length:
|
| 216 |
+
cur_token = output_ids[:, start:start + 1]
|
| 217 |
+
cur_pos = position_ids[:, start:start + 1]
|
| 218 |
+
|
| 219 |
+
output = target_model(
|
| 220 |
+
cur_token,
|
| 221 |
+
position_ids=cur_pos,
|
| 222 |
+
past_key_values=past_key_values,
|
| 223 |
+
use_cache=True,
|
| 224 |
+
output_hidden_states=False,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
next_token = sample(output.logits, temperature)
|
| 228 |
+
start += 1
|
| 229 |
+
output_ids[:, start:start + 1] = next_token
|
| 230 |
+
past_key_values.crop(start)
|
| 231 |
+
|
| 232 |
+
# Check stop tokens (matches official: check all generated)
|
| 233 |
+
if stop_token_ids is not None and any(
|
| 234 |
+
sid in output_ids[:, num_input_tokens:] for sid in stop_token_ids
|
| 235 |
+
):
|
| 236 |
+
break
|
| 237 |
+
|
| 238 |
+
output_ids = output_ids[:, :max_length]
|
| 239 |
+
output_ids = output_ids[:, output_ids[0] != mask_token_id]
|
| 240 |
+
if stop_token_ids is not None:
|
| 241 |
+
stop_t = torch.tensor(stop_token_ids, device=output_ids.device)
|
| 242 |
+
stop_idx = torch.isin(output_ids[0][num_input_tokens:], stop_t).nonzero(as_tuple=True)[0]
|
| 243 |
+
if stop_idx.numel() > 0:
|
| 244 |
+
output_ids = output_ids[:, :num_input_tokens + stop_idx[0] + 1]
|
| 245 |
+
|
| 246 |
+
num_output_tokens = output_ids.shape[1] - num_input_tokens
|
| 247 |
+
total_decode_time = cuda_time() - decode_start
|
| 248 |
+
time_per_output_token = total_decode_time / max(num_output_tokens, 1)
|
| 249 |
+
|
| 250 |
+
return SimpleNamespace(
|
| 251 |
+
output_ids=output_ids,
|
| 252 |
+
num_input_tokens=num_input_tokens,
|
| 253 |
+
num_output_tokens=num_output_tokens,
|
| 254 |
+
time_to_first_token=time_to_first_token,
|
| 255 |
+
time_per_output_token=time_per_output_token,
|
| 256 |
+
acceptance_lengths=[1] * max(num_output_tokens, 0), # AR: always 1
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# ──────────────────────────────────────────────────────────────────
|
| 261 |
+
# Core: spec_generate with layer-by-layer injection (KV-cached)
|
| 262 |
+
# ──────────────────────────────────────────────────────────────────
|
| 263 |
+
@torch.inference_mode()
|
| 264 |
+
def spec_generate_inject(
|
| 265 |
+
target_model: nn.Module,
|
| 266 |
+
draft_model: nn.Module,
|
| 267 |
+
input_ids: torch.LongTensor,
|
| 268 |
+
max_new_tokens: int = 2048,
|
| 269 |
+
block_size: int = 16,
|
| 270 |
+
mask_token_id: int = MASK_TOKEN_ID,
|
| 271 |
+
temperature: float = 0.0,
|
| 272 |
+
stop_token_ids: Optional[List[int]] = None,
|
| 273 |
+
) -> SimpleNamespace:
|
| 274 |
+
"""
|
| 275 |
+
Speculative generation using DFlash-LoRA-Inject inference pattern.
|
| 276 |
+
Returns SimpleNamespace matching official dflash_generate output format.
|
| 277 |
+
"""
|
| 278 |
+
device = input_ids.device
|
| 279 |
+
num_input_tokens = input_ids.shape[1]
|
| 280 |
+
max_length = num_input_tokens + max_new_tokens
|
| 281 |
+
|
| 282 |
+
draft_layers = draft_model.model.layers
|
| 283 |
+
draft_norm = draft_model.model.norm
|
| 284 |
+
draft_lm_head = draft_model.lm_head
|
| 285 |
+
rotary_emb = draft_model.model.rotary_emb
|
| 286 |
+
num_layers = len(draft_layers)
|
| 287 |
+
|
| 288 |
+
output_ids = torch.full(
|
| 289 |
+
(1, max_length + block_size), mask_token_id,
|
| 290 |
+
dtype=torch.long, device=device,
|
| 291 |
+
)
|
| 292 |
+
output_ids[:, :num_input_tokens] = input_ids
|
| 293 |
+
|
| 294 |
+
# ── Prefill: target with KV cache + hidden states ──
|
| 295 |
+
prefill_start = cuda_time()
|
| 296 |
+
target_kv = DynamicCache()
|
| 297 |
+
target_output = target_model(
|
| 298 |
+
input_ids,
|
| 299 |
+
past_key_values=target_kv,
|
| 300 |
+
use_cache=True,
|
| 301 |
+
output_hidden_states=True,
|
| 302 |
+
)
|
| 303 |
+
first_token = sample(target_output.logits[:, -1:, :], temperature)
|
| 304 |
+
output_ids[:, num_input_tokens] = first_token.squeeze()
|
| 305 |
+
|
| 306 |
+
ctx_hidden_per_layer = [
|
| 307 |
+
target_output.hidden_states[i + 1]
|
| 308 |
+
for i in range(num_layers)
|
| 309 |
+
]
|
| 310 |
+
|
| 311 |
+
time_to_first_token = cuda_time() - prefill_start
|
| 312 |
+
|
| 313 |
+
# Decode
|
| 314 |
+
decode_start = cuda_time()
|
| 315 |
+
acceptance_lengths = []
|
| 316 |
+
start = num_input_tokens
|
| 317 |
+
draft_prefill = True
|
| 318 |
+
|
| 319 |
+
while start < max_length:
|
| 320 |
+
end = min(start + block_size, max_length)
|
| 321 |
+
actual_block_size = end - start
|
| 322 |
+
|
| 323 |
+
block_ids = output_ids[:, start:end].clone()
|
| 324 |
+
|
| 325 |
+
# ── Draft: forward with layer-by-layer injection ──
|
| 326 |
+
draft_hidden = draft_model.model.embed_tokens(block_ids)
|
| 327 |
+
ctx_len = ctx_hidden_per_layer[0].shape[1]
|
| 328 |
+
|
| 329 |
+
dflash_mask = build_dflash_mask(ctx_len, actual_block_size, device)
|
| 330 |
+
combined_pos = torch.arange(ctx_len + actual_block_size, device=device).unsqueeze(0)
|
| 331 |
+
|
| 332 |
+
dummy_combined = torch.empty(1, ctx_len + actual_block_size, draft_hidden.shape[-1],
|
| 333 |
+
device=device, dtype=torch.bfloat16)
|
| 334 |
+
position_embeddings = rotary_emb(dummy_combined, combined_pos)
|
| 335 |
+
|
| 336 |
+
for layer_idx in range(num_layers):
|
| 337 |
+
target_ctx = ctx_hidden_per_layer[layer_idx]
|
| 338 |
+
combined = torch.cat([target_ctx, draft_hidden], dim=1)
|
| 339 |
+
|
| 340 |
+
layer_output = draft_layers[layer_idx](
|
| 341 |
+
combined,
|
| 342 |
+
attention_mask=dflash_mask,
|
| 343 |
+
position_ids=combined_pos,
|
| 344 |
+
position_embeddings=position_embeddings,
|
| 345 |
+
)
|
| 346 |
+
if isinstance(layer_output, tuple):
|
| 347 |
+
layer_output = layer_output[0]
|
| 348 |
+
draft_hidden = layer_output[:, ctx_len:, :]
|
| 349 |
+
|
| 350 |
+
draft_hidden = draft_norm(draft_hidden)
|
| 351 |
+
draft_logits = draft_lm_head(draft_hidden)
|
| 352 |
+
|
| 353 |
+
draft_predictions = sample(draft_logits[:, :-1, :], temperature)
|
| 354 |
+
block_ids[:, 1:actual_block_size] = draft_predictions[:, :actual_block_size - 1]
|
| 355 |
+
|
| 356 |
+
# Exclude draft's first prefill from decode timing (matches official pattern)
|
| 357 |
+
if draft_prefill:
|
| 358 |
+
draft_prefill = False
|
| 359 |
+
decode_start = cuda_time()
|
| 360 |
+
|
| 361 |
+
# ── Verify: target forward on block tokens (with KV cache) ──
|
| 362 |
+
position_ids_block = torch.arange(
|
| 363 |
+
start, start + actual_block_size, device=device
|
| 364 |
+
).unsqueeze(0)
|
| 365 |
+
|
| 366 |
+
target_verify = target_model(
|
| 367 |
+
block_ids,
|
| 368 |
+
position_ids=position_ids_block,
|
| 369 |
+
past_key_values=target_kv,
|
| 370 |
+
use_cache=True,
|
| 371 |
+
output_hidden_states=True,
|
| 372 |
+
)
|
| 373 |
+
target_tokens = sample(target_verify.logits, temperature)
|
| 374 |
+
|
| 375 |
+
# Acceptance
|
| 376 |
+
matches = (block_ids[:, 1:actual_block_size] == target_tokens[:, :actual_block_size - 1])
|
| 377 |
+
acceptance_length = int(matches.cumprod(dim=1).sum(dim=1)[0].item())
|
| 378 |
+
|
| 379 |
+
output_ids[:, start:start + acceptance_length + 1] = block_ids[:, :acceptance_length + 1]
|
| 380 |
+
output_ids[:, start + acceptance_length + 1] = target_tokens[:, acceptance_length]
|
| 381 |
+
|
| 382 |
+
accepted_end = start + acceptance_length + 1
|
| 383 |
+
target_kv.crop(accepted_end)
|
| 384 |
+
|
| 385 |
+
for i in range(num_layers):
|
| 386 |
+
new_hidden = target_verify.hidden_states[i + 1][:, :acceptance_length + 1, :]
|
| 387 |
+
ctx_hidden_per_layer[i] = torch.cat([ctx_hidden_per_layer[i], new_hidden], dim=1)
|
| 388 |
+
|
| 389 |
+
start += acceptance_length + 1
|
| 390 |
+
acceptance_lengths.append(acceptance_length + 1)
|
| 391 |
+
|
| 392 |
+
# Official: check ALL generated tokens
|
| 393 |
+
if stop_token_ids is not None and any(
|
| 394 |
+
sid in output_ids[:, num_input_tokens:] for sid in stop_token_ids
|
| 395 |
+
):
|
| 396 |
+
break
|
| 397 |
+
|
| 398 |
+
output_ids = output_ids[:, :min(start, max_length)]
|
| 399 |
+
output_ids = output_ids[:, output_ids[0] != mask_token_id]
|
| 400 |
+
if stop_token_ids is not None:
|
| 401 |
+
stop_t = torch.tensor(stop_token_ids, device=output_ids.device)
|
| 402 |
+
stop_idx = torch.isin(output_ids[0][num_input_tokens:], stop_t).nonzero(as_tuple=True)[0]
|
| 403 |
+
if stop_idx.numel() > 0:
|
| 404 |
+
output_ids = output_ids[:, :num_input_tokens + stop_idx[0] + 1]
|
| 405 |
+
|
| 406 |
+
num_output_tokens = output_ids.shape[1] - num_input_tokens
|
| 407 |
+
total_decode_time = cuda_time() - decode_start
|
| 408 |
+
time_per_output_token = total_decode_time / max(num_output_tokens, 1)
|
| 409 |
+
|
| 410 |
+
return SimpleNamespace(
|
| 411 |
+
output_ids=output_ids,
|
| 412 |
+
num_input_tokens=num_input_tokens,
|
| 413 |
+
num_output_tokens=num_output_tokens,
|
| 414 |
+
time_to_first_token=time_to_first_token,
|
| 415 |
+
time_per_output_token=time_per_output_token,
|
| 416 |
+
acceptance_lengths=acceptance_lengths,
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
# ──────────────────────────────────────────────────────────────────
|
| 421 |
+
# Main
|
| 422 |
+
# ──────────────────────────────────────────────────────────────────
|
| 423 |
+
def parse_args():
|
| 424 |
+
p = argparse.ArgumentParser(description="Offline eval for DFlash-LoRA-Inject (aligned with official)")
|
| 425 |
+
p.add_argument("--base-model", default=BASE_MODEL)
|
| 426 |
+
p.add_argument("--adapter-root", default=ADAPTER_ROOT)
|
| 427 |
+
p.add_argument("--ckpt", default=DEFAULT_CKPT, help="Checkpoint folder name")
|
| 428 |
+
p.add_argument("--merged-path",
|
| 429 |
+
default="/workspace/hanrui/syxin/Specforge/outputs/qwen3-8b-dflash-lora-inject-merged",
|
| 430 |
+
help="Path to pre-merged model. If None, will merge on the fly.")
|
| 431 |
+
p.add_argument("--block-size", type=int, default=BLOCK_SIZE)
|
| 432 |
+
p.add_argument("--max-new-tokens", type=int, default=2048,
|
| 433 |
+
help="Max new tokens per turn (official shell uses 2048)")
|
| 434 |
+
p.add_argument("--temperature", type=float, default=0.0)
|
| 435 |
+
p.add_argument("--datasets", nargs="+", default=list(OFFICIAL_TASKS.keys()),
|
| 436 |
+
help="Benchmarks to run (default: all 10 official tasks)")
|
| 437 |
+
p.add_argument("--max-samples", type=int, default=None,
|
| 438 |
+
help="Override max samples per dataset (None = use official per-task counts)")
|
| 439 |
+
p.add_argument("--output-dir", default=RESULT_DIR)
|
| 440 |
+
return p.parse_args()
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def main():
|
| 444 |
+
args = parse_args()
|
| 445 |
+
|
| 446 |
+
# Fix random seeds (matches official)
|
| 447 |
+
random.seed(0)
|
| 448 |
+
np.random.seed(0)
|
| 449 |
+
torch.manual_seed(0)
|
| 450 |
+
torch.cuda.manual_seed_all(0)
|
| 451 |
+
torch.backends.cudnn.deterministic = True
|
| 452 |
+
torch.backends.cudnn.benchmark = False
|
| 453 |
+
|
| 454 |
+
# ── Init distributed ──
|
| 455 |
+
dist_init()
|
| 456 |
+
torch.cuda.set_device(dist_local_rank())
|
| 457 |
+
device = torch.device(f"cuda:{dist_local_rank()}")
|
| 458 |
+
|
| 459 |
+
print_rank0(f"Running on {dist_size()} GPU(s)")
|
| 460 |
+
|
| 461 |
+
# Detect flash_attn (only for target model; draft needs sdpa for custom DFlash mask)
|
| 462 |
+
installed_flash_attn = has_flash_attn()
|
| 463 |
+
target_attn_impl = "flash_attention_2" if installed_flash_attn else "sdpa"
|
| 464 |
+
draft_attn_impl = "sdpa" # DFlash injection uses custom attention mask
|
| 465 |
+
print_rank0(f"Using attn_implementation: target={target_attn_impl}, draft={draft_attn_impl}")
|
| 466 |
+
|
| 467 |
+
# ── Load models ──
|
| 468 |
+
print_rank0(f"Loading target model: {args.base_model}")
|
| 469 |
+
target_model = AutoModelForCausalLM.from_pretrained(
|
| 470 |
+
args.base_model,
|
| 471 |
+
torch_dtype=torch.bfloat16,
|
| 472 |
+
attn_implementation=target_attn_impl,
|
| 473 |
+
device_map=device,
|
| 474 |
+
trust_remote_code=True,
|
| 475 |
+
)
|
| 476 |
+
target_model.eval()
|
| 477 |
+
|
| 478 |
+
if args.merged_path and os.path.isdir(args.merged_path):
|
| 479 |
+
print_rank0(f"Loading pre-merged draft model: {args.merged_path}")
|
| 480 |
+
draft_model = AutoModelForCausalLM.from_pretrained(
|
| 481 |
+
args.merged_path,
|
| 482 |
+
torch_dtype=torch.bfloat16,
|
| 483 |
+
attn_implementation=draft_attn_impl,
|
| 484 |
+
device_map=device,
|
| 485 |
+
trust_remote_code=True,
|
| 486 |
+
)
|
| 487 |
+
else:
|
| 488 |
+
adapter_path = os.path.join(args.adapter_root, args.ckpt)
|
| 489 |
+
print_rank0(f"Loading base + LoRA adapter: {adapter_path}")
|
| 490 |
+
draft_model = AutoModelForCausalLM.from_pretrained(
|
| 491 |
+
args.base_model,
|
| 492 |
+
torch_dtype=torch.bfloat16,
|
| 493 |
+
attn_implementation=draft_attn_impl,
|
| 494 |
+
device_map=device,
|
| 495 |
+
trust_remote_code=True,
|
| 496 |
+
)
|
| 497 |
+
draft_model = PeftModel.from_pretrained(draft_model, adapter_path)
|
| 498 |
+
draft_model = draft_model.merge_and_unload()
|
| 499 |
+
draft_model.eval()
|
| 500 |
+
|
| 501 |
+
tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
|
| 502 |
+
stop_token_ids = [tokenizer.eos_token_id]
|
| 503 |
+
|
| 504 |
+
block_size = args.block_size
|
| 505 |
+
|
| 506 |
+
# ── Run benchmarks ──
|
| 507 |
+
all_results = {"model": f"dflash-lora-inject/{args.ckpt}", "block_size": block_size}
|
| 508 |
+
|
| 509 |
+
for dataset_name in args.datasets:
|
| 510 |
+
print_rank0(f"\n{'=' * 60}")
|
| 511 |
+
print_rank0(f"Benchmark: {dataset_name} ({dist_size()} GPUs)")
|
| 512 |
+
print_rank0(f"{'=' * 60}")
|
| 513 |
+
|
| 514 |
+
# Load dataset using official loader
|
| 515 |
+
dataset = load_and_process_dataset(dataset_name)
|
| 516 |
+
|
| 517 |
+
# Sample selection: official uses shuffle(seed=0).select()
|
| 518 |
+
max_samples = args.max_samples if args.max_samples is not None else OFFICIAL_TASKS.get(dataset_name)
|
| 519 |
+
if max_samples is not None and len(dataset) > max_samples:
|
| 520 |
+
dataset = dataset.shuffle(seed=0).select(range(max_samples))
|
| 521 |
+
|
| 522 |
+
print_rank0(f"Total {len(dataset)} samples, distributed across {dist_size()} GPUs")
|
| 523 |
+
|
| 524 |
+
responses = []
|
| 525 |
+
indices = range(dist_rank(), len(dataset), dist_size())
|
| 526 |
+
|
| 527 |
+
iterator = tqdm(indices, desc=f"[GPU{dist_rank()}] {dataset_name}",
|
| 528 |
+
unit="sample", disable=not dist_is_main())
|
| 529 |
+
|
| 530 |
+
for idx in iterator:
|
| 531 |
+
instance = dataset[idx]
|
| 532 |
+
|
| 533 |
+
# Multi-turn support (matches official benchmark.py)
|
| 534 |
+
messages = []
|
| 535 |
+
for turn_index, user_content in enumerate(instance["turns"]):
|
| 536 |
+
messages.append({"role": "user", "content": user_content})
|
| 537 |
+
input_text = tokenizer.apply_chat_template(
|
| 538 |
+
messages, tokenize=False, add_generation_prompt=True,
|
| 539 |
+
enable_thinking=False,
|
| 540 |
+
)
|
| 541 |
+
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
| 542 |
+
|
| 543 |
+
response = {}
|
| 544 |
+
|
| 545 |
+
# AR baseline: pure target-only autoregressive (no draft overhead)
|
| 546 |
+
response[1] = ar_generate(
|
| 547 |
+
target_model=target_model,
|
| 548 |
+
input_ids=input_ids,
|
| 549 |
+
max_new_tokens=args.max_new_tokens,
|
| 550 |
+
mask_token_id=MASK_TOKEN_ID,
|
| 551 |
+
temperature=args.temperature,
|
| 552 |
+
stop_token_ids=stop_token_ids,
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
# Speculative: DFlash-LoRA-Inject
|
| 556 |
+
response[block_size] = spec_generate_inject(
|
| 557 |
+
target_model=target_model,
|
| 558 |
+
draft_model=draft_model,
|
| 559 |
+
input_ids=input_ids,
|
| 560 |
+
max_new_tokens=args.max_new_tokens,
|
| 561 |
+
block_size=block_size,
|
| 562 |
+
mask_token_id=MASK_TOKEN_ID,
|
| 563 |
+
temperature=args.temperature,
|
| 564 |
+
stop_token_ids=stop_token_ids,
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
# Append assistant response for multi-turn context
|
| 568 |
+
spec_response = response[block_size]
|
| 569 |
+
generated_ids = spec_response.output_ids[0, spec_response.num_input_tokens:]
|
| 570 |
+
output_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
| 571 |
+
messages.append({"role": "assistant", "content": output_text})
|
| 572 |
+
responses.append(response)
|
| 573 |
+
|
| 574 |
+
if dist_is_main() and responses:
|
| 575 |
+
recent_tau = np.mean([np.mean(r[block_size].acceptance_lengths) for r in responses[-5:]])
|
| 576 |
+
iterator.set_postfix(accept_len=f"{recent_tau:.2f}")
|
| 577 |
+
|
| 578 |
+
# ── Gather to rank 0 (matches official) ──
|
| 579 |
+
if dist_size() > 1:
|
| 580 |
+
gathered = dist_gather(responses, dst=0)
|
| 581 |
+
if not dist_is_main():
|
| 582 |
+
continue
|
| 583 |
+
responses = list(chain(*gathered))
|
| 584 |
+
elif not dist_is_main():
|
| 585 |
+
continue
|
| 586 |
+
|
| 587 |
+
# ── Compute metrics (exact official formulas) ──
|
| 588 |
+
t1 = np.mean([r[1].time_per_output_token for r in responses])
|
| 589 |
+
tb = np.mean([r[block_size].time_per_output_token for r in responses])
|
| 590 |
+
speedup = t1 / tb if tb > 0 else 0
|
| 591 |
+
|
| 592 |
+
# Acceptance length: per-sample mean, then mean of means (official)
|
| 593 |
+
tau = np.mean([np.mean(r[block_size].acceptance_lengths) for r in responses])
|
| 594 |
+
|
| 595 |
+
# Histogram
|
| 596 |
+
acceptance_lengths = list(chain(*[r[block_size].acceptance_lengths for r in responses]))
|
| 597 |
+
histogram = [acceptance_lengths.count(b) / len(acceptance_lengths) for b in range(block_size + 1)]
|
| 598 |
+
|
| 599 |
+
print_rank0(f"\n{dataset_name} Results:")
|
| 600 |
+
print_rank0(f" Decoding speedup: {speedup:.2f}x")
|
| 601 |
+
print_rank0(f" Average Acceptance length: {tau:.2f}")
|
| 602 |
+
print_rank0(f" Acceptance length histogram: {[f'{x * 100:.1f}%' for x in histogram]}")
|
| 603 |
+
print_rank0(f" Num responses: {len(responses)}")
|
| 604 |
+
|
| 605 |
+
all_results[dataset_name] = {
|
| 606 |
+
"decoding_speedup": speedup,
|
| 607 |
+
"avg_accept_length": tau,
|
| 608 |
+
"acceptance_histogram": histogram,
|
| 609 |
+
"num_responses": len(responses),
|
| 610 |
+
"num_gpus": dist_size(),
|
| 611 |
+
}
|
| 612 |
+
|
| 613 |
+
# ── Save results ──
|
| 614 |
+
if dist_is_main():
|
| 615 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 616 |
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
| 617 |
+
result_file = os.path.join(
|
| 618 |
+
args.output_dir,
|
| 619 |
+
f"dflash_lora_inject_offline_{args.ckpt}_{timestamp}.json",
|
| 620 |
+
)
|
| 621 |
+
with open(result_file, "w") as f:
|
| 622 |
+
json.dump(all_results, f, indent=2)
|
| 623 |
+
print(f"\nResults saved to: {result_file}")
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
if __name__ == "__main__":
|
| 627 |
+
main()
|
syxin/idea.md
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
现在关于target model的hidden state注入
|
| 2 |
+
|
| 3 |
+
dflash的做法是,抽5层的feature过一下fc然后concat到mask token对应的hidden state前面
|
| 4 |
+
|
| 5 |
+
但是如果我们的draft是用lora的原始模型
|
| 6 |
+
|
| 7 |
+
我们不用这样注入
|
| 8 |
+
|
| 9 |
+
我们可以直接把target model的hidden state直接层对层拉过来
|
| 10 |
+
|
| 11 |
+
我是把加了lora后的模型作为draft model用的
|
| 12 |
+
|
| 13 |
+
它本质上还是一个speculative decode
|
| 14 |
+
|
| 15 |
+
我的想法的核心是,因为这个draft model足够大,也和target model足够像,把他转为和dflash一样每次用mask直接生成16个token,可能能得到很长的accept len,以此获得加速
|
| 16 |
+
|
| 17 |
+
而dflash能work的核心是,它在生成阶段是使用的部分target model的hidden state,注入到mask token的hidden state前面
|
| 18 |
+
|
| 19 |
+
我们也用相同的做法
|
| 20 |
+
|
| 21 |
+
带lora的模型,lora只负责让它能并行解码16个mask token,但是前面的上下文信息,依然用原始model跑出来的,通过注入放进draft的时候
|
| 22 |
+
|
| 23 |
+
而且由于模型结构的一致,我们可以直接层对层注入进去
|
syxin/launch_train.sh
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /workspace/hanrui/syxin/Specforge
|
| 5 |
+
|
| 6 |
+
export TORCHINDUCTOR_CACHE_DIR=/workspace/hanrui/cache/compiled_kernels
|
| 7 |
+
export SPECFORGE_DATA_NUM_PROC=16
|
| 8 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 9 |
+
export PYTORCH_ALLOC_CONF=expandable_segments:True
|
| 10 |
+
export HF_DATASETS_CACHE=/workspace/hanrui/cache/hf_datasets
|
| 11 |
+
export HF_HOME=/workspace/hanrui/cache/hf_home
|
| 12 |
+
|
| 13 |
+
torchrun --nproc_per_node=8 \
|
| 14 |
+
scripts/train_dflash_lora_inject.py \
|
| 15 |
+
--target-model-path /workspace/models/Qwen3-8B \
|
| 16 |
+
--target-model-backend hf \
|
| 17 |
+
--train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
|
| 18 |
+
--output-dir outputs/qwen3-8b-sft-32gpu-v2 \
|
| 19 |
+
--block-size 16 \
|
| 20 |
+
--attention-backend additive \
|
| 21 |
+
--attn-implementation sdpa \
|
| 22 |
+
--max-length 2048 \
|
| 23 |
+
--batch-size 4 \
|
| 24 |
+
--accumulation-steps 8 \
|
| 25 |
+
--num-epochs 3 \
|
| 26 |
+
--learning-rate 5e-5 \
|
| 27 |
+
--loss-decay-gamma 7 \
|
| 28 |
+
--gradient-checkpointing \
|
| 29 |
+
--chat-template qwen \
|
| 30 |
+
--log-interval 50 \
|
| 31 |
+
--save-interval 500 \
|
| 32 |
+
--cache-dir /workspace/hanrui/cache \
|
| 33 |
+
--lora-rank 32 \
|
| 34 |
+
--lora-alpha 64 \
|
| 35 |
+
--lora-dropout 0.1 \
|
| 36 |
+
--trust-remote-code \
|
| 37 |
+
--dataloader-num-workers 0
|
syxin/launch_train_wrapper.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Python wrapper to launch bash training script via torchrun
|
| 4 |
+
"""
|
| 5 |
+
import subprocess
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
if __name__ == "__main__":
|
| 10 |
+
# Get the bash script path and arguments
|
| 11 |
+
bash_script = "/workspace/hanrui/syxin/run_train_multinode.sh"
|
| 12 |
+
args = sys.argv[1:] # Pass through all arguments
|
| 13 |
+
|
| 14 |
+
# Build the command
|
| 15 |
+
cmd = ["bash", bash_script] + args
|
| 16 |
+
|
| 17 |
+
# Execute the bash script
|
| 18 |
+
result = subprocess.run(cmd, env=os.environ.copy())
|
| 19 |
+
|
| 20 |
+
# Exit with the same code as the bash script
|
| 21 |
+
sys.exit(result.returncode)
|
syxin/list.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### 1. `train_dflash_lora.py`
|
| 2 |
+
* 加了lora,原来是调用小模型,现在是hidden states+lora预测。
|
| 3 |
+
* `dflash_lora_mask_fn`函数是在处理预测的那一块草稿Block时,可以同时看到这一块里的所有词。
|
| 4 |
+
|
| 5 |
+
### 2. OOM优化
|
| 6 |
+
* 分片策略ZeRO-3,FSDP切分从`SHARD_GRAD_OP`升级到`FULL_SHARD`。
|
| 7 |
+
* `batch-size=1`,`accumulation-steps=8`。
|
| 8 |
+
* 参考之前的代码用了FlexAttention(`dflash_lora_mask_fn`)。
|
| 9 |
+
* `_chunked_lm_loss()`,把算loss切片成256块来算+梯度检查。
|
| 10 |
+
|
| 11 |
+
### 运行
|
| 12 |
+
* bash /workspace/hanrui/junquan/SpecForge/scripts/run_train_dflash_lora.sh 2
|
syxin/merge_lora.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Step 1: Merge DFlash-LoRA adapter into base model.
|
| 3 |
+
Usage:
|
| 4 |
+
conda activate sglang
|
| 5 |
+
python3 merge_lora.py
|
| 6 |
+
python3 merge_lora.py --ckpt epoch_2_step_15000 # 测其他 checkpoint
|
| 7 |
+
"""
|
| 8 |
+
import argparse
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from peft import PeftModel
|
| 13 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 14 |
+
|
| 15 |
+
BASE_MODEL = "/workspace/models/Qwen3-8B"
|
| 16 |
+
OUTPUT_ROOT = "/workspace/hanrui/syxin/Specforge/outputs/qwen3-8b-dflash-lora"
|
| 17 |
+
MERGE_ROOT = "/workspace/hanrui/syxin/Specforge/outputs/qwen3-8b-dflash-lora-merged"
|
| 18 |
+
|
| 19 |
+
def parse_args():
|
| 20 |
+
p = argparse.ArgumentParser()
|
| 21 |
+
p.add_argument("--ckpt", default="epoch_3_step_18576",
|
| 22 |
+
help="Checkpoint folder name under OUTPUT_ROOT")
|
| 23 |
+
p.add_argument("--merged-path", default=MERGE_ROOT,
|
| 24 |
+
help="Where to save the merged model")
|
| 25 |
+
return p.parse_args()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def main():
|
| 29 |
+
args = parse_args()
|
| 30 |
+
adapter_path = os.path.join(OUTPUT_ROOT, args.ckpt)
|
| 31 |
+
merged_path = args.merged_path
|
| 32 |
+
|
| 33 |
+
if os.path.exists(merged_path):
|
| 34 |
+
print(f"[skip] Merged model already exists: {merged_path}")
|
| 35 |
+
return
|
| 36 |
+
|
| 37 |
+
assert os.path.isdir(adapter_path), f"Adapter not found: {adapter_path}"
|
| 38 |
+
|
| 39 |
+
print(f"Base model : {BASE_MODEL}")
|
| 40 |
+
print(f"Adapter : {adapter_path}")
|
| 41 |
+
print(f"Output : {merged_path}")
|
| 42 |
+
print()
|
| 43 |
+
|
| 44 |
+
print("[1/4] Loading base model to CPU ...")
|
| 45 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 46 |
+
BASE_MODEL,
|
| 47 |
+
torch_dtype=torch.bfloat16,
|
| 48 |
+
device_map="cpu",
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
print("[2/4] Loading LoRA adapter ...")
|
| 52 |
+
model = PeftModel.from_pretrained(model, adapter_path)
|
| 53 |
+
|
| 54 |
+
print("[3/4] Merging weights ...")
|
| 55 |
+
model = model.merge_and_unload()
|
| 56 |
+
|
| 57 |
+
print("[4/4] Saving merged model ...")
|
| 58 |
+
os.makedirs(merged_path, exist_ok=True)
|
| 59 |
+
model.save_pretrained(merged_path, safe_serialization=True)
|
| 60 |
+
AutoTokenizer.from_pretrained(BASE_MODEL).save_pretrained(merged_path)
|
| 61 |
+
|
| 62 |
+
print(f"\nDone. Merged model saved to: {merged_path}")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
main()
|
syxin/oom_fix_progress.md
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DFlash LoRA OOM 修复记录
|
| 2 |
+
|
| 3 |
+
## OOM 根因分析
|
| 4 |
+
|
| 5 |
+
1. **SHARD_GRAD_OP (ZeRO-2)** — 每卡持有完整 Qwen3-8B 参数 (~16GB bf16),参数未分片
|
| 6 |
+
2. **SDPA + 4D additive mask** — FlashAttention 不支持 4D additive mask,fallback 到 math backend,每层 materialize 完整 attention scores (`bsz × 32heads × 2048 × 2048`)
|
| 7 |
+
3. **大 vocab logits** — `[bsz, 2048, 151936]` bf16 ≈ 1.18GB,加上梯度和 boolean indexing 拷贝,峰值 ~3-4GB
|
| 8 |
+
4. **机器只有 2 张 H100**,脚本默认 `NUM_GPUS=4`
|
| 9 |
+
|
| 10 |
+
## 已完成的改动
|
| 11 |
+
|
| 12 |
+
### 1. FSDP sharding 改为 FULL_SHARD (ZeRO-3)
|
| 13 |
+
- 文件: `SpecForge/scripts/train_dflash_lora.py:347`
|
| 14 |
+
- `ShardingStrategy.SHARD_GRAD_OP` → `ShardingStrategy.FULL_SHARD`
|
| 15 |
+
- 效果: 参数跨卡分片,每卡省 ~8-12GB
|
| 16 |
+
|
| 17 |
+
### 2. 降 batch-size,提高 accumulation-steps
|
| 18 |
+
- 文件: `SpecForge/scripts/run_train_dflash_lora.sh`
|
| 19 |
+
- `--batch-size 2` → `1`,`--accumulation-steps 4` → `8`
|
| 20 |
+
- 效果: 等效 global batch size 不变,峰值显存减半
|
| 21 |
+
|
| 22 |
+
## 待验证 / 后续优化
|
| 23 |
+
|
| 24 |
+
- [ ] 运行时传 `bash run_train_dflash_lora.sh 2` 确保用 2 卡
|
| 25 |
+
- [x] 如仍 OOM,考虑 chunked cross-entropy loss 避免大 vocab logits 全量 materialize
|
| 26 |
+
- [x] 长期可探索自定义 attention kernel 支持 block-sparse mask,绕过 SDPA math fallback
|
| 27 |
+
|
| 28 |
+
### 3. flex_attention + BlockMask 替换 4D additive mask
|
| 29 |
+
- 文件: `SpecForge/specforge/core/dflash_lora.py`, `specforge/modeling/draft/dflash_lora.py`, `scripts/train_dflash_lora.py`
|
| 30 |
+
- 从非 LoRA 版 `dflash.py` 移植 `_get_or_create_block_mask()` 方法,适配 LoRA 场景 (Q_LEN == KV_LEN == seq_len)
|
| 31 |
+
- LoRA 版 mask: context causal + block bidirectional (非 LoRA 版是 [context, noise] concat KV)
|
| 32 |
+
- 用 `--attention-backend flex_attention` 启用 (默认),退回 `--attention-backend additive` 走原有 4D mask
|
| 33 |
+
- HuggingFace model 用 `attn_implementation="flex_attention"` 加载
|
| 34 |
+
- 效果: 不再 fallback 到 SDPA math backend,省去 `[bsz, heads, seq, seq]` attention scores 的显存
|
| 35 |
+
|
| 36 |
+
### 4. chunked cross-entropy loss
|
| 37 |
+
- 文件: `SpecForge/specforge/core/dflash_lora.py`, `specforge/modeling/draft/dflash_lora.py`, `scripts/train_dflash_lora.py`
|
| 38 |
+
- 从非 LoRA 版 `dflash.py` 移植 `_chunked_lm_loss()` 方法
|
| 39 |
+
- 分 chunk 过 lm_head + CE loss + gradient checkpointing,避免 materialize 完整 `[bsz, seq, vocab]` logits
|
| 40 |
+
- 用 `--lm-head-chunk-size 256` 启用 (默认 0 = 不启用)
|
| 41 |
+
- `DFlashLoRADraftModel.forward()` 新增 `output_hidden_states` 参数,chunked 时返回 hidden states
|
| 42 |
+
- 效果: logits 峰值显存从 O(seq_len × vocab_size) 降至 O(chunk_size × vocab_size)
|
syxin/requirements.txt
ADDED
|
File without changes
|
syxin/run_bench.sh
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Step 3: Run HumanEval / MT-Bench / GSM8K benchmarks.
|
| 3 |
+
# Run AFTER start_server.sh is up.
|
| 4 |
+
# Usage:
|
| 5 |
+
# bash run_bench.sh # all three benches, full dataset
|
| 6 |
+
# bash run_bench.sh humaneval # only humaneval
|
| 7 |
+
# bash run_bench.sh mtbench gsm8k # pick any subset
|
| 8 |
+
|
| 9 |
+
set -e
|
| 10 |
+
|
| 11 |
+
INTRANET_IP=10.1.1.131
|
| 12 |
+
PORT=30000
|
| 13 |
+
BASE_MODEL=/workspace/models/Qwen3-8B
|
| 14 |
+
MERGED=/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-merged
|
| 15 |
+
BENCH_DIR=/workspace/hanrui/syxin_old/Specforge/benchmarks
|
| 16 |
+
RESULT_DIR=$BENCH_DIR/results
|
| 17 |
+
|
| 18 |
+
# ---- sanity check ----
|
| 19 |
+
echo "Checking server at http://$INTRANET_IP:$PORT ..."
|
| 20 |
+
curl -sf http://$INTRANET_IP:$PORT/v1/models > /dev/null || {
|
| 21 |
+
echo "[ERROR] Server not reachable. Start it first: bash start_server.sh"
|
| 22 |
+
exit 1
|
| 23 |
+
}
|
| 24 |
+
echo "Server OK."
|
| 25 |
+
|
| 26 |
+
mkdir -p $RESULT_DIR
|
| 27 |
+
cd $BENCH_DIR
|
| 28 |
+
export PYTHONPATH=/workspace/hanrui/syxin_old/Specforge:$PYTHONPATH
|
| 29 |
+
|
| 30 |
+
# ---- decide which benches to run ----
|
| 31 |
+
TARGETS=("$@")
|
| 32 |
+
if [ ${#TARGETS[@]} -eq 0 ]; then
|
| 33 |
+
TARGETS=(humaneval mtbench gsm8k)
|
| 34 |
+
fi
|
| 35 |
+
|
| 36 |
+
BENCH_ARGS=""
|
| 37 |
+
for t in "${TARGETS[@]}"; do
|
| 38 |
+
case $t in
|
| 39 |
+
humaneval) BENCH_ARGS="$BENCH_ARGS humaneval:164" ;;
|
| 40 |
+
mtbench) BENCH_ARGS="$BENCH_ARGS mtbench:80" ;;
|
| 41 |
+
gsm8k) BENCH_ARGS="$BENCH_ARGS gsm8k:1319" ;;
|
| 42 |
+
*)
|
| 43 |
+
echo "[ERROR] Unknown bench: $t (choices: humaneval mtbench gsm8k)"
|
| 44 |
+
exit 1
|
| 45 |
+
;;
|
| 46 |
+
esac
|
| 47 |
+
done
|
| 48 |
+
|
| 49 |
+
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
| 50 |
+
echo "Running: $BENCH_ARGS"
|
| 51 |
+
echo "Results -> $RESULT_DIR"
|
| 52 |
+
echo ""
|
| 53 |
+
|
| 54 |
+
python3 bench_eagle3.py \
|
| 55 |
+
--model-path $BASE_MODEL \
|
| 56 |
+
--speculative-draft-model-path $MERGED \
|
| 57 |
+
--host $INTRANET_IP \
|
| 58 |
+
--port $PORT \
|
| 59 |
+
--config-list "16,4,1,4" \
|
| 60 |
+
--benchmark-list $BENCH_ARGS \
|
| 61 |
+
--output-dir $RESULT_DIR \
|
| 62 |
+
--name dflash_lora_${TIMESTAMP} \
|
| 63 |
+
--skip-launch-server \
|
| 64 |
+
2>&1 | tee $RESULT_DIR/bench_${TIMESTAMP}.log
|
| 65 |
+
|
| 66 |
+
echo ""
|
| 67 |
+
echo "Done. Latest result files:"
|
| 68 |
+
ls -lht $RESULT_DIR/*.jsonl 2>/dev/null | head -5
|
syxin/run_bench_dflash.sh
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Evaluate DFlash-LoRA-Inject accepted length (offline, 8 GPUs parallel).
|
| 3 |
+
# No sglang server needed. Each GPU loads its own target+draft and processes a shard.
|
| 4 |
+
#
|
| 5 |
+
# Usage:
|
| 6 |
+
# bash run_bench_dflash.sh # 8 GPUs, all 3 benches
|
| 7 |
+
# bash run_bench_dflash.sh humaneval # only humaneval
|
| 8 |
+
# bash run_bench_dflash.sh mtbench gsm8k # pick any subset
|
| 9 |
+
# bash run_bench_dflash.sh --quick # quick test (20 samples)
|
| 10 |
+
# bash run_bench_dflash.sh --ckpt epoch_0_step_500 # specific checkpoint
|
| 11 |
+
# NUM_GPUS=4 bash run_bench_dflash.sh # use 4 GPUs
|
| 12 |
+
|
| 13 |
+
set -e
|
| 14 |
+
|
| 15 |
+
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)
|
| 16 |
+
PYTHON=/workspace/miniconda3/envs/spec/bin/python3
|
| 17 |
+
RESULT_DIR=/workspace/hanrui/syxin_old/Specforge/benchmarks/results
|
| 18 |
+
NUM_GPUS=${NUM_GPUS:-8}
|
| 19 |
+
|
| 20 |
+
# ---- parse args ----
|
| 21 |
+
BENCHMARKS=()
|
| 22 |
+
EXTRA_ARGS=()
|
| 23 |
+
QUICK=false
|
| 24 |
+
|
| 25 |
+
for arg in "$@"; do
|
| 26 |
+
case $arg in
|
| 27 |
+
humaneval|mtbench|gsm8k)
|
| 28 |
+
BENCHMARKS+=("$arg")
|
| 29 |
+
;;
|
| 30 |
+
--quick)
|
| 31 |
+
QUICK=true
|
| 32 |
+
;;
|
| 33 |
+
*)
|
| 34 |
+
EXTRA_ARGS+=("$arg")
|
| 35 |
+
;;
|
| 36 |
+
esac
|
| 37 |
+
done
|
| 38 |
+
|
| 39 |
+
if [ ${#BENCHMARKS[@]} -eq 0 ]; then
|
| 40 |
+
BENCHMARKS=(humaneval mtbench gsm8k)
|
| 41 |
+
fi
|
| 42 |
+
|
| 43 |
+
if [ "$QUICK" = true ]; then
|
| 44 |
+
EXTRA_ARGS+=(--num-samples 20)
|
| 45 |
+
fi
|
| 46 |
+
|
| 47 |
+
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
| 48 |
+
|
| 49 |
+
echo "============================================"
|
| 50 |
+
echo " DFlash-LoRA-Inject Offline Eval"
|
| 51 |
+
echo " GPUs : $NUM_GPUS"
|
| 52 |
+
echo " benchmarks : ${BENCHMARKS[*]}"
|
| 53 |
+
echo " extra args : ${EXTRA_ARGS[*]}"
|
| 54 |
+
echo " results : $RESULT_DIR"
|
| 55 |
+
echo "============================================"
|
| 56 |
+
echo ""
|
| 57 |
+
|
| 58 |
+
mkdir -p $RESULT_DIR
|
| 59 |
+
|
| 60 |
+
$PYTHON -m torch.distributed.run \
|
| 61 |
+
--standalone \
|
| 62 |
+
--nproc_per_node $NUM_GPUS \
|
| 63 |
+
$SCRIPT_DIR/eval_dflash_lora_inject.py \
|
| 64 |
+
--benchmarks ${BENCHMARKS[@]} \
|
| 65 |
+
--output-dir $RESULT_DIR \
|
| 66 |
+
"${EXTRA_ARGS[@]}" \
|
| 67 |
+
2>&1 | tee $RESULT_DIR/bench_dflash_lora_inject_offline_${TIMESTAMP}.log
|
| 68 |
+
|
| 69 |
+
echo ""
|
| 70 |
+
echo "Done. Latest result files:"
|
| 71 |
+
ls -lht $RESULT_DIR/*.json 2>/dev/null | head -5
|
syxin/run_bench_dflash_b16_baseline.sh
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# DFlash-b16 baseline: measure accepted length offline, 8 GPUs parallel.
|
| 3 |
+
# Usage:
|
| 4 |
+
# bash run_bench_dflash_b16_baseline.sh # 8 GPUs, all 3 benches
|
| 5 |
+
# bash run_bench_dflash_b16_baseline.sh humaneval # only humaneval
|
| 6 |
+
# bash run_bench_dflash_b16_baseline.sh --quick # 20 samples per bench
|
| 7 |
+
# NUM_GPUS=4 bash run_bench_dflash_b16_baseline.sh # 4 GPUs
|
| 8 |
+
|
| 9 |
+
set -e
|
| 10 |
+
|
| 11 |
+
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)
|
| 12 |
+
PYTHON=/workspace/miniconda3/envs/spec/bin/python3
|
| 13 |
+
RESULT_DIR=/workspace/hanrui/syxin_old/Specforge/benchmarks/results
|
| 14 |
+
NUM_GPUS=${NUM_GPUS:-8}
|
| 15 |
+
|
| 16 |
+
BENCHMARKS=()
|
| 17 |
+
EXTRA_ARGS=()
|
| 18 |
+
QUICK=false
|
| 19 |
+
|
| 20 |
+
for arg in "$@"; do
|
| 21 |
+
case $arg in
|
| 22 |
+
humaneval|mtbench|gsm8k) BENCHMARKS+=("$arg") ;;
|
| 23 |
+
--quick) QUICK=true ;;
|
| 24 |
+
*) EXTRA_ARGS+=("$arg") ;;
|
| 25 |
+
esac
|
| 26 |
+
done
|
| 27 |
+
|
| 28 |
+
if [ ${#BENCHMARKS[@]} -eq 0 ]; then
|
| 29 |
+
BENCHMARKS=(humaneval mtbench gsm8k)
|
| 30 |
+
fi
|
| 31 |
+
|
| 32 |
+
if [ "$QUICK" = true ]; then
|
| 33 |
+
EXTRA_ARGS+=(--num-samples 20)
|
| 34 |
+
fi
|
| 35 |
+
|
| 36 |
+
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
| 37 |
+
|
| 38 |
+
echo "============================================"
|
| 39 |
+
echo " DFlash-b16 Baseline Offline Eval"
|
| 40 |
+
echo " GPUs : $NUM_GPUS"
|
| 41 |
+
echo " draft : /workspace/models/Qwen3-8B-DFlash-b16"
|
| 42 |
+
echo " benchmarks : ${BENCHMARKS[*]}"
|
| 43 |
+
echo " extra args : ${EXTRA_ARGS[*]}"
|
| 44 |
+
echo "============================================"
|
| 45 |
+
echo ""
|
| 46 |
+
|
| 47 |
+
mkdir -p $RESULT_DIR
|
| 48 |
+
|
| 49 |
+
$PYTHON -m torch.distributed.run \
|
| 50 |
+
--standalone \
|
| 51 |
+
--nproc_per_node $NUM_GPUS \
|
| 52 |
+
$SCRIPT_DIR/eval_dflash_b16_baseline.py \
|
| 53 |
+
--benchmarks ${BENCHMARKS[@]} \
|
| 54 |
+
--output-dir $RESULT_DIR \
|
| 55 |
+
"${EXTRA_ARGS[@]}" \
|
| 56 |
+
2>&1 | tee $RESULT_DIR/bench_dflash_b16_baseline_${TIMESTAMP}.log
|
| 57 |
+
|
| 58 |
+
echo ""
|
| 59 |
+
echo "Done. Latest result files:"
|
| 60 |
+
ls -lht $RESULT_DIR/*.json 2>/dev/null | head -5
|
syxin/run_qwen3_8b_sft_32gpu.sh
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
export JOB_NAME='qwen3-8b-sft'
|
| 3 |
+
export GPU_NUMS=64
|
| 4 |
+
export TRAIN_SCRIPT='/workspace/hanrui/syxin/launch_train_wrapper.py'
|
| 5 |
+
export WORK_DIR='/workspace/hanrui/syxin/Specforge'
|
| 6 |
+
|
| 7 |
+
if [ $GPU_NUMS -lt 8 ]; then
|
| 8 |
+
export NNODES=1
|
| 9 |
+
export GPU_NUMS_PER_NODE=$GPU_NUMS
|
| 10 |
+
else
|
| 11 |
+
export NNODES=$((GPU_NUMS/8))
|
| 12 |
+
export GPU_NUMS_PER_NODE=8
|
| 13 |
+
fi
|
| 14 |
+
|
| 15 |
+
# 使用 spec 环境的 northjob
|
| 16 |
+
/workspace/miniconda3/envs/spec/bin/northjob \
|
| 17 |
+
create \
|
| 18 |
+
--job-type train \
|
| 19 |
+
--nproc-per-node $GPU_NUMS_PER_NODE \
|
| 20 |
+
--gpu-per-node $GPU_NUMS_PER_NODE \
|
| 21 |
+
--nnodes $NNODES \
|
| 22 |
+
--k8s-priority 3 \
|
| 23 |
+
--k8s-queue bg-agentic-coding \
|
| 24 |
+
--k8s-namespace bg-agentic-coding \
|
| 25 |
+
--k8s-pvc-name i-xinsiyang-y4zy0sik0a \
|
| 26 |
+
--k8s-pvc-mount-path /workspace \
|
| 27 |
+
--k8s-no-reclaim \
|
| 28 |
+
--k8s-images harbor.local.clusters/bp/megatron-bplm:25.03_fp8.ibgda.qwen3.next.fix_triton.fix_te.hf457.qwen3_vl \
|
| 29 |
+
--job-name $JOB_NAME \
|
| 30 |
+
--workspace $WORK_DIR \
|
| 31 |
+
$TRAIN_SCRIPT $GPU_NUMS_PER_NODE
|
syxin/run_train_dflash_direct_inject.sh
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
ROOT_DIR=/workspace/hanrui/syxin_old/Specforge
|
| 5 |
+
NUM_GPUS=8
|
| 6 |
+
OUTPUT_DIR=$ROOT_DIR/outputs/qwen3-8b-dflash-direct-inject
|
| 7 |
+
if [[ $# -ge 1 ]]; then
|
| 8 |
+
NUM_GPUS=$1
|
| 9 |
+
shift
|
| 10 |
+
fi
|
| 11 |
+
if [[ $# -ge 1 && "${1:0:1}" != "-" ]]; then
|
| 12 |
+
OUTPUT_DIR=$1
|
| 13 |
+
shift
|
| 14 |
+
fi
|
| 15 |
+
EXTRA_ARGS=("$@")
|
| 16 |
+
|
| 17 |
+
export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels
|
| 18 |
+
export SPECFORGE_DATA_NUM_PROC=16
|
| 19 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 20 |
+
export PYTORCH_ALLOC_CONF=expandable_segments:True
|
| 21 |
+
export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}"
|
| 22 |
+
DEFAULT_SPECFORGE_PY=/workspace/hanrui/specforge/bin/python3
|
| 23 |
+
if [[ -z "${PYTHON_BIN:-}" ]]; then
|
| 24 |
+
if [[ -x "$DEFAULT_SPECFORGE_PY" ]]; then
|
| 25 |
+
PYTHON_BIN="$DEFAULT_SPECFORGE_PY"
|
| 26 |
+
else
|
| 27 |
+
PYTHON_BIN=python3
|
| 28 |
+
fi
|
| 29 |
+
fi
|
| 30 |
+
|
| 31 |
+
cd $ROOT_DIR
|
| 32 |
+
|
| 33 |
+
$PYTHON_BIN -m torch.distributed.run \
|
| 34 |
+
--standalone \
|
| 35 |
+
--nproc_per_node $NUM_GPUS \
|
| 36 |
+
scripts/train_dflash.py \
|
| 37 |
+
--target-model-path /workspace/models/Qwen3-8B \
|
| 38 |
+
--target-model-backend sglang \
|
| 39 |
+
--train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
|
| 40 |
+
--output-dir $OUTPUT_DIR \
|
| 41 |
+
--block-size 16 \
|
| 42 |
+
--num-draft-layers 36 \
|
| 43 |
+
--attention-backend flex_attention \
|
| 44 |
+
--max-length 2048 \
|
| 45 |
+
--batch-size 1 \
|
| 46 |
+
--accumulation-steps 8 \
|
| 47 |
+
--num-epochs 3 \
|
| 48 |
+
--learning-rate 6e-4 \
|
| 49 |
+
--loss-decay-gamma 7 \
|
| 50 |
+
--lm-head-chunk-size 256 \
|
| 51 |
+
--gradient-checkpointing \
|
| 52 |
+
--chat-template qwen \
|
| 53 |
+
--log-interval 50 \
|
| 54 |
+
--save-interval 500 \
|
| 55 |
+
--cache-dir $ROOT_DIR/cache \
|
| 56 |
+
"${EXTRA_ARGS[@]}"
|
syxin/run_train_dflash_lora_inject.sh
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
ROOT_DIR=/workspace/hanrui/syxin/Specforge
|
| 5 |
+
NUM_GPUS=8
|
| 6 |
+
OUTPUT_DIR=$ROOT_DIR/outputs/qwen3-8b-dflash-lora-inject
|
| 7 |
+
CACHE_DIR=/tmp/specforge_cache
|
| 8 |
+
|
| 9 |
+
# Parse arguments
|
| 10 |
+
if [[ $# -ge 1 ]]; then
|
| 11 |
+
NUM_GPUS=$1
|
| 12 |
+
shift
|
| 13 |
+
fi
|
| 14 |
+
if [[ $# -ge 1 && "${1:0:1}" != "-" ]]; then
|
| 15 |
+
OUTPUT_DIR=$1
|
| 16 |
+
shift
|
| 17 |
+
fi
|
| 18 |
+
EXTRA_ARGS=("$@")
|
| 19 |
+
|
| 20 |
+
# Environment variables
|
| 21 |
+
export TORCHINDUCTOR_CACHE_DIR=/tmp/specforge_cache/compiled_kernels
|
| 22 |
+
export SPECFORGE_DATA_NUM_PROC=16
|
| 23 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 24 |
+
export PYTORCH_ALLOC_CONF=expandable_segments:True
|
| 25 |
+
export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}"
|
| 26 |
+
export HF_DATASETS_CACHE=/tmp/specforge_cache/hf_datasets
|
| 27 |
+
export HF_HOME=/tmp/specforge_cache/hf_home
|
| 28 |
+
|
| 29 |
+
# Python binary
|
| 30 |
+
DEFAULT_SPECFORGE_PY=/workspace/hanrui/specforge/bin/python3
|
| 31 |
+
if [[ -z "${PYTHON_BIN:-}" ]]; then
|
| 32 |
+
if [[ -x "$DEFAULT_SPECFORGE_PY" ]]; then
|
| 33 |
+
PYTHON_BIN="$DEFAULT_SPECFORGE_PY"
|
| 34 |
+
else
|
| 35 |
+
PYTHON_BIN=python3
|
| 36 |
+
fi
|
| 37 |
+
fi
|
| 38 |
+
|
| 39 |
+
cd $ROOT_DIR
|
| 40 |
+
|
| 41 |
+
$PYTHON_BIN -m torch.distributed.run \
|
| 42 |
+
--standalone \
|
| 43 |
+
--nproc_per_node $NUM_GPUS \
|
| 44 |
+
scripts/train_dflash_lora_inject.py \
|
| 45 |
+
--target-model-path /workspace/models/Qwen3-8B \
|
| 46 |
+
--target-model-backend hf \
|
| 47 |
+
--train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
|
| 48 |
+
--output-dir $OUTPUT_DIR \
|
| 49 |
+
--block-size 16 \
|
| 50 |
+
--attention-backend additive \
|
| 51 |
+
--attn-implementation sdpa \
|
| 52 |
+
--max-length 2048 \
|
| 53 |
+
--batch-size 8 \
|
| 54 |
+
--accumulation-steps 8 \
|
| 55 |
+
--num-epochs 3 \
|
| 56 |
+
--learning-rate 5e-5 \
|
| 57 |
+
--loss-decay-gamma 7 \
|
| 58 |
+
--gradient-checkpointing \
|
| 59 |
+
--chat-template qwen \
|
| 60 |
+
--log-interval 50 \
|
| 61 |
+
--save-interval 500 \
|
| 62 |
+
--cache-dir $CACHE_DIR \
|
| 63 |
+
--lora-rank 32 \
|
| 64 |
+
--lora-alpha 64 \
|
| 65 |
+
--lora-dropout 0.1 \
|
| 66 |
+
--trust-remote-code \
|
| 67 |
+
--dataloader-num-workers 0 \
|
| 68 |
+
--early-stop \
|
| 69 |
+
--early-stop-patience 5 \
|
| 70 |
+
--early-stop-min-delta 0.005 \
|
| 71 |
+
"${EXTRA_ARGS[@]}"
|
syxin/run_train_multinode.sh
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
ROOT_DIR=/workspace/hanrui/syxin/Specforge
|
| 5 |
+
NUM_GPUS=8
|
| 6 |
+
OUTPUT_DIR=$ROOT_DIR/outputs/qwen3-8b-sft-32gpu-v3
|
| 7 |
+
CACHE_DIR=/tmp/specforge_cache
|
| 8 |
+
|
| 9 |
+
# Parse arguments
|
| 10 |
+
if [[ $# -ge 1 ]]; then
|
| 11 |
+
NUM_GPUS=$1
|
| 12 |
+
shift
|
| 13 |
+
fi
|
| 14 |
+
if [[ $# -ge 1 && "${1:0:1}" != "-" ]]; then
|
| 15 |
+
OUTPUT_DIR=$1
|
| 16 |
+
shift
|
| 17 |
+
fi
|
| 18 |
+
EXTRA_ARGS=("$@")
|
| 19 |
+
|
| 20 |
+
# Environment variables
|
| 21 |
+
export TORCHINDUCTOR_CACHE_DIR=/tmp/specforge_cache/compiled_kernels
|
| 22 |
+
export SPECFORGE_DATA_NUM_PROC=16
|
| 23 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 24 |
+
export PYTORCH_ALLOC_CONF=expandable_segments:True
|
| 25 |
+
export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}"
|
| 26 |
+
export HF_DATASETS_CACHE=/tmp/specforge_cache/hf_datasets
|
| 27 |
+
export HF_HOME=/tmp/specforge_cache/hf_home
|
| 28 |
+
|
| 29 |
+
# Python binary
|
| 30 |
+
DEFAULT_SPECFORGE_PY=/workspace/miniconda3/envs/spec/bin/python3
|
| 31 |
+
if [[ -z "${PYTHON_BIN:-}" ]]; then
|
| 32 |
+
if [[ -x "$DEFAULT_SPECFORGE_PY" ]]; then
|
| 33 |
+
PYTHON_BIN="$DEFAULT_SPECFORGE_PY"
|
| 34 |
+
else
|
| 35 |
+
PYTHON_BIN=python3
|
| 36 |
+
fi
|
| 37 |
+
fi
|
| 38 |
+
|
| 39 |
+
cd $ROOT_DIR
|
| 40 |
+
|
| 41 |
+
# northjob 已经通过 torchrun 设置了分布式环境变量
|
| 42 |
+
# 直接运行训练脚本,不要再启动 torch.distributed.run
|
| 43 |
+
$PYTHON_BIN scripts/train_dflash_lora_inject.py \
|
| 44 |
+
--target-model-path /workspace/models/Qwen3-8B \
|
| 45 |
+
--target-model-backend hf \
|
| 46 |
+
--train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
|
| 47 |
+
--output-dir $OUTPUT_DIR \
|
| 48 |
+
--block-size 16 \
|
| 49 |
+
--attention-backend additive \
|
| 50 |
+
--attn-implementation sdpa \
|
| 51 |
+
--max-length 2048 \
|
| 52 |
+
--batch-size 4 \
|
| 53 |
+
--accumulation-steps 16 \
|
| 54 |
+
--num-epochs 3 \
|
| 55 |
+
--learning-rate 5e-5 \
|
| 56 |
+
--loss-decay-gamma 7 \
|
| 57 |
+
--gradient-checkpointing \
|
| 58 |
+
--chat-template qwen \
|
| 59 |
+
--log-interval 50 \
|
| 60 |
+
--save-interval 500 \
|
| 61 |
+
--cache-dir $CACHE_DIR \
|
| 62 |
+
--lora-rank 32 \
|
| 63 |
+
--lora-alpha 64 \
|
| 64 |
+
--lora-dropout 0.1 \
|
| 65 |
+
--trust-remote-code \
|
| 66 |
+
--dataloader-num-workers 0 \
|
| 67 |
+
"${EXTRA_ARGS[@]}"
|
syxin/run_train_qwen3_8b_sft_32gpu.sh
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
ROOT_DIR=/workspace/hanrui/syxin_old/Specforge
|
| 5 |
+
NUM_GPUS=8
|
| 6 |
+
OUTPUT_DIR=$ROOT_DIR/outputs/qwen3-8b-sft-32gpu-v2
|
| 7 |
+
CACHE_DIR=/tmp/specforge_cache_sft
|
| 8 |
+
|
| 9 |
+
# Parse arguments
|
| 10 |
+
if [[ $# -ge 1 ]]; then
|
| 11 |
+
NUM_GPUS=$1
|
| 12 |
+
shift
|
| 13 |
+
fi
|
| 14 |
+
if [[ $# -ge 1 && "${1:0:1}" != "-" ]]; then
|
| 15 |
+
OUTPUT_DIR=$1
|
| 16 |
+
shift
|
| 17 |
+
fi
|
| 18 |
+
EXTRA_ARGS=("$@")
|
| 19 |
+
|
| 20 |
+
# Environment variables
|
| 21 |
+
export TORCHINDUCTOR_CACHE_DIR=/tmp/specforge_cache_sft/compiled_kernels
|
| 22 |
+
export SPECFORGE_DATA_NUM_PROC=16
|
| 23 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 24 |
+
export PYTORCH_ALLOC_CONF=expandable_segments:True
|
| 25 |
+
export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}"
|
| 26 |
+
export HF_DATASETS_CACHE=/tmp/specforge_cache_sft/hf_datasets
|
| 27 |
+
export HF_HOME=/tmp/specforge_cache_sft/hf_home
|
| 28 |
+
|
| 29 |
+
# Python binary
|
| 30 |
+
DEFAULT_SPECFORGE_PY=/workspace/hanrui/specforge/bin/python3
|
| 31 |
+
if [[ -z "${PYTHON_BIN:-}" ]]; then
|
| 32 |
+
if [[ -x "$DEFAULT_SPECFORGE_PY" ]]; then
|
| 33 |
+
PYTHON_BIN="$DEFAULT_SPECFORGE_PY"
|
| 34 |
+
else
|
| 35 |
+
PYTHON_BIN=python3
|
| 36 |
+
fi
|
| 37 |
+
fi
|
| 38 |
+
|
| 39 |
+
cd $ROOT_DIR
|
| 40 |
+
|
| 41 |
+
# northjob 已经通过 torchrun 启动了分布式,这里直接运行训练脚本
|
| 42 |
+
$PYTHON_BIN $ROOT_DIR/scripts/train_dflash_lora_inject.py \
|
| 43 |
+
--target-model-path /workspace/models/Qwen3-8B \
|
| 44 |
+
--target-model-backend hf \
|
| 45 |
+
--train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
|
| 46 |
+
--output-dir $OUTPUT_DIR \
|
| 47 |
+
--block-size 16 \
|
| 48 |
+
--attention-backend additive \
|
| 49 |
+
--attn-implementation sdpa \
|
| 50 |
+
--max-length 2048 \
|
| 51 |
+
--batch-size 8 \
|
| 52 |
+
--accumulation-steps 8 \
|
| 53 |
+
--num-epochs 3 \
|
| 54 |
+
--learning-rate 5e-5 \
|
| 55 |
+
--loss-decay-gamma 7 \
|
| 56 |
+
--gradient-checkpointing \
|
| 57 |
+
--chat-template qwen \
|
| 58 |
+
--log-interval 50 \
|
| 59 |
+
--save-interval 500 \
|
| 60 |
+
--cache-dir $CACHE_DIR \
|
| 61 |
+
--lora-rank 32 \
|
| 62 |
+
--lora-alpha 64 \
|
| 63 |
+
--lora-dropout 0.1 \
|
| 64 |
+
--trust-remote-code \
|
| 65 |
+
--dataloader-num-workers 0 \
|
| 66 |
+
"${EXTRA_ARGS[@]}"
|
syxin/server.log
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/workspace/hanrui/sglang/python/sglang/launch_server.py:51: UserWarning: 'python -m sglang.launch_server' is still supported, but 'sglang serve' is the recommended entrypoint.
|
| 2 |
+
Example: sglang serve --model-path <model> [options]
|
| 3 |
+
warnings.warn(
|
| 4 |
+
[2026-03-07 15:24:13] INFO server_args.py:2048: Attention backend not specified. Use fa3 backend by default.
|
| 5 |
+
[2026-03-07 15:24:13] WARNING server_args.py:2629: Max running requests is reset to 48 for speculative decoding. You can override this by explicitly setting --max-running-requests.
|
| 6 |
+
[2026-03-07 15:24:13] WARNING server_args.py:2650: Overlap scheduler is disabled when spec v2 is off or using unsupported speculative algorithm. You can set env SGLANG_ENABLE_SPEC_V2=True to enable the experimental overlap scheduler.
|
| 7 |
+
[2026-03-07 15:24:13] WARNING server_args.py:2712: speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1
|
| 8 |
+
[2026-03-07 15:24:14] server_args=ServerArgs(model_path='/workspace/models/Qwen3-8B', tokenizer_path='/workspace/models/Qwen3-8B', tokenizer_mode='auto', tokenizer_worker_num=1, skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=True, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='10.233.100.123', port=30000, fastapi_root_path='', grpc_mode=False, skip_server_warmup=False, warmups=None, nccl_port=None, checkpoint_engine_wait_weights_before_ready=False, ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_keyfile_password=None, enable_ssl_refresh=False, dtype='bfloat16', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', enable_fp32_lm_head=False, modelopt_quant=None, modelopt_checkpoint_restore_path=None, modelopt_checkpoint_save_path=None, modelopt_export_path=None, quantize_and_serve=False, rl_quant_profile=None, mem_fraction_static=0.8, max_running_requests=48, max_queued_requests=None, max_total_tokens=None, chunked_prefill_size=8192, enable_dynamic_chunking=False, max_prefill_tokens=16384, prefill_max_requests=None, schedule_policy='fcfs', enable_priority_scheduling=False, disable_priority_preemption=False, default_priority_value=None, abort_on_priority_when_disabled=False, schedule_low_priority_values_first=False, priority_scheduling_preemption_threshold=10, schedule_conservativeness=1.0, page_size=1, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, radix_eviction_policy='lru', enable_prefill_delayer=False, prefill_delayer_max_delay_passes=30, prefill_delayer_token_usage_low_watermark=None, prefill_delayer_forward_passes_buckets=None, prefill_delayer_wait_seconds_buckets=None, device='cuda', tp_size=4, pp_size=1, pp_max_micro_batch_size=None, pp_async_batch_depth=0, stream_interval=1, stream_output=False, enable_streaming_session=False, random_seed=551181117, constrained_json_whitespace_pattern=None, constrained_json_disable_any_whitespace=False, watchdog_timeout=300, soft_watchdog_timeout=None, dist_timeout=None, download_dir=None, model_checksum=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, use_ray=False, custom_sigquit_handler=None, log_level='info', log_level_http=None, log_requests=False, log_requests_level=2, log_requests_format='text', log_requests_target=None, uvicorn_access_log_exclude_prefixes=[], crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, tokenizer_metrics_custom_labels_header='x-custom-labels', tokenizer_metrics_allowed_custom_labels=None, extra_metric_labels=None, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, prompt_tokens_buckets=None, generation_tokens_buckets=None, gc_warning_threshold_secs=0.0, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, enable_trace=False, otlp_traces_endpoint='localhost:4317', export_metrics_to_file=False, export_metrics_to_file_dir=None, api_key=None, admin_api_key=None, served_model_name='/workspace/models/Qwen3-8B', weight_version='default', chat_template=None, hf_chat_template_name=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, sampling_defaults='model', dp_size=1, load_balance_method='round_robin', attn_cp_size=1, moe_dp_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, enable_lora_overlap_loading=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_eviction_policy='lru', lora_backend='csgmv', max_lora_chunk_size=16, attention_backend='fa3', decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, fp8_gemm_runner_backend='auto', fp4_gemm_runner_backend='flashinfer_cutlass', nsa_prefill_backend=None, nsa_decode_backend=None, disable_flashinfer_autotune=False, mamba_backend='triton', speculative_algorithm='STANDALONE', speculative_draft_model_path='/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-merged', speculative_draft_model_revision='main', speculative_draft_load_format=None, speculative_num_steps=4, speculative_eagle_topk=1, speculative_num_draft_tokens=5, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, speculative_attention_mode='prefill', speculative_draft_attention_backend=None, speculative_moe_runner_backend='auto', speculative_moe_a2a_backend=None, speculative_draft_model_quantization=None, speculative_ngram_min_match_window_size=1, speculative_ngram_max_match_window_size=12, speculative_ngram_min_bfs_breadth=1, speculative_ngram_max_bfs_breadth=10, speculative_ngram_match_type='BFS', speculative_ngram_branch_length=18, speculative_ngram_capacity=10000000, enable_multi_layer_eagle=False, ep_size=1, moe_a2a_backend='none', moe_runner_backend='auto', flashinfer_mxfp4_moe_precision='default', enable_flashinfer_allreduce_fusion=False, enable_aiter_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm=None, init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, eplb_min_rebalancing_utilization_threshold=1.0, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, elastic_ep_backend=None, enable_elastic_expert_backup=False, mooncake_ib_device=None, max_mamba_cache_size=None, mamba_ssm_dtype=None, mamba_full_memory_ratio=0.9, mamba_scheduler_strategy='no_buffer', mamba_track_interval=256, linear_attn_backend='triton', linear_attn_decode_backend=None, linear_attn_prefill_backend=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through', hicache_io_backend='kernel', hicache_mem_layout='layer_first', disable_hicache_numa_detect=False, hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', hicache_storage_backend_extra_config=None, hierarchical_sparse_attention_extra_config=None, enable_lmcache=False, kt_weight_path=None, kt_method='AMXINT4', kt_cpuinfer=None, kt_threadpool_count=2, kt_num_gpu_experts=None, kt_max_deferred_experts_per_token=None, dllm_algorithm=None, dllm_algorithm_config=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, cpu_offload_gb=0, offload_group_size=-1, offload_num_in_group=1, offload_prefetch_step=1, offload_mode='cpu', multi_item_scoring_delimiter=None, disable_radix_cache=False, cuda_graph_max_bs=512, cuda_graph_bs=[1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432, 448, 464, 480, 496, 512], disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_layerwise_nvtx_marker=False, enable_nccl_nvls=False, enable_symm_mem=False, disable_flashinfer_cutlass_moe_fp4_allgather=False, enable_tokenizer_batch_encode=False, disable_tokenizer_batch_decode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, enable_torch_symm_mem=False, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_single_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, disable_piecewise_cuda_graph=True, enforce_piecewise_cuda_graph=False, enable_torch_compile_debug_mode=False, torch_compile_max_bs=32, piecewise_cuda_graph_max_tokens=8192, piecewise_cuda_graph_tokens=[4, 8, 12, 16, 20, 24, 28, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 288, 320, 352, 384, 416, 448, 480, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1280, 1536, 1792, 2048, 2304, 2560, 2816, 3072, 3328, 3584, 3840, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192], piecewise_cuda_graph_compiler='eager', torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, triton_attention_split_tile_size=None, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, enable_weights_cpu_backup=False, enable_draft_weights_cpu_backup=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, keep_mm_feature_on_device=False, enable_return_hidden_states=False, enable_return_routed_experts=False, scheduler_recv_interval=1, numa_node=None, enable_deterministic_inference=False, rl_on_policy_target=None, enable_attn_tp_input_scattered=False, enable_nsa_prefill_context_parallel=False, nsa_prefill_cp_mode='round-robin-split', enable_fused_qk_norm_rope=False, enable_precise_embedding_interpolation=False, enable_fused_moe_sum_all_reduce=False, enable_dynamic_batch_tokenizer=False, dynamic_batch_tokenizer_batch_size=32, dynamic_batch_tokenizer_batch_timeout=0.002, debug_tensor_dump_output_folder=None, debug_tensor_dump_layers=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_ib_device=None, disaggregation_decode_enable_offload_kvcache=False, num_reserved_decode_tokens=512, disaggregation_decode_polling_interval=1, encoder_only=False, language_only=False, encoder_transfer_backend='zmq_to_scheduler', encoder_urls=[], enable_adaptive_dispatch_to_encoder=False, custom_weight_loader=[], weight_loader_disable_mmap=False, remote_instance_weight_loader_seed_instance_ip=None, remote_instance_weight_loader_seed_instance_service_port=None, remote_instance_weight_loader_send_weights_group_ports=None, remote_instance_weight_loader_backend='nccl', remote_instance_weight_loader_start_seed_via_transfer_engine=False, enable_pdmux=False, pdmux_config_path=None, sm_group_num=8, mm_max_concurrent_calls=32, mm_per_request_timeout=10.0, enable_broadcast_mm_inputs_process=False, enable_prefix_mm_cache=False, mm_enable_dp_encoder=False, mm_process_config={}, limit_mm_data_per_request=None, enable_mm_global_cache=False, decrypted_config_file=None, decrypted_draft_config_file=None, forward_hooks=None)
|
| 9 |
+
[2026-03-07 15:24:15] Using default HuggingFace chat template with detected content format: string
|
| 10 |
+
[2026-03-07 15:24:25 TP2] Mamba selective_state_update backend initialized: triton
|
| 11 |
+
[2026-03-07 15:24:25 TP2] Init torch distributed begin.
|
| 12 |
+
[2026-03-07 15:24:26 TP0] Mamba selective_state_update backend initialized: triton
|
| 13 |
+
[2026-03-07 15:24:26 TP0] Init torch distributed begin.
|
| 14 |
+
[2026-03-07 15:24:26 TP3] Mamba selective_state_update backend initialized: triton
|
| 15 |
+
[2026-03-07 15:24:26 TP1] Mamba selective_state_update backend initialized: triton
|
| 16 |
+
[2026-03-07 15:24:26 TP3] Init torch distributed begin.
|
| 17 |
+
[2026-03-07 15:24:26 TP1] Init torch distributed begin.
|
| 18 |
+
[Gloo] Rank 1 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
|
| 19 |
+
[Gloo] Rank 0 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
|
| 20 |
+
[Gloo] Rank 3 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
|
| 21 |
+
[Gloo] Rank 2 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
|
| 22 |
+
[Gloo] Rank 0 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
|
| 23 |
+
[Gloo] Rank 2 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
|
| 24 |
+
[Gloo] Rank 1 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
|
| 25 |
+
[Gloo] Rank 3 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
|
| 26 |
+
[2026-03-07 15:24:27 TP0] sglang is using nccl==2.27.5
|
| 27 |
+
[2026-03-07 15:24:29 TP0] Scheduler hit an exception: Traceback (most recent call last):
|
| 28 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 3239, in run_scheduler_process
|
| 29 |
+
scheduler = Scheduler(
|
| 30 |
+
^^^^^^^^^^
|
| 31 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 365, in __init__
|
| 32 |
+
self.init_model_worker()
|
| 33 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 561, in init_model_worker
|
| 34 |
+
self.init_tp_model_worker()
|
| 35 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 519, in init_tp_model_worker
|
| 36 |
+
self.tp_worker = TpModelWorker(
|
| 37 |
+
^^^^^^^^^^^^^^
|
| 38 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/tp_worker.py", line 258, in __init__
|
| 39 |
+
self._init_model_runner()
|
| 40 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/tp_worker.py", line 341, in _init_model_runner
|
| 41 |
+
self._model_runner = ModelRunner(
|
| 42 |
+
^^^^^^^^^^^^
|
| 43 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/model_executor/model_runner.py", line 395, in __init__
|
| 44 |
+
pre_model_load_memory = self.init_torch_distributed()
|
| 45 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 46 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/model_executor/model_runner.py", line 813, in init_torch_distributed
|
| 47 |
+
initialize_model_parallel(
|
| 48 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 1764, in initialize_model_parallel
|
| 49 |
+
_TP = init_model_parallel_group(
|
| 50 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 51 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 1450, in init_model_parallel_group
|
| 52 |
+
return GroupCoordinator(
|
| 53 |
+
^^^^^^^^^^^^^^^^^
|
| 54 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 357, in __init__
|
| 55 |
+
self.pynccl_comm = PyNcclCommunicator(
|
| 56 |
+
^^^^^^^^^^^^^^^^^^^
|
| 57 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl.py", line 113, in __init__
|
| 58 |
+
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
| 59 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 60 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 401, in ncclCommInitRank
|
| 61 |
+
self.NCCL_CHECK(
|
| 62 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 376, in NCCL_CHECK
|
| 63 |
+
raise RuntimeError(f"NCCL error: {error_str}")
|
| 64 |
+
RuntimeError: NCCL error: unhandled system error (run with NCCL_DEBUG=INFO for details)
|
| 65 |
+
|
| 66 |
+
[2026-03-07 15:24:29] Received sigquit from a child process. It usually means the child failed.
|
| 67 |
+
[2026-03-07 15:24:29 TP2] Scheduler hit an exception: Traceback (most recent call last):
|
| 68 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 3239, in run_scheduler_process
|
| 69 |
+
scheduler = Scheduler(
|
| 70 |
+
^^^^^^^^^^
|
| 71 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 365, in __init__
|
| 72 |
+
self.init_model_worker()
|
| 73 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 561, in init_model_worker
|
| 74 |
+
self.init_tp_model_worker()
|
| 75 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 519, in init_tp_model_worker
|
| 76 |
+
self.tp_worker = TpModelWorker(
|
| 77 |
+
^^^^^^^^^^^^^^
|
| 78 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/tp_worker.py", line 258, in __init__
|
| 79 |
+
self._init_model_runner()
|
| 80 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/tp_worker.py", line 341, in _init_model_runner
|
| 81 |
+
self._model_runner = ModelRunner(
|
| 82 |
+
^^^^^^^^^^^^
|
| 83 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/model_executor/model_runner.py", line 395, in __init__
|
| 84 |
+
pre_model_load_memory = self.init_torch_distributed()
|
| 85 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 86 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/model_executor/model_runner.py", line 813, in init_torch_distributed
|
| 87 |
+
initialize_model_parallel(
|
| 88 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 1764, in initialize_model_parallel
|
| 89 |
+
_TP = init_model_parallel_group(
|
| 90 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 91 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 1450, in init_model_parallel_group
|
| 92 |
+
return GroupCoordinator(
|
| 93 |
+
^^^^^^^^^^^^^^^^^
|
| 94 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 357, in __init__
|
| 95 |
+
self.pynccl_comm = PyNcclCommunicator(
|
| 96 |
+
^^^^^^^^^^^^^^^^^^^
|
| 97 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl.py", line 113, in __init__
|
| 98 |
+
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
| 99 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 100 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 401, in ncclCommInitRank
|
| 101 |
+
self.NCCL_CHECK(
|
| 102 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 376, in NCCL_CHECK
|
| 103 |
+
raise RuntimeError(f"NCCL error: {error_str}")
|
| 104 |
+
RuntimeError: NCCL error: unhandled system error (run with NCCL_DEBUG=INFO for details)
|
| 105 |
+
|
| 106 |
+
[2026-03-07 15:24:29 TP1] Scheduler hit an exception: Traceback (most recent call last):
|
| 107 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 3239, in run_scheduler_process
|
| 108 |
+
scheduler = Scheduler(
|
| 109 |
+
^^^^^^^^^^
|
| 110 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 365, in __init__
|
| 111 |
+
self.init_model_worker()
|
| 112 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 561, in init_model_worker
|
| 113 |
+
self.init_tp_model_worker()
|
| 114 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 519, in init_tp_model_worker
|
| 115 |
+
self.tp_worker = TpModelWorker(
|
| 116 |
+
^^^^^^^^^^^^^^
|
| 117 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/tp_worker.py", line 258, in __init__
|
| 118 |
+
self._init_model_runner()
|
| 119 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/tp_worker.py", line 341, in _init_model_runner
|
| 120 |
+
self._model_runner = ModelRunner(
|
| 121 |
+
^^^^^^^^^^^^
|
| 122 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/model_executor/model_runner.py", line 395, in __init__
|
| 123 |
+
pre_model_load_memory = self.init_torch_distributed()
|
| 124 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 125 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/model_executor/model_runner.py", line 813, in init_torch_distributed
|
| 126 |
+
initialize_model_parallel(
|
| 127 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 1764, in initialize_model_parallel
|
| 128 |
+
_TP = init_model_parallel_group(
|
| 129 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 130 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 1450, in init_model_parallel_group
|
| 131 |
+
return GroupCoordinator(
|
| 132 |
+
^^^^^^^^^^^^^^^^^
|
| 133 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 357, in __init__
|
| 134 |
+
self.pynccl_comm = PyNcclCommunicator(
|
| 135 |
+
^^^^^^^^^^^^^^^^^^^
|
| 136 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl.py", line 113, in __init__
|
| 137 |
+
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
| 138 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 139 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 401, in ncclCommInitRank
|
| 140 |
+
self.NCCL_CHECK(
|
| 141 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 376, in NCCL_CHECK
|
| 142 |
+
raise RuntimeError(f"NCCL error: {error_str}")
|
| 143 |
+
RuntimeError: NCCL error: unhandled system error (run with NCCL_DEBUG=INFO for details)
|
| 144 |
+
|
| 145 |
+
[2026-03-07 15:24:29] Received sigquit from a child process. It usually means the child failed.
|
| 146 |
+
[2026-03-07 15:24:29] Received sigquit from a child process. It usually means the child failed.
|
| 147 |
+
[2026-03-07 15:24:29 TP3] Scheduler hit an exception: Traceback (most recent call last):
|
| 148 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 3239, in run_scheduler_process
|
| 149 |
+
scheduler = Scheduler(
|
| 150 |
+
^^^^^^^^^^
|
| 151 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 365, in __init__
|
| 152 |
+
self.init_model_worker()
|
| 153 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 561, in init_model_worker
|
| 154 |
+
self.init_tp_model_worker()
|
| 155 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 519, in init_tp_model_worker
|
| 156 |
+
self.tp_worker = TpModelWorker(
|
| 157 |
+
^^^^^^^^^^^^^^
|
| 158 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/tp_worker.py", line 258, in __init__
|
| 159 |
+
self._init_model_runner()
|
| 160 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/managers/tp_worker.py", line 341, in _init_model_runner
|
| 161 |
+
self._model_runner = ModelRunner(
|
| 162 |
+
^^^^^^^^^^^^
|
| 163 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/model_executor/model_runner.py", line 395, in __init__
|
| 164 |
+
pre_model_load_memory = self.init_torch_distributed()
|
| 165 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 166 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/model_executor/model_runner.py", line 813, in init_torch_distributed
|
| 167 |
+
initialize_model_parallel(
|
| 168 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 1764, in initialize_model_parallel
|
| 169 |
+
_TP = init_model_parallel_group(
|
| 170 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 171 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 1450, in init_model_parallel_group
|
| 172 |
+
return GroupCoordinator(
|
| 173 |
+
^^^^^^^^^^^^^^^^^
|
| 174 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 357, in __init__
|
| 175 |
+
self.pynccl_comm = PyNcclCommunicator(
|
| 176 |
+
^^^^^^^^^^^^^^^^^^^
|
| 177 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl.py", line 113, in __init__
|
| 178 |
+
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
| 179 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 180 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 401, in ncclCommInitRank
|
| 181 |
+
self.NCCL_CHECK(
|
| 182 |
+
File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 376, in NCCL_CHECK
|
| 183 |
+
raise RuntimeError(f"NCCL error: {error_str}")
|
| 184 |
+
RuntimeError: NCCL error: unhandled system error (run with NCCL_DEBUG=INFO for details)
|
| 185 |
+
|
| 186 |
+
[2026-03-07 15:24:29] Received sigquit from a child process. It usually means the child failed.
|
syxin/start_server.sh
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Step 2: Launch SGLang server with STANDALONE speculative decoding.
|
| 3 |
+
# Usage:
|
| 4 |
+
# bash start_server.sh
|
| 5 |
+
# bash start_server.sh 8 # use tp=8
|
| 6 |
+
|
| 7 |
+
set -e
|
| 8 |
+
|
| 9 |
+
TP=${1:-2}
|
| 10 |
+
|
| 11 |
+
BASE_MODEL=/workspace/models/Qwen3-8B
|
| 12 |
+
MERGED=/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-merged
|
| 13 |
+
INTRANET_IP=10.1.1.131
|
| 14 |
+
PORT=30000
|
| 15 |
+
|
| 16 |
+
if [ ! -d "$MERGED" ]; then
|
| 17 |
+
echo "[ERROR] Merged model not found: $MERGED"
|
| 18 |
+
echo " Run: conda activate sglang && python3 merge_lora.py"
|
| 19 |
+
exit 1
|
| 20 |
+
fi
|
| 21 |
+
|
| 22 |
+
echo "============================================"
|
| 23 |
+
echo " SGLang STANDALONE Speculative Decoding"
|
| 24 |
+
echo " target : $BASE_MODEL"
|
| 25 |
+
echo " draft : $MERGED"
|
| 26 |
+
echo " host : $INTRANET_IP:$PORT"
|
| 27 |
+
echo " tp : $TP"
|
| 28 |
+
echo "============================================"
|
| 29 |
+
|
| 30 |
+
/workspace/miniconda3/envs/sglang/bin/python3 -m sglang.launch_server \
|
| 31 |
+
--model-path $BASE_MODEL \
|
| 32 |
+
--speculative-algorithm STANDALONE \
|
| 33 |
+
--speculative-draft-model-path $MERGED \
|
| 34 |
+
--speculative-num-steps 4 \
|
| 35 |
+
--speculative-eagle-topk 1 \
|
| 36 |
+
--speculative-num-draft-tokens 4 \
|
| 37 |
+
--tp-size $TP \
|
| 38 |
+
--mem-fraction-static 0.30 \
|
| 39 |
+
--trust-remote-code \
|
| 40 |
+
--host $INTRANET_IP \
|
| 41 |
+
--port $PORT \
|
| 42 |
+
--dtype bfloat16
|
syxin/start_server_dflash.sh
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Evaluate DFlash-LoRA-Inject: measure accepted length OFFLINE.
|
| 3 |
+
# 8 GPUs parallel by default, each GPU runs a shard of prompts independently.
|
| 4 |
+
#
|
| 5 |
+
# WHY offline?
|
| 6 |
+
# sglang STANDALONE treats draft as an independent autoregressive model,
|
| 7 |
+
# completely ignoring the layer-by-layer injection that LoRA-Inject was
|
| 8 |
+
# trained with. Result: accept_length ≈ 4.7 for ALL models (no signal).
|
| 9 |
+
#
|
| 10 |
+
# sglang DFLASH expects the DFlash-b16 architecture (5-layer, fc+hidden_norm),
|
| 11 |
+
# which is structurally different from LoRA-Inject (full 36-layer + LoRA).
|
| 12 |
+
#
|
| 13 |
+
# So we run offline spec-generate with the correct injection pattern.
|
| 14 |
+
#
|
| 15 |
+
# Usage:
|
| 16 |
+
# bash start_server_dflash.sh # 8 GPUs, all benchmarks
|
| 17 |
+
# bash start_server_dflash.sh 4 # 4 GPUs
|
| 18 |
+
# bash start_server_dflash.sh 8 humaneval # specific benchmark
|
| 19 |
+
# bash start_server_dflash.sh 8 --num-samples 20 # quick test
|
| 20 |
+
|
| 21 |
+
set -e
|
| 22 |
+
|
| 23 |
+
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)
|
| 24 |
+
|
| 25 |
+
NUM_GPUS=${1:-8}
|
| 26 |
+
shift 2>/dev/null || true
|
| 27 |
+
|
| 28 |
+
# ---- defaults ----
|
| 29 |
+
BASE_MODEL=/workspace/models/Qwen3-8B
|
| 30 |
+
ADAPTER_ROOT=/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject
|
| 31 |
+
CKPT=epoch_3_step_1400
|
| 32 |
+
MERGED=/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject-merged
|
| 33 |
+
RESULT_DIR=/workspace/hanrui/syxin_old/Specforge/benchmarks/results
|
| 34 |
+
PYTHON=/workspace/miniconda3/envs/spec/bin/python3
|
| 35 |
+
|
| 36 |
+
echo "============================================"
|
| 37 |
+
echo " DFlash-LoRA-Inject Offline Evaluation"
|
| 38 |
+
echo " target : $BASE_MODEL"
|
| 39 |
+
echo " ckpt : $CKPT"
|
| 40 |
+
echo " merged : $MERGED"
|
| 41 |
+
echo " GPUs : $NUM_GPUS"
|
| 42 |
+
echo "============================================"
|
| 43 |
+
|
| 44 |
+
$PYTHON -m torch.distributed.run \
|
| 45 |
+
--standalone \
|
| 46 |
+
--nproc_per_node $NUM_GPUS \
|
| 47 |
+
$SCRIPT_DIR/eval_dflash_lora_inject.py \
|
| 48 |
+
--base-model $BASE_MODEL \
|
| 49 |
+
--adapter-root $ADAPTER_ROOT \
|
| 50 |
+
--ckpt $CKPT \
|
| 51 |
+
--merged-path $MERGED \
|
| 52 |
+
--block-size 16 \
|
| 53 |
+
--output-dir $RESULT_DIR \
|
| 54 |
+
"$@"
|
syxin/step1.md
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Now I have a complete picture. Here's the full analysis and proposed plan.
|
| 2 |
+
|
| 3 |
+
---
|
| 4 |
+
Current Logic of train_dflash.py
|
| 5 |
+
|
| 6 |
+
This script trains a small, standalone DFlash draft model (1–10 transformer layers) that learns to predict token blocks in parallel,
|
| 7 |
+
guided by hidden states extracted from a frozen target model.
|
| 8 |
+
|
| 9 |
+
Architecture (two separate models)
|
| 10 |
+
|
| 11 |
+
1. Target model (frozen, e.g. Qwen3-8B/32B) — runs a full forward pass to produce hidden states at selected layers.
|
| 12 |
+
2. Draft model (DFlashDraftModel, ~1–10 layers) — a lightweight Qwen3-based decoder that takes noise embeddings + target hidden states
|
| 13 |
+
and predicts the block tokens.
|
| 14 |
+
3. Target embed_tokens + lm_head — loaded separately via TargetEmbeddingsAndHead to avoid duplicating the full target model in memory.
|
| 15 |
+
|
| 16 |
+
Key locations
|
| 17 |
+
|
| 18 |
+
┌──────────────────────────┬────────────────────────────────────┬───────────────────────────────────────────────────────┐
|
| 19 |
+
│ Component │ File │ Lines │
|
| 20 |
+
├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
|
| 21 |
+
│ Model init │ scripts/train_dflash.py │ build_models() L254–311 │
|
| 22 |
+
├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
|
| 23 |
+
│ Target hidden extraction │ scripts/train_dflash.py │ L644–647 (target_model.generate_dflash_data) │
|
| 24 |
+
├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
|
| 25 |
+
│ Forward pass │ specforge/core/dflash.py │ OnlineDFlashModel.forward() L243–332 │
|
| 26 |
+
├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
|
| 27 |
+
│ Loss calculation │ specforge/core/dflash.py │ _full_lm_loss() L382–417, _chunked_lm_loss() L419–478 │
|
| 28 |
+
├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
|
| 29 |
+
│ Loss mask │ specforge/core/dflash.py │ create_dflash_loss_mask() L481–509 │
|
| 30 |
+
├──────────────────────────┼────────────────────────────────────┼────���──────────────────────────────────────────────────┤
|
| 31 |
+
│ Draft model architecture │ specforge/modeling/draft/dflash.py │ DFlashDraftModel L212–266 │
|
| 32 |
+
├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
|
| 33 |
+
│ DFlash attention │ specforge/modeling/draft/dflash.py │ Qwen3DFlashAttention L42–134 │
|
| 34 |
+
└──────────────────────────┴────────────────────────────────────┴───────────────────────────────────────────────────────┘
|
| 35 |
+
|
| 36 |
+
Forward pass flow (per training step)
|
| 37 |
+
|
| 38 |
+
input_ids, attention_mask, loss_mask → target_model.generate_dflash_data()
|
| 39 |
+
↓
|
| 40 |
+
hidden_states (from target layers [1,9,17,25,33])
|
| 41 |
+
↓
|
| 42 |
+
OnlineDFlashModel.forward():
|
| 43 |
+
1. Truncate to block boundary
|
| 44 |
+
2. prepare_noise_input(): anchor tokens kept, rest → MASK
|
| 45 |
+
3. embed_tokens(noise_input_ids) → noise_embedding
|
| 46 |
+
4. Build DFlash attention mask (flex_attention or additive)
|
| 47 |
+
5. draft_model(noise_embedding, target_hidden, mask)
|
| 48 |
+
6. lm_head(hidden) → logits
|
| 49 |
+
7. CE loss on non-anchor positions (weighted by loss_mask × decay)
|
| 50 |
+
|
| 51 |
+
The draft model's custom Qwen3DFlashAttention concatenates [context_hidden, noise_hidden] as KV, with queries only from noise tokens. The
|
| 52 |
+
attention mask enforces: block tokens see all preceding blocks' context + bidirectional within their own block.
|
| 53 |
+
|
| 54 |
+
---
|
| 55 |
+
What already exists: train_dflash_lora.py
|
| 56 |
+
|
| 57 |
+
Interestingly, the repo already has a LoRA variant at scripts/train_dflash_lora.py with its own model (DFlashLoRADraftModel) and wrapper
|
| 58 |
+
(OnlineDFlashLoRAModel). This is exactly the approach you described — Qwen3-8B + LoRA, no separate target model, 1-step diffusion
|
| 59 |
+
training. The key differences from train_dflash.py:
|
| 60 |
+
|
| 61 |
+
┌─────────────────┬─────────────────────────────────────────────────────────────┬────────────────────────────────────────────────────┐
|
| 62 |
+
│ Aspect │ train_dflash.py │ train_dflash_lora.py │
|
| 63 |
+
├─────────────────┼─────────────────────────────────────────────────────────────┼────────────────────────────────────────────────────┤
|
| 64 |
+
│ Draft model │ Small custom DFlashDraftModel (1–10 layers) │ Full Qwen3-8B + LoRA adapters │
|
| 65 |
+
├─────────────────┼─────────────────────────────────────────────────────────────┼────────────────────────────────────────────────────┤
|
| 66 |
+
│ Target model │ Separate frozen model for hidden state extraction │ None — model uses its own representations │
|
| 67 |
+
├─────────────────┼─────────────────────────────────────────────────────────────┼───────────────────────────���────────────────────────┤
|
| 68 |
+
│ Attention │ Custom Qwen3DFlashAttention (Q from noise, KV from [ctx, │ Standard HF attention with 4D additive DFlash mask │
|
| 69 |
+
│ │ noise]) │ │
|
| 70 |
+
├─────────────────┼─────────────────────────────────────────────────────────────┼────────────────────────────────────────────────────┤
|
| 71 |
+
│ Forward │ draft_model(noise_emb, target_hidden, mask) │ model(noise_input_ids, 4d_mask, position_ids) → │
|
| 72 |
+
│ │ │ logits │
|
| 73 |
+
├─────────────────┼─────────────────────────────────────────────────────────────┼────────────────────────────────────────────────────┤
|
| 74 |
+
│ Trainable │ All draft model params │ Only LoRA (q/k/v/o_proj), base frozen │
|
| 75 |
+
│ params │ │ │
|
| 76 |
+
├─────────────────┼─────────────────────────────────────────────────────────────┼────────────────────────────────────────────────────┤
|
| 77 |
+
│ FSDP strategy │ SHARD_GRAD_OP │ FULL_SHARD │
|
| 78 |
+
└─────────────────┴─────────────────────────────────────────────────────────────┴────────────────────────────────────────────────────┘
|
| 79 |
+
|
| 80 |
+
---
|
| 81 |
+
Proposed Modification Plan
|
| 82 |
+
|
| 83 |
+
Since train_dflash_lora.py already implements the core idea, the plan focuses on what's missing or needs improvement to make it a proper
|
| 84 |
+
"1-step dLLM draft model" for your research:
|
| 85 |
+
|
| 86 |
+
Phase 1: Validate and extend the existing LoRA pipeline
|
| 87 |
+
|
| 88 |
+
1. Add MLP to LoRA targets — The current config only targets q_proj, k_proj, v_proj, o_proj. For stronger 1-step diffusion capability,
|
| 89 |
+
add gate_proj, up_proj, down_proj to lora_target_modules. This gives the model more capacity to learn the non-autoregressive distribution
|
| 90 |
+
shift.
|
| 91 |
+
2. Add multi-step noise schedule support — Currently the training is strictly 1-step (all non-anchors → MASK). For a proper diffusion/AR
|
| 92 |
+
fusion, add an option for a noise schedule where a fraction of block tokens are revealed (not just the anchor), controlled by a
|
| 93 |
+
noise_ratio parameter. This would modify prepare_noise_input() in OnlineDFlashLoRAModel:
|
| 94 |
+
# Instead of: all non-anchor → MASK
|
| 95 |
+
# Allow: randomly keep some non-anchor tokens with probability (1 - noise_ratio)
|
| 96 |
+
3. Add configurable context_len strategy — Currently context_len=0 treats the whole sequence as blocks. Add a --context-ratio arg that
|
| 97 |
+
dynamically sets context_len as a fraction of the sequence, so the model learns to condition on varying amounts of AR-decoded prefix.
|
| 98 |
+
|
| 99 |
+
Phase 2: Training logic improvements
|
| 100 |
+
|
| 101 |
+
4. Add KL divergence loss — In addition to CE loss against ground truth, add an optional KL loss against the base model's AR distribution
|
| 102 |
+
(teacher forcing). This regularizes the LoRA model to stay close to the original Qwen3-8B distribution. Modify
|
| 103 |
+
OnlineDFlashLoRAModel.forward():
|
| 104 |
+
# Compute base model logits (no_grad, no LoRA) as teacher
|
| 105 |
+
# KL(draft_logits || teacher_logits) on block positions
|
| 106 |
+
# total_loss = ce_loss + kl_weight * kl_loss
|
| 107 |
+
5. Add evaluation with speculative decoding metrics — The current accuracy metric is block-wise acceptance rate. Add an eval loop that
|
| 108 |
+
actually runs speculative decoding (draft → verify) to measure real speedup, using the LoRA model as the drafter and the base model (with
|
| 109 |
+
LoRA disabled) as the verifier.
|
| 110 |
+
|
| 111 |
+
Phase 3: Integration with train_dflash.py style features
|
| 112 |
+
|
| 113 |
+
6. Port random anchor sampling — train_dflash.py has --random-anchor (L147–156) which samples diverse anchor positions per sequence. This
|
| 114 |
+
is missing from the LoRA variant and would improve training diversity. Port _sample_anchor_positions and _build_blocks_from_anchors from
|
| 115 |
+
OnlineDFlashModel to OnlineDFlashLoRAModel.
|
| 116 |
+
7. Port chunked lm_head — The LoRA variant materializes full [bsz, seq_len, vocab_size] logits. For Qwen3-8B's 152K vocab, this is ~1.2GB
|
| 117 |
+
per sample at seq_len=2048. Port the _chunked_lm_loss from OnlineDFlashModel for memory efficiency.
|
| 118 |
+
8. Add tensor parallelism support — The LoRA script currently forces tp_size=1. For Qwen3-8B on multi-GPU setups, add TP support to shard
|
| 119 |
+
the base model across GPUs while keeping LoRA params on each rank.
|
| 120 |
+
|
| 121 |
+
Minimal starting point
|
| 122 |
+
|
| 123 |
+
If you want to just get training running immediately, the existing train_dflash_lora.py + configs/qwen3-8b-dflash-lora.json should work
|
| 124 |
+
out of the box:
|
| 125 |
+
|
| 126 |
+
torchrun --nproc_per_node=4 scripts/train_dflash_lora.py \
|
| 127 |
+
--model-path Qwen/Qwen3-8B \
|
| 128 |
+
--lora-config configs/qwen3-8b-dflash-lora.json \
|
| 129 |
+
--train-data-path /path/to/data.jsonl \
|
| 130 |
+
--output-dir ./outputs/qwen3-8b-dflash-lora \
|
| 131 |
+
--block-size 16 \
|
| 132 |
+
--gradient-checkpointing \
|
| 133 |
+
--learning-rate 2e-4 \
|
| 134 |
+
--batch-size 1 \
|
| 135 |
+
--max-length 2048 \
|
| 136 |
+
--num-epochs 3
|
| 137 |
+
|
| 138 |
+
Want me to start implementing any of these phases? I'd suggest starting with Phase 1 (noise schedule + MLP targets) since that's the most
|
| 139 |
+
impactful for the diffusion/AR fusion research angle.
|