ZZZyx3587 commited on
Commit
fa671bd
·
verified ·
1 Parent(s): 288ba88

Upload run.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run.py +202 -4
run.py CHANGED
@@ -9,6 +9,10 @@ import sys
9
  import time
10
  import re
11
  import os
 
 
 
 
12
  from concurrent.futures import ThreadPoolExecutor, as_completed
13
 
14
  # 自动加载 .env 文件
@@ -52,6 +56,140 @@ from repo_evaluator import evaluate_repo
52
  from llm_utils import call_llm_json, parse_json_safe
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def run(arxiv_url: str, top_n: int = 5) -> dict:
56
  """主入口:先搜索 GitHub → 再基于实际仓库归纳方法族 → 最后评估。
57
 
@@ -88,6 +226,20 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
88
  print(f" 分类: {', '.join(paper.get('categories', []))}")
89
  print(f" 耗时: {time.time() - t0:.1f}s")
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  # ================================================================
92
  # [2/5] 宽泛 GitHub 搜索(甲 Workflow)
93
  # ================================================================
@@ -96,12 +248,14 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
96
  t0 = time.time()
97
 
98
  broad_queries = _extract_broad_queries(title, abstract)
99
- print(f" 生成 {len(broad_queries)} 个宽泛搜索词:")
100
- for q in broad_queries:
 
 
101
  print(f" - {q}")
102
 
103
  try:
104
- broad_results = search_repos(broad_queries, max_per_keyword=5)
105
  except Exception as e:
106
  return {
107
  "paper": paper,
@@ -128,6 +282,19 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
128
  filtered_results = _filter_repos(title, abstract, broad_results)
129
  print(f" 耗时: {time.time() - t0:.1f}s")
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  # ================================================================
132
  # [3/5] 基于仓库数据归纳方法族(Agent 1) ← 核心改动
133
  # ================================================================
@@ -136,7 +303,7 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
136
  t0 = time.time()
137
 
138
  try:
139
- direction = analyze_direction(title, abstract, filtered_results)
140
  except Exception as e:
141
  return {
142
  "paper": paper,
@@ -151,6 +318,9 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
151
  for mf in families:
152
  matched = mf.get("matched_repos", [])
153
  print(f" - {mf.get('family_name', '?')}: {len(matched)} 个仓库 {matched}")
 
 
 
154
  print(f" 耗时: {time.time() - t0:.1f}s")
155
 
156
  # ================================================================
@@ -435,6 +605,34 @@ def _filter_repos(title: str, abstract: str, repos: list[dict]) -> list[dict]:
435
  return filtered
436
 
437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
  def _make_error_evaluation(error_msg: str) -> dict:
439
  """构造一个表示评估失败的 evaluation dict"""
440
  return {
 
9
  import time
10
  import re
11
  import os
12
+ import json
13
+ import urllib.request
14
+ import urllib.error
15
+ import xml.etree.ElementTree as ET
16
  from concurrent.futures import ThreadPoolExecutor, as_completed
17
 
18
  # 自动加载 .env 文件
 
56
  from llm_utils import call_llm_json, parse_json_safe
57
 
58
 
59
+ def _enrich_domain_context(title: str, abstract: str, categories: list[str]) -> str:
60
+ """从 arxiv 搜索同领域的综述和相关高引论文,扩充领域上下文。
61
+
62
+ 当输入论文不是综述时,普通研究论文的摘要不足以让 Agent 1 推断领域全景。
63
+ 此函数从 arxiv 搜索相关综述/survey 论文,提取摘要作为补充上下文。
64
+ 失败时返回空字符串(静默降级)。
65
+ """
66
+ if not categories:
67
+ return ""
68
+
69
+ primary_cat = categories[0]
70
+ # 用主要分类 + 关键词搜索综述/survey 论文
71
+ keywords = ["survey", "review", "comprehensive", "benchmark"]
72
+ all_abstracts = []
73
+
74
+ for kw in keywords[:2]: # 只搜两个关键词,避免触发 arxiv 限速
75
+ search_query = f"cat:{primary_cat}+AND+ti:{kw}"
76
+ api_url = (
77
+ f"http://export.arxiv.org/api/query"
78
+ f"?search_query={search_query}&sortBy=relevance&max_results=2"
79
+ )
80
+ req = urllib.request.Request(api_url, headers={"User-Agent": "ResearchRadar/1.0"})
81
+ try:
82
+ with urllib.request.urlopen(req, timeout=15) as resp:
83
+ xml_text = resp.read().decode("utf-8")
84
+ root = ET.fromstring(xml_text)
85
+ ns = {"atom": "http://www.w3.org/2005/Atom"}
86
+ for entry in root.findall("atom:entry", ns):
87
+ t = entry.find("atom:title", ns)
88
+ s = entry.find("atom:summary", ns)
89
+ if t is not None and t.text and s is not None and s.text:
90
+ all_abstracts.append({
91
+ "title": t.text.strip().replace("\n", " "),
92
+ "abstract": s.text.strip().replace("\n", " ")[:800],
93
+ })
94
+ except Exception:
95
+ continue
96
+
97
+ if not all_abstracts:
98
+ return ""
99
+
100
+ # 组装为上下文文本
101
+ lines = ["## 该领域的相关综述/调查论文(供参考领域全景)"]
102
+ for i, a in enumerate(all_abstracts[:4]):
103
+ lines.append(f"{i+1}. **{a['title']}**: {a['abstract']}")
104
+ return "\n\n".join(lines)
105
+
106
+
107
+ def _mine_comparison_algorithms(arxiv_id: str, title: str, abstract: str) -> tuple[list[str], dict[str, int]]:
108
+ """从 Semantic Scholar 引用网络中挖掘对比实验算法。
109
+
110
+ 核心思路:论文的 references(引用的论文)中通常包含对比实验的 baseline 方法,
111
+ 通过提取这些论文的标题作为额外搜索词,可以大幅提升 GitHub 搜索的覆盖度。
112
+
113
+ Returns:
114
+ (extra_queries, citation_map): 额外搜索词列表, {paper_title: citation_count}
115
+ """
116
+ s2_url = (
117
+ f"https://api.semanticscholar.org/graph/v1/paper/ArXiv:{arxiv_id}"
118
+ f"?fields=references.title,references.citationCount,references.abstract"
119
+ f"&limit=50"
120
+ )
121
+ req = urllib.request.Request(s2_url, headers={"User-Agent": "ResearchRadar/1.0"})
122
+ try:
123
+ with urllib.request.urlopen(req, timeout=20) as resp:
124
+ data = json.loads(resp.read().decode("utf-8"))
125
+ except Exception as e:
126
+ print(f" [WARN] Semantic Scholar API 不可用,跳过对比算法挖掘: {e}")
127
+ return [], {}
128
+
129
+ refs = data.get("references", [])
130
+ if not refs:
131
+ return [], {}
132
+
133
+ # 按引用量排序,取 top 15
134
+ refs.sort(key=lambda r: r.get("citationCount", 0), reverse=True)
135
+ top_refs = refs[:15]
136
+
137
+ # 构建标题列表供 LLM 识别方法论文
138
+ title_list = []
139
+ citation_map = {}
140
+ for r in top_refs:
141
+ t = (r.get("title") or "").strip()
142
+ cc = r.get("citationCount", 0)
143
+ if t and len(t) > 10:
144
+ title_list.append(f"- [{cc} cites] {t}")
145
+ citation_map[t] = cc
146
+
147
+ if len(title_list) < 3:
148
+ return [], citation_map
149
+
150
+ # 用 LLM 从引用论文标题中识别哪些是方法/算法论文
151
+ system_prompt = """你是学术论文分析专家。从引用论文列表中识别哪些是提出了具体算法/方法的论文。
152
+
153
+ 排除标准:
154
+ - 数据集/benchmark 论文(如 ImageNet, CIFAR, MVTec AD)
155
+ - 综述/survey 论文
156
+ - 纯理论/数学论文
157
+ - 框架/库论文(如 PyTorch, TensorFlow)
158
+
159
+ 保留标准:
160
+ - 提出了具体的模型/架构/算法名称
161
+ - 可以作为对比实验的 baseline 方法
162
+
163
+ 输出严格 JSON:
164
+ {"methods": ["方法名1", "方法名2"], "search_queries": ["method1 pytorch implementation", "method2 official code"]}
165
+
166
+ 方法名尽量使用论文中常用的英文缩写或全称。"""
167
+
168
+ user_prompt = f"""输入论文标题: {title[:200]}
169
+
170
+ 引用论文列表:
171
+ {chr(10).join(title_list)}
172
+
173
+ 请识别哪些是方法/算法论文,生成对应的 GitHub 搜索词。"""
174
+
175
+ try:
176
+ raw = call_llm_json(system_prompt, user_prompt, temperature=0.2, max_tokens=800)
177
+ data = parse_json_safe(raw, "comparison_miner")
178
+ methods = data.get("methods", [])
179
+ queries = data.get("search_queries", [])
180
+ print(f" 从引用网络识别到 {len(methods)} 个对比算法: {methods[:8]}")
181
+ return queries[:8], citation_map
182
+ except Exception as e:
183
+ print(f" [WARN] 对比算法识别失败,降级: {e}")
184
+ # 降级:直接用引用论文标题作为搜索词
185
+ fallback_queries = []
186
+ for t in list(citation_map.keys())[:5]:
187
+ short = re.sub(r'[^\w\s-]', '', t).strip()[:80]
188
+ if len(short) > 15:
189
+ fallback_queries.append(f"{short} pytorch implementation")
190
+ return fallback_queries, citation_map
191
+
192
+
193
  def run(arxiv_url: str, top_n: int = 5) -> dict:
194
  """主入口:先搜索 GitHub → 再基于实际仓库归纳方法族 → 最后评估。
