File size: 9,954 Bytes
494c9e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
#!/usr/bin/env python3
"""
Semantic analyzer 效果评估脚本

通过 HTTP 调用 /api/analyze-semantic 接口进行评估。

支持 submode:count / match_score(已废弃) / fill_blank

评估维度:
1. 生成的 top10 (token和概率) 的合理性
2. token_attention score 的合理性
3. 完全无关查询时的结果合理性

用法(从项目根目录运行):
  python scripts/eval_semantic.py -c scripts/cases/eval_cases_short.json -o eval_result.jsonl
  python scripts/eval_semantic.py --submode count fill_blank -o eval_result.jsonl
  python scripts/eval_semantic.py --url http://localhost:5001

输出为 JSONL 格式,每完成一例追加一行;中断后可再次运行,从中断处续跑。
"""

import argparse
import json
import os
import sys
import time
from pathlib import Path
from typing import Optional, Tuple

# Hugging Face Token(用于Private Space,可通过环境变量HF_TOKEN设置)
HF_TOKEN_ENV = "HF_TOKEN"

try:
    import requests
except ImportError:
    print("错误: 需要安装 requests 库")
    print("请运行: pip install requests")
    sys.exit(1)


# 测试用例:(名称, query, text)
# 相关:query 与 text 主题一致
# 无关:query 与 text 完全无关
TEST_CASES = [
    ("相关_AI", "人工智能", "人工智能正在改变我们的生活。机器学习、深度学习等技术在医疗、金融等领域广泛应用。"),
    ("相关_天气", "天气", "今天北京天气晴朗,气温适宜,适合户外活动。明天可能有小雨。"),
    ("无关_足球对AI", "足球比赛", "人工智能正在改变我们的生活。机器学习、深度学习等技术在医疗、金融等领域广泛应用。"),
    ("无关_烹饪对天气", "红烧肉做法", "今天北京天气晴朗,气温适宜,适合户外活动。明天可能有小雨。"),
]

DEFAULT_API_BASE = "http://localhost:5001"


def analyze_semantic_http(api_base: str, query: str, text: str, submode: Optional[str] = None, token: Optional[str] = None, timeout: int = 300) -> dict:
    """通过 HTTP 调用 analyze-semantic 接口"""
    url = f"{api_base.rstrip('/')}/api/analyze-semantic"
    payload: dict = {"query": query, "text": text, "debug_info": True}
    if submode is not None:
        payload["submode"] = submode
    headers = {"Content-Type": "application/json"}
    if token:
        headers["Authorization"] = f"Bearer {token}"

    resp = requests.post(url, json=payload, headers=headers, timeout=timeout)
    resp.raise_for_status()
    data = resp.json()
    if not data.get("success"):
        raise RuntimeError(data.get("message", "分析失败"))
    return data


def _load_jsonl(path: Path) -> list:
    """加载 JSONL 文件,用于断点续跑"""
    if not path.exists():
        return []
    results = []
    for line in path.read_text(encoding="utf-8").strip().split("\n"):
        if not line:
            continue
        try:
            results.append(json.loads(line))
        except json.JSONDecodeError:
            pass
    return results


def _append_record(path: Path, record: dict) -> None:
    """追加单条记录到 JSONL 文件"""
    with path.open("a", encoding="utf-8") as f:
        f.write(json.dumps(record, ensure_ascii=False) + "\n")


