| |
| """ |
| Semantic analyzer 效果评估脚本 |
| |
| 通过 HTTP 调用 /api/analyze-semantic 接口进行评估。 |
| |
| 支持 submode:count / match_score(已废弃) / fill_blank |
| |
| 评估维度: |
| 1. 生成的 top10 (token和概率) 的合理性 |
| 2. token_attention score 的合理性 |
| 3. 完全无关查询时的结果合理性 |
| |
| 用法(从项目根目录运行): |
| python scripts/eval_semantic.py -c scripts/cases/eval_cases_short.json -o eval_result.jsonl |
| python scripts/eval_semantic.py --submode count fill_blank -o eval_result.jsonl |
| python scripts/eval_semantic.py --url http://localhost:5001 |
| |
| 输出为 JSONL 格式,每完成一例追加一行;中断后可再次运行,从中断处续跑。 |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| import time |
| from pathlib import Path |
| from typing import Optional, Tuple |
|
|
| |
| HF_TOKEN_ENV = "HF_TOKEN" |
|
|
| try: |
| import requests |
| except ImportError: |
| print("错误: 需要安装 requests 库") |
| print("请运行: pip install requests") |
| sys.exit(1) |
|
|
|
|
| |
| |
| |
| TEST_CASES = [ |
| ("相关_AI", "人工智能", "人工智能正在改变我们的生活。机器学习、深度学习等技术在医疗、金融等领域广泛应用。"), |
| ("相关_天气", "天气", "今天北京天气晴朗,气温适宜,适合户外活动。明天可能有小雨。"), |
| ("无关_足球对AI", "足球比赛", "人工智能正在改变我们的生活。机器学习、深度学习等技术在医疗、金融等领域广泛应用。"), |
| ("无关_烹饪对天气", "红烧肉做法", "今天北京天气晴朗,气温适宜,适合户外活动。明天可能有小雨。"), |
| ] |
|
|
| DEFAULT_API_BASE = "http://localhost:5001" |
|
|
|
|
| def analyze_semantic_http(api_base: str, query: str, text: str, submode: Optional[str] = None, token: Optional[str] = None, timeout: int = 300) -> dict: |
| """通过 HTTP 调用 analyze-semantic 接口""" |
| url = f"{api_base.rstrip('/')}/api/analyze-semantic" |
| payload: dict = {"query": query, "text": text, "debug_info": True} |
| if submode is not None: |
| payload["submode"] = submode |
| headers = {"Content-Type": "application/json"} |
| if token: |
| headers["Authorization"] = f"Bearer {token}" |
|
|
| resp = requests.post(url, json=payload, headers=headers, timeout=timeout) |
| resp.raise_for_status() |
| data = resp.json() |
| if not data.get("success"): |
| raise RuntimeError(data.get("message", "分析失败")) |
| return data |
|
|
|
|
| def _load_jsonl(path: Path) -> list: |
| """加载 JSONL 文件,用于断点续跑""" |
| if not path.exists(): |
| return [] |
| results = [] |
| for line in path.read_text(encoding="utf-8").strip().split("\n"): |
| if not line: |
| continue |
| try: |
| results.append(json.loads(line)) |
| except json.JSONDecodeError: |
| pass |
| return results |
|
|
|
|
| def _append_record(path: Path, record: dict) -> None: |
| """追加单条记录到 JSONL 文件""" |
| with path.open("a", encoding="utf-8") as f: |
| f.write(json.dumps(record, ensure_ascii=False) + "\n") |
|
|
|
|
| def run_eval( |
| api_base: str, |
| submode: str, |
| test_cases: list, |
| token: Optional[str] = None, |
| output_path: Optional[Path] = None, |
| all_results: Optional[list] = None, |
| completed: Optional[set] = None, |
| max_retries: int = 3, |
| timeout: int = 300, |
| ) -> Tuple[list, bool]: |
| """返回 (results, aborted),重试后仍失败时 aborted 为 True""" |
| completed = completed or set() |
| results = [] |
| for j, (name, query, text) in enumerate(test_cases): |
| prog = f"[{j+1}/{len(test_cases)}]" |
| if (submode, name) in completed: |
| print(f"{prog} ⏭ 跳过: {submode} | {name}", flush=True) |
| continue |
| print(f"{prog} 执行: {submode} | {name}", flush=True) |
| res = None |
| last_error = None |
| for attempt in range(max_retries + 1): |
| try: |
| res = analyze_semantic_http(api_base, query, text, submode, token=token, timeout=timeout) |
| break |
| except Exception as e: |
| last_error = e |
| if attempt < max_retries: |
| wait = 3 * (attempt + 1) |
| print(f"{prog} 重试 {attempt + 1}/{max_retries},{wait}s 后... - {e}", flush=True) |
| time.sleep(wait) |
| if res is None: |
| print(f"{prog} ✗ 失败(已重试 {max_retries} 次): {submode} | {name} - {last_error}", flush=True) |
| record = {"submode": submode, "case": name, "query": query, "error": str(last_error)} |
| results.append(record) |
| if all_results is not None: |
| all_results.append(record) |
| completed.add((submode, name)) |
| print(f"\n⚠ 重试后仍失败,中断后续用例", flush=True) |
| return results, True |
|
|
| di = res.get("debug_info", {}) |
| topk_tokens = di.get("topk_tokens", []) |
| topk_probs = di.get("topk_probs", []) |
| token_attention = res.get("token_attention", []) |
|
|
| |
| score_max = max(a["score"] for a in token_attention) if token_attention else 0 |
| denom = score_max if score_max > 0 else 1 |
|
|
| |
| sorted_attn = sorted(token_attention, key=lambda x: x["score"], reverse=True)[:10] |
| top_scored = [] |
| for a in sorted_attn: |
| score_norm = round(a["score"] / denom, 6) |
| top_scored.append({ |
| "raw": a["raw"], |
| "score": round(a["score"], 6), |
| "score_norm": score_norm, |
| "offset": a["offset"], |
| }) |
|
|
| record = { |
| "model": res.get("model", ""), |
| "submode": submode, |
| "case": name, |
| "query": query, |
| "text_preview": text[:80] + "..." if len(text) > 80 else text, |
| "full_match_degree": res.get("full_match_degree", None), |
| "top10_tokens": topk_tokens, |
| "top10_probs": [round(p, 6) for p in topk_probs], |
| "top10_scored_raw": top_scored, |
| "score_stats": { |
| "min": round(min(a["score"] for a in token_attention), 6) if token_attention else None, |
| "max": round(score_max, 6) if token_attention else None, |
| "mean": round(sum(a["score"] for a in token_attention) / len(token_attention), 6) if token_attention else None, |
| "mean_norm": round(sum(a["score"] / denom for a in token_attention) / len(token_attention), 6) if token_attention else None, |
| }, |
| } |
| results.append(record) |
| if all_results is not None: |
| all_results.append(record) |
| completed.add((submode, name)) |
| if output_path: |
| _append_record(output_path, record) |
| print(f"{prog} ✓ 完成: {submode} | {name}", flush=True) |
|
|
| return results, False |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="评估 semantic analyzer 效果(HTTP)") |
| parser.add_argument( |
| "--submode", |
| choices=["count", "match_score", "fill_blank"], |
| nargs="+", |
| default=None, |
| help="instruct 模型子模式(可多个),不指定则依次评估 count/fill_blank;match_score 已废弃", |
| ) |
| parser.add_argument( |
| "--output", "-o", |
| type=Path, |
| default=None, |
| help="结果输出 JSONL 路径(支持断点续跑)", |
| ) |
| parser.add_argument( |
| "--url", |
| default=DEFAULT_API_BASE, |
| help=f"API 地址,默认 {DEFAULT_API_BASE}", |
| ) |
| parser.add_argument( |
| "--hf-token", |
| type=str, |
| default=None, |
| help=f"Hugging Face Token(用于Private Space,也可通过环境变量{HF_TOKEN_ENV}设置)", |
| ) |
| parser.add_argument( |
| "--cases", "-c", |
| type=Path, |
| nargs="+", |
| default=None, |
| help="自定义测试用例 JSON 文件,可指定多个,格式 [{name, query, text}, ...]", |
| ) |
| parser.add_argument( |
| "--retries", |
| type=int, |
| default=3, |
| help="失败时自动重试次数,默认 3", |
| ) |
| parser.add_argument( |
| "--timeout", |
| type=int, |
| default=300, |
| help="单次请求超时秒数,默认 300", |
| ) |
| args = parser.parse_args() |
|
|
| api_base = args.url.rstrip("/") |
| hf_token = args.hf_token or os.environ.get(HF_TOKEN_ENV) |
|
|
| if args.cases: |
| test_cases = [] |
| for path in args.cases: |
| raw = json.loads(path.read_text(encoding="utf-8")) |
| |
| test_cases.extend([(c["name"], c["query"], (c["text"] or "").strip()) for c in raw]) |
| print(f"已加载 {len(test_cases)} 个用例,来自 {len(args.cases)} 个文件") |
| else: |
| test_cases = TEST_CASES |
|
|
| submodes = args.submode if args.submode else ["count", "match_score", "fill_blank"] |
| all_results: list = [] |
| completed: set = set() |
| if args.output and args.output.exists(): |
| all_results = _load_jsonl(args.output) |
| completed = {(r["submode"], r["case"]) for r in all_results} |
| print(f"已加载 {len(all_results)} 条历史结果,从中断处续跑") |
| for sm in submodes: |
| _, aborted = run_eval( |
| api_base, sm, test_cases, token=hf_token, |
| output_path=args.output, all_results=all_results, |
| completed=completed, max_retries=args.retries, timeout=args.timeout, |
| ) |
| if aborted: |
| break |
| if args.output: |
| print(f"\n✅ 结果已写入 {args.output}(共 {len(all_results)} 条)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|