ResearchRadar / llm_utils.py
ZZZyx3587's picture
Upload llm_utils.py with huggingface_hub
b29c06e verified
# llm_utils.py
# ============================================================
# 类型:公共工具模块(乙负责)
# 功能:OpenAI 客户端配置、JSON 安全解析、LLM 调用重试、输出校验
# 被 direction_analyzer.py 和 repo_evaluator.py 共用
# ============================================================
import json
import re
import time
import sys
import os
from openai import OpenAI
# ---- 配置 ----
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "sk-a115c4388d444cf688d323855edb36da")
DEEPSEEK_BASE_URL = "https://api.deepseek.com"
MODEL = "deepseek-v4-pro"
MAX_RETRIES = 3 # API 调用失败时最多重试次数
RETRY_DELAY = 2.0 # 重试间隔(秒),指数增长
_client = None
def get_client() -> OpenAI:
"""获取 OpenAI 客户端单例"""
global _client
if _client is None:
_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url=DEEPSEEK_BASE_URL)
return _client
# ---- LLM 调用(带重试) ----
def call_llm(
system_prompt: str,
user_prompt: str,
temperature: float = 0.3,
max_tokens: int = 8000,
response_format: dict | None = None,
) -> str:
"""调用 DeepSeek API,自动重试。
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
temperature: 温度参数
max_tokens: 最大输出 token 数
response_format: OpenAI 兼容的 response_format,如 {"type": "json_object"}
Returns:
LLM 返回的原始文本
Raises:
RuntimeError: 重试次数用尽后仍失败
"""
client = get_client()
last_error = None
for attempt in range(MAX_RETRIES):
try:
kwargs = dict(
model=MODEL,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=temperature,
max_tokens=max_tokens,
timeout=120.0,
)
if response_format is not None:
kwargs["response_format"] = response_format
response = client.chat.completions.create(**kwargs)
content = response.choices[0].message.content
if content is None or not content.strip():
raise ValueError("API 返回空内容")
return content.strip()
except Exception as e:
last_error = e
if attempt < MAX_RETRIES - 1:
wait = RETRY_DELAY * (2 ** attempt)
print(f" [重试 {attempt+1}/{MAX_RETRIES}] API 调用失败: {e}{wait:.0f}s 后重试...")
time.sleep(wait)
raise RuntimeError(f"LLM API 调用失败(已重试 {MAX_RETRIES} 次): {last_error}")
def call_llm_json(
system_prompt: str,
user_prompt: str,
temperature: float = 0.3,
max_tokens: int = 8000,
) -> str:
"""调用 DeepSeek API 并强制 JSON 输出模式。
这是 call_llm 的便捷封装,自动设置 response_format={"type": "json_object"}。
DeepSeek API 在 JSON 模式下会约束输出为合法 JSON,大幅减少格式偏差。
"""
return call_llm(
system_prompt=system_prompt,
user_prompt=user_prompt,
temperature=temperature,
max_tokens=max_tokens,
response_format={"type": "json_object"},
)
# ---- JSON 安全解析 ----
def parse_json_safe(raw: str, module_name: str = "") -> dict:
"""安全解析 LLM 返回的 JSON,7 种策略依次尝试。
LLM 可能返回:
1. 纯 JSON
2. ```json ... ``` 代码块包裹
3. ``` ... ``` 无语言标记包裹
4. 带前缀文字 + JSON
5. 带后缀文字 + JSON
6. 多行中夹杂 JSON 对象
7. 含未转义控制字符的 JSON
"""
# 全局预处理:替换未转义的控制字符(换行、制表符)
# JSON 字符串内不允许这些字符,替换为空格不会影响合法的 JSON 结构
raw = raw.replace('\r\n', ' ').replace('\n', ' ').replace('\r', ' ').replace('\t', ' ')
strategies_tried = []
# 策略 1:直接解析
try:
return json.loads(raw)
except json.JSONDecodeError as e:
strategies_tried.append(f"直接解析: {e}")
# 策略 2:提取 ```json ... ``` 代码块
match = re.search(r'```json\s*([\s\S]*?)```', raw)
if match:
try:
return json.loads(match.group(1).strip())
except json.JSONDecodeError:
pass
# 策略 3:提取 ``` ... ``` 任意代码块
match = re.search(r'```\s*([\s\S]*?)```', raw)
if match:
try:
return json.loads(match.group(1).strip())
except json.JSONDecodeError:
pass
# 策略 4:提取最外层 { ... }
match = re.search(r'\{[\s\S]*\}', raw)
if match:
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
pass
# 策略 5:修复常见 JSON 错误(尾部逗号、单引号)
cleaned = raw.strip()
cleaned = re.sub(r',\s*}', '}', cleaned) # 移除尾部逗号
cleaned = re.sub(r',\s*]', ']', cleaned) # 移除数组尾部逗号
try:
return json.loads(cleaned)
except json.JSONDecodeError:
pass
# 策略 6:尝试找包含 "subfield" 或 "score" 的 { ... }
for keyword in ["subfield", "overall_score", "score"]:
pattern = r'\{[^{}]*"' + keyword + r'"[^{}]*\{[^{}]*\}[^{}]*\}|\{[^{}]*"' + keyword + r'"[^{}]*\}'
match = re.search(pattern, raw)
if match:
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
pass
# 全部失败
module_tag = f"[{module_name}] " if module_name else ""
raise ValueError(
f"{module_tag}无法解析 LLM 返回的 JSON。\n"
f"原始返回(前 500 字符):\n{raw[:500]}"
)
# ---- 输出校验 ----
def validate_direction_output(data: dict) -> dict:
"""校验 direction_analyzer 的输出,填充缺失字段为默认值。"""
if not isinstance(data, dict):
raise ValueError(f"direction_analyzer 输出不是 dict,而是 {type(data).__name__}")
# 确保 method_families 是 list
families = data.get("method_families", [])
if not isinstance(families, list):
families = []
# 校验每个方法族
validated_families = []
for mf in families:
if not isinstance(mf, dict):
continue
validated_families.append({
"family_name": mf.get("family_name", "未命名方法族"),
"description": mf.get("description", ""),
"representative_work": mf.get("representative_work", ""),
"search_queries": mf.get("search_queries", []) if isinstance(mf.get("search_queries"), list) else [],
"matched_repos": mf.get("matched_repos", []) if isinstance(mf.get("matched_repos"), list) else [],
})
# 确保 broad_queries 是 list
broad_queries = data.get("broad_queries", [])
if not isinstance(broad_queries, list):
broad_queries = []
return {
"subfield": data.get("subfield", "未知子领域"),
"subfield_trend": data.get("subfield_trend", "暂无趋势分析"),
"method_families": validated_families,
"broad_queries": broad_queries,
}
def validate_eval_output(data: dict) -> dict:
"""校验 repo_evaluator 的输出,填充缺失字段为默认值。"""
if not isinstance(data, dict):
raise ValueError(f"repo_evaluator 输出不是 dict,而是 {type(data).__name__}")
# 数值字段默认 0
int_fields = [
"reproducibility_score", "benchmark_fitness_score", "overall_score",
"env_score", "doc_score", "code_score", "community_score", "dep_score",
"benchmark_score",
]
for field in int_fields:
if field not in data or not isinstance(data[field], (int, float)):
data[field] = 0
else:
data[field] = int(data[field])
# 字符串字段
data.setdefault("verdict", "error")
data.setdefault("reasoning", "评估未返回分析")
data.setdefault("benchmark_readiness", "not_ready")
data.setdefault("suggested_use", "请手动评估该仓库")
# 列表字段
if "risks" not in data or not isinstance(data["risks"], list):
data["risks"] = []
return data
# ---- Windows 终端编码 ----
def fix_windows_encoding():
"""修复 Windows 终端中文显示乱码问题"""
if sys.platform == "win32":
try:
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
except Exception:
pass