ZZZyx3587 commited on
Commit
151e05c
·
verified ·
1 Parent(s): 4c14732

Upload run.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run.py +99 -1
run.py CHANGED
@@ -259,6 +259,83 @@ def _search_s2_papers(query: str, limit: int = 5) -> list[dict]:
259
  return papers
260
 
261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  def run(arxiv_url: str, top_n: int = 5, progress=None) -> dict:
263
  """主入口:先搜索 GitHub → 再基于实际仓库归纳方法族 → 最后评估。
264
 
@@ -294,6 +371,20 @@ def run(arxiv_url: str, top_n: int = 5, progress=None) -> dict:
294
  print(f" 分类: {', '.join(paper.get('categories', []))}")
295
  print(f" 耗时: {time.time() - t0:.1f}s")
296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  # ================================================================
298
  # [1.5] 对比实验算法挖掘(从 S2 引用网络提取算法名)
299
  # ================================================================
@@ -443,9 +534,16 @@ def run(arxiv_url: str, top_n: int = 5, progress=None) -> dict:
443
  classified.sort(key=lambda r: r.get("stars", 0), reverse=True)
444
  unclassified.sort(key=lambda r: r.get("stars", 0), reverse=True)
445
 
446
- candidates = (classified + unclassified)[:top_n]
 
 
 
 
 
447
 
448
  print(f" 有方法族归属: {len(classified)} 个,未归类: {len(unclassified)} 个")
 
 
449
  print(f" 最终选取 {len(candidates)} 个仓库:")
450
  for i, c in enumerate(candidates):
451
  family = repo_family_map.get(c["full_name"], "未归类")
 
259
  return papers
260
 
261
 
262
+ def _find_input_paper_repo(title: str, arxiv_id: str, authors: list[str]) -> dict | None:
263
+ """搜索输入论文自身的官方代码仓库。
264
+
265
+ 搜索策略(按优先级):
266
+ 1. GitHub 搜索 arxiv ID
267
+ 2. GitHub 搜索论文标题(精确匹配)
268
+ 3. GitHub 搜索一作姓名 + 论文核心关键词
269
+
270
+ Returns:
271
+ 找到则返回符合格式的候选仓库 dict,找不到返回 None
272
+ """
273
+ import requests
274
+
275
+ token = os.getenv("GITHUB_TOKEN", "")
276
+ headers = {"Accept": "application/vnd.github.v3+json"}
277
+ if token:
278
+ headers["Authorization"] = f"Bearer {token}"
279
+
280
+ # 提取核心关键词(取标题前 6 个实义词)
281
+ stops = {
282
+ 'a', 'an', 'the', 'of', 'for', 'in', 'on', 'to', 'and', 'or', 'is', 'are',
283
+ 'we', 'our', 'that', 'this', 'with', 'from', 'by', 'as', 'at', 'be', 'it',
284
+ 'its', 'not', 'can', 'has', 'have', 'been', 'was', 'were', 'will', 'would',
285
+ 'could', 'should', 'may', 'do', 'does', 'did', 'so', 'if', 'no', 'new',
286
+ 'based', 'using', 'which', 'into', 'such', 'than', 'then', 'these', 'those',
287
+ 'propose', 'present', 'method', 'approach', 'framework', 'novel', 'model',
288
+ 'learning', 'deep', 'via', 'et', 'al', 'towards', 'toward',
289
+ }
290
+ title_words = [w for w in re.sub(r'[^\w\s-]', ' ', title).split()
291
+ if w.lower() not in stops and len(w) >= 3]
292
+ core_keywords = " ".join(title_words[:6])
293
+
294
+ search_queries = [
295
+ f'"{arxiv_id}" in:name,description,readme',
296
+ f'"{core_keywords}" in:name,description stars:>=3',
297
+ ]
298
+
299
+ # 用一作姓名 + 关键词搜索
300
+ first_author = authors[0] if authors else ""
301
+ if first_author:
302
+ last_name = first_author.split()[-1] if first_author.split() else ""
303
+ if len(last_name) >= 3:
304
+ search_queries.append(f"{last_name} {core_keywords[:80]} in:name,description")
305
+
306
+ for query in search_queries:
307
+ url = "https://api.github.com/search/repositories"
308
+ params = {"q": query, "sort": "stars", "order": "desc", "per_page": 5}
309
+ try:
310
+ resp = requests.get(url, headers=headers, params=params, timeout=15)
311
+ if resp.status_code in (403, 429):
312
+ continue
313
+ resp.raise_for_status()
314
+ data = resp.json()
315
+ except Exception:
316
+ continue
317
+
318
+ # 筛选:标题必须包含核心关键词中至少 2 个
319
+ for item in data.get("items", []):
320
+ repo_title = (item.get("description") or "").lower()
321
+ repo_name = item.get("full_name", "").lower()
322
+ combined = f"{repo_name} {repo_title}"
323
+ matches = sum(1 for kw in title_words[:6] if kw.lower() in combined)
324
+ if matches >= 2:
325
+ return {
326
+ "full_name": item["full_name"],
327
+ "html_url": item["html_url"],
328
+ "description": item.get("description", ""),
329
+ "stars": item.get("stargazers_count", 0),
330
+ "language": item.get("language", ""),
331
+ "updated_at": item.get("updated_at", ""),
332
+ "topics": item.get("topics", []),
333
+ "match_keyword": f"本文代码: {query[:60]}",
334
+ }
335
+
336
+ return None
337
+
338
+
339
  def run(arxiv_url: str, top_n: int = 5, progress=None) -> dict:
340
  """主入口:先搜索 GitHub → 再基于实际仓库归纳方法族 → 最后评估。
341
 
 
371
  print(f" 分类: {', '.join(paper.get('categories', []))}")
372
  print(f" 耗时: {time.time() - t0:.1f}s")
373
 
374
+ # ================================================================
375
+ # [1.2] 搜索输入论文的官方代码
376
+ # ================================================================
377
+ print()
378
+ print("[1.2] 正在搜索输入论文的官方代码...")
379
+ _prog(0.10, "正在搜索论文自身的官方代码...")
380
+ t0 = time.time()
381
+ input_paper_repo = _find_input_paper_repo(title, arxiv_id, paper.get("authors", []))
382
+ if input_paper_repo:
383
+ print(f" 找到论文官方代码: {input_paper_repo['full_name']} (Stars: {input_paper_repo.get('stars', 0)})")
384
+ else:
385
+ print(f" 未找到论文官方代码")
386
+ print(f" 耗时: {time.time() - t0:.1f}s")
387
+
388
  # ================================================================
389
  # [1.5] 对比实验算法挖掘(从 S2 引用网络提取算法名)
390
  # ================================================================
 
534
  classified.sort(key=lambda r: r.get("stars", 0), reverse=True)
535
  unclassified.sort(key=lambda r: r.get("stars", 0), reverse=True)
536
 
537
+ # 论文自身代码优先排在最前面
538
+ if input_paper_repo:
539
+ candidates = [input_paper_repo] + (classified + unclassified)[:top_n - 1]
540
+ repo_family_map[input_paper_repo["full_name"]] = "本文代码"
541
+ else:
542
+ candidates = (classified + unclassified)[:top_n]
543
 
544
  print(f" 有方法族归属: {len(classified)} 个,未归类: {len(unclassified)} 个")
545
+ if input_paper_repo:
546
+ print(f" 含论文自身代码: {input_paper_repo['full_name']}")
547
  print(f" 最终选取 {len(candidates)} 个仓库:")
548
  for i, c in enumerate(candidates):
549
  family = repo_family_map.get(c["full_name"], "未归类")