akseljoonas HF Staff commited on
Commit
d81613f
·
1 Parent(s): 210a259

Add Semantic Scholar integration: citation graph, filtered search, snippet search, recommendations

Browse files

Extends hf_papers with S2 backend:
- citation_graph: references + citations with influence flags and intents
- Smart search routing: date/category/citation filters via S2 bulk search
- snippet_search: semantic full-text passage search across 12M+ papers
- recommend: single/multi-paper recommendations with pos/neg examples
- paper_details enriched with citation count, TL;DR, fields of study
- Rate-limited S2 client with 2 retries on 429/5xx
- Research prompt: paper analysis checklist, new tool guidance
- Frontend shimmer labels for new operations

agent/tools/papers_tool.py CHANGED
@@ -2,11 +2,14 @@
2
  HF Papers Tool — Discover papers, read their contents, and find linked resources.
3
 
4
  Operations: trending, search, paper_details, read_paper,
5
- find_datasets, find_models, find_collections, find_all_resources
 
6
  """
7
 
8
  import asyncio
 
9
  import re
 
10
  from typing import Any
11
 
12
  import httpx
@@ -30,6 +33,75 @@ SORT_MAP = {
30
  "trending": "trendingScore",
31
  }
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  # ---------------------------------------------------------------------------
35
  # HTML paper parsing
@@ -193,7 +265,7 @@ def _format_paper_list(
193
  return "\n".join(lines)
194
 
195
 
196
- def _format_paper_detail(paper: dict) -> str:
197
  arxiv_id = paper.get("id", "")
198
  title = paper.get("title", "Unknown")
199
  upvotes = paper.get("upvotes", 0)
@@ -205,7 +277,12 @@ def _format_paper_detail(paper: dict) -> str:
205
  authors = paper.get("authors") or []
206
 
207
  lines = [f"# {title}"]
208
- lines.append(f"**arxiv_id:** {arxiv_id} | **upvotes:** {upvotes}")
 
 
 
 
 
209
  lines.append(f"https://huggingface.co/papers/{arxiv_id}")
210
  lines.append(f"https://arxiv.org/abs/{arxiv_id}")
211
 
@@ -218,16 +295,27 @@ def _format_paper_detail(paper: dict) -> str:
218
 
219
  if keywords:
220
  lines.append(f"**Keywords:** {', '.join(keywords)}")
 
 
 
 
 
 
221
  if github:
222
  lines.append(f"**GitHub:** {github} ({stars} stars)")
223
 
 
 
 
 
224
  if ai_summary:
225
  lines.append(f"\n## AI Summary\n{ai_summary}")
226
  if summary:
227
  lines.append(f"\n## Abstract\n{_truncate(summary, 500)}")
228
 
229
  lines.append(
230
- "\n**Next:** Use read_paper to read specific sections, or find_all_resources to discover linked datasets/models."
 
231
  )
232
  return "\n".join(lines)
233
 
@@ -441,11 +529,101 @@ async def _op_trending(args: dict[str, Any], limit: int) -> ToolResult:
441
  }
442
 
443
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
  async def _op_search(args: dict[str, Any], limit: int) -> ToolResult:
445
  query = args.get("query")
446
  if not query:
447
  return _error("'query' is required for search operation.")
448
 
 
 
 
 
 
 
 
 
449
  async with httpx.AsyncClient(timeout=15) as client:
450
  resp = await client.get(
451
  f"{HF_API}/papers/search", params={"q": query, "limit": limit}
@@ -473,13 +651,21 @@ async def _op_paper_details(args: dict[str, Any], limit: int) -> ToolResult:
473
  if not arxiv_id:
474
  return _error("'arxiv_id' is required for paper_details.")
475
 
 
 
476
  async with httpx.AsyncClient(timeout=15) as client:
477
- resp = await client.get(f"{HF_API}/papers/{arxiv_id}")
478
- resp.raise_for_status()
479
- paper = resp.json()
 
 
 
 
 
 
480
 
481
  return {
482
- "formatted": _format_paper_detail(paper),
483
  "totalResults": 1,
484
  "resultsShared": 1,
485
  }
@@ -545,6 +731,113 @@ async def _op_read_paper(args: dict[str, Any], limit: int) -> ToolResult:
545
  return {"formatted": formatted, "totalResults": 1, "resultsShared": 1}
546
 
547
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548
  async def _op_find_datasets(args: dict[str, Any], limit: int) -> ToolResult:
549
  arxiv_id = _validate_arxiv_id(args)
550
  if not arxiv_id:
@@ -703,6 +996,135 @@ async def _op_find_all_resources(args: dict[str, Any], limit: int) -> ToolResult
703
  return {"formatted": formatted, "totalResults": total, "resultsShared": total}
704
 
705
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706
  # ---------------------------------------------------------------------------
707
  # Operation dispatch
708
  # ---------------------------------------------------------------------------
@@ -712,6 +1134,9 @@ _OPERATIONS = {
712
  "search": _op_search,
713
  "paper_details": _op_paper_details,
714
  "read_paper": _op_read_paper,
 
 
 
715
  "find_datasets": _op_find_datasets,
716
  "find_models": _op_find_models,
717
  "find_collections": _op_find_collections,
@@ -726,22 +1151,25 @@ _OPERATIONS = {
726
  HF_PAPERS_TOOL_SPEC = {
727
  "name": "hf_papers",
728
  "description": (
729
- "Discover ML research papers, find their linked resources (datasets, models, collections), "
730
- "and read paper contents on HuggingFace Hub and arXiv.\n\n"
731
- "Use this when exploring a research area, looking for datasets for a task, "
732
- "implementing a paper's approach, or trying to improve performance on something. "
733
- "Typical flow:\n"
734
- " hf_papers(search/trending)hf_papers(read_paper)hf_papers(find_all_resources)hf_inspect_dataset\n\n"
 
735
  "Operations:\n"
736
  "- trending: Get trending daily papers, optionally filter by topic keyword\n"
737
- "- search: Full-text search for papers by query\n"
738
- "- paper_details: Get metadata, abstract, AI summary, and github link for a paper\n"
739
- "- read_paper: Read paper contents — without section: returns abstract + table of contents; "
740
- "with section: returns full section text\n"
 
 
741
  "- find_datasets: Find datasets linked to a paper\n"
742
  "- find_models: Find models linked to a paper\n"
743
  "- find_collections: Find collections that include a paper\n"
744
- "- find_all_resources: Parallel fetch of datasets + models + collections for a paper (unified view)"
745
  ),
746
  "parameters": {
747
  "type": "object",
@@ -754,36 +1182,69 @@ HF_PAPERS_TOOL_SPEC = {
754
  "query": {
755
  "type": "string",
756
  "description": (
757
- "Search query. Required for: search. "
758
- "Optional for: trending (filters results by keyword match on title, summary, and AI-generated keywords)."
 
759
  ),
760
  },
761
  "arxiv_id": {
762
  "type": "string",
763
  "description": (
764
  "ArXiv paper ID (e.g. '2305.18290'). "
765
- "Required for: paper_details, read_paper, find_datasets, find_models, find_collections, find_all_resources. "
766
- "Get IDs from trending or search results first."
767
  ),
768
  },
769
  "section": {
770
  "type": "string",
771
  "description": (
772
  "Section name or number to read (e.g. '3', 'Experiments', '4.2'). "
773
- "Optional for: read_paper. Without this, read_paper returns the abstract + table of contents "
774
- "so you can choose which section to read."
775
  ),
776
  },
 
 
 
 
 
777
  "date": {
778
  "type": "string",
779
  "description": "Date in YYYY-MM-DD format. Optional for: trending (defaults to recent papers).",
780
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
781
  "sort": {
782
  "type": "string",
783
  "enum": ["downloads", "likes", "trending"],
784
  "description": (
785
- "Sort order for find_datasets and find_models. Default: downloads. "
786
- "Use 'downloads' for most-used, 'likes' for community favorites, 'trending' for recently popular."
787
  ),
788
  },
789
  "limit": {
 
2
  HF Papers Tool — Discover papers, read their contents, and find linked resources.
3
 
4
  Operations: trending, search, paper_details, read_paper,
5
+ find_datasets, find_models, find_collections, find_all_resources,
6
+ citation_graph, snippet_search, recommend
7
  """
