akseljoonas HF Staff commited on
Commit
84d6321
·
1 Parent(s): d81613f

Minimize S2 API calls: drop paper_details enrichment, add response cache, rate limit only with API key

Browse files
Files changed (1) hide show
  1. agent/tools/papers_tool.py +59 -45
agent/tools/papers_tool.py CHANGED
@@ -43,29 +43,41 @@ S2_HEADERS: dict[str, str] = {"x-api-key": S2_API_KEY} if S2_API_KEY else {}
43
  S2_TIMEOUT = 12
44
  _s2_last_request: float = 0.0
45
 
 
 
 
 
46
 
47
  def _s2_paper_id(arxiv_id: str) -> str:
48
  """Convert bare arxiv ID to S2 format."""
49
  return f"ARXIV:{arxiv_id}"
50
 
51
 
 
 
 
 
 
 
52
  async def _s2_request(
53
  client: httpx.AsyncClient,
54
  method: str,
55
  path: str,
56
  **kwargs: Any,
57
  ) -> httpx.Response | None:
58
- """Rate-limited S2 request with 2 retries on 429/5xx."""
59
  global _s2_last_request
60
  url = f"{S2_API}{path}"
61
  kwargs.setdefault("headers", {}).update(S2_HEADERS)
62
  kwargs.setdefault("timeout", S2_TIMEOUT)
63
 
64
  for attempt in range(3):
65
- # Rate limit: 1 request per second
66
- elapsed = time.monotonic() - _s2_last_request
67
- if elapsed < 1.0:
68
- await asyncio.sleep(1.0 - elapsed)
 
 
69
  _s2_last_request = time.monotonic()
70
 
71
  try:
@@ -89,18 +101,32 @@ async def _s2_request(
89
  return None
90
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  async def _s2_get_paper(
93
  client: httpx.AsyncClient, arxiv_id: str, fields: str,
94
  ) -> dict | None:
95
  """Fetch a single paper from S2 by arxiv ID. Returns None on failure."""
96
- resp = await _s2_request(
97
- client, "GET",
98
  f"/graph/v1/paper/{_s2_paper_id(arxiv_id)}",
99
- params={"fields": fields},
100
  )
101
- if resp and resp.status_code == 200:
102
- return resp.json()
103
- return None
104
 
105
 
106
  # ---------------------------------------------------------------------------
@@ -651,21 +677,13 @@ async def _op_paper_details(args: dict[str, Any], limit: int) -> ToolResult:
651
  if not arxiv_id:
652
  return _error("'arxiv_id' is required for paper_details.")
653
 
654
- s2_fields = "citationCount,influentialCitationCount,tldr,s2FieldsOfStudy,venue"
655
-
656
  async with httpx.AsyncClient(timeout=15) as client:
657
- hf_coro = client.get(f"{HF_API}/papers/{arxiv_id}")
658
- s2_coro = _s2_get_paper(client, arxiv_id, s2_fields)
659
- hf_resp, s2_data = await asyncio.gather(hf_coro, s2_coro, return_exceptions=True)
660
-
661
- if isinstance(hf_resp, Exception):
662
- raise hf_resp
663
- hf_resp.raise_for_status()
664
- paper = hf_resp.json()
665
- s2 = s2_data if isinstance(s2_data, dict) else None
666
 
667
  return {
668
- "formatted": _format_paper_detail(paper, s2),
669
  "totalResults": 1,
670
  "resultsShared": 1,
671
  }
@@ -800,32 +818,27 @@ async def _op_citation_graph(args: dict[str, Any], limit: int) -> ToolResult:
800
  direction = args.get("direction", "both")
801
  s2_id = _s2_paper_id(arxiv_id)
802
  fields = "title,externalIds,year,citationCount,influentialCitationCount,contexts,intents,isInfluential"
 
803
 
804
  async with httpx.AsyncClient(timeout=15) as client:
805
  refs, cites = None, None
806
  coros = []
807
  if direction in ("references", "both"):
808
- coros.append(
809
- _s2_request(client, "GET", f"/graph/v1/paper/{s2_id}/references",
810
- params={"fields": fields, "limit": limit})
811
- )
812
  if direction in ("citations", "both"):
813
- coros.append(
814
- _s2_request(client, "GET", f"/graph/v1/paper/{s2_id}/citations",
815
- params={"fields": fields, "limit": limit})
816
- )
817
 
818
  results = await asyncio.gather(*coros, return_exceptions=True)
819
  idx = 0
820
  if direction in ("references", "both"):
821
  r = results[idx]
822
- if isinstance(r, httpx.Response) and r.status_code == 200:
823
- refs = r.json().get("data", [])
824
  idx += 1
825
  if direction in ("citations", "both"):
826
  r = results[idx]
827
- if isinstance(r, httpx.Response) and r.status_code == 200:
828
- cites = r.json().get("data", [])
829
 
830
  if refs is None and cites is None:
831
  return _error(f"Could not fetch citation data for {arxiv_id}. Paper may not be indexed by Semantic Scholar.")
@@ -1088,7 +1101,7 @@ async def _op_recommend(args: dict[str, Any], limit: int) -> ToolResult:
1088
 
1089
  async with httpx.AsyncClient(timeout=15) as client:
1090
  if positive_ids and not arxiv_id:
1091
- # Multi-paper recommendations
1092
  pos = [_s2_paper_id(pid.strip()) for pid in positive_ids.split(",") if pid.strip()]
1093
  neg_raw = args.get("negative_ids", "")
1094
  neg = [_s2_paper_id(pid.strip()) for pid in neg_raw.split(",") if pid.strip()] if neg_raw else []
@@ -1097,17 +1110,18 @@ async def _op_recommend(args: dict[str, Any], limit: int) -> ToolResult:
1097
  json={"positivePaperIds": pos, "negativePaperIds": neg},
1098
  params={"fields": fields, "limit": limit},
1099
  )
 
 
 
