File size: 8,740 Bytes
03c63b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b29c06e
03c63b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b29c06e
03c63b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
# llm_utils.py
# ============================================================
# 类型:公共工具模块(乙负责)
# 功能:OpenAI 客户端配置、JSON 安全解析、LLM 调用重试、输出校验
# 被 direction_analyzer.py 和 repo_evaluator.py 共用
# ============================================================

import json
import re
import time
import sys
import os
from openai import OpenAI

# ---- 配置 ----
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "sk-a115c4388d444cf688d323855edb36da")
DEEPSEEK_BASE_URL = "https://api.deepseek.com"
MODEL = "deepseek-v4-pro"
MAX_RETRIES = 3          # API 调用失败时最多重试次数
RETRY_DELAY = 2.0        # 重试间隔(秒),指数增长

_client = None


def get_client() -> OpenAI:
    """获取 OpenAI 客户端单例"""
    global _client
    if _client is None:
        _client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url=DEEPSEEK_BASE_URL)
    return _client


# ---- LLM 调用(带重试) ----

def call_llm(
    system_prompt: str,
    user_prompt: str,
    temperature: float = 0.3,
    max_tokens: int = 8000,
    response_format: dict | None = None,
) -> str:
    """调用 DeepSeek API,自动重试。

    Args:
        system_prompt: 系统提示词
        user_prompt: 用户提示词
        temperature: 温度参数
        max_tokens: 最大输出 token 数
        response_format: OpenAI 兼容的 response_format,如 {"type": "json_object"}

    Returns:
        LLM 返回的原始文本

    Raises:
        RuntimeError: 重试次数用尽后仍失败
    """
    client = get_client()
    last_error = None

    for attempt in range(MAX_RETRIES):
        try:
            kwargs = dict(
                model=MODEL,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt},
                ],
                temperature=temperature,
                max_tokens=max_tokens,
                timeout=120.0,
            )
            if response_format is not None:
                kwargs["response_format"] = response_format

            response = client.chat.completions.create(**kwargs)
            content = response.choices[0].message.content
            if content is None or not content.strip():
                raise ValueError("API 返回空内容")
            return content.strip()
        except Exception as e:
            last_error = e
            if attempt < MAX_RETRIES - 1:
                wait = RETRY_DELAY * (2 ** attempt)
                print(f"  [重试 {attempt+1}/{MAX_RETRIES}] API 调用失败: {e}{wait:.0f}s 后重试...")
                time.sleep(wait)

    raise RuntimeError(f"LLM API 调用失败(已重试 {MAX_RETRIES} 次): {last_error}")


def call_llm_json(
    system_prompt: str,
    user_prompt: str,
    temperature: float = 0.3,
    max_tokens: int = 8000,
) -> str:
    """调用 DeepSeek API 并强制 JSON 输出模式。

    这是 call_llm 的便捷封装,自动设置 response_format={"type": "json_object"}。
    DeepSeek API 在 JSON 模式下会约束输出为合法 JSON,大幅减少格式偏差。
    """
    return call_llm(
        system_prompt=system_prompt,
        user_prompt=user_prompt,
        temperature=temperature,
        max_tokens=max_tokens,
        response_format={"type": "json_object"},
    )


# ---- JSON 安全解析 ----

