Lekr0 commited on
Commit
7c50656
·
verified ·
1 Parent(s): 40d87dd

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. SpecForge-ext/.pre-commit-config.yaml +53 -0
  2. SpecForge-ext/convert_mtbench.py +22 -0
  3. SpecForge-ext/download_datasets.py +64 -0
  4. SpecForge-ext/download_mtbench.sh +23 -0
  5. SpecForge-ext/download_mtbench_data.py +51 -0
  6. SpecForge-ext/mtbench_sample.json +26 -0
  7. SpecForge-ext/pyproject.toml +44 -0
  8. SpecForge-ext/requirements.txt +0 -0
  9. SpecForge-ext/setup.py +33 -0
  10. SpecForge-ext/test_accept_length.md +300 -0
  11. SpecForge/.editorconfig +25 -0
  12. SpecForge/.isort.cfg +3 -0
  13. SpecForge/.pre-commit-config.yaml +53 -0
  14. SpecForge/LICENSE +21 -0
  15. SpecForge/MANIFEST.in +2 -0
  16. SpecForge/README.md +70 -0
  17. SpecForge/pyproject.toml +47 -0
  18. SpecForge/requirements-rocm.txt +20 -0
  19. SpecForge/version.txt +1 -0
  20. idea1/.editorconfig +25 -0
  21. idea1/.isort.cfg +3 -0
  22. idea1/.pre-commit-config.yaml +53 -0
  23. idea1/LICENSE +21 -0
  24. idea1/requirements-rocm.txt +20 -0
  25. idea1/version.txt +1 -0
  26. qwen3-8b_dflash_regen/.gitattributes +36 -0
  27. syxin/backup.log +0 -0
  28. syxin/dflash_lora_changelog.md +232 -0
  29. syxin/eval_accepted_length.md +217 -0
  30. syxin/eval_dflash_b16_baseline.py +354 -0
  31. syxin/eval_dflash_lora_inject.py +627 -0
  32. syxin/idea.md +23 -0
  33. syxin/launch_train.sh +37 -0
  34. syxin/launch_train_wrapper.py +21 -0
  35. syxin/list.md +12 -0
  36. syxin/merge_lora.py +66 -0
  37. syxin/oom_fix_progress.md +42 -0
  38. syxin/requirements.txt +0 -0
  39. syxin/run_bench.sh +68 -0
  40. syxin/run_bench_dflash.sh +71 -0
  41. syxin/run_bench_dflash_b16_baseline.sh +60 -0
  42. syxin/run_qwen3_8b_sft_32gpu.sh +31 -0
  43. syxin/run_train_dflash_direct_inject.sh +56 -0
  44. syxin/run_train_dflash_lora_inject.sh +71 -0
  45. syxin/run_train_multinode.sh +67 -0
  46. syxin/run_train_qwen3_8b_sft_32gpu.sh +66 -0
  47. syxin/server.log +186 -0
  48. syxin/start_server.sh +42 -0
  49. syxin/start_server_dflash.sh +54 -0
  50. syxin/step1.md +139 -0
