File size: 9,954 Bytes
494c9e4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 | #!/usr/bin/env python3
"""
Semantic analyzer 效果评估脚本
通过 HTTP 调用 /api/analyze-semantic 接口进行评估。
支持 submode:count / match_score(已废弃) / fill_blank
评估维度:
1. 生成的 top10 (token和概率) 的合理性
2. token_attention score 的合理性
3. 完全无关查询时的结果合理性
用法(从项目根目录运行):
python scripts/eval_semantic.py -c scripts/cases/eval_cases_short.json -o eval_result.jsonl
python scripts/eval_semantic.py --submode count fill_blank -o eval_result.jsonl
python scripts/eval_semantic.py --url http://localhost:5001
输出为 JSONL 格式,每完成一例追加一行;中断后可再次运行,从中断处续跑。
"""
import argparse
import json
import os
import sys
import time
from pathlib import Path
from typing import Optional, Tuple
# Hugging Face Token(用于Private Space,可通过环境变量HF_TOKEN设置)
HF_TOKEN_ENV = "HF_TOKEN"
try:
import requests
except ImportError:
print("错误: 需要安装 requests 库")
print("请运行: pip install requests")
sys.exit(1)
# 测试用例:(名称, query, text)
# 相关:query 与 text 主题一致
# 无关:query 与 text 完全无关
TEST_CASES = [
("相关_AI", "人工智能", "人工智能正在改变我们的生活。机器学习、深度学习等技术在医疗、金融等领域广泛应用。"),
("相关_天气", "天气", "今天北京天气晴朗,气温适宜,适合户外活动。明天可能有小雨。"),
("无关_足球对AI", "足球比赛", "人工智能正在改变我们的生活。机器学习、深度学习等技术在医疗、金融等领域广泛应用。"),
("无关_烹饪对天气", "红烧肉做法", "今天北京天气晴朗,气温适宜,适合户外活动。明天可能有小雨。"),
]
DEFAULT_API_BASE = "http://localhost:5001"
def analyze_semantic_http(api_base: str, query: str, text: str, submode: Optional[str] = None, token: Optional[str] = None, timeout: int = 300) -> dict:
"""通过 HTTP 调用 analyze-semantic 接口"""
url = f"{api_base.rstrip('/')}/api/analyze-semantic"
payload: dict = {"query": query, "text": text, "debug_info": True}
if submode is not None:
payload["submode"] = submode
headers = {"Content-Type": "application/json"}
if token:
headers["Authorization"] = f"Bearer {token}"
resp = requests.post(url, json=payload, headers=headers, timeout=timeout)
resp.raise_for_status()
data = resp.json()
if not data.get("success"):
raise RuntimeError(data.get("message", "分析失败"))
return data
def _load_jsonl(path: Path) -> list:
"""加载 JSONL 文件,用于断点续跑"""
if not path.exists():
return []
results = []
for line in path.read_text(encoding="utf-8").strip().split("\n"):
if not line:
continue
try:
results.append(json.loads(line))
except json.JSONDecodeError:
pass
return results
def _append_record(path: Path, record: dict) -> None:
"""追加单条记录到 JSONL 文件"""
with path.open("a", encoding="utf-8") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
def run_eval(
api_base: str,
submode: str,
test_cases: list,
token: Optional[str] = None,
output_path: Optional[Path] = None,
all_results: Optional[list] = None,
completed: Optional[set] = None,
max_retries: int = 3,
timeout: int = 300,
) -> Tuple[list, bool]:
"""返回 (results, aborted),重试后仍失败时 aborted 为 True"""
completed = completed or set()
results = []
for j, (name, query, text) in enumerate(test_cases):
prog = f"[{j+1}/{len(test_cases)}]"
if (submode, name) in completed:
print(f"{prog} ⏭ 跳过: {submode} | {name}", flush=True)
continue
print(f"{prog} 执行: {submode} | {name}", flush=True)
res = None
last_error = None
for attempt in range(max_retries + 1):
try:
res = analyze_semantic_http(api_base, query, text, submode, token=token, timeout=timeout)
break
except Exception as e:
last_error = e
if attempt < max_retries:
wait = 3 * (attempt + 1)
print(f"{prog} 重试 {attempt + 1}/{max_retries},{wait}s 后... - {e}", flush=True)
time.sleep(wait)
if res is None:
print(f"{prog} ✗ 失败(已重试 {max_retries} 次): {submode} | {name} - {last_error}", flush=True)
record = {"submode": submode, "case": name, "query": query, "error": str(last_error)}
results.append(record)
if all_results is not None:
all_results.append(record)
completed.add((submode, name))
print(f"\n⚠ 重试后仍失败,中断后续用例", flush=True)
return results, True
di = res.get("debug_info", {})
topk_tokens = di.get("topk_tokens", [])
topk_probs = di.get("topk_probs", [])
token_attention = res.get("token_attention", [])
# 0-max 归一化: score / max ∈ [0, 1],最大值归一为 1
score_max = max(a["score"] for a in token_attention) if token_attention else 0
denom = score_max if score_max > 0 else 1
# 按 score 排序取 top10
sorted_attn = sorted(token_attention, key=lambda x: x["score"], reverse=True)[:10]
top_scored = []
for a in sorted_attn:
score_norm = round(a["score"] / denom, 6)
top_scored.append({
"raw": a["raw"],
"score": round(a["score"], 6),
"score_norm": score_norm,
"offset": a["offset"],
})
record = {
"model": res.get("model", ""),
"submode": submode,
"case": name,
"query": query,
"text_preview": text[:80] + "..." if len(text) > 80 else text,
"full_match_degree": res.get("full_match_degree", None),
"top10_tokens": topk_tokens,
"top10_probs": [round(p, 6) for p in topk_probs],
"top10_scored_raw": top_scored,
"score_stats": {
"min": round(min(a["score"] for a in token_attention), 6) if token_attention else None,
"max": round(score_max, 6) if token_attention else None,
"mean": round(sum(a["score"] for a in token_attention) / len(token_attention), 6) if token_attention else None,
"mean_norm": round(sum(a["score"] / denom for a in token_attention) / len(token_attention), 6) if token_attention else None,
},
}
results.append(record)
if all_results is not None:
all_results.append(record)
completed.add((submode, name))
if output_path:
_append_record(output_path, record)
print(f"{prog} ✓ 完成: {submode} | {name}", flush=True)
return results, False
def main():
parser = argparse.ArgumentParser(description="评估 semantic analyzer 效果(HTTP)")
parser.add_argument(
"--submode",
choices=["count", "match_score", "fill_blank"],
nargs="+",
default=None,
help="instruct 模型子模式(可多个),不指定则依次评估 count/fill_blank;match_score 已废弃",
)
parser.add_argument(
"--output", "-o",
type=Path,
default=None,
help="结果输出 JSONL 路径(支持断点续跑)",
)
parser.add_argument(
"--url",
default=DEFAULT_API_BASE,
help=f"API 地址,默认 {DEFAULT_API_BASE}",
)
parser.add_argument(
"--hf-token",
type=str,
default=None,
help=f"Hugging Face Token(用于Private Space,也可通过环境变量{HF_TOKEN_ENV}设置)",
)
parser.add_argument(
"--cases", "-c",
type=Path,
nargs="+",
default=None,
help="自定义测试用例 JSON 文件,可指定多个,格式 [{name, query, text}, ...]",
)
parser.add_argument(
"--retries",
type=int,
default=3,
help="失败时自动重试次数,默认 3",
)
parser.add_argument(
"--timeout",
type=int,
default=300,
help="单次请求超时秒数,默认 300",
)
args = parser.parse_args()
api_base = args.url.rstrip("/")
hf_token = args.hf_token or os.environ.get(HF_TOKEN_ENV)
if args.cases:
test_cases = []
for path in args.cases:
raw = json.loads(path.read_text(encoding="utf-8"))
# strip() 与浏览器语义分析时的 trim() 保持一致,避免 token 数差异
test_cases.extend([(c["name"], c["query"], (c["text"] or "").strip()) for c in raw])
print(f"已加载 {len(test_cases)} 个用例,来自 {len(args.cases)} 个文件")
else:
test_cases = TEST_CASES
submodes = args.submode if args.submode else ["count", "match_score", "fill_blank"]
all_results: list = []
completed: set = set()
if args.output and args.output.exists():
all_results = _load_jsonl(args.output)
completed = {(r["submode"], r["case"]) for r in all_results}
print(f"已加载 {len(all_results)} 条历史结果,从中断处续跑")
for sm in submodes:
_, aborted = run_eval(
api_base, sm, test_cases, token=hf_token,
output_path=args.output, all_results=all_results,
completed=completed, max_retries=args.retries, timeout=args.timeout,
)
if aborted:
break
if args.output:
print(f"\n✅ 结果已写入 {args.output}(共 {len(all_results)} 条)")
if __name__ == "__main__":
main()
|