Spaces:
Running
Running
File size: 38,058 Bytes
03c63b9 fa671bd 03c63b9 dfce9ec 03c63b9 fa671bd 4b7df8a fa671bd 4b7df8a fa671bd 4b7df8a fa671bd 4b7df8a fa671bd 4b7df8a fa671bd 4b7df8a fa671bd 4b7df8a fa671bd 4b7df8a fa671bd 4b7df8a fa671bd 4f7835c fa671bd 4b7df8a cc40543 4b7df8a cc40543 4b7df8a cc40543 4b7df8a 151e05c 097cbfc 03c63b9 097cbfc 03c63b9 097cbfc 03c63b9 097cbfc 03c63b9 097cbfc 03c63b9 151e05c fa671bd 097cbfc fa671bd 4b7df8a 097cbfc 4b7df8a 03c63b9 097cbfc 03c63b9 097cbfc 03c63b9 fa671bd 908d065 4b7df8a fa671bd 03c63b9 fa671bd 03c63b9 097cbfc 03c63b9 fa671bd 097cbfc fa671bd 097cbfc fa671bd 03c63b9 097cbfc 03c63b9 097cbfc 03c63b9 fa671bd 03c63b9 908d065 03c63b9 fa671bd 03c63b9 097cbfc 03c63b9 097cbfc 03c63b9 151e05c 03c63b9 151e05c 03c63b9 097cbfc 03c63b9 097cbfc 03c63b9 dfce9ec 097cbfc dfce9ec 03c63b9 097cbfc 03c63b9 dfce9ec 03c63b9 4f7835c 03c63b9 4f7835c 03c63b9 fa671bd 908d065 03c63b9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 | # run.py
# ============================================================
# 类型:调度器(乙负责)
# 功能:按"先搜再分析"流程串联 Workflow(甲)和 Agent(乙)
# 用法:python run.py <arxiv_url>
# 示例:python run.py https://arxiv.org/abs/2011.08785
# ============================================================
import sys
import time
import re
import os
import json
import urllib.request
import urllib.error
import xml.etree.ElementTree as ET
from concurrent.futures import ThreadPoolExecutor, as_completed
# 自动加载 .env 文件
_ENV_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".env")
if os.path.isfile(_ENV_PATH):
with open(_ENV_PATH, "r", encoding="utf-8") as _f:
for _line in _f:
_line = _line.strip()
if _line and not _line.startswith("#") and "=" in _line:
_key, _val = _line.split("=", 1)
if _key not in os.environ:
os.environ[_key] = _val.strip().strip('"').strip("'")
# ---- 导入甲的 Workflow 模块 ----
try:
from paper_fetcher import fetch_paper_info
except ImportError as e:
raise ImportError(
f"缺少模块: paper_fetcher.py(甲负责),请等待甲交付。\n"
f" 原始错误: {e}"
) from e
try:
from repo_searcher import search_repos
except ImportError as e:
raise ImportError(
f"缺少模块: repo_searcher.py(甲负责),请等待甲交付。\n"
f" 原始错误: {e}"
) from e
try:
from repo_fetcher import fetch_readme, fetch_dependencies
except ImportError as e:
raise ImportError(
f"缺少模块: repo_fetcher.py(甲负责),请等待甲交付。\n"
f" 原始错误: {e}"
) from e
from direction_analyzer import analyze_direction
from repo_evaluator import evaluate_repo
from supervisor import supervise
from llm_utils import call_llm_json, parse_json_safe
def _enrich_domain_context(title: str, abstract: str, categories: list[str]) -> str:
"""从 Semantic Scholar 搜索同领域相关论文,扩充领域上下文。
S2 覆盖全工科(含 IEEE/ASME/ASCE 等期刊),不限于 arXiv。
搜索策略:1) 论文标题关键词 + survey/review 2) 摘要关键词
失败时返回空字符串(静默降级)。
"""
all_abstracts = []
# 从标题提取 2-3 个核心关键词作为搜索词
title_keywords = _extract_title_keywords(title)
search_queries = [
f"{' '.join(title_keywords)} survey review",
f"{' '.join(title_keywords)} state of the art",
]
for query in search_queries[:2]:
api_url = (
f"https://api.semanticscholar.org/graph/v1/paper/search"
f"?query={urllib.request.quote(query)}&limit=3"
f"&fields=title,abstract,year,citationCount"
)
req = urllib.request.Request(api_url, headers={"User-Agent": "ResearchRadar/1.0"})
try:
with urllib.request.urlopen(req, timeout=15) as resp:
data = json.loads(resp.read().decode("utf-8"))
for paper in data.get("data", []):
t = (paper.get("title") or "").strip()
s = (paper.get("abstract") or "").strip()
if t and s:
all_abstracts.append({
"title": t,
"abstract": s[:800],
"year": paper.get("year", ""),
"citations": paper.get("citationCount", 0),
})
except Exception:
continue
if not all_abstracts:
return ""
# 组装为上下文文本
lines = ["## 该领域相关论文(来自 Semantic Scholar,覆盖全工科领域)"]
for i, a in enumerate(all_abstracts[:4]):
year_str = f" ({a['year']})" if a.get("year") else ""
cites_str = f" [{a.get('citations', 0)} 引用]"
lines.append(f"{i+1}. **{a['title']}**{year_str}{cites_str}: {a['abstract']}")
return "\n\n".join(lines)
def _extract_title_keywords(title: str) -> list[str]:
"""从论文标题提取核心实义词作为搜索关键词。"""
stops = {
'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are',
'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it',
'its', 'not', 'can', 'has', 'have', 'been', 'was', 'were', 'will', 'would',
'could', 'should', 'may', 'do', 'does', 'did', 'so', 'if', 'no', 'new',
'based', 'using', 'which', 'into', 'such', 'than', 'then', 'these', 'those',
'propose', 'present', 'method', 'approach', 'framework', 'novel', 'model',
'learning', 'deep', 'via', 'et', 'al', 'towards', 'toward',
}
words = [w.lower() for w in re.sub(r'[^\w\s-]', ' ', title).split() if len(w) >= 4 and w.lower() not in stops]
# 去重保持顺序
seen = set()
uniq = []
for w in words:
if w not in seen:
seen.add(w)
uniq.append(w)
return uniq[:5]
def _mine_comparison_algorithms(arxiv_id: str, title: str, abstract: str) -> tuple[list[str], dict[str, int]]:
"""从 Semantic Scholar 引用网络中挖掘对比实验算法。
核心思路:论文的 references(引用的论文)中通常包含对比实验的 baseline 方法,
通过提取这些论文的标题作为额外搜索词,可以大幅提升 GitHub 搜索的覆盖度。
Returns:
(extra_queries, citation_map): 额外搜索词列表, {paper_title: citation_count}
"""
s2_url = (
f"https://api.semanticscholar.org/graph/v1/paper/ArXiv:{arxiv_id}"
f"?fields=references.title,references.citationCount,references.abstract"
f"&limit=50"
)
req = urllib.request.Request(s2_url, headers={"User-Agent": "ResearchRadar/1.0"})
try:
with urllib.request.urlopen(req, timeout=20) as resp:
data = json.loads(resp.read().decode("utf-8"))
except Exception as e:
print(f" [WARN] Semantic Scholar API 不可用,跳过对比算法挖掘: {e}")
return [], {}
refs = data.get("references", [])
if not refs:
return [], {}
# 按引用量排序,取 top 15
refs.sort(key=lambda r: r.get("citationCount", 0), reverse=True)
top_refs = refs[:15]
# 构建标题列表供 LLM 识别方法论文
title_list = []
citation_map = {}
for r in top_refs:
t = (r.get("title") or "").strip()
cc = r.get("citationCount", 0)
if t and len(t) > 10:
title_list.append(f"- [{cc} cites] {t}")
citation_map[t] = cc
if len(title_list) < 3:
return [], citation_map
# 用 LLM 从引用论文标题中识别哪些是方法/算法论文
system_prompt = """你是学术论文分析专家。从引用论文列表中识别哪些是提出了具体算法/方法的论文。
排除标准:
- 数据集/benchmark 论文(如 ImageNet, CIFAR, MVTec AD)
- 综述/survey 论文
- 纯理论/数学论文
- 框架/库论文(如 PyTorch, TensorFlow)
保留标准:
- 提出了具体的模型/架构/算法名称
- 可以作为对比实验的 baseline 方法
输出严格 JSON:
{"methods": ["方法名1", "方法名2"], "search_queries": ["method1 pytorch implementation", "method2 official code"]}
方法名尽量使用论文中常用的英文缩写或全称。"""
user_prompt = f"""输入论文标题: {title[:200]}
引用论文列表:
{chr(10).join(title_list)}
请识别哪些是方法/算法论文,生成对应的 GitHub 搜索词。"""
try:
raw = call_llm_json(system_prompt, user_prompt, temperature=0.2, max_tokens=3000)
data = parse_json_safe(raw, "comparison_miner")
methods = data.get("methods", [])
queries = data.get("search_queries", [])
print(f" 从引用网络识别到 {len(methods)} 个对比算法: {methods[:8]}")
return queries[:8], citation_map
except Exception as e:
print(f" [WARN] 对比算法识别失败,降级: {e}")
# 降级:直接用引用论文标题作为搜索词
fallback_queries = []
for t in list(citation_map.keys())[:5]:
short = re.sub(r'[^\w\s-]', '', t).strip()[:80]
if len(short) > 15:
fallback_queries.append(f"{short} pytorch implementation")
return fallback_queries, citation_map
def _search_s2_papers(query: str, limit: int = 5) -> list[dict]:
"""在 Semantic Scholar 中搜索论文,覆盖全工科领域(IEEE/ASME/ASCE 等期刊)。
带 30 分钟 TTL 缓存,降低 S2 API 压力。
Returns:
list[dict]: [{"title": ..., "abstract": ..., "year": ..., "citations": ..., "url": ...}, ...]
"""
from cache import s2_cache
cache_key = f"s2_search:{query}:{limit}"
cached = s2_cache.get(cache_key)
if cached is not None:
return cached
api_url = (
f"https://api.semanticscholar.org/graph/v1/paper/search"
f"?query={urllib.request.quote(query)}&limit={limit}"
f"&fields=title,abstract,year,citationCount,url,externalIds"
)
req = urllib.request.Request(api_url, headers={"User-Agent": "ResearchRadar/1.0"})
try:
with urllib.request.urlopen(req, timeout=15) as resp:
data = json.loads(resp.read().decode("utf-8"))
except Exception:
return []
papers = []
for p in data.get("data", []):
title = (p.get("title") or "").strip()
if not title:
continue
papers.append({
"title": title,
"abstract": (p.get("abstract") or "").strip()[:1000],
"year": p.get("year"),
"citations": p.get("citationCount", 0),
"url": p.get("url", ""),
"arxiv_id": (p.get("externalIds") or {}).get("ArXiv", ""),
"doi": (p.get("externalIds") or {}).get("DOI", ""),
})
s2_cache.set(cache_key, papers)
return papers
def _find_input_paper_repo(title: str, arxiv_id: str, authors: list[str]) -> dict | None:
"""搜索输入论文自身的官方代码仓库。
搜索策略(按优先级):
1. GitHub 搜索 arxiv ID
2. GitHub 搜索论文标题(精确匹配)
3. GitHub 搜索一作姓名 + 论文核心关键词
Returns:
找到则返回符合格式的候选仓库 dict,找不到返回 None
"""
import requests
token = os.getenv("GITHUB_TOKEN", "")
headers = {"Accept": "application/vnd.github.v3+json"}
if token:
headers["Authorization"] = f"Bearer {token}"
# 提取核心关键词(取标题前 6 个实义词)
stops = {
'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are',
'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it',
'its', 'not', 'can', 'has', 'have', 'been', 'was', 'were', 'will', 'would',
'could', 'should', 'may', 'do', 'does', 'did', 'so', 'if', 'no', 'new',
'based', 'using', 'which', 'into', 'such', 'than', 'then', 'these', 'those',
'propose', 'present', 'method', 'approach', 'framework', 'novel', 'model',
'learning', 'deep', 'via', 'et', 'al', 'towards', 'toward',
}
title_words = [w for w in re.sub(r'[^\w\s-]', ' ', title).split()
if w.lower() not in stops and len(w) >= 3]
core_keywords = " ".join(title_words[:6])
search_queries = [
f'"{arxiv_id}" in:name,description,readme',
f'"{core_keywords}" in:name,description stars:>=3',
]
# 用一作姓名 + 关键词搜索
first_author = authors[0] if authors else ""
if first_author:
last_name = first_author.split()[-1] if first_author.split() else ""
if len(last_name) >= 3:
search_queries.append(f"{last_name} {core_keywords[:80]} in:name,description")
for query in search_queries:
url = "https://api.github.com/search/repositories"
params = {"q": query, "sort": "stars", "order": "desc", "per_page": 5}
try:
resp = requests.get(url, headers=headers, params=params, timeout=15)
if resp.status_code in (403, 429):
continue
resp.raise_for_status()
data = resp.json()
except Exception:
continue
# 筛选:标题必须包含核心关键词中至少 2 个
for item in data.get("items", []):
repo_title = (item.get("description") or "").lower()
repo_name = item.get("full_name", "").lower()
combined = f"{repo_name} {repo_title}"
matches = sum(1 for kw in title_words[:6] if kw.lower() in combined)
if matches >= 2:
return {
"full_name": item["full_name"],
"html_url": item["html_url"],
"description": item.get("description", ""),
"stars": item.get("stargazers_count", 0),
"language": item.get("language", ""),
"updated_at": item.get("updated_at", ""),
"topics": item.get("topics", []),
"match_keyword": f"本文代码: {query[:60]}",
}
return None
def run(arxiv_url: str, top_n: int = 5, progress=None) -> dict:
"""主入口:先搜索 GitHub → 再基于实际仓库归纳方法族 → 最后评估。
Args:
arxiv_url: arxiv 论文 URL
top_n: 最终评估的仓库数量
progress: gr.Progress 实例或 callable(fraction, desc),用于前端进度条
Returns:
dict: {"paper": {...}, "direction": {...}, "repos": [{...}], "error": "..."}
"""
t_start = time.time()
def _prog(frac: float, desc: str):
if progress:
progress(frac, desc=desc)
# ================================================================
# [1/6] 论文信息获取(甲 Workflow)
# ================================================================
print("=" * 60)
print("[1/6] 正在获取论文信息...")
_prog(0.05, "正在从 arxiv 获取论文信息...")
t0 = time.time()
try:
paper = fetch_paper_info(arxiv_url)
except Exception as e:
return {"error": f"获取论文信息失败: {e}"}
title = paper.get("title", "")
abstract = paper.get("abstract", "")
print(f" 标题: {title[:100]}...")
print(f" 作者: {', '.join(paper.get('authors', [])[:3])}")
print(f" 分类: {', '.join(paper.get('categories', []))}")
print(f" 耗时: {time.time() - t0:.1f}s")
# ================================================================
# [1.2] 搜索输入论文的官方代码
# ================================================================
print()
print("[1.2] 正在搜索输入论文的官方代码...")
_prog(0.10, "正在搜索论文自身的官方代码...")
t0 = time.time()
input_paper_repo = _find_input_paper_repo(title, arxiv_id, paper.get("authors", []))
if input_paper_repo:
print(f" 找到论文官方代码: {input_paper_repo['full_name']} (Stars: {input_paper_repo.get('stars', 0)})")
else:
print(f" 未找到论文官方代码")
print(f" 耗时: {time.time() - t0:.1f}s")
# ================================================================
# [1.5] 对比实验算法挖掘(从 S2 引用网络提取算法名)
# ================================================================
print()
print("[1.5] 正在从论文引用网络挖掘对比实验算法...")
_prog(0.12, "正在从引用网络挖掘对比算法...")
t0 = time.time()
arxiv_id = paper.get("arxiv_id", "")
comparison_queries, citation_map = _mine_comparison_algorithms(arxiv_id, title, abstract)
if comparison_queries:
print(f" 生成 {len(comparison_queries)} 个对比算法搜索词:")
for q in comparison_queries[:5]:
print(f" - {q}")
print(f" 耗时: {time.time() - t0:.1f}s")
# ================================================================
# [1.8] S2 论文搜索(覆盖全工科领域,含 IEEE/ASME 等期刊)
# ================================================================
print()
print("[1.8] 正在 Semantic Scholar 中搜索同领域论文(覆盖全工科)...")
_prog(0.18, "正在 Semantic Scholar 搜索同领域论文...")
t0 = time.time()
title_kws = _extract_title_keywords(title)
s2_papers = _search_s2_papers(" ".join(title_kws), limit=8)
# 从 S2 论文标题中提取额外搜索词
s2_extra_queries = []
if s2_papers:
print(f" 找到 {len(s2_papers)} 篇相关论文:")
for p in s2_papers[:5]:
cits = p.get("citations", 0)
print(f" - [{cits} cites] {p['title'][:80]}")
# 用高引论文标题作为搜索词
if cits >= 50:
s2_quoted = p['title'].strip()[:80]
# 去掉特殊字符
s2_clean = re.sub(r'[^\w\s-]', '', s2_quoted).strip()
if len(s2_clean) > 15:
s2_extra_queries.append(f"{s2_clean} github")
print(f" 额外生成 {len(s2_extra_queries)} 个 S2 来源搜索词")
print(f" 耗时: {time.time() - t0:.1f}s")
# ================================================================
# [2/6] 宽泛 GitHub 搜索(甲 Workflow)
# ================================================================
print()
print("[2/6] 正在宽泛搜索 GitHub(先搜仓库,再让 Agent 分析)...")
_prog(0.22, "正在 GitHub 搜索开源仓库...")
t0 = time.time()
broad_queries = _extract_broad_queries(title, abstract)
# 合并对比算法搜索词,去重
has_github_token = bool(os.getenv("GITHUB_TOKEN", ""))
max_queries = 15 if has_github_token else 8
all_queries = list(dict.fromkeys(s2_extra_queries + comparison_queries + broad_queries))[:max_queries]
if not has_github_token:
print(f" ⚠️ 未设置 GITHUB_TOKEN,搜索词限制为 {max_queries} 个以避免限速")
print(f" LLM: {len(broad_queries)} + S2引用: {len(comparison_queries)} + S2论文: {len(s2_extra_queries)} = {len(all_queries)} 个总搜索词:")
for q in all_queries:
print(f" - {q}")
try:
broad_results = search_repos(all_queries, max_per_keyword=5)
except Exception as e:
return {
"paper": paper,
"error": f"GitHub 搜索失败: {e}",
}
print(f" 去重后获得 {len(broad_results)} 个候选仓库")
print(f" 耗时: {time.time() - t0:.1f}s")
if not broad_results:
return {
"paper": paper,
"direction": {},
"repos": [],
"error": "未找到相关开源仓库。该方向可能太新或太冷门,暂无高质量开源实现。",
}
# ================================================================
# [2.5] LLM 过滤不相关仓库
# ================================================================
print()
print("[2.5] 正在用 LLM 过滤不相关仓库...")
_prog(0.32, "正在用 LLM 过滤不相关仓库...")
t0 = time.time()
filtered_results = _filter_repos(title, abstract, broad_results)
print(f" 耗时: {time.time() - t0:.1f}s")
# ================================================================
# [2.6] 领域上下文扩充(S2 搜索综述论文补充领域知识)
# ================================================================
print()
print("[2.6] 正在扩充领域上下文(搜索相关综述论文)...")
_prog(0.38, "正在扩充领域上下文...")
t0 = time.time()
domain_context = _enrich_domain_context(title, abstract, paper.get("categories", []))
if domain_context:
print(f" 获取到 {domain_context.count('**') // 2} 篇相关综述的摘要")
else:
print(f" 未找到相关综述(将仅基于论文摘要和仓库数据进行分析)")
print(f" 耗时: {time.time() - t0:.1f}s")
# ================================================================
# [3/6] 基于仓库数据归纳方法族(Agent 1)
# ================================================================
print()
print(f"[3/6] 正在分析 {len(filtered_results)} 个仓库,归纳方法族(Agent 1)...")
_prog(0.42, "正在用 LLM 归纳方法族...")
t0 = time.time()
try:
direction = analyze_direction(title, abstract, filtered_results, domain_context)
except Exception as e:
print(f" [WARN] Agent 1 方向解析失败,降级为基本分析: {e}")
direction = _make_fallback_direction(title, abstract, filtered_results)
subfield = direction.get("subfield", "未知")
families = direction.get("method_families", [])
print(f" 子领域: {subfield}")
print(f" 趋势: {direction.get('subfield_trend', '')[:80]}...")
print(f" 方法族 ({len(families)} 个):")
for mf in families:
matched = mf.get("matched_repos", [])
print(f" - {mf.get('family_name', '?')}: {len(matched)} 个仓库 {matched}")
# 防混淆校验:检查子领域是否与论文标题存在语义关联
_sanity_check_direction(title, subfield)
print(f" 耗时: {time.time() - t0:.1f}s")
# ================================================================
# [4/6] 筛选仓库 + 构建方法族归属映射
# ================================================================
print()
print(f"[4/6] 正在筛选并获取仓库详情...")
_prog(0.55, "正在筛选仓库并获取详情...")
t0 = time.time()
# 从 Agent 1 的 matched_repos 中建立 full_name → family_name 映射
repo_family_map = {}
for mf in families:
family_name = mf.get("family_name", "")
for repo_name in mf.get("matched_repos", []):
repo_family_map[repo_name] = family_name
# 从 filtered_results 中筛选:优先选有方法族归属的,再按 stars 排序
classified = [r for r in filtered_results if r["full_name"] in repo_family_map]
unclassified = [r for r in filtered_results if r["full_name"] not in repo_family_map]
classified.sort(key=lambda r: r.get("stars", 0), reverse=True)
unclassified.sort(key=lambda r: r.get("stars", 0), reverse=True)
# 论文自身代码优先排在最前面
if input_paper_repo:
candidates = [input_paper_repo] + (classified + unclassified)[:top_n - 1]
repo_family_map[input_paper_repo["full_name"]] = "本文代码"
else:
candidates = (classified + unclassified)[:top_n]
print(f" 有方法族归属: {len(classified)} 个,未归类: {len(unclassified)} 个")
if input_paper_repo:
print(f" 含论文自身代码: {input_paper_repo['full_name']}")
print(f" 最终选取 {len(candidates)} 个仓库:")
for i, c in enumerate(candidates):
family = repo_family_map.get(c["full_name"], "未归类")
print(f" {i+1}. {c['full_name']} [{family}] Stars:{c.get('stars', 0)}")
print(f" 耗时: {time.time() - t0:.1f}s")
# ================================================================
# [5/6] 仓库详情获取 + 评估(Agent 2,并行)
# ================================================================
print()
print(f"[5/6] 正在获取仓库详情并评估(Agent 2,{len(candidates)} 个仓库并行)...")
_prog(0.60, f"正在评估 {len(candidates)} 个开源仓库...")
t0 = time.time()
def _eval_single(repo, idx, total):
"""单个仓库评估(在线程中执行)"""
full_name = repo["full_name"]
try:
owner, name = full_name.split("/", 1)
except ValueError:
return idx, None
matched_family = repo_family_map.get(full_name, "")
family_tag = f"[{matched_family}]" if matched_family else "[未归类]"
print(f" ({idx+1}/{total}) {full_name} {family_tag}...")
try:
readme = fetch_readme(owner, name)
deps = fetch_dependencies(owner, name)
evaluation = evaluate_repo(repo, readme, deps, matched_family)
except Exception as e:
print(f" [WARN] {full_name} 评估失败: {e}")
evaluation = _make_error_evaluation(str(e))
return idx, {
"full_name": repo.get("full_name", ""),
"html_url": repo.get("html_url", ""),
"description": repo.get("description", ""),
"stars": repo.get("stars", 0),
"language": repo.get("language", ""),
"updated_at": repo.get("updated_at", ""),
"topics": repo.get("topics", []),
"match_keyword": repo.get("match_keyword", ""),
"method_family": matched_family,
"evaluation": evaluation,
}
n = len(candidates)
evaluated = [None] * n
with ThreadPoolExecutor(max_workers=min(n, 5)) as executor:
futures = {
executor.submit(_eval_single, repo, i, n): i
for i, repo in enumerate(candidates)
}
for future in as_completed(futures):
idx, result = future.result()
if result is not None:
evaluated[idx] = result
evaluated = [r for r in evaluated if r is not None]
# 按综合评分降序排列
evaluated.sort(key=lambda r: r["evaluation"].get("overall_score", 0), reverse=True)
print(f" 耗时: {time.time() - t0:.1f}s")
# ================================================================
# 汇总
# ================================================================
print()
print("=" * 60)
print("完成!")
best = evaluated[0] if evaluated else None
if best:
print(f" 最高分: {best['full_name']} ({best['evaluation'].get('overall_score', 0)}/100)")
# ================================================================
# 审核层:检查 Agent 输出质量和来源可靠性
# ================================================================
_prog(0.90, "正在进行质量审核...")
audit = supervise(title, abstract, direction, evaluated)
print(f"\n 审核结果: {audit['summary']}")
print(f" 综合质量评分: {audit['overall_score']}/100")
if audit["actions"]:
for action in audit["actions"]:
print(f" → {action}")
print(f" 总耗时: {time.time() - t_start:.1f}s")
print("=" * 60)
_prog(1.0, "分析完成,正在生成研报...")
return {
"paper": paper,
"direction": direction,
"repos": evaluated,
"audit": audit,
}
def _extract_broad_queries(title: str, abstract: str) -> list[str]:
"""用 LLM 从论文标题和摘要中生成 GitHub 搜索关键词。
相比规则提取,LLM 理解论文后能生成更精准的领域搜索词。
如果 LLM 调用失败,降级为规则提取。
"""
system_prompt = """你是学术论文搜索专家。根据论文标题和摘要,生成 6-8 个 GitHub 搜索查询,用于找到该研究方向的开源实现。
要求:
- 查询必须是英文,用空格分隔关键词
- 包含方法名/技术关键词 + 限定词(implementation / pytorch / code)
- 包含 1-2 个宽泛的领域查询(如 "anomaly detection pytorch library")
- 不要用引号或特殊符号
- 避免太泛的词(如单一个 "anomaly")
输出严格 JSON:
{"queries": ["query1", "query2", ...]}"""
user_prompt = f"""论文标题: {title[:300]}
摘要: {abstract[:800]}
请为该论文生成 GitHub 搜索查询。"""
try:
raw = call_llm_json(system_prompt, user_prompt, temperature=0.3, max_tokens=3000)
data = parse_json_safe(raw, "broad_queries")
queries = data.get("queries", [])
if isinstance(queries, list) and len(queries) >= 3:
return queries[:10]
except Exception as e:
print(f" [WARN] LLM 关键词生成失败,降级为规则提取: {e}")
return _extract_broad_queries_fallback(title, abstract)
def _extract_broad_queries_fallback(title: str, abstract: str) -> list[str]:
"""从论文标题和摘要中提取宽泛搜索关键词(规则降级版)。
策略:提取有意义的词组 + 搜索限定词,按多样性采样(不偏向标题前缀)。
不依赖 LLM,纯规则提取。
"""
abstract_first = abstract.split(".")[0] if abstract else ""
text = f"{title} {abstract_first}".lower()
text = re.sub(r'[^\w\s-]', ' ', text)
words = text.split()
stop_words = {
'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are',
'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it',
'its', 'not', 'can', 'has', 'have', 'been', 'was', 'were', 'will', 'would',
'could', 'should', 'may', 'do', 'does', 'did', 'so', 'if', 'no', 'new',
'based', 'using', 'which', 'into', 'such', 'than', 'then', 'these', 'those',
'propose', 'present', 'method', 'approach', 'framework', 'novel', 'model',
'learning', 'deep', 'via', 'et', 'al', 'all', 'we', 'you', 'they',
}
meaningful = [w for w in words if w not in stop_words and len(w) > 1]
bigrams = [f"{meaningful[i]} {meaningful[i+1]}" for i in range(len(meaningful) - 1)]
trigrams = [f"{meaningful[i]} {meaningful[i+1]} {meaningful[i+2]}" for i in range(len(meaningful) - 2)]
def _sample_diverse(items: list[str], n: int) -> list[str]:
if len(items) <= n:
return items
step = max(1, len(items) // n)
return [items[i] for i in range(0, len(items), step)][:n]
phrases = _sample_diverse(trigrams, 3) + _sample_diverse(bigrams, 4)
seen = set()
unique_phrases = []
for p in sorted(phrases, key=len, reverse=True):
if p not in seen:
seen.add(p)
unique_phrases.append(p)
qualifiers = [
"implementation pytorch",
"official code",
"pytorch github",
]
queries = []
for phrase in unique_phrases[:5]:
if len(phrase) > 5:
queries.append(f"{phrase} {qualifiers[0]}")
mid = len(meaningful) // 2
broad = " ".join(meaningful[max(0, mid-2):mid+2])
if len(broad) > 10:
queries.append(f"{broad} {qualifiers[2]}")
seen = set()
unique = []
for q in queries:
if q not in seen:
seen.add(q)
unique.append(q)
return unique[:10]
def _filter_repos(title: str, abstract: str, repos: list[dict]) -> list[dict]:
"""用 LLM 快速过滤不相关的仓库,只保留与论文方向相关的。
在 GitHub 关键词搜索之后、Agent 1 分析之前运行。
如果过滤后为空或 LLM 调用失败,返回原始列表作为降级方案。
"""
if len(repos) <= 3:
return repos
# 超过 20 个仓库时,只取 top 20(按 stars),避免 JSON 输出过长
if len(repos) > 20:
repos_for_filter = sorted(repos, key=lambda r: r.get("stars", 0), reverse=True)[:20]
else:
repos_for_filter = repos
# 构造缩略的仓库清单供 LLM 判断
repo_list_parts = []
for i, r in enumerate(repos_for_filter):
desc = (r.get("description") or "")[:120]
topics = ", ".join(r.get("topics", [])[:8])
repo_list_parts.append(
f"{i+1}. {r['full_name']} (Stars:{r.get('stars',0)})\n"
f" Description: {desc}\n"
f" Topics: {topics}"
)
repo_list_text = "\n".join(repo_list_parts)
system_prompt = """你是一个学术论文与开源代码匹配专家。任务是判断 GitHub 仓库是否与给定论文属于同一研究方向。
排除标准(满足任一即排除):
- 仓库解决的业务问题与论文完全不同(如论文做工业缺陷检测,仓库做医学图像/自动驾驶/人脸识别)
- 仓库使用的核心技术方法与论文完全无关
- 仓库是通用教程/课程作业/面试题集合
保留标准(满足任一即保留):
- 仓库实现的方法与论文方法同属一个技术范式
- 仓库可用于该论文的对比实验(baseline/comparison)
- 仓库是该领域的知名 benchmarking 库
输出严格 JSON(不要 reasoning 字段以节省 token):
{"relevant": ["owner/repo1", "owner/repo2"], "irrelevant": ["owner/repo3"]}"""
user_prompt = f"""## 论文信息
标题: {title[:200]}
摘要: {abstract[:500]}
## GitHub 仓库清单(共 {len(repos_for_filter)} 个)
{repo_list_text}
请判断每个仓库是否与该论文方向相关。"""
try:
raw = call_llm_json(system_prompt, user_prompt, temperature=0.1, max_tokens=10000)
data = parse_json_safe(raw, "filter_repos")
except Exception as e:
print(f" [WARN] 仓库过滤失败,保留全部: {e}")
return repos
relevant_set = set(data.get("relevant", []))
irrelevant = data.get("irrelevant", [])
if irrelevant:
print(f" [过滤] 剔除 {len(irrelevant)} 个不相关仓库: {irrelevant}")
filtered = [r for r in repos if r["full_name"] in relevant_set]
if len(filtered) < 2:
print(f" [WARN] 过滤后仅剩 {len(filtered)} 个仓库,回退到原始结果")
return repos
print(f" 过滤后保留 {len(filtered)}/{len(repos)} 个仓库")
return filtered
def _sanity_check_direction(title: str, subfield: str) -> None:
"""防混淆检查:验证子领域分析是否与论文标题存在最低限度的语义关联。
如果子领域关键词与标题完全无关,可能是 LLM 混淆了论文。
仅打印警告,不阻断流程。
"""
# 提取标题中的实义词(长度>=4,排除停用词)
title_stops = {
'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are',
'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it',
'its', 'not', 'can', 'has', 'have', 'been', 'was', 'were', 'will', 'would',
'could', 'should', 'may', 'do', 'does', 'did', 'so', 'if', 'no', 'new',
'based', 'using', 'which', 'into', 'such', 'than', 'then', 'these', 'those',
'propose', 'present', 'method', 'approach', 'framework', 'novel', 'model',
'learning', 'deep', 'via', 'et', 'al',
}
title_words = set(
w.lower() for w in re.sub(r'[^\w\s]', ' ', title).split()
if len(w) >= 4 and w.lower() not in title_stops
)
subfield_lower = subfield.lower()
overlap = [w for w in title_words if w in subfield_lower]
if not overlap and title_words:
print(f" ⚠️ [防混淆警告] 子领域\"{subfield}\"与论文标题无关键词重叠")
print(f" 标题关键词: {sorted(title_words)[:10]}")
print(f" 这可能是 LLM 混淆了论文,请人工核实分析结果。")
def _make_fallback_direction(title: str, abstract: str, repos: list[dict]) -> dict:
"""Agent 1 失败时的降级方向分析:基于搜索到的仓库名称推断子领域。"""
# 从仓库 topic/description 提取高频词作为子领域
all_words = []
for r in repos[:10]:
desc = (r.get("description") or "")
topics = " ".join(r.get("topics", []))
all_words.extend((desc + " " + topics).lower().split())
stops = {'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are',
'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it'}
meaningful = [w for w in all_words if w not in stops and len(w) >= 4]
word_freq = {}
for w in meaningful:
word_freq[w] = word_freq.get(w, 0) + 1
top_words = sorted(word_freq, key=word_freq.get, reverse=True)[:6]
return {
"subfield": f"基于仓库数据推断: {', '.join(top_words[:3])}" if top_words else "未知领域",
"subfield_trend": "(Agent 1 暂不可用,趋势分析跳过。后续版本将自动恢复。)",
"method_families": [],
"broad_queries": [],
}
def _make_error_evaluation(error_msg: str) -> dict:
"""构造一个表示评估失败的 evaluation dict"""
return {
"reproducibility_score": 0,
"benchmark_fitness_score": 0,
"overall_score": 0,
"verdict": "error",
"env_score": 0,
"doc_score": 0,
"code_score": 0,
"community_score": 0,
"dep_score": 0,
"benchmark_score": 0,
"reasoning": f"评估失败: {error_msg[:100]}",
"risks": ["评估过程出错,请手动检查该仓库"],
"benchmark_readiness": "not_ready",
"suggested_use": "评估失败,请手动检查",
}
# ============================================================
# 命令行入口
# ============================================================
if __name__ == "__main__":
# 修复 Windows 终端 emoji 编码问题
from llm_utils import fix_windows_encoding
fix_windows_encoding()
if len(sys.argv) < 2:
print("用法: python run.py <arxiv_url>")
print("示例: python run.py https://arxiv.org/abs/1706.03762")
print("示例: python run.py https://arxiv.org/abs/2011.08785")
sys.exit(1)
url = sys.argv[1]
result = run(url)
if result.get("error"):
print(f"\n{'='*60}")
print(f"运行未完全成功")
print(f"{'='*60}")
print(f" {result['error']}")
if result.get("paper"):
print(f"\n 论文信息已获取: {result['paper'].get('title', '')[:80]}")
if result.get("direction"):
print(f" 方向解析已完成: {result['direction'].get('subfield', '')}")
else:
from app import format_report
# 用 utf-8 编码输出避免 emoji 乱码
report = format_report(result)
print(report.encode(sys.stdout.encoding or 'utf-8', errors='replace').decode(sys.stdout.encoding or 'utf-8', errors='replace'))
|