# llm_utils.py # ============================================================ # 类型:公共工具模块(乙负责) # 功能:OpenAI 客户端配置、JSON 安全解析、LLM 调用重试、输出校验 # 被 direction_analyzer.py 和 repo_evaluator.py 共用 # ============================================================ import json import re import time import sys import os from openai import OpenAI # ---- 配置 ---- DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "sk-a115c4388d444cf688d323855edb36da") DEEPSEEK_BASE_URL = "https://api.deepseek.com" MODEL = "deepseek-v4-pro" MAX_RETRIES = 3 # API 调用失败时最多重试次数 RETRY_DELAY = 2.0 # 重试间隔(秒),指数增长 _client = None def get_client() -> OpenAI: """获取 OpenAI 客户端单例""" global _client if _client is None: _client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url=DEEPSEEK_BASE_URL) return _client # ---- LLM 调用(带重试) ---- def call_llm( system_prompt: str, user_prompt: str, temperature: float = 0.3, max_tokens: int = 8000, response_format: dict | None = None, ) -> str: """调用 DeepSeek API,自动重试。 Args: system_prompt: 系统提示词 user_prompt: 用户提示词 temperature: 温度参数 max_tokens: 最大输出 token 数 response_format: OpenAI 兼容的 response_format,如 {"type": "json_object"} Returns: LLM 返回的原始文本 Raises: RuntimeError: 重试次数用尽后仍失败 """ client = get_client() last_error = None for attempt in range(MAX_RETRIES): try: kwargs = dict( model=MODEL, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], temperature=temperature, max_tokens=max_tokens, timeout=120.0, ) if response_format is not None: kwargs["response_format"] = response_format response = client.chat.completions.create(**kwargs) content = response.choices[0].message.content if content is None or not content.strip(): raise ValueError("API 返回空内容") return content.strip() except Exception as e: last_error = e if attempt < MAX_RETRIES - 1: wait = RETRY_DELAY * (2 ** attempt) print(f" [重试 {attempt+1}/{MAX_RETRIES}] API 调用失败: {e},{wait:.0f}s 后重试...") time.sleep(wait) raise RuntimeError(f"LLM API 调用失败(已重试 {MAX_RETRIES} 次): {last_error}") def call_llm_json( system_prompt: str, user_prompt: str, temperature: float = 0.3, max_tokens: int = 8000, ) -> str: """调用 DeepSeek API 并强制 JSON 输出模式。 这是 call_llm 的便捷封装,自动设置 response_format={"type": "json_object"}。 DeepSeek API 在 JSON 模式下会约束输出为合法 JSON,大幅减少格式偏差。 """ return call_llm( system_prompt=system_prompt, user_prompt=user_prompt, temperature=temperature, max_tokens=max_tokens, response_format={"type": "json_object"}, ) # ---- JSON 安全解析 ---- def parse_json_safe(raw: str, module_name: str = "") -> dict: """安全解析 LLM 返回的 JSON,7 种策略依次尝试。 LLM 可能返回: 1. 纯 JSON 2. ```json ... ``` 代码块包裹 3. ``` ... ``` 无语言标记包裹 4. 带前缀文字 + JSON 5. 带后缀文字 + JSON 6. 多行中夹杂 JSON 对象 7. 含未转义控制字符的 JSON """ # 全局预处理:替换未转义的控制字符(换行、制表符) # JSON 字符串内不允许这些字符,替换为空格不会影响合法的 JSON 结构 raw = raw.replace('\r\n', ' ').replace('\n', ' ').replace('\r', ' ').replace('\t', ' ') strategies_tried = [] # 策略 1:直接解析 try: return json.loads(raw) except json.JSONDecodeError as e: strategies_tried.append(f"直接解析: {e}") # 策略 2:提取 ```json ... ``` 代码块 match = re.search(r'```json\s*([\s\S]*?)```', raw) if match: try: return json.loads(match.group(1).strip()) except json.JSONDecodeError: pass # 策略 3:提取 ``` ... ``` 任意代码块 match = re.search(r'```\s*([\s\S]*?)```', raw) if match: try: return json.loads(match.group(1).strip()) except json.JSONDecodeError: pass # 策略 4:提取最外层 { ... } match = re.search(r'\{[\s\S]*\}', raw) if match: try: return json.loads(match.group(0)) except json.JSONDecodeError: pass # 策略 5:修复常见 JSON 错误(尾部逗号、单引号) cleaned = raw.strip() cleaned = re.sub(r',\s*}', '}', cleaned) # 移除尾部逗号 cleaned = re.sub(r',\s*]', ']', cleaned) # 移除数组尾部逗号 try: return json.loads(cleaned) except json.JSONDecodeError: pass # 策略 6:尝试找包含 "subfield" 或 "score" 的 { ... } for keyword in ["subfield", "overall_score", "score"]: pattern = r'\{[^{}]*"' + keyword + r'"[^{}]*\{[^{}]*\}[^{}]*\}|\{[^{}]*"' + keyword + r'"[^{}]*\}' match = re.search(pattern, raw) if match: try: return json.loads(match.group(0)) except json.JSONDecodeError: pass # 全部失败 module_tag = f"[{module_name}] " if module_name else "" raise ValueError( f"{module_tag}无法解析 LLM 返回的 JSON。\n" f"原始返回(前 500 字符):\n{raw[:500]}" ) # ---- 输出校验 ---- def validate_direction_output(data: dict) -> dict: """校验 direction_analyzer 的输出,填充缺失字段为默认值。""" if not isinstance(data, dict): raise ValueError(f"direction_analyzer 输出不是 dict,而是 {type(data).__name__}") # 确保 method_families 是 list families = data.get("method_families", []) if not isinstance(families, list): families = [] # 校验每个方法族 validated_families = [] for mf in families: if not isinstance(mf, dict): continue validated_families.append({ "family_name": mf.get("family_name", "未命名方法族"), "description": mf.get("description", ""), "representative_work": mf.get("representative_work", ""), "search_queries": mf.get("search_queries", []) if isinstance(mf.get("search_queries"), list) else [], "matched_repos": mf.get("matched_repos", []) if isinstance(mf.get("matched_repos"), list) else [], }) # 确保 broad_queries 是 list broad_queries = data.get("broad_queries", []) if not isinstance(broad_queries, list): broad_queries = [] return { "subfield": data.get("subfield", "未知子领域"), "subfield_trend": data.get("subfield_trend", "暂无趋势分析"), "method_families": validated_families, "broad_queries": broad_queries, } def validate_eval_output(data: dict) -> dict: """校验 repo_evaluator 的输出,填充缺失字段为默认值。""" if not isinstance(data, dict): raise ValueError(f"repo_evaluator 输出不是 dict,而是 {type(data).__name__}") # 数值字段默认 0 int_fields = [ "reproducibility_score", "benchmark_fitness_score", "overall_score", "env_score", "doc_score", "code_score", "community_score", "dep_score", "benchmark_score", ] for field in int_fields: if field not in data or not isinstance(data[field], (int, float)): data[field] = 0 else: data[field] = int(data[field]) # 字符串字段 data.setdefault("verdict", "error") data.setdefault("reasoning", "评估未返回分析") data.setdefault("benchmark_readiness", "not_ready") data.setdefault("suggested_use", "请手动评估该仓库") # 列表字段 if "risks" not in data or not isinstance(data["risks"], list): data["risks"] = [] return data # ---- Windows 终端编码 ---- def fix_windows_encoding(): """修复 Windows 终端中文显示乱码问题""" if sys.platform == "win32": try: sys.stdout.reconfigure(encoding="utf-8", errors="replace") except Exception: pass