Spaces:
Running
Running
| # run.py | |
| # ============================================================ | |
| # 类型:调度器(乙负责) | |
| # 功能:按"先搜再分析"流程串联 Workflow(甲)和 Agent(乙) | |
| # 用法:python run.py <arxiv_url> | |
| # 示例:python run.py https://arxiv.org/abs/2011.08785 | |
| # ============================================================ | |
| import sys | |
| import time | |
| import re | |
| import os | |
| import json | |
| import urllib.request | |
| import urllib.error | |
| import xml.etree.ElementTree as ET | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| # 自动加载 .env 文件 | |
| _ENV_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".env") | |
| if os.path.isfile(_ENV_PATH): | |
| with open(_ENV_PATH, "r", encoding="utf-8") as _f: | |
| for _line in _f: | |
| _line = _line.strip() | |
| if _line and not _line.startswith("#") and "=" in _line: | |
| _key, _val = _line.split("=", 1) | |
| if _key not in os.environ: | |
| os.environ[_key] = _val.strip().strip('"').strip("'") | |
| # ---- 导入甲的 Workflow 模块 ---- | |
| try: | |
| from paper_fetcher import fetch_paper_info | |
| except ImportError as e: | |
| raise ImportError( | |
| f"缺少模块: paper_fetcher.py(甲负责),请等待甲交付。\n" | |
| f" 原始错误: {e}" | |
| ) from e | |
| try: | |
| from repo_searcher import search_repos | |
| except ImportError as e: | |
| raise ImportError( | |
| f"缺少模块: repo_searcher.py(甲负责),请等待甲交付。\n" | |
| f" 原始错误: {e}" | |
| ) from e | |
| try: | |
| from repo_fetcher import fetch_readme, fetch_dependencies | |
| except ImportError as e: | |
| raise ImportError( | |
| f"缺少模块: repo_fetcher.py(甲负责),请等待甲交付。\n" | |
| f" 原始错误: {e}" | |
| ) from e | |
| from direction_analyzer import analyze_direction | |
| from repo_evaluator import evaluate_repo | |
| from supervisor import supervise | |
| from llm_utils import call_llm_json, parse_json_safe | |
| def _enrich_domain_context(title: str, abstract: str, categories: list[str]) -> str: | |
| """从 Semantic Scholar 搜索同领域相关论文,扩充领域上下文。 | |
| S2 覆盖全工科(含 IEEE/ASME/ASCE 等期刊),不限于 arXiv。 | |
| 搜索策略:1) 论文标题关键词 + survey/review 2) 摘要关键词 | |
| 失败时返回空字符串(静默降级)。 | |
| """ | |
| all_abstracts = [] | |
| # 从标题提取 2-3 个核心关键词作为搜索词 | |
| title_keywords = _extract_title_keywords(title) | |
| search_queries = [ | |
| f"{' '.join(title_keywords)} survey review", | |
| f"{' '.join(title_keywords)} state of the art", | |
| ] | |
| for query in search_queries[:2]: | |
| api_url = ( | |
| f"https://api.semanticscholar.org/graph/v1/paper/search" | |
| f"?query={urllib.request.quote(query)}&limit=3" | |
| f"&fields=title,abstract,year,citationCount" | |
| ) | |
| req = urllib.request.Request(api_url, headers={"User-Agent": "ResearchRadar/1.0"}) | |
| try: | |
| with urllib.request.urlopen(req, timeout=15) as resp: | |
| data = json.loads(resp.read().decode("utf-8")) | |
| for paper in data.get("data", []): | |
| t = (paper.get("title") or "").strip() | |
| s = (paper.get("abstract") or "").strip() | |
| if t and s: | |
| all_abstracts.append({ | |
| "title": t, | |
| "abstract": s[:800], | |
| "year": paper.get("year", ""), | |
| "citations": paper.get("citationCount", 0), | |
| }) | |
| except Exception: | |
| continue | |
| if not all_abstracts: | |
| return "" | |
| # 组装为上下文文本 | |
| lines = ["## 该领域相关论文(来自 Semantic Scholar,覆盖全工科领域)"] | |
| for i, a in enumerate(all_abstracts[:4]): | |
| year_str = f" ({a['year']})" if a.get("year") else "" | |
| cites_str = f" [{a.get('citations', 0)} 引用]" | |
| lines.append(f"{i+1}. **{a['title']}**{year_str}{cites_str}: {a['abstract']}") | |
| return "\n\n".join(lines) | |
| def _extract_title_keywords(title: str) -> list[str]: | |
| """从论文标题提取核心实义词作为搜索关键词。""" | |
| stops = { | |
| 'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are', | |
| 'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it', | |
| 'its', 'not', 'can', 'has', 'have', 'been', 'was', 'were', 'will', 'would', | |
| 'could', 'should', 'may', 'do', 'does', 'did', 'so', 'if', 'no', 'new', | |
| 'based', 'using', 'which', 'into', 'such', 'than', 'then', 'these', 'those', | |
| 'propose', 'present', 'method', 'approach', 'framework', 'novel', 'model', | |
| 'learning', 'deep', 'via', 'et', 'al', 'towards', 'toward', | |
| } | |
| words = [w.lower() for w in re.sub(r'[^\w\s-]', ' ', title).split() if len(w) >= 4 and w.lower() not in stops] | |
| # 去重保持顺序 | |
| seen = set() | |
| uniq = [] | |
| for w in words: | |
| if w not in seen: | |
| seen.add(w) | |
| uniq.append(w) | |
| return uniq[:5] | |
| def _mine_comparison_algorithms(arxiv_id: str, title: str, abstract: str) -> tuple[list[str], dict[str, int]]: | |
| """从 Semantic Scholar 引用网络中挖掘对比实验算法。 | |
| 核心思路:论文的 references(引用的论文)中通常包含对比实验的 baseline 方法, | |
| 通过提取这些论文的标题作为额外搜索词,可以大幅提升 GitHub 搜索的覆盖度。 | |
| Returns: | |
| (extra_queries, citation_map): 额外搜索词列表, {paper_title: citation_count} | |
| """ | |
| s2_url = ( | |
| f"https://api.semanticscholar.org/graph/v1/paper/ArXiv:{arxiv_id}" | |
| f"?fields=references.title,references.citationCount,references.abstract" | |
| f"&limit=50" | |
| ) | |
| req = urllib.request.Request(s2_url, headers={"User-Agent": "ResearchRadar/1.0"}) | |
| try: | |
| with urllib.request.urlopen(req, timeout=20) as resp: | |
| data = json.loads(resp.read().decode("utf-8")) | |
| except Exception as e: | |
| print(f" [WARN] Semantic Scholar API 不可用,跳过对比算法挖掘: {e}") | |
| return [], {} | |
| refs = data.get("references", []) | |
| if not refs: | |
| return [], {} | |
| # 按引用量排序,取 top 15 | |
| refs.sort(key=lambda r: r.get("citationCount", 0), reverse=True) | |
| top_refs = refs[:15] | |
| # 构建标题列表供 LLM 识别方法论文 | |
| title_list = [] | |
| citation_map = {} | |
| for r in top_refs: | |
| t = (r.get("title") or "").strip() | |
| cc = r.get("citationCount", 0) | |
| if t and len(t) > 10: | |
| title_list.append(f"- [{cc} cites] {t}") | |
| citation_map[t] = cc | |
| if len(title_list) < 3: | |
| return [], citation_map | |
| # 用 LLM 从引用论文标题中识别哪些是方法/算法论文 | |
| system_prompt = """你是学术论文分析专家。从引用论文列表中识别哪些是提出了具体算法/方法的论文。 | |
| 排除标准: | |
| - 数据集/benchmark 论文(如 ImageNet, CIFAR, MVTec AD) | |
| - 综述/survey 论文 | |
| - 纯理论/数学论文 | |
| - 框架/库论文(如 PyTorch, TensorFlow) | |
| 保留标准: | |
| - 提出了具体的模型/架构/算法名称 | |
| - 可以作为对比实验的 baseline 方法 | |
| 输出严格 JSON: | |
| {"methods": ["方法名1", "方法名2"], "search_queries": ["method1 pytorch implementation", "method2 official code"]} | |
| 方法名尽量使用论文中常用的英文缩写或全称。""" | |
| user_prompt = f"""输入论文标题: {title[:200]} | |
| 引用论文列表: | |
| {chr(10).join(title_list)} | |
| 请识别哪些是方法/算法论文,生成对应的 GitHub 搜索词。""" | |
| try: | |
| raw = call_llm_json(system_prompt, user_prompt, temperature=0.2, max_tokens=3000) | |
| data = parse_json_safe(raw, "comparison_miner") | |
| methods = data.get("methods", []) | |
| queries = data.get("search_queries", []) | |
| print(f" 从引用网络识别到 {len(methods)} 个对比算法: {methods[:8]}") | |
| return queries[:8], citation_map | |
| except Exception as e: | |
| print(f" [WARN] 对比算法识别失败,降级: {e}") | |
| # 降级:直接用引用论文标题作为搜索词 | |
| fallback_queries = [] | |
| for t in list(citation_map.keys())[:5]: | |
| short = re.sub(r'[^\w\s-]', '', t).strip()[:80] | |
| if len(short) > 15: | |
| fallback_queries.append(f"{short} pytorch implementation") | |
| return fallback_queries, citation_map | |
| def _search_s2_papers(query: str, limit: int = 5) -> list[dict]: | |
| """在 Semantic Scholar 中搜索论文,覆盖全工科领域(IEEE/ASME/ASCE 等期刊)。 | |
| 带 30 分钟 TTL 缓存,降低 S2 API 压力。 | |
| Returns: | |
| list[dict]: [{"title": ..., "abstract": ..., "year": ..., "citations": ..., "url": ...}, ...] | |
| """ | |
| from cache import s2_cache | |
| cache_key = f"s2_search:{query}:{limit}" | |
| cached = s2_cache.get(cache_key) | |
| if cached is not None: | |
| return cached | |
| api_url = ( | |
| f"https://api.semanticscholar.org/graph/v1/paper/search" | |
| f"?query={urllib.request.quote(query)}&limit={limit}" | |
| f"&fields=title,abstract,year,citationCount,url,externalIds" | |
| ) | |
| req = urllib.request.Request(api_url, headers={"User-Agent": "ResearchRadar/1.0"}) | |
| try: | |
| with urllib.request.urlopen(req, timeout=15) as resp: | |
| data = json.loads(resp.read().decode("utf-8")) | |
| except Exception: | |
| return [] | |
| papers = [] | |
| for p in data.get("data", []): | |
| title = (p.get("title") or "").strip() | |
| if not title: | |
| continue | |
| papers.append({ | |
| "title": title, | |
| "abstract": (p.get("abstract") or "").strip()[:1000], | |
| "year": p.get("year"), | |
| "citations": p.get("citationCount", 0), | |
| "url": p.get("url", ""), | |
| "arxiv_id": (p.get("externalIds") or {}).get("ArXiv", ""), | |
| "doi": (p.get("externalIds") or {}).get("DOI", ""), | |
| }) | |
| s2_cache.set(cache_key, papers) | |
| return papers | |
| def _find_input_paper_repo(title: str, arxiv_id: str, authors: list[str]) -> dict | None: | |
| """搜索输入论文自身的官方代码仓库。 | |
| 搜索策略(按优先级): | |
| 1. GitHub 搜索 arxiv ID | |
| 2. GitHub 搜索论文标题(精确匹配) | |
| 3. GitHub 搜索一作姓名 + 论文核心关键词 | |
| Returns: | |
| 找到则返回符合格式的候选仓库 dict,找不到返回 None | |
| """ | |
| import requests | |
| token = os.getenv("GITHUB_TOKEN", "") | |
| headers = {"Accept": "application/vnd.github.v3+json"} | |
| if token: | |
| headers["Authorization"] = f"Bearer {token}" | |
| # 提取核心关键词(取标题前 6 个实义词) | |
| stops = { | |
| 'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are', | |
| 'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it', | |
| 'its', 'not', 'can', 'has', 'have', 'been', 'was', 'were', 'will', 'would', | |
| 'could', 'should', 'may', 'do', 'does', 'did', 'so', 'if', 'no', 'new', | |
| 'based', 'using', 'which', 'into', 'such', 'than', 'then', 'these', 'those', | |
| 'propose', 'present', 'method', 'approach', 'framework', 'novel', 'model', | |
| 'learning', 'deep', 'via', 'et', 'al', 'towards', 'toward', | |
| } | |
| title_words = [w for w in re.sub(r'[^\w\s-]', ' ', title).split() | |
| if w.lower() not in stops and len(w) >= 3] | |
| core_keywords = " ".join(title_words[:6]) | |
| search_queries = [ | |
| f'"{arxiv_id}" in:name,description,readme', | |
| f'"{core_keywords}" in:name,description stars:>=3', | |
| ] | |
| # 用一作姓名 + 关键词搜索 | |
| first_author = authors[0] if authors else "" | |
| if first_author: | |
| last_name = first_author.split()[-1] if first_author.split() else "" | |
| if len(last_name) >= 3: | |
| search_queries.append(f"{last_name} {core_keywords[:80]} in:name,description") | |
| for query in search_queries: | |
| url = "https://api.github.com/search/repositories" | |
| params = {"q": query, "sort": "stars", "order": "desc", "per_page": 5} | |
| try: | |
| resp = requests.get(url, headers=headers, params=params, timeout=15) | |
| if resp.status_code in (403, 429): | |
| continue | |
| resp.raise_for_status() | |
| data = resp.json() | |
| except Exception: | |
| continue | |
| # 筛选:标题必须包含核心关键词中至少 2 个 | |
| for item in data.get("items", []): | |
| repo_title = (item.get("description") or "").lower() | |
| repo_name = item.get("full_name", "").lower() | |
| combined = f"{repo_name} {repo_title}" | |
| matches = sum(1 for kw in title_words[:6] if kw.lower() in combined) | |
| if matches >= 2: | |
| return { | |
| "full_name": item["full_name"], | |
| "html_url": item["html_url"], | |
| "description": item.get("description", ""), | |
| "stars": item.get("stargazers_count", 0), | |
| "language": item.get("language", ""), | |
| "updated_at": item.get("updated_at", ""), | |
| "topics": item.get("topics", []), | |
| "match_keyword": f"本文代码: {query[:60]}", | |
| } | |
| return None | |
| def run(arxiv_url: str, top_n: int = 5, progress=None) -> dict: | |
| """主入口:先搜索 GitHub → 再基于实际仓库归纳方法族 → 最后评估。 | |
| Args: | |
| arxiv_url: arxiv 论文 URL | |
| top_n: 最终评估的仓库数量 | |
| progress: gr.Progress 实例或 callable(fraction, desc),用于前端进度条 | |
| Returns: | |
| dict: {"paper": {...}, "direction": {...}, "repos": [{...}], "error": "..."} | |
| """ | |
| t_start = time.time() | |
| def _prog(frac: float, desc: str): | |
| if progress: | |
| progress(frac, desc=desc) | |
| # ================================================================ | |
| # [1/6] 论文信息获取(甲 Workflow) | |
| # ================================================================ | |
| print("=" * 60) | |
| print("[1/6] 正在获取论文信息...") | |
| _prog(0.05, "正在从 arxiv 获取论文信息...") | |
| t0 = time.time() | |
| try: | |
| paper = fetch_paper_info(arxiv_url) | |
| except Exception as e: | |
| return {"error": f"获取论文信息失败: {e}"} | |
| title = paper.get("title", "") | |
| abstract = paper.get("abstract", "") | |
| print(f" 标题: {title[:100]}...") | |
| print(f" 作者: {', '.join(paper.get('authors', [])[:3])}") | |
| print(f" 分类: {', '.join(paper.get('categories', []))}") | |
| print(f" 耗时: {time.time() - t0:.1f}s") | |
| # ================================================================ | |
| # [1.2] 搜索输入论文的官方代码 | |
| # ================================================================ | |
| print() | |
| print("[1.2] 正在搜索输入论文的官方代码...") | |
| _prog(0.10, "正在搜索论文自身的官方代码...") | |
| t0 = time.time() | |
| input_paper_repo = _find_input_paper_repo(title, arxiv_id, paper.get("authors", [])) | |
| if input_paper_repo: | |
| print(f" 找到论文官方代码: {input_paper_repo['full_name']} (Stars: {input_paper_repo.get('stars', 0)})") | |
| else: | |
| print(f" 未找到论文官方代码") | |
| print(f" 耗时: {time.time() - t0:.1f}s") | |
| # ================================================================ | |
| # [1.5] 对比实验算法挖掘(从 S2 引用网络提取算法名) | |
| # ================================================================ | |
| print() | |
| print("[1.5] 正在从论文引用网络挖掘对比实验算法...") | |
| _prog(0.12, "正在从引用网络挖掘对比算法...") | |
| t0 = time.time() | |
| arxiv_id = paper.get("arxiv_id", "") | |
| comparison_queries, citation_map = _mine_comparison_algorithms(arxiv_id, title, abstract) | |
| if comparison_queries: | |
| print(f" 生成 {len(comparison_queries)} 个对比算法搜索词:") | |
| for q in comparison_queries[:5]: | |
| print(f" - {q}") | |
| print(f" 耗时: {time.time() - t0:.1f}s") | |
| # ================================================================ | |
| # [1.8] S2 论文搜索(覆盖全工科领域,含 IEEE/ASME 等期刊) | |
| # ================================================================ | |
| print() | |
| print("[1.8] 正在 Semantic Scholar 中搜索同领域论文(覆盖全工科)...") | |
| _prog(0.18, "正在 Semantic Scholar 搜索同领域论文...") | |
| t0 = time.time() | |
| title_kws = _extract_title_keywords(title) | |
| s2_papers = _search_s2_papers(" ".join(title_kws), limit=8) | |
| # 从 S2 论文标题中提取额外搜索词 | |
| s2_extra_queries = [] | |
| if s2_papers: | |
| print(f" 找到 {len(s2_papers)} 篇相关论文:") | |
| for p in s2_papers[:5]: | |
| cits = p.get("citations", 0) | |
| print(f" - [{cits} cites] {p['title'][:80]}") | |
| # 用高引论文标题作为搜索词 | |
| if cits >= 50: | |
| s2_quoted = p['title'].strip()[:80] | |
| # 去掉特殊字符 | |
| s2_clean = re.sub(r'[^\w\s-]', '', s2_quoted).strip() | |
| if len(s2_clean) > 15: | |
| s2_extra_queries.append(f"{s2_clean} github") | |
| print(f" 额外生成 {len(s2_extra_queries)} 个 S2 来源搜索词") | |
| print(f" 耗时: {time.time() - t0:.1f}s") | |
| # ================================================================ | |
| # [2/6] 宽泛 GitHub 搜索(甲 Workflow) | |
| # ================================================================ | |
| print() | |
| print("[2/6] 正在宽泛搜索 GitHub(先搜仓库,再让 Agent 分析)...") | |
| _prog(0.22, "正在 GitHub 搜索开源仓库...") | |
| t0 = time.time() | |
| broad_queries = _extract_broad_queries(title, abstract) | |
| # 合并对比算法搜索词,去重 | |
| has_github_token = bool(os.getenv("GITHUB_TOKEN", "")) | |
| max_queries = 15 if has_github_token else 8 | |
| all_queries = list(dict.fromkeys(s2_extra_queries + comparison_queries + broad_queries))[:max_queries] | |
| if not has_github_token: | |
| print(f" ⚠️ 未设置 GITHUB_TOKEN,搜索词限制为 {max_queries} 个以避免限速") | |
| print(f" LLM: {len(broad_queries)} + S2引用: {len(comparison_queries)} + S2论文: {len(s2_extra_queries)} = {len(all_queries)} 个总搜索词:") | |
| for q in all_queries: | |
| print(f" - {q}") | |
| try: | |
| broad_results = search_repos(all_queries, max_per_keyword=5) | |
| except Exception as e: | |
| return { | |
| "paper": paper, | |
| "error": f"GitHub 搜索失败: {e}", | |
| } | |
| print(f" 去重后获得 {len(broad_results)} 个候选仓库") | |
| print(f" 耗时: {time.time() - t0:.1f}s") | |
| if not broad_results: | |
| return { | |
| "paper": paper, | |
| "direction": {}, | |
| "repos": [], | |
| "error": "未找到相关开源仓库。该方向可能太新或太冷门,暂无高质量开源实现。", | |
| } | |
| # ================================================================ | |
| # [2.5] LLM 过滤不相关仓库 | |
| # ================================================================ | |
| print() | |
| print("[2.5] 正在用 LLM 过滤不相关仓库...") | |
| _prog(0.32, "正在用 LLM 过滤不相关仓库...") | |
| t0 = time.time() | |
| filtered_results = _filter_repos(title, abstract, broad_results) | |
| print(f" 耗时: {time.time() - t0:.1f}s") | |
| # ================================================================ | |
| # [2.6] 领域上下文扩充(S2 搜索综述论文补充领域知识) | |
| # ================================================================ | |
| print() | |
| print("[2.6] 正在扩充领域上下文(搜索相关综述论文)...") | |
| _prog(0.38, "正在扩充领域上下文...") | |
| t0 = time.time() | |
| domain_context = _enrich_domain_context(title, abstract, paper.get("categories", [])) | |
| if domain_context: | |
| print(f" 获取到 {domain_context.count('**') // 2} 篇相关综述的摘要") | |
| else: | |
| print(f" 未找到相关综述(将仅基于论文摘要和仓库数据进行分析)") | |
| print(f" 耗时: {time.time() - t0:.1f}s") | |
| # ================================================================ | |
| # [3/6] 基于仓库数据归纳方法族(Agent 1) | |
| # ================================================================ | |
| print() | |
| print(f"[3/6] 正在分析 {len(filtered_results)} 个仓库,归纳方法族(Agent 1)...") | |
| _prog(0.42, "正在用 LLM 归纳方法族...") | |
| t0 = time.time() | |
| try: | |
| direction = analyze_direction(title, abstract, filtered_results, domain_context) | |
| except Exception as e: | |
| print(f" [WARN] Agent 1 方向解析失败,降级为基本分析: {e}") | |
| direction = _make_fallback_direction(title, abstract, filtered_results) | |
| subfield = direction.get("subfield", "未知") | |
| families = direction.get("method_families", []) | |
| print(f" 子领域: {subfield}") | |
| print(f" 趋势: {direction.get('subfield_trend', '')[:80]}...") | |
| print(f" 方法族 ({len(families)} 个):") | |
| for mf in families: | |
| matched = mf.get("matched_repos", []) | |
| print(f" - {mf.get('family_name', '?')}: {len(matched)} 个仓库 {matched}") | |
| # 防混淆校验:检查子领域是否与论文标题存在语义关联 | |
| _sanity_check_direction(title, subfield) | |
| print(f" 耗时: {time.time() - t0:.1f}s") | |
| # ================================================================ | |
| # [4/6] 筛选仓库 + 构建方法族归属映射 | |
| # ================================================================ | |
| print() | |
| print(f"[4/6] 正在筛选并获取仓库详情...") | |
| _prog(0.55, "正在筛选仓库并获取详情...") | |
| t0 = time.time() | |
| # 从 Agent 1 的 matched_repos 中建立 full_name → family_name 映射 | |
| repo_family_map = {} | |
| for mf in families: | |
| family_name = mf.get("family_name", "") | |
| for repo_name in mf.get("matched_repos", []): | |
| repo_family_map[repo_name] = family_name | |
| # 从 filtered_results 中筛选:优先选有方法族归属的,再按 stars 排序 | |
| classified = [r for r in filtered_results if r["full_name"] in repo_family_map] | |
| unclassified = [r for r in filtered_results if r["full_name"] not in repo_family_map] | |
| classified.sort(key=lambda r: r.get("stars", 0), reverse=True) | |
| unclassified.sort(key=lambda r: r.get("stars", 0), reverse=True) | |
| # 论文自身代码优先排在最前面 | |
| if input_paper_repo: | |
| candidates = [input_paper_repo] + (classified + unclassified)[:top_n - 1] | |
| repo_family_map[input_paper_repo["full_name"]] = "本文代码" | |
| else: | |
| candidates = (classified + unclassified)[:top_n] | |
| print(f" 有方法族归属: {len(classified)} 个,未归类: {len(unclassified)} 个") | |
| if input_paper_repo: | |
| print(f" 含论文自身代码: {input_paper_repo['full_name']}") | |
| print(f" 最终选取 {len(candidates)} 个仓库:") | |
| for i, c in enumerate(candidates): | |
| family = repo_family_map.get(c["full_name"], "未归类") | |
| print(f" {i+1}. {c['full_name']} [{family}] Stars:{c.get('stars', 0)}") | |
| print(f" 耗时: {time.time() - t0:.1f}s") | |
| # ================================================================ | |
| # [5/6] 仓库详情获取 + 评估(Agent 2,并行) | |
| # ================================================================ | |
| print() | |
| print(f"[5/6] 正在获取仓库详情并评估(Agent 2,{len(candidates)} 个仓库并行)...") | |
| _prog(0.60, f"正在评估 {len(candidates)} 个开源仓库...") | |
| t0 = time.time() | |
| def _eval_single(repo, idx, total): | |
| """单个仓库评估(在线程中执行)""" | |
| full_name = repo["full_name"] | |
| try: | |
| owner, name = full_name.split("/", 1) | |
| except ValueError: | |
| return idx, None | |
| matched_family = repo_family_map.get(full_name, "") | |
| family_tag = f"[{matched_family}]" if matched_family else "[未归类]" | |
| print(f" ({idx+1}/{total}) {full_name} {family_tag}...") | |
| try: | |
| readme = fetch_readme(owner, name) | |
| deps = fetch_dependencies(owner, name) | |
| evaluation = evaluate_repo(repo, readme, deps, matched_family) | |
| except Exception as e: | |
| print(f" [WARN] {full_name} 评估失败: {e}") | |
| evaluation = _make_error_evaluation(str(e)) | |
| return idx, { | |
| "full_name": repo.get("full_name", ""), | |
| "html_url": repo.get("html_url", ""), | |
| "description": repo.get("description", ""), | |
| "stars": repo.get("stars", 0), | |
| "language": repo.get("language", ""), | |
| "updated_at": repo.get("updated_at", ""), | |
| "topics": repo.get("topics", []), | |
| "match_keyword": repo.get("match_keyword", ""), | |
| "method_family": matched_family, | |
| "evaluation": evaluation, | |
| } | |
| n = len(candidates) | |
| evaluated = [None] * n | |
| with ThreadPoolExecutor(max_workers=min(n, 5)) as executor: | |
| futures = { | |
| executor.submit(_eval_single, repo, i, n): i | |
| for i, repo in enumerate(candidates) | |
| } | |
| for future in as_completed(futures): | |
| idx, result = future.result() | |
| if result is not None: | |
| evaluated[idx] = result | |
| evaluated = [r for r in evaluated if r is not None] | |
| # 按综合评分降序排列 | |
| evaluated.sort(key=lambda r: r["evaluation"].get("overall_score", 0), reverse=True) | |
| print(f" 耗时: {time.time() - t0:.1f}s") | |
| # ================================================================ | |
| # 汇总 | |
| # ================================================================ | |
| print() | |
| print("=" * 60) | |
| print("完成!") | |
| best = evaluated[0] if evaluated else None | |
| if best: | |
| print(f" 最高分: {best['full_name']} ({best['evaluation'].get('overall_score', 0)}/100)") | |
| # ================================================================ | |
| # 审核层:检查 Agent 输出质量和来源可靠性 | |
| # ================================================================ | |
| _prog(0.90, "正在进行质量审核...") | |
| audit = supervise(title, abstract, direction, evaluated) | |
| print(f"\n 审核结果: {audit['summary']}") | |
| print(f" 综合质量评分: {audit['overall_score']}/100") | |
| if audit["actions"]: | |
| for action in audit["actions"]: | |
| print(f" → {action}") | |
| print(f" 总耗时: {time.time() - t_start:.1f}s") | |
| print("=" * 60) | |
| _prog(1.0, "分析完成,正在生成研报...") | |
| return { | |
| "paper": paper, | |
| "direction": direction, | |
| "repos": evaluated, | |
| "audit": audit, | |
| } | |
| def _extract_broad_queries(title: str, abstract: str) -> list[str]: | |
| """用 LLM 从论文标题和摘要中生成 GitHub 搜索关键词。 | |
| 相比规则提取,LLM 理解论文后能生成更精准的领域搜索词。 | |
| 如果 LLM 调用失败,降级为规则提取。 | |
| """ | |
| system_prompt = """你是学术论文搜索专家。根据论文标题和摘要,生成 6-8 个 GitHub 搜索查询,用于找到该研究方向的开源实现。 | |
| 要求: | |
| - 查询必须是英文,用空格分隔关键词 | |
| - 包含方法名/技术关键词 + 限定词(implementation / pytorch / code) | |
| - 包含 1-2 个宽泛的领域查询(如 "anomaly detection pytorch library") | |
| - 不要用引号或特殊符号 | |
| - 避免太泛的词(如单一个 "anomaly") | |
| 输出严格 JSON: | |
| {"queries": ["query1", "query2", ...]}""" | |
| user_prompt = f"""论文标题: {title[:300]} | |
| 摘要: {abstract[:800]} | |
| 请为该论文生成 GitHub 搜索查询。""" | |
| try: | |
| raw = call_llm_json(system_prompt, user_prompt, temperature=0.3, max_tokens=3000) | |
| data = parse_json_safe(raw, "broad_queries") | |
| queries = data.get("queries", []) | |
| if isinstance(queries, list) and len(queries) >= 3: | |
| return queries[:10] | |
| except Exception as e: | |
| print(f" [WARN] LLM 关键词生成失败,降级为规则提取: {e}") | |
| return _extract_broad_queries_fallback(title, abstract) | |
| def _extract_broad_queries_fallback(title: str, abstract: str) -> list[str]: | |
| """从论文标题和摘要中提取宽泛搜索关键词(规则降级版)。 | |
| 策略:提取有意义的词组 + 搜索限定词,按多样性采样(不偏向标题前缀)。 | |
| 不依赖 LLM,纯规则提取。 | |
| """ | |
| abstract_first = abstract.split(".")[0] if abstract else "" | |
| text = f"{title} {abstract_first}".lower() | |
| text = re.sub(r'[^\w\s-]', ' ', text) | |
| words = text.split() | |
| stop_words = { | |
| 'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are', | |
| 'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it', | |
| 'its', 'not', 'can', 'has', 'have', 'been', 'was', 'were', 'will', 'would', | |
| 'could', 'should', 'may', 'do', 'does', 'did', 'so', 'if', 'no', 'new', | |
| 'based', 'using', 'which', 'into', 'such', 'than', 'then', 'these', 'those', | |
| 'propose', 'present', 'method', 'approach', 'framework', 'novel', 'model', | |
| 'learning', 'deep', 'via', 'et', 'al', 'all', 'we', 'you', 'they', | |
| } | |
| meaningful = [w for w in words if w not in stop_words and len(w) > 1] | |
| bigrams = [f"{meaningful[i]} {meaningful[i+1]}" for i in range(len(meaningful) - 1)] | |
| trigrams = [f"{meaningful[i]} {meaningful[i+1]} {meaningful[i+2]}" for i in range(len(meaningful) - 2)] | |
| def _sample_diverse(items: list[str], n: int) -> list[str]: | |
| if len(items) <= n: | |
| return items | |
| step = max(1, len(items) // n) | |
| return [items[i] for i in range(0, len(items), step)][:n] | |
| phrases = _sample_diverse(trigrams, 3) + _sample_diverse(bigrams, 4) | |
| seen = set() | |
| unique_phrases = [] | |
| for p in sorted(phrases, key=len, reverse=True): | |
| if p not in seen: | |
| seen.add(p) | |
| unique_phrases.append(p) | |
| qualifiers = [ | |
| "implementation pytorch", | |
| "official code", | |
| "pytorch github", | |
| ] | |
| queries = [] | |
| for phrase in unique_phrases[:5]: | |
| if len(phrase) > 5: | |
| queries.append(f"{phrase} {qualifiers[0]}") | |
| mid = len(meaningful) // 2 | |
| broad = " ".join(meaningful[max(0, mid-2):mid+2]) | |
| if len(broad) > 10: | |
| queries.append(f"{broad} {qualifiers[2]}") | |
| seen = set() | |
| unique = [] | |
| for q in queries: | |
| if q not in seen: | |
| seen.add(q) | |
| unique.append(q) | |
| return unique[:10] | |
| def _filter_repos(title: str, abstract: str, repos: list[dict]) -> list[dict]: | |
| """用 LLM 快速过滤不相关的仓库,只保留与论文方向相关的。 | |
| 在 GitHub 关键词搜索之后、Agent 1 分析之前运行。 | |
| 如果过滤后为空或 LLM 调用失败,返回原始列表作为降级方案。 | |
| """ | |
| if len(repos) <= 3: | |
| return repos | |
| # 超过 20 个仓库时,只取 top 20(按 stars),避免 JSON 输出过长 | |
| if len(repos) > 20: | |
| repos_for_filter = sorted(repos, key=lambda r: r.get("stars", 0), reverse=True)[:20] | |
| else: | |
| repos_for_filter = repos | |
| # 构造缩略的仓库清单供 LLM 判断 | |
| repo_list_parts = [] | |
| for i, r in enumerate(repos_for_filter): | |
| desc = (r.get("description") or "")[:120] | |
| topics = ", ".join(r.get("topics", [])[:8]) | |
| repo_list_parts.append( | |
| f"{i+1}. {r['full_name']} (Stars:{r.get('stars',0)})\n" | |
| f" Description: {desc}\n" | |
| f" Topics: {topics}" | |
| ) | |
| repo_list_text = "\n".join(repo_list_parts) | |
| system_prompt = """你是一个学术论文与开源代码匹配专家。任务是判断 GitHub 仓库是否与给定论文属于同一研究方向。 | |
| 排除标准(满足任一即排除): | |
| - 仓库解决的业务问题与论文完全不同(如论文做工业缺陷检测,仓库做医学图像/自动驾驶/人脸识别) | |
| - 仓库使用的核心技术方法与论文完全无关 | |
| - 仓库是通用教程/课程作业/面试题集合 | |
| 保留标准(满足任一即保留): | |
| - 仓库实现的方法与论文方法同属一个技术范式 | |
| - 仓库可用于该论文的对比实验(baseline/comparison) | |
| - 仓库是该领域的知名 benchmarking 库 | |
| 输出严格 JSON(不要 reasoning 字段以节省 token): | |
| {"relevant": ["owner/repo1", "owner/repo2"], "irrelevant": ["owner/repo3"]}""" | |
| user_prompt = f"""## 论文信息 | |
| 标题: {title[:200]} | |
| 摘要: {abstract[:500]} | |
| ## GitHub 仓库清单(共 {len(repos_for_filter)} 个) | |
| {repo_list_text} | |
| 请判断每个仓库是否与该论文方向相关。""" | |
| try: | |
| raw = call_llm_json(system_prompt, user_prompt, temperature=0.1, max_tokens=10000) | |
| data = parse_json_safe(raw, "filter_repos") | |
| except Exception as e: | |
| print(f" [WARN] 仓库过滤失败,保留全部: {e}") | |
| return repos | |
| relevant_set = set(data.get("relevant", [])) | |
| irrelevant = data.get("irrelevant", []) | |
| if irrelevant: | |
| print(f" [过滤] 剔除 {len(irrelevant)} 个不相关仓库: {irrelevant}") | |
| filtered = [r for r in repos if r["full_name"] in relevant_set] | |
| if len(filtered) < 2: | |
| print(f" [WARN] 过滤后仅剩 {len(filtered)} 个仓库,回退到原始结果") | |
| return repos | |
| print(f" 过滤后保留 {len(filtered)}/{len(repos)} 个仓库") | |
| return filtered | |
| def _sanity_check_direction(title: str, subfield: str) -> None: | |
| """防混淆检查:验证子领域分析是否与论文标题存在最低限度的语义关联。 | |
| 如果子领域关键词与标题完全无关,可能是 LLM 混淆了论文。 | |
| 仅打印警告,不阻断流程。 | |
| """ | |
| # 提取标题中的实义词(长度>=4,排除停用词) | |
| title_stops = { | |
| 'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are', | |
| 'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it', | |
| 'its', 'not', 'can', 'has', 'have', 'been', 'was', 'were', 'will', 'would', | |
| 'could', 'should', 'may', 'do', 'does', 'did', 'so', 'if', 'no', 'new', | |
| 'based', 'using', 'which', 'into', 'such', 'than', 'then', 'these', 'those', | |
| 'propose', 'present', 'method', 'approach', 'framework', 'novel', 'model', | |
| 'learning', 'deep', 'via', 'et', 'al', | |
| } | |
| title_words = set( | |
| w.lower() for w in re.sub(r'[^\w\s]', ' ', title).split() | |
| if len(w) >= 4 and w.lower() not in title_stops | |
| ) | |
| subfield_lower = subfield.lower() | |
| overlap = [w for w in title_words if w in subfield_lower] | |
| if not overlap and title_words: | |
| print(f" ⚠️ [防混淆警告] 子领域\"{subfield}\"与论文标题无关键词重叠") | |
| print(f" 标题关键词: {sorted(title_words)[:10]}") | |
| print(f" 这可能是 LLM 混淆了论文,请人工核实分析结果。") | |
| def _make_fallback_direction(title: str, abstract: str, repos: list[dict]) -> dict: | |
| """Agent 1 失败时的降级方向分析:基于搜索到的仓库名称推断子领域。""" | |
| # 从仓库 topic/description 提取高频词作为子领域 | |
| all_words = [] | |
| for r in repos[:10]: | |
| desc = (r.get("description") or "") | |
| topics = " ".join(r.get("topics", [])) | |
| all_words.extend((desc + " " + topics).lower().split()) | |
| stops = {'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are', | |
| 'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it'} | |
| meaningful = [w for w in all_words if w not in stops and len(w) >= 4] | |
| word_freq = {} | |
| for w in meaningful: | |
| word_freq[w] = word_freq.get(w, 0) + 1 | |
| top_words = sorted(word_freq, key=word_freq.get, reverse=True)[:6] | |
| return { | |
| "subfield": f"基于仓库数据推断: {', '.join(top_words[:3])}" if top_words else "未知领域", | |
| "subfield_trend": "(Agent 1 暂不可用,趋势分析跳过。后续版本将自动恢复。)", | |
| "method_families": [], | |
| "broad_queries": [], | |
| } | |
| def _make_error_evaluation(error_msg: str) -> dict: | |
| """构造一个表示评估失败的 evaluation dict""" | |
| return { | |
| "reproducibility_score": 0, | |
| "benchmark_fitness_score": 0, | |
| "overall_score": 0, | |
| "verdict": "error", | |
| "env_score": 0, | |
| "doc_score": 0, | |
| "code_score": 0, | |
| "community_score": 0, | |
| "dep_score": 0, | |
| "benchmark_score": 0, | |
| "reasoning": f"评估失败: {error_msg[:100]}", | |
| "risks": ["评估过程出错,请手动检查该仓库"], | |
| "benchmark_readiness": "not_ready", | |
| "suggested_use": "评估失败,请手动检查", | |
| } | |
| # ============================================================ | |
| # 命令行入口 | |
| # ============================================================ | |
| if __name__ == "__main__": | |
| # 修复 Windows 终端 emoji 编码问题 | |
| from llm_utils import fix_windows_encoding | |
| fix_windows_encoding() | |
| if len(sys.argv) < 2: | |
| print("用法: python run.py <arxiv_url>") | |
| print("示例: python run.py https://arxiv.org/abs/1706.03762") | |
| print("示例: python run.py https://arxiv.org/abs/2011.08785") | |
| sys.exit(1) | |
| url = sys.argv[1] | |
| result = run(url) | |
| if result.get("error"): | |
| print(f"\n{'='*60}") | |
| print(f"运行未完全成功") | |
| print(f"{'='*60}") | |
| print(f" {result['error']}") | |
| if result.get("paper"): | |
| print(f"\n 论文信息已获取: {result['paper'].get('title', '')[:80]}") | |
| if result.get("direction"): | |
| print(f" 方向解析已完成: {result['direction'].get('subfield', '')}") | |
| else: | |
| from app import format_report | |
| # 用 utf-8 编码输出避免 emoji 乱码 | |
| report = format_report(result) | |
| print(report.encode(sys.stdout.encoding or 'utf-8', errors='replace').decode(sys.stdout.encoding or 'utf-8', errors='replace')) | |