8
 
9
  import asyncio
10
+ import os
11
  import re
12
+ import time
13
  from typing import Any
14
 
15
  import httpx
 
33
  "trending": "trendingScore",
34
  }
35
 
36
+ # ---------------------------------------------------------------------------
37
+ # Semantic Scholar API
38
+ # ---------------------------------------------------------------------------
39
+
40
+ S2_API = "https://api.semanticscholar.org"
41
+ S2_API_KEY = os.environ.get("S2_API_KEY")
42
+ S2_HEADERS: dict[str, str] = {"x-api-key": S2_API_KEY} if S2_API_KEY else {}
43
+ S2_TIMEOUT = 12
44
+ _s2_last_request: float = 0.0
45
+
46
+
47
+ def _s2_paper_id(arxiv_id: str) -> str:
48
+ """Convert bare arxiv ID to S2 format."""
49
+ return f"ARXIV:{arxiv_id}"
50
+
51
+
52
+ async def _s2_request(
53
+ client: httpx.AsyncClient,
54
+ method: str,
55
+ path: str,
56
+ **kwargs: Any,
57
+ ) -> httpx.Response | None:
58
+ """Rate-limited S2 request with 2 retries on 429/5xx."""
59
+ global _s2_last_request
60
+ url = f"{S2_API}{path}"
61
+ kwargs.setdefault("headers", {}).update(S2_HEADERS)
62
+ kwargs.setdefault("timeout", S2_TIMEOUT)
63
+
64
+ for attempt in range(3):
65
+ # Rate limit: 1 request per second
66
+ elapsed = time.monotonic() - _s2_last_request
67
+ if elapsed < 1.0:
68
+ await asyncio.sleep(1.0 - elapsed)
69
+ _s2_last_request = time.monotonic()
70
+
71
+ try:
72
+ resp = await client.request(method, url, **kwargs)
73
+ if resp.status_code == 429:
74
+ if attempt < 2:
75
+ await asyncio.sleep(60)
76
+ continue
77
+ return None
78
+ if resp.status_code >= 500:
79
+ if attempt < 2:
80
+ await asyncio.sleep(3)
81
+ continue
82
+ return None
83
+ return resp
84
+ except (httpx.RequestError, httpx.HTTPStatusError):
85
+ if attempt < 2:
86
+ await asyncio.sleep(3)
87
+ continue
88
+ return None
89
+ return None
90
+
91
+
92
+ async def _s2_get_paper(
93
+ client: httpx.AsyncClient, arxiv_id: str, fields: str,
94
+ ) -> dict | None:
95
+ """Fetch a single paper from S2 by arxiv ID. Returns None on failure."""
96
+ resp = await _s2_request(
97
+ client, "GET",
98
+ f"/graph/v1/paper/{_s2_paper_id(arxiv_id)}",
99
+ params={"fields": fields},
100
+ )
101
+ if resp and resp.status_code == 200:
102
+ return resp.json()
103
+ return None
104
+
105
 
