akseljoonas HF Staff commited on
Commit
64bf289
Β·
1 Parent(s): ed56d3d

feat: add hf_papers tool for paper discovery and reading

Browse files
agent/core/tools.py CHANGED
@@ -47,6 +47,7 @@ from agent.tools.hf_repo_git_tool import (
47
  hf_repo_git_handler,
48
  )
49
  from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
 
50
  from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
51
  from agent.tools.sandbox_tool import get_sandbox_tools
52
 
@@ -224,7 +225,11 @@ class ToolRouter:
224
 
225
  @observe(name="call_tool")
226
  async def call_tool(
227
- self, tool_name: str, arguments: dict[str, Any], session: Any = None, tool_call_id: str | None = None
 
 
 
 
228
  ) -> tuple[str, bool]:
229
  """
230
  Call a tool and return (output_string, success_bool).
@@ -242,7 +247,9 @@ class ToolRouter:
242
  if "session" in sig.parameters:
243
  # Check if handler also accepts tool_call_id parameter
244
  if "tool_call_id" in sig.parameters:
245
- return await tool.handler(arguments, session=session, tool_call_id=tool_call_id)
 
 
246
  return await tool.handler(arguments, session=session)
247
  return await tool.handler(arguments)
248
 
@@ -282,6 +289,13 @@ def create_builtin_tools() -> list[ToolSpec]:
282
  parameters=HF_DOCS_FETCH_TOOL_SPEC["parameters"],
283
  handler=hf_docs_fetch_handler,
284
  ),
 
 
 
 
 
 
 
285
  # Dataset inspection tool (unified)
286
  ToolSpec(
287
  name=HF_INSPECT_DATASET_TOOL_SPEC["name"],
 
47
  hf_repo_git_handler,
48
  )
49
  from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
50
+ from agent.tools.papers_tool import HF_PAPERS_TOOL_SPEC, hf_papers_handler
51
  from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
52
  from agent.tools.sandbox_tool import get_sandbox_tools
53
 
 
225
 
226
  @observe(name="call_tool")
227
  async def call_tool(
228
+ self,
229
+ tool_name: str,
230
+ arguments: dict[str, Any],
231
+ session: Any = None,
232
+ tool_call_id: str | None = None,
233
  ) -> tuple[str, bool]:
234
  """
235
  Call a tool and return (output_string, success_bool).
 
247
  if "session" in sig.parameters:
248
  # Check if handler also accepts tool_call_id parameter
249
  if "tool_call_id" in sig.parameters:
250
+ return await tool.handler(
251
+ arguments, session=session, tool_call_id=tool_call_id
252
+ )
253
  return await tool.handler(arguments, session=session)
254
  return await tool.handler(arguments)
255
 
 
289
  parameters=HF_DOCS_FETCH_TOOL_SPEC["parameters"],
290
  handler=hf_docs_fetch_handler,
291
  ),
292
+ # Paper discovery and reading
293
+ ToolSpec(
294
+ name=HF_PAPERS_TOOL_SPEC["name"],
295
+ description=HF_PAPERS_TOOL_SPEC["description"],
296
+ parameters=HF_PAPERS_TOOL_SPEC["parameters"],
297
+ handler=hf_papers_handler,
298
+ ),
299
  # Dataset inspection tool (unified)
