Lekr0 commited on
Commit
d522318
·
verified ·
1 Parent(s): 4024ed7

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. SpecForge-ext/benchmarks/README.md +67 -0
  2. SpecForge-ext/benchmarks/__init__.py +3 -0
  3. SpecForge-ext/benchmarks/bench_eagle3.py +268 -0
  4. SpecForge-ext/benchmarks/benchmarker/__init__.py +29 -0
  5. SpecForge-ext/benchmarks/benchmarker/__pycache__/__init__.cpython-310.pyc +0 -0
  6. SpecForge-ext/benchmarks/benchmarker/__pycache__/__init__.cpython-311.pyc +0 -0
  7. SpecForge-ext/benchmarks/benchmarker/__pycache__/__init__.cpython-312.pyc +0 -0
  8. SpecForge-ext/benchmarks/benchmarker/__pycache__/aime.cpython-310.pyc +0 -0
  9. SpecForge-ext/benchmarks/benchmarker/__pycache__/aime.cpython-311.pyc +0 -0
  10. SpecForge-ext/benchmarks/benchmarker/__pycache__/aime.cpython-312.pyc +0 -0
  11. SpecForge-ext/benchmarks/benchmarker/__pycache__/base.cpython-310.pyc +0 -0
  12. SpecForge-ext/benchmarks/benchmarker/__pycache__/base.cpython-311.pyc +0 -0
  13. SpecForge-ext/benchmarks/benchmarker/__pycache__/base.cpython-312.pyc +0 -0
  14. SpecForge-ext/benchmarks/benchmarker/__pycache__/ceval.cpython-310.pyc +0 -0
  15. SpecForge-ext/benchmarks/benchmarker/__pycache__/ceval.cpython-311.pyc +0 -0
  16. SpecForge-ext/benchmarks/benchmarker/__pycache__/financeqa.cpython-310.pyc +0 -0
  17. SpecForge-ext/benchmarks/benchmarker/__pycache__/financeqa.cpython-311.pyc +0 -0
  18. SpecForge-ext/benchmarks/benchmarker/__pycache__/gpqa.cpython-310.pyc +0 -0
  19. SpecForge-ext/benchmarks/benchmarker/__pycache__/gpqa.cpython-311.pyc +0 -0
  20. SpecForge-ext/benchmarks/benchmarker/__pycache__/gsm8k.cpython-310.pyc +0 -0
  21. SpecForge-ext/benchmarks/benchmarker/__pycache__/gsm8k.cpython-311.pyc +0 -0
  22. SpecForge-ext/benchmarks/benchmarker/__pycache__/humaneval.cpython-310.pyc +0 -0
  23. SpecForge-ext/benchmarks/benchmarker/__pycache__/humaneval.cpython-311.pyc +0 -0
  24. SpecForge-ext/benchmarks/benchmarker/__pycache__/livecodebench.cpython-310.pyc +0 -0
  25. SpecForge-ext/benchmarks/benchmarker/__pycache__/math500.cpython-310.pyc +0 -0
  26. SpecForge-ext/benchmarks/benchmarker/__pycache__/math500.cpython-311.pyc +0 -0
  27. SpecForge-ext/benchmarks/benchmarker/__pycache__/mmlu.cpython-310.pyc +0 -0
  28. SpecForge-ext/benchmarks/benchmarker/__pycache__/mmlu.cpython-311.pyc +0 -0
  29. SpecForge-ext/benchmarks/benchmarker/__pycache__/mmstar.cpython-310.pyc +0 -0
  30. SpecForge-ext/benchmarks/benchmarker/__pycache__/mmstar.cpython-311.pyc +0 -0
  31. SpecForge-ext/benchmarks/benchmarker/__pycache__/mtbench.cpython-310.pyc +0 -0
  32. SpecForge-ext/benchmarks/benchmarker/__pycache__/mtbench.cpython-311.pyc +0 -0
  33. SpecForge-ext/benchmarks/benchmarker/__pycache__/registry.cpython-310.pyc +0 -0
  34. SpecForge-ext/benchmarks/benchmarker/__pycache__/registry.cpython-311.pyc +0 -0
  35. SpecForge-ext/benchmarks/benchmarker/__pycache__/simpleqa.cpython-310.pyc +0 -0
  36. SpecForge-ext/benchmarks/benchmarker/__pycache__/simpleqa.cpython-311.pyc +0 -0
  37. SpecForge-ext/benchmarks/benchmarker/__pycache__/utils.cpython-310.pyc +0 -0
  38. SpecForge-ext/benchmarks/benchmarker/__pycache__/utils.cpython-311.pyc +0 -0
  39. SpecForge-ext/benchmarks/benchmarker/aime.py +133 -0
  40. SpecForge-ext/benchmarks/benchmarker/base.py +218 -0
  41. SpecForge-ext/benchmarks/benchmarker/ceval.py +267 -0
  42. SpecForge-ext/benchmarks/benchmarker/financeqa.py +59 -0
  43. SpecForge-ext/benchmarks/benchmarker/gpqa.py +85 -0
  44. SpecForge-ext/benchmarks/benchmarker/gsm8k.py +108 -0
  45. SpecForge-ext/benchmarks/benchmarker/humaneval.py +201 -0
  46. SpecForge-ext/benchmarks/benchmarker/livecodebench.py +46 -0
  47. SpecForge-ext/benchmarks/benchmarker/math500.py +122 -0
  48. SpecForge-ext/benchmarks/benchmarker/mmlu.py +82 -0
  49. SpecForge-ext/benchmarks/benchmarker/mmstar.py +185 -0
  50. SpecForge-ext/benchmarks/benchmarker/mtbench.py +70 -0