def run_eval(
    api_base: str,
    submode: str,
    test_cases: list,
    token: Optional[str] = None,
    output_path: Optional[Path] = None,
    all_results: Optional[list] = None,
    completed: Optional[set] = None,
    max_retries: int = 3,
    timeout: int = 300,
) -> Tuple[list, bool]:
    """返回 (results, aborted),重试后仍失败时 aborted 为 True"""
    completed = completed or set()
    results = []
    for j, (name, query, text) in enumerate(test_cases):
        prog = f"[{j+1}/{len(test_cases)}]"
        if (submode, name) in completed:
            print(f"{prog} ⏭ 跳过: {submode} | {name}", flush=True)
            continue
        print(f"{prog} 执行: {submode} | {name}", flush=True)
        res = None
        last_error = None
        for attempt in range(max_retries + 1):
            try:
                res = analyze_semantic_http(api_base, query, text, submode, token=token, timeout=timeout)
                break
            except Exception as e:
                last_error = e
                if attempt < max_retries:
                    wait = 3 * (attempt + 1)
                    print(f"{prog}   重试 {attempt + 1}/{max_retries}{wait}s 后... - {e}", flush=True)
                    time.sleep(wait)
        if res is None:
            print(f"{prog} ✗ 失败(已重试 {max_retries} 次): {submode} | {name} - {last_error}", flush=True)
            record = {"submode": submode, "case": name, "query": query, "error": str(last_error)}
            results.append(record)
            if all_results is not None:
                all_results.append(record)
            completed.add((submode, name))
            print(f"\n⚠ 重试后仍失败,中断后续用例", flush=True)
            return results, True

        di = res.get("debug_info", {})
        topk_tokens = di.get("topk_tokens", [])
        topk_probs = di.get("topk_probs", [])
        token_attention = res.get("token_attention", [])

        # 0-max 归一化: score / max ∈ [0, 1],最大值归一为 1
        score_max = max(a["score"] for a in token_attention) if token_attention else 0
        denom = score_max if score_max > 0 else 1

        # 按 score 排序取 top10
        sorted_attn = sorted(token_attention, key=lambda x: x["score"], reverse=True)[:10]
        top_scored = []
        for a in sorted_attn:
            score_norm = round(a["score"] / denom, 6)
            top_scored.append({
                "raw": a["raw"],
                "score": round(a["score"], 6),
                "score_norm": score_norm,
                "offset": a["offset"],
            })

        record = {
            "model": res.get("model", ""),
            "submode": submode,
            "case": name,
            "query": query,
            "text_preview": text[:80] + "..." if len(text) > 80 else text,
            "full_match_degree": res.get("full_match_degree", None),
            "top10_tokens": topk_tokens,
            "top10_probs": [round(p, 6) for p in topk_probs],
            "top10_scored_raw": top_scored,
            "score_stats": {
                "min": round(min(a["score"] for a in token_attention), 6) if token_attention else None,
                "max": round(score_max, 6) if token_attention else None,
                "mean": round(sum(a["score"] for a in token_attention) / len(token_attention), 6) if token_attention else None,
                "mean_norm": round(sum(a["score"] / denom for a in token_attention) / len(token_attention), 6) if token_attention else None,
            },
        }
        results.append(record)
        if all_results is not None:
            all_results.append(record)
        completed.add((submode, name))
        if output_path:
            _append_record(output_path, record)
        print(f"{prog} ✓ 完成: {submode} | {name}", flush=True)

    return results, False


def main():
    parser = argparse.ArgumentParser(description="评估 semantic analyzer 效果(HTTP)")
    parser.add_argument(
        "--submode",
        choices=["count", "match_score", "fill_blank"],
        nargs="+",
        default=None,
        help="instruct 模型子模式(可多个),不指定则依次评估 count/fill_blank;match_score 已废弃",
    )
    parser.add_argument(
        "--output", "-o",
        type=Path,
        default=None,
        help="结果输出 JSONL 路径(支持断点续跑)",
    )
    parser.add_argument(
        "--url",
        default=DEFAULT_API_BASE,
        help=f"API 地址,默认 {DEFAULT_API_BASE}",
    )
    parser.add_argument(
        "--hf-token",
        type=str,
        default=None,
        help=f"Hugging Face Token(用于Private Space,也可通过环境变量{HF_TOKEN_ENV}设置)",
    )
    parser.add_argument(
        "--cases", "-c",
        type=Path,
        nargs="+",
        default=None,
        help="自定义测试用例 JSON 文件,可指定多个,格式 [{name, query, text}, ...]",
    )
    parser.add_argument(
        "--retries",
        type=int,
        default=3,
        help="失败时自动重试次数,默认 3",
    )
    parser.add_argument(
        "--timeout",
        type=int,
        default=300,
        help="单次请求超时秒数,默认 300",
    )
    args = parser.parse_args()

    api_base = args.url.rstrip("/")
    hf_token = args.hf_token or os.environ.get(HF_TOKEN_ENV)

    if args.cases:
        test_cases = []
        for path in args.cases:
            raw = json.loads(path.read_text(encoding="utf-8"))
            # strip() 与浏览器语义分析时的 trim() 保持一致,避免 token 数差异
            test_cases.extend([(c["name"], c["query"], (c["text"] or "").strip()) for c in raw])
        print(f"已加载 {len(test_cases)} 个用例,来自 {len(args.cases)} 个文件")
    else:
        test_cases = TEST_CASES

    submodes = args.submode if args.submode else ["count", "match_score", "fill_blank"]
    all_results: list = []
    completed: set = set()
    if args.output and args.output.exists():
        all_results = _load_jsonl(args.output)
        completed = {(r["submode"], r["case"]) for r in all_results}
        print(f"已加载 {len(all_results)} 条历史结果,从中断处续跑")
    for sm in submodes:
        _, aborted = run_eval(
            api_base, sm, test_cases, token=hf_token,
            output_path=args.output, all_results=all_results,
            completed=completed, max_retries=args.retries, timeout=args.timeout,
        )
        if aborted:
            break
    if args.output:
        print(f"\n✅ 结果已写入 {args.output}(共 {len(all_results)} 条)")


if __name__ == "__main__":
    main()