106
  # ---------------------------------------------------------------------------
107
  # HTML paper parsing
 
265
  return "\n".join(lines)
266
 
267
 
268
+ def _format_paper_detail(paper: dict, s2_data: dict | None = None) -> str:
269
  arxiv_id = paper.get("id", "")
270
  title = paper.get("title", "Unknown")
271
  upvotes = paper.get("upvotes", 0)
 
277
  authors = paper.get("authors") or []
278
 
279
  lines = [f"# {title}"]
280
+ meta_parts = [f"**arxiv_id:** {arxiv_id}", f"**upvotes:** {upvotes}"]
281
+ if s2_data:
282
+ cites = s2_data.get("citationCount", 0)
283
+ influential = s2_data.get("influentialCitationCount", 0)
284
+ meta_parts.append(f"**citations:** {cites} ({influential} influential)")
285
+ lines.append(" | ".join(meta_parts))
286
  lines.append(f"https://huggingface.co/papers/{arxiv_id}")
287
  lines.append(f"https://arxiv.org/abs/{arxiv_id}")
288
 
 
295
 
296
  if keywords:
297
  lines.append(f"**Keywords:** {', '.join(keywords)}")
298
+ if s2_data and s2_data.get("s2FieldsOfStudy"):
299
+ fields = [f["category"] for f in s2_data["s2FieldsOfStudy"] if f.get("category")]
300
+ if fields:
301
+ lines.append(f"**Fields:** {', '.join(fields)}")
302
+ if s2_data and s2_data.get("venue"):
303
+ lines.append(f"**Venue:** {s2_data['venue']}")
304
  if github:
305
  lines.append(f"**GitHub:** {github} ({stars} stars)")
306
 
307
+ if s2_data and s2_data.get("tldr"):
308
+ tldr_text = s2_data["tldr"].get("text", "")
309
+ if tldr_text:
310
+ lines.append(f"\n## TL;DR\n{tldr_text}")
311
  if ai_summary:
312
  lines.append(f"\n## AI Summary\n{ai_summary}")
313
  if summary:
314
  lines.append(f"\n## Abstract\n{_truncate(summary, 500)}")
315
 
316
  lines.append(
317
+ "\n**Next:** Use read_paper to read specific sections, find_all_resources for linked datasets/models, "
318
+ "or citation_graph to trace references and citations."
319
  )
320
  return "\n".join(lines)
321
 
 
529
  }
530
 
531
 