SpecForge-ext/benchmarks/README.md ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Benchmarking for Speculative Decoding
2
+
3
+ ## Overview
4
+
5
+ We provided a unified script to test the performance of the Speculative Decoding with EAGLE3 algorithm on multiple datasets. You can follow the steps below to run the benchmarks.
6
+
7
+ ## Run Benchmarks
8
+
9
+ ### Launch SGLang and Benchmarker Concurrently
10
+
11
+ `bench_eagle3.py` can help you launch a SGLang server process and a Benchmarking process concurrently. In this way, you don't have to launch the SGLang server manually, this script will manually handle the SGLang launch under different speculative decoding configurations. Some important arguments are:
12
+ - `--model-path`: the path to the target model.
13
+ - `--speculative-draft-model-path`: the path to the draft model.
14
+ - `--port`: the port to launch the SGLang server.
15
+ - `--trust-remote-code`: trust the remote code.
16
+ - `--mem-fraction-static`: the memory fraction for the static memory.
17
+ - `--tp-size`: the tensor parallelism size.
18
+ - `--attention-backend`: the attention backend.
19
+ - `--config-list`: the list of speculative decoding configuration to test, the format is `<batch-size>,<num-steps>,<topk>,<num-draft-tokens>`.
20
+ - `--benchmark-list`: the list of benchmarks to test, the format is `<benchmark-name>:<num-prompts>:<subset>`.
21
+
22
+ ```shell
23
+ python3 bench_eagle3.py \
24
+ --model-path meta-llama/Llama-3.1-8B-Instruct \
25
+ --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \
26
+ --port 30000 \
27
+ --trust-remote-code \
28
+ --mem-fraction-static 0.8 \
29
+ --tp-size 1 \
30
+ --attention-backend fa3 \
31
+ --config-list 1,0,0,0 1,3,1,4 \
32
+ --benchmark-list mtbench gsm8k:5 ceval:5:accountant \
33
+ --dtype bfloat16
34
+ ```
35
+
36
+ ### Launch Benchmarker Independently
37
+
38
+ If you want to launch the SGLang server independently, you can use the following command.
39
+
40
+ ```shell
41
+ # you can launch a server
42
+ python3 -m sglang.launch_server \
43
+ --model meta-llama/Llama-3.1-8B-Instruct \
44
+ --speculative-algorithm EAGLE3 \
45
+ --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \
46
+ --speculative-num-steps 3 \
47
+ --speculative-eagle-topk 1 \
48
+ --speculative-num-draft-tokens 4 \
49
+ --mem-fraction-static 0.75 \
50
+ --cuda-graph-max-bs 1 \
51
+ --tp 1 \
52
+ --trust-remote-code \
53
+ --host 0.0.0.0 \
54
+ --port 30000 \
55
+ --dtype bfloat16
56
+ ```
57
+
58
+ Then we can start benchmarking. Note that you should use the same host and port as the one used in the SGLang server. Note that `--skip-launch-server` is required to skip the launch of the SGLang server.
59
+
60
+ ```bash
61
+ python bench_eagle3.py \
62
+ --model-path meta-llama/Llama-3.1-8B-Instruct \
63
+ --port 30000 \
64
+ --config-list 1,3,1,4 \
65
+ --benchmark-list mtbench:5 ceval:5:accountant gsm8k:5 humaneval:5 math500:5 mtbench:5 aime:1 \
66
+ --skip-launch-server
67
+ ```
SpecForge-ext/benchmarks/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Benchmark scripts for speculative decoding evaluation.
3
+ """
SpecForge-ext/benchmarks/bench_eagle3.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Usage:
4
+
5
+ # if you want to run benchmarks directly
6
+ # mtbench:20 means only run 20 samples in the dataset
7
+ python bench_eagle3.py \
8
+ --model meta-llama/Llama-3.1-8B-Instruct \
9
+ --speculative-algorithm EAGLE3 \
10
+ --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \
11
+ --port 30000 \
12
+ --config-list 1,0,0,0 1,3,1,4 \
13
+ --benchmark-list mtbench:20 \
14
+ --dtype bfloat16
15
+
16
+
17
+ or if you want run sglang alone.
18
+
19
+ # launch sglang
20
+ python3 -m sglang.launch_server \
21
+ --model meta-llama/Llama-3.1-8B-Instruct \
22
+ --speculative-algorithm EAGLE3 \
23
+ --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \
24
+ --speculative-num-steps 3 \
25
+ --speculative-eagle-topk 1 \
26
+ --speculative-num-draft-tokens 4 \
27
+ --mem-fraction-static 0.75 \
28
+ --cuda-graph-max-bs 1 \
29
+ --tp 1 \
30
+ --trust-remote-code \
31
+ --host 0.0.0.0 \
32
+ --port 30000 \
33
+ --dtype bfloat16
34
+
35
+ # then run benchmarks
36
+ python bench_eagle3.py \
37
+ --model-path meta-llama/Llama-3.1-8B-Instruct \
38
+ --port 30000 \
39
+ --config-list 1,0,0,0 \
40
+ --benchmark-list mtbench:80 \
41
+ --dtype bfloat16 \
42
+ --skip-launch-server
43
+ """
44
+ import argparse
45
+ import json
46
+ import os
47
+ import time
48
+ from dataclasses import asdict
49
+ from typing import List
50
+
51
+ import requests
52
+ from benchmarker import BENCHMARKS
53
+ from sglang.srt.server_args import ServerArgs
54
+ from sglang.test.test_utils import kill_process_tree, popen_launch_server
55
+ from sglang.utils import wait_for_server
56
+
57
+
58
+ def parse_args():
59
+ parser = argparse.ArgumentParser()
60
+ sglang_group = parser.add_argument_group("sglang")
61
+ ServerArgs.add_cli_args(sglang_group)
62
+
63
+ # make the follow args a group
64
+ benchmark_group = parser.add_argument_group("benchmark")
65
+ benchmark_group.add_argument(
66
+ "--skip-launch-server", action="store_true", default=False
67
+ )
68
+ benchmark_group.add_argument("--timeout-for-server-launch", type=int, default=600)
69
+ benchmark_group.add_argument("--num-prompts", type=int, default=80)
70
+ benchmark_group.add_argument("--output-dir", type=str, default="./results")
71
+ benchmark_group.add_argument(
72
+ "--config-list", type=str, nargs="+", default=["1,0,0,0", "1,3,1,4"]
73
+ )
74
+ benchmark_group.add_argument(
75
+ "--name",
76
+ type=str,
77
+ default=None,
78
+ help="name of this benchmark run, if provided, will be added to the output file name",
79
+ )
80
+ benchmark_group.add_argument(
81
+ "--benchmark-list",
82
+ type=str,
83
+ nargs="+",
84
+ default=[
85
+ "mtbench:80",
86
+ "gsm8k:200",
87
+ "humaneval:200",
88
+ "math500:200",
89
+ "ceval:200",
90
+ ],
91
+ help=f"The list of benchmarks to run. The format is <benchmark-name>:<num-prompts>:<subset>,<subset>. We support the following benchmarks: {', '.join(BENCHMARKS.benchmarks.keys())}",
92
+ )
93
+ benchmark_group.add_argument(
94
+ "--enable-multi-turn-conversation",
95
+ action="store_true",
96
+ default=False,
97
+ )
98
+ return parser.parse_args()
99
+
100
+
101
+ def launch_sglang_server(
102
+ server_args: ServerArgs,
103
+ base_url: str,
104
+ batch_size: int,
105
+ steps: int,
106
+ topk: int,
107
+ num_draft_tokens: int,
108
+ timeout: int,
109
+ ):
110
+ """
111
+ This function launches the SGLang server with the given server arguments.
112
+ """
113
+ sglang_args: List[str] = []
114
+ if steps > 0:
115
+ sglang_args.extend(
116
+ [
117
+ "--speculative-algorithm",
118
+ "EAGLE3",
119
+ "--speculative-num-steps",
120
+ str(steps),
121
+ "--speculative-eagle-topk",
122
+ str(topk),
123
+ "--speculative-num-draft-tokens",
124
+ str(num_draft_tokens),
125
+ "--speculative-draft-model-path",
126
+ server_args.speculative_draft_model_path,
127
+ ]
128
+ )
129
+
130
+ sglang_args.extend(
131
+ [
132
+ "--cuda-graph-max-bs",
133
+ str(batch_size),
134
+ "--mem-fraction-static",
135
+ str(server_args.mem_fraction_static),
136
+ "--tp-size",
137
+ str(server_args.tp_size),
138
+ "--max-running-requests",
139
+ str(batch_size),
140
+ ]
141
+ )
142
+
143
+ if server_args.trust_remote_code:
144
+ sglang_args.extend(["--trust-remote-code"])
145
+
146
+ if server_args.disable_radix_cache:
147
+ sglang_args.extend(["--disable-radix-cache"])
148
+
149
+ if server_args.ep_size:
150
+ sglang_args.extend(["--ep-size", str(server_args.ep_size)])
151
+
152
+ if server_args.attention_backend:
153
+ sglang_args.extend(["--attention-backend", server_args.attention_backend])
154
+
155
+ if server_args.quantization:
156
+ sglang_args.extend(["--quantization", server_args.quantization])
157
+
158
+ if server_args.dtype:
159
+ sglang_args.extend(["--dtype", server_args.dtype])
160
+
161
+ process = popen_launch_server(
162
+ server_args.model_path,
163
+ base_url,
164
+ timeout=timeout,
165
+ other_args=sglang_args,
166
+ env={
167
+ "SGLANG_RECORD_STEP_TIME": "1",
168
+ "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN": "1",
169
+ **os.environ,
170
+ },
171
+ )
172
+ return process
173
+
174
+
175
+ def send_flush_cache_request(base_url: str):
176
+ requests.post(base_url + "/flush_cache")
177
+
178
+
179
+ def main():
180
+ args = parse_args()
181
+ server_args: ServerArgs = ServerArgs.from_cli_args(args)
182
+ configs = [tuple(map(int, config.split(","))) for config in args.config_list]
183
+
184
+ # split the arg into list of (bench_name, num_prompts)
185
+ benchmark_list = []
186
+ for item in args.benchmark_list:
187
+ splits = item.split(":")
188
+ if len(splits) == 1:
189
+ bench_name = splits[0]
190
+ num_prompts = None
191
+ subset = None
192
+ elif len(splits) == 2:
193
+ bench_name, num_prompts = splits
194
+ subset = None
195
+ elif len(splits) == 3:
196
+ bench_name, num_prompts, subset = splits
197
+ subset = subset.split(",")
198
+ else:
199
+ raise ValueError(f"Invalid benchmark list format: {item}")
200
+ benchmark_list.append((bench_name, num_prompts, subset))
201
+ assert len(benchmark_list) != 0, "the number of benchmark list is 0"
202
+
203
+ base_url = f"http://localhost:{args.port}"
204
+
205
+ results = {}
206
+ results["model"] = server_args.speculative_draft_model_path
207
+
208
+ def run_benchmarks(batch_size: int, steps: int, topk: int, num_draft_tokens: int):
209
+ for benchmark_name, num_prompts, subset in benchmark_list:
210
+ print(
211
+ f"Running benchmark {benchmark_name} with {num_prompts} prompts, batch size {batch_size}, steps {steps}, topk {topk}, num_draft_tokens {num_draft_tokens}, subset {subset}"
212
+ )
213
+ benchmarkder_cls = BENCHMARKS.get(benchmark_name)
214
+ num_prompts = int(num_prompts) if num_prompts is not None else None
215
+ if subset is None:
216
+ benchmarker = benchmarkder_cls(num_samples=num_prompts)
217
+ else:
218
+ benchmarker = benchmarkder_cls(num_samples=num_prompts, subset=subset)
219
+ metrics_list = benchmarker.run(
220
+ host=args.host, port=args.port, batch_size=batch_size
221
+ )
222
+ send_flush_cache_request(base_url)
223
+ if benchmark_name not in results:
224
+ results[benchmark_name] = []
225
+ results[benchmark_name].append(
226
+ dict(
227
+ batch_size=batch_size,
228
+ steps=steps,
229
+ topk=topk,
230
+ num_draft_tokens=num_draft_tokens,
231
+ metrics=[asdict(metric) for metric in metrics_list],
232
+ num_samples=num_prompts,
233
+ )
234
+ )
235
+
236
+ if args.skip_launch_server:
237
+ batch_size = configs[0][0] if len(configs) > 0 else 8
238
+ run_benchmarks(batch_size, None, None, None)
239
+ else:
240
+ # we itearate over each config from args
241
+ for batch_size, steps, topk, num_draft_tokens in configs:
242
+ process = launch_sglang_server(
243
+ server_args,
244
+ base_url,
245
+ batch_size,
246
+ steps,
247
+ topk,
248
+ num_draft_tokens,
249
+ args.timeout_for_server_launch,
250
+ )
251
+ wait_for_server(base_url)
252
+ run_benchmarks(batch_size, steps, topk, num_draft_tokens)
253
+ kill_process_tree(process.pid)
254
+ process.wait()
255
+
256
+ os.makedirs(args.output_dir, exist_ok=True)
257
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
258
+ result_file = os.path.join(
259
+ args.output_dir,
260
+ f"{args.name + '_' if args.name else ''}results_{timestamp}.jsonl",
261
+ )
262
+ with open(result_file, "w") as f:
263
+ json.dump(results, f, indent=4)
264
+ print(f"Results saved to {result_file}")
265
+
266
+
267
+ if __name__ == "__main__":
268
+ main()
SpecForge-ext/benchmarks/benchmarker/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .aime import AIMEBenchmarker
2
+ from .ceval import CEvalBenchmarker
3
+ from .financeqa import FinanceQABenchmarker
4
+ from .gpqa import GPQABenchmarker
5
+ from .gsm8k import GSM8KBenchmarker
6
+ from .humaneval import HumanEvalBenchmarker
7
+ from .livecodebench import LCBBenchmarker
8
+ from .math500 import Math500Benchmarker
9
+ from .mmlu import MMLUBenchmarker
10
+ from .mmstar import MMStarBenchmarker
11
+ from .mtbench import MTBenchBenchmarker
12
+ from .registry import BENCHMARKS
13
+ from .simpleqa import SimpleQABenchmarker
14
+
15
+ __all__ = [
16
+ "BENCHMARKS",
17
+ "AIMEBenchmarker",
18
+ "CEvalBenchmarker",
19
+ "GSM8KBenchmarker",
20
+ "HumanEvalBenchmarker",
21
+ "Math500Benchmarker",
22
+ "MTBenchBenchmarker",
23
+ "MMStarBenchmarker",
24
+ "GPQABenchmarker",
25
+ "FinanceQABenchmarker",
26
+ "MMLUBenchmarker",
27
+ "LCBBenchmarker",
28
+ "SimpleQABenchmarker",
29
+ ]
SpecForge-ext/benchmarks/benchmarker/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (872 Bytes). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.11 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (896 Bytes). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/aime.cpython-310.pyc ADDED
Binary file (4.16 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/aime.cpython-311.pyc ADDED
Binary file (6.79 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/aime.cpython-312.pyc ADDED
Binary file (5.8 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/base.cpython-310.pyc ADDED
Binary file (6.47 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/base.cpython-311.pyc ADDED
Binary file (9.05 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/base.cpython-312.pyc ADDED
Binary file (8.11 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/ceval.cpython-310.pyc ADDED
Binary file (6.57 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/ceval.cpython-311.pyc ADDED
Binary file (11.7 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/financeqa.cpython-310.pyc ADDED
Binary file (2.07 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/financeqa.cpython-311.pyc ADDED
Binary file (3.34 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/gpqa.cpython-310.pyc ADDED
Binary file (3.24 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/gpqa.cpython-311.pyc ADDED
Binary file (5.36 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/gsm8k.cpython-310.pyc ADDED
Binary file (3.94 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/gsm8k.cpython-311.pyc ADDED
Binary file (6.72 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/humaneval.cpython-310.pyc ADDED
Binary file (4.88 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/humaneval.cpython-311.pyc ADDED
Binary file (8.95 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/livecodebench.cpython-310.pyc ADDED
Binary file (1.87 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/math500.cpython-310.pyc ADDED
Binary file (3.73 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/math500.cpython-311.pyc ADDED
Binary file (6.26 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/mmlu.cpython-310.pyc ADDED
Binary file (3.11 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/mmlu.cpython-311.pyc ADDED
Binary file (5.19 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/mmstar.cpython-310.pyc ADDED
Binary file (5.08 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/mmstar.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/mtbench.cpython-310.pyc ADDED
Binary file (3.07 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/mtbench.cpython-311.pyc ADDED
Binary file (4.63 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/registry.cpython-310.pyc ADDED
Binary file (1.22 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/registry.cpython-311.pyc ADDED
Binary file (1.51 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/simpleqa.cpython-310.pyc ADDED
Binary file (1.81 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/simpleqa.cpython-311.pyc ADDED
Binary file (2.83 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/utils.cpython-310.pyc ADDED
Binary file (8.59 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/__pycache__/utils.cpython-311.pyc ADDED
Binary file (13.7 kB). View file
 
SpecForge-ext/benchmarks/benchmarker/aime.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AIME benchmark
3
+ """
4
+
5
+ import re
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ from datasets import load_dataset
9
+
10
+ from .base import Benchmarker
11
+ from .registry import BENCHMARKS
12
+ from .utils import create_simple_sgl_function
13
+
14
+
15
+ def extract_aime_answer(output: str) -> Optional[str]:
16
+ """Extract final answer from AIME problem solution.
17
+
18
+ AIME answers are typically integers between 0 and 999, and are usually
19
+ in \boxed{} format.
20
+ """
21
+ # Try to find answer in \boxed{} format
22
+ boxed_pattern = r"\\boxed\{([^}]+)\}"
23
+ match = re.search(boxed_pattern, output)
24
+ if match:
25
+ answer = match.group(1).strip()
26
+ # Extract number from the boxed content
27
+ numbers = re.findall(r"\d+", answer)
28
+ if numbers:
29
+ return numbers[-1] # Take the last number (usually the final answer)
30
+ return answer
31
+
32
+ # Try to find answer in \boxed format (without braces)
33
+ boxed_pattern2 = r"\\boxed\s+(\d+)"
34
+ match = re.search(boxed_pattern2, output)
35
+ if match:
36
+ return match.group(1).strip()
37
+
38
+ # Look for patterns like "The answer is 42" or "Answer: 123"
39
+ answer_patterns = [
40
+ r"(?:answer|Answer|ANSWER)[\s:]+(\d+)",
41
+ r"(?:final\s+answer|Final\s+Answer)[\s:]+(\d+)",
42
+ r"(?:is|equals?|=\s*)(\d+)\s*$",
43
+ ]
44
+ for pattern in answer_patterns:
45
+ matches = re.findall(pattern, output, re.IGNORECASE)
46
+ if matches:
47
+ return matches[-1].strip()
48
+
49
+ # Fallback: extract the last integer in the text
50
+ numbers = re.findall(r"\b(\d+)\b", output)
51
+ if numbers:
52
+ # Filter to reasonable AIME answer range (0-999)
53
+ valid_numbers = [n for n in numbers if 0 <= int(n) <= 999]
54
+ if valid_numbers:
55
+ return valid_numbers[-1]
56
+
57
+ return None
58
+
59
+
60
+ @BENCHMARKS.register("aime")
61
+ class AIMEBenchmarker(Benchmarker):
62
+ """AIME benchmark implementation."""
63
+
64
+ def __init__(self, num_samples: Optional[int] = None):
65
+ super().__init__(num_samples, None)
66
+
67
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]:
68
+ """Load and preprocess AIME dataset."""
69
+ dataset = load_dataset("Maxwell-Jia/AIME_2024")["train"]
70
+ questions = []
71
+ labels = []
72
+ for idx, q in enumerate(dataset):
73
+ if self.num_samples is not None and idx >= self.num_samples:
74
+ break
75
+
76
+ questions.append({"question": q["Problem"]})
77
+ # Extract answer from Answer field
78
+ answer = None
79
+ if "Answer" in q:
80
+ answer = str(q["Answer"]).strip()
81
+ elif "answer" in q:
82
+ answer = str(q["answer"]).strip()
83
+ labels.append(answer)
84
+ return questions, labels
85
+
86
+ def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
87
+ """Extract answer from model output."""
88
+ return extract_aime_answer(output)
89
+
90
+ def compute_accuracy(
91
+ self, predictions: List[Any], labels: List[Any]
92
+ ) -> Optional[float]:
93
+ """Compute accuracy for AIME by comparing numeric answers."""
94
+ if not labels or len(labels) == 0:
95
+ return None
96
+ if all(label is None for label in labels):
97
+ return None
98
+
99
+ correct = 0
100
+ valid_count = 0
101
+ for pred, label in zip(predictions, labels):
102
+ if label is not None:
103
+ valid_count += 1
104
+ if pred is not None:
105
+ # Normalize answers for comparison
106
+ pred_normalized = str(pred).strip()
107
+ label_normalized = str(label).strip()
108
+ # Try exact match first
109
+ if pred_normalized == label_normalized:
110
+ correct += 1
111
+ else:
112
+ # Try numeric comparison
113
+ try:
114
+ pred_num = int(pred_normalized)
115
+ label_num = int(label_normalized)
116
+ if pred_num == label_num:
117
+ correct += 1
118
+ except ValueError:
119
+ pass
120
+
121
+ return correct / valid_count if valid_count > 0 else 0.0
122
+
123
+ def create_sgl_function(self):
124
+ """Create SGL function for AIME with reasoning prompt."""
125
+ return create_simple_sgl_function(
126
+ function_name="reasoning_gen",
127
+ answer_key="answer",
128
+ user_prefix="\nPlease reason step by step, and put your final answer within \\boxed{}.",
129
+ )
130
+
131
+ def get_max_new_tokens(self) -> int:
132
+ """AIME problems require more tokens."""
133
+ return 32768
SpecForge-ext/benchmarks/benchmarker/base.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base class for benchmark implementations.
3
+ """
4
+
5
+ import time
6
+ from abc import ABC, abstractmethod
7
+ from argparse import Namespace
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple
9
+
10
+ from sglang import set_default_backend
11
+ from sglang.test.test_utils import select_sglang_backend
12
+
13
+ from .utils import compute_metrics
14
+
15
+
16
+ class Benchmarker(ABC):
17
+ """
18
+ Base class for benchmark implementations.
19
+
20
+ Subclasses should implement:
21
+ - load_data(): Load and preprocess dataset
22
+ - create_sgl_function(): Create the SGL function for inference
23
+
24
+ Optional overrides:
25
+ - extract_answer(): Extract answer from model output (if needed)
26
+ - compute_accuracy(): Compute accuracy metric (if applicable)
27
+ - get_answer_keys(): Get list of answer keys for multi-turn conversations
28
+
29
+ Args:
30
+ num_samples: The number of samples to run the benchmark on. If not provided, all questions will be used.
31
+ subset: The subset of the dataset to run the benchmark on. If not provided, all subsets will be used.
32
+ """
33
+
34
+ def __init__(
35
+ self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None
36
+ ):
37
+ self.num_samples = num_samples
38
+ self.subset = subset
39
+
40
+ @abstractmethod
41
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[Any]]:
42
+ """
43
+ Load and preprocess the dataset.
44
+
45
+ Returns:
46
+ Tuple of (questions, labels) where:
47
+ - questions: List of question dicts for SGL function
48
+ - labels: List of ground truth labels (can be None if not applicable)
49
+ """
50
+ raise NotImplementedError
51
+
52
+ @abstractmethod
53
+ def create_sgl_function(self) -> Callable:
54
+ """
55
+ Create the SGL function for inference.
56
+
57
+ Returns:
58
+ SGL function decorated with @sgl.function
59
+ """
60
+ raise NotImplementedError
61
+
62
+ def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[Any]:
63
+ """
64
+ Extract answer from model output.
65
+
66
+ Args:
67
+ output: Raw model output string
68
+ label: Optional ground truth label for reference
69
+
70
+ Returns:
71
+ Extracted answer, or None if extraction fails
72
+ """
73
+ return output
74
+
75
+ def compute_accuracy(
76
+ self, predictions: List[Any], labels: List[Any]
77
+ ) -> Optional[float]:
78
+ """
79
+ Compute accuracy metric.
80
+
81
+ Args:
82
+ predictions: List of predicted answers
83
+ labels: List of ground truth labels
84
+
85
+ Returns:
86
+ Accuracy score (0-1), or None if not applicable
87
+ """
88
+ return None
89
+
90
+ def get_answer_keys(self) -> Optional[List[str]]:
91
+ """
92
+ Get list of answer keys for multi-turn conversations.
93
+
94
+ Returns:
95
+ List of answer keys (e.g., ["answer_1", "answer_2"]), or None for single-turn
96
+ """
97
+ return None
98
+
99
+ def get_max_new_tokens(self) -> int:
100
+ """
101
+ Get maximum number of new tokens to generate.
102
+
103
+ Returns:
104
+ Maximum tokens (default: 2048)
105
+ """
106
+ return 2048
107
+
108
+ def run(
109
+ self,
110
+ host: str,
111
+ port: int,
112
+ batch_size: int,
113
+ max_new_tokens: int = None,
114
+ num_runs: int = 1,
115
+ ):
116
+ """
117
+ Run the benchmark evaluation.
118
+
119
+ This method handles the common workflow:
120
+ 1. Initialize backend
121
+ 2. Load data
122
+ 3. Create SGL function
123
+ 4. Run inference loops
124
+ 5. Compute metrics
125
+ 6. Print results
126
+
127
+ Args:
128
+ host (str): The host of the SGLang server
129
+ port (int): The port of the SGLang server
130
+ batch_size (int): The number of prompts to process in parallel
131
+ num_samples (int): The number of samples to run the benchmark on. If not provided, all samples will be used.
132
+ max_new_tokens (int): Maximum number of new tokens to generate, default is 2048
133
+ num_runs (int): The number of times to run this benchmark, default is 1. You can set it to a larger number if you want to get more stable results.
134
+ """
135
+ if not host.startswith(("http://", "https://")):
136
+ host = f"http://{host}"
137
+ # Initialize backend
138
+ sglang_args = Namespace(host=host, port=port, backend="srt-no-parallel")
139
+ set_default_backend(select_sglang_backend(sglang_args))
140
+
141
+ # Load data
142
+ questions, labels = self.load_data()
143
+ if len(questions) == 0:
144
+ print("No valid questions found. Please check the dataset format.")
145
+ return
146
+
147
+ # Create SGL function
148
+ sgl_function = self.create_sgl_function()
149
+
150
+ # Run evaluation loops
151
+ metrics_list = []
152
+ answer_keys = self.get_answer_keys()
153
+ max_new_tokens = max_new_tokens or self.get_max_new_tokens()
154
+
155
+ for _ in range(num_runs):
156
+ tic = time.perf_counter()
157
+ states = sgl_function.run_batch(
158
+ questions,
159
+ temperature=0,
160
+ max_new_tokens=max_new_tokens,
161
+ num_threads=batch_size,
162
+ progress_bar=True,
163
+ )
164
+ latency = time.perf_counter() - tic
165
+
166
+ # Extract predictions
167
+ predictions = []
168
+ primary_answer_key = answer_keys[0] if answer_keys else "answer"
169
+ for i in range(len(states)):
170
+ # Access answer from state object (states[i] supports dict-like access)
171
+ output = states[i][primary_answer_key]
172
+ if isinstance(output, str):
173
+ extracted = self.extract_answer(
174
+ output,
175
+ (labels[i] if labels and i < len(labels) else None),
176
+ )
177
+ else:
178
+ extracted = output
179
+ predictions.append(extracted)
180
+
181
+ # Compute accuracy if applicable
182
+ accuracy = None
183
+ # Check if we have a labels list (even if all labels are None)
184
+ has_labels_list = labels and len(labels) > 0
185
+
186
+ if has_labels_list:
187
+ # Always call compute_accuracy if we have a labels list
188
+ # This allows it to return None, which will be displayed in print_results
189
+ accuracy = self.compute_accuracy(predictions, labels)
190
+ if accuracy is not None:
191
+ valid_count = sum(1 for p in predictions if p is not None)
192
+ if valid_count < len(predictions):
193
+ print(
194
+ f"Warning: {len(predictions) - valid_count} predictions could not be extracted."
195
+ )
196
+
197
+ # Compute performance metrics
198
+ metrics = compute_metrics(
199
+ states,
200
+ latency,
201
+ answer_key=primary_answer_key,
202
+ additional_answer_keys=(
203
+ answer_keys[1:] if answer_keys and len(answer_keys) > 1 else None
204
+ ),
205
+ )
206
+ # Always set accuracy if we have a labels list (even if compute_accuracy returns None)
207
+ # This allows print_results to show None when compute_accuracy returns None
208
+ if has_labels_list:
209
+ metrics.accuracy = (
210
+ accuracy # Can be None if compute_accuracy returns None
211
+ )
212
+ if accuracy is not None:
213
+ metrics.num_valid_predictions = sum(
214
+ 1 for p in predictions if p is not None
215
+ )
216
+
217
+ metrics_list.append(metrics)
218
+ return metrics_list
SpecForge-ext/benchmarks/benchmarker/ceval.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ C-Eval benchmark evaluation script.
3
+ """
4
+
5
+ import re
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ from datasets import concatenate_datasets, load_dataset
9
+
10
+ from .base import Benchmarker
11
+ from .registry import BENCHMARKS
12
+ from .utils import create_simple_sgl_function
13
+
14
+
15
+ def extract_answer(answer_str: str) -> str:
16
+ """Extract the answer choice (A, B, C, D) from the model output."""
17
+ # Try to find the answer in various formats
18
+ answer_str = answer_str.strip().upper()
19
+
20
+ # Direct match for single letter
21
+ match = re.search(r"\b([ABCD])\b", answer_str)
22
+ if match:
23
+ return match.group(1)
24
+
25
+ # Try to find answer in parentheses or brackets
26
+ for pattern in [
27
+ r"\(([ABCD])\)",
28
+ r"\[([ABCD])\]",
29
+ r"答案[::]\s*([ABCD])",
30
+ r"Answer[::]\s*([ABCD])",
31
+ ]:
32
+ match = re.search(pattern, answer_str, re.IGNORECASE)
33
+ if match:
34
+ return match.group(1).upper()
35
+
36
+ # Try to find the first occurrence of A, B, C, or D
37
+ match = re.search(r"([ABCD])", answer_str)
38
+ if match:
39
+ return match.group(1)
40
+
41
+ return None
42
+
43
+
44
+ def format_question(question: str, options: List[str]) -> str:
45
+ """Format the question with options."""
46
+ prompt = question + "\n\n选项:\n"
47
+ for i, option in enumerate(options):
48
+ prompt += f"{chr(65 + i)}. {option}\n"
49
+ prompt += "\n请从A、B、C、D中选择一个答案。"
50
+ return prompt
51
+
52
+
53
+ @BENCHMARKS.register("ceval")
54
+ class CEvalBenchmarker(Benchmarker):
55
+ """C-Eval benchmark implementation."""
56
+
57
+ def __init__(
58
+ self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None
59
+ ):
60
+ if subset is None:
61
+ subset = "all"
62
+ super().__init__(num_samples, subset)
63
+
64
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[str]]:
65
+ """Load and preprocess C-Eval dataset."""
66
+ all_configs = [
67
+ "accountant",
68
+ "advanced_mathematics",
69
+ "art_studies",
70
+ "basic_medicine",
71
+ "business_administration",
72
+ "chinese_language_and_literature",
73
+ "civil_servant",
74
+ "clinical_medicine",
75
+ "college_chemistry",
76
+ "college_economics",
77
+ "college_physics",
78
+ "college_programming",
79
+ "computer_architecture",
80
+ "computer_network",
81
+ "discrete_mathematics",
82
+ "education_science",
83
+ "electrical_engineer",
84
+ "environmental_impact_assessment_engineer",
85
+ "fire_engineer",
86
+ "high_school_biology",
87
+ "high_school_chemistry",
88
+ "high_school_chinese",
89
+ "high_school_geography",
90
+ "high_school_history",
91
+ "high_school_mathematics",
92
+ "high_school_physics",
93
+ "high_school_politics",
94
+ "ideological_and_moral_cultivation",
95
+ "law",
96
+ "legal_professional",
97
+ "logic",
98
+ "mao_zedong_thought",
99
+ "marxism",
100
+ "metrology_engineer",
101
+ "middle_school_biology",
102
+ "middle_school_chemistry",
103
+ "middle_school_geography",
104
+ "middle_school_history",
105
+ "middle_school_mathematics",
106
+ "middle_school_physics",
107
+ "middle_school_politics",
108
+ "modern_chinese_history",
109
+ "operating_system",
110
+ "physician",
111
+ "plant_protection",
112
+ "probability_and_statistics",
113
+ "professional_tour_guide",
114
+ "sports_science",
115
+ "tax_accountant",
116
+ "teacher_qualification",
117
+ "urban_and_rural_planner",
118
+ "veterinary_medicine",
119
+ ]
120
+
121
+ # Select configs to load
122
+ if self.subset == "all":
123
+ configs_to_load = all_configs
124
+ else:
125
+ for subset in self.subset:
126
+ assert (
127
+ subset in all_configs
128
+ ), f"Subset {subset} not found in C-Eval dataset"
129
+ configs_to_load = self.subset
130
+
131
+ # Load datasets
132
+ try:
133
+ datasets = []
134
+ for config in configs_to_load:
135
+ try:
136
+ ds = load_dataset("ceval/ceval-exam", name=config, split="test")
137
+ datasets.append(ds)
138
+ print(f"Loaded config '{config}' with {len(ds)} samples")
139
+ except Exception as e:
140
+ print(f"Warning: Failed to load config '{config}': {e}")
141
+ if len(datasets) == 0:
142
+ raise ValueError("No configs could be loaded")
143
+ dataset = concatenate_datasets(datasets)
144
+ print(
145
+ f"Successfully loaded C-Eval dataset with all configs (total: {len(dataset)} samples)"
146
+ )
147
+ except Exception as e:
148
+ print(e)
149
+ print(f"Failed to load C-Eval dataset from 'ceval/ceval-exam': {e}")
150
+ print("Please ensure the dataset is available or install it manually.")
151
+ print("You can try: pip install datasets")
152
+ print("Or download from: https://huggingface.co/datasets/ceval/ceval-exam")
153
+ return [], []
154
+
155
+ # Process questions
156
+ questions = []
157
+ labels = []
158
+ for idx, item in enumerate(dataset):
159
+ if self.num_samples is not None and idx >= self.num_samples:
160
+ break
161
+
162
+ # Handle different dataset formats
163
+ question_text = None
164
+ if "question" in item:
165
+ question_text = item["question"]
166
+ elif "inputs" in item:
167
+ question_text = item["inputs"]
168
+ elif "problem" in item:
169
+ question_text = item["problem"]
170
+ elif "content" in item:
171
+ question_text = item["content"]
172
+
173
+ if not question_text:
174
+ continue
175
+
176
+ # Get options - C-Eval typically has options as a list or dict
177
+ options = None
178
+ if "options" in item:
179
+ options = item["options"]
180
+ if isinstance(options, dict):
181
+ # Convert dict to list in order A, B, C, D
182
+ options = [
183
+ options.get("A", ""),
184
+ options.get("B", ""),
185
+ options.get("C", ""),
186
+ options.get("D", ""),
187
+ ]
188
+ elif isinstance(options, list):
189
+ # Ensure we have 4 options
190
+ while len(options) < 4:
191
+ options.append("")
192
+ elif "choices" in item:
193
+ options = item["choices"]
194
+ if isinstance(options, dict):
195
+ options = [
196
+ options.get("A", ""),
197
+ options.get("B", ""),
198
+ options.get("C", ""),
199
+ options.get("D", ""),
200
+ ]
201
+ else:
202
+ # Try to construct options from A, B, C, D fields
203
+ options = [
204
+ item.get("A", item.get("option_A", "")),
205
+ item.get("B", item.get("option_B", "")),
206
+ item.get("C", item.get("option_C", "")),
207
+ item.get("D", item.get("option_D", "")),
208
+ ]
209
+
210
+ # Filter out empty options
211
+ if options:
212
+ options = [str(opt).strip() for opt in options if opt]
213
+ if len(options) < 2: # Need at least 2 options
214
+ continue
215
+ else:
216
+ continue
217
+
218
+ # Get answer
219
+ answer = None
220
+ if "answer" in item:
221
+ answer = str(item["answer"]).upper().strip()
222
+ elif "target" in item:
223
+ answer = str(item["target"]).upper().strip()
224
+ elif "label" in item:
225
+ answer = str(item["label"]).upper().strip()
226
+ elif "correct" in item:
227
+ answer = str(item["correct"]).upper().strip()
228
+
229
+ # Validate answer
230
+ if answer and answer in ["A", "B", "C", "D"]:
231
+ # Format question
232
+ formatted_question = format_question(question_text, options)
233
+ questions.append({"question": formatted_question})
234
+ labels.append(answer)
235
+
236
+ if len(questions) == 0:
237
+ print("No valid questions found. Please check the dataset format.")
238
+ print(
239
+ "Sample item keys:",
240
+ list(dataset[0].keys()) if len(dataset) > 0 else "No items",
241
+ )
242
+ return [], []
243
+
244
+ return questions, labels
245
+
246
+ def create_sgl_function(self):
247
+ """Create SGL function for C-Eval."""
248
+ return create_simple_sgl_function(
249
+ function_name="get_ceval_answer",
250
+ answer_key="answer",
251
+ max_tokens=self.get_max_new_tokens(),
252
+ )
253
+
254
+ def extract_answer(self, output: str, label: Any = None) -> str:
255
+ """Extract answer choice from model output."""
256
+ return extract_answer(output)
257
+
258
+ def compute_accuracy(self, predictions: List[str], labels: List[str]) -> float:
259
+ """Compute accuracy metric."""
260
+ correct = 0
261
+ valid_count = 0
262
+ for i in range(len(predictions)):
263
+ if predictions[i] is not None: # Only count valid predictions
264
+ valid_count += 1
265
+ if predictions[i] == labels[i]:
266
+ correct += 1
267
+ return correct / valid_count if valid_count > 0 else 0.0
SpecForge-ext/benchmarks/benchmarker/financeqa.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple
2
+
3
+ from datasets import load_dataset
4
+
5
+ from .base import Benchmarker
6
+ from .registry import BENCHMARKS
7
+ from .utils import create_simple_sgl_function
8
+
9
+ QUESTION_PROMPT = """
10
+ Given the following context:
11
+
12
+ {context}
13
+
14
+ Can you answer the following question?
15
+
16
+ {question}
17
+ """.strip()
18
+
19
+
20
+ def generate_question(row: Dict[str, Any]) -> str:
21
+ if row["context"] is None:
22
+ return row["question"].strip()
23
+ else:
24
+ question = QUESTION_PROMPT.format(
25
+ context=row["context"].strip(),
26
+ question=row["question"].strip(),
27
+ )
28
+ return question
29
+
30
+
31
+ @BENCHMARKS.register("financeqa")
32
+ class FinanceQABenchmarker(Benchmarker):
33
+ """FinanceQA benchmark implementation."""
34
+
35
+ def __init__(self, num_samples: Optional[int] = None):
36
+ super().__init__(num_samples, None)
37
+
38
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
39
+ """Load and preprocess FinanceQA dataset."""
40
+ # Read data
41
+ ds = load_dataset("AfterQuery/FinanceQA")["test"]
42
+
43
+ questions = []
44
+ labels = []
45
+ for i in range((len(ds))):
46
+ if self.num_samples is not None and i >= self.num_samples:
47
+ break
48
+
49
+ question_text = generate_question(ds[i])
50
+ questions.append({"question": question_text})
51
+ labels.append(None)
52
+ return questions, labels
53
+
54
+ def create_sgl_function(self):
55
+ return create_simple_sgl_function(
56
+ function_name="get_financeqa_answer",
57
+ answer_key="answer",
58
+ max_tokens=self.get_max_new_tokens(),
59
+ )
SpecForge-ext/benchmarks/benchmarker/gpqa.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+
4
+ from datasets import load_dataset
5
+
6
+ from .base import Benchmarker
7
+ from .registry import BENCHMARKS
8
+ from .utils import create_simple_sgl_function
9
+
10
+ GPQA_QUERY_TEMPLATE = """
11
+ Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
12
+
13
+ {Question}
14
+
15
+ A) {A}
16
+ B) {B}
17
+ C) {C}
18
+ D) {D}
19
+ """.strip()
20
+
21
+
22
+ def generate_question(row: Dict[str, Any]) -> str:
23
+ gold_index = random.randint(0, 3)
24
+ choices = [
25
+ row["Incorrect Answer 1"],
26
+ row["Incorrect Answer 2"],
27
+ row["Incorrect Answer 3"],
28
+ ]
29
+ choices.insert(gold_index, row["Correct Answer"])
30
+
31
+ question = GPQA_QUERY_TEMPLATE.format(
32
+ Question=row["Question"].strip(),
33
+ A=choices[0].strip(),
34
+ B=choices[1].strip(),
35
+ C=choices[2].strip(),
36
+ D=choices[3].strip(),
37
+ )
38
+
39
+ # 0 means A, 1 means B, 2 means C, 3 means D
40
+ answer = ["A", "B", "C", "D"][gold_index]
41
+ return question, answer
42
+
43
+
44
+ @BENCHMARKS.register("gpqa")
45
+ class GPQABenchmarker(Benchmarker):
46
+ """GPQA benchmark implementation."""
47
+
48
+ def __init__(self, num_samples: Optional[int] = None):
49
+ super().__init__(num_samples, None)
50
+
51
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
52
+ """Load and preprocess GPQA dataset."""
53
+ # Read data
54
+ ds = load_dataset("Idavidrein/gpqa", "gpqa_main")["train"]
55
+
56
+ questions = []
57
+ labels = []
58
+ for i in range((len(ds))):
59
+ if self.num_samples is not None and i >= self.num_samples:
60
+ break
61
+
62
+ question_text, answer = generate_question(ds[i])
63
+ questions.append({"question": question_text})
64
+ labels.append(answer)
65
+ return questions, labels
66
+
67
+ def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]:
68
+ if "Answer: " not in output:
69
+ return None
70
+ return output.split("Answer: ")[1].strip()
71
+
72
+ def compute_accuracy(
73
+ self, predictions: List[Any], labels: List[Any]
74
+ ) -> Optional[float]:
75
+ if not labels or len(labels) == 0:
76
+ return None
77
+ correct = sum(1 for pred, label in zip(predictions, labels) if pred == label)
78
+ return correct / len(labels) if len(labels) > 0 else 0.0
79
+
80
+ def create_sgl_function(self):
81
+ return create_simple_sgl_function(
82
+ function_name="get_gpqa_mcq_answer",
83
+ answer_key="answer",
84
+ max_tokens=self.get_max_new_tokens(),
85
+ )
SpecForge-ext/benchmarks/benchmarker/gsm8k.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GSM8K benchmark evaluation script.
3
+ """
4
+
5
+ import ast
6
+ import os
7
+ import re
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ from sglang.utils import download_and_cache_file, read_jsonl
11
+
12
+ from .base import Benchmarker
13
+ from .registry import BENCHMARKS
14
+ from .utils import create_few_shot_sgl_function
15
+
16
+ INVALID = -9999999
17
+
18
+
19
+ def get_one_example(lines: List[Dict], i: int, include_answer: bool) -> str:
20
+ """Format a single example."""
21
+ ret = "Question: " + lines[i]["question"] + "\nAnswer:"
22
+ if include_answer:
23
+ ret += " " + lines[i]["answer"]
24
+ return ret
25
+
26
+
27
+ def get_few_shot_examples(lines: List[Dict], k: int) -> str:
28
+ """Get few-shot examples as a string."""
29
+ ret = ""
30
+ for i in range(k):
31
+ ret += get_one_example(lines, i, True) + "\n\n"
32
+ return ret
33
+
34
+
35
+ def get_answer_value(answer_str: str) -> int:
36
+ """Extract numeric answer from model output."""
37
+ answer_str = answer_str.replace(",", "")
38
+ numbers = re.findall(r"\d+", answer_str)
39
+ if len(numbers) < 1:
40
+ return INVALID
41
+ try:
42
+ return ast.literal_eval(numbers[-1])
43
+ except SyntaxError:
44
+ return INVALID
45
+
46
+
47
+ @BENCHMARKS.register("gsm8k")
48
+ class GSM8KBenchmarker(Benchmarker):
49
+ """GSM8K benchmark implementation."""
50
+
51
+ def __init__(self, num_samples: Optional[int] = None):
52
+ super().__init__(num_samples, None)
53
+
54
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
55
+ """Load and preprocess GSM8K dataset."""
56
+ # 优先从本地数据目录读取
57
+ local_path = "/workspace/hanrui/datasets/gsm8k/test.jsonl"
58
+
59
+ if os.path.exists(local_path):
60
+ print(f"Loading GSM8K data from local: {local_path}")
61
+ lines = list(read_jsonl(local_path))
62
+ else:
63
+ # 如果本地不存在,从网络下载
64
+ print(f"Local data not found, downloading from GitHub...")
65
+ url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
66
+ data_path = download_and_cache_file(url)
67
+ lines = list(read_jsonl(data_path))
68
+
69
+ # Construct prompts
70
+ few_shot_examples = get_few_shot_examples(lines, 5)
71
+
72
+ questions = []
73
+ labels = []
74
+ for i in range((len(lines))):
75
+ if self.num_samples is not None and i >= self.num_samples:
76
+ break
77
+
78
+ question_text = get_one_example(lines, i, False)
79
+ questions.append({"question": question_text})
80
+ labels.append(get_answer_value(lines[i]["answer"]))
81
+
82
+ # Store few_shot_examples for use in create_sgl_function
83
+ self.few_shot_examples = few_shot_examples
84
+
85
+ assert all(l != INVALID for l in labels), "Some labels are invalid"
86
+ return questions, labels
87
+
88
+ def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]:
89
+ """Extract numeric answer from model output."""
90
+ return get_answer_value(output)
91
+
92
+ def compute_accuracy(
93
+ self, predictions: List[Any], labels: List[Any]
94
+ ) -> Optional[float]:
95
+ """Compute accuracy for GSM8K by comparing numeric answers."""
96
+ if not labels or len(labels) == 0:
97
+ return None
98
+ correct = sum(1 for pred, label in zip(predictions, labels) if pred == label)
99
+ return correct / len(labels) if len(labels) > 0 else 0.0
100
+
101
+ def create_sgl_function(self):
102
+ """Create SGL function for GSM8K with few-shot examples."""
103
+ return create_few_shot_sgl_function(
104
+ few_shot_examples=self.few_shot_examples,
105
+ function_name="few_shot_gsm8k",
106
+ answer_key="answer",
107
+ stop=["Question", "Assistant:", "<|separator|>"],
108
+ )
SpecForge-ext/benchmarks/benchmarker/humaneval.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HumanEval benchmark evaluation script.
3
+ """
4
+
5
+ import json
6
+ import os
7
+ import re
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ from datasets import load_dataset
11
+
12
+ from .base import Benchmarker
13
+ from .registry import BENCHMARKS
14
+ from .utils import create_simple_sgl_function
15
+
16
+
17
+ def extract_code_from_output(output: str) -> Optional[str]:
18
+ """Extract Python code from model output.
19
+
20
+ Tries to extract code blocks or function definitions.
21
+ """
22
+ # Try to find code in markdown code blocks
23
+ code_block_pattern = r"```(?:python)?\n(.*?)```"
24
+ match = re.search(code_block_pattern, output, re.DOTALL)
25
+ if match:
26
+ return match.group(1).strip()
27
+
28
+ # Try to find function definition (common in HumanEval)
29
+ # Look for "def " followed by code until the next def or end of string
30
+ def_pattern = r"(def\s+\w+\([^)]*\):.*?)(?=\n\ndef\s+|\Z)"
31
+ match = re.search(def_pattern, output, re.DOTALL)
32
+ if match:
33
+ return match.group(1).strip()
34
+
35
+ # Fallback: return the output as-is (might already be code)
36
+ return output.strip() if output.strip() else None
37
+
38
+
39
+ def check_code_passes_tests(code: str, test_code: str, entry_point: str) -> bool:
40
+ """Check if generated code passes the test cases.
41
+
42
+ This is a simplified version. For full evaluation, use the official
43
+ HumanEval evaluation framework.
44
+
45
+ HumanEval test code typically contains assertions that will raise
46
+ AssertionError if the code doesn't pass. If execution completes without
47
+ exceptions, the tests pass.
48
+ """
49
+ try:
50
+ # Create a safe execution environment
51
+ namespace = {}
52
+ # Execute the code (function definition)
53
+ exec(code, namespace)
54
+ # Execute the test code (which contains assertions)
55
+ # If no exception is raised, the tests pass
56
+ exec(test_code, namespace)
57
+ return True
58
+ except AssertionError:
59
+ # Assertion failed - test didn't pass
60
+ return False
61
+ except Exception:
62
+ # Any other exception (syntax error, runtime error, etc.) means test failed
63
+ return False
64
+
65
+
66
+ @BENCHMARKS.register("humaneval")
67
+ class HumanEvalBenchmarker(Benchmarker):
68
+ """HumanEval benchmark implementation."""
69
+
70
+ def __init__(self, num_samples: Optional[int] = None):
71
+ """Initialize benchmark and store test cases."""
72
+ super().__init__(num_samples, None)
73
+ self.test_cases = []
74
+ self.entry_points = []
75
+
76
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[Dict[str, str]]]]:
77
+ """Load and preprocess HumanEval dataset."""
78
+ # 优先从本地数据目录读取
79
+ local_path = "/workspace/hanrui/datasets/humaneval/test.jsonl"
80
+
81
+ if os.path.exists(local_path):
82
+ print(f"Loading HumanEval data from local: {local_path}")
83
+ with open(local_path, 'r') as f:
84
+ dataset = [json.loads(line) for line in f]
85
+ else:
86
+ # 如果本地不存在,从 HuggingFace 下载
87
+ print(f"Local data not found, downloading from HuggingFace...")
88
+ dataset = load_dataset("openai/openai_humaneval")["test"]
89
+
90
+ questions = []
91
+ labels = []
92
+ self.test_cases = []
93
+ self.entry_points = []
94
+
95
+ for idx, q in enumerate(dataset):
96
+ if self.num_samples is not None and idx >= self.num_samples:
97
+ break
98
+
99
+ questions.append({"question": q["prompt"]})
100
+
101
+ # Store test case and entry point for evaluation
102
+ test_code = q.get("test", "")
103
+ entry_point = q.get("entry_point", "")
104
+ self.test_cases.append(test_code)
105
+ self.entry_points.append(entry_point)
106
+
107
+ # Store canonical solution as reference (optional, for comparison)
108
+ canonical_solution = q.get("canonical_solution", "")
109
+ labels.append(
110
+ {
111
+ "test": test_code,
112
+ "entry_point": entry_point,
113
+ "canonical_solution": canonical_solution,
114
+ }
115
+ )
116
+
117
+ return questions, labels
118
+
119
+ def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
120
+ """Extract code from model output."""
121
+ return extract_code_from_output(output)
122
+
123
+ def compute_accuracy(
124
+ self, predictions: List[Any], labels: List[Any]
125
+ ) -> Optional[float]:
126
+ """Compute accuracy for HumanEval by checking if code passes tests.
127
+
128
+ Note: This is a simplified evaluation. For official pass@k metrics,
129
+ use the HumanEval evaluation framework.
130
+ """
131
+ if not labels or len(labels) == 0:
132
+ return None
133
+ if all(label is None for label in labels):
134
+ return None
135
+
136
+ correct = 0
137
+ valid_count = 0
138
+
139
+ for i, (pred, label) in enumerate(zip(predictions, labels)):
140
+ if label is not None and isinstance(label, dict):
141
+ valid_count += 1
142
+ if pred is not None:
143
+ try:
144
+ # Get the prompt (function signature and docstring)
145
+ prompt = self.questions[i]["question"]
146
+ entry_point = label.get("entry_point", "")
147
+
148
+ # The prompt contains the function signature (e.g., "def function_name(...):")
149
+ # The generated code might be:
150
+ # 1. Just the function body (what we want) - need to combine with prompt
151
+ # 2. The complete function including signature - use as-is
152
+ # 3. Code in markdown blocks - already extracted by extract_code_from_output
153
+
154
+ pred_str = str(pred).strip()
155
+
156
+ # Check if pred already contains a complete function definition
157
+ # (starts with "def " and contains the entry_point function name)
158
+ if pred_str.startswith("def ") and entry_point:
159
+ # Check if this is the same function (by name)
160
+ func_name_match = re.match(r"def\s+(\w+)\s*\(", pred_str)
161
+ if (
162
+ func_name_match
163
+ and func_name_match.group(1) == entry_point
164
+ ):
165
+ # Generated code includes complete function, use it as-is
166
+ full_code = pred_str
167
+ else:
168
+ # Different function or no match, combine with prompt
169
+ full_code = prompt + "\n" + pred_str
170
+ elif pred_str.startswith("def "):
171
+ # Has function definition but we can't verify entry_point, use as-is
172
+ full_code = pred_str
173
+ else:
174
+ # Generated code is just the body, combine with prompt
175
+ full_code = prompt + "\n" + pred_str
176
+
177
+ # Check if code passes tests
178
+ test_code = label.get("test", "")
179
+
180
+ if test_code and check_code_passes_tests(
181
+ full_code, test_code, entry_point
182
+ ):
183
+ correct += 1
184
+ except Exception as e:
185
+ # If evaluation fails, consider it incorrect
186
+ # Uncomment for debugging: print(f"Error evaluating code {i}: {e}")
187
+ pass
188
+
189
+ return correct / valid_count if valid_count > 0 else 0.0
190
+
191
+ def create_sgl_function(self):
192
+ """Create SGL function for HumanEval."""
193
+ return create_simple_sgl_function(
194
+ function_name="get_humaneval_answer",
195
+ answer_key="answer",
196
+ max_tokens=self.get_max_new_tokens(),
197
+ )
198
+
199
+ def get_max_new_tokens(self) -> int:
200
+ """HumanEval code generation requires more tokens."""
201
+ return 1024
SpecForge-ext/benchmarks/benchmarker/livecodebench.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GSM8K benchmark evaluation script.
3
+ """
4
+
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ from datasets import load_dataset
8
+
9
+ from .base import Benchmarker
10
+ from .registry import BENCHMARKS
11
+ from .utils import create_simple_sgl_function
12
+
13
+
14
+ def generate_question(row: Dict[str, Any]) -> str:
15
+ question = row["question_content"].strip()
16
+ return question
17
+
18
+
19
+ @BENCHMARKS.register("livecodebench")
20
+ class LCBBenchmarker(Benchmarker):
21
+ """LiveCodeBench benchmark implementation."""
22
+
23
+ def __init__(self, num_samples: Optional[int] = None):
24
+ super().__init__(num_samples, None)
25
+
26
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
27
+ # Read data
28
+ ds = load_dataset("livecodebench/code_generation")["test"]
29
+
30
+ questions = []
31
+ labels = []
32
+ for i in range((len(ds))):
33
+ if self.num_samples is not None and i >= self.num_samples:
34
+ break
35
+
36
+ question_text = generate_question(ds[i])
37
+ questions.append({"question": question_text})
38
+ labels.append(None)
39
+ return questions, labels
40
+
41
+ def create_sgl_function(self):
42
+ return create_simple_sgl_function(
43
+ function_name="get_livecodebench_answer",
44
+ answer_key="answer",
45
+ max_tokens=self.get_max_new_tokens(),
46
+ )
SpecForge-ext/benchmarks/benchmarker/math500.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MATH-500 benchmark evaluation script.
3
+ """
4
+
5
+ import re
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ from datasets import load_dataset
9
+
10
+ from .base import Benchmarker
11
+ from .registry import BENCHMARKS
12
+ from .utils import create_simple_sgl_function
13
+
14
+
15
+ def extract_math_answer(output: str) -> Optional[str]:
16
+ """Extract final answer from math problem solution.
17
+
18
+ Tries to extract answer from \boxed{} format first, then looks for
19
+ the last number in the output.
20
+ """
21
+ # Try to find answer in \boxed{} format
22
+ boxed_pattern = r"\\boxed\{([^}]+)\}"
23
+ match = re.search(boxed_pattern, output)
24
+ if match:
25
+ return match.group(1).strip()
26
+
27
+ # Try to find answer in \boxed format (without braces)
28
+ boxed_pattern2 = r"\\boxed\s+([^\s]+)"
29
+ match = re.search(boxed_pattern2, output)
30
+ if match:
31
+ return match.group(1).strip()
32
+
33
+ # Try to find the last number (could be integer or decimal)
34
+ # Look for patterns like "The answer is 42" or "Answer: 3.14"
35
+ answer_patterns = [
36
+ r"(?:answer|Answer|ANSWER)[\s:]+([-+]?\d*\.?\d+)",
37
+ r"(?:is|equals?|=\s*)([-+]?\d*\.?\d+)\s*$",
38
+ ]
39
+ for pattern in answer_patterns:
40
+ matches = re.findall(pattern, output, re.IGNORECASE)
41
+ if matches:
42
+ return matches[-1].strip()
43
+
44
+ # Fallback: extract the last number in the text
45
+ numbers = re.findall(r"[-+]?\d*\.?\d+", output)
46
+ if numbers:
47
+ return numbers[-1]
48
+
49
+ return None
50
+
51
+
52
+ @BENCHMARKS.register("math500")
53
+ class Math500Benchmarker(Benchmarker):
54
+ """MATH-500 benchmark implementation."""
55
+
56
+ def __init__(self, num_samples: Optional[int] = None):
57
+ super().__init__(num_samples, None)
58
+
59
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]:
60
+ """Load and preprocess MATH-500 dataset."""
61
+ dataset = load_dataset("HuggingFaceH4/MATH-500")["test"]
62
+ questions = []
63
+ labels = []
64
+ for idx, q in enumerate(dataset):
65
+ if self.num_samples is not None and idx >= self.num_samples:
66
+ break
67
+
68
+ questions.append({"question": q["problem"]})
69
+ # Extract answer from solution or answer field
70
+ answer = None
71
+ if "answer" in q:
72
+ answer = str(q["answer"]).strip()
73
+ elif "solution" in q:
74
+ # Try to extract from solution
75
+ answer = extract_math_answer(q["solution"])
76
+ labels.append(answer)
77
+ return questions, labels
78
+
79
+ def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
80
+ """Extract answer from model output."""
81
+ return extract_math_answer(output)
82
+
83
+ def compute_accuracy(
84
+ self, predictions: List[Any], labels: List[Any]
85
+ ) -> Optional[float]:
86
+ """Compute accuracy for MATH-500 by comparing answers."""
87
+ if not labels or len(labels) == 0:
88
+ return None
89
+ if all(label is None for label in labels):
90
+ return None
91
+
92
+ correct = 0
93
+ valid_count = 0
94
+ for pred, label in zip(predictions, labels):
95
+ if label is not None:
96
+ valid_count += 1
97
+ if pred is not None:
98
+ # Normalize answers for comparison (remove whitespace, handle different formats)
99
+ pred_normalized = str(pred).strip().lower()
100
+ label_normalized = str(label).strip().lower()
101
+ # Try exact match first
102
+ if pred_normalized == label_normalized:
103
+ correct += 1
104
+ else:
105
+ # Try numeric comparison if both are numbers
106
+ try:
107
+ pred_num = float(pred_normalized)
108
+ label_num = float(label_normalized)
109
+ if abs(pred_num - label_num) < 1e-6:
110
+ correct += 1
111
+ except ValueError:
112
+ pass
113
+
114
+ return correct / valid_count if valid_count > 0 else 0.0
115
+
116
+ def create_sgl_function(self):
117
+ """Create SGL function for MATH-500."""
118
+ return create_simple_sgl_function(
119
+ function_name="get_math500_answer",
120
+ answer_key="answer",
121
+ max_tokens=self.get_max_new_tokens(),
122
+ )
SpecForge-ext/benchmarks/benchmarker/mmlu.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple
2
+
3
+ from datasets import load_dataset
4
+
5
+ from .base import Benchmarker
6
+ from .registry import BENCHMARKS
7
+ from .utils import create_simple_sgl_function
8
+
9
+ GPQA_QUERY_TEMPLATE = """
10
+ Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
11
+
12
+ {Question}
13
+
14
+ A) {A}
15
+ B) {B}
16
+ C) {C}
17
+ D) {D}
18
+ """.strip()
19
+
20
+
21
+ def generate_question(row: Dict[str, Any]) -> str:
22
+ choices = row["choices"]
23
+ question = GPQA_QUERY_TEMPLATE.format(
24
+ Question=row["question"].strip(),
25
+ A=choices[0].strip(),
26
+ B=choices[1].strip(),
27
+ C=choices[2].strip(),
28
+ D=choices[3].strip(),
29
+ )
30
+
31
+ # 0 means A, 1 means B, 2 means C, 3 means D
32
+ answer = ["A", "B", "C", "D"][row["answer"]]
33
+ print(answer)
34
+ return question, answer
35
+
36
+
37
+ @BENCHMARKS.register("mmlu")
38
+ class MMLUBenchmarker(Benchmarker):
39
+ """MMLU benchmark implementation."""
40
+
41
+ def __init__(
42
+ self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None
43
+ ):
44
+ if subset is None:
45
+ subset = ["all"]
46
+ super().__init__(num_samples, subset)
47
+
48
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
49
+ # Read data
50
+ questions = []
51
+ labels = []
52
+
53
+ for subset in self.subset:
54
+ ds = load_dataset("cais/mmlu", subset)["test"]
55
+ for i in range((len(ds))):
56
+ if self.num_samples is not None and i >= self.num_samples:
57
+ break
58
+
59
+ question_text, answer = generate_question(ds[i])
60
+ questions.append({"question": question_text})
61
+ labels.append(answer)
62
+ return questions, labels
63
+
64
+ def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]:
65
+ if "Answer: " not in output:
66
+ return None
67
+ return output.split("Answer: ")[1].strip()
68
+
69
+ def compute_accuracy(
70
+ self, predictions: List[Any], labels: List[Any]
71
+ ) -> Optional[float]:
72
+ if not labels or len(labels) == 0:
73
+ return None
74
+ correct = sum(1 for pred, label in zip(predictions, labels) if pred == label)
75
+ return correct / len(labels) if len(labels) > 0 else 0.0
76
+
77
+ def create_sgl_function(self):
78
+ return create_simple_sgl_function(
79
+ function_name="get_mmlu_answer",
80
+ answer_key="answer",
81
+ max_tokens=self.get_max_new_tokens(),
82
+ )
SpecForge-ext/benchmarks/benchmarker/mmstar.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MMStar benchmark evaluation script.
3
+ """
4
+
5
+ import os
6
+ import re
7
+ import shutil
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ from datasets import load_dataset
11
+
12
+ from .base import Benchmarker
13
+ from .registry import BENCHMARKS
14
+ from .utils import create_image_sgl_function
15
+
16
+
17
+ def extract_mmstar_answer(
18
+ output: str, options: Optional[List[str]] = None
19
+ ) -> Optional[str]:
20
+ """Extract answer from MMStar model output.
21
+
22
+ MMStar questions typically have multiple choice options (A, B, C, D, etc.)
23
+ """
24
+ output_upper = output.strip().upper()
25
+
26
+ # Try to find answer choice (A, B, C, D, etc.)
27
+ # Direct match for single letter
28
+ match = re.search(r"\b([A-Z])\b", output_upper)
29
+ if match:
30
+ letter = match.group(1)
31
+ if options and len(options) > 0:
32
+ # Validate that the letter is within valid range
33
+ max_option = chr(64 + len(options)) # 'A' + (len-1)
34
+ if "A" <= letter <= max_option:
35
+ return letter
36
+ else:
37
+ # Assume A-D are valid
38
+ if "A" <= letter <= "D":
39
+ return letter
40
+
41
+ # Try to find answer in parentheses or brackets
42
+ for pattern in [
43
+ r"\(([A-Z])\)",
44
+ r"\[([A-Z])\]",
45
+ r"答案[::]\s*([A-Z])",
46
+ r"Answer[::]\s*([A-Z])",
47
+ r"选择[::]\s*([A-Z])",
48
+ ]:
49
+ match = re.search(pattern, output_upper)
50
+ if match:
51
+ letter = match.group(1)
52
+ if options and len(options) > 0:
53
+ max_option = chr(64 + len(options))
54
+ if "A" <= letter <= max_option:
55
+ return letter
56
+ elif "A" <= letter <= "D":
57
+ return letter
58
+
59
+ return None
60
+
61
+
62
+ @BENCHMARKS.register("mmstar")
63
+ class MMStarBenchmarker(Benchmarker):
64
+ """MMStar benchmark implementation."""
65
+
66
+ def __init__(self, num_samples: Optional[int] = None):
67
+ super().__init__(num_samples, None)
68
+ """Initialize benchmark and set up cache directory."""
69
+ self.cache_dir = None
70
+ self.options_list = [] # Store options for each question
71
+
72
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]:
73
+ """Load and preprocess MMStar dataset."""
74
+ self.cache_dir = os.path.join(".cache", "mmstar_specforge")
75
+ image_dir = os.path.join(self.cache_dir, "images")
76
+ os.makedirs(self.cache_dir, exist_ok=True)
77
+ os.makedirs(image_dir, exist_ok=True)
78
+ print(f"Created temporary image directory: {self.cache_dir}")
79
+
80
+ dataset = load_dataset("Lin-Chen/MMStar")["val"]
81
+ questions = []
82
+ labels = []
83
+ self.options_list = []
84
+
85
+ for idx, q in enumerate(dataset):
86
+ if self.num_samples is not None and idx >= self.num_samples:
87
+ break
88
+
89
+ image = q["image"]
90
+ image_path = os.path.join(self.cache_dir, q["meta_info"]["image_path"])
91
+ image.convert("RGB").save(image_path, "JPEG")
92
+
93
+ # Extract question and options
94
+ question_full = q["question"]
95
+ if "Options:" in question_full:
96
+ question_text, options_text = question_full.split("Options:", 1)
97
+ question_text = question_text.strip()
98
+ # Parse options (typically A. option1 B. option2 etc.)
99
+ options = []
100
+ for line in options_text.strip().split("\n"):
101
+ line = line.strip()
102
+ if line and re.match(r"^[A-Z]\.", line):
103
+ option_text = re.sub(r"^[A-Z]\.\s*", "", line).strip()
104
+ options.append(option_text)
105
+ self.options_list.append(options)
106
+ else:
107
+ question_text = question_full.strip()
108
+ self.options_list.append([])
109
+
110
+ item = {
111
+ "image_path": image_path,
112
+ "question": question_text,
113
+ }
114
+ questions.append(item)
115
+
116
+ # Extract ground truth answer
117
+ answer = None
118
+ if "answer" in q:
119
+ answer = str(q["answer"]).strip().upper()
120
+ elif "correct_answer" in q:
121
+ answer = str(q["correct_answer"]).strip().upper()
122
+ elif "ground_truth" in q:
123
+ answer = str(q["ground_truth"]).strip().upper()
124
+
125
+ # Validate answer is a valid option letter
126
+ if answer and len(answer) == 1 and "A" <= answer <= "Z":
127
+ if self.options_list[-1]:
128
+ max_option = chr(64 + len(self.options_list[-1]))
129
+ if answer <= max_option:
130
+ labels.append(answer)
131
+ else:
132
+ labels.append(None)
133
+ else:
134
+ labels.append(answer)
135
+ else:
136
+ labels.append(None)
137
+
138
+ return questions, labels
139
+
140
+ def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
141
+ """Extract answer from model output."""
142
+ # Use the options for the current question if available
143
+ # Note: We can't easily get the question index here, so we'll use a simpler approach
144
+ return extract_mmstar_answer(output)
145
+
146
+ def compute_accuracy(
147
+ self, predictions: List[Any], labels: List[Any]
148
+ ) -> Optional[float]:
149
+ """Compute accuracy for MMStar by comparing answer choices."""
150
+ if not labels or len(labels) == 0:
151
+ return None
152
+ if all(label is None for label in labels):
153
+ return None
154
+
155
+ correct = 0
156
+ valid_count = 0
157
+ for pred, label in zip(predictions, labels):
158
+ if label is not None:
159
+ valid_count += 1
160
+ if pred is not None:
161
+ # Normalize to uppercase for comparison
162
+ pred_normalized = str(pred).strip().upper()
163
+ label_normalized = str(label).strip().upper()
164
+ if pred_normalized == label_normalized:
165
+ correct += 1
166
+
167
+ return correct / valid_count if valid_count > 0 else 0.0
168
+
169
+ def create_sgl_function(self):
170
+ """Create SGL function for MMStar (image-based Q&A)."""
171
+ return create_image_sgl_function(
172
+ function_name="get_mmstar_answer",
173
+ answer_key="answer",
174
+ max_tokens=self.get_max_new_tokens(),
175
+ )
176
+
177
+ def run(self, *args, **kwargs):
178
+ """Run benchmark and clean up cache directory."""
179
+ try:
180
+ return super().run(*args, **kwargs)
181
+ finally:
182
+ # Clean up cache directory
183
+ if self.cache_dir and os.path.exists(self.cache_dir):
184
+ shutil.rmtree(self.cache_dir)
185
+ print(f"Deleted temporary directory: {self.cache_dir}")
SpecForge-ext/benchmarks/benchmarker/mtbench.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MT-Bench benchmark evaluation script.
3
+ Adapted from https://github.com/chromecast56/sglang/blob/6f145d2eadb93a116134f703358ce76f15381045/benchmark/mtbench/bench_sglang.py
4
+ """
5
+
6
+ import os
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ from sglang.utils import download_and_cache_file, read_jsonl
10
+
11
+ from .base import Benchmarker
12
+ from .registry import BENCHMARKS
13
+ from .utils import create_multi_turn_sgl_function
14
+
15
+ SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
16
+
17
+
18
+ @BENCHMARKS.register("mtbench")
19
+ class MTBenchBenchmarker(Benchmarker):
20
+ """MT-Bench benchmark implementation."""
21
+
22
+ def __init__(
23
+ self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None
24
+ ):
25
+ # support categorical data for mtbench
26
+ if subset is None:
27
+ subset = ["all"]
28
+ super().__init__(num_samples, subset)
29
+
30
+ def load_data(self) -> Tuple[List[Dict[str, Any]], List[None]]:
31
+ """Load and preprocess MT-Bench dataset."""
32
+ # 优先从本地数据目录读取
33
+ local_path = "/workspace/hanrui/datasets/mtbench/question.jsonl"
34
+
35
+ if os.path.exists(local_path):
36
+ print(f"Loading MT-Bench data from local: {local_path}")
37
+ questions_data = list(read_jsonl(local_path))
38
+ else:
39
+ # 如果本地不存在,从网络下载
40
+ print(f"Local data not found, downloading from GitHub...")
41
+ url = "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl"
42
+ download_and_cache_file(url, filename="mtbench.jsonl")
43
+ questions_data = list(read_jsonl("mtbench.jsonl"))
44
+
45
+ questions_data = questions_data
46
+
47
+ questions = [
48
+ {"question_1": q["turns"][0], "question_2": q["turns"][1]}
49
+ for q in questions_data
50
+ ]
51
+ # MT-Bench doesn't have labels for accuracy computation
52
+ labels = [None] * len(questions)
53
+
54
+ if self.num_samples is not None:
55
+ questions = questions[: self.num_samples]
56
+ labels = labels[: self.num_samples]
57
+ return questions, labels
58
+
59
+ def create_sgl_function(self):
60
+ """Create SGL function for MT-Bench (2-turn conversation)."""
61
+ return create_multi_turn_sgl_function(
62
+ function_name="answer_mt_bench",
63
+ system_prompt=SYSTEM_PROMPT,
64
+ num_turns=2,
65
+ max_tokens=self.get_max_new_tokens(),
66
+ )
67
+
68
+ def get_answer_keys(self) -> List[str]:
69
+ """Return answer keys for multi-turn conversation."""
70
+ return ["answer_1", "answer_2"]