1100
  else:
1101
- # Single-paper recommendations
1102
- resp = await _s2_request(
1103
- client, "GET",
1104
  f"/recommendations/v1/papers/forpaper/{_s2_paper_id(arxiv_id)}",
1105
- params={"fields": fields, "limit": limit, "from": "recent"},
1106
  )
1107
-
1108
- if not resp or resp.status_code != 200:
1109
- return _error("Recommendation request failed. Semantic Scholar may be unavailable.")
1110
- data = resp.json()
1111
 
1112
  papers = data.get("recommendedPapers") or []
1113
  if not papers:
@@ -1161,7 +1175,7 @@ HF_PAPERS_TOOL_SPEC = {
1161
  "Operations:\n"
1162
  "- trending: Get trending daily papers, optionally filter by topic keyword\n"
1163
  "- search: Search papers. Uses HF by default (ML-tuned). Add date_from/min_citations/categories to use Semantic Scholar with filters\n"
1164
- "- paper_details: Metadata, abstract, AI summary, citation count, TL;DR\n"
1165
  "- read_paper: Read paper contents — without section: abstract + TOC; with section: full text\n"
1166
  "- citation_graph: Get references and citations for a paper with influence flags and citation intents\n"
1167
  "- snippet_search: Semantic search over full-text passages from 12M+ papers\n"
 
43
  S2_TIMEOUT = 12
44
  _s2_last_request: float = 0.0
45
 
46
+ # Shared response cache (survives across sessions, keyed by (path, params_tuple))
47
+ _s2_cache: dict[str, Any] = {}
48
+ _S2_CACHE_MAX = 500
49
+
50
 
51
  def _s2_paper_id(arxiv_id: str) -> str:
52
  """Convert bare arxiv ID to S2 format."""
53
  return f"ARXIV:{arxiv_id}"
54
 
55
 
56
+ def _s2_cache_key(path: str, params: dict | None) -> str:
57
+ """Build a hashable cache key from path + sorted params."""
58
+ p = tuple(sorted((params or {}).items()))
59
+ return f"{path}:{p}"
60
+
61
+
62
  async def _s2_request(
63
  client: httpx.AsyncClient,
64
  method: str,
65
  path: str,
66
  **kwargs: Any,
67
  ) -> httpx.Response | None:
68
+ """S2 request with 2 retries on 429/5xx. Rate-limited only when using API key."""
69
  global _s2_last_request
70
  url = f"{S2_API}{path}"
71
  kwargs.setdefault("headers", {}).update(S2_HEADERS)
72
  kwargs.setdefault("timeout", S2_TIMEOUT)
73
 
74
  for attempt in range(3):
75
+ # Rate limit only when authenticated (1 req/s for search, 10 req/s for others)
76
+ if S2_API_KEY:
77
+ min_interval = 1.0 if "search" in path else 0.1
78
+ elapsed = time.monotonic() - _s2_last_request
79
+ if elapsed < min_interval:
80
+ await asyncio.sleep(min_interval - elapsed)
81
  _s2_last_request = time.monotonic()
82
 
83
  try:
 
101
  return None
102
 
103
 
104
+ async def _s2_get_json(
105
+ client: httpx.AsyncClient, path: str, params: dict | None = None,
106
+ ) -> dict | None:
107
+ """Cached S2 GET returning parsed JSON or None."""
108
+ key = _s2_cache_key(path, params)
109
+ if key in _s2_cache:
110
+ return _s2_cache[key]
111
+
112
+ resp = await _s2_request(client, "GET", path, params=params or {})
113
+ if resp and resp.status_code == 200:
114
+ data = resp.json()
115
+ if len(_s2_cache) < _S2_CACHE_MAX:
116
+ _s2_cache[key] = data
117
+ return data
118
+ return None
119
+
120
+
121
  async def _s2_get_paper(
122
  client: httpx.AsyncClient, arxiv_id: str, fields: str,
123
  ) -> dict | None:
124
  """Fetch a single paper from S2 by arxiv ID. Returns None on failure."""
125
+ return await _s2_get_json(
126
+ client,
127
  f"/graph/v1/paper/{_s2_paper_id(arxiv_id)}",
128
+ {"fields": fields},
129
  )
 
 
 
130
 
131
 
132
  # ---------------------------------------------------------------------------
 
677
  if not arxiv_id:
678
  return _error("'arxiv_id' is required for paper_details.")
679
 
 
 
680
  async with httpx.AsyncClient(timeout=15) as client:
681
+ resp = await client.get(f"{HF_API}/papers/{arxiv_id}")
682
+ resp.raise_for_status()
683
+ paper = resp.json()
 
 
 
 
 
 
684
 
685
  return {
686
+ "formatted": _format_paper_detail(paper),
687
  "totalResults": 1,
688
  "resultsShared": 1,
689
  }
 
818
  direction = args.get("direction", "both")
819
  s2_id = _s2_paper_id(arxiv_id)
820
  fields = "title,externalIds,year,citationCount,influentialCitationCount,contexts,intents,isInfluential"
821
+ params = {"fields": fields, "limit": limit}
822
 
823
  async with httpx.AsyncClient(timeout=15) as client:
824
  refs, cites = None, None
825
  coros = []
826
  if direction in ("references", "both"):
827
+ coros.append(_s2_get_json(client, f"/graph/v1/paper/{s2_id}/references", params))
 
 
 
828
  if direction in ("citations", "both"):
829
+ coros.append(_s2_get_json(client, f"/graph/v1/paper/{s2_id}/citations", params))
 
 
 
830
 
831
  results = await asyncio.gather(*coros, return_exceptions=True)
832
  idx = 0
833
  if direction in ("references", "both"):
834
  r = results[idx]
835
+ if isinstance(r, dict):
836
+ refs = r.get("data", [])
837
  idx += 1
838
  if direction in ("citations", "both"):
839
  r = results[idx]
840
+ if isinstance(r, dict):
841
+ cites = r.get("data", [])
842
 
843
  if refs is None and cites is None:
844
  return _error(f"Could not fetch citation data for {arxiv_id}. Paper may not be indexed by Semantic Scholar.")
 
1101
 
1102
  async with httpx.AsyncClient(timeout=15) as client:
1103
  if positive_ids and not arxiv_id:
1104
+ # Multi-paper recommendations (POST, not cached)
1105
  pos = [_s2_paper_id(pid.strip()) for pid in positive_ids.split(",") if pid.strip()]
1106
  neg_raw = args.get("negative_ids", "")
1107
  neg = [_s2_paper_id(pid.strip()) for pid in neg_raw.split(",") if pid.strip()] if neg_raw else []
 
1110
  json={"positivePaperIds": pos, "negativePaperIds": neg},
1111
  params={"fields": fields, "limit": limit},
1112
  )