532
+ def _format_s2_paper_list(papers: list[dict], title: str) -> str:
533
+ """Format a list of S2 paper results."""
534
+ lines = [f"# {title}"]
535
+ lines.append(f"Showing {len(papers)} result(s)\n")
536
+
537
+ for i, paper in enumerate(papers, 1):
538
+ ptitle = paper.get("title") or "(untitled)"
539
+ year = paper.get("year") or "?"
540
+ cites = paper.get("citationCount", 0)
541
+ venue = paper.get("venue") or ""
542
+ ext_ids = paper.get("externalIds") or {}
543
+ aid = ext_ids.get("ArXiv", "")
544
+ tldr = (paper.get("tldr") or {}).get("text", "")
545
+
546
+ lines.append(f"### {i}. {ptitle}")
547
+ meta = [f"Year: {year}", f"Citations: {cites}"]
548
+ if venue:
549
+ meta.append(f"Venue: {venue}")
550
+ if aid:
551
+ meta.append(f"arxiv_id: {aid}")
552
+ lines.append(" | ".join(meta))
553
+ if aid:
554
+ lines.append(f"https://arxiv.org/abs/{aid}")
555
+ if tldr:
556
+ lines.append(f"**TL;DR:** {tldr}")
557
+ lines.append("")
558
+
559
+ lines.append("Use paper_details with arxiv_id for full info, or read_paper to read sections.")
560
+ return "\n".join(lines)
561
+
562
+
563
+ async def _s2_bulk_search(query: str, args: dict[str, Any], limit: int) -> ToolResult | None:
564
+ """Search via S2 bulk endpoint with filters. Returns None on failure."""
565
+ params: dict[str, Any] = {
566
+ "query": query,
567
+ "limit": limit,
568
+ "fields": "title,externalIds,year,citationCount,tldr,venue,publicationDate",
569
+ }
570
+
571
+ # Date filter
572
+ date_from = args.get("date_from", "")
573
+ date_to = args.get("date_to", "")
574
+ if date_from or date_to:
575
+ params["publicationDateOrYear"] = f"{date_from}:{date_to}"
576
+
577
+ # Fields of study
578
+ categories = args.get("categories")
579
+ if categories:
580
+ params["fieldsOfStudy"] = categories
581
+
582
+ # Min citations
583
+ min_cites = args.get("min_citations")
584
+ if min_cites:
585
+ params["minCitationCount"] = str(min_cites)
586
+
587
+ # Sort
588
+ sort_by = args.get("sort_by")
589
+ if sort_by and sort_by != "relevance":
590
+ params["sort"] = f"{sort_by}:desc"
591
+
592
+ async with httpx.AsyncClient(timeout=15) as client:
593
+ resp = await _s2_request(client, "GET", "/graph/v1/paper/search/bulk", params=params)
594
+ if not resp or resp.status_code != 200:
595
+ return None
596
+ data = resp.json()
597
+
598
+ papers = data.get("data") or []
599
+ if not papers:
600
+ return {
601
+ "formatted": f"No papers found for '{query}' with the given filters.",
602
+ "totalResults": 0,
603
+ "resultsShared": 0,
604
+ }
605
+
606
+ formatted = _format_s2_paper_list(papers[:limit], f"Papers matching '{query}' (Semantic Scholar)")
607
+ return {
608
+ "formatted": formatted,
609
+ "totalResults": data.get("total", len(papers)),
610
+ "resultsShared": min(limit, len(papers)),
611
+ }
612
+
613
+
614
  async def _op_search(args: dict[str, Any], limit: int) -> ToolResult:
615
  query = args.get("query")
616
  if not query:
617
  return _error("'query' is required for search operation.")
618
 
619
+ # Route to S2 when filters are present
620
+ use_s2 = any(args.get(k) for k in ("date_from", "date_to", "categories", "min_citations", "sort_by"))
621
+ if use_s2:
622
+ result = await _s2_bulk_search(query, args, limit)
623
+ if result is not None:
624
+ return result
625
+ # Fall back to HF search (without filters) if S2 fails
626
+
627
  async with httpx.AsyncClient(timeout=15) as client:
