# run.py # ============================================================ # 类型:调度器(乙负责) # 功能:按"先搜再分析"流程串联 Workflow(甲)和 Agent(乙) # 用法:python run.py # 示例:python run.py https://arxiv.org/abs/2011.08785 # ============================================================ import sys import time import re import os import json import urllib.request import urllib.error import xml.etree.ElementTree as ET from concurrent.futures import ThreadPoolExecutor, as_completed # 自动加载 .env 文件 _ENV_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".env") if os.path.isfile(_ENV_PATH): with open(_ENV_PATH, "r", encoding="utf-8") as _f: for _line in _f: _line = _line.strip() if _line and not _line.startswith("#") and "=" in _line: _key, _val = _line.split("=", 1) if _key not in os.environ: os.environ[_key] = _val.strip().strip('"').strip("'") # ---- 导入甲的 Workflow 模块 ---- try: from paper_fetcher import fetch_paper_info except ImportError as e: raise ImportError( f"缺少模块: paper_fetcher.py(甲负责),请等待甲交付。\n" f" 原始错误: {e}" ) from e try: from repo_searcher import search_repos except ImportError as e: raise ImportError( f"缺少模块: repo_searcher.py(甲负责),请等待甲交付。\n" f" 原始错误: {e}" ) from e try: from repo_fetcher import fetch_readme, fetch_dependencies except ImportError as e: raise ImportError( f"缺少模块: repo_fetcher.py(甲负责),请等待甲交付。\n" f" 原始错误: {e}" ) from e from direction_analyzer import analyze_direction from repo_evaluator import evaluate_repo from supervisor import supervise from llm_utils import call_llm_json, parse_json_safe def _enrich_domain_context(title: str, abstract: str, categories: list[str]) -> str: """从 Semantic Scholar 搜索同领域相关论文,扩充领域上下文。 S2 覆盖全工科(含 IEEE/ASME/ASCE 等期刊),不限于 arXiv。 搜索策略:1) 论文标题关键词 + survey/review 2) 摘要关键词 失败时返回空字符串(静默降级)。 """ all_abstracts = [] # 从标题提取 2-3 个核心关键词作为搜索词 title_keywords = _extract_title_keywords(title) search_queries = [ f"{' '.join(title_keywords)} survey review", f"{' '.join(title_keywords)} state of the art", ] for query in search_queries[:2]: api_url = ( f"https://api.semanticscholar.org/graph/v1/paper/search" f"?query={urllib.request.quote(query)}&limit=3" f"&fields=title,abstract,year,citationCount" ) req = urllib.request.Request(api_url, headers={"User-Agent": "ResearchRadar/1.0"}) try: with urllib.request.urlopen(req, timeout=15) as resp: data = json.loads(resp.read().decode("utf-8")) for paper in data.get("data", []): t = (paper.get("title") or "").strip() s = (paper.get("abstract") or "").strip() if t and s: all_abstracts.append({ "title": t, "abstract": s[:800], "year": paper.get("year", ""), "citations": paper.get("citationCount", 0), }) except Exception: continue if not all_abstracts: return "" # 组装为上下文文本 lines = ["## 该领域相关论文(来自 Semantic Scholar,覆盖全工科领域)"] for i, a in enumerate(all_abstracts[:4]): year_str = f" ({a['year']})" if a.get("year") else "" cites_str = f" [{a.get('citations', 0)} 引用]" lines.append(f"{i+1}. **{a['title']}**{year_str}{cites_str}: {a['abstract']}") return "\n\n".join(lines) def _extract_title_keywords(title: str) -> list[str]: """从论文标题提取核心实义词作为搜索关键词。""" stops = { 'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are', 'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it', 'its', 'not', 'can', 'has', 'have', 'been', 'was', 'were', 'will', 'would', 'could', 'should', 'may', 'do', 'does', 'did', 'so', 'if', 'no', 'new', 'based', 'using', 'which', 'into', 'such', 'than', 'then', 'these', 'those', 'propose', 'present', 'method', 'approach', 'framework', 'novel', 'model', 'learning', 'deep', 'via', 'et', 'al', 'towards', 'toward', } words = [w.lower() for w in re.sub(r'[^\w\s-]', ' ', title).split() if len(w) >= 4 and w.lower() not in stops] # 去重保持顺序 seen = set() uniq = [] for w in words: if w not in seen: seen.add(w) uniq.append(w) return uniq[:5] def _mine_comparison_algorithms(arxiv_id: str, title: str, abstract: str) -> tuple[list[str], dict[str, int]]: """从 Semantic Scholar 引用网络中挖掘对比实验算法。 核心思路:论文的 references(引用的论文)中通常包含对比实验的 baseline 方法, 通过提取这些论文的标题作为额外搜索词,可以大幅提升 GitHub 搜索的覆盖度。 Returns: (extra_queries, citation_map): 额外搜索词列表, {paper_title: citation_count} """ s2_url = ( f"https://api.semanticscholar.org/graph/v1/paper/ArXiv:{arxiv_id}" f"?fields=references.title,references.citationCount,references.abstract" f"&limit=50" ) req = urllib.request.Request(s2_url, headers={"User-Agent": "ResearchRadar/1.0"}) try: with urllib.request.urlopen(req, timeout=20) as resp: data = json.loads(resp.read().decode("utf-8")) except Exception as e: print(f" [WARN] Semantic Scholar API 不可用,跳过对比算法挖掘: {e}") return [], {} refs = data.get("references", []) if not refs: return [], {} # 按引用量排序,取 top 15 refs.sort(key=lambda r: r.get("citationCount", 0), reverse=True) top_refs = refs[:15] # 构建标题列表供 LLM 识别方法论文 title_list = [] citation_map = {} for r in top_refs: t = (r.get("title") or "").strip() cc = r.get("citationCount", 0) if t and len(t) > 10: title_list.append(f"- [{cc} cites] {t}") citation_map[t] = cc if len(title_list) < 3: return [], citation_map # 用 LLM 从引用论文标题中识别哪些是方法/算法论文 system_prompt = """你是学术论文分析专家。从引用论文列表中识别哪些是提出了具体算法/方法的论文。 排除标准: - 数据集/benchmark 论文(如 ImageNet, CIFAR, MVTec AD) - 综述/survey 论文 - 纯理论/数学论文 - 框架/库论文(如 PyTorch, TensorFlow) 保留标准: - 提出了具体的模型/架构/算法名称 - 可以作为对比实验的 baseline 方法 输出严格 JSON: {"methods": ["方法名1", "方法名2"], "search_queries": ["method1 pytorch implementation", "method2 official code"]} 方法名尽量使用论文中常用的英文缩写或全称。""" user_prompt = f"""输入论文标题: {title[:200]} 引用论文列表: {chr(10).join(title_list)} 请识别哪些是方法/算法论文,生成对应的 GitHub 搜索词。""" try: raw = call_llm_json(system_prompt, user_prompt, temperature=0.2, max_tokens=3000) data = parse_json_safe(raw, "comparison_miner") methods = data.get("methods", []) queries = data.get("search_queries", []) print(f" 从引用网络识别到 {len(methods)} 个对比算法: {methods[:8]}") return queries[:8], citation_map except Exception as e: print(f" [WARN] 对比算法识别失败,降级: {e}") # 降级:直接用引用论文标题作为搜索词 fallback_queries = [] for t in list(citation_map.keys())[:5]: short = re.sub(r'[^\w\s-]', '', t).strip()[:80] if len(short) > 15: fallback_queries.append(f"{short} pytorch implementation") return fallback_queries, citation_map def _search_s2_papers(query: str, limit: int = 5) -> list[dict]: """在 Semantic Scholar 中搜索论文,覆盖全工科领域(IEEE/ASME/ASCE 等期刊)。 带 30 分钟 TTL 缓存,降低 S2 API 压力。 Returns: list[dict]: [{"title": ..., "abstract": ..., "year": ..., "citations": ..., "url": ...}, ...] """ from cache import s2_cache cache_key = f"s2_search:{query}:{limit}" cached = s2_cache.get(cache_key) if cached is not None: return cached api_url = ( f"https://api.semanticscholar.org/graph/v1/paper/search" f"?query={urllib.request.quote(query)}&limit={limit}" f"&fields=title,abstract,year,citationCount,url,externalIds" ) req = urllib.request.Request(api_url, headers={"User-Agent": "ResearchRadar/1.0"}) try: with urllib.request.urlopen(req, timeout=15) as resp: data = json.loads(resp.read().decode("utf-8")) except Exception: return [] papers = [] for p in data.get("data", []): title = (p.get("title") or "").strip() if not title: continue papers.append({ "title": title, "abstract": (p.get("abstract") or "").strip()[:1000], "year": p.get("year"), "citations": p.get("citationCount", 0), "url": p.get("url", ""), "arxiv_id": (p.get("externalIds") or {}).get("ArXiv", ""), "doi": (p.get("externalIds") or {}).get("DOI", ""), }) s2_cache.set(cache_key, papers) return papers def _find_input_paper_repo(title: str, arxiv_id: str, authors: list[str]) -> dict | None: """搜索输入论文自身的官方代码仓库。 搜索策略(按优先级): 1. GitHub 搜索 arxiv ID 2. GitHub 搜索论文标题(精确匹配) 3. GitHub 搜索一作姓名 + 论文核心关键词 Returns: 找到则返回符合格式的候选仓库 dict,找不到返回 None """ import requests token = os.getenv("GITHUB_TOKEN", "") headers = {"Accept": "application/vnd.github.v3+json"} if token: headers["Authorization"] = f"Bearer {token}" # 提取核心关键词(取标题前 6 个实义词) stops = { 'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are', 'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it', 'its', 'not', 'can', 'has', 'have', 'been', 'was', 'were', 'will', 'would', 'could', 'should', 'may', 'do', 'does', 'did', 'so', 'if', 'no', 'new', 'based', 'using', 'which', 'into', 'such', 'than', 'then', 'these', 'those', 'propose', 'present', 'method', 'approach', 'framework', 'novel', 'model', 'learning', 'deep', 'via', 'et', 'al', 'towards', 'toward', } title_words = [w for w in re.sub(r'[^\w\s-]', ' ', title).split() if w.lower() not in stops and len(w) >= 3] core_keywords = " ".join(title_words[:6]) search_queries = [ f'"{arxiv_id}" in:name,description,readme', f'"{core_keywords}" in:name,description stars:>=3', ] # 用一作姓名 + 关键词搜索 first_author = authors[0] if authors else "" if first_author: last_name = first_author.split()[-1] if first_author.split() else "" if len(last_name) >= 3: search_queries.append(f"{last_name} {core_keywords[:80]} in:name,description") for query in search_queries: url = "https://api.github.com/search/repositories" params = {"q": query, "sort": "stars", "order": "desc", "per_page": 5} try: resp = requests.get(url, headers=headers, params=params, timeout=15) if resp.status_code in (403, 429): continue resp.raise_for_status() data = resp.json() except Exception: continue # 筛选:标题必须包含核心关键词中至少 2 个 for item in data.get("items", []): repo_title = (item.get("description") or "").lower() repo_name = item.get("full_name", "").lower() combined = f"{repo_name} {repo_title}" matches = sum(1 for kw in title_words[:6] if kw.lower() in combined) if matches >= 2: return { "full_name": item["full_name"], "html_url": item["html_url"], "description": item.get("description", ""), "stars": item.get("stargazers_count", 0), "language": item.get("language", ""), "updated_at": item.get("updated_at", ""), "topics": item.get("topics", []), "match_keyword": f"本文代码: {query[:60]}", } return None def run(arxiv_url: str, top_n: int = 5, progress=None) -> dict: """主入口:先搜索 GitHub → 再基于实际仓库归纳方法族 → 最后评估。 Args: arxiv_url: arxiv 论文 URL top_n: 最终评估的仓库数量 progress: gr.Progress 实例或 callable(fraction, desc),用于前端进度条 Returns: dict: {"paper": {...}, "direction": {...}, "repos": [{...}], "error": "..."} """ t_start = time.time() def _prog(frac: float, desc: str): if progress: progress(frac, desc=desc) # ================================================================ # [1/6] 论文信息获取(甲 Workflow) # ================================================================ print("=" * 60) print("[1/6] 正在获取论文信息...") _prog(0.05, "正在从 arxiv 获取论文信息...") t0 = time.time() try: paper = fetch_paper_info(arxiv_url) except Exception as e: return {"error": f"获取论文信息失败: {e}"} title = paper.get("title", "") abstract = paper.get("abstract", "") print(f" 标题: {title[:100]}...") print(f" 作者: {', '.join(paper.get('authors', [])[:3])}") print(f" 分类: {', '.join(paper.get('categories', []))}") print(f" 耗时: {time.time() - t0:.1f}s") # ================================================================ # [1.2] 搜索输入论文的官方代码 # ================================================================ print() print("[1.2] 正在搜索输入论文的官方代码...") _prog(0.10, "正在搜索论文自身的官方代码...") t0 = time.time() input_paper_repo = _find_input_paper_repo(title, arxiv_id, paper.get("authors", [])) if input_paper_repo: print(f" 找到论文官方代码: {input_paper_repo['full_name']} (Stars: {input_paper_repo.get('stars', 0)})") else: print(f" 未找到论文官方代码") print(f" 耗时: {time.time() - t0:.1f}s") # ================================================================ # [1.5] 对比实验算法挖掘(从 S2 引用网络提取算法名) # ================================================================ print() print("[1.5] 正在从论文引用网络挖掘对比实验算法...") _prog(0.12, "正在从引用网络挖掘对比算法...") t0 = time.time() arxiv_id = paper.get("arxiv_id", "") comparison_queries, citation_map = _mine_comparison_algorithms(arxiv_id, title, abstract) if comparison_queries: print(f" 生成 {len(comparison_queries)} 个对比算法搜索词:") for q in comparison_queries[:5]: print(f" - {q}") print(f" 耗时: {time.time() - t0:.1f}s") # ================================================================ # [1.8] S2 论文搜索(覆盖全工科领域,含 IEEE/ASME 等期刊) # ================================================================ print() print("[1.8] 正在 Semantic Scholar 中搜索同领域论文(覆盖全工科)...") _prog(0.18, "正在 Semantic Scholar 搜索同领域论文...") t0 = time.time() title_kws = _extract_title_keywords(title) s2_papers = _search_s2_papers(" ".join(title_kws), limit=8) # 从 S2 论文标题中提取额外搜索词 s2_extra_queries = [] if s2_papers: print(f" 找到 {len(s2_papers)} 篇相关论文:") for p in s2_papers[:5]: cits = p.get("citations", 0) print(f" - [{cits} cites] {p['title'][:80]}") # 用高引论文标题作为搜索词 if cits >= 50: s2_quoted = p['title'].strip()[:80] # 去掉特殊字符 s2_clean = re.sub(r'[^\w\s-]', '', s2_quoted).strip() if len(s2_clean) > 15: s2_extra_queries.append(f"{s2_clean} github") print(f" 额外生成 {len(s2_extra_queries)} 个 S2 来源搜索词") print(f" 耗时: {time.time() - t0:.1f}s") # ================================================================ # [2/6] 宽泛 GitHub 搜索(甲 Workflow) # ================================================================ print() print("[2/6] 正在宽泛搜索 GitHub(先搜仓库,再让 Agent 分析)...") _prog(0.22, "正在 GitHub 搜索开源仓库...") t0 = time.time() broad_queries = _extract_broad_queries(title, abstract) # 合并对比算法搜索词,去重 has_github_token = bool(os.getenv("GITHUB_TOKEN", "")) max_queries = 15 if has_github_token else 8 all_queries = list(dict.fromkeys(s2_extra_queries + comparison_queries + broad_queries))[:max_queries] if not has_github_token: print(f" ⚠️ 未设置 GITHUB_TOKEN,搜索词限制为 {max_queries} 个以避免限速") print(f" LLM: {len(broad_queries)} + S2引用: {len(comparison_queries)} + S2论文: {len(s2_extra_queries)} = {len(all_queries)} 个总搜索词:") for q in all_queries: print(f" - {q}") try: broad_results = search_repos(all_queries, max_per_keyword=5) except Exception as e: return { "paper": paper, "error": f"GitHub 搜索失败: {e}", } print(f" 去重后获得 {len(broad_results)} 个候选仓库") print(f" 耗时: {time.time() - t0:.1f}s") if not broad_results: return { "paper": paper, "direction": {}, "repos": [], "error": "未找到相关开源仓库。该方向可能太新或太冷门,暂无高质量开源实现。", } # ================================================================ # [2.5] LLM 过滤不相关仓库 # ================================================================ print() print("[2.5] 正在用 LLM 过滤不相关仓库...") _prog(0.32, "正在用 LLM 过滤不相关仓库...") t0 = time.time() filtered_results = _filter_repos(title, abstract, broad_results) print(f" 耗时: {time.time() - t0:.1f}s") # ================================================================ # [2.6] 领域上下文扩充(S2 搜索综述论文补充领域知识) # ================================================================ print() print("[2.6] 正在扩充领域上下文(搜索相关综述论文)...") _prog(0.38, "正在扩充领域上下文...") t0 = time.time() domain_context = _enrich_domain_context(title, abstract, paper.get("categories", [])) if domain_context: print(f" 获取到 {domain_context.count('**') // 2} 篇相关综述的摘要") else: print(f" 未找到相关综述(将仅基于论文摘要和仓库数据进行分析)") print(f" 耗时: {time.time() - t0:.1f}s") # ================================================================ # [3/6] 基于仓库数据归纳方法族(Agent 1) # ================================================================ print() print(f"[3/6] 正在分析 {len(filtered_results)} 个仓库,归纳方法族(Agent 1)...") _prog(0.42, "正在用 LLM 归纳方法族...") t0 = time.time() try: direction = analyze_direction(title, abstract, filtered_results, domain_context) except Exception as e: print(f" [WARN] Agent 1 方向解析失败,降级为基本分析: {e}") direction = _make_fallback_direction(title, abstract, filtered_results) subfield = direction.get("subfield", "未知") families = direction.get("method_families", []) print(f" 子领域: {subfield}") print(f" 趋势: {direction.get('subfield_trend', '')[:80]}...") print(f" 方法族 ({len(families)} 个):") for mf in families: matched = mf.get("matched_repos", []) print(f" - {mf.get('family_name', '?')}: {len(matched)} 个仓库 {matched}") # 防混淆校验:检查子领域是否与论文标题存在语义关联 _sanity_check_direction(title, subfield) print(f" 耗时: {time.time() - t0:.1f}s") # ================================================================ # [4/6] 筛选仓库 + 构建方法族归属映射 # ================================================================ print() print(f"[4/6] 正在筛选并获取仓库详情...") _prog(0.55, "正在筛选仓库并获取详情...") t0 = time.time() # 从 Agent 1 的 matched_repos 中建立 full_name → family_name 映射 repo_family_map = {} for mf in families: family_name = mf.get("family_name", "") for repo_name in mf.get("matched_repos", []): repo_family_map[repo_name] = family_name # 从 filtered_results 中筛选:优先选有方法族归属的,再按 stars 排序 classified = [r for r in filtered_results if r["full_name"] in repo_family_map] unclassified = [r for r in filtered_results if r["full_name"] not in repo_family_map] classified.sort(key=lambda r: r.get("stars", 0), reverse=True) unclassified.sort(key=lambda r: r.get("stars", 0), reverse=True) # 论文自身代码优先排在最前面 if input_paper_repo: candidates = [input_paper_repo] + (classified + unclassified)[:top_n - 1] repo_family_map[input_paper_repo["full_name"]] = "本文代码" else: candidates = (classified + unclassified)[:top_n] print(f" 有方法族归属: {len(classified)} 个,未归类: {len(unclassified)} 个") if input_paper_repo: print(f" 含论文自身代码: {input_paper_repo['full_name']}") print(f" 最终选取 {len(candidates)} 个仓库:") for i, c in enumerate(candidates): family = repo_family_map.get(c["full_name"], "未归类") print(f" {i+1}. {c['full_name']} [{family}] Stars:{c.get('stars', 0)}") print(f" 耗时: {time.time() - t0:.1f}s") # ================================================================ # [5/6] 仓库详情获取 + 评估(Agent 2,并行) # ================================================================ print() print(f"[5/6] 正在获取仓库详情并评估(Agent 2,{len(candidates)} 个仓库并行)...") _prog(0.60, f"正在评估 {len(candidates)} 个开源仓库...") t0 = time.time() def _eval_single(repo, idx, total): """单个仓库评估(在线程中执行)""" full_name = repo["full_name"] try: owner, name = full_name.split("/", 1) except ValueError: return idx, None matched_family = repo_family_map.get(full_name, "") family_tag = f"[{matched_family}]" if matched_family else "[未归类]" print(f" ({idx+1}/{total}) {full_name} {family_tag}...") try: readme = fetch_readme(owner, name) deps = fetch_dependencies(owner, name) evaluation = evaluate_repo(repo, readme, deps, matched_family) except Exception as e: print(f" [WARN] {full_name} 评估失败: {e}") evaluation = _make_error_evaluation(str(e)) return idx, { "full_name": repo.get("full_name", ""), "html_url": repo.get("html_url", ""), "description": repo.get("description", ""), "stars": repo.get("stars", 0), "language": repo.get("language", ""), "updated_at": repo.get("updated_at", ""), "topics": repo.get("topics", []), "match_keyword": repo.get("match_keyword", ""), "method_family": matched_family, "evaluation": evaluation, } n = len(candidates) evaluated = [None] * n with ThreadPoolExecutor(max_workers=min(n, 5)) as executor: futures = { executor.submit(_eval_single, repo, i, n): i for i, repo in enumerate(candidates) } for future in as_completed(futures): idx, result = future.result() if result is not None: evaluated[idx] = result evaluated = [r for r in evaluated if r is not None] # 按综合评分降序排列 evaluated.sort(key=lambda r: r["evaluation"].get("overall_score", 0), reverse=True) print(f" 耗时: {time.time() - t0:.1f}s") # ================================================================ # 汇总 # ================================================================ print() print("=" * 60) print("完成!") best = evaluated[0] if evaluated else None if best: print(f" 最高分: {best['full_name']} ({best['evaluation'].get('overall_score', 0)}/100)") # ================================================================ # 审核层:检查 Agent 输出质量和来源可靠性 # ================================================================ _prog(0.90, "正在进行质量审核...") audit = supervise(title, abstract, direction, evaluated) print(f"\n 审核结果: {audit['summary']}") print(f" 综合质量评分: {audit['overall_score']}/100") if audit["actions"]: for action in audit["actions"]: print(f" → {action}") print(f" 总耗时: {time.time() - t_start:.1f}s") print("=" * 60) _prog(1.0, "分析完成,正在生成研报...") return { "paper": paper, "direction": direction, "repos": evaluated, "audit": audit, } def _extract_broad_queries(title: str, abstract: str) -> list[str]: """用 LLM 从论文标题和摘要中生成 GitHub 搜索关键词。 相比规则提取,LLM 理解论文后能生成更精准的领域搜索词。 如果 LLM 调用失败,降级为规则提取。 """ system_prompt = """你是学术论文搜索专家。根据论文标题和摘要,生成 6-8 个 GitHub 搜索查询,用于找到该研究方向的开源实现。 要求: - 查询必须是英文,用空格分隔关键词 - 包含方法名/技术关键词 + 限定词(implementation / pytorch / code) - 包含 1-2 个宽泛的领域查询(如 "anomaly detection pytorch library") - 不要用引号或特殊符号 - 避免太泛的词(如单一个 "anomaly") 输出严格 JSON: {"queries": ["query1", "query2", ...]}""" user_prompt = f"""论文标题: {title[:300]} 摘要: {abstract[:800]} 请为该论文生成 GitHub 搜索查询。""" try: raw = call_llm_json(system_prompt, user_prompt, temperature=0.3, max_tokens=3000) data = parse_json_safe(raw, "broad_queries") queries = data.get("queries", []) if isinstance(queries, list) and len(queries) >= 3: return queries[:10] except Exception as e: print(f" [WARN] LLM 关键词生成失败,降级为规则提取: {e}") return _extract_broad_queries_fallback(title, abstract) def _extract_broad_queries_fallback(title: str, abstract: str) -> list[str]: """从论文标题和摘要中提取宽泛搜索关键词(规则降级版)。 策略:提取有意义的词组 + 搜索限定词,按多样性采样(不偏向标题前缀)。 不依赖 LLM,纯规则提取。 """ abstract_first = abstract.split(".")[0] if abstract else "" text = f"{title} {abstract_first}".lower() text = re.sub(r'[^\w\s-]', ' ', text) words = text.split() stop_words = { 'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are', 'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it', 'its', 'not', 'can', 'has', 'have', 'been', 'was', 'were', 'will', 'would', 'could', 'should', 'may', 'do', 'does', 'did', 'so', 'if', 'no', 'new', 'based', 'using', 'which', 'into', 'such', 'than', 'then', 'these', 'those', 'propose', 'present', 'method', 'approach', 'framework', 'novel', 'model', 'learning', 'deep', 'via', 'et', 'al', 'all', 'we', 'you', 'they', } meaningful = [w for w in words if w not in stop_words and len(w) > 1] bigrams = [f"{meaningful[i]} {meaningful[i+1]}" for i in range(len(meaningful) - 1)] trigrams = [f"{meaningful[i]} {meaningful[i+1]} {meaningful[i+2]}" for i in range(len(meaningful) - 2)] def _sample_diverse(items: list[str], n: int) -> list[str]: if len(items) <= n: return items step = max(1, len(items) // n) return [items[i] for i in range(0, len(items), step)][:n] phrases = _sample_diverse(trigrams, 3) + _sample_diverse(bigrams, 4) seen = set() unique_phrases = [] for p in sorted(phrases, key=len, reverse=True): if p not in seen: seen.add(p) unique_phrases.append(p) qualifiers = [ "implementation pytorch", "official code", "pytorch github", ] queries = [] for phrase in unique_phrases[:5]: if len(phrase) > 5: queries.append(f"{phrase} {qualifiers[0]}") mid = len(meaningful) // 2 broad = " ".join(meaningful[max(0, mid-2):mid+2]) if len(broad) > 10: queries.append(f"{broad} {qualifiers[2]}") seen = set() unique = [] for q in queries: if q not in seen: seen.add(q) unique.append(q) return unique[:10] def _filter_repos(title: str, abstract: str, repos: list[dict]) -> list[dict]: """用 LLM 快速过滤不相关的仓库,只保留与论文方向相关的。 在 GitHub 关键词搜索之后、Agent 1 分析之前运行。 如果过滤后为空或 LLM 调用失败,返回原始列表作为降级方案。 """ if len(repos) <= 3: return repos # 超过 20 个仓库时,只取 top 20(按 stars),避免 JSON 输出过长 if len(repos) > 20: repos_for_filter = sorted(repos, key=lambda r: r.get("stars", 0), reverse=True)[:20] else: repos_for_filter = repos # 构造缩略的仓库清单供 LLM 判断 repo_list_parts = [] for i, r in enumerate(repos_for_filter): desc = (r.get("description") or "")[:120] topics = ", ".join(r.get("topics", [])[:8]) repo_list_parts.append( f"{i+1}. {r['full_name']} (Stars:{r.get('stars',0)})\n" f" Description: {desc}\n" f" Topics: {topics}" ) repo_list_text = "\n".join(repo_list_parts) system_prompt = """你是一个学术论文与开源代码匹配专家。任务是判断 GitHub 仓库是否与给定论文属于同一研究方向。 排除标准(满足任一即排除): - 仓库解决的业务问题与论文完全不同(如论文做工业缺陷检测,仓库做医学图像/自动驾驶/人脸识别) - 仓库使用的核心技术方法与论文完全无关 - 仓库是通用教程/课程作业/面试题集合 保留标准(满足任一即保留): - 仓库实现的方法与论文方法同属一个技术范式 - 仓库可用于该论文的对比实验(baseline/comparison) - 仓库是该领域的知名 benchmarking 库 输出严格 JSON(不要 reasoning 字段以节省 token): {"relevant": ["owner/repo1", "owner/repo2"], "irrelevant": ["owner/repo3"]}""" user_prompt = f"""## 论文信息 标题: {title[:200]} 摘要: {abstract[:500]} ## GitHub 仓库清单(共 {len(repos_for_filter)} 个) {repo_list_text} 请判断每个仓库是否与该论文方向相关。""" try: raw = call_llm_json(system_prompt, user_prompt, temperature=0.1, max_tokens=10000) data = parse_json_safe(raw, "filter_repos") except Exception as e: print(f" [WARN] 仓库过滤失败,保留全部: {e}") return repos relevant_set = set(data.get("relevant", [])) irrelevant = data.get("irrelevant", []) if irrelevant: print(f" [过滤] 剔除 {len(irrelevant)} 个不相关仓库: {irrelevant}") filtered = [r for r in repos if r["full_name"] in relevant_set] if len(filtered) < 2: print(f" [WARN] 过滤后仅剩 {len(filtered)} 个仓库,回退到原始结果") return repos print(f" 过滤后保留 {len(filtered)}/{len(repos)} 个仓库") return filtered def _sanity_check_direction(title: str, subfield: str) -> None: """防混淆检查:验证子领域分析是否与论文标题存在最低限度的语义关联。 如果子领域关键词与标题完全无关,可能是 LLM 混淆了论文。 仅打印警告,不阻断流程。 """ # 提取标题中的实义词(长度>=4,排除停用词) title_stops = { 'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are', 'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it', 'its', 'not', 'can', 'has', 'have', 'been', 'was', 'were', 'will', 'would', 'could', 'should', 'may', 'do', 'does', 'did', 'so', 'if', 'no', 'new', 'based', 'using', 'which', 'into', 'such', 'than', 'then', 'these', 'those', 'propose', 'present', 'method', 'approach', 'framework', 'novel', 'model', 'learning', 'deep', 'via', 'et', 'al', } title_words = set( w.lower() for w in re.sub(r'[^\w\s]', ' ', title).split() if len(w) >= 4 and w.lower() not in title_stops ) subfield_lower = subfield.lower() overlap = [w for w in title_words if w in subfield_lower] if not overlap and title_words: print(f" ⚠️ [防混淆警告] 子领域\"{subfield}\"与论文标题无关键词重叠") print(f" 标题关键词: {sorted(title_words)[:10]}") print(f" 这可能是 LLM 混淆了论文,请人工核实分析结果。") def _make_fallback_direction(title: str, abstract: str, repos: list[dict]) -> dict: """Agent 1 失败时的降级方向分析:基于搜索到的仓库名称推断子领域。""" # 从仓库 topic/description 提取高频词作为子领域 all_words = [] for r in repos[:10]: desc = (r.get("description") or "") topics = " ".join(r.get("topics", [])) all_words.extend((desc + " " + topics).lower().split()) stops = {'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are', 'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it'} meaningful = [w for w in all_words if w not in stops and len(w) >= 4] word_freq = {} for w in meaningful: word_freq[w] = word_freq.get(w, 0) + 1 top_words = sorted(word_freq, key=word_freq.get, reverse=True)[:6] return { "subfield": f"基于仓库数据推断: {', '.join(top_words[:3])}" if top_words else "未知领域", "subfield_trend": "(Agent 1 暂不可用,趋势分析跳过。后续版本将自动恢复。)", "method_families": [], "broad_queries": [], } def _make_error_evaluation(error_msg: str) -> dict: """构造一个表示评估失败的 evaluation dict""" return { "reproducibility_score": 0, "benchmark_fitness_score": 0, "overall_score": 0, "verdict": "error", "env_score": 0, "doc_score": 0, "code_score": 0, "community_score": 0, "dep_score": 0, "benchmark_score": 0, "reasoning": f"评估失败: {error_msg[:100]}", "risks": ["评估过程出错,请手动检查该仓库"], "benchmark_readiness": "not_ready", "suggested_use": "评估失败,请手动检查", } # ============================================================ # 命令行入口 # ============================================================ if __name__ == "__main__": # 修复 Windows 终端 emoji 编码问题 from llm_utils import fix_windows_encoding fix_windows_encoding() if len(sys.argv) < 2: print("用法: python run.py ") print("示例: python run.py https://arxiv.org/abs/1706.03762") print("示例: python run.py https://arxiv.org/abs/2011.08785") sys.exit(1) url = sys.argv[1] result = run(url) if result.get("error"): print(f"\n{'='*60}") print(f"运行未完全成功") print(f"{'='*60}") print(f" {result['error']}") if result.get("paper"): print(f"\n 论文信息已获取: {result['paper'].get('title', '')[:80]}") if result.get("direction"): print(f" 方向解析已完成: {result['direction'].get('subfield', '')}") else: from app import format_report # 用 utf-8 编码输出避免 emoji 乱码 report = format_report(result) print(report.encode(sys.stdout.encoding or 'utf-8', errors='replace').decode(sys.stdout.encoding or 'utf-8', errors='replace'))