ZZZyx3587 commited on
Commit
097cbfc
·
verified ·
1 Parent(s): 3b1ebc4

Upload run.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run.py +27 -18
run.py CHANGED
@@ -251,29 +251,28 @@ def _search_s2_papers(query: str, limit: int = 5) -> list[dict]:
251
  return papers
252
 
253
 
254
- def run(arxiv_url: str, top_n: int = 5) -> dict:
255
  """主入口:先搜索 GitHub → 再基于实际仓库归纳方法族 → 最后评估。
256
 
257
- 新流程(方案 B):
258
- 第一层:论文信息获取 (Workflow)
259
- 第二层:宽泛 GitHub 搜索 (Workflow)
260
- 第三层:基于仓库数据归纳方法族 (Agent 1) ← 不再是 LLM 凭记忆
261
- 第四层:仓库评估 (Agent 2)
262
-
263
  Args:
264
  arxiv_url: arxiv 论文 URL
265
  top_n: 最终评估的仓库数量
 
266
 
267
  Returns:
268
  dict: {"paper": {...}, "direction": {...}, "repos": [{...}], "error": "..."}
269
  """
270
  t_start = time.time()
 
 
 
271
 
272
  # ================================================================
273
- # [1/5] 论文信息获取(甲 Workflow)
274
  # ================================================================
275
  print("=" * 60)
276
- print("[1/5] 正在获取论文信息...")
 
277
  t0 = time.time()
278
  try:
279
  paper = fetch_paper_info(arxiv_url)
@@ -292,6 +291,7 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
292
  # ================================================================
293
  print()
294
  print("[1.5] 正在从论文引用网络挖掘对比实验算法...")
 
295
  t0 = time.time()
296
  arxiv_id = paper.get("arxiv_id", "")
297
  comparison_queries, citation_map = _mine_comparison_algorithms(arxiv_id, title, abstract)
@@ -306,6 +306,7 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
306
  # ================================================================
307
  print()
308
  print("[1.8] 正在 Semantic Scholar 中搜索同领域论文(覆盖全工科)...")
 
309
  t0 = time.time()
310
  title_kws = _extract_title_keywords(title)
311
  s2_papers = _search_s2_papers(" ".join(title_kws), limit=8)
@@ -327,10 +328,11 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
327
  print(f" 耗时: {time.time() - t0:.1f}s")
328
 
329
  # ================================================================
330
- # [2/5] 宽泛 GitHub 搜索(甲 Workflow)
331
  # ================================================================
332
  print()
333
- print("[2/5] 正在宽泛搜索 GitHub(先搜仓库,再让 Agent 分析)...")
 
334
  t0 = time.time()
335
 
336
  broad_queries = _extract_broad_queries(title, abstract)
@@ -364,15 +366,17 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
364
  # ================================================================
365
  print()
366
  print("[2.5] 正在用 LLM 过滤不相关仓库...")
 
367
  t0 = time.time()
368
  filtered_results = _filter_repos(title, abstract, broad_results)
369
  print(f" 耗时: {time.time() - t0:.1f}s")
370
 
371
  # ================================================================
372
- # [2.6] 领域上下文扩充( arxiv 搜索综述论文补充领域知识)
373
  # ================================================================
374
  print()
375
  print("[2.6] 正在扩充领域上下文(搜索相关综述论文)...")
 
376
  t0 = time.time()
377
  domain_context = _enrich_domain_context(title, abstract, paper.get("categories", []))
378
  if domain_context:
@@ -382,10 +386,11 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
382
  print(f" 耗时: {time.time() - t0:.1f}s")
383
 
384
  # ================================================================
385
- # [3/5] 基于仓库数据归纳方法族(Agent 1) ← 核心改动
386
  # ================================================================
387
  print()
388
- print(f"[3/5] 正在分析 {len(filtered_results)} 个仓库,归纳方法族(Agent 1)...")
 
389
  t0 = time.time()
390
 
391
  try:
@@ -410,10 +415,11 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
410
  print(f" 耗时: {time.time() - t0:.1f}s")
411
 
412
  # ================================================================
413
- # [4/5] 筛选仓库 + 构建法族归属映射
414
  # ================================================================
415
  print()
416
- print(f"[4/5] 正在筛选并获取仓库详情...")
 
417
  t0 = time.time()
418
 
419
  # 从 Agent 1 的 matched_repos 中建立 full_name → family_name 映射
@@ -439,10 +445,11 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
439
  print(f" 耗时: {time.time() - t0:.1f}s")
440
 
441
  # ================================================================
442
- # [5/5] 仓库详情获取 + 评估(Agent 2,并行)
443
  # ================================================================
444
  print()
445
- print(f"[5/5] 正在获取仓库详情并评估(Agent 2,{len(candidates)} 个仓库并行)...")
 
446
  t0 = time.time()
447
 
448
  def _eval_single(repo, idx, total):
@@ -508,6 +515,7 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
508
  # ================================================================
509
  # 审核层:检查 Agent 输出质量和来源可靠性
510
  # ================================================================
 
511
  audit = supervise(title, abstract, direction, evaluated)
512
  print(f"\n 审核结果: {audit['summary']}")
513
  print(f" 综合质量评分: {audit['overall_score']}/100")
@@ -518,6 +526,7 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
518
  print(f" 总耗时: {time.time() - t_start:.1f}s")
519
  print("=" * 60)