1113
+ if not resp or resp.status_code != 200:
1114
+ return _error("Recommendation request failed. Semantic Scholar may be unavailable.")
1115
+ data = resp.json()
1116
  else:
1117
+ # Single-paper recommendations (cached)
1118
+ data = await _s2_get_json(
1119
+ client,
1120
  f"/recommendations/v1/papers/forpaper/{_s2_paper_id(arxiv_id)}",
1121
+ {"fields": fields, "limit": limit, "from": "recent"},
1122
  )
1123
+ if not data:
1124
+ return _error("Recommendation request failed. Semantic Scholar may be unavailable.")
 
 
1125
 
1126
  papers = data.get("recommendedPapers") or []
1127
  if not papers:
 
1175
  "Operations:\n"
1176
  "- trending: Get trending daily papers, optionally filter by topic keyword\n"
1177
  "- search: Search papers. Uses HF by default (ML-tuned). Add date_from/min_citations/categories to use Semantic Scholar with filters\n"
1178
+ "- paper_details: Metadata, abstract, AI summary, github link\n"
1179
  "- read_paper: Read paper contents — without section: abstract + TOC; with section: full text\n"
1180
  "- citation_graph: Get references and citations for a paper with influence flags and citation intents\n"
1181
  "- snippet_search: Semantic search over full-text passages from 12M+ papers\n"