SpecForge-ext/.pre-commit-config.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_stages: [pre-commit, pre-push, manual]
2
+
3
+ repos:
4
+ - repo: https://github.com/PyCQA/autoflake
5
+ rev: v2.3.1
6
+ hooks:
7
+ - id: autoflake
8
+ args: [--remove-all-unused-imports, --in-place]
9
+ - repo: https://github.com/pre-commit/pre-commit-hooks
10
+ rev: v5.0.0
11
+ hooks:
12
+ - id: check-symlinks
13
+ - id: destroyed-symlinks
14
+ - id: trailing-whitespace
15
+ - id: end-of-file-fixer
16
+ - id: check-yaml
17
+ args: [--allow-multiple-documents]
18
+ - id: check-toml
19
+ - id: check-ast
20
+ - id: check-added-large-files
21
+ - id: check-merge-conflict
22
+ - id: check-shebang-scripts-are-executable
23
+ - id: detect-private-key
24
+ - id: debug-statements
25
+ - id: no-commit-to-branch
26
+ - repo: https://github.com/PyCQA/isort
27
+ rev: 5.13.2
28
+ hooks:
29
+ - id: isort
30
+ - repo: https://github.com/astral-sh/ruff-pre-commit
31
+ rev: v0.11.10
32
+ hooks:
33
+ - id: ruff
34
+ args: [--select=F401, --fixable=F401]
35
+ files: ^(benchmark/|docs/|examples/)
36
+ exclude: \.ipynb$
37
+ - repo: https://github.com/psf/black
38
+ rev: 24.10.0
39
+ hooks:
40
+ - id: black-jupyter
41
+ - repo: https://github.com/pre-commit/mirrors-clang-format
42
+ rev: v18.1.8
43
+ hooks:
44
+ - id: clang-format
45
+ types_or: [c++, cuda]
46
+ args: [--style=file, --verbose]
47
+ - repo: https://github.com/kynan/nbstripout
48
+ rev: 0.8.1
49
+ hooks:
50
+ - id: nbstripout
51
+ args:
52
+ - '--keep-output'
53
+ - '--extra-keys=metadata.kernelspec metadata.language_info.version'
SpecForge-ext/convert_mtbench.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import json
3
+ import os
4
+
5
+ # 读取 JSON 文件并转换为 JSONL
6
+ input_file = "/workspace/hanrui/SpecForge-ext/mtbench_sample.json"
7
+ with open(input_file, 'r') as f:
8
+ data = json.load(f)
9
+
10
+ # 保存为 jsonl
11
+ cache_dir = os.path.expanduser("~/.cache/sglang")
12
+ os.makedirs(cache_dir, exist_ok=True)
13
+ output_file = os.path.join(cache_dir, "mtbench.jsonl")
14
+
15
+ with open(output_file, 'w') as f:
16
+ for item in data:
17
+ f.write(json.dumps(item) + '\n')
18
+
19
+ print(f"Converted {len(data)} questions")
20
+ print(f"Saved to {output_file}")
21
+ print(f"\nFirst question:")
22
+ print(json.dumps(data[0], indent=2))
SpecForge-ext/download_datasets.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 下载 GSM8K 和 HumanEval 数据集到本地
4
+ """
5
+ import os
6
+ import json
7
+ import requests
8
+ from datasets import load_dataset
9
+
10
+ DATA_DIR = "/workspace/hanrui/datasets"
11
+ os.makedirs(DATA_DIR, exist_ok=True)
12
+
13
+ print("=" * 60)
14
+ print("下载 GSM8K 数据集")
15
+ print("=" * 60)
16
+
17
+ try:
18
+ # 下载 GSM8K
19
+ gsm8k_dir = os.path.join(DATA_DIR, "gsm8k")
20
+ os.makedirs(gsm8k_dir, exist_ok=True)
21
+
22
+ print("Loading GSM8K from HuggingFace...")
23
+ dataset = load_dataset("gsm8k", "main", split="test")
24
+
25
+ # 保存为 jsonl
26
+ output_file = os.path.join(gsm8k_dir, "test.jsonl")
27
+ with open(output_file, 'w') as f:
28
+ for item in dataset:
29
+ f.write(json.dumps(item) + '\n')
30
+
31
+ print(f"✓ GSM8K saved to {output_file}")
32
+ print(f" Total samples: {len(dataset)}")
33
+
34
+ except Exception as e:
35
+ print(f"✗ GSM8K download failed: {e}")
36
+
37
+ print("\n" + "=" * 60)
38
+ print("下载 HumanEval 数据集")
39
+ print("=" * 60)
40
+
41
+ try:
42
+ # 下载 HumanEval
43
+ humaneval_dir = os.path.join(DATA_DIR, "humaneval")
44
+ os.makedirs(humaneval_dir, exist_ok=True)
45
+
46
+ print("Loading HumanEval from HuggingFace...")
47
+ dataset = load_dataset("openai_humaneval", split="test")
48
+
49
+ # 保存为 jsonl
50
+ output_file = os.path.join(humaneval_dir, "test.jsonl")
51
+ with open(output_file, 'w') as f:
52
+ for item in dataset:
53
+ f.write(json.dumps(item) + '\n')
54
+
55
+ print(f"✓ HumanEval saved to {output_file}")
56
+ print(f" Total samples: {len(dataset)}")
57
+
58
+ except Exception as e:
59
+ print(f"✗ HumanEval download failed: {e}")
60
+
61
+ print("\n" + "=" * 60)
62
+ print("下载完成")
63
+ print("=" * 60)
64
+ print(f"数据保存在: {DATA_DIR}")
SpecForge-ext/download_mtbench.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # 下载 mtbench 数据文件
4
+ # 如果无法访问 GitHub,需要手动下载或使用镜像
5
+
6
+ CACHE_DIR="$HOME/.cache/sglang"
7
+ mkdir -p "$CACHE_DIR"
8
+
9
+ echo "Downloading mtbench data..."
10
+
11
+ # 方法1:尝试使用代理下载
12
+ https_proxy=http://10.1.2.1:7890 http_proxy=http://10.1.2.1:7890 \
13
+ curl -L "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl" \
14
+ -o "$CACHE_DIR/mtbench.jsonl"
15
+
16
+ if [ $? -eq 0 ]; then
17
+ echo "Downloaded to $CACHE_DIR/mtbench.jsonl"
18
+ ls -lh "$CACHE_DIR/mtbench.jsonl"
19
+ else
20
+ echo "Download failed. Please manually download the file from:"
21
+ echo "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl"
22
+ echo "And save it to: $CACHE_DIR/mtbench.jsonl"
23
+ fi
SpecForge-ext/download_mtbench_data.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 下载并转换 MT-Bench 数据到本地目录
4
+ """
5
+ import json
6
+ import os
7
+ import requests
8
+
9
+ # 目标目录
10
+ DATA_DIR = "/workspace/hanrui/datasets/mtbench"
11
+ os.makedirs(DATA_DIR, exist_ok=True)
12
+
13
+ # 下载 MT-Bench 问题数据
14
+ url = "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl"
15
+ output_file = os.path.join(DATA_DIR, "question.jsonl")
16
+
17
+ print(f"Downloading MT-Bench questions from {url}")
18
+ print(f"Saving to {output_file}")
19
+
20
+ try:
21
+ # 使用代理下载
22
+ proxies = {
23
+ 'http': 'http://10.1.2.1:7890',
24
+ 'https': 'http://10.1.2.1:7890',
25
+ }
26
+
27
+ response = requests.get(url, proxies=proxies, timeout=30)
28
+ response.raise_for_status()
29
+
30
+ with open(output_file, 'wb') as f:
31
+ f.write(response.content)
32
+
33
+ print(f"✓ Downloaded successfully")
34
+
35
+ # 验证数据
36
+ with open(output_file, 'r') as f:
37
+ lines = f.readlines()
38
+
39
+ print(f"✓ Total questions: {len(lines)}")
40
+
41
+ # 显示第一个问题
42
+ first_question = json.loads(lines[0])
43
+ print(f"\nFirst question:")
44
+ print(json.dumps(first_question, indent=2))
45
+
46
+ except Exception as e:
47
+ print(f"✗ Download failed: {e}")
48
+ print(f"\nPlease manually download from:")
49
+ print(f" {url}")
50
+ print(f"And save to:")
51
+ print(f" {output_file}")
SpecForge-ext/mtbench_sample.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "question_id": 1,
4
+ "category": "writing",
5
+ "turns": [
6
+ "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.",
7
+ "Rewrite your previous response. Start every sentence with the letter A."
8
+ ]
9
+ },
10
+ {
11
+ "question_id": 2,
12
+ "category": "roleplay",
13
+ "turns": [
14
+ "Imagine you are writing a blog post comparing two popular smartphone models. Develop an outline for the blog post, including key points and subheadings to effectively compare and contrast the features, performance, and user experience of the two models. Please answer in fewer than 200 words.",
15
+ "Take your previous response and rephrase it as a limerick."
16
+ ]
17
+ },
18
+ {
19
+ "question_id": 3,
20
+ "category": "reasoning",
21
+ "turns": [
22
+ "Describe a vivid and unique character, using strong imagery and creative language. Please answer in fewer than two paragraphs.",
23
+ "Revise your previous response and incorporate an allusion to a famous work of literature or historical event in each sentence."
24
+ ]
25
+ }
26
+ ]
SpecForge-ext/pyproject.toml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "specforge"
7
+ dynamic = ["version", "description"]
8
+ readme = "README.md"
9
+ requires-python = ">=3.11"
10
+ dependencies = [
11
+ "pre-commit",
12
+ "torch==2.9.1",
13
+ "torchaudio==2.9.1",
14
+ "torchvision==0.24.1",
15
+ "transformers==4.57.1",
16
+ "qwen-vl-utils==0.0.11",
17
+ "datasets",
18
+ "setuptools",
19
+ "tqdm",
20
+ "wandb",
21
+ "psutil",
22
+ "numpy",
23
+ "accelerate",
24
+ "pydantic",
25
+ "sglang==0.5.6",
26
+ "openai-harmony",
27
+ "ninja",
28
+ "packaging",
29
+ "yunchang",
30
+ ]
31
+
32
+ [tool.setuptools]
33
+ packages = ["specforge"]
34
+
35
+ [project.optional-dependencies]
36
+ dev = [
37
+ "pre-commit",
38
+ "unittest"
39
+ ]
40
+ fa = ["flash-attn"]
41
+
42
+ [tool.setuptools.dynamic]
43
+ version = {file = "version.txt"}
44
+ description = {file = "README.md"}
SpecForge-ext/requirements.txt ADDED
File without changes
SpecForge-ext/setup.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tomllib
2
+ from pathlib import Path
3
+
4
+ from setuptools import find_packages, setup
5
+
6
+
7
+ def read_readme():
8
+ with open("README.md", "r") as f:
9
+ return f.read()
10
+
11
+
12
+ def read_version():
13
+ with open("version.txt", "r") as f:
14
+ return f.read().strip()
15
+
16
+
17
+ def read_dependencies():
18
+ pyproject_path = Path(__file__).parent / "pyproject.toml"
19
+ with open(pyproject_path, "rb") as f:
20
+ pyproject = tomllib.load(f)
21
+ return pyproject.get("project", {}).get("dependencies", [])
22
+
23
+
24
+ setup(
25
+ name="specforge",
26
+ packages=find_packages(exclude=["configs", "scripts", "tests"]),
27
+ version=read_version(),
28
+ install_requires=read_dependencies(),
29
+ long_description=read_readme(),
30
+ long_description_content_type="text/markdown",
31
+ author="SGLang Team",
32
+ url="https://github.com/sgl-project/SpecForge",
33
+ )
SpecForge-ext/test_accept_length.md ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Accept Length 测试指南
2
+
3
+ ## 0. 准备工作
4
+
5
+ ### 创建目录
6
+ ```bash
7
+ cd /workspace/hanrui/SpecForge-ext
8
+ mkdir -p logs results
9
+ ```
10
+
11
+ ### 下载数据集(首次运行)
12
+ ```bash
13
+ cd /workspace/hanrui/SpecForge-ext
14
+ python download_datasets.py
15
+ ```
16
+
17
+ 数据保存位置:
18
+ - MT-Bench: `/workspace/hanrui/datasets/mtbench/question.jsonl`
19
+ - GSM8K: `/workspace/hanrui/datasets/gsm8k/test.jsonl`
20
+ - HumanEval: `/workspace/hanrui/datasets/humaneval/test.jsonl`
21
+
22
+ ---
23
+
24
+ ## 1. 测试 Baseline 模型
25
+
26
+ ### 启动服务器(终端1)
27
+ ```bash
28
+ cd /workspace/hanrui/SpecForge-ext
29
+
30
+ # 设置环境变量
31
+ export NO_PROXY="localhost,127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16"
32
+ export no_proxy="localhost,127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16"
33
+
34
+ # 启动 baseline 服务器
35
+ python3 -m sglang.launch_server \
36
+ --model /workspace/Qwen3-8B \
37
+ --speculative-algorithm EAGLE3 \
38
+ --speculative-draft-model-path /workspace/qwen3_8b_eagle3 \
39
+ --speculative-num-steps 3 \
40
+ --speculative-eagle-topk 1 \
41
+ --speculative-num-draft-tokens 4 \
42
+ --mem-fraction-static 0.75 \
43
+ --cuda-graph-max-bs 1 \
44
+ --tp 1 \
45
+ --trust-remote-code \
46
+ --host 0.0.0.0 \
47
+ --port 30000 \
48
+ --dtype bfloat16 \
49
+ --skip-server-warmup
50
+ ```
51
+
52
+ 等待看到 `Application startup complete` 后,继续下一步。
53
+
54
+ ### 运行三个 Benchmark(终端2)
55
+ ```bash
56
+ cd /workspace/hanrui/SpecForge-ext
57
+ conda activate /workspace/Hanrui/
58
+
59
+ # 设置环境变量
60
+ export NO_PROXY="localhost,127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16"
61
+ export no_proxy="localhost,127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16"
62
+
63
+ # 1. MT-Bench
64
+ echo "=== Running MT-Bench (Baseline) ==="
65
+ python benchmarks/bench_eagle3.py \
66
+ --model-path /workspace/Qwen3-8B \
67
+ --host 10.1.1.31 \
68
+ --port 30000 \
69
+ --config-list 1,3,1,4 \
70
+ --benchmark-list mtbench:80 \
71
+ --dtype bfloat16 \
72
+ --skip-launch-server \
73
+ --name baseline_mtbench \
74
+ --output-dir ./results \
75
+ 2>&1 | tee logs/baseline_mtbench_$(date +%Y%m%d_%H%M%S).log
76
+
77
+ # 2. GSM8K
78
+ echo "=== Running GSM8K (Baseline) ==="
79
+ python benchmarks/bench_eagle3.py \
80
+ --model-path /workspace/Qwen3-8B \
81
+ --host 10.1.1.31 \
82
+ --port 30000 \
83
+ --config-list 1,3,1,4 \
84
+ --benchmark-list gsm8k:100 \
85
+ --dtype bfloat16 \
86
+ --skip-launch-server \
87
+ --name baseline_gsm8k \
88
+ --output-dir ./results \
89
+ 2>&1 | tee logs/baseline_gsm8k_$(date +%Y%m%d_%H%M%S).log
90
+
91
+ # 3. HumanEval
92
+ echo "=== Running HumanEval (Baseline) ==="
93
+ python benchmarks/bench_eagle3.py \
94
+ --model-path /workspace/Qwen3-8B \
95
+ --host 10.1.1.31 \
96
+ --port 30000 \
97
+ --config-list 1,3,1,4 \
98
+ --benchmark-list humaneval:164 \
99
+ --dtype bfloat16 \
100
+ --skip-launch-server \
101
+ --name baseline_humaneval \
102
+ --output-dir ./results \
103
+ 2>&1 | tee logs/baseline_humaneval_$(date +%Y%m%d_%H%M%S).log
104
+
105
+ echo "=== Baseline 测试完成 ==="
106
+ ```
107
+
108
+ ---
109
+
110
+ ## 2. 测试训练后的模型
111
+
112
+ ### 停止 Baseline 服务器并启动训练后的服务器(终端1)
113
+ ```bash
114
+ cd /workspace/hanrui/SpecForge-ext
115
+
116
+ # 停止旧服务器
117
+ pkill -f "sglang.launch_server"
118
+ sleep 5
119
+
120
+ # 设置环境变量
121
+ export NO_PROXY="localhost,127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16"
122
+ export no_proxy="localhost,127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16"
123
+
124
+ # 启动训练后的服务器
125
+ python3 -m sglang.launch_server \
126
+ --model /workspace/Qwen3-8B \
127
+ --speculative-algorithm EAGLE3 \
128
+ --speculative-draft-model-path /workspace/hanrui/SpecForge-ext/outputs/qwen3-8b-qwen3eagle-5layer/epoch_9_step_12310 \
129
+ --speculative-num-steps 3 \
130
+ --speculative-eagle-topk 1 \
131
+ --speculative-num-draft-tokens 4 \
132
+ --mem-fraction-static 0.75 \
133
+ --cuda-graph-max-bs 1 \
134
+ --tp 1 \
135
+ --trust-remote-code \
136
+ --host 0.0.0.0 \
137
+ --port 30000 \
138
+ --dtype bfloat16 \
139
+ --skip-server-warmup
140
+ ```
141
+
142
+ 等待看到 `Application startup complete` 后,继续下一步。
143
+
144
+ ### 运行三个 Benchmark(终端2)
145
+ ```bash
146
+ cd /workspace/hanrui/SpecForge-ext
147
+ conda activate /workspace/Hanrui/
148
+
149
+ # 设置环境变量
150
+ export NO_PROXY="localhost,127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16"
151
+ export no_proxy="localhost,127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16"
152
+
153
+ # 1. MT-Bench
154
+ echo "=== Running MT-Bench (Trained) ==="
155
+ python benchmarks/bench_eagle3.py \
156
+ --model-path /workspace/Qwen3-8B \
157
+ --host 10.1.1.31 \
158
+ --port 30000 \
159
+ --config-list 1,3,1,4 \
160
+ --benchmark-list mtbench:80 \
161
+ --dtype bfloat16 \
162
+ --skip-launch-server \
163
+ --name trained_mtbench \
164
+ --output-dir ./results \
165
+ 2>&1 | tee logs/trained_mtbench_$(date +%Y%m%d_%H%M%S).log
166
+
167
+ # 2. GSM8K
168
+ echo "=== Running GSM8K (Trained) ==="
169
+ python benchmarks/bench_eagle3.py \
170
+ --model-path /workspace/Qwen3-8B \
171
+ --host 10.1.1.31 \
172
+ --port 30000 \
173
+ --config-list 1,3,1,4 \
174
+ --benchmark-list gsm8k:100 \
175
+ --dtype bfloat16 \
176
+ --skip-launch-server \
177
+ --name trained_gsm8k \
178
+ --output-dir ./results \
179
+ 2>&1 | tee logs/trained_gsm8k_$(date +%Y%m%d_%H%M%S).log
180
+
181
+ # 3. HumanEval
182
+ echo "=== Running HumanEval (Trained) ==="
183
+ python benchmarks/bench_eagle3.py \
184
+ --model-path /workspace/Qwen3-8B \
185
+ --host 10.1.1.31 \
186
+ --port 30000 \
187
+ --config-list 1,3,1,4 \
188
+ --benchmark-list humaneval:164 \
189
+ --dtype bfloat16 \
190
+ --skip-launch-server \
191
+ --name trained_humaneval \
192
+ --output-dir ./results \
193
+ 2>&1 | tee logs/trained_humaneval_$(date +%Y%m%d_%H%M%S).log
194
+
195
+ echo "=== Trained 测试完成 ==="
196
+ ```
197
+
198
+ ---
199
+
200
+ ## 3. 查看结果
201
+
202
+ ### 日志文件位置
203
+ 所有日志保存在:`/workspace/hanrui/SpecForge-ext/logs/`
204
+ - `baseline_mtbench_*.log`
205
+ - `baseline_gsm8k_*.log`
206
+ - `baseline_humaneval_*.log`
207
+ - `trained_mtbench_*.log`
208
+ - `trained_gsm8k_*.log`
209
+ - `trained_humaneval_*.log`
210
+
211
+ 所有结果保存在:`/workspace/hanrui/SpecForge-ext/results/`
212
+ - `baseline_mtbench_*.jsonl`
213
+ - `baseline_gsm8k_*.jsonl`
214
+ - `baseline_humaneval_*.jsonl`
215
+ - `trained_mtbench_*.jsonl`
216
+ - `trained_gsm8k_*.jsonl`
217
+ - `trained_humaneval_*.jsonl`
218
+
219
+ ### 生成对比报告
220
+ ```bash
221
+ cd /workspace/hanrui/SpecForge-ext
222
+
223
+ python3 << 'EOF'
224
+ import json
225
+ import glob
226
+
227
+ print("=" * 80)
228
+ print("Accept Length 对比报告")
229
+ print("=" * 80)
230
+
231
+ datasets = ['mtbench', 'gsm8k', 'humaneval']
232
+
233
+ for dataset in datasets:
234
+ print(f"\n{'=' * 80}")
235
+ print(f"{dataset.upper()} 结果对比")
236
+ print('=' * 80)
237
+
238
+ baseline_files = sorted(glob.glob(f'results/baseline_{dataset}_*.jsonl'))
239
+ trained_files = sorted(glob.glob(f'results/trained_{dataset}_*.jsonl'))
240
+
241
+ if not baseline_files or not trained_files:
242
+ print(f" 未找到 {dataset} 的结果文件")
243
+ continue
244
+
245
+ with open(baseline_files[-1], 'r') as f:
246
+ baseline = json.load(f)
247
+
248
+ with open(trained_files[-1], 'r') as f:
249
+ trained = json.load(f)
250
+
251
+ baseline_metrics = baseline[dataset][0]['metrics'][0]
252
+ trained_metrics = trained[dataset][0]['metrics'][0]
253
+
254
+ print(f"\nBaseline:")
255
+ print(f" Accept Length: {baseline_metrics['accept_length']:.4f}")
256
+ print(f" Output Throughput: {baseline_metrics['output_throughput']:.2f} tokens/s")
257
+ if 'accuracy' in baseline_metrics and baseline_metrics['accuracy'] is not None:
258
+ print(f" Accuracy: {baseline_metrics['accuracy']:.2%}")
259
+
260
+ print(f"\nTrained:")
261
+ print(f" Accept Length: {trained_metrics['accept_length']:.4f}")
262
+ print(f" Output Throughput: {trained_metrics['output_throughput']:.2f} tokens/s")
263
+ if 'accuracy' in trained_metrics and trained_metrics['accuracy'] is not None:
264
+ print(f" Accuracy: {trained_metrics['accuracy']:.2%}")
265
+
266
+ accept_diff = trained_metrics['accept_length'] - baseline_metrics['accept_length']
267
+ accept_pct = (accept_diff / baseline_metrics['accept_length']) * 100
268
+
269
+ throughput_diff = trained_metrics['output_throughput'] - baseline_metrics['output_throughput']
270
+ throughput_pct = (throughput_diff / baseline_metrics['output_throughput']) * 100
271
+
272
+ print(f"\n差异:")
273
+ print(f" Accept Length: {accept_diff:+.4f} ({accept_pct:+.2f}%)")
274
+ print(f" Throughput: {throughput_diff:+.2f} tokens/s ({throughput_pct:+.2f}%)")
275
+
276
+ if 'accuracy' in baseline_metrics and baseline_metrics['accuracy'] is not None:
277
+ acc_diff = trained_metrics['accuracy'] - baseline_metrics['accuracy']
278
+ acc_pct = acc_diff * 100
279
+ print(f" Accuracy: {acc_pct:+.2f} percentage points")
280
+
281
+ print("\n" + "=" * 80)
282
+ EOF
283
+ ```
284
+
285
+ ---
286
+
287
+ ## 4. 快速查看单个结果
288
+ ```bash
289
+ cd /workspace/hanrui/SpecForge-ext
290
+
291
+ # 查看 baseline 的 accept_length
292
+ cat results/baseline_mtbench_*.jsonl | jq '.mtbench[0].metrics[0].accept_length'
293
+ cat results/baseline_gsm8k_*.jsonl | jq '.gsm8k[0].metrics[0].accept_length'
294
+ cat results/baseline_humaneval_*.jsonl | jq '.humaneval[0].metrics[0].accept_length'
295
+
296
+ # 查看 trained 的 accept_length
297
+ cat results/trained_mtbench_*.jsonl | jq '.mtbench[0].metrics[0].accept_length'
298
+ cat results/trained_gsm8k_*.jsonl | jq '.gsm8k[0].metrics[0].accept_length'
299
+ cat results/trained_humaneval_*.jsonl | jq '.humaneval[0].metrics[0].accept_length'
300
+ ```
SpecForge/.editorconfig ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://editorconfig.org/
2
+
3
+ root = true
4
+
5
+ [*]
6
+ charset = utf-8
7
+ end_of_line = lf
8
+ indent_style = space
9
+ indent_size = 4
10
+ trim_trailing_whitespace = true
11
+ insert_final_newline = true
12
+
13
+ [*.{json,yaml,yml}]
14
+ indent_size = 2
15
+
16
+ [*.md]
17
+ indent_size = 2
18
+ x-soft-wrap-text = true
19
+
20
+ [*.rst]
21
+ indent_size = 4
22
+ x-soft-wrap-text = true
23
+
24
+ [Makefile]
25
+ indent_style = tab
SpecForge/.isort.cfg ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [settings]
2
+ profile=black
3
+ known_first_party=sgl-eagle
SpecForge/.pre-commit-config.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_stages: [pre-commit, pre-push, manual]
2
+
3
+ repos:
4
+ - repo: https://github.com/PyCQA/autoflake
5
+ rev: v2.3.1
6
+ hooks:
7
+ - id: autoflake
8
+ args: [--remove-all-unused-imports, --in-place]
9
+ - repo: https://github.com/pre-commit/pre-commit-hooks
10
+ rev: v5.0.0
11
+ hooks:
12
+ - id: check-symlinks
13
+ - id: destroyed-symlinks
14
+ - id: trailing-whitespace
15
+ - id: end-of-file-fixer
16
+ - id: check-yaml
17
+ args: [--allow-multiple-documents]
18
+ - id: check-toml
19
+ - id: check-ast
20
+ - id: check-added-large-files
21
+ - id: check-merge-conflict
22
+ - id: check-shebang-scripts-are-executable
23
+ - id: detect-private-key
24
+ - id: debug-statements
25
+ - id: no-commit-to-branch
26
+ - repo: https://github.com/PyCQA/isort
27
+ rev: 5.13.2
28
+ hooks:
29
+ - id: isort
30
+ - repo: https://github.com/astral-sh/ruff-pre-commit
31
+ rev: v0.11.10
32
+ hooks:
33
+ - id: ruff
34
+ args: [--select=F401, --fixable=F401]
35
+ files: ^(benchmark/|docs/|examples/)
36
+ exclude: \.ipynb$
37
+ - repo: https://github.com/psf/black
38
+ rev: 24.10.0
39
+ hooks:
40
+ - id: black-jupyter
41
+ - repo: https://github.com/pre-commit/mirrors-clang-format
42
+ rev: v18.1.8
43
+ hooks:
44
+ - id: clang-format
45
+ types_or: [c++, cuda]
46
+ args: [--style=file, --verbose]
47
+ - repo: https://github.com/kynan/nbstripout
48
+ rev: 0.8.1
49
+ hooks:
50
+ - id: nbstripout
51
+ args:
52
+ - '--keep-output'
53
+ - '--extra-keys=metadata.kernelspec metadata.language_info.version'
SpecForge/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 sgl-project
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
SpecForge/MANIFEST.in ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ include requirements.txt
2
+ include version.txt
SpecForge/README.md ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center" id="sglangtop">
2
+ <img src="./assets/logo.png" alt="logo" width="400" margin="10px"></img>
3
+
4
+ [![documentation](https://img.shields.io/badge/📖-Documentation-red.svg?style=flat)](https://docs.sglang.ai/SpecForge/)
5
+ [![SpecBundle](https://img.shields.io/badge/🤗%20SpecBundle-yellow.svg?style=flat)](https://huggingface.co/collections/lmsys/specbundle)
6
+ [![DeepWiki](https://img.shields.io/badge/DeepWiki-SpecForge-blue.svg?logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAACwAAAAyCAYAAAAnWDnqAAAAAXNSR0IArs4c6QAAA05JREFUaEPtmUtyEzEQhtWTQyQLHNak2AB7ZnyXZMEjXMGeK/AIi+QuHrMnbChYY7MIh8g01fJoopFb0uhhEqqcbWTp06/uv1saEDv4O3n3dV60RfP947Mm9/SQc0ICFQgzfc4CYZoTPAswgSJCCUJUnAAoRHOAUOcATwbmVLWdGoH//PB8mnKqScAhsD0kYP3j/Yt5LPQe2KvcXmGvRHcDnpxfL2zOYJ1mFwrryWTz0advv1Ut4CJgf5uhDuDj5eUcAUoahrdY/56ebRWeraTjMt/00Sh3UDtjgHtQNHwcRGOC98BJEAEymycmYcWwOprTgcB6VZ5JK5TAJ+fXGLBm3FDAmn6oPPjR4rKCAoJCal2eAiQp2x0vxTPB3ALO2CRkwmDy5WohzBDwSEFKRwPbknEggCPB/imwrycgxX2NzoMCHhPkDwqYMr9tRcP5qNrMZHkVnOjRMWwLCcr8ohBVb1OMjxLwGCvjTikrsBOiA6fNyCrm8V1rP93iVPpwaE+gO0SsWmPiXB+jikdf6SizrT5qKasx5j8ABbHpFTx+vFXp9EnYQmLx02h1QTTrl6eDqxLnGjporxl3NL3agEvXdT0WmEost648sQOYAeJS9Q7bfUVoMGnjo4AZdUMQku50McDcMWcBPvr0SzbTAFDfvJqwLzgxwATnCgnp4wDl6Aa+Ax283gghmj+vj7feE2KBBRMW3FzOpLOADl0Isb5587h/U4gGvkt5v60Z1VLG8BhYjbzRwyQZemwAd6cCR5/XFWLYZRIMpX39AR0tjaGGiGzLVyhse5C9RKC6ai42ppWPKiBagOvaYk8lO7DajerabOZP46Lby5wKjw1HCRx7p9sVMOWGzb/vA1hwiWc6jm3MvQDTogQkiqIhJV0nBQBTU+3okKCFDy9WwferkHjtxib7t3xIUQtHxnIwtx4mpg26/HfwVNVDb4oI9RHmx5WGelRVlrtiw43zboCLaxv46AZeB3IlTkwouebTr1y2NjSpHz68WNFjHvupy3q8TFn3Hos2IAk4Ju5dCo8B3wP7VPr/FGaKiG+T+v+TQqIrOqMTL1VdWV1DdmcbO8KXBz6esmYWYKPwDL5b5FA1a0hwapHiom0r/cKaoqr+27/XcrS5UwSMbQAAAABJRU5ErkJggg==)](https://deepwiki.com/sgl-project/SpecForge)
7
+
8
+ [![github badge](https://img.shields.io/badge/📃%20LMSYS-Blog-black.svg?style=flat)](https://lmsys.org/blog/2025-07-25-spec-forge/)
9
+ [![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&amp)](https://sgl-fru7574.slack.com/archives/C09784E3EN6)
10
+ [![license](https://img.shields.io/badge/License-MIT%202.0-blue)](./LICENSE)
11
+
12
+ </div>
13
+
14
+ ## 📍 Overview
15
+
16
+ SpecForge is an ecosystem project developed by the SGLang team. It is a framework for training speculative decoding models so that you can smoothly port them over to the SGLang serving framework to speed up your inference.
17
+
18
+ We have seen many open-source projects for speculative decoding, but most of them are not well-maintained or not directly compatible with SGLang. We prepared this project because we wish that the open-source community can enjoy a speculative decoding framework that is
19
+ - regularly maintained by the SpecForge team: the code is runnable out-of-the-box
20
+ - directly compatible with SGLang: there is no additional efforts for porting to SGLang
21
+ - provide performant training capabilities: we provided online/offline/tensor-parallel/FSDP to suit your needs
22
+
23
+
24
+ Check out [**our documentation**](https://docs.sglang.ai/SpecForge/) to get started.
25
+
26
+
27
+ ## 🚀 Accelerate with SpecBundle
28
+
29
+ SpecBundle is a collection of production-grade speculative decoding models that are released by the SpecForge team and our industry partners. They provide higher acceptance rate compared to the existing open-source checkpoints over a wide range of domains. Together with SGLang, you can experience up to 4x speedup for inference. Check out our resources below:
30
+
31
+
32
+ | Item | Link |
33
+ | --- | --- |
34
+ | 📝 Documentation | [Link](https://docs.sglang.io/SpecForge/community_resources/specbundle.html) |
35
+ | 📊 Performance Dashboard | [Link](https://docs.sglang.io/SpecForge/SpecBundle/index.html) |
36
+ | 🤗 Hugging Face Collection | [Link](https://huggingface.co/collections/lmsys/specbundle) |
37
+
38
+
39
+ ## 🎉 News
40
+
41
+ - [2025-12] 🎉 Released SpecBundle (phase 1) and SpecForge v0.2. Check out our blog at [LMSYS.org](https://lmsys.org/blog/2025-12-23-spec-bundle-phase-1/)
42
+ - [2025-12] 🔔 Released the roadmap for 2026 Q1.
43
+ - [2025-08] 🔔 SpecForge is listed as a [flagship project](https://lmsys.org/about/) in LMSYS. Congratulations to the SpecForge team!
44
+ - [2025-08] 🔥 SpecForge powered the Eagle3 draft model for GPT-OSS. Check out the blog at [LMSYS.org](https://lmsys.org/blog/2025-08-27-gpt-oss/)
45
+ - [2025-07] 🔥 SpecForge is released together with Llama4-Eagle3 checkpoints. Check out our blog at [LMSYS.org](https://lmsys.org/blog/2025-07-25-spec-forge/)
46
+
47
+ ## ✨ Acknowledgements
48
+
49
+ <img src="./assets/acknowledgements.png" alt="acknowledgements"></img>
50
+
51
+ We would like to express our sincere gratitude to the official EAGLE team, especially Hongyang Zhang and Yuhui Li, for their invaluable contributions and support. Our thanks also go to the NVIDIA team—particularly Avery H and Izzy Putterman—and to the Google team, especially Ying Wang, for their insightful discussions and generous assistance throughout the project.
52
+
53
+ We are especially grateful to Meituan for their strong backing and meaningful contributions, which played a vital role in driving this project forward.
54
+
55
+ This project has also been inspired by many outstanding open-source projects from the LLM community, including [EAGLE](https://github.com/SafeAILab/EAGLE), [BaldEagle](https://github.com/NickL77/BaldEagle), and [TensorRT-Model-Optimizer](https://github.com/NVIDIA/TensorRT-Model-Optimizer) and others. Their contributions and shared knowledge have greatly benefited our work.
56
+
57
+ ## 💡 Special Thanks to Voltage Park
58
+
59
+ We would like to extend our sincere thanks to [Voltage Park](https://www.voltagepark.com/), our official infrastructure partner. As part of a formal collaboration with the SGLang team, Voltage Park provided critical GPU resources that empowered us to train and evaluate large-scale speculative decoding models efficiently and reliably. This partnership was instrumental in making SpecForge possible. We deeply appreciate Voltage Park’s mission to make cutting-edge AI infrastructure more accessible, and we look forward to continued collaboration as we push the boundaries of open-source LLM serving and optimization.
60
+
61
+ ## 📃 Citation
62
+
63
+ ```bibtex
64
+ @misc{specforge2025,
65
+ title={SpecForge: Train speculative decoding models effortlessly},
66
+ author={Shenggui Li, Yikai Zhu, Chao Wang, Fan Yin, Shuai Shi, Yubo Wang, Yi Zhang, Yingyi Huang, Haoshuai Zheng, Yineng Zhang},
67
+ year={2025},
68
+ publisher={GitHub},
69
+ howpublished={\url{https://github.com/sgl-project/specforge}},
70
+ }
SpecForge/pyproject.toml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "specforge"
7
+ dynamic = ["version"]
8
+ readme = "README.md"
9
+ requires-python = ">=3.11"
10
+ description = "SpecForge: Speculative Decoding Training Framework"
11
+ authors = [{name = "SGLang Team"}]
12
+ urls = {Homepage = "https://github.com/sgl-project/SpecForge"}
13
+ dependencies = [
14
+ "pre-commit",
15
+ "torch==2.9.1",
16
+ "torchaudio==2.9.1",
17
+ "torchvision==0.24.1",
18
+ "transformers==4.57.1",
19
+ "qwen-vl-utils==0.0.11",
20
+ "datasets",
21
+ "setuptools",
22
+ "tqdm",
23
+ "wandb",
24
+ "psutil",
25
+ "numpy",
26
+ "accelerate",
27
+ "pydantic",
28
+ "sglang==0.5.9",
29
+ "openai-harmony",
30
+ "ninja",
31
+ "packaging",
32
+ "yunchang",
33
+ "tensorboard",
34
+ ]
35
+
36
+ [tool.setuptools.packages.find]
37
+ exclude = ["configs*", "scripts*", "tests*"]
38
+
39
+ [project.optional-dependencies]
40
+ dev = [
41
+ "pre-commit",
42
+ "unittest"
43
+ ]
44
+ fa = ["flash-attn"]
45
+
46
+ [tool.setuptools.dynamic]
47
+ version = {file = "version.txt"}
SpecForge/requirements-rocm.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use the PyTorch ROCm wheel index (choose the stream that matches your system)
2
+ --extra-index-url https://download.pytorch.org/whl/rocm6.3
3
+
4
+ pre-commit
5
+ torch==2.8.0+rocm6.3
6
+ torchaudio==2.8.0+rocm6.3
7
+ torchvision==0.23.0+rocm6.3
8
+ transformers==4.57.1
9
+ qwen-vl-utils==0.0.11
10
+ datasets
11
+ setuptools
12
+ tqdm
13
+ wandb
14
+ psutil
15
+ numpy
16
+ accelerate
17
+ pydantic
18
+ sglang[all]==0.5.4
19
+ openai-harmony
20
+ tensorboard
SpecForge/version.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 0.2.0
idea1/.editorconfig ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://editorconfig.org/
2
+
3
+ root = true
4
+
5
+ [*]
6
+ charset = utf-8
7
+ end_of_line = lf
8
+ indent_style = space
9
+ indent_size = 4
10
+ trim_trailing_whitespace = true
11
+ insert_final_newline = true
12
+
13
+ [*.{json,yaml,yml}]
14
+ indent_size = 2
15
+
16
+ [*.md]
17
+ indent_size = 2
18
+ x-soft-wrap-text = true
19
+
20
+ [*.rst]
21
+ indent_size = 4
22
+ x-soft-wrap-text = true
23
+
24
+ [Makefile]
25
+ indent_style = tab
idea1/.isort.cfg ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [settings]
2
+ profile=black
3
+ known_first_party=sgl-eagle
idea1/.pre-commit-config.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_stages: [pre-commit, pre-push, manual]
2
+
3
+ repos:
4
+ - repo: https://github.com/PyCQA/autoflake
5
+ rev: v2.3.1
6
+ hooks:
7
+ - id: autoflake
8
+ args: [--remove-all-unused-imports, --in-place]
9
+ - repo: https://github.com/pre-commit/pre-commit-hooks
10
+ rev: v5.0.0
11
+ hooks:
12
+ - id: check-symlinks
13
+ - id: destroyed-symlinks
14
+ - id: trailing-whitespace
15
+ - id: end-of-file-fixer
16
+ - id: check-yaml
17
+ args: [--allow-multiple-documents]
18
+ - id: check-toml
19
+ - id: check-ast
20
+ - id: check-added-large-files
21
+ - id: check-merge-conflict
22
+ - id: check-shebang-scripts-are-executable
23
+ - id: detect-private-key
24
+ - id: debug-statements
25
+ - id: no-commit-to-branch
26
+ - repo: https://github.com/PyCQA/isort
27
+ rev: 5.13.2
28
+ hooks:
29
+ - id: isort
30
+ - repo: https://github.com/astral-sh/ruff-pre-commit
31
+ rev: v0.11.10
32
+ hooks:
33
+ - id: ruff
34
+ args: [--select=F401, --fixable=F401]
35
+ files: ^(benchmark/|docs/|examples/)
36
+ exclude: \.ipynb$
37
+ - repo: https://github.com/psf/black
38
+ rev: 24.10.0
39
+ hooks:
40
+ - id: black-jupyter
41
+ - repo: https://github.com/pre-commit/mirrors-clang-format
42
+ rev: v18.1.8
43
+ hooks:
44
+ - id: clang-format
45
+ types_or: [c++, cuda]
46
+ args: [--style=file, --verbose]
47
+ - repo: https://github.com/kynan/nbstripout
48
+ rev: 0.8.1
49
+ hooks:
50
+ - id: nbstripout
51
+ args:
52
+ - '--keep-output'
53
+ - '--extra-keys=metadata.kernelspec metadata.language_info.version'
idea1/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 sgl-project
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
idea1/requirements-rocm.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use the PyTorch ROCm wheel index (choose the stream that matches your system)
2
+ --extra-index-url https://download.pytorch.org/whl/rocm6.3
3
+
4
+ pre-commit
5
+ torch==2.8.0+rocm6.3
6
+ torchaudio==2.8.0+rocm6.3
7
+ torchvision==0.23.0+rocm6.3
8
+ transformers==4.57.1
9
+ qwen-vl-utils==0.0.11
10
+ datasets
11
+ setuptools
12
+ tqdm
13
+ wandb
14
+ psutil
15
+ numpy
16
+ accelerate
17
+ pydantic
18
+ sglang[all]==0.5.4
19
+ openai-harmony
20
+ tensorboard
idea1/version.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 0.2.0
qwen3-8b_dflash_regen/.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ sharegpt_train_regenerated.jsonl filter=lfs diff=lfs merge=lfs -text
syxin/backup.log ADDED
The diff for this file is too large to render. See raw diff
 
syxin/dflash_lora_changelog.md ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DFlash LoRA 全部改动记录
2
+
3
+ ## 概述
4
+
5
+ 为了让 Qwen3-8B DFlash LoRA 训练在 2×H100 上跑通(解决 OOM),共新增/修改了 **5 个文件,1084 行代码**。改动分为两大阶段:基础搭建 + OOM 修复。
6
+
7
+ ---
8
+
9
+ ## 新增文件清单
10
+
11
+ | 文件 | 行数 | 用途 |
12
+ |------|------|------|
13
+ | `specforge/core/dflash_lora.py` | 453 | 训练 wrapper(OnlineDFlashLoRAModel) |
14
+ | `specforge/modeling/draft/dflash_lora.py` | 141 | LoRA draft 模型(DFlashLoRADraftModel) |
15
+ | `scripts/train_dflash_lora.py` | 449 | 训练入口脚本 |
16
+ | `scripts/run_train_dflash_lora.sh` | 31 | 启动 shell 脚本 |
17
+ | `configs/qwen3-8b-dflash-lora.json` | 10 | LoRA 配置文件 |
18
+
19
+ ---
20
+
21
+ ## Step 1 完成过程
22
+
23
+ ### 1.1 分析现有代码
24
+
25
+ 首先分析了非 LoRA 版 `train_dflash.py` 的完整流程:
26
+
27
+ ```
28
+ input_ids → target_model.generate_dflash_data() → hidden_states
29
+ → OnlineDFlashModel.forward():
30
+ 1. 截断到 block 边界
31
+ 2. prepare_noise_input(): anchor 保留,其余 → MASK
32
+ 3. embed_tokens(noise_input_ids) → noise_embedding
33
+ 4. 构建 DFlash attention mask
34
+ 5. draft_model(noise_embedding, target_hidden, mask)
35
+ 6. lm_head(hidden) → logits → CE loss
36
+ ```
37
+
38
+ 非 LoRA 版使用独立的小型 draft model + 冻结 target model 提取 hidden states。
39
+
40
+ ### 1.2 确定 LoRA 版设计差异
41
+
42
+ | 方面 | 非 LoRA 版 (`train_dflash.py`) | LoRA 版 (`train_dflash_lora.py`) |
43
+ |------|------|------|
44
+ | Draft model | 自定义小模型 (1-10 层) | Qwen3-8B + PEFT LoRA |
45
+ | Target model | 冻结大模型提取 hidden states | 无需 — 模型用自身表征 |
46
+ | Attention | 自定义 Qwen3DFlashAttention,KV = [ctx, noise] concat | 标准 HF attention + DFlash mask |
47
+ | KV 结构 | Q_LEN = noise_len, KV_LEN = 2×noise_len | Q_LEN = KV_LEN = seq_len |
48
+ | 可训练参数 | 全部 draft model 参数 | 仅 LoRA (q/k/v/o_proj) |
49
+
50
+ ### 1.3 新建 LoRA 版三个核心文件
51
+
52
+ #### `specforge/modeling/draft/dflash_lora.py` — DFlashLoRADraftModel
53
+
54
+ - `from_pretrained()`: 加载 Qwen3-8B,注入 PEFT LoRA,支持 `attn_implementation` 参数
55
+ - `forward()`: 标准 HF forward,支持 `output_hidden_states` 参数(chunked loss 需要)
56
+ - `get_lm_head()`: 穿透 PEFT 层级获取 lm_head 引用
57
+ - `gradient_checkpointing_enable()`: 代理到底层模型
58
+ - `save_pretrained()`: 仅保存 LoRA adapter 权重
59
+
60
+ #### `specforge/core/dflash_lora.py` — OnlineDFlashLoRAModel
61
+
62
+ - `prepare_noise_input()`: context 部分保持不变,block 部分只保留 anchor,其余替换为 MASK
63
+ - `build_dflash_full_attn_mask_fast()`: 向量化构建 4D additive mask `[bsz, 1, seq, seq]`
64
+ - `_compute_loss_weights()`: context + anchor 权重为 0,非 anchor 权重为 1(或 decay)
65
+ - `_full_lm_loss()`: 标准 CE loss 路径
66
+ - `_compute_accuracy()`: block-wise acceptance rate(累积正确预测长度 / block 非 anchor 长度)
67
+ - `forward()`: 完整训练 forward pass
68
+
69
+ LoRA 版 mask 规则:
70
+ - context token i → 因果注意力 (j ≤ i)
71
+ - block token i (属于 block b) → 所有 context + 同 block 内双向注意力
72
+
73
+ #### `scripts/train_dflash_lora.py` — 训练脚本
74
+
75
+ - 参数解析:model/lora/dataset/training/output/distributed/tracker 7 组参数
76
+ - `build_model()`: 加载模型 + 注入 LoRA + 包装 OnlineDFlashLoRAModel
77
+ - `build_dataloader()`: 复用 `build_eagle3_dataset` 和 `prepare_dp_dataloaders`
78
+ - FSDP 包装 + BF16Optimizer
79
+ - 训练循环:forward → backward → accumulation → optimizer step
80
+ - checkpoint 保存/恢复
81
+
82
+ ---
83
+
84
+ ## OOM 修复改动(4 项)
85
+
86
+ ### 改动 1: FSDP FULL_SHARD (ZeRO-3)
87
+
88
+ **问题**: `SHARD_GRAD_OP` (ZeRO-2) 每卡持有完整 Qwen3-8B 参数 (~16GB bf16)
89
+
90
+ **修复**: `train_dflash_lora.py:362`
91
+ ```python
92
+ # 之前
93
+ sharding_strategy=ShardingStrategy.SHARD_GRAD_OP
94
+ # 之后
95
+ sharding_strategy=ShardingStrategy.FULL_SHARD
96
+ ```
97
+
98
+ **效果**: 参数跨卡分片,每卡省 ~8-12GB
99
+
100
+ ### 改动 2: batch_size=1 + accumulation_steps=8
101
+
102
+ **问题**: `batch_size=2` 时峰值显存过高
103
+
104
+ **修复**: `run_train_dflash_lora.sh`
105
+ ```bash
106
+ --batch-size 1 \
107
+ --accumulation-steps 8 \
108
+ ```
109
+
110
+ **效果**: 等效 global batch size 不变,峰值显存减半
111
+
112
+ ### 改动 3: flex_attention + BlockMask 替换 4D additive mask
113
+
114
+ **问题**: SDPA 不支持 4D additive mask → fallback 到 math backend → 每层 materialize 完整 `[bsz, 32heads, 2048, 2048]` attention scores
115
+
116
+ **修复**: 从非 LoRA 版 `dflash.py` 移植 `_get_or_create_block_mask()` 方法,适配 LoRA 场景
117
+
118
+ 涉及文件:
119
+
120
+ 1. **`specforge/core/dflash_lora.py`**
121
+ - `__init__()`: 添加 `attention_backend` 参数(默认 `"flex_attention"`),BlockMask 缓存字段
122
+ - 新增 `_get_or_create_block_mask()`: 用 `create_block_mask()` 构建零显存的 BlockMask
123
+ - `forward()`: 根据 `attention_backend` 选择 BlockMask 或 additive mask
124
+
125
+ 2. **`specforge/modeling/draft/dflash_lora.py`**
126
+ - `from_pretrained()`: 当 backend 为 flex_attention 时,传 `attn_implementation="flex_attention"` 给 HuggingFace
127
+
128
+ 3. **`scripts/train_dflash_lora.py`**
129
+ - `parse_args()`: `--attention-backend` 参数 (`flex_attention` | `additive`)
130
+ - `build_model()`: 根据 backend 选择 `attn_implementation`
131
+
132
+ BlockMask mask function(LoRA 版):
133
+ ```python
134
+ def dflash_lora_mask_fn(b, h, q_idx, kv_idx):
135
+ # Context query: 标准因果
136
+ is_q_ctx = q_idx < context_len
137
+ ctx_visible = is_q_ctx & (kv_idx <= q_idx)
138
+
139
+ # Block query: 全部 context + 同 block 双向
140
+ is_q_block = q_idx >= context_len
141
+ is_k_ctx = kv_idx < context_len
142
+ q_block_id = (q_idx - context_len) // block_size
143
+ k_block_id = (kv_idx - context_len) // block_size
144
+ block_attend_ctx = is_q_block & is_k_ctx
145
+ block_attend_same = is_q_block & (~is_k_ctx) & (q_block_id == k_block_id)
146
+
147
+ return ctx_visible | (block_attend_ctx | block_attend_same)
148
+ ```
149
+
150
+ **验证**: 手动逐元素对比 BlockMask 和 additive mask 输出,三组测试 (context_len=4/0, seq=12/16/64) pattern 完全一致。
151
+
152
+ **效果**: 不再 fallback 到 SDPA math backend,省去 `[bsz, heads, seq, seq]` attention scores 显存
153
+
154
+ ### 改动 4: chunked cross-entropy loss
155
+
156
+ **问题**: `[bsz, 2048, 151936]` bf16 logits ≈ 1.18GB,加梯度 ~2.4GB+
157
+
158
+ **修复**: 从非 LoRA 版 `dflash.py:419-478` 移植 chunked loss
159
+
160
+ 涉及文件:
161
+
162
+ 1. **`specforge/core/dflash_lora.py`**
163
+ - `__init__()`: 添加 `lm_head_chunk_size` 参数(默认 0 = 不启用)
164
+ - 新增 `_chunked_lm_loss()`: 分 chunk 过 lm_head + CE loss + gradient checkpointing
165
+ - 提取 `_full_lm_loss()`: 原始非 chunked 路径
166
+ - `forward()`: `lm_head_chunk_size > 0` 时走 chunked 路径
167
+
168
+ 2. **`specforge/modeling/draft/dflash_lora.py`**
169
+ - `forward()`: 新增 `output_hidden_states` 参数,True 时返回 last hidden state 而非 logits
170
+ - `get_lm_head()`: 穿透 PEFT 层级返回 `base_model.lm_head` 引用
171
+
172
+ 3. **`scripts/train_dflash_lora.py`**
173
+ - `parse_args()`: `--lm-head-chunk-size` 参数(默认 0,推荐 256)
174
+ - `build_model()`: 传递到 OnlineDFlashLoRAModel
175
+
176
+ Chunked loss 核心逻辑:
177
+ ```python
178
+ # 分 chunk 计算,每 chunk 用 gradient checkpointing(backward 时重算 logits,不存储)
179
+ for start in range(0, effective_len, chunk_size):
180
+ end = min(start + chunk_size, effective_len)
181
+ chunk_loss, chunk_weight = grad_checkpoint(
182
+ _chunk_ce, # lm_head + CE
183
+ hidden[:, start:end, :], # 只取当前 chunk
184
+ input_ids[:, start:end],
185
+ combined_mask[:, start:end],
186
+ use_reentrant=False,
187
+ )
188
+ total_loss += chunk_loss
189
+ total_weight += chunk_weight
190
+ loss = total_loss / total_weight
191
+ ```
192
+
193
+ **效果**: logits 峰值显存从 `O(seq_len × vocab_size)` 降至 `O(chunk_size × vocab_size)`,256 chunk → ~150MB vs 1.18GB
194
+
195
+ ---
196
+
197
+ ## 当前训练命令
198
+
199
+ ```bash
200
+ bash run_train_dflash_lora.sh 2 # 2 = GPU 数量
201
+ ```
202
+
203
+ 对应完整参数:
204
+ ```bash
205
+ torchrun --nproc_per_node 2 scripts/train_dflash_lora.py \
206
+ --model-path /workspace/Qwen3-8B \
207
+ --train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
208
+ --output-dir outputs/qwen3-8b-dflash-lora \
209
+ --lora-config configs/qwen3-8b-dflash-lora.json \
210
+ --block-size 16 \
211
+ --max-length 2048 \
212
+ --batch-size 1 \
213
+ --num-epochs 3 \
214
+ --learning-rate 2e-4 \
215
+ --accumulation-steps 8 \
216
+ --loss-decay-gamma 7 \
217
+ --attention-backend flex_attention \
218
+ --lm-head-chunk-size 256 \
219
+ --gradient-checkpointing \
220
+ --chat-template qwen \
221
+ --log-interval 50 \
222
+ --save-interval 500
223
+ ```
224
+
225
+ ---
226
+
227
+ ## 待验证
228
+
229
+ - [ ] 跑 `bash run_train_dflash_lora.sh 2` 确认不再 OOM
230
+ - [ ] 确认无 SDPA math fallback warning
231
+ - [ ] 观察 GPU 显存峰值
232
+ - [ ] 确认 loss 下降和 accuracy 上升趋势正常
syxin/eval_accepted_length.md ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DFlash-LoRA-Inject 评测:Accepted Length & Accuracy
2
+
3
+ ## 为什么不能用 sglang 在线评测?
4
+
5
+ DFlash-LoRA-Inject 的推理需要**逐层注入 target 模型的 hidden states** 到 draft 模型中,
6
+ 这是 LoRA-Inject 训练时的核心机制。但 sglang 不支持这种推理模式:
7
+
8
+ | sglang 算法 | 问题 |
9
+ |---|---|
10
+ | `STANDALONE` | 把 draft 当独立自回归模型跑,**完全忽略 layer injection**。merged 模型 ≈ 原始 Qwen3-8B,accept_length 恒 ≈ 4.7,跟 LoRA 训没训没关系 |
11
+ | `DFLASH` | 期望 DFlash-b16 架构(5 层 + fc + hidden_norm),跟 LoRA-Inject(36 层全模型)结构不匹配 |
12
+
13
+ 因此必须**离线评测**:加载 target + draft 两个模型,手动实现带 layer injection 的 speculative decoding 循环。
14
+
15
+ ---
16
+
17
+ ## 基本信息
18
+
19
+ | 项目 | 路径 / 值 |
20
+ |---|---|
21
+ | conda 环境 | `spec` |
22
+ | 基座模型(target) | `/workspace/models/Qwen3-8B` |
23
+ | 训练输出(最终 ckpt) | `.../outputs/qwen3-8b-dflash-lora-inject/epoch_3_step_1400` |
24
+ | 合并后 draft 模型 | `.../outputs/qwen3-8b-dflash-lora-inject-merged` |
25
+ | 评测脚本 | `/workspace/hanrui/syxin_old/eval_dflash_lora_inject.py` |
26
+ | 本地数据集 | `/workspace/hanrui/datasets/{humaneval,mtbench,gsm8k}` |
27
+ | 结果输出目录 | `/workspace/hanrui/syxin_old/Specforge/benchmarks/results/` |
28
+ | GPU | 8 × H100 80GB(单卡即可,需 ~32GB 加载两个 8B 模型) |
29
+
30
+ ---
31
+
32
+ ## Step 1:合并 LoRA 权重
33
+
34
+ LoRA-Inject 训练只保存 adapter 权重,评测时需要完整模型。
35
+
36
+ ```bash
37
+ conda activate spec
38
+
39
+ python3 -c "
40
+ from peft import PeftModel
41
+ from transformers import AutoModelForCausalLM, AutoTokenizer
42
+ import torch, os
43
+
44
+ BASE = '/workspace/models/Qwen3-8B'
45
+ ADAPTER = '/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject/epoch_3_step_1400'
46
+ MERGED = '/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject-merged'
47
+
48
+ if os.path.exists(MERGED):
49
+ print(f'[skip] Merged model already exists: {MERGED}')
50
+ else:
51
+ print('[1/4] Loading base model to CPU ...')
52
+ model = AutoModelForCausalLM.from_pretrained(BASE, torch_dtype=torch.bfloat16, device_map='cpu')
53
+ print('[2/4] Loading LoRA adapter ...')
54
+ model = PeftModel.from_pretrained(model, ADAPTER)
55
+ print('[3/4] Merging weights ...')
56
+ model = model.merge_and_unload()
57
+ print('[4/4] Saving merged model ...')
58
+ os.makedirs(MERGED, exist_ok=True)
59
+ model.save_pretrained(MERGED, safe_serialization=True)
60
+ AutoTokenizer.from_pretrained(BASE).save_pretrained(MERGED)
61
+ print(f'Done. Merged model saved to: {MERGED}')
62
+ "
63
+ ```
64
+
65
+ > 耗时约 3–5 分钟,CPU 内存占用 ≈ 16 GB。已存在则自动跳过。
66
+
67
+ ---
68
+
69
+ ## Step 2:离线评测 accepted length
70
+
71
+ **不需要启动 sglang server**,直接跑:
72
+
73
+ ### 全部 Bench(推荐)
74
+
75
+ ```bash
76
+ bash /workspace/hanrui/syxin_old/run_bench_dflash.sh
77
+ ```
78
+
79
+ ### 单独跑 / 快速测试
80
+
81
+ ```bash
82
+ # 只跑 HumanEval
83
+ bash /workspace/hanrui/syxin_old/run_bench_dflash.sh humaneval
84
+
85
+ # 快速测试(每个 bench 20 条)
86
+ bash /workspace/hanrui/syxin_old/run_bench_dflash.sh --quick
87
+
88
+ # 指定 checkpoint
89
+ bash /workspace/hanrui/syxin_old/run_bench_dflash.sh --ckpt epoch_0_step_1000
90
+
91
+ # 组合
92
+ bash /workspace/hanrui/syxin_old/run_bench_dflash.sh humaneval gsm8k --quick
93
+ ```
94
+
95
+ ### 或者直接调 Python
96
+
97
+ ```bash
98
+ conda activate spec
99
+
100
+ python3 /workspace/hanrui/syxin_old/eval_dflash_lora_inject.py \
101
+ --benchmarks humaneval mtbench gsm8k \
102
+ --block-size 16 \
103
+ --max-new-tokens 512 \
104
+ --temperature 0.0
105
+ ```
106
+
107
+ ---
108
+
109
+ ## 结果文件说明
110
+
111
+ 结果保存在 `results/` 下,文件名示例:
112
+ ```
113
+ dflash_lora_inject_offline_epoch_3_step_1400_20260314_150000.json
114
+ ```
115
+
116
+ ```json
117
+ {
118
+ "model": "dflash-lora-inject/epoch_3_step_1400",
119
+ "block_size": 16,
120
+ "humaneval": {
121
+ "avg_accept_length": 3.42,
122
+ "total_tokens": 28500,
123
+ "latency": 120.5,
124
+ "throughput": 236.5,
125
+ "num_samples": 164,
126
+ "num_verify_rounds": 8320
127
+ },
128
+ "mtbench": { ... },
129
+ "gsm8k": { ... }
130
+ }
131
+ ```
132
+
133
+ | 字段 | 含义 |
134
+ |---|---|
135
+ | `avg_accept_length` | **核心指标**:平均每次 verify 接受的 token 数(含 injection)。越高越好,`1.0` = draft 完全无效 |
136
+ | `total_tokens` | 总生成 token 数 |
137
+ | `throughput` | tokens/s(离线评测,不含 batching 优化) |
138
+ | `num_verify_rounds` | 总验证轮数 |
139
+
140
+ ---
141
+
142
+ ## 对比 baseline
143
+
144
+ 对比未经 LoRA 训练的原始 Qwen3-8B 当 draft 的 accept_length:
145
+
146
+ ```bash
147
+ python3 /workspace/hanrui/syxin_old/eval_dflash_lora_inject.py \
148
+ --merged-path /workspace/models/Qwen3-8B \
149
+ --benchmarks humaneval mtbench gsm8k \
150
+ --num-samples 50
151
+ ```
152
+
153
+ > 这会用原始 Qwen3-8B 同时当 target 和 draft(带 injection),
154
+ > 对比 LoRA 训练前后 accept_length 是否有提升。
155
+
156
+ ---
157
+
158
+ ## 如何测其他 checkpoint
159
+
160
+ ```bash
161
+ # 方法 1:直接加载 adapter(自动 merge,不保存)
162
+ python3 /workspace/hanrui/syxin_old/eval_dflash_lora_inject.py \
163
+ --ckpt epoch_0_step_1000 \
164
+ --benchmarks humaneval --num-samples 50
165
+
166
+ # 方法 2:预先 merge 到不同目录
167
+ python3 -c "
168
+ from peft import PeftModel
169
+ from transformers import AutoModelForCausalLM, AutoTokenizer
170
+ import torch, os
171
+ BASE = '/workspace/models/Qwen3-8B'
172
+ ADAPTER = '/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject/epoch_0_step_1000'
173
+ MERGED = '/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject-merged-epoch_0_step_1000'
174
+ model = AutoModelForCausalLM.from_pretrained(BASE, torch_dtype=torch.bfloat16, device_map='cpu')
175
+ model = PeftModel.from_pretrained(model, ADAPTER).merge_and_unload()
176
+ os.makedirs(MERGED, exist_ok=True)
177
+ model.save_pretrained(MERGED, safe_serialization=True)
178
+ AutoTokenizer.from_pretrained(BASE).save_pretrained(MERGED)
179
+ "
180
+
181
+ python3 /workspace/hanrui/syxin_old/eval_dflash_lora_inject.py \
182
+ --merged-path .../qwen3-8b-dflash-lora-inject-merged-epoch_0_step_1000 \
183
+ --benchmarks humaneval --num-samples 50
184
+ ```
185
+
186
+ 可用 checkpoint:`epoch_0_step_500` / `epoch_0_step_1000` / `epoch_0_step_1400` / `epoch_2_step_34500` / `epoch_2_step_35000` / `epoch_3_step_1400`
187
+
188
+ ---
189
+
190
+ ## 常见问题
191
+
192
+ ### Q1:accept_length 和 STANDALONE 模式下差不多(都 ≈ 4.7)
193
+
194
+ 这说明 layer injection 没有真正起作用。检查:
195
+ - 评测脚本确实用的是 `eval_dflash_lora_inject.py`(离线),不是 sglang bench
196
+ - merged 模型确实是 LoRA-Inject 版本(不是原始 Qwen3-8B)
197
+
198
+ ### Q2:OOM(单卡放不下两个 8B 模型)
199
+
200
+ 两个 bf16 的 Qwen3-8B ≈ 32GB,单卡 H100 80GB 够用。如果 OOM:
201
+ - 检查是否有其他进程占用显存
202
+ - 减小 `--max-new-tokens`(试 256)
203
+ - 减小 `--num-samples`
204
+
205
+ ### Q3:数据集下载失败(无外网)
206
+
207
+ 评测脚本优先读本地文件:
208
+
209
+ | bench | 本地文件 |
210
+ |---|---|
211
+ | GSM8K | `/workspace/hanrui/datasets/gsm8k/test.jsonl` |
212
+ | MT-Bench | `/workspace/hanrui/datasets/mtbench/question.jsonl` |
213
+ | HumanEval | `/workspace/hanrui/datasets/humaneval/test.jsonl` |
214
+
215
+ ---
216
+
217
+ *基座:`/workspace/models/Qwen3-8B` | 最终 ckpt:`epoch_3_step_1400` | block_size:16*
syxin/eval_dflash_b16_baseline.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Offline evaluation for DFlash-b16 baseline: measure accepted length.
4
+ 8 GPUs parallel, each GPU loads target + draft independently.
5
+
6
+ Usage:
7
+ # 8 GPUs
8
+ torchrun --nproc_per_node 8 eval_dflash_b16_baseline.py
9
+
10
+ # quick test
11
+ torchrun --nproc_per_node 8 eval_dflash_b16_baseline.py --num-samples 20
12
+
13
+ # single GPU
14
+ python3 eval_dflash_b16_baseline.py --benchmarks humaneval
15
+ """
16
+ import argparse
17
+ import json
18
+ import os
19
+ import sys
20
+ import time
21
+ from typing import List, Optional, Tuple
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.distributed as dist
26
+ from tqdm import tqdm
27
+ from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, DynamicCache
28
+
29
+ # Add DFlash model path so we can import utils
30
+ sys.path.insert(0, "/workspace/models/Qwen3-8B-DFlash-b16")
31
+ from utils import extract_context_feature, sample
32
+
33
+ # ──────────────────────────────────────────────────────────────────
34
+ BASE_MODEL = "/workspace/models/Qwen3-8B"
35
+ DRAFT_MODEL = "/workspace/models/Qwen3-8B-DFlash-b16"
36
+ RESULT_DIR = "/workspace/hanrui/syxin_old/Specforge/benchmarks/results"
37
+
38
+
39
+ # ──────────────────────────────────────────────────────────────────
40
+ # Distributed helpers
41
+ # ──────────────────────────────────────────────────────────────────
42
+ def is_distributed():
43
+ return dist.is_available() and dist.is_initialized()
44
+
45
+ def get_rank():
46
+ return dist.get_rank() if is_distributed() else 0
47
+
48
+ def get_world_size():
49
+ return dist.get_world_size() if is_distributed() else 1
50
+
51
+ def is_main():
52
+ return get_rank() == 0
53
+
54
+ def print_rank0(*args, **kwargs):
55
+ if is_main():
56
+ print(*args, **kwargs)
57
+
58
+ def split_list(lst, rank, world_size):
59
+ return [x for i, x in enumerate(lst) if i % world_size == rank]
60
+
61
+
62
+ # ──────────────────────────────────────────────────────────────────
63
+ # Prompts
64
+ # ──────────────────────────────────────────────────────────────────
65
+ def load_prompts(bench_name: str, num_samples: Optional[int] = None) -> List[str]:
66
+ local_paths = {
67
+ "humaneval": "/workspace/hanrui/datasets/humaneval/test.jsonl",
68
+ "mtbench": "/workspace/hanrui/datasets/mtbench/question.jsonl",
69
+ "gsm8k": "/workspace/hanrui/datasets/gsm8k/test.jsonl",
70
+ }
71
+ prompts = []
72
+ path = local_paths.get(bench_name)
73
+ if path and os.path.exists(path):
74
+ with open(path) as f:
75
+ for line in f:
76
+ item = json.loads(line)
77
+ if bench_name == "humaneval":
78
+ p = f"Write a solution to the following problem and make sure that it passes the tests:\n```python\n{item['prompt']}\n```"
79
+ elif bench_name == "mtbench":
80
+ p = item.get("turns", [item.get("prompt", "")])[0]
81
+ elif bench_name == "gsm8k":
82
+ p = item["question"] + "\nPlease reason step by step, and put your final answer within \\boxed{}."
83
+ else:
84
+ p = str(item)
85
+ prompts.append(p)
86
+ else:
87
+ from datasets import load_dataset
88
+ if bench_name == "humaneval":
89
+ ds = load_dataset("openai/openai_humaneval", split="test")
90
+ prompts = [f"Write a solution to the following problem and make sure that it passes the tests:\n```python\n{x['prompt']}\n```" for x in ds]
91
+ elif bench_name == "mtbench":
92
+ ds = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train")
93
+ prompts = [x["prompt"][0] for x in ds]
94
+ elif bench_name == "gsm8k":
95
+ ds = load_dataset("openai/gsm8k", "main", split="test")
96
+ prompts = [x["question"] + "\nPlease reason step by step, and put your final answer within \\boxed{}." for x in ds]
97
+ if num_samples is not None:
98
+ prompts = prompts[:num_samples]
99
+ return prompts
100
+
101
+
102
+ # ──────────────────────────────────────────────────────────────────
103
+ # spec_generate with acceptance_lengths returned
104
+ # (Same logic as DFlashDraftModel.spec_generate but returns accept lens)
105
+ # ──────────────────────────────────────────────────────────────────
106
+ @torch.inference_mode()
107
+ def spec_generate_b16(
108
+ draft_model,
109
+ target_model: nn.Module,
110
+ input_ids: torch.LongTensor,
111
+ max_new_tokens: int = 512,
112
+ temperature: float = 0.0,
113
+ stop_token_ids: Optional[List[int]] = None,
114
+ ) -> Tuple[torch.Tensor, List[int]]:
115
+ """Same as DFlashDraftModel.spec_generate but also returns acceptance_lengths."""
116
+ draft_model.eval()
117
+ device = target_model.device if hasattr(target_model, 'device') else input_ids.device
118
+ num_input_tokens = input_ids.shape[1]
119
+ max_length = num_input_tokens + max_new_tokens
120
+ block_size = draft_model.block_size
121
+ mask_token_id = draft_model.mask_token_id
122
+
123
+ output_ids = torch.full(
124
+ (1, max_length + block_size), mask_token_id,
125
+ dtype=torch.long, device=device,
126
+ )
127
+ position_ids = torch.arange(output_ids.shape[1], device=device).unsqueeze(0)
128
+
129
+ past_key_values_target = DynamicCache()
130
+ past_key_values_draft = DynamicCache()
131
+
132
+ # Prefill
133
+ output = target_model(
134
+ input_ids,
135
+ position_ids=position_ids[:, :num_input_tokens],
136
+ past_key_values=past_key_values_target,
137
+ use_cache=True,
138
+ logits_to_keep=1,
139
+ output_hidden_states=True,
140
+ )
141
+ output_ids[:, :num_input_tokens] = input_ids
142
+ output_ids[:, num_input_tokens:num_input_tokens + 1] = sample(output.logits, temperature)
143
+ target_hidden = extract_context_feature(output.hidden_states, draft_model.target_layer_ids)
144
+
145
+ # Decode
146
+ acceptance_lengths = []
147
+ start = num_input_tokens
148
+ while start < max_length:
149
+ block_output_ids = output_ids[:, start:start + block_size].clone()
150
+ block_position_ids = position_ids[:, start:start + block_size]
151
+ noise_embedding = target_model.model.embed_tokens(block_output_ids)
152
+
153
+ draft_logits = target_model.lm_head(
154
+ draft_model(
155
+ target_hidden=target_hidden,
156
+ noise_embedding=noise_embedding,
157
+ position_ids=position_ids[:, past_key_values_draft.get_seq_length():start + block_size],
158
+ past_key_values=past_key_values_draft,
159
+ use_cache=True,
160
+ is_causal=False,
161
+ )[:, -block_size + 1:, :]
162
+ )
163
+ past_key_values_draft.crop(start)
164
+ block_output_ids[:, 1:] = sample(draft_logits)
165
+
166
+ output = target_model(
167
+ block_output_ids,
168
+ position_ids=block_position_ids,
169
+ past_key_values=past_key_values_target,
170
+ use_cache=True,
171
+ output_hidden_states=True,
172
+ )
173
+
174
+ posterior = sample(output.logits, temperature)
175
+ acceptance_length = (
176
+ (block_output_ids[:, 1:] == posterior[:, :-1])
177
+ .cumprod(dim=1).sum(dim=1)[0].item()
178
+ )
179
+ output_ids[:, start:start + int(acceptance_length) + 1] = block_output_ids[:, :int(acceptance_length) + 1]
180
+ output_ids[:, start + int(acceptance_length) + 1] = posterior[:, int(acceptance_length)]
181
+ start += int(acceptance_length) + 1
182
+ past_key_values_target.crop(start)
183
+ target_hidden = extract_context_feature(
184
+ output.hidden_states, draft_model.target_layer_ids
185
+ )[:, :int(acceptance_length) + 1, :]
186
+ acceptance_lengths.append(int(acceptance_length) + 1)
187
+
188
+ if stop_token_ids is not None and any(
189
+ sid in output_ids[:, num_input_tokens:start] for sid in stop_token_ids
190
+ ):
191
+ break
192
+
193
+ output_ids = output_ids[:, :max_length]
194
+ output_ids = output_ids[:, output_ids[0] != mask_token_id]
195
+ if stop_token_ids is not None:
196
+ stop_t = torch.tensor(stop_token_ids, device=output_ids.device)
197
+ stop_idx = torch.isin(output_ids[0][num_input_tokens:], stop_t).nonzero(as_tuple=True)[0]
198
+ if stop_idx.numel() > 0:
199
+ output_ids = output_ids[:, :num_input_tokens + stop_idx[0] + 1]
200
+
201
+ return output_ids, acceptance_lengths
202
+
203
+
204
+ # ──────────────────────────────────────────────────────────────────
205
+ def parse_args():
206
+ p = argparse.ArgumentParser()
207
+ p.add_argument("--base-model", default=BASE_MODEL)
208
+ p.add_argument("--draft-model", default=DRAFT_MODEL)
209
+ p.add_argument("--max-new-tokens", type=int, default=512)
210
+ p.add_argument("--temperature", type=float, default=0.0)
211
+ p.add_argument("--benchmarks", nargs="+", default=["humaneval", "mtbench", "gsm8k"])
212
+ p.add_argument("--num-samples", type=int, default=None)
213
+ p.add_argument("--output-dir", default=RESULT_DIR)
214
+ return p.parse_args()
215
+
216
+
217
+ def main():
218
+ args = parse_args()
219
+
220
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
221
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
222
+
223
+ if world_size > 1:
224
+ dist.init_process_group(backend="nccl")
225
+ torch.cuda.set_device(local_rank)
226
+
227
+ device = f"cuda:{local_rank}"
228
+ rank = get_rank()
229
+
230
+ print_rank0(f"Running DFlash-b16 baseline on {world_size} GPU(s)")
231
+
232
+ # ── Load models ──
233
+ print_rank0(f"Loading target: {args.base_model}")
234
+ target_model = AutoModelForCausalLM.from_pretrained(
235
+ args.base_model,
236
+ torch_dtype=torch.bfloat16,
237
+ device_map=device,
238
+ trust_remote_code=True,
239
+ )
240
+ target_model.eval()
241
+
242
+ print_rank0(f"Loading DFlash-b16 draft: {args.draft_model}")
243
+ draft_model = AutoModel.from_pretrained(
244
+ args.draft_model,
245
+ torch_dtype=torch.bfloat16,
246
+ trust_remote_code=True,
247
+ ).to(device)
248
+ draft_model.eval()
249
+
250
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
251
+ stop_token_ids = [tokenizer.eos_token_id]
252
+
253
+ print_rank0(f"DFlash-b16: block_size={draft_model.block_size}, "
254
+ f"target_layer_ids={draft_model.target_layer_ids}, "
255
+ f"num_layers={len(draft_model.layers)}")
256
+
257
+ # ── Run benchmarks ──
258
+ results = {"model": "Qwen3-8B-DFlash-b16", "type": "baseline",
259
+ "block_size": draft_model.block_size}
260
+
261
+ for bench_name in args.benchmarks:
262
+ print_rank0(f"\n{'='*60}")
263
+ print_rank0(f"Benchmark: {bench_name} ({world_size} GPUs)")
264
+ print_rank0(f"{'='*60}")
265
+
266
+ all_prompts = load_prompts(bench_name, args.num_samples)
267
+ my_prompts = split_list(all_prompts, rank, world_size)
268
+ print_rank0(f"Total {len(all_prompts)} prompts, ~{len(my_prompts)} per GPU")
269
+
270
+ local_accept_lengths = []
271
+ local_tokens = 0
272
+ t0 = time.time()
273
+
274
+ iterator = tqdm(my_prompts, desc=f"[GPU{rank}] {bench_name}", unit="sample",
275
+ disable=(rank != 0))
276
+ for prompt in iterator:
277
+ messages = [{"role": "user", "content": prompt}]
278
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
279
+ input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
280
+
281
+ output_ids, accept_lens = spec_generate_b16(
282
+ draft_model=draft_model,
283
+ target_model=target_model,
284
+ input_ids=input_ids,
285
+ max_new_tokens=args.max_new_tokens,
286
+ temperature=args.temperature,
287
+ stop_token_ids=stop_token_ids,
288
+ )
289
+
290
+ local_accept_lengths.extend(accept_lens)
291
+ num_gen = output_ids.shape[1] - input_ids.shape[1]
292
+ local_tokens += num_gen
293
+
294
+ if rank == 0 and len(local_accept_lengths) > 0:
295
+ avg = sum(local_accept_lengths) / len(local_accept_lengths)
296
+ iterator.set_postfix(accept_len=f"{avg:.2f}", tokens=local_tokens, gen=num_gen)
297
+
298
+ elapsed = time.time() - t0
299
+
300
+ # ── Gather ──
301
+ if world_size > 1:
302
+ local_sum = torch.tensor(sum(local_accept_lengths), dtype=torch.float64, device=device)
303
+ local_count = torch.tensor(len(local_accept_lengths), dtype=torch.long, device=device)
304
+ local_tok = torch.tensor(local_tokens, dtype=torch.long, device=device)
305
+ dist.all_reduce(local_sum, op=dist.ReduceOp.SUM)
306
+ dist.all_reduce(local_count, op=dist.ReduceOp.SUM)
307
+ dist.all_reduce(local_tok, op=dist.ReduceOp.SUM)
308
+ total_accept_sum = local_sum.item()
309
+ total_count = local_count.item()
310
+ total_tokens = local_tok.item()
311
+ else:
312
+ total_accept_sum = sum(local_accept_lengths)
313
+ total_count = len(local_accept_lengths)
314
+ total_tokens = local_tokens
315
+
316
+ avg_accept_length = total_accept_sum / max(total_count, 1)
317
+ throughput = total_tokens / elapsed if elapsed > 0 else 0
318
+
319
+ print_rank0(f"\n{bench_name} Results:")
320
+ print_rank0(f" Avg Accept Length: {avg_accept_length:.3f}")
321
+ print_rank0(f" Total tokens: {total_tokens}")
322
+ print_rank0(f" Latency: {elapsed:.1f}s")
323
+ print_rank0(f" Throughput: {throughput:.1f} tok/s (aggregate {world_size} GPUs)")
324
+ print_rank0(f" Num verify rounds: {total_count}")
325
+ print_rank0(f" Num samples: {len(all_prompts)}")
326
+
327
+ results[bench_name] = {
328
+ "avg_accept_length": avg_accept_length,
329
+ "total_tokens": total_tokens,
330
+ "latency": elapsed,
331
+ "throughput": throughput,
332
+ "num_samples": len(all_prompts),
333
+ "num_verify_rounds": total_count,
334
+ "num_gpus": world_size,
335
+ }
336
+
337
+ # ── Save ──
338
+ if is_main():
339
+ os.makedirs(args.output_dir, exist_ok=True)
340
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
341
+ result_file = os.path.join(
342
+ args.output_dir,
343
+ f"dflash_b16_baseline_offline_{timestamp}.json",
344
+ )
345
+ with open(result_file, "w") as f:
346
+ json.dump(results, f, indent=2)
347
+ print(f"\nResults saved to: {result_file}")
348
+
349
+ if world_size > 1:
350
+ dist.destroy_process_group()
351
+
352
+
353
+ if __name__ == "__main__":
354
+ main()
syxin/eval_dflash_lora_inject.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Offline evaluation for DFlash-LoRA-Inject: measure accepted length & speedup.
4
+ Aligned with official DFlash benchmark.py methodology.
5
+
6
+ Unlike DFlash-b16 which uses a small 5-layer draft model with fc/hidden_norm,
7
+ LoRA-Inject uses a full Qwen3-8B with LoRA adapters that receives target hidden
8
+ states via layer-by-layer injection.
9
+
10
+ Usage:
11
+ conda activate spec
12
+
13
+ # 8 GPU parallel (default, all 10 benchmarks)
14
+ torchrun --nproc_per_node 8 eval_dflash_lora_inject.py
15
+
16
+ # single GPU
17
+ python3 eval_dflash_lora_inject.py
18
+
19
+ # specific checkpoint / benchmark
20
+ torchrun --nproc_per_node 8 eval_dflash_lora_inject.py --ckpt epoch_0_step_1000 --datasets humaneval
21
+
22
+ # quick test
23
+ torchrun --nproc_per_node 8 eval_dflash_lora_inject.py --max-samples 20
24
+ """
25
+ import argparse
26
+ import json
27
+ import os
28
+ import random
29
+ import sys
30
+ import time
31
+ import warnings
32
+ from itertools import chain
33
+ from types import SimpleNamespace
34
+ from typing import List, Optional, Tuple
35
+
36
+ import numpy as np
37
+ import torch
38
+ import torch.nn as nn
39
+ import torch.distributed as dist
40
+ from peft import PeftModel
41
+ from tqdm import tqdm
42
+ from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
43
+
44
+ # Import official dataset loader
45
+ sys.path.insert(0, "/workspace/hanrui/dflash")
46
+ from model.utils import load_and_process_dataset
47
+
48
+ # ──────────────────────────────────────────────────────────────────
49
+ # Config defaults
50
+ # ──────────────────────────────────────────────────────────────────
51
+ BASE_MODEL = "/workspace/models/Qwen3-8B"
52
+ ADAPTER_ROOT = "/workspace/hanrui/syxin/Specforge/outputs/qwen3-8b-dflash-lora-inject"
53
+ DEFAULT_CKPT = "epoch_3_step_1400"
54
+ MASK_TOKEN_ID = 151669 # Qwen3 <|mask|>
55
+ BLOCK_SIZE = 16
56
+ RESULT_DIR = "/workspace/hanrui/syxin/Specforge/benchmarks/results"
57
+
58
+ # Official benchmark tasks (from run_benchmark.sh)
59
+ OFFICIAL_TASKS = {
60
+ "gsm8k": 128,
61
+ "math500": 128,
62
+ "aime24": 30,
63
+ "aime25": 30,
64
+ "humaneval": 164,
65
+ "mbpp": 128,
66
+ "livecodebench": 128,
67
+ "swe-bench": 128,
68
+ "mt-bench": 80,
69
+ "alpaca": 128,
70
+ }
71
+
72
+
73
+ # ──────────────────────────────────────────────────────────────────
74
+ # CUDA-synchronised timer (matches official benchmark.py)
75
+ # ──────────────────────────────────────────────────────────────────
76
+ def cuda_time() -> float:
77
+ torch.cuda.synchronize()
78
+ return time.perf_counter()
79
+
80
+
81
+ def has_flash_attn() -> bool:
82
+ try:
83
+ import flash_attn # noqa: F401
84
+ return True
85
+ except ImportError:
86
+ print("[WARN] flash_attn not installed, falling back to sdpa.")
87
+ return False
88
+
89
+
90
+ # ──────────────────────────────────────────────────────────────────
91
+ # Distributed helpers (mirrors official distributed.py)
92
+ # ──────────────────────────────────────────────────────────────────
93
+ def dist_init():
94
+ if "RANK" not in os.environ:
95
+ warnings.warn("RANK not set. Skipping distributed init.")
96
+ return
97
+ dist.init_process_group(backend="nccl", init_method="env://")
98
+
99
+ def dist_rank():
100
+ return int(os.environ.get("RANK", 0))
101
+
102
+ def dist_size():
103
+ return int(os.environ.get("WORLD_SIZE", 1))
104
+
105
+ def dist_local_rank():
106
+ return int(os.environ.get("LOCAL_RANK", 0))
107
+
108
+ def dist_is_main():
109
+ return dist_rank() == 0
110
+
111
+ def dist_gather(obj, dst=0):
112
+ if not dist.is_initialized():
113
+ return [obj]
114
+ if dist_is_main():
115
+ objs = [None for _ in range(dist_size())]
116
+ dist.gather_object(obj, objs, dst=dst)
117
+ return objs
118
+ else:
119
+ dist.gather_object(obj, dst=dst)
120
+ return None
121
+
122
+ def print_rank0(*args, **kwargs):
123
+ if dist_is_main():
124
+ print(*args, **kwargs)
125
+
126
+
127
+ # ──────────────────────────────────────────────────────────────────
128
+ # Sampling (matches official model/utils.py::sample)
129
+ # ──────────────────────────────────────────────────────────────────
130
+ def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor:
131
+ if temperature < 1e-5:
132
+ return torch.argmax(logits, dim=-1)
133
+ bsz, seq_len, vocab_size = logits.shape
134
+ logits = logits.view(-1, vocab_size)
135
+ logits = logits / temperature
136
+ probs = torch.softmax(logits, dim=-1)
137
+ return torch.multinomial(probs, num_samples=1).view(bsz, seq_len)
138
+
139
+
140
+ # ──────────────────────────────────────────────────────────────────
141
+ # Build DFlash attention mask (vectorized, no Python loops)
142
+ # ──────────────────────────────────────────────────────────────────
143
+ def build_dflash_mask(ctx_len: int, block_size: int, device, dtype=torch.bfloat16):
144
+ """
145
+ Build DFlash attention mask for [context | block] sequence.
146
+ - Context part: standard causal
147
+ - Block part: each token sees all context + all tokens in same block (bidirectional)
148
+ """
149
+ full_len = ctx_len + block_size
150
+ neg_inf = torch.finfo(dtype).min
151
+
152
+ mask = torch.full((1, 1, full_len, full_len), neg_inf, device=device, dtype=dtype)
153
+
154
+ if ctx_len > 0:
155
+ ctx_rows = torch.arange(ctx_len, device=device)
156
+ ctx_cols = torch.arange(ctx_len, device=device)
157
+ causal = ctx_cols.unsqueeze(0) <= ctx_rows.unsqueeze(1)
158
+ mask[0, 0, :ctx_len, :ctx_len].masked_fill_(causal, 0)
159
+
160
+ if ctx_len > 0:
161
+ mask[0, 0, ctx_len:, :ctx_len] = 0
162
+ mask[0, 0, ctx_len:, ctx_len:] = 0
163
+
164
+ return mask
165
+
166
+
167
+ # ──────────────────────────────────────────────────────────────────
168
+ # Pure autoregressive generation (target model only, no draft)
169
+ # Used for AR baseline timing -- avoids inflating AR time with draft overhead.
170
+ # ──────────────────────────────────────────────────────────────────
171
+ @torch.inference_mode()
172
+ def ar_generate(
173
+ target_model: nn.Module,
174
+ input_ids: torch.LongTensor,
175
+ max_new_tokens: int = 2048,
176
+ mask_token_id: int = MASK_TOKEN_ID,
177
+ temperature: float = 0.0,
178
+ stop_token_ids: Optional[List[int]] = None,
179
+ ) -> SimpleNamespace:
180
+ """
181
+ Pure autoregressive generation using only the target model.
182
+ Mirrors official benchmark.py with block_size=1 (no draft model involved).
183
+ Returns SimpleNamespace matching official dflash_generate output format.
184
+ """
185
+ device = input_ids.device
186
+ num_input_tokens = input_ids.shape[1]
187
+ max_length = num_input_tokens + max_new_tokens
188
+
189
+ output_ids = torch.full(
190
+ (1, max_length + 1), mask_token_id,
191
+ dtype=torch.long, device=device,
192
+ )
193
+ output_ids[:, :num_input_tokens] = input_ids
194
+ position_ids = torch.arange(output_ids.shape[1], device=device).unsqueeze(0)
195
+ past_key_values = DynamicCache()
196
+
197
+ # Prefill
198
+ prefill_start = cuda_time()
199
+ output = target_model(
200
+ input_ids,
201
+ position_ids=position_ids[:, :num_input_tokens],
202
+ past_key_values=past_key_values,
203
+ use_cache=True,
204
+ logits_to_keep=1,
205
+ output_hidden_states=False,
206
+ )
207
+ first_token = sample(output.logits, temperature)
208
+ output_ids[:, num_input_tokens:num_input_tokens + 1] = first_token
209
+ time_to_first_token = cuda_time() - prefill_start
210
+
211
+ # Decode (autoregressive, one token at a time)
212
+ decode_start = cuda_time()
213
+ start = num_input_tokens
214
+
215
+ while start < max_length:
216
+ cur_token = output_ids[:, start:start + 1]
217
+ cur_pos = position_ids[:, start:start + 1]
218
+
219
+ output = target_model(
220
+ cur_token,
221
+ position_ids=cur_pos,
222
+ past_key_values=past_key_values,
223
+ use_cache=True,
224
+ output_hidden_states=False,
225
+ )
226
+
227
+ next_token = sample(output.logits, temperature)
228
+ start += 1
229
+ output_ids[:, start:start + 1] = next_token
230
+ past_key_values.crop(start)
231
+
232
+ # Check stop tokens (matches official: check all generated)
233
+ if stop_token_ids is not None and any(
234
+ sid in output_ids[:, num_input_tokens:] for sid in stop_token_ids
235
+ ):
236
+ break
237
+
238
+ output_ids = output_ids[:, :max_length]
239
+ output_ids = output_ids[:, output_ids[0] != mask_token_id]
240
+ if stop_token_ids is not None:
241
+ stop_t = torch.tensor(stop_token_ids, device=output_ids.device)
242
+ stop_idx = torch.isin(output_ids[0][num_input_tokens:], stop_t).nonzero(as_tuple=True)[0]
243
+ if stop_idx.numel() > 0:
244
+ output_ids = output_ids[:, :num_input_tokens + stop_idx[0] + 1]
245
+
246
+ num_output_tokens = output_ids.shape[1] - num_input_tokens
247
+ total_decode_time = cuda_time() - decode_start
248
+ time_per_output_token = total_decode_time / max(num_output_tokens, 1)
249
+
250
+ return SimpleNamespace(
251
+ output_ids=output_ids,
252
+ num_input_tokens=num_input_tokens,
253
+ num_output_tokens=num_output_tokens,
254
+ time_to_first_token=time_to_first_token,
255
+ time_per_output_token=time_per_output_token,
256
+ acceptance_lengths=[1] * max(num_output_tokens, 0), # AR: always 1
257
+ )
258
+
259
+
260
+ # ──────────────────────────────────────────────────────────────────
261
+ # Core: spec_generate with layer-by-layer injection (KV-cached)
262
+ # ──────────────────────────────────────────────────────────────────
263
+ @torch.inference_mode()
264
+ def spec_generate_inject(
265
+ target_model: nn.Module,
266
+ draft_model: nn.Module,
267
+ input_ids: torch.LongTensor,
268
+ max_new_tokens: int = 2048,
269
+ block_size: int = 16,
270
+ mask_token_id: int = MASK_TOKEN_ID,
271
+ temperature: float = 0.0,
272
+ stop_token_ids: Optional[List[int]] = None,
273
+ ) -> SimpleNamespace:
274
+ """
275
+ Speculative generation using DFlash-LoRA-Inject inference pattern.
276
+ Returns SimpleNamespace matching official dflash_generate output format.
277
+ """
278
+ device = input_ids.device
279
+ num_input_tokens = input_ids.shape[1]
280
+ max_length = num_input_tokens + max_new_tokens
281
+
282
+ draft_layers = draft_model.model.layers
283
+ draft_norm = draft_model.model.norm
284
+ draft_lm_head = draft_model.lm_head
285
+ rotary_emb = draft_model.model.rotary_emb
286
+ num_layers = len(draft_layers)
287
+
288
+ output_ids = torch.full(
289
+ (1, max_length + block_size), mask_token_id,
290
+ dtype=torch.long, device=device,
291
+ )
292
+ output_ids[:, :num_input_tokens] = input_ids
293
+
294
+ # ── Prefill: target with KV cache + hidden states ──
295
+ prefill_start = cuda_time()
296
+ target_kv = DynamicCache()
297
+ target_output = target_model(
298
+ input_ids,
299
+ past_key_values=target_kv,
300
+ use_cache=True,
301
+ output_hidden_states=True,
302
+ )
303
+ first_token = sample(target_output.logits[:, -1:, :], temperature)
304
+ output_ids[:, num_input_tokens] = first_token.squeeze()
305
+
306
+ ctx_hidden_per_layer = [
307
+ target_output.hidden_states[i + 1]
308
+ for i in range(num_layers)
309
+ ]
310
+
311
+ time_to_first_token = cuda_time() - prefill_start
312
+
313
+ # Decode
314
+ decode_start = cuda_time()
315
+ acceptance_lengths = []
316
+ start = num_input_tokens
317
+ draft_prefill = True
318
+
319
+ while start < max_length:
320
+ end = min(start + block_size, max_length)
321
+ actual_block_size = end - start
322
+
323
+ block_ids = output_ids[:, start:end].clone()
324
+
325
+ # ── Draft: forward with layer-by-layer injection ──
326
+ draft_hidden = draft_model.model.embed_tokens(block_ids)
327
+ ctx_len = ctx_hidden_per_layer[0].shape[1]
328
+
329
+ dflash_mask = build_dflash_mask(ctx_len, actual_block_size, device)
330
+ combined_pos = torch.arange(ctx_len + actual_block_size, device=device).unsqueeze(0)
331
+
332
+ dummy_combined = torch.empty(1, ctx_len + actual_block_size, draft_hidden.shape[-1],
333
+ device=device, dtype=torch.bfloat16)
334
+ position_embeddings = rotary_emb(dummy_combined, combined_pos)
335
+
336
+ for layer_idx in range(num_layers):
337
+ target_ctx = ctx_hidden_per_layer[layer_idx]
338
+ combined = torch.cat([target_ctx, draft_hidden], dim=1)
339
+
340
+ layer_output = draft_layers[layer_idx](
341
+ combined,
342
+ attention_mask=dflash_mask,
343
+ position_ids=combined_pos,
344
+ position_embeddings=position_embeddings,
345
+ )
346
+ if isinstance(layer_output, tuple):
347
+ layer_output = layer_output[0]
348
+ draft_hidden = layer_output[:, ctx_len:, :]
349
+
350
+ draft_hidden = draft_norm(draft_hidden)
351
+ draft_logits = draft_lm_head(draft_hidden)
352
+
353
+ draft_predictions = sample(draft_logits[:, :-1, :], temperature)
354
+ block_ids[:, 1:actual_block_size] = draft_predictions[:, :actual_block_size - 1]
355
+
356
+ # Exclude draft's first prefill from decode timing (matches official pattern)
357
+ if draft_prefill:
358
+ draft_prefill = False
359
+ decode_start = cuda_time()
360
+
361
+ # ── Verify: target forward on block tokens (with KV cache) ──
362
+ position_ids_block = torch.arange(
363
+ start, start + actual_block_size, device=device
364
+ ).unsqueeze(0)
365
+
366
+ target_verify = target_model(
367
+ block_ids,
368
+ position_ids=position_ids_block,
369
+ past_key_values=target_kv,
370
+ use_cache=True,
371
+ output_hidden_states=True,
372
+ )
373
+ target_tokens = sample(target_verify.logits, temperature)
374
+
375
+ # Acceptance
376
+ matches = (block_ids[:, 1:actual_block_size] == target_tokens[:, :actual_block_size - 1])
377
+ acceptance_length = int(matches.cumprod(dim=1).sum(dim=1)[0].item())
378
+
379
+ output_ids[:, start:start + acceptance_length + 1] = block_ids[:, :acceptance_length + 1]
380
+ output_ids[:, start + acceptance_length + 1] = target_tokens[:, acceptance_length]
381
+
382
+ accepted_end = start + acceptance_length + 1
383
+ target_kv.crop(accepted_end)
384
+
385
+ for i in range(num_layers):
386
+ new_hidden = target_verify.hidden_states[i + 1][:, :acceptance_length + 1, :]
387
+ ctx_hidden_per_layer[i] = torch.cat([ctx_hidden_per_layer[i], new_hidden], dim=1)
388
+
389
+ start += acceptance_length + 1
390
+ acceptance_lengths.append(acceptance_length + 1)
391
+
392
+ # Official: check ALL generated tokens
393
+ if stop_token_ids is not None and any(
394
+ sid in output_ids[:, num_input_tokens:] for sid in stop_token_ids
395
+ ):
396
+ break
397
+
398
+ output_ids = output_ids[:, :min(start, max_length)]
399
+ output_ids = output_ids[:, output_ids[0] != mask_token_id]
400
+ if stop_token_ids is not None:
401
+ stop_t = torch.tensor(stop_token_ids, device=output_ids.device)
402
+ stop_idx = torch.isin(output_ids[0][num_input_tokens:], stop_t).nonzero(as_tuple=True)[0]
403
+ if stop_idx.numel() > 0:
404
+ output_ids = output_ids[:, :num_input_tokens + stop_idx[0] + 1]
405
+
406
+ num_output_tokens = output_ids.shape[1] - num_input_tokens
407
+ total_decode_time = cuda_time() - decode_start
408
+ time_per_output_token = total_decode_time / max(num_output_tokens, 1)
409
+
410
+ return SimpleNamespace(
411
+ output_ids=output_ids,
412
+ num_input_tokens=num_input_tokens,
413
+ num_output_tokens=num_output_tokens,
414
+ time_to_first_token=time_to_first_token,
415
+ time_per_output_token=time_per_output_token,
416
+ acceptance_lengths=acceptance_lengths,
417
+ )
418
+
419
+
420
+ # ──────────────────────────────────────────────────────────────────
421
+ # Main
422
+ # ──────────────────────────────────────────────────────────────────
423
+ def parse_args():
424
+ p = argparse.ArgumentParser(description="Offline eval for DFlash-LoRA-Inject (aligned with official)")
425
+ p.add_argument("--base-model", default=BASE_MODEL)
426
+ p.add_argument("--adapter-root", default=ADAPTER_ROOT)
427
+ p.add_argument("--ckpt", default=DEFAULT_CKPT, help="Checkpoint folder name")
428
+ p.add_argument("--merged-path",
429
+ default="/workspace/hanrui/syxin/Specforge/outputs/qwen3-8b-dflash-lora-inject-merged",
430
+ help="Path to pre-merged model. If None, will merge on the fly.")
431
+ p.add_argument("--block-size", type=int, default=BLOCK_SIZE)
432
+ p.add_argument("--max-new-tokens", type=int, default=2048,
433
+ help="Max new tokens per turn (official shell uses 2048)")
434
+ p.add_argument("--temperature", type=float, default=0.0)
435
+ p.add_argument("--datasets", nargs="+", default=list(OFFICIAL_TASKS.keys()),
436
+ help="Benchmarks to run (default: all 10 official tasks)")
437
+ p.add_argument("--max-samples", type=int, default=None,
438
+ help="Override max samples per dataset (None = use official per-task counts)")
439
+ p.add_argument("--output-dir", default=RESULT_DIR)
440
+ return p.parse_args()
441
+
442
+
443
+ def main():
444
+ args = parse_args()
445
+
446
+ # Fix random seeds (matches official)
447
+ random.seed(0)
448
+ np.random.seed(0)
449
+ torch.manual_seed(0)
450
+ torch.cuda.manual_seed_all(0)
451
+ torch.backends.cudnn.deterministic = True
452
+ torch.backends.cudnn.benchmark = False
453
+
454
+ # ── Init distributed ──
455
+ dist_init()
456
+ torch.cuda.set_device(dist_local_rank())
457
+ device = torch.device(f"cuda:{dist_local_rank()}")
458
+
459
+ print_rank0(f"Running on {dist_size()} GPU(s)")
460
+
461
+ # Detect flash_attn (only for target model; draft needs sdpa for custom DFlash mask)
462
+ installed_flash_attn = has_flash_attn()
463
+ target_attn_impl = "flash_attention_2" if installed_flash_attn else "sdpa"
464
+ draft_attn_impl = "sdpa" # DFlash injection uses custom attention mask
465
+ print_rank0(f"Using attn_implementation: target={target_attn_impl}, draft={draft_attn_impl}")
466
+
467
+ # ── Load models ──
468
+ print_rank0(f"Loading target model: {args.base_model}")
469
+ target_model = AutoModelForCausalLM.from_pretrained(
470
+ args.base_model,
471
+ torch_dtype=torch.bfloat16,
472
+ attn_implementation=target_attn_impl,
473
+ device_map=device,
474
+ trust_remote_code=True,
475
+ )
476
+ target_model.eval()
477
+
478
+ if args.merged_path and os.path.isdir(args.merged_path):
479
+ print_rank0(f"Loading pre-merged draft model: {args.merged_path}")
480
+ draft_model = AutoModelForCausalLM.from_pretrained(
481
+ args.merged_path,
482
+ torch_dtype=torch.bfloat16,
483
+ attn_implementation=draft_attn_impl,
484
+ device_map=device,
485
+ trust_remote_code=True,
486
+ )
487
+ else:
488
+ adapter_path = os.path.join(args.adapter_root, args.ckpt)
489
+ print_rank0(f"Loading base + LoRA adapter: {adapter_path}")
490
+ draft_model = AutoModelForCausalLM.from_pretrained(
491
+ args.base_model,
492
+ torch_dtype=torch.bfloat16,
493
+ attn_implementation=draft_attn_impl,
494
+ device_map=device,
495
+ trust_remote_code=True,
496
+ )
497
+ draft_model = PeftModel.from_pretrained(draft_model, adapter_path)
498
+ draft_model = draft_model.merge_and_unload()
499
+ draft_model.eval()
500
+
501
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
502
+ stop_token_ids = [tokenizer.eos_token_id]
503
+
504
+ block_size = args.block_size
505
+
506
+ # ── Run benchmarks ──
507
+ all_results = {"model": f"dflash-lora-inject/{args.ckpt}", "block_size": block_size}
508
+
509
+ for dataset_name in args.datasets:
510
+ print_rank0(f"\n{'=' * 60}")
511
+ print_rank0(f"Benchmark: {dataset_name} ({dist_size()} GPUs)")
512
+ print_rank0(f"{'=' * 60}")
513
+
514
+ # Load dataset using official loader
515
+ dataset = load_and_process_dataset(dataset_name)
516
+
517
+ # Sample selection: official uses shuffle(seed=0).select()
518
+ max_samples = args.max_samples if args.max_samples is not None else OFFICIAL_TASKS.get(dataset_name)
519
+ if max_samples is not None and len(dataset) > max_samples:
520
+ dataset = dataset.shuffle(seed=0).select(range(max_samples))
521
+
522
+ print_rank0(f"Total {len(dataset)} samples, distributed across {dist_size()} GPUs")
523
+
524
+ responses = []
525
+ indices = range(dist_rank(), len(dataset), dist_size())
526
+
527
+ iterator = tqdm(indices, desc=f"[GPU{dist_rank()}] {dataset_name}",
528
+ unit="sample", disable=not dist_is_main())
529
+
530
+ for idx in iterator:
531
+ instance = dataset[idx]
532
+
533
+ # Multi-turn support (matches official benchmark.py)
534
+ messages = []
535
+ for turn_index, user_content in enumerate(instance["turns"]):
536
+ messages.append({"role": "user", "content": user_content})
537
+ input_text = tokenizer.apply_chat_template(
538
+ messages, tokenize=False, add_generation_prompt=True,
539
+ enable_thinking=False,
540
+ )
541
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
542
+
543
+ response = {}
544
+
545
+ # AR baseline: pure target-only autoregressive (no draft overhead)
546
+ response[1] = ar_generate(
547
+ target_model=target_model,
548
+ input_ids=input_ids,
549
+ max_new_tokens=args.max_new_tokens,
550
+ mask_token_id=MASK_TOKEN_ID,
551
+ temperature=args.temperature,
552
+ stop_token_ids=stop_token_ids,
553
+ )
554
+
555
+ # Speculative: DFlash-LoRA-Inject
556
+ response[block_size] = spec_generate_inject(
557
+ target_model=target_model,
558
+ draft_model=draft_model,
559
+ input_ids=input_ids,
560
+ max_new_tokens=args.max_new_tokens,
561
+ block_size=block_size,
562
+ mask_token_id=MASK_TOKEN_ID,
563
+ temperature=args.temperature,
564
+ stop_token_ids=stop_token_ids,
565
+ )
566
+
567
+ # Append assistant response for multi-turn context
568
+ spec_response = response[block_size]
569
+ generated_ids = spec_response.output_ids[0, spec_response.num_input_tokens:]
570
+ output_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
571
+ messages.append({"role": "assistant", "content": output_text})
572
+ responses.append(response)
573
+
574
+ if dist_is_main() and responses:
575
+ recent_tau = np.mean([np.mean(r[block_size].acceptance_lengths) for r in responses[-5:]])
576
+ iterator.set_postfix(accept_len=f"{recent_tau:.2f}")
577
+
578
+ # ── Gather to rank 0 (matches official) ──
579
+ if dist_size() > 1:
580
+ gathered = dist_gather(responses, dst=0)
581
+ if not dist_is_main():
582
+ continue
583
+ responses = list(chain(*gathered))
584
+ elif not dist_is_main():
585
+ continue
586
+
587
+ # ── Compute metrics (exact official formulas) ──
588
+ t1 = np.mean([r[1].time_per_output_token for r in responses])
589
+ tb = np.mean([r[block_size].time_per_output_token for r in responses])
590
+ speedup = t1 / tb if tb > 0 else 0
591
+
592
+ # Acceptance length: per-sample mean, then mean of means (official)
593
+ tau = np.mean([np.mean(r[block_size].acceptance_lengths) for r in responses])
594
+
595
+ # Histogram
596
+ acceptance_lengths = list(chain(*[r[block_size].acceptance_lengths for r in responses]))
597
+ histogram = [acceptance_lengths.count(b) / len(acceptance_lengths) for b in range(block_size + 1)]
598
+
599
+ print_rank0(f"\n{dataset_name} Results:")
600
+ print_rank0(f" Decoding speedup: {speedup:.2f}x")
601
+ print_rank0(f" Average Acceptance length: {tau:.2f}")
602
+ print_rank0(f" Acceptance length histogram: {[f'{x * 100:.1f}%' for x in histogram]}")
603
+ print_rank0(f" Num responses: {len(responses)}")
604
+
605
+ all_results[dataset_name] = {
606
+ "decoding_speedup": speedup,
607
+ "avg_accept_length": tau,
608
+ "acceptance_histogram": histogram,
609
+ "num_responses": len(responses),
610
+ "num_gpus": dist_size(),
611
+ }
612
+
613
+ # ── Save results ──
614
+ if dist_is_main():
615
+ os.makedirs(args.output_dir, exist_ok=True)
616
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
617
+ result_file = os.path.join(
618
+ args.output_dir,
619
+ f"dflash_lora_inject_offline_{args.ckpt}_{timestamp}.json",
620
+ )
621
+ with open(result_file, "w") as f:
622
+ json.dump(all_results, f, indent=2)
623
+ print(f"\nResults saved to: {result_file}")
624
+
625
+
626
+ if __name__ == "__main__":
627
+ main()
syxin/idea.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 现在关于target model的hidden state注入
2
+
3
+ dflash的做法是,抽5层的feature过一下fc然后concat到mask token对应的hidden state前面
4
+
5
+ 但是如果我们的draft是用lora的原始模型
6
+
7
+ 我们不用这样注入
8
+
9
+ 我们可以直接把target model的hidden state直接层对层拉过来
10
+
11
+ 我是把加了lora后的模型作为draft model用的
12
+
13
+ 它本质上还是一个speculative decode
14
+
15
+ 我的想法的核心是,因为这个draft model足够大,也和target model足够像,把他转为和dflash一样每次用mask直接生成16个token,可能能得到很长的accept len,以此获得加速
16
+
17
+ 而dflash能work的核心是,它在生成阶段是使用的部分target model的hidden state,注入到mask token的hidden state前面
18
+
19
+ 我们也用相同的做法
20
+
21
+ 带lora的模型,lora只负责让它能并行解码16个mask token,但是前面的上下文信息,依然用原始model跑出来的,通过注入放进draft的时候
22
+
23
+ 而且由于模型结构的一致,我们可以直接层对层注入进去
syxin/launch_train.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ cd /workspace/hanrui/syxin/Specforge
5
+
6
+ export TORCHINDUCTOR_CACHE_DIR=/workspace/hanrui/cache/compiled_kernels
7
+ export SPECFORGE_DATA_NUM_PROC=16
8
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
9
+ export PYTORCH_ALLOC_CONF=expandable_segments:True
10
+ export HF_DATASETS_CACHE=/workspace/hanrui/cache/hf_datasets
11
+ export HF_HOME=/workspace/hanrui/cache/hf_home
12
+
13
+ torchrun --nproc_per_node=8 \
14
+ scripts/train_dflash_lora_inject.py \
15
+ --target-model-path /workspace/models/Qwen3-8B \
16
+ --target-model-backend hf \
17
+ --train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
18
+ --output-dir outputs/qwen3-8b-sft-32gpu-v2 \
19
+ --block-size 16 \
20
+ --attention-backend additive \
21
+ --attn-implementation sdpa \
22
+ --max-length 2048 \
23
+ --batch-size 4 \
24
+ --accumulation-steps 8 \
25
+ --num-epochs 3 \
26
+ --learning-rate 5e-5 \
27
+ --loss-decay-gamma 7 \
28
+ --gradient-checkpointing \
29
+ --chat-template qwen \
30
+ --log-interval 50 \
31
+ --save-interval 500 \
32
+ --cache-dir /workspace/hanrui/cache \
33
+ --lora-rank 32 \
34
+ --lora-alpha 64 \
35
+ --lora-dropout 0.1 \
36
+ --trust-remote-code \
37
+ --dataloader-num-workers 0
syxin/launch_train_wrapper.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Python wrapper to launch bash training script via torchrun
4
+ """
5
+ import subprocess
6
+ import sys
7
+ import os
8
+
9
+ if __name__ == "__main__":
10
+ # Get the bash script path and arguments
11
+ bash_script = "/workspace/hanrui/syxin/run_train_multinode.sh"
12
+ args = sys.argv[1:] # Pass through all arguments
13
+
14
+ # Build the command
15
+ cmd = ["bash", bash_script] + args
16
+
17
+ # Execute the bash script
18
+ result = subprocess.run(cmd, env=os.environ.copy())
19
+
20
+ # Exit with the same code as the bash script
21
+ sys.exit(result.returncode)
syxin/list.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 1. `train_dflash_lora.py`
2
+ * 加了lora,原来是调用小模型,现在是hidden states+lora预测。
3
+ * `dflash_lora_mask_fn`函数是在处理预测的那一块草稿Block时,可以同时看到这一块里的所有词。
4
+
5
+ ### 2. OOM优化
6
+ * 分片策略ZeRO-3,FSDP切分从`SHARD_GRAD_OP`升级到`FULL_SHARD`。
7
+ * `batch-size=1`,`accumulation-steps=8`。
8
+ * 参考之前的代码用了FlexAttention(`dflash_lora_mask_fn`)。
9
+ * `_chunked_lm_loss()`,把算loss切片成256块来算+梯度检查。
10
+
11
+ ### 运行
12
+ * bash /workspace/hanrui/junquan/SpecForge/scripts/run_train_dflash_lora.sh 2
syxin/merge_lora.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Step 1: Merge DFlash-LoRA adapter into base model.
3
+ Usage:
4
+ conda activate sglang
5
+ python3 merge_lora.py
6
+ python3 merge_lora.py --ckpt epoch_2_step_15000 # 测其他 checkpoint
7
+ """
8
+ import argparse
9
+ import os
10
+
11
+ import torch
12
+ from peft import PeftModel
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+
15
+ BASE_MODEL = "/workspace/models/Qwen3-8B"
16
+ OUTPUT_ROOT = "/workspace/hanrui/syxin/Specforge/outputs/qwen3-8b-dflash-lora"
17
+ MERGE_ROOT = "/workspace/hanrui/syxin/Specforge/outputs/qwen3-8b-dflash-lora-merged"
18
+
19
+ def parse_args():
20
+ p = argparse.ArgumentParser()
21
+ p.add_argument("--ckpt", default="epoch_3_step_18576",
22
+ help="Checkpoint folder name under OUTPUT_ROOT")
23
+ p.add_argument("--merged-path", default=MERGE_ROOT,
24
+ help="Where to save the merged model")
25
+ return p.parse_args()
26
+
27
+
28
+ def main():
29
+ args = parse_args()
30
+ adapter_path = os.path.join(OUTPUT_ROOT, args.ckpt)
31
+ merged_path = args.merged_path
32
+
33
+ if os.path.exists(merged_path):
34
+ print(f"[skip] Merged model already exists: {merged_path}")
35
+ return
36
+
37
+ assert os.path.isdir(adapter_path), f"Adapter not found: {adapter_path}"
38
+
39
+ print(f"Base model : {BASE_MODEL}")
40
+ print(f"Adapter : {adapter_path}")
41
+ print(f"Output : {merged_path}")
42
+ print()
43
+
44
+ print("[1/4] Loading base model to CPU ...")
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ BASE_MODEL,
47
+ torch_dtype=torch.bfloat16,
48
+ device_map="cpu",
49
+ )
50
+
51
+ print("[2/4] Loading LoRA adapter ...")
52
+ model = PeftModel.from_pretrained(model, adapter_path)
53
+
54
+ print("[3/4] Merging weights ...")
55
+ model = model.merge_and_unload()
56
+
57
+ print("[4/4] Saving merged model ...")
58
+ os.makedirs(merged_path, exist_ok=True)
59
+ model.save_pretrained(merged_path, safe_serialization=True)
60
+ AutoTokenizer.from_pretrained(BASE_MODEL).save_pretrained(merged_path)
61
+
62
+ print(f"\nDone. Merged model saved to: {merged_path}")
63
+
64
+
65
+ if __name__ == "__main__":
66
+ main()
syxin/oom_fix_progress.md ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DFlash LoRA OOM 修复记录
2
+
3
+ ## OOM 根因分析
4
+
5
+ 1. **SHARD_GRAD_OP (ZeRO-2)** — 每卡持有完整 Qwen3-8B 参数 (~16GB bf16),参数未分片
6
+ 2. **SDPA + 4D additive mask** — FlashAttention 不支持 4D additive mask,fallback 到 math backend,每层 materialize 完整 attention scores (`bsz × 32heads × 2048 × 2048`)
7
+ 3. **大 vocab logits** — `[bsz, 2048, 151936]` bf16 ≈ 1.18GB,加上梯度和 boolean indexing 拷贝,峰值 ~3-4GB
8
+ 4. **机器只有 2 张 H100**,脚本默认 `NUM_GPUS=4`
9
+
10
+ ## 已完成的改动
11
+
12
+ ### 1. FSDP sharding 改为 FULL_SHARD (ZeRO-3)
13
+ - 文件: `SpecForge/scripts/train_dflash_lora.py:347`
14
+ - `ShardingStrategy.SHARD_GRAD_OP` → `ShardingStrategy.FULL_SHARD`
15
+ - 效果: 参数跨卡分片,每卡省 ~8-12GB
16
+
17
+ ### 2. 降 batch-size,提高 accumulation-steps
18
+ - 文件: `SpecForge/scripts/run_train_dflash_lora.sh`
19
+ - `--batch-size 2` → `1`,`--accumulation-steps 4` → `8`
20
+ - 效果: 等效 global batch size 不变,峰值显存减半
21
+
22
+ ## 待验证 / 后续优化
23
+
24
+ - [ ] 运行时传 `bash run_train_dflash_lora.sh 2` 确保用 2 卡
25
+ - [x] 如仍 OOM,考虑 chunked cross-entropy loss 避免大 vocab logits 全量 materialize
26
+ - [x] 长期可探索自定义 attention kernel 支持 block-sparse mask,绕过 SDPA math fallback
27
+
28
+ ### 3. flex_attention + BlockMask 替换 4D additive mask
29
+ - 文件: `SpecForge/specforge/core/dflash_lora.py`, `specforge/modeling/draft/dflash_lora.py`, `scripts/train_dflash_lora.py`
30
+ - 从非 LoRA 版 `dflash.py` 移植 `_get_or_create_block_mask()` 方法,适配 LoRA 场景 (Q_LEN == KV_LEN == seq_len)
31
+ - LoRA 版 mask: context causal + block bidirectional (非 LoRA 版是 [context, noise] concat KV)
32
+ - 用 `--attention-backend flex_attention` 启用 (默认),退回 `--attention-backend additive` 走原有 4D mask
33
+ - HuggingFace model 用 `attn_implementation="flex_attention"` 加载
34
+ - 效果: 不再 fallback 到 SDPA math backend,省去 `[bsz, heads, seq, seq]` attention scores 的显存
35
+
36
+ ### 4. chunked cross-entropy loss
37
+ - 文件: `SpecForge/specforge/core/dflash_lora.py`, `specforge/modeling/draft/dflash_lora.py`, `scripts/train_dflash_lora.py`
38
+ - 从非 LoRA 版 `dflash.py` 移植 `_chunked_lm_loss()` 方法
39
+ - 分 chunk 过 lm_head + CE loss + gradient checkpointing,避免 materialize 完整 `[bsz, seq, vocab]` logits
40
+ - 用 `--lm-head-chunk-size 256` 启用 (默认 0 = 不启用)
41
+ - `DFlashLoRADraftModel.forward()` 新增 `output_hidden_states` 参数,chunked 时返回 hidden states
42
+ - 效果: logits 峰值显存从 O(seq_len × vocab_size) 降至 O(chunk_size × vocab_size)
syxin/requirements.txt ADDED
File without changes
syxin/run_bench.sh ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Step 3: Run HumanEval / MT-Bench / GSM8K benchmarks.
3
+ # Run AFTER start_server.sh is up.
4
+ # Usage:
5
+ # bash run_bench.sh # all three benches, full dataset
6
+ # bash run_bench.sh humaneval # only humaneval
7
+ # bash run_bench.sh mtbench gsm8k # pick any subset
8
+
9
+ set -e
10
+
11
+ INTRANET_IP=10.1.1.131
12
+ PORT=30000
13
+ BASE_MODEL=/workspace/models/Qwen3-8B
14
+ MERGED=/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-merged
15
+ BENCH_DIR=/workspace/hanrui/syxin_old/Specforge/benchmarks
16
+ RESULT_DIR=$BENCH_DIR/results
17
+
18
+ # ---- sanity check ----
19
+ echo "Checking server at http://$INTRANET_IP:$PORT ..."
20
+ curl -sf http://$INTRANET_IP:$PORT/v1/models > /dev/null || {
21
+ echo "[ERROR] Server not reachable. Start it first: bash start_server.sh"
22
+ exit 1
23
+ }
24
+ echo "Server OK."
25
+
26
+ mkdir -p $RESULT_DIR
27
+ cd $BENCH_DIR
28
+ export PYTHONPATH=/workspace/hanrui/syxin_old/Specforge:$PYTHONPATH
29
+
30
+ # ---- decide which benches to run ----
31
+ TARGETS=("$@")
32
+ if [ ${#TARGETS[@]} -eq 0 ]; then
33
+ TARGETS=(humaneval mtbench gsm8k)
34
+ fi
35
+
36
+ BENCH_ARGS=""
37
+ for t in "${TARGETS[@]}"; do
38
+ case $t in
39
+ humaneval) BENCH_ARGS="$BENCH_ARGS humaneval:164" ;;
40
+ mtbench) BENCH_ARGS="$BENCH_ARGS mtbench:80" ;;
41
+ gsm8k) BENCH_ARGS="$BENCH_ARGS gsm8k:1319" ;;
42
+ *)
43
+ echo "[ERROR] Unknown bench: $t (choices: humaneval mtbench gsm8k)"
44
+ exit 1
45
+ ;;
46
+ esac
47
+ done
48
+
49
+ TIMESTAMP=$(date +%Y%m%d_%H%M%S)
50
+ echo "Running: $BENCH_ARGS"
51
+ echo "Results -> $RESULT_DIR"
52
+ echo ""
53
+
54
+ python3 bench_eagle3.py \
55
+ --model-path $BASE_MODEL \
56
+ --speculative-draft-model-path $MERGED \
57
+ --host $INTRANET_IP \
58
+ --port $PORT \
59
+ --config-list "16,4,1,4" \
60
+ --benchmark-list $BENCH_ARGS \
61
+ --output-dir $RESULT_DIR \
62
+ --name dflash_lora_${TIMESTAMP} \
63
+ --skip-launch-server \
64
+ 2>&1 | tee $RESULT_DIR/bench_${TIMESTAMP}.log
65
+
66
+ echo ""
67
+ echo "Done. Latest result files:"
68
+ ls -lht $RESULT_DIR/*.jsonl 2>/dev/null | head -5
syxin/run_bench_dflash.sh ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Evaluate DFlash-LoRA-Inject accepted length (offline, 8 GPUs parallel).
3
+ # No sglang server needed. Each GPU loads its own target+draft and processes a shard.
4
+ #
5
+ # Usage:
6
+ # bash run_bench_dflash.sh # 8 GPUs, all 3 benches
7
+ # bash run_bench_dflash.sh humaneval # only humaneval
8
+ # bash run_bench_dflash.sh mtbench gsm8k # pick any subset
9
+ # bash run_bench_dflash.sh --quick # quick test (20 samples)
10
+ # bash run_bench_dflash.sh --ckpt epoch_0_step_500 # specific checkpoint
11
+ # NUM_GPUS=4 bash run_bench_dflash.sh # use 4 GPUs
12
+
13
+ set -e
14
+
15
+ SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)
16
+ PYTHON=/workspace/miniconda3/envs/spec/bin/python3
17
+ RESULT_DIR=/workspace/hanrui/syxin_old/Specforge/benchmarks/results
18
+ NUM_GPUS=${NUM_GPUS:-8}
19
+
20
+ # ---- parse args ----
21
+ BENCHMARKS=()
22
+ EXTRA_ARGS=()
23
+ QUICK=false
24
+
25
+ for arg in "$@"; do
26
+ case $arg in
27
+ humaneval|mtbench|gsm8k)
28
+ BENCHMARKS+=("$arg")
29
+ ;;
30
+ --quick)
31
+ QUICK=true
32
+ ;;
33
+ *)
34
+ EXTRA_ARGS+=("$arg")
35
+ ;;
36
+ esac
37
+ done
38
+
39
+ if [ ${#BENCHMARKS[@]} -eq 0 ]; then
40
+ BENCHMARKS=(humaneval mtbench gsm8k)
41
+ fi
42
+
43
+ if [ "$QUICK" = true ]; then
44
+ EXTRA_ARGS+=(--num-samples 20)
45
+ fi
46
+
47
+ TIMESTAMP=$(date +%Y%m%d_%H%M%S)
48
+
49
+ echo "============================================"
50
+ echo " DFlash-LoRA-Inject Offline Eval"
51
+ echo " GPUs : $NUM_GPUS"
52
+ echo " benchmarks : ${BENCHMARKS[*]}"
53
+ echo " extra args : ${EXTRA_ARGS[*]}"
54
+ echo " results : $RESULT_DIR"
55
+ echo "============================================"
56
+ echo ""
57
+
58
+ mkdir -p $RESULT_DIR
59
+
60
+ $PYTHON -m torch.distributed.run \
61
+ --standalone \
62
+ --nproc_per_node $NUM_GPUS \
63
+ $SCRIPT_DIR/eval_dflash_lora_inject.py \
64
+ --benchmarks ${BENCHMARKS[@]} \
65
+ --output-dir $RESULT_DIR \
66
+ "${EXTRA_ARGS[@]}" \
67
+ 2>&1 | tee $RESULT_DIR/bench_dflash_lora_inject_offline_${TIMESTAMP}.log
68
+
69
+ echo ""
70
+ echo "Done. Latest result files:"
71
+ ls -lht $RESULT_DIR/*.json 2>/dev/null | head -5
syxin/run_bench_dflash_b16_baseline.sh ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # DFlash-b16 baseline: measure accepted length offline, 8 GPUs parallel.
3
+ # Usage:
4
+ # bash run_bench_dflash_b16_baseline.sh # 8 GPUs, all 3 benches
5
+ # bash run_bench_dflash_b16_baseline.sh humaneval # only humaneval
6
+ # bash run_bench_dflash_b16_baseline.sh --quick # 20 samples per bench
7
+ # NUM_GPUS=4 bash run_bench_dflash_b16_baseline.sh # 4 GPUs
8
+
9
+ set -e
10
+
11
+ SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)
12
+ PYTHON=/workspace/miniconda3/envs/spec/bin/python3
13
+ RESULT_DIR=/workspace/hanrui/syxin_old/Specforge/benchmarks/results
14
+ NUM_GPUS=${NUM_GPUS:-8}
15
+
16
+ BENCHMARKS=()
17
+ EXTRA_ARGS=()
18
+ QUICK=false
19
+
20
+ for arg in "$@"; do
21
+ case $arg in
22
+ humaneval|mtbench|gsm8k) BENCHMARKS+=("$arg") ;;
23
+ --quick) QUICK=true ;;
24
+ *) EXTRA_ARGS+=("$arg") ;;
25
+ esac
26
+ done
27
+
28
+ if [ ${#BENCHMARKS[@]} -eq 0 ]; then
29
+ BENCHMARKS=(humaneval mtbench gsm8k)
30
+ fi
31
+
32
+ if [ "$QUICK" = true ]; then
33
+ EXTRA_ARGS+=(--num-samples 20)
34
+ fi
35
+
36
+ TIMESTAMP=$(date +%Y%m%d_%H%M%S)
37
+
38
+ echo "============================================"
39
+ echo " DFlash-b16 Baseline Offline Eval"
40
+ echo " GPUs : $NUM_GPUS"
41
+ echo " draft : /workspace/models/Qwen3-8B-DFlash-b16"
42
+ echo " benchmarks : ${BENCHMARKS[*]}"
43
+ echo " extra args : ${EXTRA_ARGS[*]}"
44
+ echo "============================================"
45
+ echo ""
46
+
47
+ mkdir -p $RESULT_DIR
48
+
49
+ $PYTHON -m torch.distributed.run \
50
+ --standalone \
51
+ --nproc_per_node $NUM_GPUS \
52
+ $SCRIPT_DIR/eval_dflash_b16_baseline.py \
53
+ --benchmarks ${BENCHMARKS[@]} \
54
+ --output-dir $RESULT_DIR \
55
+ "${EXTRA_ARGS[@]}" \
56
+ 2>&1 | tee $RESULT_DIR/bench_dflash_b16_baseline_${TIMESTAMP}.log
57
+
58
+ echo ""
59
+ echo "Done. Latest result files:"
60
+ ls -lht $RESULT_DIR/*.json 2>/dev/null | head -5
syxin/run_qwen3_8b_sft_32gpu.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ export JOB_NAME='qwen3-8b-sft'
3
+ export GPU_NUMS=64
4
+ export TRAIN_SCRIPT='/workspace/hanrui/syxin/launch_train_wrapper.py'
5
+ export WORK_DIR='/workspace/hanrui/syxin/Specforge'
6
+
7
+ if [ $GPU_NUMS -lt 8 ]; then
8
+ export NNODES=1
9
+ export GPU_NUMS_PER_NODE=$GPU_NUMS
10
+ else
11
+ export NNODES=$((GPU_NUMS/8))
12
+ export GPU_NUMS_PER_NODE=8
13
+ fi
14
+
15
+ # 使用 spec 环境的 northjob
16
+ /workspace/miniconda3/envs/spec/bin/northjob \
17
+ create \
18
+ --job-type train \
19
+ --nproc-per-node $GPU_NUMS_PER_NODE \
20
+ --gpu-per-node $GPU_NUMS_PER_NODE \
21
+ --nnodes $NNODES \
22
+ --k8s-priority 3 \
23
+ --k8s-queue bg-agentic-coding \
24
+ --k8s-namespace bg-agentic-coding \
25
+ --k8s-pvc-name i-xinsiyang-y4zy0sik0a \
26
+ --k8s-pvc-mount-path /workspace \
27
+ --k8s-no-reclaim \
28
+ --k8s-images harbor.local.clusters/bp/megatron-bplm:25.03_fp8.ibgda.qwen3.next.fix_triton.fix_te.hf457.qwen3_vl \
29
+ --job-name $JOB_NAME \
30
+ --workspace $WORK_DIR \
31
+ $TRAIN_SCRIPT $GPU_NUMS_PER_NODE
syxin/run_train_dflash_direct_inject.sh ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ ROOT_DIR=/workspace/hanrui/syxin_old/Specforge
5
+ NUM_GPUS=8
6
+ OUTPUT_DIR=$ROOT_DIR/outputs/qwen3-8b-dflash-direct-inject
7
+ if [[ $# -ge 1 ]]; then
8
+ NUM_GPUS=$1
9
+ shift
10
+ fi
11
+ if [[ $# -ge 1 && "${1:0:1}" != "-" ]]; then
12
+ OUTPUT_DIR=$1
13
+ shift
14
+ fi
15
+ EXTRA_ARGS=("$@")
16
+
17
+ export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels
18
+ export SPECFORGE_DATA_NUM_PROC=16
19
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
20
+ export PYTORCH_ALLOC_CONF=expandable_segments:True
21
+ export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}"
22
+ DEFAULT_SPECFORGE_PY=/workspace/hanrui/specforge/bin/python3
23
+ if [[ -z "${PYTHON_BIN:-}" ]]; then
24
+ if [[ -x "$DEFAULT_SPECFORGE_PY" ]]; then
25
+ PYTHON_BIN="$DEFAULT_SPECFORGE_PY"
26
+ else
27
+ PYTHON_BIN=python3
28
+ fi
29
+ fi
30
+
31
+ cd $ROOT_DIR
32
+
33
+ $PYTHON_BIN -m torch.distributed.run \
34
+ --standalone \
35
+ --nproc_per_node $NUM_GPUS \
36
+ scripts/train_dflash.py \
37
+ --target-model-path /workspace/models/Qwen3-8B \
38
+ --target-model-backend sglang \
39
+ --train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
40
+ --output-dir $OUTPUT_DIR \
41
+ --block-size 16 \
42
+ --num-draft-layers 36 \
43
+ --attention-backend flex_attention \
44
+ --max-length 2048 \
45
+ --batch-size 1 \
46
+ --accumulation-steps 8 \
47
+ --num-epochs 3 \
48
+ --learning-rate 6e-4 \
49
+ --loss-decay-gamma 7 \
50
+ --lm-head-chunk-size 256 \
51
+ --gradient-checkpointing \
52
+ --chat-template qwen \
53
+ --log-interval 50 \
54
+ --save-interval 500 \
55
+ --cache-dir $ROOT_DIR/cache \
56
+ "${EXTRA_ARGS[@]}"
syxin/run_train_dflash_lora_inject.sh ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ ROOT_DIR=/workspace/hanrui/syxin/Specforge
5
+ NUM_GPUS=8
6
+ OUTPUT_DIR=$ROOT_DIR/outputs/qwen3-8b-dflash-lora-inject
7
+ CACHE_DIR=/tmp/specforge_cache
8
+
9
+ # Parse arguments
10
+ if [[ $# -ge 1 ]]; then
11
+ NUM_GPUS=$1
12
+ shift
13
+ fi
14
+ if [[ $# -ge 1 && "${1:0:1}" != "-" ]]; then
15
+ OUTPUT_DIR=$1
16
+ shift
17
+ fi
18
+ EXTRA_ARGS=("$@")
19
+
20
+ # Environment variables
21
+ export TORCHINDUCTOR_CACHE_DIR=/tmp/specforge_cache/compiled_kernels
22
+ export SPECFORGE_DATA_NUM_PROC=16
23
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
24
+ export PYTORCH_ALLOC_CONF=expandable_segments:True
25
+ export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}"
26
+ export HF_DATASETS_CACHE=/tmp/specforge_cache/hf_datasets
27
+ export HF_HOME=/tmp/specforge_cache/hf_home
28
+
29
+ # Python binary
30
+ DEFAULT_SPECFORGE_PY=/workspace/hanrui/specforge/bin/python3
31
+ if [[ -z "${PYTHON_BIN:-}" ]]; then
32
+ if [[ -x "$DEFAULT_SPECFORGE_PY" ]]; then
33
+ PYTHON_BIN="$DEFAULT_SPECFORGE_PY"
34
+ else
35
+ PYTHON_BIN=python3
36
+ fi
37
+ fi
38
+
39
+ cd $ROOT_DIR
40
+
41
+ $PYTHON_BIN -m torch.distributed.run \
42
+ --standalone \
43
+ --nproc_per_node $NUM_GPUS \
44
+ scripts/train_dflash_lora_inject.py \
45
+ --target-model-path /workspace/models/Qwen3-8B \
46
+ --target-model-backend hf \
47
+ --train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
48
+ --output-dir $OUTPUT_DIR \
49
+ --block-size 16 \
50
+ --attention-backend additive \
51
+ --attn-implementation sdpa \
52
+ --max-length 2048 \
53
+ --batch-size 8 \
54
+ --accumulation-steps 8 \
55
+ --num-epochs 3 \
56
+ --learning-rate 5e-5 \
57
+ --loss-decay-gamma 7 \
58
+ --gradient-checkpointing \
59
+ --chat-template qwen \
60
+ --log-interval 50 \
61
+ --save-interval 500 \
62
+ --cache-dir $CACHE_DIR \
63
+ --lora-rank 32 \
64
+ --lora-alpha 64 \
65
+ --lora-dropout 0.1 \
66
+ --trust-remote-code \
67
+ --dataloader-num-workers 0 \
68
+ --early-stop \
69
+ --early-stop-patience 5 \
70
+ --early-stop-min-delta 0.005 \
71
+ "${EXTRA_ARGS[@]}"
syxin/run_train_multinode.sh ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ ROOT_DIR=/workspace/hanrui/syxin/Specforge
5
+ NUM_GPUS=8
6
+ OUTPUT_DIR=$ROOT_DIR/outputs/qwen3-8b-sft-32gpu-v3
7
+ CACHE_DIR=/tmp/specforge_cache
8
+
9
+ # Parse arguments
10
+ if [[ $# -ge 1 ]]; then
11
+ NUM_GPUS=$1
12
+ shift
13
+ fi
14
+ if [[ $# -ge 1 && "${1:0:1}" != "-" ]]; then
15
+ OUTPUT_DIR=$1
16
+ shift
17
+ fi
18
+ EXTRA_ARGS=("$@")
19
+
20
+ # Environment variables
21
+ export TORCHINDUCTOR_CACHE_DIR=/tmp/specforge_cache/compiled_kernels
22
+ export SPECFORGE_DATA_NUM_PROC=16
23
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
24
+ export PYTORCH_ALLOC_CONF=expandable_segments:True
25
+ export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}"
26
+ export HF_DATASETS_CACHE=/tmp/specforge_cache/hf_datasets
27
+ export HF_HOME=/tmp/specforge_cache/hf_home
28
+
29
+ # Python binary
30
+ DEFAULT_SPECFORGE_PY=/workspace/miniconda3/envs/spec/bin/python3
31
+ if [[ -z "${PYTHON_BIN:-}" ]]; then
32
+ if [[ -x "$DEFAULT_SPECFORGE_PY" ]]; then
33
+ PYTHON_BIN="$DEFAULT_SPECFORGE_PY"
34
+ else
35
+ PYTHON_BIN=python3
36
+ fi
37
+ fi
38
+
39
+ cd $ROOT_DIR
40
+
41
+ # northjob 已经通过 torchrun 设置了分布式环境变量
42
+ # 直接运行训练脚本,不要再启动 torch.distributed.run
43
+ $PYTHON_BIN scripts/train_dflash_lora_inject.py \
44
+ --target-model-path /workspace/models/Qwen3-8B \
45
+ --target-model-backend hf \
46
+ --train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
47
+ --output-dir $OUTPUT_DIR \
48
+ --block-size 16 \
49
+ --attention-backend additive \
50
+ --attn-implementation sdpa \
51
+ --max-length 2048 \
52
+ --batch-size 4 \
53
+ --accumulation-steps 16 \
54
+ --num-epochs 3 \
55
+ --learning-rate 5e-5 \
56
+ --loss-decay-gamma 7 \
57
+ --gradient-checkpointing \
58
+ --chat-template qwen \
59
+ --log-interval 50 \
60
+ --save-interval 500 \
61
+ --cache-dir $CACHE_DIR \
62
+ --lora-rank 32 \
63
+ --lora-alpha 64 \
64
+ --lora-dropout 0.1 \
65
+ --trust-remote-code \
66
+ --dataloader-num-workers 0 \
67
+ "${EXTRA_ARGS[@]}"
syxin/run_train_qwen3_8b_sft_32gpu.sh ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ ROOT_DIR=/workspace/hanrui/syxin_old/Specforge
5
+ NUM_GPUS=8
6
+ OUTPUT_DIR=$ROOT_DIR/outputs/qwen3-8b-sft-32gpu-v2
7
+ CACHE_DIR=/tmp/specforge_cache_sft
8
+
9
+ # Parse arguments
10
+ if [[ $# -ge 1 ]]; then
11
+ NUM_GPUS=$1
12
+ shift
13
+ fi
14
+ if [[ $# -ge 1 && "${1:0:1}" != "-" ]]; then
15
+ OUTPUT_DIR=$1
16
+ shift
17
+ fi
18
+ EXTRA_ARGS=("$@")
19
+
20
+ # Environment variables
21
+ export TORCHINDUCTOR_CACHE_DIR=/tmp/specforge_cache_sft/compiled_kernels
22
+ export SPECFORGE_DATA_NUM_PROC=16
23
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
24
+ export PYTORCH_ALLOC_CONF=expandable_segments:True
25
+ export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}"
26
+ export HF_DATASETS_CACHE=/tmp/specforge_cache_sft/hf_datasets
27
+ export HF_HOME=/tmp/specforge_cache_sft/hf_home
28
+
29
+ # Python binary
30
+ DEFAULT_SPECFORGE_PY=/workspace/hanrui/specforge/bin/python3
31
+ if [[ -z "${PYTHON_BIN:-}" ]]; then
32
+ if [[ -x "$DEFAULT_SPECFORGE_PY" ]]; then
33
+ PYTHON_BIN="$DEFAULT_SPECFORGE_PY"
34
+ else
35
+ PYTHON_BIN=python3
36
+ fi
37
+ fi
38
+
39
+ cd $ROOT_DIR
40
+
41
+ # northjob 已经通过 torchrun 启动了分布式,这里直接运行训练脚本
42
+ $PYTHON_BIN $ROOT_DIR/scripts/train_dflash_lora_inject.py \
43
+ --target-model-path /workspace/models/Qwen3-8B \
44
+ --target-model-backend hf \
45
+ --train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
46
+ --output-dir $OUTPUT_DIR \
47
+ --block-size 16 \
48
+ --attention-backend additive \
49
+ --attn-implementation sdpa \
50
+ --max-length 2048 \
51
+ --batch-size 8 \
52
+ --accumulation-steps 8 \
53
+ --num-epochs 3 \
54
+ --learning-rate 5e-5 \
55
+ --loss-decay-gamma 7 \
56
+ --gradient-checkpointing \
57
+ --chat-template qwen \
58
+ --log-interval 50 \
59
+ --save-interval 500 \
60
+ --cache-dir $CACHE_DIR \
61
+ --lora-rank 32 \
62
+ --lora-alpha 64 \
63
+ --lora-dropout 0.1 \
64
+ --trust-remote-code \
65
+ --dataloader-num-workers 0 \
66
+ "${EXTRA_ARGS[@]}"
syxin/server.log ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /workspace/hanrui/sglang/python/sglang/launch_server.py:51: UserWarning: 'python -m sglang.launch_server' is still supported, but 'sglang serve' is the recommended entrypoint.
2
+ Example: sglang serve --model-path <model> [options]
3
+ warnings.warn(
4
+ [2026-03-07 15:24:13] INFO server_args.py:2048: Attention backend not specified. Use fa3 backend by default.
5
+ [2026-03-07 15:24:13] WARNING server_args.py:2629: Max running requests is reset to 48 for speculative decoding. You can override this by explicitly setting --max-running-requests.
6
+ [2026-03-07 15:24:13] WARNING server_args.py:2650: Overlap scheduler is disabled when spec v2 is off or using unsupported speculative algorithm. You can set env SGLANG_ENABLE_SPEC_V2=True to enable the experimental overlap scheduler.
7
+ [2026-03-07 15:24:13] WARNING server_args.py:2712: speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1
8
+ [2026-03-07 15:24:14] server_args=ServerArgs(model_path='/workspace/models/Qwen3-8B', tokenizer_path='/workspace/models/Qwen3-8B', tokenizer_mode='auto', tokenizer_worker_num=1, skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=True, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='10.233.100.123', port=30000, fastapi_root_path='', grpc_mode=False, skip_server_warmup=False, warmups=None, nccl_port=None, checkpoint_engine_wait_weights_before_ready=False, ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_keyfile_password=None, enable_ssl_refresh=False, dtype='bfloat16', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', enable_fp32_lm_head=False, modelopt_quant=None, modelopt_checkpoint_restore_path=None, modelopt_checkpoint_save_path=None, modelopt_export_path=None, quantize_and_serve=False, rl_quant_profile=None, mem_fraction_static=0.8, max_running_requests=48, max_queued_requests=None, max_total_tokens=None, chunked_prefill_size=8192, enable_dynamic_chunking=False, max_prefill_tokens=16384, prefill_max_requests=None, schedule_policy='fcfs', enable_priority_scheduling=False, disable_priority_preemption=False, default_priority_value=None, abort_on_priority_when_disabled=False, schedule_low_priority_values_first=False, priority_scheduling_preemption_threshold=10, schedule_conservativeness=1.0, page_size=1, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, radix_eviction_policy='lru', enable_prefill_delayer=False, prefill_delayer_max_delay_passes=30, prefill_delayer_token_usage_low_watermark=None, prefill_delayer_forward_passes_buckets=None, prefill_delayer_wait_seconds_buckets=None, device='cuda', tp_size=4, pp_size=1, pp_max_micro_batch_size=None, pp_async_batch_depth=0, stream_interval=1, stream_output=False, enable_streaming_session=False, random_seed=551181117, constrained_json_whitespace_pattern=None, constrained_json_disable_any_whitespace=False, watchdog_timeout=300, soft_watchdog_timeout=None, dist_timeout=None, download_dir=None, model_checksum=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, use_ray=False, custom_sigquit_handler=None, log_level='info', log_level_http=None, log_requests=False, log_requests_level=2, log_requests_format='text', log_requests_target=None, uvicorn_access_log_exclude_prefixes=[], crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, tokenizer_metrics_custom_labels_header='x-custom-labels', tokenizer_metrics_allowed_custom_labels=None, extra_metric_labels=None, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, prompt_tokens_buckets=None, generation_tokens_buckets=None, gc_warning_threshold_secs=0.0, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, enable_trace=False, otlp_traces_endpoint='localhost:4317', export_metrics_to_file=False, export_metrics_to_file_dir=None, api_key=None, admin_api_key=None, served_model_name='/workspace/models/Qwen3-8B', weight_version='default', chat_template=None, hf_chat_template_name=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, sampling_defaults='model', dp_size=1, load_balance_method='round_robin', attn_cp_size=1, moe_dp_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, enable_lora_overlap_loading=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_eviction_policy='lru', lora_backend='csgmv', max_lora_chunk_size=16, attention_backend='fa3', decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, fp8_gemm_runner_backend='auto', fp4_gemm_runner_backend='flashinfer_cutlass', nsa_prefill_backend=None, nsa_decode_backend=None, disable_flashinfer_autotune=False, mamba_backend='triton', speculative_algorithm='STANDALONE', speculative_draft_model_path='/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-merged', speculative_draft_model_revision='main', speculative_draft_load_format=None, speculative_num_steps=4, speculative_eagle_topk=1, speculative_num_draft_tokens=5, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, speculative_attention_mode='prefill', speculative_draft_attention_backend=None, speculative_moe_runner_backend='auto', speculative_moe_a2a_backend=None, speculative_draft_model_quantization=None, speculative_ngram_min_match_window_size=1, speculative_ngram_max_match_window_size=12, speculative_ngram_min_bfs_breadth=1, speculative_ngram_max_bfs_breadth=10, speculative_ngram_match_type='BFS', speculative_ngram_branch_length=18, speculative_ngram_capacity=10000000, enable_multi_layer_eagle=False, ep_size=1, moe_a2a_backend='none', moe_runner_backend='auto', flashinfer_mxfp4_moe_precision='default', enable_flashinfer_allreduce_fusion=False, enable_aiter_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm=None, init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, eplb_min_rebalancing_utilization_threshold=1.0, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, elastic_ep_backend=None, enable_elastic_expert_backup=False, mooncake_ib_device=None, max_mamba_cache_size=None, mamba_ssm_dtype=None, mamba_full_memory_ratio=0.9, mamba_scheduler_strategy='no_buffer', mamba_track_interval=256, linear_attn_backend='triton', linear_attn_decode_backend=None, linear_attn_prefill_backend=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through', hicache_io_backend='kernel', hicache_mem_layout='layer_first', disable_hicache_numa_detect=False, hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', hicache_storage_backend_extra_config=None, hierarchical_sparse_attention_extra_config=None, enable_lmcache=False, kt_weight_path=None, kt_method='AMXINT4', kt_cpuinfer=None, kt_threadpool_count=2, kt_num_gpu_experts=None, kt_max_deferred_experts_per_token=None, dllm_algorithm=None, dllm_algorithm_config=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, cpu_offload_gb=0, offload_group_size=-1, offload_num_in_group=1, offload_prefetch_step=1, offload_mode='cpu', multi_item_scoring_delimiter=None, disable_radix_cache=False, cuda_graph_max_bs=512, cuda_graph_bs=[1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432, 448, 464, 480, 496, 512], disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_layerwise_nvtx_marker=False, enable_nccl_nvls=False, enable_symm_mem=False, disable_flashinfer_cutlass_moe_fp4_allgather=False, enable_tokenizer_batch_encode=False, disable_tokenizer_batch_decode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, enable_torch_symm_mem=False, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_single_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, disable_piecewise_cuda_graph=True, enforce_piecewise_cuda_graph=False, enable_torch_compile_debug_mode=False, torch_compile_max_bs=32, piecewise_cuda_graph_max_tokens=8192, piecewise_cuda_graph_tokens=[4, 8, 12, 16, 20, 24, 28, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 288, 320, 352, 384, 416, 448, 480, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1280, 1536, 1792, 2048, 2304, 2560, 2816, 3072, 3328, 3584, 3840, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192], piecewise_cuda_graph_compiler='eager', torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, triton_attention_split_tile_size=None, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, enable_weights_cpu_backup=False, enable_draft_weights_cpu_backup=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, keep_mm_feature_on_device=False, enable_return_hidden_states=False, enable_return_routed_experts=False, scheduler_recv_interval=1, numa_node=None, enable_deterministic_inference=False, rl_on_policy_target=None, enable_attn_tp_input_scattered=False, enable_nsa_prefill_context_parallel=False, nsa_prefill_cp_mode='round-robin-split', enable_fused_qk_norm_rope=False, enable_precise_embedding_interpolation=False, enable_fused_moe_sum_all_reduce=False, enable_dynamic_batch_tokenizer=False, dynamic_batch_tokenizer_batch_size=32, dynamic_batch_tokenizer_batch_timeout=0.002, debug_tensor_dump_output_folder=None, debug_tensor_dump_layers=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_ib_device=None, disaggregation_decode_enable_offload_kvcache=False, num_reserved_decode_tokens=512, disaggregation_decode_polling_interval=1, encoder_only=False, language_only=False, encoder_transfer_backend='zmq_to_scheduler', encoder_urls=[], enable_adaptive_dispatch_to_encoder=False, custom_weight_loader=[], weight_loader_disable_mmap=False, remote_instance_weight_loader_seed_instance_ip=None, remote_instance_weight_loader_seed_instance_service_port=None, remote_instance_weight_loader_send_weights_group_ports=None, remote_instance_weight_loader_backend='nccl', remote_instance_weight_loader_start_seed_via_transfer_engine=False, enable_pdmux=False, pdmux_config_path=None, sm_group_num=8, mm_max_concurrent_calls=32, mm_per_request_timeout=10.0, enable_broadcast_mm_inputs_process=False, enable_prefix_mm_cache=False, mm_enable_dp_encoder=False, mm_process_config={}, limit_mm_data_per_request=None, enable_mm_global_cache=False, decrypted_config_file=None, decrypted_draft_config_file=None, forward_hooks=None)
9
+ [2026-03-07 15:24:15] Using default HuggingFace chat template with detected content format: string
10
+ [2026-03-07 15:24:25 TP2] Mamba selective_state_update backend initialized: triton
11
+ [2026-03-07 15:24:25 TP2] Init torch distributed begin.
12
+ [2026-03-07 15:24:26 TP0] Mamba selective_state_update backend initialized: triton
13
+ [2026-03-07 15:24:26 TP0] Init torch distributed begin.
14
+ [2026-03-07 15:24:26 TP3] Mamba selective_state_update backend initialized: triton
15
+ [2026-03-07 15:24:26 TP1] Mamba selective_state_update backend initialized: triton
16
+ [2026-03-07 15:24:26 TP3] Init torch distributed begin.
17
+ [2026-03-07 15:24:26 TP1] Init torch distributed begin.
18
+ [Gloo] Rank 1 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
19
+ [Gloo] Rank 0 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
20
+ [Gloo] Rank 3 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
21
+ [Gloo] Rank 2 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
22
+ [Gloo] Rank 0 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
23
+ [Gloo] Rank 2 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
24
+ [Gloo] Rank 1 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
25
+ [Gloo] Rank 3 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
26
+ [2026-03-07 15:24:27 TP0] sglang is using nccl==2.27.5
27
+ [2026-03-07 15:24:29 TP0] Scheduler hit an exception: Traceback (most recent call last):
28
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 3239, in run_scheduler_process
29
+ scheduler = Scheduler(
30
+ ^^^^^^^^^^
31
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 365, in __init__
32
+ self.init_model_worker()
33
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 561, in init_model_worker
34
+ self.init_tp_model_worker()
35
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 519, in init_tp_model_worker
36
+ self.tp_worker = TpModelWorker(
37
+ ^^^^^^^^^^^^^^
38
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/tp_worker.py", line 258, in __init__
39
+ self._init_model_runner()
40
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/tp_worker.py", line 341, in _init_model_runner
41
+ self._model_runner = ModelRunner(
42
+ ^^^^^^^^^^^^
43
+ File "/workspace/hanrui/sglang/python/sglang/srt/model_executor/model_runner.py", line 395, in __init__
44
+ pre_model_load_memory = self.init_torch_distributed()
45
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
46
+ File "/workspace/hanrui/sglang/python/sglang/srt/model_executor/model_runner.py", line 813, in init_torch_distributed
47
+ initialize_model_parallel(
48
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 1764, in initialize_model_parallel
49
+ _TP = init_model_parallel_group(
50
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^
51
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 1450, in init_model_parallel_group
52
+ return GroupCoordinator(
53
+ ^^^^^^^^^^^^^^^^^
54
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 357, in __init__
55
+ self.pynccl_comm = PyNcclCommunicator(
56
+ ^^^^^^^^^^^^^^^^^^^
57
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl.py", line 113, in __init__
58
+ self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
59
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
60
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 401, in ncclCommInitRank
61
+ self.NCCL_CHECK(
62
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 376, in NCCL_CHECK
63
+ raise RuntimeError(f"NCCL error: {error_str}")
64
+ RuntimeError: NCCL error: unhandled system error (run with NCCL_DEBUG=INFO for details)
65
+
66
+ [2026-03-07 15:24:29] Received sigquit from a child process. It usually means the child failed.
67
+ [2026-03-07 15:24:29 TP2] Scheduler hit an exception: Traceback (most recent call last):
68
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 3239, in run_scheduler_process
69
+ scheduler = Scheduler(
70
+ ^^^^^^^^^^
71
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 365, in __init__
72
+ self.init_model_worker()
73
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 561, in init_model_worker
74
+ self.init_tp_model_worker()
75
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 519, in init_tp_model_worker
76
+ self.tp_worker = TpModelWorker(
77
+ ^^^^^^^^^^^^^^
78
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/tp_worker.py", line 258, in __init__
79
+ self._init_model_runner()
80
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/tp_worker.py", line 341, in _init_model_runner
81
+ self._model_runner = ModelRunner(
82
+ ^^^^^^^^^^^^
83
+ File "/workspace/hanrui/sglang/python/sglang/srt/model_executor/model_runner.py", line 395, in __init__
84
+ pre_model_load_memory = self.init_torch_distributed()
85
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
86
+ File "/workspace/hanrui/sglang/python/sglang/srt/model_executor/model_runner.py", line 813, in init_torch_distributed
87
+ initialize_model_parallel(
88
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 1764, in initialize_model_parallel
89
+ _TP = init_model_parallel_group(
90
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^
91
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 1450, in init_model_parallel_group
92
+ return GroupCoordinator(
93
+ ^^^^^^^^^^^^^^^^^
94
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 357, in __init__
95
+ self.pynccl_comm = PyNcclCommunicator(
96
+ ^^^^^^^^^^^^^^^^^^^
97
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl.py", line 113, in __init__
98
+ self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
99
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
100
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 401, in ncclCommInitRank
101
+ self.NCCL_CHECK(
102
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 376, in NCCL_CHECK
103
+ raise RuntimeError(f"NCCL error: {error_str}")
104
+ RuntimeError: NCCL error: unhandled system error (run with NCCL_DEBUG=INFO for details)
105
+
106
+ [2026-03-07 15:24:29 TP1] Scheduler hit an exception: Traceback (most recent call last):
107
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 3239, in run_scheduler_process
108
+ scheduler = Scheduler(
109
+ ^^^^^^^^^^
110
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 365, in __init__
111
+ self.init_model_worker()
112
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 561, in init_model_worker
113
+ self.init_tp_model_worker()
114
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 519, in init_tp_model_worker
115
+ self.tp_worker = TpModelWorker(
116
+ ^^^^^^^^^^^^^^
117
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/tp_worker.py", line 258, in __init__
118
+ self._init_model_runner()
119
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/tp_worker.py", line 341, in _init_model_runner
120
+ self._model_runner = ModelRunner(
121
+ ^^^^^^^^^^^^
122
+ File "/workspace/hanrui/sglang/python/sglang/srt/model_executor/model_runner.py", line 395, in __init__
123
+ pre_model_load_memory = self.init_torch_distributed()
124
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
125
+ File "/workspace/hanrui/sglang/python/sglang/srt/model_executor/model_runner.py", line 813, in init_torch_distributed
126
+ initialize_model_parallel(
127
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 1764, in initialize_model_parallel
128
+ _TP = init_model_parallel_group(
129
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^
130
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 1450, in init_model_parallel_group
131
+ return GroupCoordinator(
132
+ ^^^^^^^^^^^^^^^^^
133
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 357, in __init__
134
+ self.pynccl_comm = PyNcclCommunicator(
135
+ ^^^^^^^^^^^^^^^^^^^
136
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl.py", line 113, in __init__
137
+ self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
138
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
139
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 401, in ncclCommInitRank
140
+ self.NCCL_CHECK(
141
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 376, in NCCL_CHECK
142
+ raise RuntimeError(f"NCCL error: {error_str}")
143
+ RuntimeError: NCCL error: unhandled system error (run with NCCL_DEBUG=INFO for details)
144
+
145
+ [2026-03-07 15:24:29] Received sigquit from a child process. It usually means the child failed.
146
+ [2026-03-07 15:24:29] Received sigquit from a child process. It usually means the child failed.
147
+ [2026-03-07 15:24:29 TP3] Scheduler hit an exception: Traceback (most recent call last):
148
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 3239, in run_scheduler_process
149
+ scheduler = Scheduler(
150
+ ^^^^^^^^^^
151
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 365, in __init__
152
+ self.init_model_worker()
153
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 561, in init_model_worker
154
+ self.init_tp_model_worker()
155
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/scheduler.py", line 519, in init_tp_model_worker
156
+ self.tp_worker = TpModelWorker(
157
+ ^^^^^^^^^^^^^^
158
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/tp_worker.py", line 258, in __init__
159
+ self._init_model_runner()
160
+ File "/workspace/hanrui/sglang/python/sglang/srt/managers/tp_worker.py", line 341, in _init_model_runner
161
+ self._model_runner = ModelRunner(
162
+ ^^^^^^^^^^^^
163
+ File "/workspace/hanrui/sglang/python/sglang/srt/model_executor/model_runner.py", line 395, in __init__
164
+ pre_model_load_memory = self.init_torch_distributed()
165
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
166
+ File "/workspace/hanrui/sglang/python/sglang/srt/model_executor/model_runner.py", line 813, in init_torch_distributed
167
+ initialize_model_parallel(
168
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 1764, in initialize_model_parallel
169
+ _TP = init_model_parallel_group(
170
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^
171
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 1450, in init_model_parallel_group
172
+ return GroupCoordinator(
173
+ ^^^^^^^^^^^^^^^^^
174
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/parallel_state.py", line 357, in __init__
175
+ self.pynccl_comm = PyNcclCommunicator(
176
+ ^^^^^^^^^^^^^^^^^^^
177
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl.py", line 113, in __init__
178
+ self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
179
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
180
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 401, in ncclCommInitRank
181
+ self.NCCL_CHECK(
182
+ File "/workspace/hanrui/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 376, in NCCL_CHECK
183
+ raise RuntimeError(f"NCCL error: {error_str}")
184
+ RuntimeError: NCCL error: unhandled system error (run with NCCL_DEBUG=INFO for details)
185
+
186
+ [2026-03-07 15:24:29] Received sigquit from a child process. It usually means the child failed.
syxin/start_server.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Step 2: Launch SGLang server with STANDALONE speculative decoding.
3
+ # Usage:
4
+ # bash start_server.sh
5
+ # bash start_server.sh 8 # use tp=8
6
+
7
+ set -e
8
+
9
+ TP=${1:-2}
10
+
11
+ BASE_MODEL=/workspace/models/Qwen3-8B
12
+ MERGED=/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-merged
13
+ INTRANET_IP=10.1.1.131
14
+ PORT=30000
15
+
16
+ if [ ! -d "$MERGED" ]; then
17
+ echo "[ERROR] Merged model not found: $MERGED"
18
+ echo " Run: conda activate sglang && python3 merge_lora.py"
19
+ exit 1
20
+ fi
21
+
22
+ echo "============================================"
23
+ echo " SGLang STANDALONE Speculative Decoding"
24
+ echo " target : $BASE_MODEL"
25
+ echo " draft : $MERGED"
26
+ echo " host : $INTRANET_IP:$PORT"
27
+ echo " tp : $TP"
28
+ echo "============================================"
29
+
30
+ /workspace/miniconda3/envs/sglang/bin/python3 -m sglang.launch_server \
31
+ --model-path $BASE_MODEL \
32
+ --speculative-algorithm STANDALONE \
33
+ --speculative-draft-model-path $MERGED \
34
+ --speculative-num-steps 4 \
35
+ --speculative-eagle-topk 1 \
36
+ --speculative-num-draft-tokens 4 \
37
+ --tp-size $TP \
38
+ --mem-fraction-static 0.30 \
39
+ --trust-remote-code \
40
+ --host $INTRANET_IP \
41
+ --port $PORT \
42
+ --dtype bfloat16
syxin/start_server_dflash.sh ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Evaluate DFlash-LoRA-Inject: measure accepted length OFFLINE.
3
+ # 8 GPUs parallel by default, each GPU runs a shard of prompts independently.
4
+ #
5
+ # WHY offline?
6
+ # sglang STANDALONE treats draft as an independent autoregressive model,
7
+ # completely ignoring the layer-by-layer injection that LoRA-Inject was
8
+ # trained with. Result: accept_length ≈ 4.7 for ALL models (no signal).
9
+ #
10
+ # sglang DFLASH expects the DFlash-b16 architecture (5-layer, fc+hidden_norm),
11
+ # which is structurally different from LoRA-Inject (full 36-layer + LoRA).
12
+ #
13
+ # So we run offline spec-generate with the correct injection pattern.
14
+ #
15
+ # Usage:
16
+ # bash start_server_dflash.sh # 8 GPUs, all benchmarks
17
+ # bash start_server_dflash.sh 4 # 4 GPUs
18
+ # bash start_server_dflash.sh 8 humaneval # specific benchmark
19
+ # bash start_server_dflash.sh 8 --num-samples 20 # quick test
20
+
21
+ set -e
22
+
23
+ SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)
24
+
25
+ NUM_GPUS=${1:-8}
26
+ shift 2>/dev/null || true
27
+
28
+ # ---- defaults ----
29
+ BASE_MODEL=/workspace/models/Qwen3-8B
30
+ ADAPTER_ROOT=/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject
31
+ CKPT=epoch_3_step_1400
32
+ MERGED=/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject-merged
33
+ RESULT_DIR=/workspace/hanrui/syxin_old/Specforge/benchmarks/results
34
+ PYTHON=/workspace/miniconda3/envs/spec/bin/python3
35
+
36
+ echo "============================================"
37
+ echo " DFlash-LoRA-Inject Offline Evaluation"
38
+ echo " target : $BASE_MODEL"
39
+ echo " ckpt : $CKPT"
40
+ echo " merged : $MERGED"
41
+ echo " GPUs : $NUM_GPUS"
42
+ echo "============================================"
43
+
44
+ $PYTHON -m torch.distributed.run \
45
+ --standalone \
46
+ --nproc_per_node $NUM_GPUS \
47
+ $SCRIPT_DIR/eval_dflash_lora_inject.py \
48
+ --base-model $BASE_MODEL \
49
+ --adapter-root $ADAPTER_ROOT \
50
+ --ckpt $CKPT \
51
+ --merged-path $MERGED \
52
+ --block-size 16 \
53
+ --output-dir $RESULT_DIR \
54
+ "$@"
syxin/step1.md ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Now I have a complete picture. Here's the full analysis and proposed plan.                                                               
2
+                                                                                                                                            
3
+   ---                                                                                                                                      
4
+   Current Logic of train_dflash.py                                                                                                         
5
+                                                                                                                                            
6
+   This script trains a small, standalone DFlash draft model (1–10 transformer layers) that learns to predict token blocks in parallel,     
7
+   guided by hidden states extracted from a frozen target model.
8
+
9
+   Architecture (two separate models)
10
+
11
+   1. Target model (frozen, e.g. Qwen3-8B/32B) — runs a full forward pass to produce hidden states at selected layers.
12
+   2. Draft model (DFlashDraftModel, ~1–10 layers) — a lightweight Qwen3-based decoder that takes noise embeddings + target hidden states
13
+   and predicts the block tokens.
14
+   3. Target embed_tokens + lm_head — loaded separately via TargetEmbeddingsAndHead to avoid duplicating the full target model in memory.
15
+
16
+   Key locations
17
+
18
+   ┌──────────────────────────┬────────────────────────────────────┬───────────────────────────────────────────────────────┐
19
+   │        Component         │                File                │                         Lines                         │
20
+   ├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
21
+   │ Model init               │ scripts/train_dflash.py            │ build_models() L254–311                               │
22
+   ├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
23
+   │ Target hidden extraction │ scripts/train_dflash.py            │ L644–647 (target_model.generate_dflash_data)          │
24
+   ├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
25
+   │ Forward pass             │ specforge/core/dflash.py           │ OnlineDFlashModel.forward() L243–332                  │
26
+   ├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
27
+   │ Loss calculation         │ specforge/core/dflash.py           │ _full_lm_loss() L382–417, _chunked_lm_loss() L419–478 │
28
+   ├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
29
+   │ Loss mask                │ specforge/core/dflash.py           │ create_dflash_loss_mask() L481–509                    │
30
+   ├──────────────────────────┼────────────────────────────────────┼────���──────────────────────────────────────────────────┤
31
+   │ Draft model architecture │ specforge/modeling/draft/dflash.py │ DFlashDraftModel L212–266                             │
32
+   ├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
33
+   │ DFlash attention         │ specforge/modeling/draft/dflash.py │ Qwen3DFlashAttention L42–134                          │
34
+   └──────────────────────────┴────────────────────────────────────┴───────────────────────────────────────────────────────┘
35
+
36
+   Forward pass flow (per training step)
37
+
38
+   input_ids, attention_mask, loss_mask  →  target_model.generate_dflash_data()
39
+                                                 ↓
40
+                                        hidden_states (from target layers [1,9,17,25,33])
41
+                                                 ↓
42
+                                 OnlineDFlashModel.forward():
43
+                                   1. Truncate to block boundary
44
+                                   2. prepare_noise_input(): anchor tokens kept, rest → MASK
45
+                                   3. embed_tokens(noise_input_ids) → noise_embedding
46
+                                   4. Build DFlash attention mask (flex_attention or additive)
47
+                                   5. draft_model(noise_embedding, target_hidden, mask)
48
+                                   6. lm_head(hidden) → logits
49
+                                   7. CE loss on non-anchor positions (weighted by loss_mask × decay)
50
+
51
+   The draft model's custom Qwen3DFlashAttention concatenates [context_hidden, noise_hidden] as KV, with queries only from noise tokens. The
52
+    attention mask enforces: block tokens see all preceding blocks' context + bidirectional within their own block.
53
+
54
+   ---
55
+   What already exists: train_dflash_lora.py
56
+
57
+   Interestingly, the repo already has a LoRA variant at scripts/train_dflash_lora.py with its own model (DFlashLoRADraftModel) and wrapper
58
+   (OnlineDFlashLoRAModel). This is exactly the approach you described — Qwen3-8B + LoRA, no separate target model, 1-step diffusion
59
+   training. The key differences from train_dflash.py:
60
+
61
+   ┌─────────────────┬─────────────────────────────────────────────────────────────┬────────────────────────────────────────────────────┐
62
+   │     Aspect      │                       train_dflash.py                       │                train_dflash_lora.py                │
63
+   ├─────────────────┼─────────────────────────────────────────────────────────────┼────────────────────────────────────────────────────┤
64
+   │ Draft model     │ Small custom DFlashDraftModel (1–10 layers)                 │ Full Qwen3-8B + LoRA adapters                      │
65
+   ├─────────────────┼─────────────────────────────────────────────────────────────┼────────────────────────────────────────────────────┤
66
+   │ Target model    │ Separate frozen model for hidden state extraction           │ None — model uses its own representations          │
67
+   ├─────────────────┼─────────────────────────────────────────────────────────────┼───────────────────────────���────────────────────────┤
68
+   │ Attention       │ Custom Qwen3DFlashAttention (Q from noise, KV from [ctx,    │ Standard HF attention with 4D additive DFlash mask │
69
+   │                 │ noise])                                                     │                                                    │
70
+   ├─────────────────┼─────────────────────────────────────────────────────────────┼────────────────────────────────────────────────────┤
71
+   │ Forward         │ draft_model(noise_emb, target_hidden, mask)                 │ model(noise_input_ids, 4d_mask, position_ids) →    │
72
+   │                 │                                                             │ logits                                             │
73
+   ├─────────────────┼─────────────────────────────────────────────────────────────┼────────────────────────────────────────────────────┤
74
+   │ Trainable       │ All draft model params                                      │ Only LoRA (q/k/v/o_proj), base frozen              │
75
+   │ params          │                                                             │                                                    │
76
+   ├─────────────────┼─────────────────────────────────────────────────────────────┼────────────────────────────────────────────────────┤
77
+   │ FSDP strategy   │ SHARD_GRAD_OP                                               │ FULL_SHARD                                         │
78
+   └─────────────────┴─────────────────────────────────────────────────────────────┴────────────────────────────────────────────────────┘
79
+
80
+   ---
81
+   Proposed Modification Plan
82
+
83
+   Since train_dflash_lora.py already implements the core idea, the plan focuses on what's missing or needs improvement to make it a proper
84
+   "1-step dLLM draft model" for your research:
85
+
86
+   Phase 1: Validate and extend the existing LoRA pipeline
87
+
88
+   1. Add MLP to LoRA targets — The current config only targets q_proj, k_proj, v_proj, o_proj. For stronger 1-step diffusion capability,
89
+   add gate_proj, up_proj, down_proj to lora_target_modules. This gives the model more capacity to learn the non-autoregressive distribution
90
+    shift.
91
+   2. Add multi-step noise schedule support — Currently the training is strictly 1-step (all non-anchors → MASK). For a proper diffusion/AR
92
+   fusion, add an option for a noise schedule where a fraction of block tokens are revealed (not just the anchor), controlled by a
93
+   noise_ratio parameter. This would modify prepare_noise_input() in OnlineDFlashLoRAModel:
94
+   # Instead of: all non-anchor → MASK
95
+   # Allow: randomly keep some non-anchor tokens with probability (1 - noise_ratio)
96
+   3. Add configurable context_len strategy — Currently context_len=0 treats the whole sequence as blocks. Add a --context-ratio arg that
97
+   dynamically sets context_len as a fraction of the sequence, so the model learns to condition on varying amounts of AR-decoded prefix.
98
+
99
+   Phase 2: Training logic improvements
100
+
101
+   4. Add KL divergence loss — In addition to CE loss against ground truth, add an optional KL loss against the base model's AR distribution
102
+    (teacher forcing). This regularizes the LoRA model to stay close to the original Qwen3-8B distribution. Modify
103
+   OnlineDFlashLoRAModel.forward():
104
+   # Compute base model logits (no_grad, no LoRA) as teacher
105
+   # KL(draft_logits || teacher_logits) on block positions
106
+   # total_loss = ce_loss + kl_weight * kl_loss
107
+   5. Add evaluation with speculative decoding metrics — The current accuracy metric is block-wise acceptance rate. Add an eval loop that
108
+   actually runs speculative decoding (draft → verify) to measure real speedup, using the LoRA model as the drafter and the base model (with
109
+    LoRA disabled) as the verifier.
110
+
111
+   Phase 3: Integration with train_dflash.py style features
112
+
113
+   6. Port random anchor sampling — train_dflash.py has --random-anchor (L147–156) which samples diverse anchor positions per sequence. This
114
+    is missing from the LoRA variant and would improve training diversity. Port _sample_anchor_positions and _build_blocks_from_anchors from
115
+    OnlineDFlashModel to OnlineDFlashLoRAModel.
116
+   7. Port chunked lm_head — The LoRA variant materializes full [bsz, seq_len, vocab_size] logits. For Qwen3-8B's 152K vocab, this is ~1.2GB
117
+    per sample at seq_len=2048. Port the _chunked_lm_loss from OnlineDFlashModel for memory efficiency.
118
+   8. Add tensor parallelism support — The LoRA script currently forces tp_size=1. For Qwen3-8B on multi-GPU setups, add TP support to shard
119
+    the base model across GPUs while keeping LoRA params on each rank.
120
+
121
+   Minimal starting point
122
+
123
+   If you want to just get training running immediately, the existing train_dflash_lora.py + configs/qwen3-8b-dflash-lora.json should work
124
+   out of the box:
125
+
126
+   torchrun --nproc_per_node=4 scripts/train_dflash_lora.py \
127
+     --model-path Qwen/Qwen3-8B \
128
+     --lora-config configs/qwen3-8b-dflash-lora.json \
129
+     --train-data-path /path/to/data.jsonl \
130
+     --output-dir ./outputs/qwen3-8b-dflash-lora \
131
+     --block-size 16 \
132
+     --gradient-checkpointing \
133
+     --learning-rate 2e-4 \
134
+     --batch-size 1 \
135
+     --max-length 2048 \
136
+     --num-epochs 3
137
+
138
+   Want me to start implementing any of these phases? I'd suggest starting with Phase 1 (noise schedule + MLP targets) since that's the most
139
+    impactful for the diffusion/AR fusion research angle.