Spaces:
Running
Running
File size: 8,740 Bytes
03c63b9 b29c06e 03c63b9 b29c06e 03c63b9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 | # llm_utils.py
# ============================================================
# 类型:公共工具模块(乙负责)
# 功能:OpenAI 客户端配置、JSON 安全解析、LLM 调用重试、输出校验
# 被 direction_analyzer.py 和 repo_evaluator.py 共用
# ============================================================
import json
import re
import time
import sys
import os
from openai import OpenAI
# ---- 配置 ----
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "sk-a115c4388d444cf688d323855edb36da")
DEEPSEEK_BASE_URL = "https://api.deepseek.com"
MODEL = "deepseek-v4-pro"
MAX_RETRIES = 3 # API 调用失败时最多重试次数
RETRY_DELAY = 2.0 # 重试间隔(秒),指数增长
_client = None
def get_client() -> OpenAI:
"""获取 OpenAI 客户端单例"""
global _client
if _client is None:
_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url=DEEPSEEK_BASE_URL)
return _client
# ---- LLM 调用(带重试) ----
def call_llm(
system_prompt: str,
user_prompt: str,
temperature: float = 0.3,
max_tokens: int = 8000,
response_format: dict | None = None,
) -> str:
"""调用 DeepSeek API,自动重试。
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
temperature: 温度参数
max_tokens: 最大输出 token 数
response_format: OpenAI 兼容的 response_format,如 {"type": "json_object"}
Returns:
LLM 返回的原始文本
Raises:
RuntimeError: 重试次数用尽后仍失败
"""
client = get_client()
last_error = None
for attempt in range(MAX_RETRIES):
try:
kwargs = dict(
model=MODEL,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=temperature,
max_tokens=max_tokens,
timeout=120.0,
)
if response_format is not None:
kwargs["response_format"] = response_format
response = client.chat.completions.create(**kwargs)
content = response.choices[0].message.content
if content is None or not content.strip():
raise ValueError("API 返回空内容")
return content.strip()
except Exception as e:
last_error = e
if attempt < MAX_RETRIES - 1:
wait = RETRY_DELAY * (2 ** attempt)
print(f" [重试 {attempt+1}/{MAX_RETRIES}] API 调用失败: {e},{wait:.0f}s 后重试...")
time.sleep(wait)
raise RuntimeError(f"LLM API 调用失败(已重试 {MAX_RETRIES} 次): {last_error}")
def call_llm_json(
system_prompt: str,
user_prompt: str,
temperature: float = 0.3,
max_tokens: int = 8000,
) -> str:
"""调用 DeepSeek API 并强制 JSON 输出模式。
这是 call_llm 的便捷封装,自动设置 response_format={"type": "json_object"}。
DeepSeek API 在 JSON 模式下会约束输出为合法 JSON,大幅减少格式偏差。
"""
return call_llm(
system_prompt=system_prompt,
user_prompt=user_prompt,
temperature=temperature,
max_tokens=max_tokens,
response_format={"type": "json_object"},
)
# ---- JSON 安全解析 ----
def parse_json_safe(raw: str, module_name: str = "") -> dict:
"""安全解析 LLM 返回的 JSON,7 种策略依次尝试。
LLM 可能返回:
1. 纯 JSON
2. ```json ... ``` 代码块包裹
3. ``` ... ``` 无语言标记包裹
4. 带前缀文字 + JSON
5. 带后缀文字 + JSON
6. 多行中夹杂 JSON 对象
7. 含未转义控制字符的 JSON
"""
# 全局预处理:替换未转义的控制字符(换行、制表符)
# JSON 字符串内不允许这些字符,替换为空格不会影响合法的 JSON 结构
raw = raw.replace('\r\n', ' ').replace('\n', ' ').replace('\r', ' ').replace('\t', ' ')
strategies_tried = []
# 策略 1:直接解析
try:
return json.loads(raw)
except json.JSONDecodeError as e:
strategies_tried.append(f"直接解析: {e}")
# 策略 2:提取 ```json ... ``` 代码块
match = re.search(r'```json\s*([\s\S]*?)```', raw)
if match:
try:
return json.loads(match.group(1).strip())
except json.JSONDecodeError:
pass
# 策略 3:提取 ``` ... ``` 任意代码块
match = re.search(r'```\s*([\s\S]*?)```', raw)
if match:
try:
return json.loads(match.group(1).strip())
except json.JSONDecodeError:
pass
# 策略 4:提取最外层 { ... }
match = re.search(r'\{[\s\S]*\}', raw)
if match:
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
pass
# 策略 5:修复常见 JSON 错误(尾部逗号、单引号)
cleaned = raw.strip()
cleaned = re.sub(r',\s*}', '}', cleaned) # 移除尾部逗号
cleaned = re.sub(r',\s*]', ']', cleaned) # 移除数组尾部逗号
try:
return json.loads(cleaned)
except json.JSONDecodeError:
pass
# 策略 6:尝试找包含 "subfield" 或 "score" 的 { ... }
for keyword in ["subfield", "overall_score", "score"]:
pattern = r'\{[^{}]*"' + keyword + r'"[^{}]*\{[^{}]*\}[^{}]*\}|\{[^{}]*"' + keyword + r'"[^{}]*\}'
match = re.search(pattern, raw)
if match:
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
pass
# 全部失败
module_tag = f"[{module_name}] " if module_name else ""
raise ValueError(
f"{module_tag}无法解析 LLM 返回的 JSON。\n"
f"原始返回(前 500 字符):\n{raw[:500]}"
)
# ---- 输出校验 ----
def validate_direction_output(data: dict) -> dict:
"""校验 direction_analyzer 的输出,填充缺失字段为默认值。"""
if not isinstance(data, dict):
raise ValueError(f"direction_analyzer 输出不是 dict,而是 {type(data).__name__}")
# 确保 method_families 是 list
families = data.get("method_families", [])
if not isinstance(families, list):
families = []
# 校验每个方法族
validated_families = []
for mf in families:
if not isinstance(mf, dict):
continue
validated_families.append({
"family_name": mf.get("family_name", "未命名方法族"),
"description": mf.get("description", ""),
"representative_work": mf.get("representative_work", ""),
"search_queries": mf.get("search_queries", []) if isinstance(mf.get("search_queries"), list) else [],
"matched_repos": mf.get("matched_repos", []) if isinstance(mf.get("matched_repos"), list) else [],
})
# 确保 broad_queries 是 list
broad_queries = data.get("broad_queries", [])
if not isinstance(broad_queries, list):
broad_queries = []
return {
"subfield": data.get("subfield", "未知子领域"),
"subfield_trend": data.get("subfield_trend", "暂无趋势分析"),
"method_families": validated_families,
"broad_queries": broad_queries,
}
def validate_eval_output(data: dict) -> dict:
"""校验 repo_evaluator 的输出,填充缺失字段为默认值。"""
if not isinstance(data, dict):
raise ValueError(f"repo_evaluator 输出不是 dict,而是 {type(data).__name__}")
# 数值字段默认 0
int_fields = [
"reproducibility_score", "benchmark_fitness_score", "overall_score",
"env_score", "doc_score", "code_score", "community_score", "dep_score",
"benchmark_score",
]
for field in int_fields:
if field not in data or not isinstance(data[field], (int, float)):
data[field] = 0
else:
data[field] = int(data[field])
# 字符串字段
data.setdefault("verdict", "error")
data.setdefault("reasoning", "评估未返回分析")
data.setdefault("benchmark_readiness", "not_ready")
data.setdefault("suggested_use", "请手动评估该仓库")
# 列表字段
if "risks" not in data or not isinstance(data["risks"], list):
data["risks"] = []
return data
# ---- Windows 终端编码 ----
def fix_windows_encoding():
"""修复 Windows 终端中文显示乱码问题"""
if sys.platform == "win32":
try:
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
except Exception:
pass
|