195
 
 
226
  print(f" 分类: {', '.join(paper.get('categories', []))}")
227
  print(f" 耗时: {time.time() - t0:.1f}s")
228
 
229
+ # ================================================================
230
+ # [1.5] 对比实验算法挖掘(从 S2 引用网络提取算法名)
231
+ # ================================================================
232
+ print()
233
+ print("[1.5] 正在从论文引用网络挖掘对比实验算法...")
234
+ t0 = time.time()
235
+ arxiv_id = paper.get("arxiv_id", "")
236
+ comparison_queries, citation_map = _mine_comparison_algorithms(arxiv_id, title, abstract)
237
+ if comparison_queries:
238
+ print(f" 生成 {len(comparison_queries)} 个对比算法搜索词:")
239
+ for q in comparison_queries[:5]:
240
+ print(f" - {q}")
241
+ print(f" 耗时: {time.time() - t0:.1f}s")
242
+
243
  # ================================================================
244
  # [2/5] 宽泛 GitHub 搜索(甲 Workflow)
245
  # ================================================================
 
248
  t0 = time.time()
249
 
250
  broad_queries = _extract_broad_queries(title, abstract)
251
+ # 合并对比算法搜索词,去重
252
+ all_queries = list(dict.fromkeys(comparison_queries + broad_queries))[:12]
253
+ print(f" 生成 {len(broad_queries)} 个宽泛搜索词 + {len(comparison_queries)} 个对比算法搜索词 = {len(all_queries)} 个总搜索词:")
254
+ for q in all_queries:
255
  print(f" - {q}")