def parse_json_safe(raw: str, module_name: str = "") -> dict:
    """安全解析 LLM 返回的 JSON,7 种策略依次尝试。

    LLM 可能返回:
    1. 纯 JSON
    2. ```json ... ``` 代码块包裹
    3. ``` ... ``` 无语言标记包裹
    4. 带前缀文字 + JSON
    5. 带后缀文字 + JSON
    6. 多行中夹杂 JSON 对象
    7. 含未转义控制字符的 JSON
    """
    # 全局预处理:替换未转义的控制字符(换行、制表符)
    # JSON 字符串内不允许这些字符,替换为空格不会影响合法的 JSON 结构
    raw = raw.replace('\r\n', ' ').replace('\n', ' ').replace('\r', ' ').replace('\t', ' ')

    strategies_tried = []

    # 策略 1:直接解析
    try:
        return json.loads(raw)
    except json.JSONDecodeError as e:
        strategies_tried.append(f"直接解析: {e}")

    # 策略 2:提取 ```json ... ``` 代码块
    match = re.search(r'```json\s*([\s\S]*?)```', raw)
    if match:
        try:
            return json.loads(match.group(1).strip())
        except json.JSONDecodeError:
            pass

    # 策略 3:提取 ``` ... ``` 任意代码块
    match = re.search(r'```\s*([\s\S]*?)```', raw)
    if match:
        try:
            return json.loads(match.group(1).strip())
        except json.JSONDecodeError:
            pass

    # 策略 4:提取最外层 { ... }
    match = re.search(r'\{[\s\S]*\}', raw)
    if match:
        try:
            return json.loads(match.group(0))
        except json.JSONDecodeError:
            pass

    # 策略 5:修复常见 JSON 错误(尾部逗号、单引号)
    cleaned = raw.strip()
    cleaned = re.sub(r',\s*}', '}', cleaned)    # 移除尾部逗号
    cleaned = re.sub(r',\s*]', ']', cleaned)    # 移除数组尾部逗号
    try:
        return json.loads(cleaned)
    except json.JSONDecodeError:
        pass

    # 策略 6:尝试找包含 "subfield" 或 "score" 的 { ... }
    for keyword in ["subfield", "overall_score", "score"]:
        pattern = r'\{[^{}]*"' + keyword + r'"[^{}]*\{[^{}]*\}[^{}]*\}|\{[^{}]*"' + keyword + r'"[^{}]*\}'
        match = re.search(pattern, raw)
        if match:
            try:
                return json.loads(match.group(0))
            except json.JSONDecodeError:
                pass

    # 全部失败
    module_tag = f"[{module_name}] " if module_name else ""
    raise ValueError(
        f"{module_tag}无法解析 LLM 返回的 JSON。\n"
        f"原始返回(前 500 字符):\n{raw[:500]}"
    )


# ---- 输出校验 ----

def validate_direction_output(data: dict) -> dict:
    """校验 direction_analyzer 的输出,填充缺失字段为默认值。"""
    if not isinstance(data, dict):
        raise ValueError(f"direction_analyzer 输出不是 dict,而是 {type(data).__name__}")

    # 确保 method_families 是 list
    families = data.get("method_families", [])
    if not isinstance(families, list):
        families = []

    # 校验每个方法族
    validated_families = []
    for mf in families:
        if not isinstance(mf, dict):
            continue
        validated_families.append({
            "family_name": mf.get("family_name", "未命名方法族"),
            "description": mf.get("description", ""),
            "representative_work": mf.get("representative_work", ""),
            "search_queries": mf.get("search_queries", []) if isinstance(mf.get("search_queries"), list) else [],
            "matched_repos": mf.get("matched_repos", []) if isinstance(mf.get("matched_repos"), list) else [],
        })

    # 确保 broad_queries 是 list
    broad_queries = data.get("broad_queries", [])
    if not isinstance(broad_queries, list):
        broad_queries = []

    return {
        "subfield": data.get("subfield", "未知子领域"),
        "subfield_trend": data.get("subfield_trend", "暂无趋势分析"),
        "method_families": validated_families,
        "broad_queries": broad_queries,
    }


def validate_eval_output(data: dict) -> dict:
    """校验 repo_evaluator 的输出,填充缺失字段为默认值。"""
    if not isinstance(data, dict):
        raise ValueError(f"repo_evaluator 输出不是 dict,而是 {type(data).__name__}")

    # 数值字段默认 0
    int_fields = [
        "reproducibility_score", "benchmark_fitness_score", "overall_score",
        "env_score", "doc_score", "code_score", "community_score", "dep_score",
        "benchmark_score",
    ]
    for field in int_fields:
        if field not in data or not isinstance(data[field], (int, float)):
            data[field] = 0
        else:
            data[field] = int(data[field])

    # 字符串字段
    data.setdefault("verdict", "error")
    data.setdefault("reasoning", "评估未返回分析")
    data.setdefault("benchmark_readiness", "not_ready")
    data.setdefault("suggested_use", "请手动评估该仓库")

    # 列表字段
    if "risks" not in data or not isinstance(data["risks"], list):
        data["risks"] = []

    return data


# ---- Windows 终端编码 ----

def fix_windows_encoding():
    """修复 Windows 终端中文显示乱码问题"""
    if sys.platform == "win32":
        try:
            sys.stdout.reconfigure(encoding="utf-8", errors="replace")
        except Exception:
            pass