Spaces:
Running
Running
| # llm_utils.py | |
| # ============================================================ | |
| # 类型:公共工具模块(乙负责) | |
| # 功能:OpenAI 客户端配置、JSON 安全解析、LLM 调用重试、输出校验 | |
| # 被 direction_analyzer.py 和 repo_evaluator.py 共用 | |
| # ============================================================ | |
| import json | |
| import re | |
| import time | |
| import sys | |
| import os | |
| from openai import OpenAI | |
| # ---- 配置 ---- | |
| DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "sk-a115c4388d444cf688d323855edb36da") | |
| DEEPSEEK_BASE_URL = "https://api.deepseek.com" | |
| MODEL = "deepseek-v4-pro" | |
| MAX_RETRIES = 3 # API 调用失败时最多重试次数 | |
| RETRY_DELAY = 2.0 # 重试间隔(秒),指数增长 | |
| _client = None | |
| def get_client() -> OpenAI: | |
| """获取 OpenAI 客户端单例""" | |
| global _client | |
| if _client is None: | |
| _client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url=DEEPSEEK_BASE_URL) | |
| return _client | |
| # ---- LLM 调用(带重试) ---- | |
| def call_llm( | |
| system_prompt: str, | |
| user_prompt: str, | |
| temperature: float = 0.3, | |
| max_tokens: int = 8000, | |
| response_format: dict | None = None, | |
| ) -> str: | |
| """调用 DeepSeek API,自动重试。 | |
| Args: | |
| system_prompt: 系统提示词 | |
| user_prompt: 用户提示词 | |
| temperature: 温度参数 | |
| max_tokens: 最大输出 token 数 | |
| response_format: OpenAI 兼容的 response_format,如 {"type": "json_object"} | |
| Returns: | |
| LLM 返回的原始文本 | |
| Raises: | |
| RuntimeError: 重试次数用尽后仍失败 | |
| """ | |
| client = get_client() | |
| last_error = None | |
| for attempt in range(MAX_RETRIES): | |
| try: | |
| kwargs = dict( | |
| model=MODEL, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| timeout=120.0, | |
| ) | |
| if response_format is not None: | |
| kwargs["response_format"] = response_format | |
| response = client.chat.completions.create(**kwargs) | |
| content = response.choices[0].message.content | |
| if content is None or not content.strip(): | |
| raise ValueError("API 返回空内容") | |
| return content.strip() | |
| except Exception as e: | |
| last_error = e | |
| if attempt < MAX_RETRIES - 1: | |
| wait = RETRY_DELAY * (2 ** attempt) | |
| print(f" [重试 {attempt+1}/{MAX_RETRIES}] API 调用失败: {e},{wait:.0f}s 后重试...") | |
| time.sleep(wait) | |
| raise RuntimeError(f"LLM API 调用失败(已重试 {MAX_RETRIES} 次): {last_error}") | |
| def call_llm_json( | |
| system_prompt: str, | |
| user_prompt: str, | |
| temperature: float = 0.3, | |
| max_tokens: int = 8000, | |
| ) -> str: | |
| """调用 DeepSeek API 并强制 JSON 输出模式。 | |
| 这是 call_llm 的便捷封装,自动设置 response_format={"type": "json_object"}。 | |
| DeepSeek API 在 JSON 模式下会约束输出为合法 JSON,大幅减少格式偏差。 | |
| """ | |
| return call_llm( | |
| system_prompt=system_prompt, | |
| user_prompt=user_prompt, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| response_format={"type": "json_object"}, | |
| ) | |
| # ---- JSON 安全解析 ---- | |
| def parse_json_safe(raw: str, module_name: str = "") -> dict: | |
| """安全解析 LLM 返回的 JSON,7 种策略依次尝试。 | |
| LLM 可能返回: | |
| 1. 纯 JSON | |
| 2. ```json ... ``` 代码块包裹 | |
| 3. ``` ... ``` 无语言标记包裹 | |
| 4. 带前缀文字 + JSON | |
| 5. 带后缀文字 + JSON | |
| 6. 多行中夹杂 JSON 对象 | |
| 7. 含未转义控制字符的 JSON | |
| """ | |
| # 全局预处理:替换未转义的控制字符(换行、制表符) | |
| # JSON 字符串内不允许这些字符,替换为空格不会影响合法的 JSON 结构 | |
| raw = raw.replace('\r\n', ' ').replace('\n', ' ').replace('\r', ' ').replace('\t', ' ') | |
| strategies_tried = [] | |
| # 策略 1:直接解析 | |
| try: | |
| return json.loads(raw) | |
| except json.JSONDecodeError as e: | |
| strategies_tried.append(f"直接解析: {e}") | |
| # 策略 2:提取 ```json ... ``` 代码块 | |
| match = re.search(r'```json\s*([\s\S]*?)```', raw) | |
| if match: | |
| try: | |
| return json.loads(match.group(1).strip()) | |
| except json.JSONDecodeError: | |
| pass | |
| # 策略 3:提取 ``` ... ``` 任意代码块 | |
| match = re.search(r'```\s*([\s\S]*?)```', raw) | |
| if match: | |
| try: | |
| return json.loads(match.group(1).strip()) | |
| except json.JSONDecodeError: | |
| pass | |
| # 策略 4:提取最外层 { ... } | |
| match = re.search(r'\{[\s\S]*\}', raw) | |
| if match: | |
| try: | |
| return json.loads(match.group(0)) | |
| except json.JSONDecodeError: | |
| pass | |
| # 策略 5:修复常见 JSON 错误(尾部逗号、单引号) | |
| cleaned = raw.strip() | |
| cleaned = re.sub(r',\s*}', '}', cleaned) # 移除尾部逗号 | |
| cleaned = re.sub(r',\s*]', ']', cleaned) # 移除数组尾部逗号 | |
| try: | |
| return json.loads(cleaned) | |
| except json.JSONDecodeError: | |
| pass | |
| # 策略 6:尝试找包含 "subfield" 或 "score" 的 { ... } | |
| for keyword in ["subfield", "overall_score", "score"]: | |
| pattern = r'\{[^{}]*"' + keyword + r'"[^{}]*\{[^{}]*\}[^{}]*\}|\{[^{}]*"' + keyword + r'"[^{}]*\}' | |
| match = re.search(pattern, raw) | |
| if match: | |
| try: | |
| return json.loads(match.group(0)) | |
| except json.JSONDecodeError: | |
| pass | |
| # 全部失败 | |
| module_tag = f"[{module_name}] " if module_name else "" | |
| raise ValueError( | |
| f"{module_tag}无法解析 LLM 返回的 JSON。\n" | |
| f"原始返回(前 500 字符):\n{raw[:500]}" | |
| ) | |
| # ---- 输出校验 ---- | |
| def validate_direction_output(data: dict) -> dict: | |
| """校验 direction_analyzer 的输出,填充缺失字段为默认值。""" | |
| if not isinstance(data, dict): | |
| raise ValueError(f"direction_analyzer 输出不是 dict,而是 {type(data).__name__}") | |
| # 确保 method_families 是 list | |
| families = data.get("method_families", []) | |
| if not isinstance(families, list): | |
| families = [] | |
| # 校验每个方法族 | |
| validated_families = [] | |
| for mf in families: | |
| if not isinstance(mf, dict): | |
| continue | |
| validated_families.append({ | |
| "family_name": mf.get("family_name", "未命名方法族"), | |
| "description": mf.get("description", ""), | |
| "representative_work": mf.get("representative_work", ""), | |
| "search_queries": mf.get("search_queries", []) if isinstance(mf.get("search_queries"), list) else [], | |
| "matched_repos": mf.get("matched_repos", []) if isinstance(mf.get("matched_repos"), list) else [], | |
| }) | |
| # 确保 broad_queries 是 list | |
| broad_queries = data.get("broad_queries", []) | |
| if not isinstance(broad_queries, list): | |
| broad_queries = [] | |
| return { | |
| "subfield": data.get("subfield", "未知子领域"), | |
| "subfield_trend": data.get("subfield_trend", "暂无趋势分析"), | |
| "method_families": validated_families, | |
| "broad_queries": broad_queries, | |
| } | |
| def validate_eval_output(data: dict) -> dict: | |
| """校验 repo_evaluator 的输出,填充缺失字段为默认值。""" | |
| if not isinstance(data, dict): | |
| raise ValueError(f"repo_evaluator 输出不是 dict,而是 {type(data).__name__}") | |
| # 数值字段默认 0 | |
| int_fields = [ | |
| "reproducibility_score", "benchmark_fitness_score", "overall_score", | |
| "env_score", "doc_score", "code_score", "community_score", "dep_score", | |
| "benchmark_score", | |
| ] | |
| for field in int_fields: | |
| if field not in data or not isinstance(data[field], (int, float)): | |
| data[field] = 0 | |
| else: | |
| data[field] = int(data[field]) | |
| # 字符串字段 | |
| data.setdefault("verdict", "error") | |
| data.setdefault("reasoning", "评估未返回分析") | |
| data.setdefault("benchmark_readiness", "not_ready") | |
| data.setdefault("suggested_use", "请手动评估该仓库") | |
| # 列表字段 | |
| if "risks" not in data or not isinstance(data["risks"], list): | |
| data["risks"] = [] | |
| return data | |
| # ---- Windows 终端编码 ---- | |
| def fix_windows_encoding(): | |
| """修复 Windows 终端中文显示乱码问题""" | |
| if sys.platform == "win32": | |
| try: | |
| sys.stdout.reconfigure(encoding="utf-8", errors="replace") | |
| except Exception: | |
| pass | |