256
 
257
  try:
258
+ broad_results = search_repos(all_queries, max_per_keyword=5)
259
  except Exception as e:
260
  return {
261
  "paper": paper,
 
282
  filtered_results = _filter_repos(title, abstract, broad_results)
283
  print(f" 耗时: {time.time() - t0:.1f}s")
284
 
285
+ # ================================================================
286
+ # [2.6] 领域上下文扩充(从 arxiv 搜索综述论文补充领域知识)
287
+ # ================================================================
288
+ print()
289
+ print("[2.6] 正在扩充领域上下文(搜索相关综述论文)...")
290
+ t0 = time.time()
291
+ domain_context = _enrich_domain_context(title, abstract, paper.get("categories", []))
292
+ if domain_context:
293
+ print(f" 获取到 {domain_context.count('**') // 2} 篇相关综述的摘要")
294
+ else:
295
+ print(f" 未找到相关综述(将仅基于论文摘要和仓库数据进行分析)")
296
+ print(f" 耗时: {time.time() - t0:.1f}s")
297
+
298
  # ================================================================
299
  # [3/5] 基于仓库数据归纳方法族(Agent 1) ← 核心改动
300
  # ================================================================
 
303
  t0 = time.time()
304
 
305
  try:
306
+ direction = analyze_direction(title, abstract, filtered_results, domain_context)
307
  except Exception as e:
308
  return {
309
  "paper": paper,
 
318
  for mf in families:
319
  matched = mf.get("matched_repos", [])
320
  print(f" - {mf.get('family_name', '?')}: {len(matched)} 个仓库 {matched}")
321
+
322
+ # 防混淆校验:检查子领域是否与论文标题存在语义关联
323
+ _sanity_check_direction(title, subfield)
324
  print(f" 耗时: {time.time() - t0:.1f}s")
325
 
326
  # ================================================================
 
605
  return filtered
606
 
607
 
608
+ def _sanity_check_direction(title: str, subfield: str) -> None:
609
+ """防混淆检查:验证子领域分析是否与论文标题存在最低限度的语义��联。
610
+
611
+ 如果子领域关键词与标题完全无关,可能是 LLM 混淆了论文。
612
+ 仅打印警告,不阻断流程。
613
+ """
614
+ # 提取标题中的实义词(长度>=4,排除停用词)
615
+ title_stops = {
616
+ 'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are',
617
+ 'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it',
618
+ 'its', 'not', 'can', 'has', 'have', 'been', 'was', 'were', 'will', 'would',
619
+ 'could', 'should', 'may', 'do', 'does', 'did', 'so', 'if', 'no', 'new',
620
+ 'based', 'using', 'which', 'into', 'such', 'than', 'then', 'these', 'those',
621
+ 'propose', 'present', 'method', 'approach', 'framework', 'novel', 'model',
622
+ 'learning', 'deep', 'via', 'et', 'al',
623
+ }
624
+ title_words = set(
625
+ w.lower() for w in re.sub(r'[^\w\s]', ' ', title).split()
626
+ if len(w) >= 4 and w.lower() not in title_stops
627
+ )
628
+ subfield_lower = subfield.lower()
629
+ overlap = [w for w in title_words if w in subfield_lower]
630
+ if not overlap and title_words:
631
+ print(f" ⚠️ [防混淆警告] 子领域\"{subfield}\"与论文标题无关键词重叠")
632
+ print(f" 标题关键词: {sorted(title_words)[:10]}")
633
+ print(f" 这可能是 LLM 混淆了论文,请人工核实分析结果。")
634
+
635
+
636
  def _make_error_evaluation(error_msg: str) -> dict:
637
  """构造一个表示评估失败的 evaluation dict"""
638
  return {