628
  resp = await client.get(
629
  f"{HF_API}/papers/search", params={"q": query, "limit": limit}
 
651
  if not arxiv_id:
652
  return _error("'arxiv_id' is required for paper_details.")
653
 
654
+ s2_fields = "citationCount,influentialCitationCount,tldr,s2FieldsOfStudy,venue"
655
+
656
  async with httpx.AsyncClient(timeout=15) as client:
657
+ hf_coro = client.get(f"{HF_API}/papers/{arxiv_id}")
658
+ s2_coro = _s2_get_paper(client, arxiv_id, s2_fields)
659
+ hf_resp, s2_data = await asyncio.gather(hf_coro, s2_coro, return_exceptions=True)
660
+
661
+ if isinstance(hf_resp, Exception):
662
+ raise hf_resp
663
+ hf_resp.raise_for_status()
664
+ paper = hf_resp.json()
665
+ s2 = s2_data if isinstance(s2_data, dict) else None
666
 
667
  return {
668
+ "formatted": _format_paper_detail(paper, s2),
669
  "totalResults": 1,
670
  "resultsShared": 1,
671
  }
 
731
  return {"formatted": formatted, "totalResults": 1, "resultsShared": 1}
732
 
733
 
734
+ # ---------------------------------------------------------------------------
735
+ # Citation graph (Semantic Scholar)
736
+ # ---------------------------------------------------------------------------
737
+
738
+
739
+ def _format_citation_entry(entry: dict, show_context: bool = False) -> str:
740
+ """Format a single citation/reference entry."""
741
+ paper = entry.get("citingPaper") or entry.get("citedPaper") or {}
742
+ title = paper.get("title") or "(untitled)"
743
+ year = paper.get("year") or "?"
744
+ cites = paper.get("citationCount", 0)
745
+ ext_ids = paper.get("externalIds") or {}
746
+ aid = ext_ids.get("ArXiv", "")
747
+ influential = " **[influential]**" if entry.get("isInfluential") else ""
748
+
749
+ parts = [f"- **{title}** ({year}, {cites} cites){influential}"]
750
+ if aid:
751
+ parts[0] += f" arxiv:{aid}"
752
+
753
+ if show_context:
754
+ intents = entry.get("intents") or []
755
+ if intents:
756
+ parts.append(f" Intent: {', '.join(intents)}")
757
+ contexts = entry.get("contexts") or []
758
+ for ctx in contexts[:2]:
759
+ if ctx:
760
+ parts.append(f" > {_truncate(ctx, 200)}")
761
+
762
+ return "\n".join(parts)
763
+
764
+
765
+ def _format_citation_graph(
766
+ arxiv_id: str,
767
+ references: list[dict] | None,
768
+ citations: list[dict] | None,
769
+ ) -> str:
770
+ lines = [f"# Citation Graph for {arxiv_id}"]
771
+ lines.append(f"https://arxiv.org/abs/{arxiv_id}\n")
772
+
773
+ if references is not None:
774
+ lines.append(f"## References ({len(references)})")
775
+ if references:
776
+ for entry in references:
777
+ lines.append(_format_citation_entry(entry))
778
+ else:
779
+ lines.append("No references found.")
780
+ lines.append("")
781
+
782
+ if citations is not None:
783
+ lines.append(f"## Citations ({len(citations)})")
784
+ if citations:
785
+ for entry in citations:
786
+ lines.append(_format_citation_entry(entry, show_context=True))
787
+ else:
788
+ lines.append("No citations found.")
789
+ lines.append("")
790
+
791
+ lines.append("**Tip:** Use paper_details with an arxiv_id from above to explore further.")
792
+ return "\n".join(lines)
793
+
794
+
795
+ async def _op_citation_graph(args: dict[str, Any], limit: int) -> ToolResult:
796
+ arxiv_id = _validate_arxiv_id(args)
797
+ if not arxiv_id:
798
+ return _error("'arxiv_id' is required for citation_graph.")
799
+
800
+ direction = args.get("direction", "both")
801
+ s2_id = _s2_paper_id(arxiv_id)
802
+ fields = "title,externalIds,year,citationCount,influentialCitationCount,contexts,intents,isInfluential"
803
+
804
+ async with httpx.AsyncClient(timeout=15) as client:
805
+ refs, cites = None, None
806
+ coros = []
807
+ if direction in ("references", "both"):
808
+ coros.append(
809
+ _s2_request(client, "GET", f"/graph/v1/paper/{s2_id}/references",
810
+ params={"fields": fields, "limit": limit})
811
+ )
812
+ if direction in ("citations", "both"):
813
+ coros.append(
814
+ _s2_request(client, "GET", f"/graph/v1/paper/{s2_id}/citations",
815
+ params={"fields": fields, "limit": limit})
816
+ )
817
+
818
+ results = await asyncio.gather(*coros, return_exceptions=True)
819
+ idx = 0
820
+ if direction in ("references", "both"):
821
+ r = results[idx]
822
+ if isinstance(r, httpx.Response) and r.status_code == 200:
823
+ refs = r.json().get("data", [])
824
+ idx += 1
825
+ if direction in ("citations", "both"):
826
+ r = results[idx]
827
+ if isinstance(r, httpx.Response) and r.status_code == 200:
828
+ cites = r.json().get("data", [])
829
+
830
+ if refs is None and cites is None:
831
+ return _error(f"Could not fetch citation data for {arxiv_id}. Paper may not be indexed by Semantic Scholar.")
832
+
833
+ total = (len(refs) if refs else 0) + (len(cites) if cites else 0)
834
+ return {
835
+ "formatted": _format_citation_graph(arxiv_id, refs, cites),
836
+ "totalResults": total,
837
+ "resultsShared": total,
838
+ }
839
+
840
+
841
  async def _op_find_datasets(args: dict[str, Any], limit: int) -> ToolResult:
842
  arxiv_id = _validate_arxiv_id(args)
843
  if not arxiv_id:
 
996
  return {"formatted": formatted, "totalResults": total, "resultsShared": total}
997
 
998
 
999
+ # ---------------------------------------------------------------------------
1000
+ # Snippet search (Semantic Scholar)
1001
+ # ---------------------------------------------------------------------------
1002
+
1003
+
1004
+ def _format_snippets(snippets: list[dict], query: str) -> str:
1005
+ lines = [f"# Snippet Search: '{query}'"]
1006
+ lines.append(f"Found {len(snippets)} matching passage(s)\n")
1007
+
1008
+ for i, item in enumerate(snippets, 1):
1009
+ paper = item.get("paper") or {}
1010
+ ptitle = paper.get("title") or "(untitled)"
1011
+ year = paper.get("year") or "?"
1012
+ cites = paper.get("citationCount", 0)
1013
+ ext_ids = paper.get("externalIds") or {}
1014
+ aid = ext_ids.get("ArXiv", "")
1015
+
1016
+ snippet = item.get("snippet") or {}
1017
+ text = snippet.get("text", "")
1018
+ section = snippet.get("section") or ""
1019
+
1020
+ lines.append(f"### {i}. {ptitle} ({year}, {cites} cites)")
1021
+ if aid:
1022
+ lines.append(f"arxiv:{aid}")
1023
+ if section:
1024
+ lines.append(f"Section: {section}")
1025
+ if text:
1026
+ lines.append(f"> {_truncate(text, 400)}")
1027
+ lines.append("")
1028
+
1029
+ lines.append("Use paper_details or read_paper with arxiv_id to explore a paper further.")
1030
+ return "\n".join(lines)
1031
+
1032
+
1033
+ async def _op_snippet_search(args: dict[str, Any], limit: int) -> ToolResult:
1034
+ query = args.get("query")
1035
+ if not query:
1036
+ return _error("'query' is required for snippet_search.")
1037
+
1038
+ params: dict[str, Any] = {
1039
+ "query": query,
1040
+ "limit": limit,
1041
+ "fields": "title,externalIds,year,citationCount",
1042
+ }
1043
+
1044
+ # Optional filters (same as search)
1045
+ date_from = args.get("date_from", "")
1046
+ date_to = args.get("date_to", "")
1047
+ if date_from or date_to:
1048
+ params["publicationDateOrYear"] = f"{date_from}:{date_to}"
1049
+ if args.get("categories"):
1050
+ params["fieldsOfStudy"] = args["categories"]
1051
+ if args.get("min_citations"):
1052
+ params["minCitationCount"] = str(args["min_citations"])
1053
+
1054
+ async with httpx.AsyncClient(timeout=15) as client:
1055
+ resp = await _s2_request(client, "GET", "/graph/v1/snippet/search", params=params)
1056
+ if not resp or resp.status_code != 200:
1057
+ return _error("Snippet search failed. Semantic Scholar may be unavailable.")
1058
+ data = resp.json()
1059
+
1060
+ snippets = data.get("data") or []
1061
+ if not snippets:
1062
+ return {
1063
+ "formatted": f"No snippets found for '{query}'.",
1064
+ "totalResults": 0,
1065
+ "resultsShared": 0,
1066
+ }
1067
+
1068
+ return {
1069
+ "formatted": _format_snippets(snippets, query),
1070
+ "totalResults": len(snippets),
1071
+ "resultsShared": len(snippets),
1072
+ }
1073
+
1074
+
1075
+ # ---------------------------------------------------------------------------
1076
+ # Recommendations (Semantic Scholar)
1077
+ # ---------------------------------------------------------------------------
1078
+
1079
+
1080
+ async def _op_recommend(args: dict[str, Any], limit: int) -> ToolResult:
1081
+ positive_ids = args.get("positive_ids")
1082
+ arxiv_id = _validate_arxiv_id(args)
1083
+
1084
+ if not arxiv_id and not positive_ids:
1085
+ return _error("'arxiv_id' or 'positive_ids' is required for recommend.")
1086
+
1087
+ fields = "title,externalIds,year,citationCount,tldr,venue"
1088
+
1089
+ async with httpx.AsyncClient(timeout=15) as client:
1090
+ if positive_ids and not arxiv_id:
1091
+ # Multi-paper recommendations
1092
+ pos = [_s2_paper_id(pid.strip()) for pid in positive_ids.split(",") if pid.strip()]
1093
+ neg_raw = args.get("negative_ids", "")
1094
+ neg = [_s2_paper_id(pid.strip()) for pid in neg_raw.split(",") if pid.strip()] if neg_raw else []
1095
+ resp = await _s2_request(
1096
+ client, "POST", "/recommendations/v1/papers/",
1097
+ json={"positivePaperIds": pos, "negativePaperIds": neg},
1098
+ params={"fields": fields, "limit": limit},
1099
+ )
1100
+ else:
1101
+ # Single-paper recommendations
1102
+ resp = await _s2_request(
1103
+ client, "GET",
1104
+ f"/recommendations/v1/papers/forpaper/{_s2_paper_id(arxiv_id)}",
1105
+ params={"fields": fields, "limit": limit, "from": "recent"},
1106
+ )
1107
+
1108
+ if not resp or resp.status_code != 200:
1109
+ return _error("Recommendation request failed. Semantic Scholar may be unavailable.")
1110
+ data = resp.json()
1111
+
1112
+ papers = data.get("recommendedPapers") or []
1113
+ if not papers:
1114
+ return {
1115
+ "formatted": "No recommendations found.",
1116
+ "totalResults": 0,
1117
+ "resultsShared": 0,
1118
+ }
1119
+
1120
+ title = f"Recommended papers based on {arxiv_id or positive_ids}"
1121
+ return {
1122
+ "formatted": _format_s2_paper_list(papers[:limit], title),
1123
+ "totalResults": len(papers),
1124
+ "resultsShared": min(limit, len(papers)),
1125
+ }
1126
+
1127
+
1128
  # ---------------------------------------------------------------------------
1129
  # Operation dispatch
1130
  # ---------------------------------------------------------------------------
 
1134
  "search": _op_search,
1135
  "paper_details": _op_paper_details,
1136
  "read_paper": _op_read_paper,
1137
+ "citation_graph": _op_citation_graph,
1138
+ "snippet_search": _op_snippet_search,
1139
+ "recommend": _op_recommend,
1140
  "find_datasets": _op_find_datasets,
1141
  "find_models": _op_find_models,
1142
  "find_collections": _op_find_collections,
 
1151
  HF_PAPERS_TOOL_SPEC = {
1152
  "name": "hf_papers",
1153
  "description": (
1154
+ "Discover ML research papers, analyze citations, search paper contents, and find linked resources.\n\n"
1155
+ "Combines HuggingFace Hub, arXiv, and Semantic Scholar. Use for exploring research areas, "
1156
+ "finding datasets for a task, tracing citation chains, or implementing a paper's approach.\n\n"
1157
+ "Typical flows:\n"
1158
+ " search → read_paper → find_all_resources → hf_inspect_dataset\n"
1159
+ " search → paper_detailscitation_graphread_paper (trace influence)\n"
1160
+ " snippet_search → paper_details → read_paper (find specific claims)\n\n"
1161
  "Operations:\n"
1162
  "- trending: Get trending daily papers, optionally filter by topic keyword\n"
1163
+ "- search: Search papers. Uses HF by default (ML-tuned). Add date_from/min_citations/categories to use Semantic Scholar with filters\n"
1164
+ "- paper_details: Metadata, abstract, AI summary, citation count, TL;DR\n"
1165
+ "- read_paper: Read paper contents — without section: abstract + TOC; with section: full text\n"
1166
+ "- citation_graph: Get references and citations for a paper with influence flags and citation intents\n"
1167
+ "- snippet_search: Semantic search over full-text passages from 12M+ papers\n"
1168
+ "- recommend: Find similar papers (single paper or positive/negative examples)\n"
1169
  "- find_datasets: Find datasets linked to a paper\n"
1170
  "- find_models: Find models linked to a paper\n"
1171
  "- find_collections: Find collections that include a paper\n"
1172
+ "- find_all_resources: Parallel fetch of datasets + models + collections for a paper"
1173
  ),
1174
  "parameters": {
1175
  "type": "object",
 
1182
  "query": {
1183
  "type": "string",
1184
  "description": (
1185
+ "Search query. Required for: search, snippet_search. "
1186
+ "Optional for: trending (filters by keyword). "
1187
+ "Supports boolean syntax for Semantic Scholar: '\"exact phrase\" term1 | term2'."
1188
  ),
1189
  },
1190
  "arxiv_id": {
1191
  "type": "string",
1192
  "description": (
1193
  "ArXiv paper ID (e.g. '2305.18290'). "
1194
+ "Required for: paper_details, read_paper, citation_graph, find_datasets, find_models, find_collections, find_all_resources. "
1195
+ "Optional for: recommend (single-paper recs). Get IDs from search results first."
1196
  ),
1197
  },
1198
  "section": {
1199
  "type": "string",
1200
  "description": (
1201
  "Section name or number to read (e.g. '3', 'Experiments', '4.2'). "
1202
+ "Optional for: read_paper. Without this, returns abstract + TOC."
 
1203
  ),
1204
  },
1205
+ "direction": {
1206
+ "type": "string",
1207
+ "enum": ["citations", "references", "both"],
1208
+ "description": "Direction for citation_graph. Default: both.",
1209
+ },
1210
  "date": {
1211
  "type": "string",
1212
  "description": "Date in YYYY-MM-DD format. Optional for: trending (defaults to recent papers).",
1213
  },
1214
+ "date_from": {
1215
+ "type": "string",
1216
+ "description": "Start date (YYYY-MM-DD). Triggers Semantic Scholar search. For: search, snippet_search.",
1217
+ },
1218
+ "date_to": {
1219
+ "type": "string",
1220
+ "description": "End date (YYYY-MM-DD). Triggers Semantic Scholar search. For: search, snippet_search.",
1221
+ },
1222
+ "categories": {
1223
+ "type": "string",
1224
+ "description": "Field of study filter (e.g. 'Computer Science'). Triggers Semantic Scholar search.",
1225
+ },
1226
+ "min_citations": {
1227
+ "type": "integer",
1228
+ "description": "Minimum citation count filter. Triggers Semantic Scholar search.",
1229
+ },
1230
+ "sort_by": {
1231
+ "type": "string",
1232
+ "enum": ["relevance", "citationCount", "publicationDate"],
1233
+ "description": "Sort order for Semantic Scholar search. Default: relevance.",
1234
+ },
1235
+ "positive_ids": {
1236
+ "type": "string",
1237
+ "description": "Comma-separated arxiv IDs for multi-paper recommendations. For: recommend.",
1238
+ },
1239
+ "negative_ids": {
1240
+ "type": "string",
1241
+ "description": "Comma-separated arxiv IDs as negative examples. For: recommend.",
1242
+ },
1243
  "sort": {
1244
  "type": "string",
1245
  "enum": ["downloads", "likes", "trending"],
1246
  "description": (
1247
+ "Sort order for find_datasets and find_models. Default: downloads."
 
1248
  ),
1249
  },
1250
  "limit": {
agent/tools/research_tool.py CHANGED
@@ -85,12 +85,30 @@ ML moves fast — a method from 6 months ago may already be obsolete.
85
  - DPO: needs "prompt", "chosen", "rejected"
86
  - GRPO: needs "prompt" only
87
 
88
- ## Papers
89
- - `hf_papers`: Search papers, get details, find linked datasets/models
 
 
 
 
 
 
90
 
91
  ## Hub repo inspection
92
  - `hf_repo_files`: List/read files in any HF repo (model, dataset, space)
93
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  # Correct research pattern for ML tasks
95
 
96
  ```
 
85
  - DPO: needs "prompt", "chosen", "rejected"
86
  - GRPO: needs "prompt" only
87
 
88
+ ## Papers & citations
89
+ - `hf_papers(operation="search", query=...)`: Search papers (HF-tuned for ML)
90
+ - `hf_papers(operation="search", query=..., min_citations=50, sort_by="citationCount")`: Find highly-cited papers via Semantic Scholar
91
+ - `hf_papers(operation="search", query=..., date_from="2024-01-01")`: Search with date filter
92
+ - `hf_papers(operation="paper_details", arxiv_id=...)`: Metadata, citations, TL;DR
93
+ - `hf_papers(operation="citation_graph", arxiv_id=...)`: References + citations with influence flags and intents
94
+ - `hf_papers(operation="snippet_search", query=...)`: Semantic search across 12M+ full-text paper passages
95
+ - `hf_papers(operation="recommend", arxiv_id=...)`: Find related papers
96
 
97
  ## Hub repo inspection
98
  - `hf_repo_files`: List/read files in any HF repo (model, dataset, space)
99
 
100
+ # Paper analysis checklist
101
+
102
+ When reading a paper, always extract:
103
+ - **Key claims**: What does the paper propose or demonstrate?
104
+ - **Methodology**: Architecture, training setup, key techniques
105
+ - **Results**: Benchmark numbers, comparisons to baselines
106
+ - **Limitations**: What the authors acknowledge or what seems missing
107
+
108
+ Use `citation_graph` to trace influence: check what a breakthrough paper cites (foundations)
109
+ and who cites it (impact and extensions). Use `snippet_search` to verify claims across
110
+ papers (e.g., "does method X consistently outperform Y?").
111
+
112
  # Correct research pattern for ML tasks
113
 
114
  ```
frontend/src/components/Chat/ActivityStatusBar.tsx CHANGED
@@ -34,9 +34,15 @@ function formatResearchStatus(raw: string): string {
34
  if (typeof v === 'string') args[k] = v;
35
  }
36
  } catch {
 
37
  for (const m of jsonStr.matchAll(/"(\w+)":\s*"([^"]*)"/g)) {
38
  args[m[1]] = m[2];
39
  }
 
 
 
 
 
40
  }
41
  }
42
 
@@ -62,12 +68,15 @@ function formatResearchStatus(raw: string): string {
62
  }
63
  if (toolName === 'hf_papers') {
64
  const op = args.operation as string;
65
- const detail = (args.query) || (args.arxiv_id);
66
  const opLabels: Record<string, string> = {
67
  trending: 'Browsing trending papers',
68
  search: 'Searching papers',
69
  paper_details: 'Reading paper details',
70
  read_paper: 'Reading paper',
 
 
 
71
  find_datasets: 'Finding paper datasets',
72
  find_models: 'Finding paper models',
73
  find_collections: 'Finding paper collections',
 
34
  if (typeof v === 'string') args[k] = v;
35
  }
36
  } catch {
37
+ // JSON is likely truncated — extract complete "key": "value" pairs
38
  for (const m of jsonStr.matchAll(/"(\w+)":\s*"([^"]*)"/g)) {
39
  args[m[1]] = m[2];
40
  }
41
+ // Also try to extract a truncated value for known keys if not found yet
42
+ if (!args.query && !args.arxiv_id) {
43
+ const partial = jsonStr.match(/"(query|arxiv_id)":\s*"([^"]*)/);
44
+ if (partial) args[partial[1]] = partial[2];
45
+ }
46
  }
47
  }
48
 
 
68
  }
69
  if (toolName === 'hf_papers') {
70
  const op = args.operation as string;
71
+ const detail = (args.query) || (args.arxiv_id) || (args.positive_ids);
72
  const opLabels: Record<string, string> = {
73
  trending: 'Browsing trending papers',
74
  search: 'Searching papers',
75
  paper_details: 'Reading paper details',
76
  read_paper: 'Reading paper',
77
+ citation_graph: 'Tracing citations',
78
+ snippet_search: 'Searching paper passages',
79
+ recommend: 'Finding similar papers',
80
  find_datasets: 'Finding paper datasets',
81
  find_models: 'Finding paper models',
82
  find_collections: 'Finding paper collections',