ResearchRadar / run.py
ZZZyx3587's picture
Upload run.py with huggingface_hub
908d065 verified
# run.py
# ============================================================
# 类型:调度器(乙负责)
# 功能:按"先搜再分析"流程串联 Workflow(甲)和 Agent(乙)
# 用法:python run.py <arxiv_url>
# 示例:python run.py https://arxiv.org/abs/2011.08785
# ============================================================
import sys
import time
import re
import os
import json
import urllib.request
import urllib.error
import xml.etree.ElementTree as ET
from concurrent.futures import ThreadPoolExecutor, as_completed
# 自动加载 .env 文件
_ENV_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".env")
if os.path.isfile(_ENV_PATH):
with open(_ENV_PATH, "r", encoding="utf-8") as _f:
for _line in _f:
_line = _line.strip()
if _line and not _line.startswith("#") and "=" in _line:
_key, _val = _line.split("=", 1)
if _key not in os.environ:
os.environ[_key] = _val.strip().strip('"').strip("'")
# ---- 导入甲的 Workflow 模块 ----
try:
from paper_fetcher import fetch_paper_info
except ImportError as e:
raise ImportError(
f"缺少模块: paper_fetcher.py(甲负责),请等待甲交付。\n"
f" 原始错误: {e}"
) from e
try:
from repo_searcher import search_repos
except ImportError as e:
raise ImportError(
f"缺少模块: repo_searcher.py(甲负责),请等待甲交付。\n"
f" 原始错误: {e}"
) from e
try:
from repo_fetcher import fetch_readme, fetch_dependencies
except ImportError as e:
raise ImportError(
f"缺少模块: repo_fetcher.py(甲负责),请等待甲交付。\n"
f" 原始错误: {e}"
) from e
from direction_analyzer import analyze_direction
from repo_evaluator import evaluate_repo
from supervisor import supervise
from llm_utils import call_llm_json, parse_json_safe
def _enrich_domain_context(title: str, abstract: str, categories: list[str]) -> str:
"""从 Semantic Scholar 搜索同领域相关论文,扩充领域上下文。
S2 覆盖全工科(含 IEEE/ASME/ASCE 等期刊),不限于 arXiv。
搜索策略:1) 论文标题关键词 + survey/review 2) 摘要关键词
失败时返回空字符串(静默降级)。
"""
all_abstracts = []
# 从标题提取 2-3 个核心关键词作为搜索词
title_keywords = _extract_title_keywords(title)
search_queries = [
f"{' '.join(title_keywords)} survey review",
f"{' '.join(title_keywords)} state of the art",
]
for query in search_queries[:2]:
api_url = (
f"https://api.semanticscholar.org/graph/v1/paper/search"
f"?query={urllib.request.quote(query)}&limit=3"
f"&fields=title,abstract,year,citationCount"
)
req = urllib.request.Request(api_url, headers={"User-Agent": "ResearchRadar/1.0"})
try:
with urllib.request.urlopen(req, timeout=15) as resp:
data = json.loads(resp.read().decode("utf-8"))
for paper in data.get("data", []):
t = (paper.get("title") or "").strip()
s = (paper.get("abstract") or "").strip()
if t and s:
all_abstracts.append({
"title": t,
"abstract": s[:800],
"year": paper.get("year", ""),
"citations": paper.get("citationCount", 0),
})
except Exception:
continue
if not all_abstracts:
return ""
# 组装为上下文文本
lines = ["## 该领域相关论文(来自 Semantic Scholar,覆盖全工科领域)"]
for i, a in enumerate(all_abstracts[:4]):
year_str = f" ({a['year']})" if a.get("year") else ""
cites_str = f" [{a.get('citations', 0)} 引用]"
lines.append(f"{i+1}. **{a['title']}**{year_str}{cites_str}: {a['abstract']}")
return "\n\n".join(lines)
def _extract_title_keywords(title: str) -> list[str]:
"""从论文标题提取核心实义词作为搜索关键词。"""
stops = {
'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are',
'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it',
'its', 'not', 'can', 'has', 'have', 'been', 'was', 'were', 'will', 'would',
'could', 'should', 'may', 'do', 'does', 'did', 'so', 'if', 'no', 'new',
'based', 'using', 'which', 'into', 'such', 'than', 'then', 'these', 'those',
'propose', 'present', 'method', 'approach', 'framework', 'novel', 'model',
'learning', 'deep', 'via', 'et', 'al', 'towards', 'toward',
}
words = [w.lower() for w in re.sub(r'[^\w\s-]', ' ', title).split() if len(w) >= 4 and w.lower() not in stops]
# 去重保持顺序
seen = set()
uniq = []
for w in words:
if w not in seen:
seen.add(w)
uniq.append(w)
return uniq[:5]
def _mine_comparison_algorithms(arxiv_id: str, title: str, abstract: str) -> tuple[list[str], dict[str, int]]:
"""从 Semantic Scholar 引用网络中挖掘对比实验算法。
核心思路:论文的 references(引用的论文)中通常包含对比实验的 baseline 方法,
通过提取这些论文的标题作为额外搜索词,可以大幅提升 GitHub 搜索的覆盖度。
Returns:
(extra_queries, citation_map): 额外搜索词列表, {paper_title: citation_count}
"""
s2_url = (
f"https://api.semanticscholar.org/graph/v1/paper/ArXiv:{arxiv_id}"
f"?fields=references.title,references.citationCount,references.abstract"
f"&limit=50"
)
req = urllib.request.Request(s2_url, headers={"User-Agent": "ResearchRadar/1.0"})
try:
with urllib.request.urlopen(req, timeout=20) as resp:
data = json.loads(resp.read().decode("utf-8"))
except Exception as e:
print(f" [WARN] Semantic Scholar API 不可用,跳过对比算法挖掘: {e}")
return [], {}
refs = data.get("references", [])
if not refs:
return [], {}
# 按引用量排序,取 top 15
refs.sort(key=lambda r: r.get("citationCount", 0), reverse=True)
top_refs = refs[:15]
# 构建标题列表供 LLM 识别方法论文
title_list = []
citation_map = {}
for r in top_refs:
t = (r.get("title") or "").strip()
cc = r.get("citationCount", 0)
if t and len(t) > 10:
title_list.append(f"- [{cc} cites] {t}")
citation_map[t] = cc
if len(title_list) < 3:
return [], citation_map
# 用 LLM 从引用论文标题中识别哪些是方法/算法论文
system_prompt = """你是学术论文分析专家。从引用论文列表中识别哪些是提出了具体算法/方法的论文。
排除标准:
- 数据集/benchmark 论文(如 ImageNet, CIFAR, MVTec AD)
- 综述/survey 论文
- 纯理论/数学论文
- 框架/库论文(如 PyTorch, TensorFlow)
保留标准:
- 提出了具体的模型/架构/算法名称
- 可以作为对比实验的 baseline 方法
输出严格 JSON:
{"methods": ["方法名1", "方法名2"], "search_queries": ["method1 pytorch implementation", "method2 official code"]}
方法名尽量使用论文中常用的英文缩写或全称。"""
user_prompt = f"""输入论文标题: {title[:200]}
引用论文列表:
{chr(10).join(title_list)}
请识别哪些是方法/算法论文,生成对应的 GitHub 搜索词。"""
try:
raw = call_llm_json(system_prompt, user_prompt, temperature=0.2, max_tokens=3000)
data = parse_json_safe(raw, "comparison_miner")
methods = data.get("methods", [])
queries = data.get("search_queries", [])
print(f" 从引用网络识别到 {len(methods)} 个对比算法: {methods[:8]}")
return queries[:8], citation_map
except Exception as e:
print(f" [WARN] 对比算法识别失败,降级: {e}")
# 降级:直接用引用论文标题作为搜索词
fallback_queries = []
for t in list(citation_map.keys())[:5]:
short = re.sub(r'[^\w\s-]', '', t).strip()[:80]
if len(short) > 15:
fallback_queries.append(f"{short} pytorch implementation")
return fallback_queries, citation_map
def _search_s2_papers(query: str, limit: int = 5) -> list[dict]:
"""在 Semantic Scholar 中搜索论文,覆盖全工科领域(IEEE/ASME/ASCE 等期刊)。
带 30 分钟 TTL 缓存,降低 S2 API 压力。
Returns:
list[dict]: [{"title": ..., "abstract": ..., "year": ..., "citations": ..., "url": ...}, ...]
"""
from cache import s2_cache
cache_key = f"s2_search:{query}:{limit}"
cached = s2_cache.get(cache_key)
if cached is not None:
return cached
api_url = (
f"https://api.semanticscholar.org/graph/v1/paper/search"
f"?query={urllib.request.quote(query)}&limit={limit}"
f"&fields=title,abstract,year,citationCount,url,externalIds"
)
req = urllib.request.Request(api_url, headers={"User-Agent": "ResearchRadar/1.0"})
try:
with urllib.request.urlopen(req, timeout=15) as resp:
data = json.loads(resp.read().decode("utf-8"))
except Exception:
return []
papers = []
for p in data.get("data", []):
title = (p.get("title") or "").strip()
if not title:
continue
papers.append({
"title": title,
"abstract": (p.get("abstract") or "").strip()[:1000],
"year": p.get("year"),
"citations": p.get("citationCount", 0),
"url": p.get("url", ""),
"arxiv_id": (p.get("externalIds") or {}).get("ArXiv", ""),
"doi": (p.get("externalIds") or {}).get("DOI", ""),
})
s2_cache.set(cache_key, papers)
return papers
def _find_input_paper_repo(title: str, arxiv_id: str, authors: list[str]) -> dict | None:
"""搜索输入论文自身的官方代码仓库。
搜索策略(按优先级):
1. GitHub 搜索 arxiv ID
2. GitHub 搜索论文标题(精确匹配)
3. GitHub 搜索一作姓名 + 论文核心关键词
Returns:
找到则返回符合格式的候选仓库 dict,找不到返回 None
"""
import requests
token = os.getenv("GITHUB_TOKEN", "")
headers = {"Accept": "application/vnd.github.v3+json"}
if token:
headers["Authorization"] = f"Bearer {token}"
# 提取核心关键词(取标题前 6 个实义词)
stops = {
'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are',
'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it',
'its', 'not', 'can', 'has', 'have', 'been', 'was', 'were', 'will', 'would',
'could', 'should', 'may', 'do', 'does', 'did', 'so', 'if', 'no', 'new',
'based', 'using', 'which', 'into', 'such', 'than', 'then', 'these', 'those',
'propose', 'present', 'method', 'approach', 'framework', 'novel', 'model',
'learning', 'deep', 'via', 'et', 'al', 'towards', 'toward',
}
title_words = [w for w in re.sub(r'[^\w\s-]', ' ', title).split()
if w.lower() not in stops and len(w) >= 3]
core_keywords = " ".join(title_words[:6])
search_queries = [
f'"{arxiv_id}" in:name,description,readme',
f'"{core_keywords}" in:name,description stars:>=3',
]
# 用一作姓名 + 关键词搜索
first_author = authors[0] if authors else ""
if first_author:
last_name = first_author.split()[-1] if first_author.split() else ""
if len(last_name) >= 3:
search_queries.append(f"{last_name} {core_keywords[:80]} in:name,description")
for query in search_queries:
url = "https://api.github.com/search/repositories"
params = {"q": query, "sort": "stars", "order": "desc", "per_page": 5}
try:
resp = requests.get(url, headers=headers, params=params, timeout=15)
if resp.status_code in (403, 429):
continue
resp.raise_for_status()
data = resp.json()
except Exception:
continue
# 筛选:标题必须包含核心关键词中至少 2 个
for item in data.get("items", []):
repo_title = (item.get("description") or "").lower()
repo_name = item.get("full_name", "").lower()
combined = f"{repo_name} {repo_title}"
matches = sum(1 for kw in title_words[:6] if kw.lower() in combined)
if matches >= 2:
return {
"full_name": item["full_name"],
"html_url": item["html_url"],
"description": item.get("description", ""),
"stars": item.get("stargazers_count", 0),
"language": item.get("language", ""),
"updated_at": item.get("updated_at", ""),
"topics": item.get("topics", []),
"match_keyword": f"本文代码: {query[:60]}",
}
return None
def run(arxiv_url: str, top_n: int = 5, progress=None) -> dict:
"""主入口:先搜索 GitHub → 再基于实际仓库归纳方法族 → 最后评估。
Args:
arxiv_url: arxiv 论文 URL
top_n: 最终评估的仓库数量
progress: gr.Progress 实例或 callable(fraction, desc),用于前端进度条
Returns:
dict: {"paper": {...}, "direction": {...}, "repos": [{...}], "error": "..."}
"""
t_start = time.time()
def _prog(frac: float, desc: str):
if progress:
progress(frac, desc=desc)
# ================================================================
# [1/6] 论文信息获取(甲 Workflow)
# ================================================================
print("=" * 60)
print("[1/6] 正在获取论文信息...")
_prog(0.05, "正在从 arxiv 获取论文信息...")
t0 = time.time()
try:
paper = fetch_paper_info(arxiv_url)
except Exception as e:
return {"error": f"获取论文信息失败: {e}"}
title = paper.get("title", "")
abstract = paper.get("abstract", "")
print(f" 标题: {title[:100]}...")
print(f" 作者: {', '.join(paper.get('authors', [])[:3])}")
print(f" 分类: {', '.join(paper.get('categories', []))}")
print(f" 耗时: {time.time() - t0:.1f}s")
# ================================================================
# [1.2] 搜索输入论文的官方代码
# ================================================================
print()
print("[1.2] 正在搜索输入论文的官方代码...")
_prog(0.10, "正在搜索论文自身的官方代码...")
t0 = time.time()
input_paper_repo = _find_input_paper_repo(title, arxiv_id, paper.get("authors", []))
if input_paper_repo:
print(f" 找到论文官方代码: {input_paper_repo['full_name']} (Stars: {input_paper_repo.get('stars', 0)})")
else:
print(f" 未找到论文官方代码")
print(f" 耗时: {time.time() - t0:.1f}s")
# ================================================================
# [1.5] 对比实验算法挖掘(从 S2 引用网络提取算法名)
# ================================================================
print()
print("[1.5] 正在从论文引用网络挖掘对比实验算法...")
_prog(0.12, "正在从引用网络挖掘对比算法...")
t0 = time.time()
arxiv_id = paper.get("arxiv_id", "")
comparison_queries, citation_map = _mine_comparison_algorithms(arxiv_id, title, abstract)
if comparison_queries:
print(f" 生成 {len(comparison_queries)} 个对比算法搜索词:")
for q in comparison_queries[:5]:
print(f" - {q}")
print(f" 耗时: {time.time() - t0:.1f}s")
# ================================================================
# [1.8] S2 论文搜索(覆盖全工科领域,含 IEEE/ASME 等期刊)
# ================================================================
print()
print("[1.8] 正在 Semantic Scholar 中搜索同领域论文(覆盖全工科)...")
_prog(0.18, "正在 Semantic Scholar 搜索同领域论文...")
t0 = time.time()
title_kws = _extract_title_keywords(title)
s2_papers = _search_s2_papers(" ".join(title_kws), limit=8)
# 从 S2 论文标题中提取额外搜索词
s2_extra_queries = []
if s2_papers:
print(f" 找到 {len(s2_papers)} 篇相关论文:")
for p in s2_papers[:5]:
cits = p.get("citations", 0)
print(f" - [{cits} cites] {p['title'][:80]}")
# 用高引论文标题作为搜索词
if cits >= 50:
s2_quoted = p['title'].strip()[:80]
# 去掉特殊字符
s2_clean = re.sub(r'[^\w\s-]', '', s2_quoted).strip()
if len(s2_clean) > 15:
s2_extra_queries.append(f"{s2_clean} github")
print(f" 额外生成 {len(s2_extra_queries)} 个 S2 来源搜索词")
print(f" 耗时: {time.time() - t0:.1f}s")
# ================================================================
# [2/6] 宽泛 GitHub 搜索(甲 Workflow)
# ================================================================
print()
print("[2/6] 正在宽泛搜索 GitHub(先搜仓库,再让 Agent 分析)...")
_prog(0.22, "正在 GitHub 搜索开源仓库...")
t0 = time.time()
broad_queries = _extract_broad_queries(title, abstract)
# 合并对比算法搜索词,去重
has_github_token = bool(os.getenv("GITHUB_TOKEN", ""))
max_queries = 15 if has_github_token else 8
all_queries = list(dict.fromkeys(s2_extra_queries + comparison_queries + broad_queries))[:max_queries]
if not has_github_token:
print(f" ⚠️ 未设置 GITHUB_TOKEN,搜索词限制为 {max_queries} 个以避免限速")
print(f" LLM: {len(broad_queries)} + S2引用: {len(comparison_queries)} + S2论文: {len(s2_extra_queries)} = {len(all_queries)} 个总搜索词:")
for q in all_queries:
print(f" - {q}")
try:
broad_results = search_repos(all_queries, max_per_keyword=5)
except Exception as e:
return {
"paper": paper,
"error": f"GitHub 搜索失败: {e}",
}
print(f" 去重后获得 {len(broad_results)} 个候选仓库")
print(f" 耗时: {time.time() - t0:.1f}s")
if not broad_results:
return {
"paper": paper,
"direction": {},
"repos": [],
"error": "未找到相关开源仓库。该方向可能太新或太冷门,暂无高质量开源实现。",
}
# ================================================================
# [2.5] LLM 过滤不相关仓库
# ================================================================
print()
print("[2.5] 正在用 LLM 过滤不相关仓库...")
_prog(0.32, "正在用 LLM 过滤不相关仓库...")
t0 = time.time()
filtered_results = _filter_repos(title, abstract, broad_results)
print(f" 耗时: {time.time() - t0:.1f}s")
# ================================================================
# [2.6] 领域上下文扩充(S2 搜索综述论文补充领域知识)
# ================================================================
print()
print("[2.6] 正在扩充领域上下文(搜索相关综述论文)...")
_prog(0.38, "正在扩充领域上下文...")
t0 = time.time()
domain_context = _enrich_domain_context(title, abstract, paper.get("categories", []))
if domain_context:
print(f" 获取到 {domain_context.count('**') // 2} 篇相关综述的摘要")
else:
print(f" 未找到相关综述(将仅基于论文摘要和仓库数据进行分析)")
print(f" 耗时: {time.time() - t0:.1f}s")
# ================================================================
# [3/6] 基于仓库数据归纳方法族(Agent 1)
# ================================================================
print()
print(f"[3/6] 正在分析 {len(filtered_results)} 个仓库,归纳方法族(Agent 1)...")
_prog(0.42, "正在用 LLM 归纳方法族...")
t0 = time.time()
try:
direction = analyze_direction(title, abstract, filtered_results, domain_context)
except Exception as e:
print(f" [WARN] Agent 1 方向解析失败,降级为基本分析: {e}")
direction = _make_fallback_direction(title, abstract, filtered_results)
subfield = direction.get("subfield", "未知")
families = direction.get("method_families", [])
print(f" 子领域: {subfield}")
print(f" 趋势: {direction.get('subfield_trend', '')[:80]}...")
print(f" 方法族 ({len(families)} 个):")
for mf in families:
matched = mf.get("matched_repos", [])
print(f" - {mf.get('family_name', '?')}: {len(matched)} 个仓库 {matched}")
# 防混淆校验:检查子领域是否与论文标题存在语义关联
_sanity_check_direction(title, subfield)
print(f" 耗时: {time.time() - t0:.1f}s")
# ================================================================
# [4/6] 筛选仓库 + 构建方法族归属映射
# ================================================================
print()
print(f"[4/6] 正在筛选并获取仓库详情...")
_prog(0.55, "正在筛选仓库并获取详情...")
t0 = time.time()
# 从 Agent 1 的 matched_repos 中建立 full_name → family_name 映射
repo_family_map = {}
for mf in families:
family_name = mf.get("family_name", "")
for repo_name in mf.get("matched_repos", []):
repo_family_map[repo_name] = family_name
# 从 filtered_results 中筛选:优先选有方法族归属的,再按 stars 排序
classified = [r for r in filtered_results if r["full_name"] in repo_family_map]
unclassified = [r for r in filtered_results if r["full_name"] not in repo_family_map]
classified.sort(key=lambda r: r.get("stars", 0), reverse=True)
unclassified.sort(key=lambda r: r.get("stars", 0), reverse=True)
# 论文自身代码优先排在最前面
if input_paper_repo:
candidates = [input_paper_repo] + (classified + unclassified)[:top_n - 1]
repo_family_map[input_paper_repo["full_name"]] = "本文代码"
else:
candidates = (classified + unclassified)[:top_n]
print(f" 有方法族归属: {len(classified)} 个,未归类: {len(unclassified)} 个")
if input_paper_repo:
print(f" 含论文自身代码: {input_paper_repo['full_name']}")
print(f" 最终选取 {len(candidates)} 个仓库:")
for i, c in enumerate(candidates):
family = repo_family_map.get(c["full_name"], "未归类")
print(f" {i+1}. {c['full_name']} [{family}] Stars:{c.get('stars', 0)}")
print(f" 耗时: {time.time() - t0:.1f}s")
# ================================================================
# [5/6] 仓库详情获取 + 评估(Agent 2,并行)
# ================================================================
print()
print(f"[5/6] 正在获取仓库详情并评估(Agent 2,{len(candidates)} 个仓库并行)...")
_prog(0.60, f"正在评估 {len(candidates)} 个开源仓库...")
t0 = time.time()
def _eval_single(repo, idx, total):
"""单个仓库评估(在线程中执行)"""
full_name = repo["full_name"]
try:
owner, name = full_name.split("/", 1)
except ValueError:
return idx, None
matched_family = repo_family_map.get(full_name, "")
family_tag = f"[{matched_family}]" if matched_family else "[未归类]"
print(f" ({idx+1}/{total}) {full_name} {family_tag}...")
try:
readme = fetch_readme(owner, name)
deps = fetch_dependencies(owner, name)
evaluation = evaluate_repo(repo, readme, deps, matched_family)
except Exception as e:
print(f" [WARN] {full_name} 评估失败: {e}")
evaluation = _make_error_evaluation(str(e))
return idx, {
"full_name": repo.get("full_name", ""),
"html_url": repo.get("html_url", ""),
"description": repo.get("description", ""),
"stars": repo.get("stars", 0),
"language": repo.get("language", ""),
"updated_at": repo.get("updated_at", ""),
"topics": repo.get("topics", []),
"match_keyword": repo.get("match_keyword", ""),
"method_family": matched_family,
"evaluation": evaluation,
}
n = len(candidates)
evaluated = [None] * n
with ThreadPoolExecutor(max_workers=min(n, 5)) as executor:
futures = {
executor.submit(_eval_single, repo, i, n): i
for i, repo in enumerate(candidates)
}
for future in as_completed(futures):
idx, result = future.result()
if result is not None:
evaluated[idx] = result
evaluated = [r for r in evaluated if r is not None]
# 按综合评分降序排列
evaluated.sort(key=lambda r: r["evaluation"].get("overall_score", 0), reverse=True)
print(f" 耗时: {time.time() - t0:.1f}s")
# ================================================================
# 汇总
# ================================================================
print()
print("=" * 60)
print("完成!")
best = evaluated[0] if evaluated else None
if best:
print(f" 最高分: {best['full_name']} ({best['evaluation'].get('overall_score', 0)}/100)")
# ================================================================
# 审核层:检查 Agent 输出质量和来源可靠性
# ================================================================
_prog(0.90, "正在进行质量审核...")
audit = supervise(title, abstract, direction, evaluated)
print(f"\n 审核结果: {audit['summary']}")
print(f" 综合质量评分: {audit['overall_score']}/100")
if audit["actions"]:
for action in audit["actions"]:
print(f" → {action}")
print(f" 总耗时: {time.time() - t_start:.1f}s")
print("=" * 60)
_prog(1.0, "分析完成,正在生成研报...")
return {
"paper": paper,
"direction": direction,
"repos": evaluated,
"audit": audit,
}
def _extract_broad_queries(title: str, abstract: str) -> list[str]:
"""用 LLM 从论文标题和摘要中生成 GitHub 搜索关键词。
相比规则提取,LLM 理解论文后能生成更精准的领域搜索词。
如果 LLM 调用失败,降级为规则提取。
"""
system_prompt = """你是学术论文搜索专家。根据论文标题和摘要,生成 6-8 个 GitHub 搜索查询,用于找到该研究方向的开源实现。
要求:
- 查询必须是英文,用空格分隔关键词
- 包含方法名/技术关键词 + 限定词(implementation / pytorch / code)
- 包含 1-2 个宽泛的领域查询(如 "anomaly detection pytorch library")
- 不要用引号或特殊符号
- 避免太泛的词(如单一个 "anomaly")
输出严格 JSON:
{"queries": ["query1", "query2", ...]}"""
user_prompt = f"""论文标题: {title[:300]}
摘要: {abstract[:800]}
请为该论文生成 GitHub 搜索查询。"""
try:
raw = call_llm_json(system_prompt, user_prompt, temperature=0.3, max_tokens=3000)
data = parse_json_safe(raw, "broad_queries")
queries = data.get("queries", [])
if isinstance(queries, list) and len(queries) >= 3:
return queries[:10]
except Exception as e:
print(f" [WARN] LLM 关键词生成失败,降级为规则提取: {e}")
return _extract_broad_queries_fallback(title, abstract)
def _extract_broad_queries_fallback(title: str, abstract: str) -> list[str]:
"""从论文标题和摘要中提取宽泛搜索关键词(规则降级版)。
策略:提取有意义的词组 + 搜索限定词,按多样性采样(不偏向标题前缀)。
不依赖 LLM,纯规则提取。
"""
abstract_first = abstract.split(".")[0] if abstract else ""
text = f"{title} {abstract_first}".lower()
text = re.sub(r'[^\w\s-]', ' ', text)
words = text.split()
stop_words = {
'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are',
'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it',
'its', 'not', 'can', 'has', 'have', 'been', 'was', 'were', 'will', 'would',
'could', 'should', 'may', 'do', 'does', 'did', 'so', 'if', 'no', 'new',
'based', 'using', 'which', 'into', 'such', 'than', 'then', 'these', 'those',
'propose', 'present', 'method', 'approach', 'framework', 'novel', 'model',
'learning', 'deep', 'via', 'et', 'al', 'all', 'we', 'you', 'they',
}
meaningful = [w for w in words if w not in stop_words and len(w) > 1]
bigrams = [f"{meaningful[i]} {meaningful[i+1]}" for i in range(len(meaningful) - 1)]
trigrams = [f"{meaningful[i]} {meaningful[i+1]} {meaningful[i+2]}" for i in range(len(meaningful) - 2)]
def _sample_diverse(items: list[str], n: int) -> list[str]:
if len(items) <= n:
return items
step = max(1, len(items) // n)
return [items[i] for i in range(0, len(items), step)][:n]
phrases = _sample_diverse(trigrams, 3) + _sample_diverse(bigrams, 4)
seen = set()
unique_phrases = []
for p in sorted(phrases, key=len, reverse=True):
if p not in seen:
seen.add(p)
unique_phrases.append(p)
qualifiers = [
"implementation pytorch",
"official code",
"pytorch github",
]
queries = []
for phrase in unique_phrases[:5]:
if len(phrase) > 5:
queries.append(f"{phrase} {qualifiers[0]}")
mid = len(meaningful) // 2
broad = " ".join(meaningful[max(0, mid-2):mid+2])
if len(broad) > 10:
queries.append(f"{broad} {qualifiers[2]}")
seen = set()
unique = []
for q in queries:
if q not in seen:
seen.add(q)
unique.append(q)
return unique[:10]
def _filter_repos(title: str, abstract: str, repos: list[dict]) -> list[dict]:
"""用 LLM 快速过滤不相关的仓库,只保留与论文方向相关的。
在 GitHub 关键词搜索之后、Agent 1 分析之前运行。
如果过滤后为空或 LLM 调用失败,返回原始列表作为降级方案。
"""
if len(repos) <= 3:
return repos
# 超过 20 个仓库时,只取 top 20(按 stars),避免 JSON 输出过长
if len(repos) > 20:
repos_for_filter = sorted(repos, key=lambda r: r.get("stars", 0), reverse=True)[:20]
else:
repos_for_filter = repos
# 构造缩略的仓库清单供 LLM 判断
repo_list_parts = []
for i, r in enumerate(repos_for_filter):
desc = (r.get("description") or "")[:120]
topics = ", ".join(r.get("topics", [])[:8])
repo_list_parts.append(
f"{i+1}. {r['full_name']} (Stars:{r.get('stars',0)})\n"
f" Description: {desc}\n"
f" Topics: {topics}"
)
repo_list_text = "\n".join(repo_list_parts)
system_prompt = """你是一个学术论文与开源代码匹配专家。任务是判断 GitHub 仓库是否与给定论文属于同一研究方向。
排除标准(满足任一即排除):
- 仓库解决的业务问题与论文完全不同(如论文做工业缺陷检测,仓库做医学图像/自动驾驶/人脸识别)
- 仓库使用的核心技术方法与论文完全无关
- 仓库是通用教程/课程作业/面试题集合
保留标准(满足任一即保留):
- 仓库实现的方法与论文方法同属一个技术范式
- 仓库可用于该论文的对比实验(baseline/comparison)
- 仓库是该领域的知名 benchmarking 库
输出严格 JSON(不要 reasoning 字段以节省 token):
{"relevant": ["owner/repo1", "owner/repo2"], "irrelevant": ["owner/repo3"]}"""
user_prompt = f"""## 论文信息
标题: {title[:200]}
摘要: {abstract[:500]}
## GitHub 仓库清单(共 {len(repos_for_filter)} 个)
{repo_list_text}
请判断每个仓库是否与该论文方向相关。"""
try:
raw = call_llm_json(system_prompt, user_prompt, temperature=0.1, max_tokens=10000)
data = parse_json_safe(raw, "filter_repos")
except Exception as e:
print(f" [WARN] 仓库过滤失败,保留全部: {e}")
return repos
relevant_set = set(data.get("relevant", []))
irrelevant = data.get("irrelevant", [])
if irrelevant:
print(f" [过滤] 剔除 {len(irrelevant)} 个不相关仓库: {irrelevant}")
filtered = [r for r in repos if r["full_name"] in relevant_set]
if len(filtered) < 2:
print(f" [WARN] 过滤后仅剩 {len(filtered)} 个仓库,回退到原始结果")
return repos
print(f" 过滤后保留 {len(filtered)}/{len(repos)} 个仓库")
return filtered
def _sanity_check_direction(title: str, subfield: str) -> None:
"""防混淆检查:验证子领域分析是否与论文标题存在最低限度的语义关联。
如果子领域关键词与标题完全无关,可能是 LLM 混淆了论文。
仅打印警告,不阻断流程。
"""
# 提取标题中的实义词(长度>=4,排除停用词)
title_stops = {
'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are',
'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it',
'its', 'not', 'can', 'has', 'have', 'been', 'was', 'were', 'will', 'would',
'could', 'should', 'may', 'do', 'does', 'did', 'so', 'if', 'no', 'new',
'based', 'using', 'which', 'into', 'such', 'than', 'then', 'these', 'those',
'propose', 'present', 'method', 'approach', 'framework', 'novel', 'model',
'learning', 'deep', 'via', 'et', 'al',
}
title_words = set(
w.lower() for w in re.sub(r'[^\w\s]', ' ', title).split()
if len(w) >= 4 and w.lower() not in title_stops
)
subfield_lower = subfield.lower()
overlap = [w for w in title_words if w in subfield_lower]
if not overlap and title_words:
print(f" ⚠️ [防混淆警告] 子领域\"{subfield}\"与论文标题无关键词重叠")
print(f" 标题关键词: {sorted(title_words)[:10]}")
print(f" 这可能是 LLM 混淆了论文,请人工核实分析结果。")
def _make_fallback_direction(title: str, abstract: str, repos: list[dict]) -> dict:
"""Agent 1 失败时的降级方向分析:基于搜索到的仓库名称推断子领域。"""
# 从仓库 topic/description 提取高频词作为子领域
all_words = []
for r in repos[:10]:
desc = (r.get("description") or "")
topics = " ".join(r.get("topics", []))
all_words.extend((desc + " " + topics).lower().split())
stops = {'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are',
'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it'}
meaningful = [w for w in all_words if w not in stops and len(w) >= 4]
word_freq = {}
for w in meaningful:
word_freq[w] = word_freq.get(w, 0) + 1
top_words = sorted(word_freq, key=word_freq.get, reverse=True)[:6]
return {
"subfield": f"基于仓库数据推断: {', '.join(top_words[:3])}" if top_words else "未知领域",
"subfield_trend": "(Agent 1 暂不可用,趋势分析跳过。后续版本将自动恢复。)",
"method_families": [],
"broad_queries": [],
}
def _make_error_evaluation(error_msg: str) -> dict:
"""构造一个表示评估失败的 evaluation dict"""
return {
"reproducibility_score": 0,
"benchmark_fitness_score": 0,
"overall_score": 0,
"verdict": "error",
"env_score": 0,
"doc_score": 0,
"code_score": 0,
"community_score": 0,
"dep_score": 0,
"benchmark_score": 0,
"reasoning": f"评估失败: {error_msg[:100]}",
"risks": ["评估过程出错,请手动检查该仓库"],
"benchmark_readiness": "not_ready",
"suggested_use": "评估失败,请手动检查",
}
# ============================================================
# 命令行入口
# ============================================================
if __name__ == "__main__":
# 修复 Windows 终端 emoji 编码问题
from llm_utils import fix_windows_encoding
fix_windows_encoding()
if len(sys.argv) < 2:
print("用法: python run.py <arxiv_url>")
print("示例: python run.py https://arxiv.org/abs/1706.03762")
print("示例: python run.py https://arxiv.org/abs/2011.08785")
sys.exit(1)
url = sys.argv[1]
result = run(url)
if result.get("error"):
print(f"\n{'='*60}")
print(f"运行未完全成功")
print(f"{'='*60}")
print(f" {result['error']}")
if result.get("paper"):
print(f"\n 论文信息已获取: {result['paper'].get('title', '')[:80]}")
if result.get("direction"):
print(f" 方向解析已完成: {result['direction'].get('subfield', '')}")
else:
from app import format_report
# 用 utf-8 编码输出避免 emoji 乱码
report = format_report(result)
print(report.encode(sys.stdout.encoding or 'utf-8', errors='replace').decode(sys.stdout.encoding or 'utf-8', errors='replace'))