""" HF Papers Tool — Discover papers, read their contents, and find linked resources. Operations: trending, search, paper_details, read_paper, find_datasets, find_models, find_collections, find_all_resources, citation_graph, snippet_search, recommend """ import asyncio import os import re import time from typing import Any import httpx from bs4 import BeautifulSoup, Tag from agent.tools.types import ToolResult HF_API = "https://huggingface.co/api" ARXIV_HTML = "https://arxiv.org/html" AR5IV_HTML = "https://ar5iv.labs.arxiv.org/html" DEFAULT_LIMIT = 10 MAX_LIMIT = 50 MAX_SUMMARY_LEN = 300 MAX_SECTION_PREVIEW_LEN = 280 MAX_SECTION_TEXT_LEN = 8000 SORT_MAP = { "downloads": "downloads", "likes": "likes", "trending": "trendingScore", } # --------------------------------------------------------------------------- # Semantic Scholar API # --------------------------------------------------------------------------- S2_API = "https://api.semanticscholar.org" S2_API_KEY = os.environ.get("S2_API_KEY") S2_HEADERS: dict[str, str] = {"x-api-key": S2_API_KEY} if S2_API_KEY else {} S2_TIMEOUT = 12 _s2_last_request: float = 0.0 # Shared response cache (survives across sessions, keyed by (path, params_tuple)) _s2_cache: dict[str, Any] = {} _S2_CACHE_MAX = 500 def _s2_paper_id(arxiv_id: str) -> str: """Convert bare arxiv ID to S2 format.""" return f"ARXIV:{arxiv_id}" def _s2_cache_key(path: str, params: dict | None) -> str: """Build a hashable cache key from path + sorted params.""" p = tuple(sorted((params or {}).items())) return f"{path}:{p}" async def _s2_request( client: httpx.AsyncClient, method: str, path: str, **kwargs: Any, ) -> httpx.Response | None: """S2 request with 2 retries on 429/5xx. Rate-limited only when using API key.""" global _s2_last_request url = f"{S2_API}{path}" kwargs.setdefault("headers", {}).update(S2_HEADERS) kwargs.setdefault("timeout", S2_TIMEOUT) for attempt in range(3): # Rate limit only when authenticated (1 req/s for search, 10 req/s for others) if S2_API_KEY: min_interval = 1.0 if "search" in path else 0.1 elapsed = time.monotonic() - _s2_last_request if elapsed < min_interval: await asyncio.sleep(min_interval - elapsed) _s2_last_request = time.monotonic() try: resp = await client.request(method, url, **kwargs) if resp.status_code == 429: if attempt < 2: await asyncio.sleep(60) continue return None if resp.status_code >= 500: if attempt < 2: await asyncio.sleep(3) continue return None return resp except (httpx.RequestError, httpx.HTTPStatusError): if attempt < 2: await asyncio.sleep(3) continue return None return None async def _s2_get_json( client: httpx.AsyncClient, path: str, params: dict | None = None, ) -> dict | None: """Cached S2 GET returning parsed JSON or None.""" key = _s2_cache_key(path, params) if key in _s2_cache: return _s2_cache[key] resp = await _s2_request(client, "GET", path, params=params or {}) if resp and resp.status_code == 200: data = resp.json() if len(_s2_cache) < _S2_CACHE_MAX: _s2_cache[key] = data return data return None async def _s2_get_paper( client: httpx.AsyncClient, arxiv_id: str, fields: str, ) -> dict | None: """Fetch a single paper from S2 by arxiv ID. Returns None on failure.""" return await _s2_get_json( client, f"/graph/v1/paper/{_s2_paper_id(arxiv_id)}", {"fields": fields}, ) # --------------------------------------------------------------------------- # HTML paper parsing # --------------------------------------------------------------------------- def _parse_paper_html(html: str) -> dict[str, Any]: """Parse arxiv HTML into structured sections. Returns: { "title": str, "abstract": str, "sections": [{"id": str, "title": str, "level": int, "text": str}], } """ soup = BeautifulSoup(html, "html.parser") # Title title_el = soup.find("h1", class_="ltx_title") title = title_el.get_text(strip=True).removeprefix("Title:") if title_el else "" # Abstract abstract_el = soup.find("div", class_="ltx_abstract") abstract = "" if abstract_el: # Skip the "Abstract" heading itself for child in abstract_el.children: if isinstance(child, Tag) and child.name in ("h6", "h2", "h3", "p", "span"): if child.get_text(strip=True).lower() == "abstract": continue if isinstance(child, Tag) and child.name == "p": abstract += child.get_text(separator=" ", strip=True) + " " abstract = abstract.strip() # Sections — collect h2/h3 headings and text between them sections: list[dict[str, Any]] = [] headings = soup.find_all(["h2", "h3"], class_=lambda c: c and "ltx_title" in c) for heading in headings: level = 2 if heading.name == "h2" else 3 heading_text = heading.get_text(separator=" ", strip=True) # Collect text from siblings until next heading of same or higher level text_parts: list[str] = [] sibling = heading.find_next_sibling() while sibling: if isinstance(sibling, Tag): if sibling.name in ("h2", "h3") and "ltx_title" in ( sibling.get("class") or [] ): break # Also stop at h2 if we're collecting h3 content if sibling.name == "h2" and level == 3: break text_parts.append(sibling.get_text(separator=" ", strip=True)) sibling = sibling.find_next_sibling() # Also check parent section element for contained paragraphs parent_section = heading.find_parent("section") if parent_section and not text_parts: for p in parent_section.find_all("p", recursive=False): text_parts.append(p.get_text(separator=" ", strip=True)) section_text = "\n\n".join(t for t in text_parts if t) # Extract section number from heading text (e.g., "4 Experiments" → "4") num_match = re.match(r"^([A-Z]?\d+(?:\.\d+)*)\s", heading_text) section_id = num_match.group(1) if num_match else "" sections.append( { "id": section_id, "title": heading_text, "level": level, "text": section_text, } ) return {"title": title, "abstract": abstract, "sections": sections} def _find_section(sections: list[dict], query: str) -> dict | None: """Find a section by number or name (fuzzy).""" query_lower = query.lower().strip() # Exact match on section number for s in sections: if s["id"] == query_lower or s["id"] == query: return s # Exact match on title for s in sections: if query_lower == s["title"].lower(): return s # Substring match on title for s in sections: if query_lower in s["title"].lower(): return s # Number prefix match (e.g., "4" matches "4.1", "4.2", etc. — return parent) for s in sections: if s["id"].startswith(query_lower + ".") or s["id"] == query_lower: return s return None # --------------------------------------------------------------------------- # Formatting helpers # --------------------------------------------------------------------------- def _clean_description(text: str) -> str: """Strip HTML card artifacts and collapse whitespace from HF API descriptions.""" text = re.sub(r"[\t]+", " ", text) text = re.sub(r"\n{2,}", "\n", text) return text.strip() def _truncate(text: str, max_len: int) -> str: if len(text) <= max_len: return text return text[:max_len] + "..." def _format_paper_list( papers: list, title: str, date: str | None = None, query: str | None = None ) -> str: lines = [f"# {title}"] if date: lines[0] += f" ({date})" if query: lines.append(f"Filtered by: '{query}'") lines.append(f"Showing {len(papers)} paper(s)\n") for i, item in enumerate(papers, 1): paper = item.get("paper", item) arxiv_id = paper.get("id", "") paper_title = paper.get("title", "Unknown") upvotes = paper.get("upvotes", 0) summary = paper.get("ai_summary") or _truncate( paper.get("summary", ""), MAX_SUMMARY_LEN ) keywords = paper.get("ai_keywords") or [] github = paper.get("githubRepo") or "" stars = paper.get("githubStars") or 0 lines.append(f"## {i}. {paper_title}") lines.append(f"**arxiv_id:** {arxiv_id} | **upvotes:** {upvotes}") lines.append(f"https://huggingface.co/papers/{arxiv_id}") if keywords: lines.append(f"**Keywords:** {', '.join(keywords[:5])}") if github: lines.append(f"**GitHub:** {github} ({stars} stars)") if summary: lines.append(f"**Summary:** {_truncate(summary, MAX_SUMMARY_LEN)}") lines.append("") return "\n".join(lines) def _format_paper_detail(paper: dict, s2_data: dict | None = None) -> str: arxiv_id = paper.get("id", "") title = paper.get("title", "Unknown") upvotes = paper.get("upvotes", 0) ai_summary = paper.get("ai_summary") or "" summary = paper.get("summary", "") keywords = paper.get("ai_keywords") or [] github = paper.get("githubRepo") or "" stars = paper.get("githubStars") or 0 authors = paper.get("authors") or [] lines = [f"# {title}"] meta_parts = [f"**arxiv_id:** {arxiv_id}", f"**upvotes:** {upvotes}"] if s2_data: cites = s2_data.get("citationCount", 0) influential = s2_data.get("influentialCitationCount", 0) meta_parts.append(f"**citations:** {cites} ({influential} influential)") lines.append(" | ".join(meta_parts)) lines.append(f"https://huggingface.co/papers/{arxiv_id}") lines.append(f"https://arxiv.org/abs/{arxiv_id}") if authors: names = [a.get("name", "") for a in authors[:10]] author_str = ", ".join(n for n in names if n) if len(authors) > 10: author_str += f" (+{len(authors) - 10} more)" lines.append(f"**Authors:** {author_str}") if keywords: lines.append(f"**Keywords:** {', '.join(keywords)}") if s2_data and s2_data.get("s2FieldsOfStudy"): fields = [f["category"] for f in s2_data["s2FieldsOfStudy"] if f.get("category")] if fields: lines.append(f"**Fields:** {', '.join(fields)}") if s2_data and s2_data.get("venue"): lines.append(f"**Venue:** {s2_data['venue']}") if github: lines.append(f"**GitHub:** {github} ({stars} stars)") if s2_data and s2_data.get("tldr"): tldr_text = s2_data["tldr"].get("text", "") if tldr_text: lines.append(f"\n## TL;DR\n{tldr_text}") if ai_summary: lines.append(f"\n## AI Summary\n{ai_summary}") if summary: lines.append(f"\n## Abstract\n{_truncate(summary, 500)}") lines.append( "\n**Next:** Use read_paper to read specific sections, find_all_resources for linked datasets/models, " "or citation_graph to trace references and citations." ) return "\n".join(lines) def _format_read_paper_toc(parsed: dict[str, Any], arxiv_id: str) -> str: """Format TOC view: abstract + section list with previews.""" lines = [f"# {parsed['title']}"] lines.append(f"https://arxiv.org/abs/{arxiv_id}\n") if parsed["abstract"]: lines.append(f"## Abstract\n{parsed['abstract']}\n") lines.append("## Sections") for s in parsed["sections"]: prefix = " " if s["level"] == 3 else "" preview = ( _truncate(s["text"], MAX_SECTION_PREVIEW_LEN) if s["text"] else "(empty)" ) lines.append(f"{prefix}- **{s['title']}**: {preview}") lines.append( '\nCall read_paper with section parameter (e.g. section="4" or section="Experiments") to read a specific section.' ) return "\n".join(lines) def _format_read_paper_section(section: dict, arxiv_id: str) -> str: """Format a single section's full text.""" lines = [f"# {section['title']}"] lines.append(f"https://arxiv.org/abs/{arxiv_id}\n") text = section["text"] if len(text) > MAX_SECTION_TEXT_LEN: text = ( text[:MAX_SECTION_TEXT_LEN] + f"\n\n... (truncated at {MAX_SECTION_TEXT_LEN} chars)" ) lines.append(text if text else "(This section has no extractable text content.)") return "\n".join(lines) def _format_datasets(datasets: list, arxiv_id: str, sort: str) -> str: lines = [f"# Datasets linked to paper {arxiv_id}"] lines.append(f"https://huggingface.co/papers/{arxiv_id}") lines.append(f"Showing {len(datasets)} dataset(s), sorted by {sort}\n") for i, ds in enumerate(datasets, 1): ds_id = ds.get("id", "unknown") downloads = ds.get("downloads", 0) likes = ds.get("likes", 0) desc = _truncate(_clean_description(ds.get("description") or ""), MAX_SUMMARY_LEN) tags = ds.get("tags") or [] interesting = [t for t in tags if not t.startswith(("arxiv:", "region:"))][:5] lines.append(f"**{i}. [{ds_id}](https://huggingface.co/datasets/{ds_id})**") lines.append(f" Downloads: {downloads:,} | Likes: {likes}") if interesting: lines.append(f" Tags: {', '.join(interesting)}") if desc: lines.append(f" {desc}") lines.append("") if datasets: top = datasets[0].get("id", "") lines.append(f'**Inspect top dataset:** hf_inspect_dataset(dataset="{top}")') return "\n".join(lines) def _format_datasets_compact(datasets: list) -> str: if not datasets: return "## Datasets\nNone found" lines = [f"## Datasets ({len(datasets)})"] for ds in datasets: lines.append( f"- **{ds.get('id', '?')}** ({ds.get('downloads', 0):,} downloads)" ) return "\n".join(lines) def _format_models(models: list, arxiv_id: str, sort: str) -> str: lines = [f"# Models linked to paper {arxiv_id}"] lines.append(f"https://huggingface.co/papers/{arxiv_id}") lines.append(f"Showing {len(models)} model(s), sorted by {sort}\n") for i, m in enumerate(models, 1): model_id = m.get("id", "unknown") downloads = m.get("downloads", 0) likes = m.get("likes", 0) pipeline = m.get("pipeline_tag") or "" library = m.get("library_name") or "" lines.append(f"**{i}. [{model_id}](https://huggingface.co/{model_id})**") meta = f" Downloads: {downloads:,} | Likes: {likes}" if pipeline: meta += f" | Task: {pipeline}" if library: meta += f" | Library: {library}" lines.append(meta) lines.append("") return "\n".join(lines) def _format_models_compact(models: list) -> str: if not models: return "## Models\nNone found" lines = [f"## Models ({len(models)})"] for m in models: pipeline = m.get("pipeline_tag") or "" suffix = f" ({pipeline})" if pipeline else "" lines.append( f"- **{m.get('id', '?')}** ({m.get('downloads', 0):,} downloads){suffix}" ) return "\n".join(lines) def _format_collections(collections: list, arxiv_id: str) -> str: lines = [f"# Collections containing paper {arxiv_id}"] lines.append(f"Showing {len(collections)} collection(s)\n") for i, c in enumerate(collections, 1): slug = c.get("slug", "") title = c.get("title", "Untitled") upvotes = c.get("upvotes", 0) owner = c.get("owner", {}).get("name", "") desc = _truncate(c.get("description") or "", MAX_SUMMARY_LEN) num_items = len(c.get("items", [])) lines.append(f"**{i}. {title}**") lines.append(f" By: {owner} | Upvotes: {upvotes} | Items: {num_items}") lines.append(f" https://huggingface.co/collections/{slug}") if desc: lines.append(f" {desc}") lines.append("") return "\n".join(lines) def _format_collections_compact(collections: list) -> str: if not collections: return "## Collections\nNone found" lines = [f"## Collections ({len(collections)})"] for c in collections: title = c.get("title", "Untitled") owner = c.get("owner", {}).get("name", "") upvotes = c.get("upvotes", 0) lines.append(f"- **{title}** by {owner} ({upvotes} upvotes)") return "\n".join(lines) # --------------------------------------------------------------------------- # Operation handlers # --------------------------------------------------------------------------- def _error(message: str) -> ToolResult: return { "formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True, } def _validate_arxiv_id(args: dict) -> str | None: """Return arxiv_id or None if missing.""" return args.get("arxiv_id") async def _op_trending(args: dict[str, Any], limit: int) -> ToolResult: date = args.get("date") query = args.get("query") params: dict[str, Any] = {"limit": limit if not query else max(limit * 3, 30)} if date: params["date"] = date async with httpx.AsyncClient(timeout=15) as client: resp = await client.get(f"{HF_API}/daily_papers", params=params) resp.raise_for_status() papers = resp.json() if query: q = query.lower() papers = [ p for p in papers if q in p.get("title", "").lower() or q in p.get("paper", {}).get("title", "").lower() or q in p.get("paper", {}).get("summary", "").lower() or any( q in kw.lower() for kw in (p.get("paper", {}).get("ai_keywords") or []) ) ] papers = papers[:limit] if not papers: msg = "No trending papers found" if query: msg += f" matching '{query}'" if date: msg += f" for {date}" return {"formatted": msg, "totalResults": 0, "resultsShared": 0} formatted = _format_paper_list(papers, "Trending Papers", date=date, query=query) return { "formatted": formatted, "totalResults": len(papers), "resultsShared": len(papers), } def _format_s2_paper_list(papers: list[dict], title: str) -> str: """Format a list of S2 paper results.""" lines = [f"# {title}"] lines.append(f"Showing {len(papers)} result(s)\n") for i, paper in enumerate(papers, 1): ptitle = paper.get("title") or "(untitled)" year = paper.get("year") or "?" cites = paper.get("citationCount", 0) venue = paper.get("venue") or "" ext_ids = paper.get("externalIds") or {} aid = ext_ids.get("ArXiv", "") tldr = (paper.get("tldr") or {}).get("text", "") lines.append(f"### {i}. {ptitle}") meta = [f"Year: {year}", f"Citations: {cites}"] if venue: meta.append(f"Venue: {venue}") if aid: meta.append(f"arxiv_id: {aid}") lines.append(" | ".join(meta)) if aid: lines.append(f"https://arxiv.org/abs/{aid}") if tldr: lines.append(f"**TL;DR:** {tldr}") lines.append("") lines.append("Use paper_details with arxiv_id for full info, or read_paper to read sections.") return "\n".join(lines) async def _s2_bulk_search(query: str, args: dict[str, Any], limit: int) -> ToolResult | None: """Search via S2 bulk endpoint with filters. Returns None on failure.""" params: dict[str, Any] = { "query": query, "limit": limit, "fields": "title,externalIds,year,citationCount,tldr,venue,publicationDate", } # Date filter date_from = args.get("date_from", "") date_to = args.get("date_to", "") if date_from or date_to: params["publicationDateOrYear"] = f"{date_from}:{date_to}" # Fields of study categories = args.get("categories") if categories: params["fieldsOfStudy"] = categories # Min citations min_cites = args.get("min_citations") if min_cites: params["minCitationCount"] = str(min_cites) # Sort sort_by = args.get("sort_by") if sort_by and sort_by != "relevance": params["sort"] = f"{sort_by}:desc" async with httpx.AsyncClient(timeout=15) as client: resp = await _s2_request(client, "GET", "/graph/v1/paper/search/bulk", params=params) if not resp or resp.status_code != 200: return None data = resp.json() papers = data.get("data") or [] if not papers: return { "formatted": f"No papers found for '{query}' with the given filters.", "totalResults": 0, "resultsShared": 0, } formatted = _format_s2_paper_list(papers[:limit], f"Papers matching '{query}' (Semantic Scholar)") return { "formatted": formatted, "totalResults": data.get("total", len(papers)), "resultsShared": min(limit, len(papers)), } async def _op_search(args: dict[str, Any], limit: int) -> ToolResult: query = args.get("query") if not query: return _error("'query' is required for search operation.") # Route to S2 when filters are present use_s2 = any(args.get(k) for k in ("date_from", "date_to", "categories", "min_citations", "sort_by")) if use_s2: result = await _s2_bulk_search(query, args, limit) if result is not None: return result # Fall back to HF search (without filters) if S2 fails async with httpx.AsyncClient(timeout=15) as client: resp = await client.get( f"{HF_API}/papers/search", params={"q": query, "limit": limit} ) resp.raise_for_status() papers = resp.json() if not papers: return { "formatted": f"No papers found for '{query}'", "totalResults": 0, "resultsShared": 0, } formatted = _format_paper_list(papers, f"Papers matching '{query}'") return { "formatted": formatted, "totalResults": len(papers), "resultsShared": len(papers), } async def _op_paper_details(args: dict[str, Any], limit: int) -> ToolResult: arxiv_id = _validate_arxiv_id(args) if not arxiv_id: return _error("'arxiv_id' is required for paper_details.") async with httpx.AsyncClient(timeout=15) as client: resp = await client.get(f"{HF_API}/papers/{arxiv_id}") resp.raise_for_status() paper = resp.json() return { "formatted": _format_paper_detail(paper), "totalResults": 1, "resultsShared": 1, } async def _op_read_paper(args: dict[str, Any], limit: int) -> ToolResult: arxiv_id = _validate_arxiv_id(args) if not arxiv_id: return _error("'arxiv_id' is required for read_paper.") section_query = args.get("section") # Try fetching HTML from arxiv, then ar5iv, then fallback to abstract parsed = None async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client: for base_url in [ARXIV_HTML, AR5IV_HTML]: try: resp = await client.get(f"{base_url}/{arxiv_id}") if resp.status_code == 200: parsed = _parse_paper_html(resp.text) if parsed["sections"]: # Only use if we got real sections break parsed = None except httpx.RequestError: continue # Fallback: return abstract from HF API if not parsed or not parsed["sections"]: try: async with httpx.AsyncClient(timeout=15) as client: resp = await client.get(f"{HF_API}/papers/{arxiv_id}") resp.raise_for_status() paper = resp.json() abstract = paper.get("summary", "") title = paper.get("title", "") msg = f"# {title}\nhttps://arxiv.org/abs/{arxiv_id}\n\n" msg += f"## Abstract\n{abstract}\n\n" msg += "HTML version not available for this paper. Only abstract shown.\n" msg += f"PDF: https://arxiv.org/pdf/{arxiv_id}" return {"formatted": msg, "totalResults": 1, "resultsShared": 1} except Exception: return _error( f"Could not fetch paper {arxiv_id}. Check the arxiv ID is correct." ) # Return TOC or specific section if not section_query: formatted = _format_read_paper_toc(parsed, arxiv_id) return { "formatted": formatted, "totalResults": len(parsed["sections"]), "resultsShared": len(parsed["sections"]), } section = _find_section(parsed["sections"], section_query) if not section: available = "\n".join(f"- {s['title']}" for s in parsed["sections"]) return _error( f"Section '{section_query}' not found. Available sections:\n{available}" ) formatted = _format_read_paper_section(section, arxiv_id) return {"formatted": formatted, "totalResults": 1, "resultsShared": 1} # --------------------------------------------------------------------------- # Citation graph (Semantic Scholar) # --------------------------------------------------------------------------- def _format_citation_entry(entry: dict, show_context: bool = False) -> str: """Format a single citation/reference entry.""" paper = entry.get("citingPaper") or entry.get("citedPaper") or {} title = paper.get("title") or "(untitled)" year = paper.get("year") or "?" cites = paper.get("citationCount", 0) ext_ids = paper.get("externalIds") or {} aid = ext_ids.get("ArXiv", "") influential = " **[influential]**" if entry.get("isInfluential") else "" parts = [f"- **{title}** ({year}, {cites} cites){influential}"] if aid: parts[0] += f" arxiv:{aid}" if show_context: intents = entry.get("intents") or [] if intents: parts.append(f" Intent: {', '.join(intents)}") contexts = entry.get("contexts") or [] for ctx in contexts[:2]: if ctx: parts.append(f" > {_truncate(ctx, 200)}") return "\n".join(parts) def _format_citation_graph( arxiv_id: str, references: list[dict] | None, citations: list[dict] | None, ) -> str: lines = [f"# Citation Graph for {arxiv_id}"] lines.append(f"https://arxiv.org/abs/{arxiv_id}\n") if references is not None: lines.append(f"## References ({len(references)})") if references: for entry in references: lines.append(_format_citation_entry(entry)) else: lines.append("No references found.") lines.append("") if citations is not None: lines.append(f"## Citations ({len(citations)})") if citations: for entry in citations: lines.append(_format_citation_entry(entry, show_context=True)) else: lines.append("No citations found.") lines.append("") lines.append("**Tip:** Use paper_details with an arxiv_id from above to explore further.") return "\n".join(lines) async def _op_citation_graph(args: dict[str, Any], limit: int) -> ToolResult: arxiv_id = _validate_arxiv_id(args) if not arxiv_id: return _error("'arxiv_id' is required for citation_graph.") direction = args.get("direction", "both") s2_id = _s2_paper_id(arxiv_id) fields = "title,externalIds,year,citationCount,influentialCitationCount,contexts,intents,isInfluential" params = {"fields": fields, "limit": limit} async with httpx.AsyncClient(timeout=15) as client: refs, cites = None, None coros = [] if direction in ("references", "both"): coros.append(_s2_get_json(client, f"/graph/v1/paper/{s2_id}/references", params)) if direction in ("citations", "both"): coros.append(_s2_get_json(client, f"/graph/v1/paper/{s2_id}/citations", params)) results = await asyncio.gather(*coros, return_exceptions=True) idx = 0 if direction in ("references", "both"): r = results[idx] if isinstance(r, dict): refs = r.get("data", []) idx += 1 if direction in ("citations", "both"): r = results[idx] if isinstance(r, dict): cites = r.get("data", []) if refs is None and cites is None: return _error(f"Could not fetch citation data for {arxiv_id}. Paper may not be indexed by Semantic Scholar.") total = (len(refs) if refs else 0) + (len(cites) if cites else 0) return { "formatted": _format_citation_graph(arxiv_id, refs, cites), "totalResults": total, "resultsShared": total, } async def _op_find_datasets(args: dict[str, Any], limit: int) -> ToolResult: arxiv_id = _validate_arxiv_id(args) if not arxiv_id: return _error("'arxiv_id' is required for find_datasets.") sort = args.get("sort", "downloads") sort_key = SORT_MAP.get(sort, "downloads") async with httpx.AsyncClient(timeout=15) as client: resp = await client.get( f"{HF_API}/datasets", params={ "filter": f"arxiv:{arxiv_id}", "limit": limit, "sort": sort_key, "direction": -1, }, ) resp.raise_for_status() datasets = resp.json() if not datasets: return { "formatted": f"No datasets found linked to paper {arxiv_id}.\nhttps://huggingface.co/papers/{arxiv_id}", "totalResults": 0, "resultsShared": 0, } return { "formatted": _format_datasets(datasets, arxiv_id, sort), "totalResults": len(datasets), "resultsShared": len(datasets), } async def _op_find_models(args: dict[str, Any], limit: int) -> ToolResult: arxiv_id = _validate_arxiv_id(args) if not arxiv_id: return _error("'arxiv_id' is required for find_models.") sort = args.get("sort", "downloads") sort_key = SORT_MAP.get(sort, "downloads") async with httpx.AsyncClient(timeout=15) as client: resp = await client.get( f"{HF_API}/models", params={ "filter": f"arxiv:{arxiv_id}", "limit": limit, "sort": sort_key, "direction": -1, }, ) resp.raise_for_status() models = resp.json() if not models: return { "formatted": f"No models found linked to paper {arxiv_id}.\nhttps://huggingface.co/papers/{arxiv_id}", "totalResults": 0, "resultsShared": 0, } return { "formatted": _format_models(models, arxiv_id, sort), "totalResults": len(models), "resultsShared": len(models), } async def _op_find_collections(args: dict[str, Any], limit: int) -> ToolResult: arxiv_id = _validate_arxiv_id(args) if not arxiv_id: return _error("'arxiv_id' is required for find_collections.") async with httpx.AsyncClient(timeout=15) as client: resp = await client.get(f"{HF_API}/collections", params={"paper": arxiv_id}) resp.raise_for_status() collections = resp.json() if not collections: return { "formatted": f"No collections found containing paper {arxiv_id}.\nhttps://huggingface.co/papers/{arxiv_id}", "totalResults": 0, "resultsShared": 0, } collections = collections[:limit] return { "formatted": _format_collections(collections, arxiv_id), "totalResults": len(collections), "resultsShared": len(collections), } async def _op_find_all_resources(args: dict[str, Any], limit: int) -> ToolResult: arxiv_id = _validate_arxiv_id(args) if not arxiv_id: return _error("'arxiv_id' is required for find_all_resources.") per_cat = min(limit, 10) async with httpx.AsyncClient(timeout=15) as client: results = await asyncio.gather( client.get( f"{HF_API}/datasets", params={ "filter": f"arxiv:{arxiv_id}", "limit": per_cat, "sort": "downloads", "direction": -1, }, ), client.get( f"{HF_API}/models", params={ "filter": f"arxiv:{arxiv_id}", "limit": per_cat, "sort": "downloads", "direction": -1, }, ), client.get(f"{HF_API}/collections", params={"paper": arxiv_id}), return_exceptions=True, ) sections = [] total = 0 # Datasets if isinstance(results[0], Exception): sections.append(f"## Datasets\nError: {results[0]}") else: datasets = results[0].json() total += len(datasets) sections.append(_format_datasets_compact(datasets[:per_cat])) # Models if isinstance(results[1], Exception): sections.append(f"## Models\nError: {results[1]}") else: models = results[1].json() total += len(models) sections.append(_format_models_compact(models[:per_cat])) # Collections if isinstance(results[2], Exception): sections.append(f"## Collections\nError: {results[2]}") else: collections = results[2].json() total += len(collections) sections.append(_format_collections_compact(collections[:per_cat])) header = f"# Resources linked to paper {arxiv_id}\nhttps://huggingface.co/papers/{arxiv_id}\n" formatted = header + "\n\n".join(sections) return {"formatted": formatted, "totalResults": total, "resultsShared": total} # --------------------------------------------------------------------------- # Snippet search (Semantic Scholar) # --------------------------------------------------------------------------- def _format_snippets(snippets: list[dict], query: str) -> str: lines = [f"# Snippet Search: '{query}'"] lines.append(f"Found {len(snippets)} matching passage(s)\n") for i, item in enumerate(snippets, 1): paper = item.get("paper") or {} ptitle = paper.get("title") or "(untitled)" year = paper.get("year") or "?" cites = paper.get("citationCount", 0) ext_ids = paper.get("externalIds") or {} aid = ext_ids.get("ArXiv", "") snippet = item.get("snippet") or {} text = snippet.get("text", "") section = snippet.get("section") or "" lines.append(f"### {i}. {ptitle} ({year}, {cites} cites)") if aid: lines.append(f"arxiv:{aid}") if section: lines.append(f"Section: {section}") if text: lines.append(f"> {_truncate(text, 400)}") lines.append("") lines.append("Use paper_details or read_paper with arxiv_id to explore a paper further.") return "\n".join(lines) async def _op_snippet_search(args: dict[str, Any], limit: int) -> ToolResult: query = args.get("query") if not query: return _error("'query' is required for snippet_search.") params: dict[str, Any] = { "query": query, "limit": limit, "fields": "title,externalIds,year,citationCount", } # Optional filters (same as search) date_from = args.get("date_from", "") date_to = args.get("date_to", "") if date_from or date_to: params["publicationDateOrYear"] = f"{date_from}:{date_to}" if args.get("categories"): params["fieldsOfStudy"] = args["categories"] if args.get("min_citations"): params["minCitationCount"] = str(args["min_citations"]) async with httpx.AsyncClient(timeout=15) as client: resp = await _s2_request(client, "GET", "/graph/v1/snippet/search", params=params) if not resp or resp.status_code != 200: return _error("Snippet search failed. Semantic Scholar may be unavailable.") data = resp.json() snippets = data.get("data") or [] if not snippets: return { "formatted": f"No snippets found for '{query}'.", "totalResults": 0, "resultsShared": 0, } return { "formatted": _format_snippets(snippets, query), "totalResults": len(snippets), "resultsShared": len(snippets), } # --------------------------------------------------------------------------- # Recommendations (Semantic Scholar) # --------------------------------------------------------------------------- async def _op_recommend(args: dict[str, Any], limit: int) -> ToolResult: positive_ids = args.get("positive_ids") arxiv_id = _validate_arxiv_id(args) if not arxiv_id and not positive_ids: return _error("'arxiv_id' or 'positive_ids' is required for recommend.") fields = "title,externalIds,year,citationCount,tldr,venue" async with httpx.AsyncClient(timeout=15) as client: if positive_ids and not arxiv_id: # Multi-paper recommendations (POST, not cached) pos = [_s2_paper_id(pid.strip()) for pid in positive_ids.split(",") if pid.strip()] neg_raw = args.get("negative_ids", "") neg = [_s2_paper_id(pid.strip()) for pid in neg_raw.split(",") if pid.strip()] if neg_raw else [] resp = await _s2_request( client, "POST", "/recommendations/v1/papers/", json={"positivePaperIds": pos, "negativePaperIds": neg}, params={"fields": fields, "limit": limit}, ) if not resp or resp.status_code != 200: return _error("Recommendation request failed. Semantic Scholar may be unavailable.") data = resp.json() else: # Single-paper recommendations (cached) data = await _s2_get_json( client, f"/recommendations/v1/papers/forpaper/{_s2_paper_id(arxiv_id)}", {"fields": fields, "limit": limit, "from": "recent"}, ) if not data: return _error("Recommendation request failed. Semantic Scholar may be unavailable.") papers = data.get("recommendedPapers") or [] if not papers: return { "formatted": "No recommendations found.", "totalResults": 0, "resultsShared": 0, } title = f"Recommended papers based on {arxiv_id or positive_ids}" return { "formatted": _format_s2_paper_list(papers[:limit], title), "totalResults": len(papers), "resultsShared": min(limit, len(papers)), } # --------------------------------------------------------------------------- # Operation dispatch # --------------------------------------------------------------------------- _OPERATIONS = { "trending": _op_trending, "search": _op_search, "paper_details": _op_paper_details, "read_paper": _op_read_paper, "citation_graph": _op_citation_graph, "snippet_search": _op_snippet_search, "recommend": _op_recommend, "find_datasets": _op_find_datasets, "find_models": _op_find_models, "find_collections": _op_find_collections, "find_all_resources": _op_find_all_resources, } # --------------------------------------------------------------------------- # Tool spec + handler # --------------------------------------------------------------------------- HF_PAPERS_TOOL_SPEC = { "name": "hf_papers", "description": ( "Discover ML research papers, analyze citations, search paper contents, and find linked resources.\n\n" "Combines HuggingFace Hub, arXiv, and Semantic Scholar. Use for exploring research areas, " "finding datasets for a task, tracing citation chains, or implementing a paper's approach.\n\n" "Typical flows:\n" " search → read_paper → find_all_resources → hf_inspect_dataset\n" " search → paper_details → citation_graph → read_paper (trace influence)\n" " snippet_search → paper_details → read_paper (find specific claims)\n\n" "Operations:\n" "- trending: Get trending daily papers, optionally filter by topic keyword\n" "- search: Search papers. Uses HF by default (ML-tuned). Add date_from/min_citations/categories to use Semantic Scholar with filters\n" "- paper_details: Metadata, abstract, AI summary, github link\n" "- read_paper: Read paper contents — without section: abstract + TOC; with section: full text\n" "- citation_graph: Get references and citations for a paper with influence flags and citation intents\n" "- snippet_search: Semantic search over full-text passages from 12M+ papers\n" "- recommend: Find similar papers (single paper or positive/negative examples)\n" "- find_datasets: Find datasets linked to a paper\n" "- find_models: Find models linked to a paper\n" "- find_collections: Find collections that include a paper\n" "- find_all_resources: Parallel fetch of datasets + models + collections for a paper" ), "parameters": { "type": "object", "properties": { "operation": { "type": "string", "enum": list(_OPERATIONS.keys()), "description": "Operation to execute.", }, "query": { "type": "string", "description": ( "Search query. Required for: search, snippet_search. " "Optional for: trending (filters by keyword). " "Supports boolean syntax for Semantic Scholar: '\"exact phrase\" term1 | term2'." ), }, "arxiv_id": { "type": "string", "description": ( "ArXiv paper ID (e.g. '2305.18290'). " "Required for: paper_details, read_paper, citation_graph, find_datasets, find_models, find_collections, find_all_resources. " "Optional for: recommend (single-paper recs). Get IDs from search results first." ), }, "section": { "type": "string", "description": ( "Section name or number to read (e.g. '3', 'Experiments', '4.2'). " "Optional for: read_paper. Without this, returns abstract + TOC." ), }, "direction": { "type": "string", "enum": ["citations", "references", "both"], "description": "Direction for citation_graph. Default: both.", }, "date": { "type": "string", "description": "Date in YYYY-MM-DD format. Optional for: trending (defaults to recent papers).", }, "date_from": { "type": "string", "description": "Start date (YYYY-MM-DD). Triggers Semantic Scholar search. For: search, snippet_search.", }, "date_to": { "type": "string", "description": "End date (YYYY-MM-DD). Triggers Semantic Scholar search. For: search, snippet_search.", }, "categories": { "type": "string", "description": "Field of study filter (e.g. 'Computer Science'). Triggers Semantic Scholar search.", }, "min_citations": { "type": "integer", "description": "Minimum citation count filter. Triggers Semantic Scholar search.", }, "sort_by": { "type": "string", "enum": ["relevance", "citationCount", "publicationDate"], "description": "Sort order for Semantic Scholar search. Default: relevance.", }, "positive_ids": { "type": "string", "description": "Comma-separated arxiv IDs for multi-paper recommendations. For: recommend.", }, "negative_ids": { "type": "string", "description": "Comma-separated arxiv IDs as negative examples. For: recommend.", }, "sort": { "type": "string", "enum": ["downloads", "likes", "trending"], "description": ( "Sort order for find_datasets and find_models. Default: downloads." ), }, "limit": { "type": "integer", "description": "Maximum results to return (default: 10, max: 50).", }, }, "required": ["operation"], }, } async def hf_papers_handler(arguments: dict[str, Any]) -> tuple[str, bool]: """Handler for agent tool router.""" operation = arguments.get("operation") if not operation: return "'operation' parameter is required.", False handler = _OPERATIONS.get(operation) if not handler: valid = ", ".join(_OPERATIONS.keys()) return f"Unknown operation: '{operation}'. Valid: {valid}", False limit = min(arguments.get("limit", DEFAULT_LIMIT), MAX_LIMIT) try: result = await handler(arguments, limit) return result["formatted"], not result.get("isError", False) except httpx.HTTPStatusError as e: return f"API error: {e.response.status_code} — {e.response.text[:200]}", False except httpx.RequestError as e: return f"Request error: {e}", False except Exception as e: return f"Error in {operation}: {e}", False