Spaces:
Running
Running
Upload run.py with huggingface_hub
Browse files
run.py
CHANGED
|
@@ -251,29 +251,28 @@ def _search_s2_papers(query: str, limit: int = 5) -> list[dict]:
|
|
| 251 |
return papers
|
| 252 |
|
| 253 |
|
| 254 |
-
def run(arxiv_url: str, top_n: int = 5) -> dict:
|
| 255 |
"""主入口:先搜索 GitHub → 再基于实际仓库归纳方法族 → 最后评估。
|
| 256 |
|
| 257 |
-
新流程(方案 B):
|
| 258 |
-
第一层:论文信息获取 (Workflow)
|
| 259 |
-
第二层:宽泛 GitHub 搜索 (Workflow)
|
| 260 |
-
第三层:基于仓库数据归纳方法族 (Agent 1) ← 不再是 LLM 凭记忆
|
| 261 |
-
第四层:仓库评估 (Agent 2)
|
| 262 |
-
|
| 263 |
Args:
|
| 264 |
arxiv_url: arxiv 论文 URL
|
| 265 |
top_n: 最终评估的仓库数量
|
|
|
|
| 266 |
|
| 267 |
Returns:
|
| 268 |
dict: {"paper": {...}, "direction": {...}, "repos": [{...}], "error": "..."}
|
| 269 |
"""
|
| 270 |
t_start = time.time()
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
# ================================================================
|
| 273 |
-
# [1/
|
| 274 |
# ================================================================
|
| 275 |
print("=" * 60)
|
| 276 |
-
print("[1/
|
|
|
|
| 277 |
t0 = time.time()
|
| 278 |
try:
|
| 279 |
paper = fetch_paper_info(arxiv_url)
|
|
@@ -292,6 +291,7 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
|
|
| 292 |
# ================================================================
|
| 293 |
print()
|
| 294 |
print("[1.5] 正在从论文引用网络挖掘对比实验算法...")
|
|
|
|
| 295 |
t0 = time.time()
|
| 296 |
arxiv_id = paper.get("arxiv_id", "")
|
| 297 |
comparison_queries, citation_map = _mine_comparison_algorithms(arxiv_id, title, abstract)
|
|
@@ -306,6 +306,7 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
|
|
| 306 |
# ================================================================
|
| 307 |
print()
|
| 308 |
print("[1.8] 正在 Semantic Scholar 中搜索同领域论文(覆盖全工科)...")
|
|
|
|
| 309 |
t0 = time.time()
|
| 310 |
title_kws = _extract_title_keywords(title)
|
| 311 |
s2_papers = _search_s2_papers(" ".join(title_kws), limit=8)
|
|
@@ -327,10 +328,11 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
|
|
| 327 |
print(f" 耗时: {time.time() - t0:.1f}s")
|
| 328 |
|
| 329 |
# ================================================================
|
| 330 |
-
# [2/
|
| 331 |
# ================================================================
|
| 332 |
print()
|
| 333 |
-
print("[2/
|
|
|
|
| 334 |
t0 = time.time()
|
| 335 |
|
| 336 |
broad_queries = _extract_broad_queries(title, abstract)
|
|
@@ -364,15 +366,17 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
|
|
| 364 |
# ================================================================
|
| 365 |
print()
|
| 366 |
print("[2.5] 正在用 LLM 过滤不相关仓库...")
|
|
|
|
| 367 |
t0 = time.time()
|
| 368 |
filtered_results = _filter_repos(title, abstract, broad_results)
|
| 369 |
print(f" 耗时: {time.time() - t0:.1f}s")
|
| 370 |
|
| 371 |
# ================================================================
|
| 372 |
-
# [2.6] 领域上下文扩充(
|
| 373 |
# ================================================================
|
| 374 |
print()
|
| 375 |
print("[2.6] 正在扩充领域上下文(搜索相关综述论文)...")
|
|
|
|
| 376 |
t0 = time.time()
|
| 377 |
domain_context = _enrich_domain_context(title, abstract, paper.get("categories", []))
|
| 378 |
if domain_context:
|
|
@@ -382,10 +386,11 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
|
|
| 382 |
print(f" 耗时: {time.time() - t0:.1f}s")
|
| 383 |
|
| 384 |
# ================================================================
|
| 385 |
-
# [3/
|
| 386 |
# ================================================================
|
| 387 |
print()
|
| 388 |
-
print(f"[3/
|
|
|
|
| 389 |
t0 = time.time()
|
| 390 |
|
| 391 |
try:
|
|
@@ -410,10 +415,11 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
|
|
| 410 |
print(f" 耗时: {time.time() - t0:.1f}s")
|
| 411 |
|
| 412 |
# ================================================================
|
| 413 |
-
# [4/
|
| 414 |
# ================================================================
|
| 415 |
print()
|
| 416 |
-
print(f"[4/
|
|
|
|
| 417 |
t0 = time.time()
|
| 418 |
|
| 419 |
# 从 Agent 1 的 matched_repos 中建立 full_name → family_name 映射
|
|
@@ -439,10 +445,11 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
|
|
| 439 |
print(f" 耗时: {time.time() - t0:.1f}s")
|
| 440 |
|
| 441 |
# ================================================================
|
| 442 |
-
# [5/
|
| 443 |
# ================================================================
|
| 444 |
print()
|
| 445 |
-
print(f"[5/
|
|
|
|
| 446 |
t0 = time.time()
|
| 447 |
|
| 448 |
def _eval_single(repo, idx, total):
|
|
@@ -508,6 +515,7 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
|
|
| 508 |
# ================================================================
|
| 509 |
# 审核层:检查 Agent 输出质量和来源可靠性
|
| 510 |
# ================================================================
|
|
|
|
| 511 |
audit = supervise(title, abstract, direction, evaluated)
|
| 512 |
print(f"\n 审核结果: {audit['summary']}")
|
| 513 |
print(f" 综合质量评分: {audit['overall_score']}/100")
|
|
@@ -518,6 +526,7 @@ def run(arxiv_url: str, top_n: int = 5) -> dict:
|
|
| 518 |
print(f" 总耗时: {time.time() - t_start:.1f}s")
|
| 519 |
print("=" * 60)
|
| 520 |
|
|
|
|
| 521 |
return {
|
| 522 |
"paper": paper,
|
| 523 |
"direction": direction,
|
|
|
|
| 251 |
return papers
|
| 252 |
|
| 253 |
|
| 254 |
+
def run(arxiv_url: str, top_n: int = 5, progress=None) -> dict:
|
| 255 |
"""主入口:先搜索 GitHub → 再基于实际仓库归纳方法族 → 最后评估。
|
| 256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
Args:
|
| 258 |
arxiv_url: arxiv 论文 URL
|
| 259 |
top_n: 最终评估的仓库数量
|
| 260 |
+
progress: gr.Progress 实例或 callable(fraction, desc),用于前端进度条
|
| 261 |
|
| 262 |
Returns:
|
| 263 |
dict: {"paper": {...}, "direction": {...}, "repos": [{...}], "error": "..."}
|
| 264 |
"""
|
| 265 |
t_start = time.time()
|
| 266 |
+
def _prog(frac: float, desc: str):
|
| 267 |
+
if progress:
|
| 268 |
+
progress(frac, desc=desc)
|
| 269 |
|
| 270 |
# ================================================================
|
| 271 |
+
# [1/6] 论文信息获取(甲 Workflow)
|
| 272 |
# ================================================================
|
| 273 |
print("=" * 60)
|
| 274 |
+
print("[1/6] 正在获取论文信息...")
|
| 275 |
+
_prog(0.05, "正在从 arxiv 获取论文信息...")
|
| 276 |
t0 = time.time()
|
| 277 |
try:
|
| 278 |
paper = fetch_paper_info(arxiv_url)
|
|
|
|
| 291 |
# ================================================================
|
| 292 |
print()
|
| 293 |
print("[1.5] 正在从论文引用网络挖掘对比实验算法...")
|
| 294 |
+
_prog(0.12, "正在从引用网络挖掘对比算法...")
|
| 295 |
t0 = time.time()
|
| 296 |
arxiv_id = paper.get("arxiv_id", "")
|
| 297 |
comparison_queries, citation_map = _mine_comparison_algorithms(arxiv_id, title, abstract)
|
|
|
|
| 306 |
# ================================================================
|
| 307 |
print()
|
| 308 |
print("[1.8] 正在 Semantic Scholar 中搜索同领域论文(覆盖全工科)...")
|
| 309 |
+
_prog(0.18, "正在 Semantic Scholar 搜索同领域论文...")
|
| 310 |
t0 = time.time()
|
| 311 |
title_kws = _extract_title_keywords(title)
|
| 312 |
s2_papers = _search_s2_papers(" ".join(title_kws), limit=8)
|
|
|
|
| 328 |
print(f" 耗时: {time.time() - t0:.1f}s")
|
| 329 |
|
| 330 |
# ================================================================
|
| 331 |
+
# [2/6] 宽泛 GitHub 搜索(甲 Workflow)
|
| 332 |
# ================================================================
|
| 333 |
print()
|
| 334 |
+
print("[2/6] 正在宽泛搜索 GitHub(先搜仓库,再让 Agent 分析)...")
|
| 335 |
+
_prog(0.22, "正在 GitHub 搜索开源仓库...")
|
| 336 |
t0 = time.time()
|
| 337 |
|
| 338 |
broad_queries = _extract_broad_queries(title, abstract)
|
|
|
|
| 366 |
# ================================================================
|
| 367 |
print()
|
| 368 |
print("[2.5] 正在用 LLM 过滤不相关仓库...")
|
| 369 |
+
_prog(0.32, "正在用 LLM 过滤不相关仓库...")
|
| 370 |
t0 = time.time()
|
| 371 |
filtered_results = _filter_repos(title, abstract, broad_results)
|
| 372 |
print(f" 耗时: {time.time() - t0:.1f}s")
|
| 373 |
|
| 374 |
# ================================================================
|
| 375 |
+
# [2.6] 领域上下文扩充(S2 搜索综述论文补充领域知识)
|
| 376 |
# ================================================================
|
| 377 |
print()
|
| 378 |
print("[2.6] 正在扩充领域上下文(搜索相关综述论文)...")
|
| 379 |
+
_prog(0.38, "正在扩充领域上下文...")
|
| 380 |
t0 = time.time()
|
| 381 |
domain_context = _enrich_domain_context(title, abstract, paper.get("categories", []))
|
| 382 |
if domain_context:
|
|
|
|
| 386 |
print(f" 耗时: {time.time() - t0:.1f}s")
|
| 387 |
|
| 388 |
# ================================================================
|
| 389 |
+
# [3/6] 基于仓库数据归纳方法族(Agent 1)
|
| 390 |
# ================================================================
|
| 391 |
print()
|
| 392 |
+
print(f"[3/6] 正在分析 {len(filtered_results)} 个仓库,归纳方法族(Agent 1)...")
|
| 393 |
+
_prog(0.42, "正在用 LLM 归纳方法族...")
|
| 394 |
t0 = time.time()
|
| 395 |
|
| 396 |
try:
|
|
|
|
| 415 |
print(f" 耗时: {time.time() - t0:.1f}s")
|
| 416 |
|
| 417 |
# ================================================================
|
| 418 |
+
# [4/6] 筛选仓库 + 构建��法族归属映射
|
| 419 |
# ================================================================
|
| 420 |
print()
|
| 421 |
+
print(f"[4/6] 正在筛选并获取仓库详情...")
|
| 422 |
+
_prog(0.55, "正在筛选仓库并获取详情...")
|
| 423 |
t0 = time.time()
|
| 424 |
|
| 425 |
# 从 Agent 1 的 matched_repos 中建立 full_name → family_name 映射
|
|
|
|
| 445 |
print(f" 耗时: {time.time() - t0:.1f}s")
|
| 446 |
|
| 447 |
# ================================================================
|
| 448 |
+
# [5/6] 仓库详情获取 + 评估(Agent 2,并行)
|
| 449 |
# ================================================================
|
| 450 |
print()
|
| 451 |
+
print(f"[5/6] 正在获取仓库详情并评估(Agent 2,{len(candidates)} 个仓库并行)...")
|
| 452 |
+
_prog(0.60, f"正在评估 {len(candidates)} 个开源仓库...")
|
| 453 |
t0 = time.time()
|
| 454 |
|
| 455 |
def _eval_single(repo, idx, total):
|
|
|
|
| 515 |
# ================================================================
|
| 516 |
# 审核层:检查 Agent 输出质量和来源可靠性
|
| 517 |
# ================================================================
|
| 518 |
+
_prog(0.90, "正在进行质量审核...")
|
| 519 |
audit = supervise(title, abstract, direction, evaluated)
|
| 520 |
print(f"\n 审核结果: {audit['summary']}")
|
| 521 |
print(f" 综合质量评分: {audit['overall_score']}/100")
|
|
|
|
| 526 |
print(f" 总耗时: {time.time() - t_start:.1f}s")
|
| 527 |
print("=" * 60)
|
| 528 |
|
| 529 |
+
_prog(1.0, "分析完成,正在生成研报...")
|
| 530 |
return {
|
| 531 |
"paper": paper,
|
| 532 |
"direction": direction,
|