300
  ToolSpec(
301
  name=HF_INSPECT_DATASET_TOOL_SPEC["name"],
agent/prompts/system_prompt_v3.yaml CHANGED
@@ -16,6 +16,9 @@ system_prompt: |
16
 
17
  Skip research only for trivial non-code operations.
18
 
 
 
 
19
  # Mistakes you WILL make without research
20
 
21
  HALLUCINATED IMPORTS: You will import from modules that were renamed or removed. Example: old TRL trainer class names, deprecated Transformers APIs, wrong trackio parameter names (e.g. `run_name` instead of `name`). Fix: read a current example script first.
@@ -39,6 +42,7 @@ system_prompt: |
39
  # When writing ML code
40
 
41
  Required sequence before any training/fine-tuning/inference script:
 
42
  1. Find working examples: github_find_examples (discover) β†’ github_read_file (study)
43
  2. Check documentation: explore_hf_docs + fetch_hf_docs for trainer configs and parameters
44
  3. Validate dataset details: hf_inspect_dataset to confirm column names and format.
 
16
 
17
  Skip research only for trivial non-code operations.
18
 
19
+ For open-ended research tasks (improving model performance, finding the best approach for a task, exploring a field, implementing a paper's method):
20
+ hf_papers(trending/search) β†’ hf_papers(read_paper) β†’ hf_papers(find_all_resources) β†’ hf_inspect_dataset
21
+
22
  # Mistakes you WILL make without research
23
 
24
  HALLUCINATED IMPORTS: You will import from modules that were renamed or removed. Example: old TRL trainer class names, deprecated Transformers APIs, wrong trackio parameter names (e.g. `run_name` instead of `name`). Fix: read a current example script first.
 
42
  # When writing ML code
43
 
44
  Required sequence before any training/fine-tuning/inference script:
45
+ 0. (When exploring approaches or finding ideas): hf_papers to discover papers, read methodology, and find linked datasets/models
46
  1. Find working examples: github_find_examples (discover) β†’ github_read_file (study)
47
  2. Check documentation: explore_hf_docs + fetch_hf_docs for trainer configs and parameters
48
  3. Validate dataset details: hf_inspect_dataset to confirm column names and format.
agent/tools/papers_tool.py ADDED
@@ -0,0 +1,813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HF Papers Tool β€” Discover papers, read their contents, and find linked resources.
3
+
4
+ Operations: trending, search, paper_details, read_paper,
5
+ find_datasets, find_models, find_collections, find_all_resources
6
+ """
7
+
8
+ import asyncio
9
+ import re
10
+ from typing import Any
11
+
12
+ import httpx
13
+ from bs4 import BeautifulSoup, Tag
14
+
15
+ from agent.tools.types import ToolResult
16
+
17
+ HF_API = "https://huggingface.co/api"
18
+ ARXIV_HTML = "https://arxiv.org/html"
19
+ AR5IV_HTML = "https://ar5iv.labs.arxiv.org/html"
20
+
21
+ DEFAULT_LIMIT = 10
22
+ MAX_LIMIT = 50
23
+ MAX_SUMMARY_LEN = 300
24
+ MAX_SECTION_PREVIEW_LEN = 280
25
+ MAX_SECTION_TEXT_LEN = 8000
26
+
27
+ SORT_MAP = {
28
+ "downloads": "downloads",
29
+ "likes": "likes",
30
+ "trending": "trendingScore",
31
+ }
32
+
33
+
34
+ # ---------------------------------------------------------------------------
35
+ # HTML paper parsing
36
+ # ---------------------------------------------------------------------------
37
+
38
+
39
+ def _parse_paper_html(html: str) -> dict[str, Any]:
40
+ """Parse arxiv HTML into structured sections.
41
+
42
+ Returns:
43
+ {
44
+ "title": str,
45
+ "abstract": str,
46
+ "sections": [{"id": str, "title": str, "level": int, "text": str}],
47
+ }
48
+ """
49
+ soup = BeautifulSoup(html, "html.parser")
50
+
51
+ # Title
52
+ title_el = soup.find("h1", class_="ltx_title")
53
+ title = title_el.get_text(strip=True).removeprefix("Title:") if title_el else ""
54
+
55
+ # Abstract
56
+ abstract_el = soup.find("div", class_="ltx_abstract")
57
+ abstract = ""
58
+ if abstract_el:
59
+ # Skip the "Abstract" heading itself
60
+ for child in abstract_el.children:
61
+ if isinstance(child, Tag) and child.name in ("h6", "h2", "h3", "p", "span"):
62
+ if child.get_text(strip=True).lower() == "abstract":
63
+ continue
64
+ if isinstance(child, Tag) and child.name == "p":
65
+ abstract += child.get_text(separator=" ", strip=True) + " "
66
+ abstract = abstract.strip()
67
+
68
+ # Sections β€” collect h2/h3 headings and text between them
69
+ sections: list[dict[str, Any]] = []
70
+ headings = soup.find_all(["h2", "h3"], class_=lambda c: c and "ltx_title" in c)
71
+
72
+ for heading in headings:
73
+ level = 2 if heading.name == "h2" else 3
74
+ heading_text = heading.get_text(separator=" ", strip=True)
75
+
76
+ # Collect text from siblings until next heading of same or higher level
77
+ text_parts: list[str] = []
78
+ sibling = heading.find_next_sibling()
79
+ while sibling:
80
+ if isinstance(sibling, Tag):
81
+ if sibling.name in ("h2", "h3") and "ltx_title" in (
82
+ sibling.get("class") or []
83
+ ):
84
+ break
85
+ # Also stop at h2 if we're collecting h3 content
86
+ if sibling.name == "h2" and level == 3:
87
+ break
88
+ text_parts.append(sibling.get_text(separator=" ", strip=True))
89
+ sibling = sibling.find_next_sibling()
90
+
91
+ # Also check parent section element for contained paragraphs
92
+ parent_section = heading.find_parent("section")
93
+ if parent_section and not text_parts:
94
+ for p in parent_section.find_all("p", recursive=False):
95
+ text_parts.append(p.get_text(separator=" ", strip=True))
96
+
97
+ section_text = "\n\n".join(t for t in text_parts if t)
98
+
99
+ # Extract section number from heading text (e.g., "4 Experiments" β†’ "4")
100
+ num_match = re.match(r"^([A-Z]?\d+(?:\.\d+)*)\s", heading_text)
101
+ section_id = num_match.group(1) if num_match else ""
102
+
103
+ sections.append(
104
+ {
105
+ "id": section_id,
106
+ "title": heading_text,
107
+ "level": level,
108
+ "text": section_text,
109
+ }
110
+ )
111
+
112
+ return {"title": title, "abstract": abstract, "sections": sections}
113
+
114
+
115
+ def _find_section(sections: list[dict], query: str) -> dict | None:
116
+ """Find a section by number or name (fuzzy)."""
117
+ query_lower = query.lower().strip()
118
+
119
+ # Exact match on section number
120
+ for s in sections:
121
+ if s["id"] == query_lower or s["id"] == query:
122
+ return s
123
+
124
+ # Exact match on title
125
+ for s in sections:
126
+ if query_lower == s["title"].lower():
127
+ return s
128
+
129
+ # Substring match on title
130
+ for s in sections:
131
+ if query_lower in s["title"].lower():
132
+ return s
133
+
134
+ # Number prefix match (e.g., "4" matches "4.1", "4.2", etc. β€” return parent)
135
+ for s in sections:
136
+ if s["id"].startswith(query_lower + ".") or s["id"] == query_lower:
137
+ return s
138
+
139
+ return None
140
+
141
+
142
+ # ---------------------------------------------------------------------------
143
+ # Formatting helpers
144
+ # ---------------------------------------------------------------------------
145
+
146
+
147
+ def _truncate(text: str, max_len: int) -> str:
148
+ if len(text) <= max_len:
149
+ return text
150
+ return text[:max_len] + "..."
151
+
152
+
153
+ def _format_paper_list(
154
+ papers: list, title: str, date: str | None = None, query: str | None = None
155
+ ) -> str:
156
+ lines = [f"# {title}"]
157
+ if date:
158
+ lines[0] += f" ({date})"
159
+ if query:
160
+ lines.append(f"Filtered by: '{query}'")
161
+ lines.append(f"Showing {len(papers)} paper(s)\n")
162
+
163
+ for i, item in enumerate(papers, 1):
164
+ paper = item.get("paper", item)
165
+ arxiv_id = paper.get("id", "")
166
+ paper_title = paper.get("title", "Unknown")
167
+ upvotes = paper.get("upvotes", 0)
168
+ summary = paper.get("ai_summary") or _truncate(
169
+ paper.get("summary", ""), MAX_SUMMARY_LEN
170
+ )
171
+ keywords = paper.get("ai_keywords") or []
172
+ github = paper.get("githubRepo") or ""
173
+ stars = paper.get("githubStars") or 0
174
+
175
+ lines.append(f"## {i}. {paper_title}")
176
+ lines.append(f"**arxiv_id:** {arxiv_id} | **upvotes:** {upvotes}")
177
+ lines.append(f"https://huggingface.co/papers/{arxiv_id}")
178
+ if keywords:
179
+ lines.append(f"**Keywords:** {', '.join(keywords[:5])}")
180
+ if github:
181
+ lines.append(f"**GitHub:** {github} ({stars} stars)")
182
+ if summary:
183
+ lines.append(f"**Summary:** {_truncate(summary, MAX_SUMMARY_LEN)}")
184
+ lines.append("")
185
+
186
+ return "\n".join(lines)
187
+
188
+
189
+ def _format_paper_detail(paper: dict) -> str:
190
+ arxiv_id = paper.get("id", "")
191
+ title = paper.get("title", "Unknown")
192
+ upvotes = paper.get("upvotes", 0)
193
+ ai_summary = paper.get("ai_summary") or ""
194
+ summary = paper.get("summary", "")
195
+ keywords = paper.get("ai_keywords") or []
196
+ github = paper.get("githubRepo") or ""
197
+ stars = paper.get("githubStars") or 0
198
+ authors = paper.get("authors") or []
199
+
200
+ lines = [f"# {title}"]
201
+ lines.append(f"**arxiv_id:** {arxiv_id} | **upvotes:** {upvotes}")
202
+ lines.append(f"https://huggingface.co/papers/{arxiv_id}")
203
+ lines.append(f"https://arxiv.org/abs/{arxiv_id}")
204
+
205
+ if authors:
206
+ names = [a.get("name", "") for a in authors[:10]]
207
+ author_str = ", ".join(n for n in names if n)
208
+ if len(authors) > 10:
209
+ author_str += f" (+{len(authors) - 10} more)"
210
+ lines.append(f"**Authors:** {author_str}")
211
+
212
+ if keywords:
213
+ lines.append(f"**Keywords:** {', '.join(keywords)}")
214
+ if github:
215
+ lines.append(f"**GitHub:** {github} ({stars} stars)")
216
+
217
+ if ai_summary:
218
+ lines.append(f"\n## AI Summary\n{ai_summary}")
219
+ if summary:
220
+ lines.append(f"\n## Abstract\n{_truncate(summary, 500)}")
221
+
222
+ lines.append(
223
+ "\n**Next:** Use read_paper to read specific sections, or find_all_resources to discover linked datasets/models."
224
+ )
225
+ return "\n".join(lines)
226
+
227
+
228
+ def _format_read_paper_toc(parsed: dict[str, Any], arxiv_id: str) -> str:
229
+ """Format TOC view: abstract + section list with previews."""
230
+ lines = [f"# {parsed['title']}"]
231
+ lines.append(f"https://arxiv.org/abs/{arxiv_id}\n")
232
+
233
+ if parsed["abstract"]:
234
+ lines.append(f"## Abstract\n{parsed['abstract']}\n")
235
+
236
+ lines.append("## Sections")
237
+ for s in parsed["sections"]:
238
+ prefix = " " if s["level"] == 3 else ""
239
+ preview = (
240
+ _truncate(s["text"], MAX_SECTION_PREVIEW_LEN) if s["text"] else "(empty)"
241
+ )
242
+ lines.append(f"{prefix}- **{s['title']}**: {preview}")
243
+
244
+ lines.append(
245
+ '\n**Tip:** Call read_paper with section parameter (e.g. section="4" or section="Experiments") to read a specific section.'
246
+ )
247
+ return "\n".join(lines)
248
+
249
+
250
+ def _format_read_paper_section(section: dict, arxiv_id: str) -> str:
251
+ """Format a single section's full text."""
252
+ lines = [f"# {section['title']}"]
253
+ lines.append(f"https://arxiv.org/abs/{arxiv_id}\n")
254
+
255
+ text = section["text"]
256
+ if len(text) > MAX_SECTION_TEXT_LEN:
257
+ text = (
258
+ text[:MAX_SECTION_TEXT_LEN]
259
+ + f"\n\n... (truncated at {MAX_SECTION_TEXT_LEN} chars)"
260
+ )
261
+
262
+ lines.append(text if text else "(This section has no extractable text content.)")
263
+ return "\n".join(lines)
264
+
265
+
266
+ def _format_datasets(datasets: list, arxiv_id: str, sort: str) -> str:
267
+ lines = [f"# Datasets linked to paper {arxiv_id}"]
268
+ lines.append(f"https://huggingface.co/papers/{arxiv_id}")
269
+ lines.append(f"Showing {len(datasets)} dataset(s), sorted by {sort}\n")
270
+
271
+ for i, ds in enumerate(datasets, 1):
272
+ ds_id = ds.get("id", "unknown")
273
+ downloads = ds.get("downloads", 0)
274
+ likes = ds.get("likes", 0)
275
+ desc = _truncate(ds.get("description") or "", MAX_SUMMARY_LEN)
276
+ tags = ds.get("tags") or []
277
+ interesting = [t for t in tags if not t.startswith(("arxiv:", "region:"))][:5]
278
+
279
+ lines.append(f"**{i}. [{ds_id}](https://huggingface.co/datasets/{ds_id})**")
280
+ lines.append(f" Downloads: {downloads:,} | Likes: {likes}")
281
+ if interesting:
282
+ lines.append(f" Tags: {', '.join(interesting)}")
283
+ if desc:
284
+ lines.append(f" {desc}")
285
+ lines.append("")
286
+
287
+ if datasets:
288
+ top = datasets[0].get("id", "")
289
+ lines.append(f'**Inspect top dataset:** hf_inspect_dataset(dataset="{top}")')
290
+ return "\n".join(lines)
291
+
292
+
293
+ def _format_datasets_compact(datasets: list) -> str:
294
+ if not datasets:
295
+ return "## Datasets\nNone found"
296
+ lines = [f"## Datasets ({len(datasets)})"]
297
+ for ds in datasets:
298
+ lines.append(
299
+ f"- **{ds.get('id', '?')}** ({ds.get('downloads', 0):,} downloads)"
300
+ )
301
+ return "\n".join(lines)
302
+
303
+
304
+ def _format_models(models: list, arxiv_id: str, sort: str) -> str:
305
+ lines = [f"# Models linked to paper {arxiv_id}"]
306
+ lines.append(f"https://huggingface.co/papers/{arxiv_id}")
307
+ lines.append(f"Showing {len(models)} model(s), sorted by {sort}\n")
308
+
309
+ for i, m in enumerate(models, 1):
310
+ model_id = m.get("id", "unknown")
311
+ downloads = m.get("downloads", 0)
312
+ likes = m.get("likes", 0)
313
+ pipeline = m.get("pipeline_tag") or ""
314
+ library = m.get("library_name") or ""
315
+
316
+ lines.append(f"**{i}. [{model_id}](https://huggingface.co/{model_id})**")
317
+ meta = f" Downloads: {downloads:,} | Likes: {likes}"
318
+ if pipeline:
319
+ meta += f" | Task: {pipeline}"
320
+ if library:
321
+ meta += f" | Library: {library}"
322
+ lines.append(meta)
323
+ lines.append("")
324
+
325
+ return "\n".join(lines)
326
+
327
+
328
+ def _format_models_compact(models: list) -> str:
329
+ if not models:
330
+ return "## Models\nNone found"
331
+ lines = [f"## Models ({len(models)})"]
332
+ for m in models:
333
+ pipeline = m.get("pipeline_tag") or ""
334
+ suffix = f" ({pipeline})" if pipeline else ""
335
+ lines.append(
336
+ f"- **{m.get('id', '?')}** ({m.get('downloads', 0):,} downloads){suffix}"
337
+ )
338
+ return "\n".join(lines)
339
+
340
+
341
+ def _format_collections(collections: list, arxiv_id: str) -> str:
342
+ lines = [f"# Collections containing paper {arxiv_id}"]
343
+ lines.append(f"Showing {len(collections)} collection(s)\n")
344
+
345
+ for i, c in enumerate(collections, 1):
346
+ slug = c.get("slug", "")
347
+ title = c.get("title", "Untitled")
348
+ upvotes = c.get("upvotes", 0)
349
+ owner = c.get("owner", {}).get("name", "")
350
+ desc = _truncate(c.get("description") or "", MAX_SUMMARY_LEN)
351
+ num_items = len(c.get("items", []))
352
+
353
+ lines.append(f"**{i}. {title}**")
354
+ lines.append(f" By: {owner} | Upvotes: {upvotes} | Items: {num_items}")
355
+ lines.append(f" https://huggingface.co/collections/{slug}")
356
+ if desc:
357
+ lines.append(f" {desc}")
358
+ lines.append("")
359
+
360
+ return "\n".join(lines)
361
+
362
+
363
+ def _format_collections_compact(collections: list) -> str:
364
+ if not collections:
365
+ return "## Collections\nNone found"
366
+ lines = [f"## Collections ({len(collections)})"]
367
+ for c in collections:
368
+ title = c.get("title", "Untitled")
369
+ owner = c.get("owner", {}).get("name", "")
370
+ upvotes = c.get("upvotes", 0)
371
+ lines.append(f"- **{title}** by {owner} ({upvotes} upvotes)")
372
+ return "\n".join(lines)
373
+
374
+
375
+ # ---------------------------------------------------------------------------
376
+ # Operation handlers
377
+ # ---------------------------------------------------------------------------
378
+
379
+
380
+ def _error(message: str) -> ToolResult:
381
+ return {
382
+ "formatted": message,
383
+ "totalResults": 0,
384
+ "resultsShared": 0,
385
+ "isError": True,
386
+ }
387
+
388
+
389
+ def _validate_arxiv_id(args: dict) -> str | None:
390
+ """Return arxiv_id or None if missing."""
391
+ return args.get("arxiv_id")
392
+
393
+
394
+ async def _op_trending(args: dict[str, Any], limit: int) -> ToolResult:
395
+ date = args.get("date")
396
+ query = args.get("query")
397
+
398
+ params: dict[str, Any] = {"limit": limit if not query else max(limit * 3, 30)}
399
+ if date:
400
+ params["date"] = date
401
+
402
+ async with httpx.AsyncClient(timeout=15) as client:
403
+ resp = await client.get(f"{HF_API}/daily_papers", params=params)
404
+ resp.raise_for_status()
405
+ papers = resp.json()
406
+
407
+ if query:
408
+ q = query.lower()
409
+ papers = [
410
+ p
411
+ for p in papers
412
+ if q in p.get("title", "").lower()
413
+ or q in p.get("paper", {}).get("title", "").lower()
414
+ or q in p.get("paper", {}).get("summary", "").lower()
415
+ or any(
416
+ q in kw.lower() for kw in (p.get("paper", {}).get("ai_keywords") or [])
417
+ )
418
+ ]
419
+
420
+ papers = papers[:limit]
421
+ if not papers:
422
+ msg = "No trending papers found"
423
+ if query:
424
+ msg += f" matching '{query}'"
425
+ if date:
426
+ msg += f" for {date}"
427
+ return {"formatted": msg, "totalResults": 0, "resultsShared": 0}
428
+
429
+ formatted = _format_paper_list(papers, "Trending Papers", date=date, query=query)
430
+ return {
431
+ "formatted": formatted,
432
+ "totalResults": len(papers),
433
+ "resultsShared": len(papers),
434
+ }
435
+
436
+
437
+ async def _op_search(args: dict[str, Any], limit: int) -> ToolResult:
438
+ query = args.get("query")
439
+ if not query:
440
+ return _error("'query' is required for search operation.")
441
+
442
+ async with httpx.AsyncClient(timeout=15) as client:
443
+ resp = await client.get(
444
+ f"{HF_API}/papers/search", params={"q": query, "limit": limit}
445
+ )
446
+ resp.raise_for_status()
447
+ papers = resp.json()
448
+
449
+ if not papers:
450
+ return {
451
+ "formatted": f"No papers found for '{query}'",
452
+ "totalResults": 0,
453
+ "resultsShared": 0,
454
+ }
455
+
456
+ formatted = _format_paper_list(papers, f"Papers matching '{query}'")
457
+ return {
458
+ "formatted": formatted,
459
+ "totalResults": len(papers),
460
+ "resultsShared": len(papers),
461
+ }
462
+
463
+
464
+ async def _op_paper_details(args: dict[str, Any], limit: int) -> ToolResult:
465
+ arxiv_id = _validate_arxiv_id(args)
466
+ if not arxiv_id:
467
+ return _error("'arxiv_id' is required for paper_details.")
468
+
469
+ async with httpx.AsyncClient(timeout=15) as client:
470
+ resp = await client.get(f"{HF_API}/papers/{arxiv_id}")
471
+ resp.raise_for_status()
472
+ paper = resp.json()
473
+
474
+ return {
475
+ "formatted": _format_paper_detail(paper),
476
+ "totalResults": 1,
477
+ "resultsShared": 1,
478
+ }
479
+
480
+
481
+ async def _op_read_paper(args: dict[str, Any], limit: int) -> ToolResult:
482
+ arxiv_id = _validate_arxiv_id(args)
483
+ if not arxiv_id:
484
+ return _error("'arxiv_id' is required for read_paper.")
485
+
486
+ section_query = args.get("section")
487
+
488
+ # Try fetching HTML from arxiv, then ar5iv, then fallback to abstract
489
+ parsed = None
490
+ async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
491
+ for base_url in [ARXIV_HTML, AR5IV_HTML]:
492
+ try:
493
+ resp = await client.get(f"{base_url}/{arxiv_id}")
494
+ if resp.status_code == 200:
495
+ parsed = _parse_paper_html(resp.text)
496
+ if parsed["sections"]: # Only use if we got real sections
497
+ break
498
+ parsed = None
499
+ except httpx.RequestError:
500
+ continue
501
+
502
+ # Fallback: return abstract from HF API
503
+ if not parsed or not parsed["sections"]:
504
+ try:
505
+ async with httpx.AsyncClient(timeout=15) as client:
506
+ resp = await client.get(f"{HF_API}/papers/{arxiv_id}")
507
+ resp.raise_for_status()
508
+ paper = resp.json()
509
+ abstract = paper.get("summary", "")
510
+ title = paper.get("title", "")
511
+ msg = f"# {title}\nhttps://arxiv.org/abs/{arxiv_id}\n\n"
512
+ msg += f"## Abstract\n{abstract}\n\n"
513
+ msg += "HTML version not available for this paper. Only abstract shown.\n"
514
+ msg += f"PDF: https://arxiv.org/pdf/{arxiv_id}"
515
+ return {"formatted": msg, "totalResults": 1, "resultsShared": 1}
516
+ except Exception:
517
+ return _error(
518
+ f"Could not fetch paper {arxiv_id}. Check the arxiv ID is correct."
519
+ )
520
+
521
+ # Return TOC or specific section
522
+ if not section_query:
523
+ formatted = _format_read_paper_toc(parsed, arxiv_id)
524
+ return {
525
+ "formatted": formatted,
526
+ "totalResults": len(parsed["sections"]),
527
+ "resultsShared": len(parsed["sections"]),
528
+ }
529
+
530
+ section = _find_section(parsed["sections"], section_query)
531
+ if not section:
532
+ available = "\n".join(f"- {s['title']}" for s in parsed["sections"])
533
+ return _error(
534
+ f"Section '{section_query}' not found. Available sections:\n{available}"
535
+ )
536
+
537
+ formatted = _format_read_paper_section(section, arxiv_id)
538
+ return {"formatted": formatted, "totalResults": 1, "resultsShared": 1}
539
+
540
+
541
+ async def _op_find_datasets(args: dict[str, Any], limit: int) -> ToolResult:
542
+ arxiv_id = _validate_arxiv_id(args)
543
+ if not arxiv_id:
544
+ return _error("'arxiv_id' is required for find_datasets.")
545
+
546
+ sort = args.get("sort", "downloads")
547
+ sort_key = SORT_MAP.get(sort, "downloads")
548
+
549
+ async with httpx.AsyncClient(timeout=15) as client:
550
+ resp = await client.get(
551
+ f"{HF_API}/datasets",
552
+ params={
553
+ "other": f"arxiv:{arxiv_id}",
554
+ "limit": limit,
555
+ "sort": sort_key,
556
+ "direction": -1,
557
+ },
558
+ )
559
+ resp.raise_for_status()
560
+ datasets = resp.json()
561
+
562
+ if not datasets:
563
+ return {
564
+ "formatted": f"No datasets found linked to paper {arxiv_id}.\nhttps://huggingface.co/papers/{arxiv_id}",
565
+ "totalResults": 0,
566
+ "resultsShared": 0,
567
+ }
568
+
569
+ return {
570
+ "formatted": _format_datasets(datasets, arxiv_id, sort),
571
+ "totalResults": len(datasets),
572
+ "resultsShared": len(datasets),
573
+ }
574
+
575
+
576
+ async def _op_find_models(args: dict[str, Any], limit: int) -> ToolResult:
577
+ arxiv_id = _validate_arxiv_id(args)
578
+ if not arxiv_id:
579
+ return _error("'arxiv_id' is required for find_models.")
580
+
581
+ sort = args.get("sort", "downloads")
582
+ sort_key = SORT_MAP.get(sort, "downloads")
583
+
584
+ async with httpx.AsyncClient(timeout=15) as client:
585
+ resp = await client.get(
586
+ f"{HF_API}/models",
587
+ params={
588
+ "other": f"arxiv:{arxiv_id}",
589
+ "limit": limit,
590
+ "sort": sort_key,
591
+ "direction": -1,
592
+ },
593
+ )
594
+ resp.raise_for_status()
595
+ models = resp.json()
596
+
597
+ if not models:
598
+ return {
599
+ "formatted": f"No models found linked to paper {arxiv_id}.\nhttps://huggingface.co/papers/{arxiv_id}",
600
+ "totalResults": 0,
601
+ "resultsShared": 0,
602
+ }
603
+
604
+ return {
605
+ "formatted": _format_models(models, arxiv_id, sort),
606
+ "totalResults": len(models),
607
+ "resultsShared": len(models),
608
+ }
609
+
610
+
611
+ async def _op_find_collections(args: dict[str, Any], limit: int) -> ToolResult:
612
+ arxiv_id = _validate_arxiv_id(args)
613
+ if not arxiv_id:
614
+ return _error("'arxiv_id' is required for find_collections.")
615
+
616
+ async with httpx.AsyncClient(timeout=15) as client:
617
+ resp = await client.get(f"{HF_API}/collections", params={"paper": arxiv_id})
618
+ resp.raise_for_status()
619
+ collections = resp.json()
620
+
621
+ if not collections:
622
+ return {
623
+ "formatted": f"No collections found containing paper {arxiv_id}.\nhttps://huggingface.co/papers/{arxiv_id}",
624
+ "totalResults": 0,
625
+ "resultsShared": 0,
626
+ }
627
+
628
+ collections = collections[:limit]
629
+ return {
630
+ "formatted": _format_collections(collections, arxiv_id),
631
+ "totalResults": len(collections),
632
+ "resultsShared": len(collections),
633
+ }
634
+
635
+
636
+ async def _op_find_all_resources(args: dict[str, Any], limit: int) -> ToolResult:
637
+ arxiv_id = _validate_arxiv_id(args)
638
+ if not arxiv_id:
639
+ return _error("'arxiv_id' is required for find_all_resources.")
640
+
641
+ per_cat = min(limit, 10)
642
+
643
+ async with httpx.AsyncClient(timeout=15) as client:
644
+ results = await asyncio.gather(
645
+ client.get(
646
+ f"{HF_API}/datasets",
647
+ params={
648
+ "other": f"arxiv:{arxiv_id}",
649
+ "limit": per_cat,
650
+ "sort": "downloads",
651
+ "direction": -1,
652
+ },
653
+ ),
654
+ client.get(
655
+ f"{HF_API}/models",
656
+ params={
657
+ "other": f"arxiv:{arxiv_id}",
658
+ "limit": per_cat,
659
+ "sort": "downloads",
660
+ "direction": -1,
661
+ },
662
+ ),
663
+ client.get(f"{HF_API}/collections", params={"paper": arxiv_id}),
664
+ return_exceptions=True,
665
+ )
666
+
667
+ sections = []
668
+ total = 0
669
+
670
+ # Datasets
671
+ if isinstance(results[0], Exception):
672
+ sections.append(f"## Datasets\nError: {results[0]}")
673
+ else:
674
+ datasets = results[0].json()
675
+ total += len(datasets)
676
+ sections.append(_format_datasets_compact(datasets[:per_cat]))
677
+
678
+ # Models
679
+ if isinstance(results[1], Exception):
680
+ sections.append(f"## Models\nError: {results[1]}")
681
+ else:
682
+ models = results[1].json()
683
+ total += len(models)
684
+ sections.append(_format_models_compact(models[:per_cat]))
685
+
686
+ # Collections
687
+ if isinstance(results[2], Exception):
688
+ sections.append(f"## Collections\nError: {results[2]}")
689
+ else:
690
+ collections = results[2].json()
691
+ total += len(collections)
692
+ sections.append(_format_collections_compact(collections[:per_cat]))
693
+
694
+ header = f"# Resources linked to paper {arxiv_id}\nhttps://huggingface.co/papers/{arxiv_id}\n"
695
+ formatted = header + "\n\n".join(sections)
696
+ return {"formatted": formatted, "totalResults": total, "resultsShared": total}
697
+
698
+
699
+ # ---------------------------------------------------------------------------
700
+ # Operation dispatch
701
+ # ---------------------------------------------------------------------------
702
+
703
+ _OPERATIONS = {
704
+ "trending": _op_trending,
705
+ "search": _op_search,
706
+ "paper_details": _op_paper_details,
707
+ "read_paper": _op_read_paper,
708
+ "find_datasets": _op_find_datasets,
709
+ "find_models": _op_find_models,
710
+ "find_collections": _op_find_collections,
711
+ "find_all_resources": _op_find_all_resources,
712
+ }
713
+
714
+
715
+ # ---------------------------------------------------------------------------
716
+ # Tool spec + handler
717
+ # ---------------------------------------------------------------------------
718
+
719
+ HF_PAPERS_TOOL_SPEC = {
720
+ "name": "hf_papers",
721
+ "description": (
722
+ "Discover ML research papers, find their linked resources (datasets, models, collections), "
723
+ "and read paper contents on HuggingFace Hub and arXiv.\n\n"
724
+ "Use this when exploring a research area, looking for datasets for a task, "
725
+ "implementing a paper's approach, or trying to improve performance on something. "
726
+ "Typical flow:\n"
727
+ " hf_papers(search/trending) β†’ hf_papers(read_paper) β†’ hf_papers(find_all_resources) β†’ hf_inspect_dataset\n\n"
728
+ "Operations:\n"
729
+ "- trending: Get trending daily papers, optionally filter by topic keyword\n"
730
+ "- search: Full-text search for papers by query\n"
731
+ "- paper_details: Get metadata, abstract, AI summary, and github link for a paper\n"
732
+ "- read_paper: Read paper contents β€” without section: returns abstract + table of contents; "
733
+ "with section: returns full section text\n"
734
+ "- find_datasets: Find datasets linked to a paper\n"
735
+ "- find_models: Find models linked to a paper\n"
736
+ "- find_collections: Find collections that include a paper\n"
737
+ "- find_all_resources: Parallel fetch of datasets + models + collections for a paper (unified view)"
738
+ ),
739
+ "parameters": {
740
+ "type": "object",
741
+ "properties": {
742
+ "operation": {
743
+ "type": "string",
744
+ "enum": list(_OPERATIONS.keys()),
745
+ "description": "Operation to execute.",
746
+ },
747
+ "query": {
748
+ "type": "string",
749
+ "description": (
750
+ "Search query. Required for: search. "
751
+ "Optional for: trending (filters results by keyword match on title, summary, and AI-generated keywords)."
752
+ ),
753
+ },
754
+ "arxiv_id": {
755
+ "type": "string",
756
+ "description": (
757
+ "ArXiv paper ID (e.g. '2305.18290'). "
758
+ "Required for: paper_details, read_paper, find_datasets, find_models, find_collections, find_all_resources. "
759
+ "Get IDs from trending or search results first."
760
+ ),
761
+ },
762
+ "section": {
763
+ "type": "string",
764
+ "description": (
765
+ "Section name or number to read (e.g. '3', 'Experiments', '4.2'). "
766
+ "Optional for: read_paper. Without this, read_paper returns the abstract + table of contents "
767
+ "so you can choose which section to read."
768
+ ),
769
+ },
770
+ "date": {
771
+ "type": "string",
772
+ "description": "Date in YYYY-MM-DD format. Optional for: trending (defaults to recent papers).",
773
+ },
774
+ "sort": {
775
+ "type": "string",
776
+ "enum": ["downloads", "likes", "trending"],
777
+ "description": (
778
+ "Sort order for find_datasets and find_models. Default: downloads. "
779
+ "Use 'downloads' for most-used, 'likes' for community favorites, 'trending' for recently popular."
780
+ ),
781
+ },
782
+ "limit": {
783
+ "type": "integer",
784
+ "description": "Maximum results to return (default: 10, max: 50).",
785
+ },
786
+ },
787
+ "required": ["operation"],
788
+ },
789
+ }
790
+
791
+
792
+ async def hf_papers_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
793
+ """Handler for agent tool router."""
794
+ operation = arguments.get("operation")
795
+ if not operation:
796
+ return "'operation' parameter is required.", False
797
+
798
+ handler = _OPERATIONS.get(operation)
799
+ if not handler:
800
+ valid = ", ".join(_OPERATIONS.keys())
801
+ return f"Unknown operation: '{operation}'. Valid: {valid}", False
802
+
803
+ limit = min(arguments.get("limit", DEFAULT_LIMIT), MAX_LIMIT)
804
+
805
+ try:
806
+ result = await handler(arguments, limit)
807
+ return result["formatted"], not result.get("isError", False)
808
+ except httpx.HTTPStatusError as e:
809
+ return f"API error: {e.response.status_code} β€” {e.response.text[:200]}", False
810
+ except httpx.RequestError as e:
811
+ return f"Request error: {e}", False
812
+ except Exception as e:
813
+ return f"Error in {operation}: {e}", False
tests/integration/tools/test_papers_integration.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Integration tests for HF Papers Tool
4
+ Tests with real HF and arXiv APIs β€” all endpoints are public, no auth required.
5
+ """
6
+ import asyncio
7
+ import sys
8
+
9
+ sys.path.insert(0, ".")
10
+
11
+ from agent.tools.papers_tool import hf_papers_handler
12
+
13
+ # ANSI color codes
14
+ GREEN = "\033[92m"
15
+ YELLOW = "\033[93m"
16
+ RED = "\033[91m"
17
+ BLUE = "\033[94m"
18
+ RESET = "\033[0m"
19
+
20
+
21
+ def print_test(msg):
22
+ print(f"{BLUE}[TEST]{RESET} {msg}")
23
+
24
+
25
+ def print_success(msg):
26
+ print(f"{GREEN}βœ“{RESET} {msg}")
27
+
28
+
29
+ def print_warning(msg):
30
+ print(f"{YELLOW}⚠{RESET} {msg}")
31
+
32
+
33
+ def print_error(msg):
34
+ print(f"{RED}βœ—{RESET} {msg}")
35
+
36
+
37
+ def print_snippet(output, length=600):
38
+ """Print a snippet of raw test output."""
39
+ out = output[:length].replace("\n", "\\n")
40
+ if len(output) > length:
41
+ out += "..."
42
+ print(f"{YELLOW}[RAW OUTPUT SNIPPET]{RESET} {out}")
43
+
44
+
45
+ passed = 0
46
+ failed = 0
47
+
48
+
49
+ async def run_tool(args: dict) -> tuple[str, bool]:
50
+ """Call the handler and return (output, success)."""
51
+ return await hf_papers_handler(args)
52
+
53
+
54
+ async def check(name: str, args: dict, *, expect_success: bool = True, expect_in: list[str] | None = None) -> str:
55
+ """Run a tool call, validate, and track pass/fail.
56
+ Prints a snippet of raw output of each test."""
57
+ global passed, failed
58
+ print_test(name)
59
+ output, success = await run_tool(args)
60
+ print_snippet(output)
61
+
62
+ if success != expect_success:
63
+ print_error(f"Expected success={expect_success}, got {success}")
64
+ print(f" Output: {output[:300]}")
65
+ failed += 1
66
+ return output
67
+
68
+ if expect_in:
69
+ missing = [s for s in expect_in if s.lower() not in output.lower()]
70
+ if missing:
71
+ print_error(f"Missing expected strings: {missing}")
72
+ print(f" Output: {output[:300]}")
73
+ failed += 1
74
+ return output
75
+
76
+ print_success(f"OK ({len(output)} chars)")
77
+ passed += 1
78
+ return output
79
+
80
+
81
+ # ---------------------------------------------------------------------------
82
+ # Test suites
83
+ # ---------------------------------------------------------------------------
84
+
85
+
86
+ async def test_paper_discovery():
87
+ print(f"\n{YELLOW}{'=' * 70}{RESET}")
88
+ print(f"{YELLOW}Test Suite 1: Paper Discovery{RESET}")
89
+ print(f"{YELLOW}{'=' * 70}{RESET}\n")
90
+
91
+ # Trending papers
92
+ output = await check(
93
+ "trending (limit=3)",
94
+ {"operation": "trending", "limit": 3},
95
+ expect_in=["Trending Papers"],
96
+ )
97
+
98
+ # Trending with keyword filter
99
+ await check(
100
+ "trending with query='language'",
101
+ {"operation": "trending", "query": "language", "limit": 5},
102
+ )
103
+
104
+ # Search
105
+ await check(
106
+ "search 'direct preference optimization'",
107
+ {"operation": "search", "query": "direct preference optimization", "limit": 3},
108
+ expect_in=["preference"],
109
+ )
110
+
111
+ # Paper details (DPO paper)
112
+ await check(
113
+ "paper_details for 2305.18290 (DPO paper)",
114
+ {"operation": "paper_details", "arxiv_id": "2305.18290"},
115
+ expect_in=["2305.18290", "Direct Preference"],
116
+ )
117
+
118
+
119
+ async def test_read_paper():
120
+ print(f"\n{YELLOW}{'=' * 70}{RESET}")
121
+ print(f"{YELLOW}Test Suite 2: Read Paper{RESET}")
122
+ print(f"{YELLOW}{'=' * 70}{RESET}\n")
123
+
124
+ # Read paper TOC (no section specified)
125
+ output = await check(
126
+ "read_paper TOC for 2305.18290",
127
+ {"operation": "read_paper", "arxiv_id": "2305.18290"},
128
+ expect_in=["Sections", "Abstract"],
129
+ )
130
+
131
+ # Read specific section by number
132
+ await check(
133
+ "read_paper section='4' (DPO paper)",
134
+ {"operation": "read_paper", "arxiv_id": "2305.18290", "section": "4"},
135
+ )
136
+
137
+ # Read specific section by name
138
+ await check(
139
+ "read_paper section='Experiments'",
140
+ {"operation": "read_paper", "arxiv_id": "2305.18290", "section": "Experiments"},
141
+ )
142
+
143
+ # Fallback for a paper that might not have HTML
144
+ # Using a very old paper ID β€” may or may not have HTML
145
+ await check(
146
+ "read_paper fallback (old paper 1706.03762 β€” Attention Is All You Need)",
147
+ {"operation": "read_paper", "arxiv_id": "1706.03762"},
148
+ expect_in=["Attention"],
149
+ )
150
+
151
+
152
+ async def test_linked_resources():
153
+ print(f"\n{YELLOW}{'=' * 70}{RESET}")
154
+ print(f"{YELLOW}Test Suite 3: Linked Resources{RESET}")
155
+ print(f"{YELLOW}{'=' * 70}{RESET}\n")
156
+
157
+ # Find datasets linked to DPO paper
158
+ await check(
159
+ "find_datasets for 2305.18290",
160
+ {"operation": "find_datasets", "arxiv_id": "2305.18290", "limit": 5},
161
+ )
162
+
163
+ # Find models linked to DPO paper
164
+ await check(
165
+ "find_models for 2305.18290",
166
+ {"operation": "find_models", "arxiv_id": "2305.18290", "limit": 5},
167
+ )
168
+
169
+ # Find collections
170
+ await check(
171
+ "find_collections for 2305.18290",
172
+ {"operation": "find_collections", "arxiv_id": "2305.18290"},
173
+ )
174
+
175
+ # Find all resources (parallel fan-out)
176
+ await check(
177
+ "find_all_resources for 2305.18290",
178
+ {"operation": "find_all_resources", "arxiv_id": "2305.18290"},
179
+ expect_in=["Datasets", "Models", "Collections"],
180
+ )
181
+
182
+
183
+ async def test_edge_cases():
184
+ print(f"\n{YELLOW}{'=' * 70}{RESET}")
185
+ print(f"{YELLOW}Test Suite 4: Edge Cases{RESET}")
186
+ print(f"{YELLOW}{'=' * 70}{RESET}\n")
187
+
188
+ # Search with no results
189
+ await check(
190
+ "search gibberish query",
191
+ {"operation": "search", "query": "xyzzyplugh_nonexistent_9999"},
192
+ expect_in=["No papers found"],
193
+ )
194
+
195
+ # Missing required param
196
+ await check(
197
+ "search without query β†’ error",
198
+ {"operation": "search"},
199
+ expect_success=False,
200
+ expect_in=["required"],
201
+ )
202
+
203
+ # Missing arxiv_id
204
+ await check(
205
+ "find_datasets without arxiv_id β†’ error",
206
+ {"operation": "find_datasets"},
207
+ expect_success=False,
208
+ expect_in=["required"],
209
+ )
210
+
211
+ # Invalid arxiv_id
212
+ await check(
213
+ "paper_details with nonexistent ID",
214
+ {"operation": "paper_details", "arxiv_id": "0000.00000"},
215
+ expect_success=False,
216
+ )
217
+
218
+ # Invalid operation
219
+ await check(
220
+ "invalid operation β†’ error",
221
+ {"operation": "nonexistent_op"},
222
+ expect_success=False,
223
+ expect_in=["Unknown operation"],
224
+ )
225
+
226
+ # read_paper with nonexistent section
227
+ await check(
228
+ "read_paper with bad section name",
229
+ {"operation": "read_paper", "arxiv_id": "2305.18290", "section": "Nonexistent Section XYZ"},
230
+ expect_success=False,
231
+ expect_in=["not found"],
232
+ )
233
+
234
+
235
+ async def main():
236
+ print("=" * 70)
237
+ print(f"{BLUE}HF Papers Tool β€” Integration Tests{RESET}")
238
+ print("=" * 70)
239
+ print(f"{BLUE}All APIs are public, no authentication required.{RESET}\n")
240
+
241
+ try:
242
+ await test_paper_discovery()
243
+ await test_read_paper()
244
+ await test_linked_resources()
245
+ await test_edge_cases()
246
+ except Exception as e:
247
+ print_error(f"Test suite crashed: {e}")
248
+ import traceback
249
+ traceback.print_exc()
250
+ sys.exit(1)
251
+
252
+ # Summary
253
+ print(f"\n{'=' * 70}")
254
+ total = passed + failed
255
+ if failed == 0:
256
+ print(f"{GREEN}βœ“ All {total} tests passed!{RESET}")
257
+ else:
258
+ print(f"{RED}βœ— {failed}/{total} tests failed{RESET}")
259
+ print(f"{GREEN}βœ“ {passed}/{total} tests passed{RESET}")
260
+
261
+ print(f"{'=' * 70}\n")
262
+
263
+ if failed > 0:
264
+ sys.exit(1)
265
+
266
+
267
+ if __name__ == "__main__":
268
+ asyncio.run(main())