520
 
 
521
  return {
522
  "paper": paper,
523
  "direction": direction,
 
251
  return papers
252
 
253
 
254
+ def run(arxiv_url: str, top_n: int = 5, progress=None) -> dict:
255
  """主入口:先搜索 GitHub → 再基于实际仓库归纳方法族 → 最后评估。
256
 
 
 
 
 
 
 
257
  Args:
258
  arxiv_url: arxiv 论文 URL
259
  top_n: 最终评估的仓库数量
260
+ progress: gr.Progress 实例或 callable(fraction, desc),用于前端进度条
261
 
262
  Returns:
263
  dict: {"paper": {...}, "direction": {...}, "repos": [{...}], "error": "..."}
264
  """
265
  t_start = time.time()
266
+ def _prog(frac: float, desc: str):
267
+ if progress:
268
+ progress(frac, desc=desc)
269
 
270
  # ================================================================
271
+ # [1/6] 论文信息获取(甲 Workflow)
272
  # ================================================================
273
  print("=" * 60)
274
+ print("[1/6] 正在获取论文信息...")
275
+ _prog(0.05, "正在从 arxiv 获取论文信息...")
276
  t0 = time.time()
277
  try:
278
  paper = fetch_paper_info(arxiv_url)
 
291
  # ================================================================
292
  print()
293
  print("[1.5] 正在从论文引用网络挖掘对比实验算法...")
294
+ _prog(0.12, "正在从引用网络挖掘对比算法...")
295
  t0 = time.time()
296
  arxiv_id = paper.get("arxiv_id", "")
297
  comparison_queries, citation_map = _mine_comparison_algorithms(arxiv_id, title, abstract)
 
306
  # ================================================================
307
  print()
308
  print("[1.8] 正在 Semantic Scholar 中搜索同领域论文(覆盖全工科)...")
309
+ _prog(0.18, "正在 Semantic Scholar 搜索同领域论文...")
310
  t0 = time.time()
311
  title_kws = _extract_title_keywords(title)
312
  s2_papers = _search_s2_papers(" ".join(title_kws), limit=8)
 
328
  print(f" 耗时: {time.time() - t0:.1f}s")
329
 
330
  # ================================================================
331
+ # [2/6] 宽泛 GitHub 搜索(甲 Workflow)
332
  # ================================================================
333
  print()
334
+ print("[2/6] 正在宽泛搜索 GitHub(先搜仓库,再让 Agent 分析)...")
335
+ _prog(0.22, "正在 GitHub 搜索开源仓库...")
336
  t0 = time.time()
337
 
338
  broad_queries = _extract_broad_queries(title, abstract)
 
366
  # ================================================================
367
  print()
368
  print("[2.5] 正在用 LLM 过滤不相关仓库...")
369
+ _prog(0.32, "正在用 LLM 过滤不相关仓库...")
370
  t0 = time.time()
371
  filtered_results = _filter_repos(title, abstract, broad_results)
372
  print(f" 耗时: {time.time() - t0:.1f}s")
373
 
374
  # ================================================================
375
+ # [2.6] 领域上下文扩充(S2 搜索综述论文补充领域知识)
376
  # ================================================================
377
  print()
378
  print("[2.6] 正在扩充领域上下文(搜索相关综述论文)...")
379
+ _prog(0.38, "正在扩充领域上下文...")
380
  t0 = time.time()
381
  domain_context = _enrich_domain_context(title, abstract, paper.get("categories", []))
382
  if domain_context:
 
386
  print(f" 耗时: {time.time() - t0:.1f}s")
387
 
388
  # ================================================================
389
+ # [3/6] 基于仓库数据归纳方法族(Agent 1)
390
  # ================================================================
391
  print()
392
+ print(f"[3/6] 正在分析 {len(filtered_results)} 个仓库,归纳方法族(Agent 1)...")
393
+ _prog(0.42, "正在用 LLM 归纳方法族...")
394
  t0 = time.time()
395
 
396
  try:
 
415
  print(f" 耗时: {time.time() - t0:.1f}s")
416
 
417
  # ================================================================
418
+ # [4/6] 筛选仓库 + 构建��法族归属映射
419
  # ================================================================
420
  print()
421
+ print(f"[4/6] 正在筛选并获取仓库详情...")
422
+ _prog(0.55, "正在筛选仓库并获取详情...")
423
  t0 = time.time()
424
 
425
  # 从 Agent 1 的 matched_repos 中建立 full_name → family_name 映射
 
445
  print(f" 耗时: {time.time() - t0:.1f}s")
446
 
447
  # ================================================================
448
+ # [5/6] 仓库详情获取 + 评估(Agent 2,并行)
449
  # ================================================================
450
  print()
451
+ print(f"[5/6] 正在获取仓库详情并评估(Agent 2,{len(candidates)} 个仓库并行)...")
452
+ _prog(0.60, f"正在评估 {len(candidates)} 个开源仓库...")
453
  t0 = time.time()
454
 
455
  def _eval_single(repo, idx, total):
 
515
  # ================================================================
516
  # 审核层:检查 Agent 输出质量和来源可靠性
517
  # ================================================================
518
+ _prog(0.90, "正在进行质量审核...")
519
  audit = supervise(title, abstract, direction, evaluated)
520
  print(f"\n 审核结果: {audit['summary']}")
521
  print(f" 综合质量评分: {audit['overall_score']}/100")
 
526
  print(f" 总耗时: {time.time() - t_start:.1f}s")
527
  print("=" * 60)
528
 
529
+ _prog(1.0, "分析完成,正在生成研报...")
530
  return {
531
  "paper": paper,
532
  "direction": direction,