Spaces:
Running
Running
Upload run.py with huggingface_hub
Browse files
run.py
CHANGED
|
@@ -9,6 +9,10 @@ import sys
|
|
| 9 |
import time
|
| 10 |
import re
|
| 11 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 13 |
|
| 14 |
# 自动加载 .env 文件
|
|
@@ -52,6 +56,140 @@ from repo_evaluator import evaluate_repo
|
|
| 52 |
from llm_utils import call_llm_json, parse_json_safe
|
| 53 |
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
def run(arxiv_url: str, top_n: int = 5) -> dict:
|
| 56 |
"""主入口:先搜索 GitHub → 再基于实际仓库归纳方法族 → 最后评估。
|
| 57 |
|
|
@@ -88,6 +226,20 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
|
|
| 88 |
print(f" 分类: {', '.join(paper.get('categories', []))}")
|
| 89 |
print(f" 耗时: {time.time() - t0:.1f}s")
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
# ================================================================
|
| 92 |
# [2/5] 宽泛 GitHub 搜索(甲 Workflow)
|
| 93 |
# ================================================================
|
|
@@ -96,12 +248,14 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
|
|
| 96 |
t0 = time.time()
|
| 97 |
|
| 98 |
broad_queries = _extract_broad_queries(title, abstract)
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
| 101 |
print(f" - {q}")
|
| 102 |
|
| 103 |
try:
|
| 104 |
-
broad_results = search_repos(
|
| 105 |
except Exception as e:
|
| 106 |
return {
|
| 107 |
"paper": paper,
|
|
@@ -128,6 +282,19 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
|
|
| 128 |
filtered_results = _filter_repos(title, abstract, broad_results)
|
| 129 |
print(f" 耗时: {time.time() - t0:.1f}s")
|
| 130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
# ================================================================
|
| 132 |
# [3/5] 基于仓库数据归纳方法族(Agent 1) ← 核心改动
|
| 133 |
# ================================================================
|
|
@@ -136,7 +303,7 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
|
|
| 136 |
t0 = time.time()
|
| 137 |
|
| 138 |
try:
|
| 139 |
-
direction = analyze_direction(title, abstract, filtered_results)
|
| 140 |
except Exception as e:
|
| 141 |
return {
|
| 142 |
"paper": paper,
|
|
@@ -151,6 +318,9 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
|
|
| 151 |
for mf in families:
|
| 152 |
matched = mf.get("matched_repos", [])
|
| 153 |
print(f" - {mf.get('family_name', '?')}: {len(matched)} 个仓库 {matched}")
|
|
|
|
|
|
|
|
|
|
| 154 |
print(f" 耗时: {time.time() - t0:.1f}s")
|
| 155 |
|
| 156 |
# ================================================================
|
|
@@ -435,6 +605,34 @@ def _filter_repos(title: str, abstract: str, repos: list[dict]) -> list[dict]:
|
|
| 435 |
return filtered
|
| 436 |
|
| 437 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
def _make_error_evaluation(error_msg: str) -> dict:
|
| 439 |
"""构造一个表示评估失败的 evaluation dict"""
|
| 440 |
return {
|
|
|
|
| 9 |
import time
|
| 10 |
import re
|
| 11 |
import os
|
| 12 |
+
import json
|
| 13 |
+
import urllib.request
|
| 14 |
+
import urllib.error
|
| 15 |
+
import xml.etree.ElementTree as ET
|
| 16 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 17 |
|
| 18 |
# 自动加载 .env 文件
|
|
|
|
| 56 |
from llm_utils import call_llm_json, parse_json_safe
|
| 57 |
|
| 58 |
|
| 59 |
+
def _enrich_domain_context(title: str, abstract: str, categories: list[str]) -> str:
|
| 60 |
+
"""从 arxiv 搜索同领域的综述和相关高引论文,扩充领域上下文。
|
| 61 |
+
|
| 62 |
+
当输入论文不是综述时,普通研究论文的摘要不足以让 Agent 1 推断领域全景。
|
| 63 |
+
此函数从 arxiv 搜索相关综述/survey 论文,提取摘要作为补充上下文。
|
| 64 |
+
失败时返回空字符串(静默降级)。
|
| 65 |
+
"""
|
| 66 |
+
if not categories:
|
| 67 |
+
return ""
|
| 68 |
+
|
| 69 |
+
primary_cat = categories[0]
|
| 70 |
+
# 用主要分类 + 关键词搜索综述/survey 论文
|
| 71 |
+
keywords = ["survey", "review", "comprehensive", "benchmark"]
|
| 72 |
+
all_abstracts = []
|
| 73 |
+
|
| 74 |
+
for kw in keywords[:2]: # 只搜两个关键词,避免触发 arxiv 限速
|
| 75 |
+
search_query = f"cat:{primary_cat}+AND+ti:{kw}"
|
| 76 |
+
api_url = (
|
| 77 |
+
f"http://export.arxiv.org/api/query"
|
| 78 |
+
f"?search_query={search_query}&sortBy=relevance&max_results=2"
|
| 79 |
+
)
|
| 80 |
+
req = urllib.request.Request(api_url, headers={"User-Agent": "ResearchRadar/1.0"})
|
| 81 |
+
try:
|
| 82 |
+
with urllib.request.urlopen(req, timeout=15) as resp:
|
| 83 |
+
xml_text = resp.read().decode("utf-8")
|
| 84 |
+
root = ET.fromstring(xml_text)
|
| 85 |
+
ns = {"atom": "http://www.w3.org/2005/Atom"}
|
| 86 |
+
for entry in root.findall("atom:entry", ns):
|
| 87 |
+
t = entry.find("atom:title", ns)
|
| 88 |
+
s = entry.find("atom:summary", ns)
|
| 89 |
+
if t is not None and t.text and s is not None and s.text:
|
| 90 |
+
all_abstracts.append({
|
| 91 |
+
"title": t.text.strip().replace("\n", " "),
|
| 92 |
+
"abstract": s.text.strip().replace("\n", " ")[:800],
|
| 93 |
+
})
|
| 94 |
+
except Exception:
|
| 95 |
+
continue
|
| 96 |
+
|
| 97 |
+
if not all_abstracts:
|
| 98 |
+
return ""
|
| 99 |
+
|
| 100 |
+
# 组装为上下文文本
|
| 101 |
+
lines = ["## 该领域的相关综述/调查论文(供参考领域全景)"]
|
| 102 |
+
for i, a in enumerate(all_abstracts[:4]):
|
| 103 |
+
lines.append(f"{i+1}. **{a['title']}**: {a['abstract']}")
|
| 104 |
+
return "\n\n".join(lines)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _mine_comparison_algorithms(arxiv_id: str, title: str, abstract: str) -> tuple[list[str], dict[str, int]]:
|
| 108 |
+
"""从 Semantic Scholar 引用网络中挖掘对比实验算法。
|
| 109 |
+
|
| 110 |
+
核心思路:论文的 references(引用的论文)中通常包含对比实验的 baseline 方法,
|
| 111 |
+
通过提取这些论文的标题作为额外搜索词,可以大幅提升 GitHub 搜索的覆盖度。
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
(extra_queries, citation_map): 额外搜索词列表, {paper_title: citation_count}
|
| 115 |
+
"""
|
| 116 |
+
s2_url = (
|
| 117 |
+
f"https://api.semanticscholar.org/graph/v1/paper/ArXiv:{arxiv_id}"
|
| 118 |
+
f"?fields=references.title,references.citationCount,references.abstract"
|
| 119 |
+
f"&limit=50"
|
| 120 |
+
)
|
| 121 |
+
req = urllib.request.Request(s2_url, headers={"User-Agent": "ResearchRadar/1.0"})
|
| 122 |
+
try:
|
| 123 |
+
with urllib.request.urlopen(req, timeout=20) as resp:
|
| 124 |
+
data = json.loads(resp.read().decode("utf-8"))
|
| 125 |
+
except Exception as e:
|
| 126 |
+
print(f" [WARN] Semantic Scholar API 不可用,跳过对比算法挖掘: {e}")
|
| 127 |
+
return [], {}
|
| 128 |
+
|
| 129 |
+
refs = data.get("references", [])
|
| 130 |
+
if not refs:
|
| 131 |
+
return [], {}
|
| 132 |
+
|
| 133 |
+
# 按引用量排序,取 top 15
|
| 134 |
+
refs.sort(key=lambda r: r.get("citationCount", 0), reverse=True)
|
| 135 |
+
top_refs = refs[:15]
|
| 136 |
+
|
| 137 |
+
# 构建标题列表供 LLM 识别方法论文
|
| 138 |
+
title_list = []
|
| 139 |
+
citation_map = {}
|
| 140 |
+
for r in top_refs:
|
| 141 |
+
t = (r.get("title") or "").strip()
|
| 142 |
+
cc = r.get("citationCount", 0)
|
| 143 |
+
if t and len(t) > 10:
|
| 144 |
+
title_list.append(f"- [{cc} cites] {t}")
|
| 145 |
+
citation_map[t] = cc
|
| 146 |
+
|
| 147 |
+
if len(title_list) < 3:
|
| 148 |
+
return [], citation_map
|
| 149 |
+
|
| 150 |
+
# 用 LLM 从引用论文标题中识别哪些是方法/算法论文
|
| 151 |
+
system_prompt = """你是学术论文分析专家。从引用论文列表中识别哪些是提出了具体算法/方法的论文。
|
| 152 |
+
|
| 153 |
+
排除标准:
|
| 154 |
+
- 数据集/benchmark 论文(如 ImageNet, CIFAR, MVTec AD)
|
| 155 |
+
- 综述/survey 论文
|
| 156 |
+
- 纯理论/数学论文
|
| 157 |
+
- 框架/库论文(如 PyTorch, TensorFlow)
|
| 158 |
+
|
| 159 |
+
保留标准:
|
| 160 |
+
- 提出了具体的模型/架构/算法名称
|
| 161 |
+
- 可以作为对比实验的 baseline 方法
|
| 162 |
+
|
| 163 |
+
输出严格 JSON:
|
| 164 |
+
{"methods": ["方法名1", "方法名2"], "search_queries": ["method1 pytorch implementation", "method2 official code"]}
|
| 165 |
+
|
| 166 |
+
方法名尽量使用论文中常用的英文缩写或全称。"""
|
| 167 |
+
|
| 168 |
+
user_prompt = f"""输入论文标题: {title[:200]}
|
| 169 |
+
|
| 170 |
+
引用论文列表:
|
| 171 |
+
{chr(10).join(title_list)}
|
| 172 |
+
|
| 173 |
+
请识别哪些是方法/算法论文,生成对应的 GitHub 搜索词。"""
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
raw = call_llm_json(system_prompt, user_prompt, temperature=0.2, max_tokens=800)
|
| 177 |
+
data = parse_json_safe(raw, "comparison_miner")
|
| 178 |
+
methods = data.get("methods", [])
|
| 179 |
+
queries = data.get("search_queries", [])
|
| 180 |
+
print(f" 从引用网络识别到 {len(methods)} 个对比算法: {methods[:8]}")
|
| 181 |
+
return queries[:8], citation_map
|
| 182 |
+
except Exception as e:
|
| 183 |
+
print(f" [WARN] 对比算法识别失败,降级: {e}")
|
| 184 |
+
# 降级:直接用引用论文标题作为搜索词
|
| 185 |
+
fallback_queries = []
|
| 186 |
+
for t in list(citation_map.keys())[:5]:
|
| 187 |
+
short = re.sub(r'[^\w\s-]', '', t).strip()[:80]
|
| 188 |
+
if len(short) > 15:
|
| 189 |
+
fallback_queries.append(f"{short} pytorch implementation")
|
| 190 |
+
return fallback_queries, citation_map
|
| 191 |
+
|
| 192 |
+
|
| 193 |
def run(arxiv_url: str, top_n: int = 5) -> dict:
|
| 194 |
"""主入口:先搜索 GitHub → 再基于实际仓库归纳方法族 → 最后评估。
|
| 195 |
|
|
|
|
| 226 |
print(f" 分类: {', '.join(paper.get('categories', []))}")
|
| 227 |
print(f" 耗时: {time.time() - t0:.1f}s")
|
| 228 |
|
| 229 |
+
# ================================================================
|
| 230 |
+
# [1.5] 对比实验算法挖掘(从 S2 引用网络提取算法名)
|
| 231 |
+
# ================================================================
|
| 232 |
+
print()
|
| 233 |
+
print("[1.5] 正在从论文引用网络挖掘对比实验算法...")
|
| 234 |
+
t0 = time.time()
|
| 235 |
+
arxiv_id = paper.get("arxiv_id", "")
|
| 236 |
+
comparison_queries, citation_map = _mine_comparison_algorithms(arxiv_id, title, abstract)
|
| 237 |
+
if comparison_queries:
|
| 238 |
+
print(f" 生成 {len(comparison_queries)} 个对比算法搜索词:")
|
| 239 |
+
for q in comparison_queries[:5]:
|
| 240 |
+
print(f" - {q}")
|
| 241 |
+
print(f" 耗时: {time.time() - t0:.1f}s")
|
| 242 |
+
|
| 243 |
# ================================================================
|
| 244 |
# [2/5] 宽泛 GitHub 搜索(甲 Workflow)
|
| 245 |
# ================================================================
|
|
|
|
| 248 |
t0 = time.time()
|
| 249 |
|
| 250 |
broad_queries = _extract_broad_queries(title, abstract)
|
| 251 |
+
# 合并对比算法搜索词,去重
|
| 252 |
+
all_queries = list(dict.fromkeys(comparison_queries + broad_queries))[:12]
|
| 253 |
+
print(f" 生成 {len(broad_queries)} 个宽泛搜索词 + {len(comparison_queries)} 个对比算法搜索词 = {len(all_queries)} 个总搜索词:")
|
| 254 |
+
for q in all_queries:
|
| 255 |
print(f" - {q}")
|
| 256 |
|
| 257 |
try:
|
| 258 |
+
broad_results = search_repos(all_queries, max_per_keyword=5)
|
| 259 |
except Exception as e:
|
| 260 |
return {
|
| 261 |
"paper": paper,
|
|
|
|
| 282 |
filtered_results = _filter_repos(title, abstract, broad_results)
|
| 283 |
print(f" 耗时: {time.time() - t0:.1f}s")
|
| 284 |
|
| 285 |
+
# ================================================================
|
| 286 |
+
# [2.6] 领域上下文扩充(从 arxiv 搜索综述论文补充领域知识)
|
| 287 |
+
# ================================================================
|
| 288 |
+
print()
|
| 289 |
+
print("[2.6] 正在扩充领域上下文(搜索相关综述论文)...")
|
| 290 |
+
t0 = time.time()
|
| 291 |
+
domain_context = _enrich_domain_context(title, abstract, paper.get("categories", []))
|
| 292 |
+
if domain_context:
|
| 293 |
+
print(f" 获取到 {domain_context.count('**') // 2} 篇相关综述的摘要")
|
| 294 |
+
else:
|
| 295 |
+
print(f" 未找到相关综述(将仅基于论文摘要和仓库数据进行分析)")
|
| 296 |
+
print(f" 耗时: {time.time() - t0:.1f}s")
|
| 297 |
+
|
| 298 |
# ================================================================
|
| 299 |
# [3/5] 基于仓库数据归纳方法族(Agent 1) ← 核心改动
|
| 300 |
# ================================================================
|
|
|
|
| 303 |
t0 = time.time()
|
| 304 |
|
| 305 |
try:
|
| 306 |
+
direction = analyze_direction(title, abstract, filtered_results, domain_context)
|
| 307 |
except Exception as e:
|
| 308 |
return {
|
| 309 |
"paper": paper,
|
|
|
|
| 318 |
for mf in families:
|
| 319 |
matched = mf.get("matched_repos", [])
|
| 320 |
print(f" - {mf.get('family_name', '?')}: {len(matched)} 个仓库 {matched}")
|
| 321 |
+
|
| 322 |
+
# 防混淆校验:检查子领域是否与论文标题存在语义关联
|
| 323 |
+
_sanity_check_direction(title, subfield)
|
| 324 |
print(f" 耗时: {time.time() - t0:.1f}s")
|
| 325 |
|
| 326 |
# ================================================================
|
|
|
|
| 605 |
return filtered
|
| 606 |
|
| 607 |
|
| 608 |
+
def _sanity_check_direction(title: str, subfield: str) -> None:
|
| 609 |
+
"""防混淆检查:验证子领域分析是否与论文标题存在最低限度的语义��联。
|
| 610 |
+
|
| 611 |
+
如果子领域关键词与标题完全无关,可能是 LLM 混淆了论文。
|
| 612 |
+
仅打印警告,不阻断流程。
|
| 613 |
+
"""
|
| 614 |
+
# 提取标题中的实义词(长度>=4,排除停用词)
|
| 615 |
+
title_stops = {
|
| 616 |
+
'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are',
|
| 617 |
+
'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it',
|
| 618 |
+
'its', 'not', 'can', 'has', 'have', 'been', 'was', 'were', 'will', 'would',
|
| 619 |
+
'could', 'should', 'may', 'do', 'does', 'did', 'so', 'if', 'no', 'new',
|
| 620 |
+
'based', 'using', 'which', 'into', 'such', 'than', 'then', 'these', 'those',
|
| 621 |
+
'propose', 'present', 'method', 'approach', 'framework', 'novel', 'model',
|
| 622 |
+
'learning', 'deep', 'via', 'et', 'al',
|
| 623 |
+
}
|
| 624 |
+
title_words = set(
|
| 625 |
+
w.lower() for w in re.sub(r'[^\w\s]', ' ', title).split()
|
| 626 |
+
if len(w) >= 4 and w.lower() not in title_stops
|
| 627 |
+
)
|
| 628 |
+
subfield_lower = subfield.lower()
|
| 629 |
+
overlap = [w for w in title_words if w in subfield_lower]
|
| 630 |
+
if not overlap and title_words:
|
| 631 |
+
print(f" ⚠️ [防混淆警告] 子领域\"{subfield}\"与论文标题无关键词重叠")
|
| 632 |
+
print(f" 标题关键词: {sorted(title_words)[:10]}")
|
| 633 |
+
print(f" 这可能是 LLM 混淆了论文,请人工核实分析结果。")
|
| 634 |
+
|
| 635 |
+
|
| 636 |
def _make_error_evaluation(error_msg: str) -> dict:
|
| 637 |
"""构造一个表示评估失败的 evaluation dict"""
|
| 638 |
return {
|