Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
akseljoonas HF Staff
fix: buffer websocket events during disconnects and replay on reconnect
ea2014d | """ | |
| HF Papers Tool — Discover papers, read their contents, and find linked resources. | |
| Operations: trending, search, paper_details, read_paper, | |
| find_datasets, find_models, find_collections, find_all_resources | |
| """ | |
| import asyncio | |
| import re | |
| from typing import Any | |
| import httpx | |
| from bs4 import BeautifulSoup, Tag | |
| from agent.tools.types import ToolResult | |
| HF_API = "https://huggingface.co/api" | |
| ARXIV_HTML = "https://arxiv.org/html" | |
| AR5IV_HTML = "https://ar5iv.labs.arxiv.org/html" | |
| DEFAULT_LIMIT = 10 | |
| MAX_LIMIT = 50 | |
| MAX_SUMMARY_LEN = 300 | |
| MAX_SECTION_PREVIEW_LEN = 280 | |
| MAX_SECTION_TEXT_LEN = 8000 | |
| SORT_MAP = { | |
| "downloads": "downloads", | |
| "likes": "likes", | |
| "trending": "trendingScore", | |
| } | |
| # --------------------------------------------------------------------------- | |
| # HTML paper parsing | |
| # --------------------------------------------------------------------------- | |
| def _parse_paper_html(html: str) -> dict[str, Any]: | |
| """Parse arxiv HTML into structured sections. | |
| Returns: | |
| { | |
| "title": str, | |
| "abstract": str, | |
| "sections": [{"id": str, "title": str, "level": int, "text": str}], | |
| } | |
| """ | |
| soup = BeautifulSoup(html, "html.parser") | |
| # Title | |
| title_el = soup.find("h1", class_="ltx_title") | |
| title = title_el.get_text(strip=True).removeprefix("Title:") if title_el else "" | |
| # Abstract | |
| abstract_el = soup.find("div", class_="ltx_abstract") | |
| abstract = "" | |
| if abstract_el: | |
| # Skip the "Abstract" heading itself | |
| for child in abstract_el.children: | |
| if isinstance(child, Tag) and child.name in ("h6", "h2", "h3", "p", "span"): | |
| if child.get_text(strip=True).lower() == "abstract": | |
| continue | |
| if isinstance(child, Tag) and child.name == "p": | |
| abstract += child.get_text(separator=" ", strip=True) + " " | |
| abstract = abstract.strip() | |
| # Sections — collect h2/h3 headings and text between them | |
| sections: list[dict[str, Any]] = [] | |
| headings = soup.find_all(["h2", "h3"], class_=lambda c: c and "ltx_title" in c) | |
| for heading in headings: | |
| level = 2 if heading.name == "h2" else 3 | |
| heading_text = heading.get_text(separator=" ", strip=True) | |
| # Collect text from siblings until next heading of same or higher level | |
| text_parts: list[str] = [] | |
| sibling = heading.find_next_sibling() | |
| while sibling: | |
| if isinstance(sibling, Tag): | |
| if sibling.name in ("h2", "h3") and "ltx_title" in ( | |
| sibling.get("class") or [] | |
| ): | |
| break | |
| # Also stop at h2 if we're collecting h3 content | |
| if sibling.name == "h2" and level == 3: | |
| break | |
| text_parts.append(sibling.get_text(separator=" ", strip=True)) | |
| sibling = sibling.find_next_sibling() | |
| # Also check parent section element for contained paragraphs | |
| parent_section = heading.find_parent("section") | |
| if parent_section and not text_parts: | |
| for p in parent_section.find_all("p", recursive=False): | |
| text_parts.append(p.get_text(separator=" ", strip=True)) | |
| section_text = "\n\n".join(t for t in text_parts if t) | |
| # Extract section number from heading text (e.g., "4 Experiments" → "4") | |
| num_match = re.match(r"^([A-Z]?\d+(?:\.\d+)*)\s", heading_text) | |
| section_id = num_match.group(1) if num_match else "" | |
| sections.append( | |
| { | |
| "id": section_id, | |
| "title": heading_text, | |
| "level": level, | |
| "text": section_text, | |
| } | |
| ) | |
| return {"title": title, "abstract": abstract, "sections": sections} | |
| def _find_section(sections: list[dict], query: str) -> dict | None: | |
| """Find a section by number or name (fuzzy).""" | |
| query_lower = query.lower().strip() | |
| # Exact match on section number | |
| for s in sections: | |
| if s["id"] == query_lower or s["id"] == query: | |
| return s | |
| # Exact match on title | |
| for s in sections: | |
| if query_lower == s["title"].lower(): | |
| return s | |
| # Substring match on title | |
| for s in sections: | |
| if query_lower in s["title"].lower(): | |
| return s | |
| # Number prefix match (e.g., "4" matches "4.1", "4.2", etc. — return parent) | |
| for s in sections: | |
| if s["id"].startswith(query_lower + ".") or s["id"] == query_lower: | |
| return s | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Formatting helpers | |
| # --------------------------------------------------------------------------- | |
| def _clean_description(text: str) -> str: | |
| """Strip HTML card artifacts and collapse whitespace from HF API descriptions.""" | |
| text = re.sub(r"[\t]+", " ", text) | |
| text = re.sub(r"\n{2,}", "\n", text) | |
| return text.strip() | |
| def _truncate(text: str, max_len: int) -> str: | |
| if len(text) <= max_len: | |
| return text | |
| return text[:max_len] + "..." | |
| def _format_paper_list( | |
| papers: list, title: str, date: str | None = None, query: str | None = None | |
| ) -> str: | |
| lines = [f"# {title}"] | |
| if date: | |
| lines[0] += f" ({date})" | |
| if query: | |
| lines.append(f"Filtered by: '{query}'") | |
| lines.append(f"Showing {len(papers)} paper(s)\n") | |
| for i, item in enumerate(papers, 1): | |
| paper = item.get("paper", item) | |
| arxiv_id = paper.get("id", "") | |
| paper_title = paper.get("title", "Unknown") | |
| upvotes = paper.get("upvotes", 0) | |
| summary = paper.get("ai_summary") or _truncate( | |
| paper.get("summary", ""), MAX_SUMMARY_LEN | |
| ) | |
| keywords = paper.get("ai_keywords") or [] | |
| github = paper.get("githubRepo") or "" | |
| stars = paper.get("githubStars") or 0 | |
| lines.append(f"## {i}. {paper_title}") | |
| lines.append(f"**arxiv_id:** {arxiv_id} | **upvotes:** {upvotes}") | |
| lines.append(f"https://huggingface.co/papers/{arxiv_id}") | |
| if keywords: | |
| lines.append(f"**Keywords:** {', '.join(keywords[:5])}") | |
| if github: | |
| lines.append(f"**GitHub:** {github} ({stars} stars)") | |
| if summary: | |
| lines.append(f"**Summary:** {_truncate(summary, MAX_SUMMARY_LEN)}") | |
| lines.append("") | |
| return "\n".join(lines) | |
| def _format_paper_detail(paper: dict) -> str: | |
| arxiv_id = paper.get("id", "") | |
| title = paper.get("title", "Unknown") | |
| upvotes = paper.get("upvotes", 0) | |
| ai_summary = paper.get("ai_summary") or "" | |
| summary = paper.get("summary", "") | |
| keywords = paper.get("ai_keywords") or [] | |
| github = paper.get("githubRepo") or "" | |
| stars = paper.get("githubStars") or 0 | |
| authors = paper.get("authors") or [] | |
| lines = [f"# {title}"] | |
| lines.append(f"**arxiv_id:** {arxiv_id} | **upvotes:** {upvotes}") | |
| lines.append(f"https://huggingface.co/papers/{arxiv_id}") | |
| lines.append(f"https://arxiv.org/abs/{arxiv_id}") | |
| if authors: | |
| names = [a.get("name", "") for a in authors[:10]] | |
| author_str = ", ".join(n for n in names if n) | |
| if len(authors) > 10: | |
| author_str += f" (+{len(authors) - 10} more)" | |
| lines.append(f"**Authors:** {author_str}") | |
| if keywords: | |
| lines.append(f"**Keywords:** {', '.join(keywords)}") | |
| if github: | |
| lines.append(f"**GitHub:** {github} ({stars} stars)") | |
| if ai_summary: | |
| lines.append(f"\n## AI Summary\n{ai_summary}") | |
| if summary: | |
| lines.append(f"\n## Abstract\n{_truncate(summary, 500)}") | |
| lines.append( | |
| "\n**Next:** Use read_paper to read specific sections, or find_all_resources to discover linked datasets/models." | |
| ) | |
| return "\n".join(lines) | |
| def _format_read_paper_toc(parsed: dict[str, Any], arxiv_id: str) -> str: | |
| """Format TOC view: abstract + section list with previews.""" | |
| lines = [f"# {parsed['title']}"] | |
| lines.append(f"https://arxiv.org/abs/{arxiv_id}\n") | |
| if parsed["abstract"]: | |
| lines.append(f"## Abstract\n{parsed['abstract']}\n") | |
| lines.append("## Sections") | |
| for s in parsed["sections"]: | |
| prefix = " " if s["level"] == 3 else "" | |
| preview = ( | |
| _truncate(s["text"], MAX_SECTION_PREVIEW_LEN) if s["text"] else "(empty)" | |
| ) | |
| lines.append(f"{prefix}- **{s['title']}**: {preview}") | |
| lines.append( | |
| '\nCall read_paper with section parameter (e.g. section="4" or section="Experiments") to read a specific section.' | |
| ) | |
| return "\n".join(lines) | |
| def _format_read_paper_section(section: dict, arxiv_id: str) -> str: | |
| """Format a single section's full text.""" | |
| lines = [f"# {section['title']}"] | |
| lines.append(f"https://arxiv.org/abs/{arxiv_id}\n") | |
| text = section["text"] | |
| if len(text) > MAX_SECTION_TEXT_LEN: | |
| text = ( | |
| text[:MAX_SECTION_TEXT_LEN] | |
| + f"\n\n... (truncated at {MAX_SECTION_TEXT_LEN} chars)" | |
| ) | |
| lines.append(text if text else "(This section has no extractable text content.)") | |
| return "\n".join(lines) | |
| def _format_datasets(datasets: list, arxiv_id: str, sort: str) -> str: | |
| lines = [f"# Datasets linked to paper {arxiv_id}"] | |
| lines.append(f"https://huggingface.co/papers/{arxiv_id}") | |
| lines.append(f"Showing {len(datasets)} dataset(s), sorted by {sort}\n") | |
| for i, ds in enumerate(datasets, 1): | |
| ds_id = ds.get("id", "unknown") | |
| downloads = ds.get("downloads", 0) | |
| likes = ds.get("likes", 0) | |
| desc = _truncate(_clean_description(ds.get("description") or ""), MAX_SUMMARY_LEN) | |
| tags = ds.get("tags") or [] | |
| interesting = [t for t in tags if not t.startswith(("arxiv:", "region:"))][:5] | |
| lines.append(f"**{i}. [{ds_id}](https://huggingface.co/datasets/{ds_id})**") | |
| lines.append(f" Downloads: {downloads:,} | Likes: {likes}") | |
| if interesting: | |
| lines.append(f" Tags: {', '.join(interesting)}") | |
| if desc: | |
| lines.append(f" {desc}") | |
| lines.append("") | |
| if datasets: | |
| top = datasets[0].get("id", "") | |
| lines.append(f'**Inspect top dataset:** hf_inspect_dataset(dataset="{top}")') | |
| return "\n".join(lines) | |
| def _format_datasets_compact(datasets: list) -> str: | |
| if not datasets: | |
| return "## Datasets\nNone found" | |
| lines = [f"## Datasets ({len(datasets)})"] | |
| for ds in datasets: | |
| lines.append( | |
| f"- **{ds.get('id', '?')}** ({ds.get('downloads', 0):,} downloads)" | |
| ) | |
| return "\n".join(lines) | |
| def _format_models(models: list, arxiv_id: str, sort: str) -> str: | |
| lines = [f"# Models linked to paper {arxiv_id}"] | |
| lines.append(f"https://huggingface.co/papers/{arxiv_id}") | |
| lines.append(f"Showing {len(models)} model(s), sorted by {sort}\n") | |
| for i, m in enumerate(models, 1): | |
| model_id = m.get("id", "unknown") | |
| downloads = m.get("downloads", 0) | |
| likes = m.get("likes", 0) | |
| pipeline = m.get("pipeline_tag") or "" | |
| library = m.get("library_name") or "" | |
| lines.append(f"**{i}. [{model_id}](https://huggingface.co/{model_id})**") | |
| meta = f" Downloads: {downloads:,} | Likes: {likes}" | |
| if pipeline: | |
| meta += f" | Task: {pipeline}" | |
| if library: | |
| meta += f" | Library: {library}" | |
| lines.append(meta) | |
| lines.append("") | |
| return "\n".join(lines) | |
| def _format_models_compact(models: list) -> str: | |
| if not models: | |
| return "## Models\nNone found" | |
| lines = [f"## Models ({len(models)})"] | |
| for m in models: | |
| pipeline = m.get("pipeline_tag") or "" | |
| suffix = f" ({pipeline})" if pipeline else "" | |
| lines.append( | |
| f"- **{m.get('id', '?')}** ({m.get('downloads', 0):,} downloads){suffix}" | |
| ) | |
| return "\n".join(lines) | |
| def _format_collections(collections: list, arxiv_id: str) -> str: | |
| lines = [f"# Collections containing paper {arxiv_id}"] | |
| lines.append(f"Showing {len(collections)} collection(s)\n") | |
| for i, c in enumerate(collections, 1): | |
| slug = c.get("slug", "") | |
| title = c.get("title", "Untitled") | |
| upvotes = c.get("upvotes", 0) | |
| owner = c.get("owner", {}).get("name", "") | |
| desc = _truncate(c.get("description") or "", MAX_SUMMARY_LEN) | |
| num_items = len(c.get("items", [])) | |
| lines.append(f"**{i}. {title}**") | |
| lines.append(f" By: {owner} | Upvotes: {upvotes} | Items: {num_items}") | |
| lines.append(f" https://huggingface.co/collections/{slug}") | |
| if desc: | |
| lines.append(f" {desc}") | |
| lines.append("") | |
| return "\n".join(lines) | |
| def _format_collections_compact(collections: list) -> str: | |
| if not collections: | |
| return "## Collections\nNone found" | |
| lines = [f"## Collections ({len(collections)})"] | |
| for c in collections: | |
| title = c.get("title", "Untitled") | |
| owner = c.get("owner", {}).get("name", "") | |
| upvotes = c.get("upvotes", 0) | |
| lines.append(f"- **{title}** by {owner} ({upvotes} upvotes)") | |
| return "\n".join(lines) | |
| # --------------------------------------------------------------------------- | |
| # Operation handlers | |
| # --------------------------------------------------------------------------- | |
| def _error(message: str) -> ToolResult: | |
| return { | |
| "formatted": message, | |
| "totalResults": 0, | |
| "resultsShared": 0, | |
| "isError": True, | |
| } | |
| def _validate_arxiv_id(args: dict) -> str | None: | |
| """Return arxiv_id or None if missing.""" | |
| return args.get("arxiv_id") | |
| async def _op_trending(args: dict[str, Any], limit: int) -> ToolResult: | |
| date = args.get("date") | |
| query = args.get("query") | |
| params: dict[str, Any] = {"limit": limit if not query else max(limit * 3, 30)} | |
| if date: | |
| params["date"] = date | |
| async with httpx.AsyncClient(timeout=15) as client: | |
| resp = await client.get(f"{HF_API}/daily_papers", params=params) | |
| resp.raise_for_status() | |
| papers = resp.json() | |
| if query: | |
| q = query.lower() | |
| papers = [ | |
| p | |
| for p in papers | |
| if q in p.get("title", "").lower() | |
| or q in p.get("paper", {}).get("title", "").lower() | |
| or q in p.get("paper", {}).get("summary", "").lower() | |
| or any( | |
| q in kw.lower() for kw in (p.get("paper", {}).get("ai_keywords") or []) | |
| ) | |
| ] | |
| papers = papers[:limit] | |
| if not papers: | |
| msg = "No trending papers found" | |
| if query: | |
| msg += f" matching '{query}'" | |
| if date: | |
| msg += f" for {date}" | |
| return {"formatted": msg, "totalResults": 0, "resultsShared": 0} | |
| formatted = _format_paper_list(papers, "Trending Papers", date=date, query=query) | |
| return { | |
| "formatted": formatted, | |
| "totalResults": len(papers), | |
| "resultsShared": len(papers), | |
| } | |
| async def _op_search(args: dict[str, Any], limit: int) -> ToolResult: | |
| query = args.get("query") | |
| if not query: | |
| return _error("'query' is required for search operation.") | |
| async with httpx.AsyncClient(timeout=15) as client: | |
| resp = await client.get( | |
| f"{HF_API}/papers/search", params={"q": query, "limit": limit} | |
| ) | |
| resp.raise_for_status() | |
| papers = resp.json() | |
| if not papers: | |
| return { | |
| "formatted": f"No papers found for '{query}'", | |
| "totalResults": 0, | |
| "resultsShared": 0, | |
| } | |
| formatted = _format_paper_list(papers, f"Papers matching '{query}'") | |
| return { | |
| "formatted": formatted, | |
| "totalResults": len(papers), | |
| "resultsShared": len(papers), | |
| } | |
| async def _op_paper_details(args: dict[str, Any], limit: int) -> ToolResult: | |
| arxiv_id = _validate_arxiv_id(args) | |
| if not arxiv_id: | |
| return _error("'arxiv_id' is required for paper_details.") | |
| async with httpx.AsyncClient(timeout=15) as client: | |
| resp = await client.get(f"{HF_API}/papers/{arxiv_id}") | |
| resp.raise_for_status() | |
| paper = resp.json() | |
| return { | |
| "formatted": _format_paper_detail(paper), | |
| "totalResults": 1, | |
| "resultsShared": 1, | |
| } | |
| async def _op_read_paper(args: dict[str, Any], limit: int) -> ToolResult: | |
| arxiv_id = _validate_arxiv_id(args) | |
| if not arxiv_id: | |
| return _error("'arxiv_id' is required for read_paper.") | |
| section_query = args.get("section") | |
| # Try fetching HTML from arxiv, then ar5iv, then fallback to abstract | |
| parsed = None | |
| async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client: | |
| for base_url in [ARXIV_HTML, AR5IV_HTML]: | |
| try: | |
| resp = await client.get(f"{base_url}/{arxiv_id}") | |
| if resp.status_code == 200: | |
| parsed = _parse_paper_html(resp.text) | |
| if parsed["sections"]: # Only use if we got real sections | |
| break | |
| parsed = None | |
| except httpx.RequestError: | |
| continue | |
| # Fallback: return abstract from HF API | |
| if not parsed or not parsed["sections"]: | |
| try: | |
| async with httpx.AsyncClient(timeout=15) as client: | |
| resp = await client.get(f"{HF_API}/papers/{arxiv_id}") | |
| resp.raise_for_status() | |
| paper = resp.json() | |
| abstract = paper.get("summary", "") | |
| title = paper.get("title", "") | |
| msg = f"# {title}\nhttps://arxiv.org/abs/{arxiv_id}\n\n" | |
| msg += f"## Abstract\n{abstract}\n\n" | |
| msg += "HTML version not available for this paper. Only abstract shown.\n" | |
| msg += f"PDF: https://arxiv.org/pdf/{arxiv_id}" | |
| return {"formatted": msg, "totalResults": 1, "resultsShared": 1} | |
| except Exception: | |
| return _error( | |
| f"Could not fetch paper {arxiv_id}. Check the arxiv ID is correct." | |
| ) | |
| # Return TOC or specific section | |
| if not section_query: | |
| formatted = _format_read_paper_toc(parsed, arxiv_id) | |
| return { | |
| "formatted": formatted, | |
| "totalResults": len(parsed["sections"]), | |
| "resultsShared": len(parsed["sections"]), | |
| } | |
| section = _find_section(parsed["sections"], section_query) | |
| if not section: | |
| available = "\n".join(f"- {s['title']}" for s in parsed["sections"]) | |
| return _error( | |
| f"Section '{section_query}' not found. Available sections:\n{available}" | |
| ) | |
| formatted = _format_read_paper_section(section, arxiv_id) | |
| return {"formatted": formatted, "totalResults": 1, "resultsShared": 1} | |
| async def _op_find_datasets(args: dict[str, Any], limit: int) -> ToolResult: | |
| arxiv_id = _validate_arxiv_id(args) | |
| if not arxiv_id: | |
| return _error("'arxiv_id' is required for find_datasets.") | |
| sort = args.get("sort", "downloads") | |
| sort_key = SORT_MAP.get(sort, "downloads") | |
| async with httpx.AsyncClient(timeout=15) as client: | |
| resp = await client.get( | |
| f"{HF_API}/datasets", | |
| params={ | |
| "filter": f"arxiv:{arxiv_id}", | |
| "limit": limit, | |
| "sort": sort_key, | |
| "direction": -1, | |
| }, | |
| ) | |
| resp.raise_for_status() | |
| datasets = resp.json() | |
| if not datasets: | |
| return { | |
| "formatted": f"No datasets found linked to paper {arxiv_id}.\nhttps://huggingface.co/papers/{arxiv_id}", | |
| "totalResults": 0, | |
| "resultsShared": 0, | |
| } | |
| return { | |
| "formatted": _format_datasets(datasets, arxiv_id, sort), | |
| "totalResults": len(datasets), | |
| "resultsShared": len(datasets), | |
| } | |
| async def _op_find_models(args: dict[str, Any], limit: int) -> ToolResult: | |
| arxiv_id = _validate_arxiv_id(args) | |
| if not arxiv_id: | |
| return _error("'arxiv_id' is required for find_models.") | |
| sort = args.get("sort", "downloads") | |
| sort_key = SORT_MAP.get(sort, "downloads") | |
| async with httpx.AsyncClient(timeout=15) as client: | |
| resp = await client.get( | |
| f"{HF_API}/models", | |
| params={ | |
| "filter": f"arxiv:{arxiv_id}", | |
| "limit": limit, | |
| "sort": sort_key, | |
| "direction": -1, | |
| }, | |
| ) | |
| resp.raise_for_status() | |
| models = resp.json() | |
| if not models: | |
| return { | |
| "formatted": f"No models found linked to paper {arxiv_id}.\nhttps://huggingface.co/papers/{arxiv_id}", | |
| "totalResults": 0, | |
| "resultsShared": 0, | |
| } | |
| return { | |
| "formatted": _format_models(models, arxiv_id, sort), | |
| "totalResults": len(models), | |
| "resultsShared": len(models), | |
| } | |
| async def _op_find_collections(args: dict[str, Any], limit: int) -> ToolResult: | |
| arxiv_id = _validate_arxiv_id(args) | |
| if not arxiv_id: | |
| return _error("'arxiv_id' is required for find_collections.") | |
| async with httpx.AsyncClient(timeout=15) as client: | |
| resp = await client.get(f"{HF_API}/collections", params={"paper": arxiv_id}) | |
| resp.raise_for_status() | |
| collections = resp.json() | |
| if not collections: | |
| return { | |
| "formatted": f"No collections found containing paper {arxiv_id}.\nhttps://huggingface.co/papers/{arxiv_id}", | |
| "totalResults": 0, | |
| "resultsShared": 0, | |
| } | |
| collections = collections[:limit] | |
| return { | |
| "formatted": _format_collections(collections, arxiv_id), | |
| "totalResults": len(collections), | |
| "resultsShared": len(collections), | |
| } | |
| async def _op_find_all_resources(args: dict[str, Any], limit: int) -> ToolResult: | |
| arxiv_id = _validate_arxiv_id(args) | |
| if not arxiv_id: | |
| return _error("'arxiv_id' is required for find_all_resources.") | |
| per_cat = min(limit, 10) | |
| async with httpx.AsyncClient(timeout=15) as client: | |
| results = await asyncio.gather( | |
| client.get( | |
| f"{HF_API}/datasets", | |
| params={ | |
| "filter": f"arxiv:{arxiv_id}", | |
| "limit": per_cat, | |
| "sort": "downloads", | |
| "direction": -1, | |
| }, | |
| ), | |
| client.get( | |
| f"{HF_API}/models", | |
| params={ | |
| "filter": f"arxiv:{arxiv_id}", | |
| "limit": per_cat, | |
| "sort": "downloads", | |
| "direction": -1, | |
| }, | |
| ), | |
| client.get(f"{HF_API}/collections", params={"paper": arxiv_id}), | |
| return_exceptions=True, | |
| ) | |
| sections = [] | |
| total = 0 | |
| # Datasets | |
| if isinstance(results[0], Exception): | |
| sections.append(f"## Datasets\nError: {results[0]}") | |
| else: | |
| datasets = results[0].json() | |
| total += len(datasets) | |
| sections.append(_format_datasets_compact(datasets[:per_cat])) | |
| # Models | |
| if isinstance(results[1], Exception): | |
| sections.append(f"## Models\nError: {results[1]}") | |
| else: | |
| models = results[1].json() | |
| total += len(models) | |
| sections.append(_format_models_compact(models[:per_cat])) | |
| # Collections | |
| if isinstance(results[2], Exception): | |
| sections.append(f"## Collections\nError: {results[2]}") | |
| else: | |
| collections = results[2].json() | |
| total += len(collections) | |
| sections.append(_format_collections_compact(collections[:per_cat])) | |
| header = f"# Resources linked to paper {arxiv_id}\nhttps://huggingface.co/papers/{arxiv_id}\n" | |
| formatted = header + "\n\n".join(sections) | |
| return {"formatted": formatted, "totalResults": total, "resultsShared": total} | |
| # --------------------------------------------------------------------------- | |
| # Operation dispatch | |
| # --------------------------------------------------------------------------- | |
| _OPERATIONS = { | |
| "trending": _op_trending, | |
| "search": _op_search, | |
| "paper_details": _op_paper_details, | |
| "read_paper": _op_read_paper, | |
| "find_datasets": _op_find_datasets, | |
| "find_models": _op_find_models, | |
| "find_collections": _op_find_collections, | |
| "find_all_resources": _op_find_all_resources, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Tool spec + handler | |
| # --------------------------------------------------------------------------- | |
| HF_PAPERS_TOOL_SPEC = { | |
| "name": "hf_papers", | |
| "description": ( | |
| "Discover ML research papers, find their linked resources (datasets, models, collections), " | |
| "and read paper contents on HuggingFace Hub and arXiv.\n\n" | |
| "Use this when exploring a research area, looking for datasets for a task, " | |
| "implementing a paper's approach, or trying to improve performance on something. " | |
| "Typical flow:\n" | |
| " hf_papers(search/trending) → hf_papers(read_paper) → hf_papers(find_all_resources) → hf_inspect_dataset\n\n" | |
| "Operations:\n" | |
| "- trending: Get trending daily papers, optionally filter by topic keyword\n" | |
| "- search: Full-text search for papers by query\n" | |
| "- paper_details: Get metadata, abstract, AI summary, and github link for a paper\n" | |
| "- read_paper: Read paper contents — without section: returns abstract + table of contents; " | |
| "with section: returns full section text\n" | |
| "- find_datasets: Find datasets linked to a paper\n" | |
| "- find_models: Find models linked to a paper\n" | |
| "- find_collections: Find collections that include a paper\n" | |
| "- find_all_resources: Parallel fetch of datasets + models + collections for a paper (unified view)" | |
| ), | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "operation": { | |
| "type": "string", | |
| "enum": list(_OPERATIONS.keys()), | |
| "description": "Operation to execute.", | |
| }, | |
| "query": { | |
| "type": "string", | |
| "description": ( | |
| "Search query. Required for: search. " | |
| "Optional for: trending (filters results by keyword match on title, summary, and AI-generated keywords)." | |
| ), | |
| }, | |
| "arxiv_id": { | |
| "type": "string", | |
| "description": ( | |
| "ArXiv paper ID (e.g. '2305.18290'). " | |
| "Required for: paper_details, read_paper, find_datasets, find_models, find_collections, find_all_resources. " | |
| "Get IDs from trending or search results first." | |
| ), | |
| }, | |
| "section": { | |
| "type": "string", | |
| "description": ( | |
| "Section name or number to read (e.g. '3', 'Experiments', '4.2'). " | |
| "Optional for: read_paper. Without this, read_paper returns the abstract + table of contents " | |
| "so you can choose which section to read." | |
| ), | |
| }, | |
| "date": { | |
| "type": "string", | |
| "description": "Date in YYYY-MM-DD format. Optional for: trending (defaults to recent papers).", | |
| }, | |
| "sort": { | |
| "type": "string", | |
| "enum": ["downloads", "likes", "trending"], | |
| "description": ( | |
| "Sort order for find_datasets and find_models. Default: downloads. " | |
| "Use 'downloads' for most-used, 'likes' for community favorites, 'trending' for recently popular." | |
| ), | |
| }, | |
| "limit": { | |
| "type": "integer", | |
| "description": "Maximum results to return (default: 10, max: 50).", | |
| }, | |
| }, | |
| "required": ["operation"], | |
| }, | |
| } | |
| async def hf_papers_handler(arguments: dict[str, Any]) -> tuple[str, bool]: | |
| """Handler for agent tool router.""" | |
| operation = arguments.get("operation") | |
| if not operation: | |
| return "'operation' parameter is required.", False | |
| handler = _OPERATIONS.get(operation) | |
| if not handler: | |
| valid = ", ".join(_OPERATIONS.keys()) | |
| return f"Unknown operation: '{operation}'. Valid: {valid}", False | |
| limit = min(arguments.get("limit", DEFAULT_LIMIT), MAX_LIMIT) | |
| try: | |
| result = await handler(arguments, limit) | |
| return result["formatted"], not result.get("isError", False) | |
| except httpx.HTTPStatusError as e: | |
| return f"API error: {e.response.status_code} — {e.response.text[:200]}", False | |
| except httpx.RequestError as e: | |
| return f"Request error: {e}", False | |
| except Exception as e: | |
| return f"Error in {operation}: {e}", False | |