Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- SpecForge-ext/benchmarks/README.md +67 -0
- SpecForge-ext/benchmarks/__init__.py +3 -0
- SpecForge-ext/benchmarks/bench_eagle3.py +268 -0
- SpecForge-ext/benchmarks/benchmarker/__init__.py +29 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/__init__.cpython-310.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/__init__.cpython-311.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/__init__.cpython-312.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/aime.cpython-310.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/aime.cpython-311.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/aime.cpython-312.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/base.cpython-310.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/base.cpython-311.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/base.cpython-312.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/ceval.cpython-310.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/ceval.cpython-311.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/financeqa.cpython-310.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/financeqa.cpython-311.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/gpqa.cpython-310.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/gpqa.cpython-311.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/gsm8k.cpython-310.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/gsm8k.cpython-311.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/humaneval.cpython-310.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/humaneval.cpython-311.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/livecodebench.cpython-310.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/math500.cpython-310.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/math500.cpython-311.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/mmlu.cpython-310.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/mmlu.cpython-311.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/mmstar.cpython-310.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/mmstar.cpython-311.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/mtbench.cpython-310.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/mtbench.cpython-311.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/registry.cpython-310.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/registry.cpython-311.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/simpleqa.cpython-310.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/simpleqa.cpython-311.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/utils.cpython-310.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/__pycache__/utils.cpython-311.pyc +0 -0
- SpecForge-ext/benchmarks/benchmarker/aime.py +133 -0
- SpecForge-ext/benchmarks/benchmarker/base.py +218 -0
- SpecForge-ext/benchmarks/benchmarker/ceval.py +267 -0
- SpecForge-ext/benchmarks/benchmarker/financeqa.py +59 -0
- SpecForge-ext/benchmarks/benchmarker/gpqa.py +85 -0
- SpecForge-ext/benchmarks/benchmarker/gsm8k.py +108 -0
- SpecForge-ext/benchmarks/benchmarker/humaneval.py +201 -0
- SpecForge-ext/benchmarks/benchmarker/livecodebench.py +46 -0
- SpecForge-ext/benchmarks/benchmarker/math500.py +122 -0
- SpecForge-ext/benchmarks/benchmarker/mmlu.py +82 -0
- SpecForge-ext/benchmarks/benchmarker/mmstar.py +185 -0
- SpecForge-ext/benchmarks/benchmarker/mtbench.py +70 -0
SpecForge-ext/benchmarks/README.md
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Benchmarking for Speculative Decoding
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
We provided a unified script to test the performance of the Speculative Decoding with EAGLE3 algorithm on multiple datasets. You can follow the steps below to run the benchmarks.
|
| 6 |
+
|
| 7 |
+
## Run Benchmarks
|
| 8 |
+
|
| 9 |
+
### Launch SGLang and Benchmarker Concurrently
|
| 10 |
+
|
| 11 |
+
`bench_eagle3.py` can help you launch a SGLang server process and a Benchmarking process concurrently. In this way, you don't have to launch the SGLang server manually, this script will manually handle the SGLang launch under different speculative decoding configurations. Some important arguments are:
|
| 12 |
+
- `--model-path`: the path to the target model.
|
| 13 |
+
- `--speculative-draft-model-path`: the path to the draft model.
|
| 14 |
+
- `--port`: the port to launch the SGLang server.
|
| 15 |
+
- `--trust-remote-code`: trust the remote code.
|
| 16 |
+
- `--mem-fraction-static`: the memory fraction for the static memory.
|
| 17 |
+
- `--tp-size`: the tensor parallelism size.
|
| 18 |
+
- `--attention-backend`: the attention backend.
|
| 19 |
+
- `--config-list`: the list of speculative decoding configuration to test, the format is `<batch-size>,<num-steps>,<topk>,<num-draft-tokens>`.
|
| 20 |
+
- `--benchmark-list`: the list of benchmarks to test, the format is `<benchmark-name>:<num-prompts>:<subset>`.
|
| 21 |
+
|
| 22 |
+
```shell
|
| 23 |
+
python3 bench_eagle3.py \
|
| 24 |
+
--model-path meta-llama/Llama-3.1-8B-Instruct \
|
| 25 |
+
--speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \
|
| 26 |
+
--port 30000 \
|
| 27 |
+
--trust-remote-code \
|
| 28 |
+
--mem-fraction-static 0.8 \
|
| 29 |
+
--tp-size 1 \
|
| 30 |
+
--attention-backend fa3 \
|
| 31 |
+
--config-list 1,0,0,0 1,3,1,4 \
|
| 32 |
+
--benchmark-list mtbench gsm8k:5 ceval:5:accountant \
|
| 33 |
+
--dtype bfloat16
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
### Launch Benchmarker Independently
|
| 37 |
+
|
| 38 |
+
If you want to launch the SGLang server independently, you can use the following command.
|
| 39 |
+
|
| 40 |
+
```shell
|
| 41 |
+
# you can launch a server
|
| 42 |
+
python3 -m sglang.launch_server \
|
| 43 |
+
--model meta-llama/Llama-3.1-8B-Instruct \
|
| 44 |
+
--speculative-algorithm EAGLE3 \
|
| 45 |
+
--speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \
|
| 46 |
+
--speculative-num-steps 3 \
|
| 47 |
+
--speculative-eagle-topk 1 \
|
| 48 |
+
--speculative-num-draft-tokens 4 \
|
| 49 |
+
--mem-fraction-static 0.75 \
|
| 50 |
+
--cuda-graph-max-bs 1 \
|
| 51 |
+
--tp 1 \
|
| 52 |
+
--trust-remote-code \
|
| 53 |
+
--host 0.0.0.0 \
|
| 54 |
+
--port 30000 \
|
| 55 |
+
--dtype bfloat16
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
Then we can start benchmarking. Note that you should use the same host and port as the one used in the SGLang server. Note that `--skip-launch-server` is required to skip the launch of the SGLang server.
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
python bench_eagle3.py \
|
| 62 |
+
--model-path meta-llama/Llama-3.1-8B-Instruct \
|
| 63 |
+
--port 30000 \
|
| 64 |
+
--config-list 1,3,1,4 \
|
| 65 |
+
--benchmark-list mtbench:5 ceval:5:accountant gsm8k:5 humaneval:5 math500:5 mtbench:5 aime:1 \
|
| 66 |
+
--skip-launch-server
|
| 67 |
+
```
|
SpecForge-ext/benchmarks/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Benchmark scripts for speculative decoding evaluation.
|
| 3 |
+
"""
|
SpecForge-ext/benchmarks/bench_eagle3.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Usage:
|
| 4 |
+
|
| 5 |
+
# if you want to run benchmarks directly
|
| 6 |
+
# mtbench:20 means only run 20 samples in the dataset
|
| 7 |
+
python bench_eagle3.py \
|
| 8 |
+
--model meta-llama/Llama-3.1-8B-Instruct \
|
| 9 |
+
--speculative-algorithm EAGLE3 \
|
| 10 |
+
--speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \
|
| 11 |
+
--port 30000 \
|
| 12 |
+
--config-list 1,0,0,0 1,3,1,4 \
|
| 13 |
+
--benchmark-list mtbench:20 \
|
| 14 |
+
--dtype bfloat16
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
or if you want run sglang alone.
|
| 18 |
+
|
| 19 |
+
# launch sglang
|
| 20 |
+
python3 -m sglang.launch_server \
|
| 21 |
+
--model meta-llama/Llama-3.1-8B-Instruct \
|
| 22 |
+
--speculative-algorithm EAGLE3 \
|
| 23 |
+
--speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \
|
| 24 |
+
--speculative-num-steps 3 \
|
| 25 |
+
--speculative-eagle-topk 1 \
|
| 26 |
+
--speculative-num-draft-tokens 4 \
|
| 27 |
+
--mem-fraction-static 0.75 \
|
| 28 |
+
--cuda-graph-max-bs 1 \
|
| 29 |
+
--tp 1 \
|
| 30 |
+
--trust-remote-code \
|
| 31 |
+
--host 0.0.0.0 \
|
| 32 |
+
--port 30000 \
|
| 33 |
+
--dtype bfloat16
|
| 34 |
+
|
| 35 |
+
# then run benchmarks
|
| 36 |
+
python bench_eagle3.py \
|
| 37 |
+
--model-path meta-llama/Llama-3.1-8B-Instruct \
|
| 38 |
+
--port 30000 \
|
| 39 |
+
--config-list 1,0,0,0 \
|
| 40 |
+
--benchmark-list mtbench:80 \
|
| 41 |
+
--dtype bfloat16 \
|
| 42 |
+
--skip-launch-server
|
| 43 |
+
"""
|
| 44 |
+
import argparse
|
| 45 |
+
import json
|
| 46 |
+
import os
|
| 47 |
+
import time
|
| 48 |
+
from dataclasses import asdict
|
| 49 |
+
from typing import List
|
| 50 |
+
|
| 51 |
+
import requests
|
| 52 |
+
from benchmarker import BENCHMARKS
|
| 53 |
+
from sglang.srt.server_args import ServerArgs
|
| 54 |
+
from sglang.test.test_utils import kill_process_tree, popen_launch_server
|
| 55 |
+
from sglang.utils import wait_for_server
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def parse_args():
|
| 59 |
+
parser = argparse.ArgumentParser()
|
| 60 |
+
sglang_group = parser.add_argument_group("sglang")
|
| 61 |
+
ServerArgs.add_cli_args(sglang_group)
|
| 62 |
+
|
| 63 |
+
# make the follow args a group
|
| 64 |
+
benchmark_group = parser.add_argument_group("benchmark")
|
| 65 |
+
benchmark_group.add_argument(
|
| 66 |
+
"--skip-launch-server", action="store_true", default=False
|
| 67 |
+
)
|
| 68 |
+
benchmark_group.add_argument("--timeout-for-server-launch", type=int, default=600)
|
| 69 |
+
benchmark_group.add_argument("--num-prompts", type=int, default=80)
|
| 70 |
+
benchmark_group.add_argument("--output-dir", type=str, default="./results")
|
| 71 |
+
benchmark_group.add_argument(
|
| 72 |
+
"--config-list", type=str, nargs="+", default=["1,0,0,0", "1,3,1,4"]
|
| 73 |
+
)
|
| 74 |
+
benchmark_group.add_argument(
|
| 75 |
+
"--name",
|
| 76 |
+
type=str,
|
| 77 |
+
default=None,
|
| 78 |
+
help="name of this benchmark run, if provided, will be added to the output file name",
|
| 79 |
+
)
|
| 80 |
+
benchmark_group.add_argument(
|
| 81 |
+
"--benchmark-list",
|
| 82 |
+
type=str,
|
| 83 |
+
nargs="+",
|
| 84 |
+
default=[
|
| 85 |
+
"mtbench:80",
|
| 86 |
+
"gsm8k:200",
|
| 87 |
+
"humaneval:200",
|
| 88 |
+
"math500:200",
|
| 89 |
+
"ceval:200",
|
| 90 |
+
],
|
| 91 |
+
help=f"The list of benchmarks to run. The format is <benchmark-name>:<num-prompts>:<subset>,<subset>. We support the following benchmarks: {', '.join(BENCHMARKS.benchmarks.keys())}",
|
| 92 |
+
)
|
| 93 |
+
benchmark_group.add_argument(
|
| 94 |
+
"--enable-multi-turn-conversation",
|
| 95 |
+
action="store_true",
|
| 96 |
+
default=False,
|
| 97 |
+
)
|
| 98 |
+
return parser.parse_args()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def launch_sglang_server(
|
| 102 |
+
server_args: ServerArgs,
|
| 103 |
+
base_url: str,
|
| 104 |
+
batch_size: int,
|
| 105 |
+
steps: int,
|
| 106 |
+
topk: int,
|
| 107 |
+
num_draft_tokens: int,
|
| 108 |
+
timeout: int,
|
| 109 |
+
):
|
| 110 |
+
"""
|
| 111 |
+
This function launches the SGLang server with the given server arguments.
|
| 112 |
+
"""
|
| 113 |
+
sglang_args: List[str] = []
|
| 114 |
+
if steps > 0:
|
| 115 |
+
sglang_args.extend(
|
| 116 |
+
[
|
| 117 |
+
"--speculative-algorithm",
|
| 118 |
+
"EAGLE3",
|
| 119 |
+
"--speculative-num-steps",
|
| 120 |
+
str(steps),
|
| 121 |
+
"--speculative-eagle-topk",
|
| 122 |
+
str(topk),
|
| 123 |
+
"--speculative-num-draft-tokens",
|
| 124 |
+
str(num_draft_tokens),
|
| 125 |
+
"--speculative-draft-model-path",
|
| 126 |
+
server_args.speculative_draft_model_path,
|
| 127 |
+
]
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
sglang_args.extend(
|
| 131 |
+
[
|
| 132 |
+
"--cuda-graph-max-bs",
|
| 133 |
+
str(batch_size),
|
| 134 |
+
"--mem-fraction-static",
|
| 135 |
+
str(server_args.mem_fraction_static),
|
| 136 |
+
"--tp-size",
|
| 137 |
+
str(server_args.tp_size),
|
| 138 |
+
"--max-running-requests",
|
| 139 |
+
str(batch_size),
|
| 140 |
+
]
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if server_args.trust_remote_code:
|
| 144 |
+
sglang_args.extend(["--trust-remote-code"])
|
| 145 |
+
|
| 146 |
+
if server_args.disable_radix_cache:
|
| 147 |
+
sglang_args.extend(["--disable-radix-cache"])
|
| 148 |
+
|
| 149 |
+
if server_args.ep_size:
|
| 150 |
+
sglang_args.extend(["--ep-size", str(server_args.ep_size)])
|
| 151 |
+
|
| 152 |
+
if server_args.attention_backend:
|
| 153 |
+
sglang_args.extend(["--attention-backend", server_args.attention_backend])
|
| 154 |
+
|
| 155 |
+
if server_args.quantization:
|
| 156 |
+
sglang_args.extend(["--quantization", server_args.quantization])
|
| 157 |
+
|
| 158 |
+
if server_args.dtype:
|
| 159 |
+
sglang_args.extend(["--dtype", server_args.dtype])
|
| 160 |
+
|
| 161 |
+
process = popen_launch_server(
|
| 162 |
+
server_args.model_path,
|
| 163 |
+
base_url,
|
| 164 |
+
timeout=timeout,
|
| 165 |
+
other_args=sglang_args,
|
| 166 |
+
env={
|
| 167 |
+
"SGLANG_RECORD_STEP_TIME": "1",
|
| 168 |
+
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN": "1",
|
| 169 |
+
**os.environ,
|
| 170 |
+
},
|
| 171 |
+
)
|
| 172 |
+
return process
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def send_flush_cache_request(base_url: str):
|
| 176 |
+
requests.post(base_url + "/flush_cache")
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def main():
|
| 180 |
+
args = parse_args()
|
| 181 |
+
server_args: ServerArgs = ServerArgs.from_cli_args(args)
|
| 182 |
+
configs = [tuple(map(int, config.split(","))) for config in args.config_list]
|
| 183 |
+
|
| 184 |
+
# split the arg into list of (bench_name, num_prompts)
|
| 185 |
+
benchmark_list = []
|
| 186 |
+
for item in args.benchmark_list:
|
| 187 |
+
splits = item.split(":")
|
| 188 |
+
if len(splits) == 1:
|
| 189 |
+
bench_name = splits[0]
|
| 190 |
+
num_prompts = None
|
| 191 |
+
subset = None
|
| 192 |
+
elif len(splits) == 2:
|
| 193 |
+
bench_name, num_prompts = splits
|
| 194 |
+
subset = None
|
| 195 |
+
elif len(splits) == 3:
|
| 196 |
+
bench_name, num_prompts, subset = splits
|
| 197 |
+
subset = subset.split(",")
|
| 198 |
+
else:
|
| 199 |
+
raise ValueError(f"Invalid benchmark list format: {item}")
|
| 200 |
+
benchmark_list.append((bench_name, num_prompts, subset))
|
| 201 |
+
assert len(benchmark_list) != 0, "the number of benchmark list is 0"
|
| 202 |
+
|
| 203 |
+
base_url = f"http://localhost:{args.port}"
|
| 204 |
+
|
| 205 |
+
results = {}
|
| 206 |
+
results["model"] = server_args.speculative_draft_model_path
|
| 207 |
+
|
| 208 |
+
def run_benchmarks(batch_size: int, steps: int, topk: int, num_draft_tokens: int):
|
| 209 |
+
for benchmark_name, num_prompts, subset in benchmark_list:
|
| 210 |
+
print(
|
| 211 |
+
f"Running benchmark {benchmark_name} with {num_prompts} prompts, batch size {batch_size}, steps {steps}, topk {topk}, num_draft_tokens {num_draft_tokens}, subset {subset}"
|
| 212 |
+
)
|
| 213 |
+
benchmarkder_cls = BENCHMARKS.get(benchmark_name)
|
| 214 |
+
num_prompts = int(num_prompts) if num_prompts is not None else None
|
| 215 |
+
if subset is None:
|
| 216 |
+
benchmarker = benchmarkder_cls(num_samples=num_prompts)
|
| 217 |
+
else:
|
| 218 |
+
benchmarker = benchmarkder_cls(num_samples=num_prompts, subset=subset)
|
| 219 |
+
metrics_list = benchmarker.run(
|
| 220 |
+
host=args.host, port=args.port, batch_size=batch_size
|
| 221 |
+
)
|
| 222 |
+
send_flush_cache_request(base_url)
|
| 223 |
+
if benchmark_name not in results:
|
| 224 |
+
results[benchmark_name] = []
|
| 225 |
+
results[benchmark_name].append(
|
| 226 |
+
dict(
|
| 227 |
+
batch_size=batch_size,
|
| 228 |
+
steps=steps,
|
| 229 |
+
topk=topk,
|
| 230 |
+
num_draft_tokens=num_draft_tokens,
|
| 231 |
+
metrics=[asdict(metric) for metric in metrics_list],
|
| 232 |
+
num_samples=num_prompts,
|
| 233 |
+
)
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
if args.skip_launch_server:
|
| 237 |
+
batch_size = configs[0][0] if len(configs) > 0 else 8
|
| 238 |
+
run_benchmarks(batch_size, None, None, None)
|
| 239 |
+
else:
|
| 240 |
+
# we itearate over each config from args
|
| 241 |
+
for batch_size, steps, topk, num_draft_tokens in configs:
|
| 242 |
+
process = launch_sglang_server(
|
| 243 |
+
server_args,
|
| 244 |
+
base_url,
|
| 245 |
+
batch_size,
|
| 246 |
+
steps,
|
| 247 |
+
topk,
|
| 248 |
+
num_draft_tokens,
|
| 249 |
+
args.timeout_for_server_launch,
|
| 250 |
+
)
|
| 251 |
+
wait_for_server(base_url)
|
| 252 |
+
run_benchmarks(batch_size, steps, topk, num_draft_tokens)
|
| 253 |
+
kill_process_tree(process.pid)
|
| 254 |
+
process.wait()
|
| 255 |
+
|
| 256 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 257 |
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
| 258 |
+
result_file = os.path.join(
|
| 259 |
+
args.output_dir,
|
| 260 |
+
f"{args.name + '_' if args.name else ''}results_{timestamp}.jsonl",
|
| 261 |
+
)
|
| 262 |
+
with open(result_file, "w") as f:
|
| 263 |
+
json.dump(results, f, indent=4)
|
| 264 |
+
print(f"Results saved to {result_file}")
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
if __name__ == "__main__":
|
| 268 |
+
main()
|
SpecForge-ext/benchmarks/benchmarker/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .aime import AIMEBenchmarker
|
| 2 |
+
from .ceval import CEvalBenchmarker
|
| 3 |
+
from .financeqa import FinanceQABenchmarker
|
| 4 |
+
from .gpqa import GPQABenchmarker
|
| 5 |
+
from .gsm8k import GSM8KBenchmarker
|
| 6 |
+
from .humaneval import HumanEvalBenchmarker
|
| 7 |
+
from .livecodebench import LCBBenchmarker
|
| 8 |
+
from .math500 import Math500Benchmarker
|
| 9 |
+
from .mmlu import MMLUBenchmarker
|
| 10 |
+
from .mmstar import MMStarBenchmarker
|
| 11 |
+
from .mtbench import MTBenchBenchmarker
|
| 12 |
+
from .registry import BENCHMARKS
|
| 13 |
+
from .simpleqa import SimpleQABenchmarker
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"BENCHMARKS",
|
| 17 |
+
"AIMEBenchmarker",
|
| 18 |
+
"CEvalBenchmarker",
|
| 19 |
+
"GSM8KBenchmarker",
|
| 20 |
+
"HumanEvalBenchmarker",
|
| 21 |
+
"Math500Benchmarker",
|
| 22 |
+
"MTBenchBenchmarker",
|
| 23 |
+
"MMStarBenchmarker",
|
| 24 |
+
"GPQABenchmarker",
|
| 25 |
+
"FinanceQABenchmarker",
|
| 26 |
+
"MMLUBenchmarker",
|
| 27 |
+
"LCBBenchmarker",
|
| 28 |
+
"SimpleQABenchmarker",
|
| 29 |
+
]
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (872 Bytes). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.11 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (896 Bytes). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/aime.cpython-310.pyc
ADDED
|
Binary file (4.16 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/aime.cpython-311.pyc
ADDED
|
Binary file (6.79 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/aime.cpython-312.pyc
ADDED
|
Binary file (5.8 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/base.cpython-310.pyc
ADDED
|
Binary file (6.47 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/base.cpython-311.pyc
ADDED
|
Binary file (9.05 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/base.cpython-312.pyc
ADDED
|
Binary file (8.11 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/ceval.cpython-310.pyc
ADDED
|
Binary file (6.57 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/ceval.cpython-311.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/financeqa.cpython-310.pyc
ADDED
|
Binary file (2.07 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/financeqa.cpython-311.pyc
ADDED
|
Binary file (3.34 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/gpqa.cpython-310.pyc
ADDED
|
Binary file (3.24 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/gpqa.cpython-311.pyc
ADDED
|
Binary file (5.36 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/gsm8k.cpython-310.pyc
ADDED
|
Binary file (3.94 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/gsm8k.cpython-311.pyc
ADDED
|
Binary file (6.72 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/humaneval.cpython-310.pyc
ADDED
|
Binary file (4.88 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/humaneval.cpython-311.pyc
ADDED
|
Binary file (8.95 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/livecodebench.cpython-310.pyc
ADDED
|
Binary file (1.87 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/math500.cpython-310.pyc
ADDED
|
Binary file (3.73 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/math500.cpython-311.pyc
ADDED
|
Binary file (6.26 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/mmlu.cpython-310.pyc
ADDED
|
Binary file (3.11 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/mmlu.cpython-311.pyc
ADDED
|
Binary file (5.19 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/mmstar.cpython-310.pyc
ADDED
|
Binary file (5.08 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/mmstar.cpython-311.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/mtbench.cpython-310.pyc
ADDED
|
Binary file (3.07 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/mtbench.cpython-311.pyc
ADDED
|
Binary file (4.63 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/registry.cpython-310.pyc
ADDED
|
Binary file (1.22 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/registry.cpython-311.pyc
ADDED
|
Binary file (1.51 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/simpleqa.cpython-310.pyc
ADDED
|
Binary file (1.81 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/simpleqa.cpython-311.pyc
ADDED
|
Binary file (2.83 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (8.59 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (13.7 kB). View file
|
|
|
SpecForge-ext/benchmarks/benchmarker/aime.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AIME benchmark
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
|
| 10 |
+
from .base import Benchmarker
|
| 11 |
+
from .registry import BENCHMARKS
|
| 12 |
+
from .utils import create_simple_sgl_function
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def extract_aime_answer(output: str) -> Optional[str]:
|
| 16 |
+
"""Extract final answer from AIME problem solution.
|
| 17 |
+
|
| 18 |
+
AIME answers are typically integers between 0 and 999, and are usually
|
| 19 |
+
in \boxed{} format.
|
| 20 |
+
"""
|
| 21 |
+
# Try to find answer in \boxed{} format
|
| 22 |
+
boxed_pattern = r"\\boxed\{([^}]+)\}"
|
| 23 |
+
match = re.search(boxed_pattern, output)
|
| 24 |
+
if match:
|
| 25 |
+
answer = match.group(1).strip()
|
| 26 |
+
# Extract number from the boxed content
|
| 27 |
+
numbers = re.findall(r"\d+", answer)
|
| 28 |
+
if numbers:
|
| 29 |
+
return numbers[-1] # Take the last number (usually the final answer)
|
| 30 |
+
return answer
|
| 31 |
+
|
| 32 |
+
# Try to find answer in \boxed format (without braces)
|
| 33 |
+
boxed_pattern2 = r"\\boxed\s+(\d+)"
|
| 34 |
+
match = re.search(boxed_pattern2, output)
|
| 35 |
+
if match:
|
| 36 |
+
return match.group(1).strip()
|
| 37 |
+
|
| 38 |
+
# Look for patterns like "The answer is 42" or "Answer: 123"
|
| 39 |
+
answer_patterns = [
|
| 40 |
+
r"(?:answer|Answer|ANSWER)[\s:]+(\d+)",
|
| 41 |
+
r"(?:final\s+answer|Final\s+Answer)[\s:]+(\d+)",
|
| 42 |
+
r"(?:is|equals?|=\s*)(\d+)\s*$",
|
| 43 |
+
]
|
| 44 |
+
for pattern in answer_patterns:
|
| 45 |
+
matches = re.findall(pattern, output, re.IGNORECASE)
|
| 46 |
+
if matches:
|
| 47 |
+
return matches[-1].strip()
|
| 48 |
+
|
| 49 |
+
# Fallback: extract the last integer in the text
|
| 50 |
+
numbers = re.findall(r"\b(\d+)\b", output)
|
| 51 |
+
if numbers:
|
| 52 |
+
# Filter to reasonable AIME answer range (0-999)
|
| 53 |
+
valid_numbers = [n for n in numbers if 0 <= int(n) <= 999]
|
| 54 |
+
if valid_numbers:
|
| 55 |
+
return valid_numbers[-1]
|
| 56 |
+
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@BENCHMARKS.register("aime")
|
| 61 |
+
class AIMEBenchmarker(Benchmarker):
|
| 62 |
+
"""AIME benchmark implementation."""
|
| 63 |
+
|
| 64 |
+
def __init__(self, num_samples: Optional[int] = None):
|
| 65 |
+
super().__init__(num_samples, None)
|
| 66 |
+
|
| 67 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]:
|
| 68 |
+
"""Load and preprocess AIME dataset."""
|
| 69 |
+
dataset = load_dataset("Maxwell-Jia/AIME_2024")["train"]
|
| 70 |
+
questions = []
|
| 71 |
+
labels = []
|
| 72 |
+
for idx, q in enumerate(dataset):
|
| 73 |
+
if self.num_samples is not None and idx >= self.num_samples:
|
| 74 |
+
break
|
| 75 |
+
|
| 76 |
+
questions.append({"question": q["Problem"]})
|
| 77 |
+
# Extract answer from Answer field
|
| 78 |
+
answer = None
|
| 79 |
+
if "Answer" in q:
|
| 80 |
+
answer = str(q["Answer"]).strip()
|
| 81 |
+
elif "answer" in q:
|
| 82 |
+
answer = str(q["answer"]).strip()
|
| 83 |
+
labels.append(answer)
|
| 84 |
+
return questions, labels
|
| 85 |
+
|
| 86 |
+
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
|
| 87 |
+
"""Extract answer from model output."""
|
| 88 |
+
return extract_aime_answer(output)
|
| 89 |
+
|
| 90 |
+
def compute_accuracy(
|
| 91 |
+
self, predictions: List[Any], labels: List[Any]
|
| 92 |
+
) -> Optional[float]:
|
| 93 |
+
"""Compute accuracy for AIME by comparing numeric answers."""
|
| 94 |
+
if not labels or len(labels) == 0:
|
| 95 |
+
return None
|
| 96 |
+
if all(label is None for label in labels):
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
correct = 0
|
| 100 |
+
valid_count = 0
|
| 101 |
+
for pred, label in zip(predictions, labels):
|
| 102 |
+
if label is not None:
|
| 103 |
+
valid_count += 1
|
| 104 |
+
if pred is not None:
|
| 105 |
+
# Normalize answers for comparison
|
| 106 |
+
pred_normalized = str(pred).strip()
|
| 107 |
+
label_normalized = str(label).strip()
|
| 108 |
+
# Try exact match first
|
| 109 |
+
if pred_normalized == label_normalized:
|
| 110 |
+
correct += 1
|
| 111 |
+
else:
|
| 112 |
+
# Try numeric comparison
|
| 113 |
+
try:
|
| 114 |
+
pred_num = int(pred_normalized)
|
| 115 |
+
label_num = int(label_normalized)
|
| 116 |
+
if pred_num == label_num:
|
| 117 |
+
correct += 1
|
| 118 |
+
except ValueError:
|
| 119 |
+
pass
|
| 120 |
+
|
| 121 |
+
return correct / valid_count if valid_count > 0 else 0.0
|
| 122 |
+
|
| 123 |
+
def create_sgl_function(self):
|
| 124 |
+
"""Create SGL function for AIME with reasoning prompt."""
|
| 125 |
+
return create_simple_sgl_function(
|
| 126 |
+
function_name="reasoning_gen",
|
| 127 |
+
answer_key="answer",
|
| 128 |
+
user_prefix="\nPlease reason step by step, and put your final answer within \\boxed{}.",
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def get_max_new_tokens(self) -> int:
|
| 132 |
+
"""AIME problems require more tokens."""
|
| 133 |
+
return 32768
|
SpecForge-ext/benchmarks/benchmarker/base.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Base class for benchmark implementations.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import time
|
| 6 |
+
from abc import ABC, abstractmethod
|
| 7 |
+
from argparse import Namespace
|
| 8 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
from sglang import set_default_backend
|
| 11 |
+
from sglang.test.test_utils import select_sglang_backend
|
| 12 |
+
|
| 13 |
+
from .utils import compute_metrics
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Benchmarker(ABC):
|
| 17 |
+
"""
|
| 18 |
+
Base class for benchmark implementations.
|
| 19 |
+
|
| 20 |
+
Subclasses should implement:
|
| 21 |
+
- load_data(): Load and preprocess dataset
|
| 22 |
+
- create_sgl_function(): Create the SGL function for inference
|
| 23 |
+
|
| 24 |
+
Optional overrides:
|
| 25 |
+
- extract_answer(): Extract answer from model output (if needed)
|
| 26 |
+
- compute_accuracy(): Compute accuracy metric (if applicable)
|
| 27 |
+
- get_answer_keys(): Get list of answer keys for multi-turn conversations
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
num_samples: The number of samples to run the benchmark on. If not provided, all questions will be used.
|
| 31 |
+
subset: The subset of the dataset to run the benchmark on. If not provided, all subsets will be used.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None
|
| 36 |
+
):
|
| 37 |
+
self.num_samples = num_samples
|
| 38 |
+
self.subset = subset
|
| 39 |
+
|
| 40 |
+
@abstractmethod
|
| 41 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[Any]]:
|
| 42 |
+
"""
|
| 43 |
+
Load and preprocess the dataset.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Tuple of (questions, labels) where:
|
| 47 |
+
- questions: List of question dicts for SGL function
|
| 48 |
+
- labels: List of ground truth labels (can be None if not applicable)
|
| 49 |
+
"""
|
| 50 |
+
raise NotImplementedError
|
| 51 |
+
|
| 52 |
+
@abstractmethod
|
| 53 |
+
def create_sgl_function(self) -> Callable:
|
| 54 |
+
"""
|
| 55 |
+
Create the SGL function for inference.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
SGL function decorated with @sgl.function
|
| 59 |
+
"""
|
| 60 |
+
raise NotImplementedError
|
| 61 |
+
|
| 62 |
+
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[Any]:
|
| 63 |
+
"""
|
| 64 |
+
Extract answer from model output.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
output: Raw model output string
|
| 68 |
+
label: Optional ground truth label for reference
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
Extracted answer, or None if extraction fails
|
| 72 |
+
"""
|
| 73 |
+
return output
|
| 74 |
+
|
| 75 |
+
def compute_accuracy(
|
| 76 |
+
self, predictions: List[Any], labels: List[Any]
|
| 77 |
+
) -> Optional[float]:
|
| 78 |
+
"""
|
| 79 |
+
Compute accuracy metric.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
predictions: List of predicted answers
|
| 83 |
+
labels: List of ground truth labels
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
Accuracy score (0-1), or None if not applicable
|
| 87 |
+
"""
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
def get_answer_keys(self) -> Optional[List[str]]:
|
| 91 |
+
"""
|
| 92 |
+
Get list of answer keys for multi-turn conversations.
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
List of answer keys (e.g., ["answer_1", "answer_2"]), or None for single-turn
|
| 96 |
+
"""
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
def get_max_new_tokens(self) -> int:
|
| 100 |
+
"""
|
| 101 |
+
Get maximum number of new tokens to generate.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
Maximum tokens (default: 2048)
|
| 105 |
+
"""
|
| 106 |
+
return 2048
|
| 107 |
+
|
| 108 |
+
def run(
|
| 109 |
+
self,
|
| 110 |
+
host: str,
|
| 111 |
+
port: int,
|
| 112 |
+
batch_size: int,
|
| 113 |
+
max_new_tokens: int = None,
|
| 114 |
+
num_runs: int = 1,
|
| 115 |
+
):
|
| 116 |
+
"""
|
| 117 |
+
Run the benchmark evaluation.
|
| 118 |
+
|
| 119 |
+
This method handles the common workflow:
|
| 120 |
+
1. Initialize backend
|
| 121 |
+
2. Load data
|
| 122 |
+
3. Create SGL function
|
| 123 |
+
4. Run inference loops
|
| 124 |
+
5. Compute metrics
|
| 125 |
+
6. Print results
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
host (str): The host of the SGLang server
|
| 129 |
+
port (int): The port of the SGLang server
|
| 130 |
+
batch_size (int): The number of prompts to process in parallel
|
| 131 |
+
num_samples (int): The number of samples to run the benchmark on. If not provided, all samples will be used.
|
| 132 |
+
max_new_tokens (int): Maximum number of new tokens to generate, default is 2048
|
| 133 |
+
num_runs (int): The number of times to run this benchmark, default is 1. You can set it to a larger number if you want to get more stable results.
|
| 134 |
+
"""
|
| 135 |
+
if not host.startswith(("http://", "https://")):
|
| 136 |
+
host = f"http://{host}"
|
| 137 |
+
# Initialize backend
|
| 138 |
+
sglang_args = Namespace(host=host, port=port, backend="srt-no-parallel")
|
| 139 |
+
set_default_backend(select_sglang_backend(sglang_args))
|
| 140 |
+
|
| 141 |
+
# Load data
|
| 142 |
+
questions, labels = self.load_data()
|
| 143 |
+
if len(questions) == 0:
|
| 144 |
+
print("No valid questions found. Please check the dataset format.")
|
| 145 |
+
return
|
| 146 |
+
|
| 147 |
+
# Create SGL function
|
| 148 |
+
sgl_function = self.create_sgl_function()
|
| 149 |
+
|
| 150 |
+
# Run evaluation loops
|
| 151 |
+
metrics_list = []
|
| 152 |
+
answer_keys = self.get_answer_keys()
|
| 153 |
+
max_new_tokens = max_new_tokens or self.get_max_new_tokens()
|
| 154 |
+
|
| 155 |
+
for _ in range(num_runs):
|
| 156 |
+
tic = time.perf_counter()
|
| 157 |
+
states = sgl_function.run_batch(
|
| 158 |
+
questions,
|
| 159 |
+
temperature=0,
|
| 160 |
+
max_new_tokens=max_new_tokens,
|
| 161 |
+
num_threads=batch_size,
|
| 162 |
+
progress_bar=True,
|
| 163 |
+
)
|
| 164 |
+
latency = time.perf_counter() - tic
|
| 165 |
+
|
| 166 |
+
# Extract predictions
|
| 167 |
+
predictions = []
|
| 168 |
+
primary_answer_key = answer_keys[0] if answer_keys else "answer"
|
| 169 |
+
for i in range(len(states)):
|
| 170 |
+
# Access answer from state object (states[i] supports dict-like access)
|
| 171 |
+
output = states[i][primary_answer_key]
|
| 172 |
+
if isinstance(output, str):
|
| 173 |
+
extracted = self.extract_answer(
|
| 174 |
+
output,
|
| 175 |
+
(labels[i] if labels and i < len(labels) else None),
|
| 176 |
+
)
|
| 177 |
+
else:
|
| 178 |
+
extracted = output
|
| 179 |
+
predictions.append(extracted)
|
| 180 |
+
|
| 181 |
+
# Compute accuracy if applicable
|
| 182 |
+
accuracy = None
|
| 183 |
+
# Check if we have a labels list (even if all labels are None)
|
| 184 |
+
has_labels_list = labels and len(labels) > 0
|
| 185 |
+
|
| 186 |
+
if has_labels_list:
|
| 187 |
+
# Always call compute_accuracy if we have a labels list
|
| 188 |
+
# This allows it to return None, which will be displayed in print_results
|
| 189 |
+
accuracy = self.compute_accuracy(predictions, labels)
|
| 190 |
+
if accuracy is not None:
|
| 191 |
+
valid_count = sum(1 for p in predictions if p is not None)
|
| 192 |
+
if valid_count < len(predictions):
|
| 193 |
+
print(
|
| 194 |
+
f"Warning: {len(predictions) - valid_count} predictions could not be extracted."
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Compute performance metrics
|
| 198 |
+
metrics = compute_metrics(
|
| 199 |
+
states,
|
| 200 |
+
latency,
|
| 201 |
+
answer_key=primary_answer_key,
|
| 202 |
+
additional_answer_keys=(
|
| 203 |
+
answer_keys[1:] if answer_keys and len(answer_keys) > 1 else None
|
| 204 |
+
),
|
| 205 |
+
)
|
| 206 |
+
# Always set accuracy if we have a labels list (even if compute_accuracy returns None)
|
| 207 |
+
# This allows print_results to show None when compute_accuracy returns None
|
| 208 |
+
if has_labels_list:
|
| 209 |
+
metrics.accuracy = (
|
| 210 |
+
accuracy # Can be None if compute_accuracy returns None
|
| 211 |
+
)
|
| 212 |
+
if accuracy is not None:
|
| 213 |
+
metrics.num_valid_predictions = sum(
|
| 214 |
+
1 for p in predictions if p is not None
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
metrics_list.append(metrics)
|
| 218 |
+
return metrics_list
|
SpecForge-ext/benchmarks/benchmarker/ceval.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
C-Eval benchmark evaluation script.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
from datasets import concatenate_datasets, load_dataset
|
| 9 |
+
|
| 10 |
+
from .base import Benchmarker
|
| 11 |
+
from .registry import BENCHMARKS
|
| 12 |
+
from .utils import create_simple_sgl_function
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def extract_answer(answer_str: str) -> str:
|
| 16 |
+
"""Extract the answer choice (A, B, C, D) from the model output."""
|
| 17 |
+
# Try to find the answer in various formats
|
| 18 |
+
answer_str = answer_str.strip().upper()
|
| 19 |
+
|
| 20 |
+
# Direct match for single letter
|
| 21 |
+
match = re.search(r"\b([ABCD])\b", answer_str)
|
| 22 |
+
if match:
|
| 23 |
+
return match.group(1)
|
| 24 |
+
|
| 25 |
+
# Try to find answer in parentheses or brackets
|
| 26 |
+
for pattern in [
|
| 27 |
+
r"\(([ABCD])\)",
|
| 28 |
+
r"\[([ABCD])\]",
|
| 29 |
+
r"答案[::]\s*([ABCD])",
|
| 30 |
+
r"Answer[::]\s*([ABCD])",
|
| 31 |
+
]:
|
| 32 |
+
match = re.search(pattern, answer_str, re.IGNORECASE)
|
| 33 |
+
if match:
|
| 34 |
+
return match.group(1).upper()
|
| 35 |
+
|
| 36 |
+
# Try to find the first occurrence of A, B, C, or D
|
| 37 |
+
match = re.search(r"([ABCD])", answer_str)
|
| 38 |
+
if match:
|
| 39 |
+
return match.group(1)
|
| 40 |
+
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def format_question(question: str, options: List[str]) -> str:
|
| 45 |
+
"""Format the question with options."""
|
| 46 |
+
prompt = question + "\n\n选项:\n"
|
| 47 |
+
for i, option in enumerate(options):
|
| 48 |
+
prompt += f"{chr(65 + i)}. {option}\n"
|
| 49 |
+
prompt += "\n请从A、B、C、D中选择一个答案。"
|
| 50 |
+
return prompt
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@BENCHMARKS.register("ceval")
|
| 54 |
+
class CEvalBenchmarker(Benchmarker):
|
| 55 |
+
"""C-Eval benchmark implementation."""
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None
|
| 59 |
+
):
|
| 60 |
+
if subset is None:
|
| 61 |
+
subset = "all"
|
| 62 |
+
super().__init__(num_samples, subset)
|
| 63 |
+
|
| 64 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[str]]:
|
| 65 |
+
"""Load and preprocess C-Eval dataset."""
|
| 66 |
+
all_configs = [
|
| 67 |
+
"accountant",
|
| 68 |
+
"advanced_mathematics",
|
| 69 |
+
"art_studies",
|
| 70 |
+
"basic_medicine",
|
| 71 |
+
"business_administration",
|
| 72 |
+
"chinese_language_and_literature",
|
| 73 |
+
"civil_servant",
|
| 74 |
+
"clinical_medicine",
|
| 75 |
+
"college_chemistry",
|
| 76 |
+
"college_economics",
|
| 77 |
+
"college_physics",
|
| 78 |
+
"college_programming",
|
| 79 |
+
"computer_architecture",
|
| 80 |
+
"computer_network",
|
| 81 |
+
"discrete_mathematics",
|
| 82 |
+
"education_science",
|
| 83 |
+
"electrical_engineer",
|
| 84 |
+
"environmental_impact_assessment_engineer",
|
| 85 |
+
"fire_engineer",
|
| 86 |
+
"high_school_biology",
|
| 87 |
+
"high_school_chemistry",
|
| 88 |
+
"high_school_chinese",
|
| 89 |
+
"high_school_geography",
|
| 90 |
+
"high_school_history",
|
| 91 |
+
"high_school_mathematics",
|
| 92 |
+
"high_school_physics",
|
| 93 |
+
"high_school_politics",
|
| 94 |
+
"ideological_and_moral_cultivation",
|
| 95 |
+
"law",
|
| 96 |
+
"legal_professional",
|
| 97 |
+
"logic",
|
| 98 |
+
"mao_zedong_thought",
|
| 99 |
+
"marxism",
|
| 100 |
+
"metrology_engineer",
|
| 101 |
+
"middle_school_biology",
|
| 102 |
+
"middle_school_chemistry",
|
| 103 |
+
"middle_school_geography",
|
| 104 |
+
"middle_school_history",
|
| 105 |
+
"middle_school_mathematics",
|
| 106 |
+
"middle_school_physics",
|
| 107 |
+
"middle_school_politics",
|
| 108 |
+
"modern_chinese_history",
|
| 109 |
+
"operating_system",
|
| 110 |
+
"physician",
|
| 111 |
+
"plant_protection",
|
| 112 |
+
"probability_and_statistics",
|
| 113 |
+
"professional_tour_guide",
|
| 114 |
+
"sports_science",
|
| 115 |
+
"tax_accountant",
|
| 116 |
+
"teacher_qualification",
|
| 117 |
+
"urban_and_rural_planner",
|
| 118 |
+
"veterinary_medicine",
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
# Select configs to load
|
| 122 |
+
if self.subset == "all":
|
| 123 |
+
configs_to_load = all_configs
|
| 124 |
+
else:
|
| 125 |
+
for subset in self.subset:
|
| 126 |
+
assert (
|
| 127 |
+
subset in all_configs
|
| 128 |
+
), f"Subset {subset} not found in C-Eval dataset"
|
| 129 |
+
configs_to_load = self.subset
|
| 130 |
+
|
| 131 |
+
# Load datasets
|
| 132 |
+
try:
|
| 133 |
+
datasets = []
|
| 134 |
+
for config in configs_to_load:
|
| 135 |
+
try:
|
| 136 |
+
ds = load_dataset("ceval/ceval-exam", name=config, split="test")
|
| 137 |
+
datasets.append(ds)
|
| 138 |
+
print(f"Loaded config '{config}' with {len(ds)} samples")
|
| 139 |
+
except Exception as e:
|
| 140 |
+
print(f"Warning: Failed to load config '{config}': {e}")
|
| 141 |
+
if len(datasets) == 0:
|
| 142 |
+
raise ValueError("No configs could be loaded")
|
| 143 |
+
dataset = concatenate_datasets(datasets)
|
| 144 |
+
print(
|
| 145 |
+
f"Successfully loaded C-Eval dataset with all configs (total: {len(dataset)} samples)"
|
| 146 |
+
)
|
| 147 |
+
except Exception as e:
|
| 148 |
+
print(e)
|
| 149 |
+
print(f"Failed to load C-Eval dataset from 'ceval/ceval-exam': {e}")
|
| 150 |
+
print("Please ensure the dataset is available or install it manually.")
|
| 151 |
+
print("You can try: pip install datasets")
|
| 152 |
+
print("Or download from: https://huggingface.co/datasets/ceval/ceval-exam")
|
| 153 |
+
return [], []
|
| 154 |
+
|
| 155 |
+
# Process questions
|
| 156 |
+
questions = []
|
| 157 |
+
labels = []
|
| 158 |
+
for idx, item in enumerate(dataset):
|
| 159 |
+
if self.num_samples is not None and idx >= self.num_samples:
|
| 160 |
+
break
|
| 161 |
+
|
| 162 |
+
# Handle different dataset formats
|
| 163 |
+
question_text = None
|
| 164 |
+
if "question" in item:
|
| 165 |
+
question_text = item["question"]
|
| 166 |
+
elif "inputs" in item:
|
| 167 |
+
question_text = item["inputs"]
|
| 168 |
+
elif "problem" in item:
|
| 169 |
+
question_text = item["problem"]
|
| 170 |
+
elif "content" in item:
|
| 171 |
+
question_text = item["content"]
|
| 172 |
+
|
| 173 |
+
if not question_text:
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
# Get options - C-Eval typically has options as a list or dict
|
| 177 |
+
options = None
|
| 178 |
+
if "options" in item:
|
| 179 |
+
options = item["options"]
|
| 180 |
+
if isinstance(options, dict):
|
| 181 |
+
# Convert dict to list in order A, B, C, D
|
| 182 |
+
options = [
|
| 183 |
+
options.get("A", ""),
|
| 184 |
+
options.get("B", ""),
|
| 185 |
+
options.get("C", ""),
|
| 186 |
+
options.get("D", ""),
|
| 187 |
+
]
|
| 188 |
+
elif isinstance(options, list):
|
| 189 |
+
# Ensure we have 4 options
|
| 190 |
+
while len(options) < 4:
|
| 191 |
+
options.append("")
|
| 192 |
+
elif "choices" in item:
|
| 193 |
+
options = item["choices"]
|
| 194 |
+
if isinstance(options, dict):
|
| 195 |
+
options = [
|
| 196 |
+
options.get("A", ""),
|
| 197 |
+
options.get("B", ""),
|
| 198 |
+
options.get("C", ""),
|
| 199 |
+
options.get("D", ""),
|
| 200 |
+
]
|
| 201 |
+
else:
|
| 202 |
+
# Try to construct options from A, B, C, D fields
|
| 203 |
+
options = [
|
| 204 |
+
item.get("A", item.get("option_A", "")),
|
| 205 |
+
item.get("B", item.get("option_B", "")),
|
| 206 |
+
item.get("C", item.get("option_C", "")),
|
| 207 |
+
item.get("D", item.get("option_D", "")),
|
| 208 |
+
]
|
| 209 |
+
|
| 210 |
+
# Filter out empty options
|
| 211 |
+
if options:
|
| 212 |
+
options = [str(opt).strip() for opt in options if opt]
|
| 213 |
+
if len(options) < 2: # Need at least 2 options
|
| 214 |
+
continue
|
| 215 |
+
else:
|
| 216 |
+
continue
|
| 217 |
+
|
| 218 |
+
# Get answer
|
| 219 |
+
answer = None
|
| 220 |
+
if "answer" in item:
|
| 221 |
+
answer = str(item["answer"]).upper().strip()
|
| 222 |
+
elif "target" in item:
|
| 223 |
+
answer = str(item["target"]).upper().strip()
|
| 224 |
+
elif "label" in item:
|
| 225 |
+
answer = str(item["label"]).upper().strip()
|
| 226 |
+
elif "correct" in item:
|
| 227 |
+
answer = str(item["correct"]).upper().strip()
|
| 228 |
+
|
| 229 |
+
# Validate answer
|
| 230 |
+
if answer and answer in ["A", "B", "C", "D"]:
|
| 231 |
+
# Format question
|
| 232 |
+
formatted_question = format_question(question_text, options)
|
| 233 |
+
questions.append({"question": formatted_question})
|
| 234 |
+
labels.append(answer)
|
| 235 |
+
|
| 236 |
+
if len(questions) == 0:
|
| 237 |
+
print("No valid questions found. Please check the dataset format.")
|
| 238 |
+
print(
|
| 239 |
+
"Sample item keys:",
|
| 240 |
+
list(dataset[0].keys()) if len(dataset) > 0 else "No items",
|
| 241 |
+
)
|
| 242 |
+
return [], []
|
| 243 |
+
|
| 244 |
+
return questions, labels
|
| 245 |
+
|
| 246 |
+
def create_sgl_function(self):
|
| 247 |
+
"""Create SGL function for C-Eval."""
|
| 248 |
+
return create_simple_sgl_function(
|
| 249 |
+
function_name="get_ceval_answer",
|
| 250 |
+
answer_key="answer",
|
| 251 |
+
max_tokens=self.get_max_new_tokens(),
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
def extract_answer(self, output: str, label: Any = None) -> str:
|
| 255 |
+
"""Extract answer choice from model output."""
|
| 256 |
+
return extract_answer(output)
|
| 257 |
+
|
| 258 |
+
def compute_accuracy(self, predictions: List[str], labels: List[str]) -> float:
|
| 259 |
+
"""Compute accuracy metric."""
|
| 260 |
+
correct = 0
|
| 261 |
+
valid_count = 0
|
| 262 |
+
for i in range(len(predictions)):
|
| 263 |
+
if predictions[i] is not None: # Only count valid predictions
|
| 264 |
+
valid_count += 1
|
| 265 |
+
if predictions[i] == labels[i]:
|
| 266 |
+
correct += 1
|
| 267 |
+
return correct / valid_count if valid_count > 0 else 0.0
|
SpecForge-ext/benchmarks/benchmarker/financeqa.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
|
| 5 |
+
from .base import Benchmarker
|
| 6 |
+
from .registry import BENCHMARKS
|
| 7 |
+
from .utils import create_simple_sgl_function
|
| 8 |
+
|
| 9 |
+
QUESTION_PROMPT = """
|
| 10 |
+
Given the following context:
|
| 11 |
+
|
| 12 |
+
{context}
|
| 13 |
+
|
| 14 |
+
Can you answer the following question?
|
| 15 |
+
|
| 16 |
+
{question}
|
| 17 |
+
""".strip()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def generate_question(row: Dict[str, Any]) -> str:
|
| 21 |
+
if row["context"] is None:
|
| 22 |
+
return row["question"].strip()
|
| 23 |
+
else:
|
| 24 |
+
question = QUESTION_PROMPT.format(
|
| 25 |
+
context=row["context"].strip(),
|
| 26 |
+
question=row["question"].strip(),
|
| 27 |
+
)
|
| 28 |
+
return question
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@BENCHMARKS.register("financeqa")
|
| 32 |
+
class FinanceQABenchmarker(Benchmarker):
|
| 33 |
+
"""FinanceQA benchmark implementation."""
|
| 34 |
+
|
| 35 |
+
def __init__(self, num_samples: Optional[int] = None):
|
| 36 |
+
super().__init__(num_samples, None)
|
| 37 |
+
|
| 38 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
|
| 39 |
+
"""Load and preprocess FinanceQA dataset."""
|
| 40 |
+
# Read data
|
| 41 |
+
ds = load_dataset("AfterQuery/FinanceQA")["test"]
|
| 42 |
+
|
| 43 |
+
questions = []
|
| 44 |
+
labels = []
|
| 45 |
+
for i in range((len(ds))):
|
| 46 |
+
if self.num_samples is not None and i >= self.num_samples:
|
| 47 |
+
break
|
| 48 |
+
|
| 49 |
+
question_text = generate_question(ds[i])
|
| 50 |
+
questions.append({"question": question_text})
|
| 51 |
+
labels.append(None)
|
| 52 |
+
return questions, labels
|
| 53 |
+
|
| 54 |
+
def create_sgl_function(self):
|
| 55 |
+
return create_simple_sgl_function(
|
| 56 |
+
function_name="get_financeqa_answer",
|
| 57 |
+
answer_key="answer",
|
| 58 |
+
max_tokens=self.get_max_new_tokens(),
|
| 59 |
+
)
|
SpecForge-ext/benchmarks/benchmarker/gpqa.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
|
| 6 |
+
from .base import Benchmarker
|
| 7 |
+
from .registry import BENCHMARKS
|
| 8 |
+
from .utils import create_simple_sgl_function
|
| 9 |
+
|
| 10 |
+
GPQA_QUERY_TEMPLATE = """
|
| 11 |
+
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
|
| 12 |
+
|
| 13 |
+
{Question}
|
| 14 |
+
|
| 15 |
+
A) {A}
|
| 16 |
+
B) {B}
|
| 17 |
+
C) {C}
|
| 18 |
+
D) {D}
|
| 19 |
+
""".strip()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def generate_question(row: Dict[str, Any]) -> str:
|
| 23 |
+
gold_index = random.randint(0, 3)
|
| 24 |
+
choices = [
|
| 25 |
+
row["Incorrect Answer 1"],
|
| 26 |
+
row["Incorrect Answer 2"],
|
| 27 |
+
row["Incorrect Answer 3"],
|
| 28 |
+
]
|
| 29 |
+
choices.insert(gold_index, row["Correct Answer"])
|
| 30 |
+
|
| 31 |
+
question = GPQA_QUERY_TEMPLATE.format(
|
| 32 |
+
Question=row["Question"].strip(),
|
| 33 |
+
A=choices[0].strip(),
|
| 34 |
+
B=choices[1].strip(),
|
| 35 |
+
C=choices[2].strip(),
|
| 36 |
+
D=choices[3].strip(),
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# 0 means A, 1 means B, 2 means C, 3 means D
|
| 40 |
+
answer = ["A", "B", "C", "D"][gold_index]
|
| 41 |
+
return question, answer
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@BENCHMARKS.register("gpqa")
|
| 45 |
+
class GPQABenchmarker(Benchmarker):
|
| 46 |
+
"""GPQA benchmark implementation."""
|
| 47 |
+
|
| 48 |
+
def __init__(self, num_samples: Optional[int] = None):
|
| 49 |
+
super().__init__(num_samples, None)
|
| 50 |
+
|
| 51 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
|
| 52 |
+
"""Load and preprocess GPQA dataset."""
|
| 53 |
+
# Read data
|
| 54 |
+
ds = load_dataset("Idavidrein/gpqa", "gpqa_main")["train"]
|
| 55 |
+
|
| 56 |
+
questions = []
|
| 57 |
+
labels = []
|
| 58 |
+
for i in range((len(ds))):
|
| 59 |
+
if self.num_samples is not None and i >= self.num_samples:
|
| 60 |
+
break
|
| 61 |
+
|
| 62 |
+
question_text, answer = generate_question(ds[i])
|
| 63 |
+
questions.append({"question": question_text})
|
| 64 |
+
labels.append(answer)
|
| 65 |
+
return questions, labels
|
| 66 |
+
|
| 67 |
+
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]:
|
| 68 |
+
if "Answer: " not in output:
|
| 69 |
+
return None
|
| 70 |
+
return output.split("Answer: ")[1].strip()
|
| 71 |
+
|
| 72 |
+
def compute_accuracy(
|
| 73 |
+
self, predictions: List[Any], labels: List[Any]
|
| 74 |
+
) -> Optional[float]:
|
| 75 |
+
if not labels or len(labels) == 0:
|
| 76 |
+
return None
|
| 77 |
+
correct = sum(1 for pred, label in zip(predictions, labels) if pred == label)
|
| 78 |
+
return correct / len(labels) if len(labels) > 0 else 0.0
|
| 79 |
+
|
| 80 |
+
def create_sgl_function(self):
|
| 81 |
+
return create_simple_sgl_function(
|
| 82 |
+
function_name="get_gpqa_mcq_answer",
|
| 83 |
+
answer_key="answer",
|
| 84 |
+
max_tokens=self.get_max_new_tokens(),
|
| 85 |
+
)
|
SpecForge-ext/benchmarks/benchmarker/gsm8k.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GSM8K benchmark evaluation script.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import ast
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
from sglang.utils import download_and_cache_file, read_jsonl
|
| 11 |
+
|
| 12 |
+
from .base import Benchmarker
|
| 13 |
+
from .registry import BENCHMARKS
|
| 14 |
+
from .utils import create_few_shot_sgl_function
|
| 15 |
+
|
| 16 |
+
INVALID = -9999999
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_one_example(lines: List[Dict], i: int, include_answer: bool) -> str:
|
| 20 |
+
"""Format a single example."""
|
| 21 |
+
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
|
| 22 |
+
if include_answer:
|
| 23 |
+
ret += " " + lines[i]["answer"]
|
| 24 |
+
return ret
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_few_shot_examples(lines: List[Dict], k: int) -> str:
|
| 28 |
+
"""Get few-shot examples as a string."""
|
| 29 |
+
ret = ""
|
| 30 |
+
for i in range(k):
|
| 31 |
+
ret += get_one_example(lines, i, True) + "\n\n"
|
| 32 |
+
return ret
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_answer_value(answer_str: str) -> int:
|
| 36 |
+
"""Extract numeric answer from model output."""
|
| 37 |
+
answer_str = answer_str.replace(",", "")
|
| 38 |
+
numbers = re.findall(r"\d+", answer_str)
|
| 39 |
+
if len(numbers) < 1:
|
| 40 |
+
return INVALID
|
| 41 |
+
try:
|
| 42 |
+
return ast.literal_eval(numbers[-1])
|
| 43 |
+
except SyntaxError:
|
| 44 |
+
return INVALID
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@BENCHMARKS.register("gsm8k")
|
| 48 |
+
class GSM8KBenchmarker(Benchmarker):
|
| 49 |
+
"""GSM8K benchmark implementation."""
|
| 50 |
+
|
| 51 |
+
def __init__(self, num_samples: Optional[int] = None):
|
| 52 |
+
super().__init__(num_samples, None)
|
| 53 |
+
|
| 54 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
|
| 55 |
+
"""Load and preprocess GSM8K dataset."""
|
| 56 |
+
# 优先从本地数据目录读取
|
| 57 |
+
local_path = "/workspace/hanrui/datasets/gsm8k/test.jsonl"
|
| 58 |
+
|
| 59 |
+
if os.path.exists(local_path):
|
| 60 |
+
print(f"Loading GSM8K data from local: {local_path}")
|
| 61 |
+
lines = list(read_jsonl(local_path))
|
| 62 |
+
else:
|
| 63 |
+
# 如果本地不存在,从网络下载
|
| 64 |
+
print(f"Local data not found, downloading from GitHub...")
|
| 65 |
+
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
| 66 |
+
data_path = download_and_cache_file(url)
|
| 67 |
+
lines = list(read_jsonl(data_path))
|
| 68 |
+
|
| 69 |
+
# Construct prompts
|
| 70 |
+
few_shot_examples = get_few_shot_examples(lines, 5)
|
| 71 |
+
|
| 72 |
+
questions = []
|
| 73 |
+
labels = []
|
| 74 |
+
for i in range((len(lines))):
|
| 75 |
+
if self.num_samples is not None and i >= self.num_samples:
|
| 76 |
+
break
|
| 77 |
+
|
| 78 |
+
question_text = get_one_example(lines, i, False)
|
| 79 |
+
questions.append({"question": question_text})
|
| 80 |
+
labels.append(get_answer_value(lines[i]["answer"]))
|
| 81 |
+
|
| 82 |
+
# Store few_shot_examples for use in create_sgl_function
|
| 83 |
+
self.few_shot_examples = few_shot_examples
|
| 84 |
+
|
| 85 |
+
assert all(l != INVALID for l in labels), "Some labels are invalid"
|
| 86 |
+
return questions, labels
|
| 87 |
+
|
| 88 |
+
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]:
|
| 89 |
+
"""Extract numeric answer from model output."""
|
| 90 |
+
return get_answer_value(output)
|
| 91 |
+
|
| 92 |
+
def compute_accuracy(
|
| 93 |
+
self, predictions: List[Any], labels: List[Any]
|
| 94 |
+
) -> Optional[float]:
|
| 95 |
+
"""Compute accuracy for GSM8K by comparing numeric answers."""
|
| 96 |
+
if not labels or len(labels) == 0:
|
| 97 |
+
return None
|
| 98 |
+
correct = sum(1 for pred, label in zip(predictions, labels) if pred == label)
|
| 99 |
+
return correct / len(labels) if len(labels) > 0 else 0.0
|
| 100 |
+
|
| 101 |
+
def create_sgl_function(self):
|
| 102 |
+
"""Create SGL function for GSM8K with few-shot examples."""
|
| 103 |
+
return create_few_shot_sgl_function(
|
| 104 |
+
few_shot_examples=self.few_shot_examples,
|
| 105 |
+
function_name="few_shot_gsm8k",
|
| 106 |
+
answer_key="answer",
|
| 107 |
+
stop=["Question", "Assistant:", "<|separator|>"],
|
| 108 |
+
)
|
SpecForge-ext/benchmarks/benchmarker/humaneval.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HumanEval benchmark evaluation script.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
from datasets import load_dataset
|
| 11 |
+
|
| 12 |
+
from .base import Benchmarker
|
| 13 |
+
from .registry import BENCHMARKS
|
| 14 |
+
from .utils import create_simple_sgl_function
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def extract_code_from_output(output: str) -> Optional[str]:
|
| 18 |
+
"""Extract Python code from model output.
|
| 19 |
+
|
| 20 |
+
Tries to extract code blocks or function definitions.
|
| 21 |
+
"""
|
| 22 |
+
# Try to find code in markdown code blocks
|
| 23 |
+
code_block_pattern = r"```(?:python)?\n(.*?)```"
|
| 24 |
+
match = re.search(code_block_pattern, output, re.DOTALL)
|
| 25 |
+
if match:
|
| 26 |
+
return match.group(1).strip()
|
| 27 |
+
|
| 28 |
+
# Try to find function definition (common in HumanEval)
|
| 29 |
+
# Look for "def " followed by code until the next def or end of string
|
| 30 |
+
def_pattern = r"(def\s+\w+\([^)]*\):.*?)(?=\n\ndef\s+|\Z)"
|
| 31 |
+
match = re.search(def_pattern, output, re.DOTALL)
|
| 32 |
+
if match:
|
| 33 |
+
return match.group(1).strip()
|
| 34 |
+
|
| 35 |
+
# Fallback: return the output as-is (might already be code)
|
| 36 |
+
return output.strip() if output.strip() else None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def check_code_passes_tests(code: str, test_code: str, entry_point: str) -> bool:
|
| 40 |
+
"""Check if generated code passes the test cases.
|
| 41 |
+
|
| 42 |
+
This is a simplified version. For full evaluation, use the official
|
| 43 |
+
HumanEval evaluation framework.
|
| 44 |
+
|
| 45 |
+
HumanEval test code typically contains assertions that will raise
|
| 46 |
+
AssertionError if the code doesn't pass. If execution completes without
|
| 47 |
+
exceptions, the tests pass.
|
| 48 |
+
"""
|
| 49 |
+
try:
|
| 50 |
+
# Create a safe execution environment
|
| 51 |
+
namespace = {}
|
| 52 |
+
# Execute the code (function definition)
|
| 53 |
+
exec(code, namespace)
|
| 54 |
+
# Execute the test code (which contains assertions)
|
| 55 |
+
# If no exception is raised, the tests pass
|
| 56 |
+
exec(test_code, namespace)
|
| 57 |
+
return True
|
| 58 |
+
except AssertionError:
|
| 59 |
+
# Assertion failed - test didn't pass
|
| 60 |
+
return False
|
| 61 |
+
except Exception:
|
| 62 |
+
# Any other exception (syntax error, runtime error, etc.) means test failed
|
| 63 |
+
return False
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@BENCHMARKS.register("humaneval")
|
| 67 |
+
class HumanEvalBenchmarker(Benchmarker):
|
| 68 |
+
"""HumanEval benchmark implementation."""
|
| 69 |
+
|
| 70 |
+
def __init__(self, num_samples: Optional[int] = None):
|
| 71 |
+
"""Initialize benchmark and store test cases."""
|
| 72 |
+
super().__init__(num_samples, None)
|
| 73 |
+
self.test_cases = []
|
| 74 |
+
self.entry_points = []
|
| 75 |
+
|
| 76 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[Dict[str, str]]]]:
|
| 77 |
+
"""Load and preprocess HumanEval dataset."""
|
| 78 |
+
# 优先从本地数据目录读取
|
| 79 |
+
local_path = "/workspace/hanrui/datasets/humaneval/test.jsonl"
|
| 80 |
+
|
| 81 |
+
if os.path.exists(local_path):
|
| 82 |
+
print(f"Loading HumanEval data from local: {local_path}")
|
| 83 |
+
with open(local_path, 'r') as f:
|
| 84 |
+
dataset = [json.loads(line) for line in f]
|
| 85 |
+
else:
|
| 86 |
+
# 如果本地不存在,从 HuggingFace 下载
|
| 87 |
+
print(f"Local data not found, downloading from HuggingFace...")
|
| 88 |
+
dataset = load_dataset("openai/openai_humaneval")["test"]
|
| 89 |
+
|
| 90 |
+
questions = []
|
| 91 |
+
labels = []
|
| 92 |
+
self.test_cases = []
|
| 93 |
+
self.entry_points = []
|
| 94 |
+
|
| 95 |
+
for idx, q in enumerate(dataset):
|
| 96 |
+
if self.num_samples is not None and idx >= self.num_samples:
|
| 97 |
+
break
|
| 98 |
+
|
| 99 |
+
questions.append({"question": q["prompt"]})
|
| 100 |
+
|
| 101 |
+
# Store test case and entry point for evaluation
|
| 102 |
+
test_code = q.get("test", "")
|
| 103 |
+
entry_point = q.get("entry_point", "")
|
| 104 |
+
self.test_cases.append(test_code)
|
| 105 |
+
self.entry_points.append(entry_point)
|
| 106 |
+
|
| 107 |
+
# Store canonical solution as reference (optional, for comparison)
|
| 108 |
+
canonical_solution = q.get("canonical_solution", "")
|
| 109 |
+
labels.append(
|
| 110 |
+
{
|
| 111 |
+
"test": test_code,
|
| 112 |
+
"entry_point": entry_point,
|
| 113 |
+
"canonical_solution": canonical_solution,
|
| 114 |
+
}
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
return questions, labels
|
| 118 |
+
|
| 119 |
+
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
|
| 120 |
+
"""Extract code from model output."""
|
| 121 |
+
return extract_code_from_output(output)
|
| 122 |
+
|
| 123 |
+
def compute_accuracy(
|
| 124 |
+
self, predictions: List[Any], labels: List[Any]
|
| 125 |
+
) -> Optional[float]:
|
| 126 |
+
"""Compute accuracy for HumanEval by checking if code passes tests.
|
| 127 |
+
|
| 128 |
+
Note: This is a simplified evaluation. For official pass@k metrics,
|
| 129 |
+
use the HumanEval evaluation framework.
|
| 130 |
+
"""
|
| 131 |
+
if not labels or len(labels) == 0:
|
| 132 |
+
return None
|
| 133 |
+
if all(label is None for label in labels):
|
| 134 |
+
return None
|
| 135 |
+
|
| 136 |
+
correct = 0
|
| 137 |
+
valid_count = 0
|
| 138 |
+
|
| 139 |
+
for i, (pred, label) in enumerate(zip(predictions, labels)):
|
| 140 |
+
if label is not None and isinstance(label, dict):
|
| 141 |
+
valid_count += 1
|
| 142 |
+
if pred is not None:
|
| 143 |
+
try:
|
| 144 |
+
# Get the prompt (function signature and docstring)
|
| 145 |
+
prompt = self.questions[i]["question"]
|
| 146 |
+
entry_point = label.get("entry_point", "")
|
| 147 |
+
|
| 148 |
+
# The prompt contains the function signature (e.g., "def function_name(...):")
|
| 149 |
+
# The generated code might be:
|
| 150 |
+
# 1. Just the function body (what we want) - need to combine with prompt
|
| 151 |
+
# 2. The complete function including signature - use as-is
|
| 152 |
+
# 3. Code in markdown blocks - already extracted by extract_code_from_output
|
| 153 |
+
|
| 154 |
+
pred_str = str(pred).strip()
|
| 155 |
+
|
| 156 |
+
# Check if pred already contains a complete function definition
|
| 157 |
+
# (starts with "def " and contains the entry_point function name)
|
| 158 |
+
if pred_str.startswith("def ") and entry_point:
|
| 159 |
+
# Check if this is the same function (by name)
|
| 160 |
+
func_name_match = re.match(r"def\s+(\w+)\s*\(", pred_str)
|
| 161 |
+
if (
|
| 162 |
+
func_name_match
|
| 163 |
+
and func_name_match.group(1) == entry_point
|
| 164 |
+
):
|
| 165 |
+
# Generated code includes complete function, use it as-is
|
| 166 |
+
full_code = pred_str
|
| 167 |
+
else:
|
| 168 |
+
# Different function or no match, combine with prompt
|
| 169 |
+
full_code = prompt + "\n" + pred_str
|
| 170 |
+
elif pred_str.startswith("def "):
|
| 171 |
+
# Has function definition but we can't verify entry_point, use as-is
|
| 172 |
+
full_code = pred_str
|
| 173 |
+
else:
|
| 174 |
+
# Generated code is just the body, combine with prompt
|
| 175 |
+
full_code = prompt + "\n" + pred_str
|
| 176 |
+
|
| 177 |
+
# Check if code passes tests
|
| 178 |
+
test_code = label.get("test", "")
|
| 179 |
+
|
| 180 |
+
if test_code and check_code_passes_tests(
|
| 181 |
+
full_code, test_code, entry_point
|
| 182 |
+
):
|
| 183 |
+
correct += 1
|
| 184 |
+
except Exception as e:
|
| 185 |
+
# If evaluation fails, consider it incorrect
|
| 186 |
+
# Uncomment for debugging: print(f"Error evaluating code {i}: {e}")
|
| 187 |
+
pass
|
| 188 |
+
|
| 189 |
+
return correct / valid_count if valid_count > 0 else 0.0
|
| 190 |
+
|
| 191 |
+
def create_sgl_function(self):
|
| 192 |
+
"""Create SGL function for HumanEval."""
|
| 193 |
+
return create_simple_sgl_function(
|
| 194 |
+
function_name="get_humaneval_answer",
|
| 195 |
+
answer_key="answer",
|
| 196 |
+
max_tokens=self.get_max_new_tokens(),
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
def get_max_new_tokens(self) -> int:
|
| 200 |
+
"""HumanEval code generation requires more tokens."""
|
| 201 |
+
return 1024
|
SpecForge-ext/benchmarks/benchmarker/livecodebench.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GSM8K benchmark evaluation script.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
|
| 9 |
+
from .base import Benchmarker
|
| 10 |
+
from .registry import BENCHMARKS
|
| 11 |
+
from .utils import create_simple_sgl_function
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def generate_question(row: Dict[str, Any]) -> str:
|
| 15 |
+
question = row["question_content"].strip()
|
| 16 |
+
return question
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@BENCHMARKS.register("livecodebench")
|
| 20 |
+
class LCBBenchmarker(Benchmarker):
|
| 21 |
+
"""LiveCodeBench benchmark implementation."""
|
| 22 |
+
|
| 23 |
+
def __init__(self, num_samples: Optional[int] = None):
|
| 24 |
+
super().__init__(num_samples, None)
|
| 25 |
+
|
| 26 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
|
| 27 |
+
# Read data
|
| 28 |
+
ds = load_dataset("livecodebench/code_generation")["test"]
|
| 29 |
+
|
| 30 |
+
questions = []
|
| 31 |
+
labels = []
|
| 32 |
+
for i in range((len(ds))):
|
| 33 |
+
if self.num_samples is not None and i >= self.num_samples:
|
| 34 |
+
break
|
| 35 |
+
|
| 36 |
+
question_text = generate_question(ds[i])
|
| 37 |
+
questions.append({"question": question_text})
|
| 38 |
+
labels.append(None)
|
| 39 |
+
return questions, labels
|
| 40 |
+
|
| 41 |
+
def create_sgl_function(self):
|
| 42 |
+
return create_simple_sgl_function(
|
| 43 |
+
function_name="get_livecodebench_answer",
|
| 44 |
+
answer_key="answer",
|
| 45 |
+
max_tokens=self.get_max_new_tokens(),
|
| 46 |
+
)
|
SpecForge-ext/benchmarks/benchmarker/math500.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MATH-500 benchmark evaluation script.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
|
| 10 |
+
from .base import Benchmarker
|
| 11 |
+
from .registry import BENCHMARKS
|
| 12 |
+
from .utils import create_simple_sgl_function
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def extract_math_answer(output: str) -> Optional[str]:
|
| 16 |
+
"""Extract final answer from math problem solution.
|
| 17 |
+
|
| 18 |
+
Tries to extract answer from \boxed{} format first, then looks for
|
| 19 |
+
the last number in the output.
|
| 20 |
+
"""
|
| 21 |
+
# Try to find answer in \boxed{} format
|
| 22 |
+
boxed_pattern = r"\\boxed\{([^}]+)\}"
|
| 23 |
+
match = re.search(boxed_pattern, output)
|
| 24 |
+
if match:
|
| 25 |
+
return match.group(1).strip()
|
| 26 |
+
|
| 27 |
+
# Try to find answer in \boxed format (without braces)
|
| 28 |
+
boxed_pattern2 = r"\\boxed\s+([^\s]+)"
|
| 29 |
+
match = re.search(boxed_pattern2, output)
|
| 30 |
+
if match:
|
| 31 |
+
return match.group(1).strip()
|
| 32 |
+
|
| 33 |
+
# Try to find the last number (could be integer or decimal)
|
| 34 |
+
# Look for patterns like "The answer is 42" or "Answer: 3.14"
|
| 35 |
+
answer_patterns = [
|
| 36 |
+
r"(?:answer|Answer|ANSWER)[\s:]+([-+]?\d*\.?\d+)",
|
| 37 |
+
r"(?:is|equals?|=\s*)([-+]?\d*\.?\d+)\s*$",
|
| 38 |
+
]
|
| 39 |
+
for pattern in answer_patterns:
|
| 40 |
+
matches = re.findall(pattern, output, re.IGNORECASE)
|
| 41 |
+
if matches:
|
| 42 |
+
return matches[-1].strip()
|
| 43 |
+
|
| 44 |
+
# Fallback: extract the last number in the text
|
| 45 |
+
numbers = re.findall(r"[-+]?\d*\.?\d+", output)
|
| 46 |
+
if numbers:
|
| 47 |
+
return numbers[-1]
|
| 48 |
+
|
| 49 |
+
return None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@BENCHMARKS.register("math500")
|
| 53 |
+
class Math500Benchmarker(Benchmarker):
|
| 54 |
+
"""MATH-500 benchmark implementation."""
|
| 55 |
+
|
| 56 |
+
def __init__(self, num_samples: Optional[int] = None):
|
| 57 |
+
super().__init__(num_samples, None)
|
| 58 |
+
|
| 59 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]:
|
| 60 |
+
"""Load and preprocess MATH-500 dataset."""
|
| 61 |
+
dataset = load_dataset("HuggingFaceH4/MATH-500")["test"]
|
| 62 |
+
questions = []
|
| 63 |
+
labels = []
|
| 64 |
+
for idx, q in enumerate(dataset):
|
| 65 |
+
if self.num_samples is not None and idx >= self.num_samples:
|
| 66 |
+
break
|
| 67 |
+
|
| 68 |
+
questions.append({"question": q["problem"]})
|
| 69 |
+
# Extract answer from solution or answer field
|
| 70 |
+
answer = None
|
| 71 |
+
if "answer" in q:
|
| 72 |
+
answer = str(q["answer"]).strip()
|
| 73 |
+
elif "solution" in q:
|
| 74 |
+
# Try to extract from solution
|
| 75 |
+
answer = extract_math_answer(q["solution"])
|
| 76 |
+
labels.append(answer)
|
| 77 |
+
return questions, labels
|
| 78 |
+
|
| 79 |
+
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
|
| 80 |
+
"""Extract answer from model output."""
|
| 81 |
+
return extract_math_answer(output)
|
| 82 |
+
|
| 83 |
+
def compute_accuracy(
|
| 84 |
+
self, predictions: List[Any], labels: List[Any]
|
| 85 |
+
) -> Optional[float]:
|
| 86 |
+
"""Compute accuracy for MATH-500 by comparing answers."""
|
| 87 |
+
if not labels or len(labels) == 0:
|
| 88 |
+
return None
|
| 89 |
+
if all(label is None for label in labels):
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
correct = 0
|
| 93 |
+
valid_count = 0
|
| 94 |
+
for pred, label in zip(predictions, labels):
|
| 95 |
+
if label is not None:
|
| 96 |
+
valid_count += 1
|
| 97 |
+
if pred is not None:
|
| 98 |
+
# Normalize answers for comparison (remove whitespace, handle different formats)
|
| 99 |
+
pred_normalized = str(pred).strip().lower()
|
| 100 |
+
label_normalized = str(label).strip().lower()
|
| 101 |
+
# Try exact match first
|
| 102 |
+
if pred_normalized == label_normalized:
|
| 103 |
+
correct += 1
|
| 104 |
+
else:
|
| 105 |
+
# Try numeric comparison if both are numbers
|
| 106 |
+
try:
|
| 107 |
+
pred_num = float(pred_normalized)
|
| 108 |
+
label_num = float(label_normalized)
|
| 109 |
+
if abs(pred_num - label_num) < 1e-6:
|
| 110 |
+
correct += 1
|
| 111 |
+
except ValueError:
|
| 112 |
+
pass
|
| 113 |
+
|
| 114 |
+
return correct / valid_count if valid_count > 0 else 0.0
|
| 115 |
+
|
| 116 |
+
def create_sgl_function(self):
|
| 117 |
+
"""Create SGL function for MATH-500."""
|
| 118 |
+
return create_simple_sgl_function(
|
| 119 |
+
function_name="get_math500_answer",
|
| 120 |
+
answer_key="answer",
|
| 121 |
+
max_tokens=self.get_max_new_tokens(),
|
| 122 |
+
)
|
SpecForge-ext/benchmarks/benchmarker/mmlu.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
|
| 5 |
+
from .base import Benchmarker
|
| 6 |
+
from .registry import BENCHMARKS
|
| 7 |
+
from .utils import create_simple_sgl_function
|
| 8 |
+
|
| 9 |
+
GPQA_QUERY_TEMPLATE = """
|
| 10 |
+
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
|
| 11 |
+
|
| 12 |
+
{Question}
|
| 13 |
+
|
| 14 |
+
A) {A}
|
| 15 |
+
B) {B}
|
| 16 |
+
C) {C}
|
| 17 |
+
D) {D}
|
| 18 |
+
""".strip()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def generate_question(row: Dict[str, Any]) -> str:
|
| 22 |
+
choices = row["choices"]
|
| 23 |
+
question = GPQA_QUERY_TEMPLATE.format(
|
| 24 |
+
Question=row["question"].strip(),
|
| 25 |
+
A=choices[0].strip(),
|
| 26 |
+
B=choices[1].strip(),
|
| 27 |
+
C=choices[2].strip(),
|
| 28 |
+
D=choices[3].strip(),
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# 0 means A, 1 means B, 2 means C, 3 means D
|
| 32 |
+
answer = ["A", "B", "C", "D"][row["answer"]]
|
| 33 |
+
print(answer)
|
| 34 |
+
return question, answer
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@BENCHMARKS.register("mmlu")
|
| 38 |
+
class MMLUBenchmarker(Benchmarker):
|
| 39 |
+
"""MMLU benchmark implementation."""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None
|
| 43 |
+
):
|
| 44 |
+
if subset is None:
|
| 45 |
+
subset = ["all"]
|
| 46 |
+
super().__init__(num_samples, subset)
|
| 47 |
+
|
| 48 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
|
| 49 |
+
# Read data
|
| 50 |
+
questions = []
|
| 51 |
+
labels = []
|
| 52 |
+
|
| 53 |
+
for subset in self.subset:
|
| 54 |
+
ds = load_dataset("cais/mmlu", subset)["test"]
|
| 55 |
+
for i in range((len(ds))):
|
| 56 |
+
if self.num_samples is not None and i >= self.num_samples:
|
| 57 |
+
break
|
| 58 |
+
|
| 59 |
+
question_text, answer = generate_question(ds[i])
|
| 60 |
+
questions.append({"question": question_text})
|
| 61 |
+
labels.append(answer)
|
| 62 |
+
return questions, labels
|
| 63 |
+
|
| 64 |
+
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]:
|
| 65 |
+
if "Answer: " not in output:
|
| 66 |
+
return None
|
| 67 |
+
return output.split("Answer: ")[1].strip()
|
| 68 |
+
|
| 69 |
+
def compute_accuracy(
|
| 70 |
+
self, predictions: List[Any], labels: List[Any]
|
| 71 |
+
) -> Optional[float]:
|
| 72 |
+
if not labels or len(labels) == 0:
|
| 73 |
+
return None
|
| 74 |
+
correct = sum(1 for pred, label in zip(predictions, labels) if pred == label)
|
| 75 |
+
return correct / len(labels) if len(labels) > 0 else 0.0
|
| 76 |
+
|
| 77 |
+
def create_sgl_function(self):
|
| 78 |
+
return create_simple_sgl_function(
|
| 79 |
+
function_name="get_mmlu_answer",
|
| 80 |
+
answer_key="answer",
|
| 81 |
+
max_tokens=self.get_max_new_tokens(),
|
| 82 |
+
)
|
SpecForge-ext/benchmarks/benchmarker/mmstar.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MMStar benchmark evaluation script.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import shutil
|
| 8 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
from datasets import load_dataset
|
| 11 |
+
|
| 12 |
+
from .base import Benchmarker
|
| 13 |
+
from .registry import BENCHMARKS
|
| 14 |
+
from .utils import create_image_sgl_function
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def extract_mmstar_answer(
|
| 18 |
+
output: str, options: Optional[List[str]] = None
|
| 19 |
+
) -> Optional[str]:
|
| 20 |
+
"""Extract answer from MMStar model output.
|
| 21 |
+
|
| 22 |
+
MMStar questions typically have multiple choice options (A, B, C, D, etc.)
|
| 23 |
+
"""
|
| 24 |
+
output_upper = output.strip().upper()
|
| 25 |
+
|
| 26 |
+
# Try to find answer choice (A, B, C, D, etc.)
|
| 27 |
+
# Direct match for single letter
|
| 28 |
+
match = re.search(r"\b([A-Z])\b", output_upper)
|
| 29 |
+
if match:
|
| 30 |
+
letter = match.group(1)
|
| 31 |
+
if options and len(options) > 0:
|
| 32 |
+
# Validate that the letter is within valid range
|
| 33 |
+
max_option = chr(64 + len(options)) # 'A' + (len-1)
|
| 34 |
+
if "A" <= letter <= max_option:
|
| 35 |
+
return letter
|
| 36 |
+
else:
|
| 37 |
+
# Assume A-D are valid
|
| 38 |
+
if "A" <= letter <= "D":
|
| 39 |
+
return letter
|
| 40 |
+
|
| 41 |
+
# Try to find answer in parentheses or brackets
|
| 42 |
+
for pattern in [
|
| 43 |
+
r"\(([A-Z])\)",
|
| 44 |
+
r"\[([A-Z])\]",
|
| 45 |
+
r"答案[::]\s*([A-Z])",
|
| 46 |
+
r"Answer[::]\s*([A-Z])",
|
| 47 |
+
r"选择[::]\s*([A-Z])",
|
| 48 |
+
]:
|
| 49 |
+
match = re.search(pattern, output_upper)
|
| 50 |
+
if match:
|
| 51 |
+
letter = match.group(1)
|
| 52 |
+
if options and len(options) > 0:
|
| 53 |
+
max_option = chr(64 + len(options))
|
| 54 |
+
if "A" <= letter <= max_option:
|
| 55 |
+
return letter
|
| 56 |
+
elif "A" <= letter <= "D":
|
| 57 |
+
return letter
|
| 58 |
+
|
| 59 |
+
return None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@BENCHMARKS.register("mmstar")
|
| 63 |
+
class MMStarBenchmarker(Benchmarker):
|
| 64 |
+
"""MMStar benchmark implementation."""
|
| 65 |
+
|
| 66 |
+
def __init__(self, num_samples: Optional[int] = None):
|
| 67 |
+
super().__init__(num_samples, None)
|
| 68 |
+
"""Initialize benchmark and set up cache directory."""
|
| 69 |
+
self.cache_dir = None
|
| 70 |
+
self.options_list = [] # Store options for each question
|
| 71 |
+
|
| 72 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]:
|
| 73 |
+
"""Load and preprocess MMStar dataset."""
|
| 74 |
+
self.cache_dir = os.path.join(".cache", "mmstar_specforge")
|
| 75 |
+
image_dir = os.path.join(self.cache_dir, "images")
|
| 76 |
+
os.makedirs(self.cache_dir, exist_ok=True)
|
| 77 |
+
os.makedirs(image_dir, exist_ok=True)
|
| 78 |
+
print(f"Created temporary image directory: {self.cache_dir}")
|
| 79 |
+
|
| 80 |
+
dataset = load_dataset("Lin-Chen/MMStar")["val"]
|
| 81 |
+
questions = []
|
| 82 |
+
labels = []
|
| 83 |
+
self.options_list = []
|
| 84 |
+
|
| 85 |
+
for idx, q in enumerate(dataset):
|
| 86 |
+
if self.num_samples is not None and idx >= self.num_samples:
|
| 87 |
+
break
|
| 88 |
+
|
| 89 |
+
image = q["image"]
|
| 90 |
+
image_path = os.path.join(self.cache_dir, q["meta_info"]["image_path"])
|
| 91 |
+
image.convert("RGB").save(image_path, "JPEG")
|
| 92 |
+
|
| 93 |
+
# Extract question and options
|
| 94 |
+
question_full = q["question"]
|
| 95 |
+
if "Options:" in question_full:
|
| 96 |
+
question_text, options_text = question_full.split("Options:", 1)
|
| 97 |
+
question_text = question_text.strip()
|
| 98 |
+
# Parse options (typically A. option1 B. option2 etc.)
|
| 99 |
+
options = []
|
| 100 |
+
for line in options_text.strip().split("\n"):
|
| 101 |
+
line = line.strip()
|
| 102 |
+
if line and re.match(r"^[A-Z]\.", line):
|
| 103 |
+
option_text = re.sub(r"^[A-Z]\.\s*", "", line).strip()
|
| 104 |
+
options.append(option_text)
|
| 105 |
+
self.options_list.append(options)
|
| 106 |
+
else:
|
| 107 |
+
question_text = question_full.strip()
|
| 108 |
+
self.options_list.append([])
|
| 109 |
+
|
| 110 |
+
item = {
|
| 111 |
+
"image_path": image_path,
|
| 112 |
+
"question": question_text,
|
| 113 |
+
}
|
| 114 |
+
questions.append(item)
|
| 115 |
+
|
| 116 |
+
# Extract ground truth answer
|
| 117 |
+
answer = None
|
| 118 |
+
if "answer" in q:
|
| 119 |
+
answer = str(q["answer"]).strip().upper()
|
| 120 |
+
elif "correct_answer" in q:
|
| 121 |
+
answer = str(q["correct_answer"]).strip().upper()
|
| 122 |
+
elif "ground_truth" in q:
|
| 123 |
+
answer = str(q["ground_truth"]).strip().upper()
|
| 124 |
+
|
| 125 |
+
# Validate answer is a valid option letter
|
| 126 |
+
if answer and len(answer) == 1 and "A" <= answer <= "Z":
|
| 127 |
+
if self.options_list[-1]:
|
| 128 |
+
max_option = chr(64 + len(self.options_list[-1]))
|
| 129 |
+
if answer <= max_option:
|
| 130 |
+
labels.append(answer)
|
| 131 |
+
else:
|
| 132 |
+
labels.append(None)
|
| 133 |
+
else:
|
| 134 |
+
labels.append(answer)
|
| 135 |
+
else:
|
| 136 |
+
labels.append(None)
|
| 137 |
+
|
| 138 |
+
return questions, labels
|
| 139 |
+
|
| 140 |
+
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]:
|
| 141 |
+
"""Extract answer from model output."""
|
| 142 |
+
# Use the options for the current question if available
|
| 143 |
+
# Note: We can't easily get the question index here, so we'll use a simpler approach
|
| 144 |
+
return extract_mmstar_answer(output)
|
| 145 |
+
|
| 146 |
+
def compute_accuracy(
|
| 147 |
+
self, predictions: List[Any], labels: List[Any]
|
| 148 |
+
) -> Optional[float]:
|
| 149 |
+
"""Compute accuracy for MMStar by comparing answer choices."""
|
| 150 |
+
if not labels or len(labels) == 0:
|
| 151 |
+
return None
|
| 152 |
+
if all(label is None for label in labels):
|
| 153 |
+
return None
|
| 154 |
+
|
| 155 |
+
correct = 0
|
| 156 |
+
valid_count = 0
|
| 157 |
+
for pred, label in zip(predictions, labels):
|
| 158 |
+
if label is not None:
|
| 159 |
+
valid_count += 1
|
| 160 |
+
if pred is not None:
|
| 161 |
+
# Normalize to uppercase for comparison
|
| 162 |
+
pred_normalized = str(pred).strip().upper()
|
| 163 |
+
label_normalized = str(label).strip().upper()
|
| 164 |
+
if pred_normalized == label_normalized:
|
| 165 |
+
correct += 1
|
| 166 |
+
|
| 167 |
+
return correct / valid_count if valid_count > 0 else 0.0
|
| 168 |
+
|
| 169 |
+
def create_sgl_function(self):
|
| 170 |
+
"""Create SGL function for MMStar (image-based Q&A)."""
|
| 171 |
+
return create_image_sgl_function(
|
| 172 |
+
function_name="get_mmstar_answer",
|
| 173 |
+
answer_key="answer",
|
| 174 |
+
max_tokens=self.get_max_new_tokens(),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
def run(self, *args, **kwargs):
|
| 178 |
+
"""Run benchmark and clean up cache directory."""
|
| 179 |
+
try:
|
| 180 |
+
return super().run(*args, **kwargs)
|
| 181 |
+
finally:
|
| 182 |
+
# Clean up cache directory
|
| 183 |
+
if self.cache_dir and os.path.exists(self.cache_dir):
|
| 184 |
+
shutil.rmtree(self.cache_dir)
|
| 185 |
+
print(f"Deleted temporary directory: {self.cache_dir}")
|
SpecForge-ext/benchmarks/benchmarker/mtbench.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MT-Bench benchmark evaluation script.
|
| 3 |
+
Adapted from https://github.com/chromecast56/sglang/blob/6f145d2eadb93a116134f703358ce76f15381045/benchmark/mtbench/bench_sglang.py
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
from sglang.utils import download_and_cache_file, read_jsonl
|
| 10 |
+
|
| 11 |
+
from .base import Benchmarker
|
| 12 |
+
from .registry import BENCHMARKS
|
| 13 |
+
from .utils import create_multi_turn_sgl_function
|
| 14 |
+
|
| 15 |
+
SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@BENCHMARKS.register("mtbench")
|
| 19 |
+
class MTBenchBenchmarker(Benchmarker):
|
| 20 |
+
"""MT-Bench benchmark implementation."""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None
|
| 24 |
+
):
|
| 25 |
+
# support categorical data for mtbench
|
| 26 |
+
if subset is None:
|
| 27 |
+
subset = ["all"]
|
| 28 |
+
super().__init__(num_samples, subset)
|
| 29 |
+
|
| 30 |
+
def load_data(self) -> Tuple[List[Dict[str, Any]], List[None]]:
|
| 31 |
+
"""Load and preprocess MT-Bench dataset."""
|
| 32 |
+
# 优先从本地数据目录读取
|
| 33 |
+
local_path = "/workspace/hanrui/datasets/mtbench/question.jsonl"
|
| 34 |
+
|
| 35 |
+
if os.path.exists(local_path):
|
| 36 |
+
print(f"Loading MT-Bench data from local: {local_path}")
|
| 37 |
+
questions_data = list(read_jsonl(local_path))
|
| 38 |
+
else:
|
| 39 |
+
# 如果本地不存在,从网络下载
|
| 40 |
+
print(f"Local data not found, downloading from GitHub...")
|
| 41 |
+
url = "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl"
|
| 42 |
+
download_and_cache_file(url, filename="mtbench.jsonl")
|
| 43 |
+
questions_data = list(read_jsonl("mtbench.jsonl"))
|
| 44 |
+
|
| 45 |
+
questions_data = questions_data
|
| 46 |
+
|
| 47 |
+
questions = [
|
| 48 |
+
{"question_1": q["turns"][0], "question_2": q["turns"][1]}
|
| 49 |
+
for q in questions_data
|
| 50 |
+
]
|
| 51 |
+
# MT-Bench doesn't have labels for accuracy computation
|
| 52 |
+
labels = [None] * len(questions)
|
| 53 |
+
|
| 54 |
+
if self.num_samples is not None:
|
| 55 |
+
questions = questions[: self.num_samples]
|
| 56 |
+
labels = labels[: self.num_samples]
|
| 57 |
+
return questions, labels
|
| 58 |
+
|
| 59 |
+
def create_sgl_function(self):
|
| 60 |
+
"""Create SGL function for MT-Bench (2-turn conversation)."""
|
| 61 |
+
return create_multi_turn_sgl_function(
|
| 62 |
+
function_name="answer_mt_bench",
|
| 63 |
+
system_prompt=SYSTEM_PROMPT,
|
| 64 |
+
num_turns=2,
|
| 65 |
+
max_tokens=self.get_max_new_tokens(),
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def get_answer_keys(self) -> List[str]:
|
| 69 |
+
"""Return answer keys for multi-turn conversation."""
|
| 70 |
+
return ["answer_1", "answer_2"]
|