from __future__ import annotations import json import os import re from pathlib import Path from typing import Any from urllib.error import HTTPError, URLError from urllib.parse import urlencode from urllib.request import Request, urlopen DEFAULT_LIMIT = 20 DEFAULT_TIMEOUT_SEC = 10 MAX_API_LIMIT = 100 MAX_PAGES = 10 MAX_TOTAL_FETCH = 500 MAX_QUERY_LENGTH = 300 BASE_API_URL = "https://huggingface.co/api" DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$") WEEK_RE = re.compile(r"^\d{4}-W\d{2}$") MONTH_RE = re.compile(r"^\d{4}-\d{2}$") SUBMITTER_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,38}$") ALLOWED_SORTS = {"publishedAt", "trending"} def _load_token() -> str | None: # Check for request-scoped token first (when running as MCP server) try: from fast_agent.mcp.auth.context import request_bearer_token ctx_token = request_bearer_token.get() if ctx_token: return ctx_token except ImportError: pass # Fall back to HF_TOKEN environment variable token = os.getenv("HF_TOKEN") if token: return token # Fall back to cached huggingface token file token_path = Path.home() / ".cache" / "huggingface" / "token" if token_path.exists(): token_value = token_path.read_text(encoding="utf-8").strip() return token_value or None return None def _max_results_from_env() -> int: raw = os.getenv("HF_MAX_RESULTS") if not raw: return DEFAULT_LIMIT try: value = int(raw) except ValueError: return DEFAULT_LIMIT return value if value > 0 else DEFAULT_LIMIT def _timeout_from_env() -> int: raw = os.getenv("HF_TIMEOUT_SEC") if not raw: return DEFAULT_TIMEOUT_SEC try: value = int(raw) except ValueError: return DEFAULT_TIMEOUT_SEC if value <= 0: return DEFAULT_TIMEOUT_SEC return min(value, DEFAULT_TIMEOUT_SEC) def _coerce_int(name: str, value: int | None, *, default: int) -> int: if value is None: return default try: resolved = int(value) except (TypeError, ValueError) as exc: raise ValueError(f"{name} must be an integer.") from exc return resolved def _normalize_date_param(name: str, value: str | None, pattern: re.Pattern[str]) -> str | None: if not value: return None cleaned = value.strip() if not cleaned: return None if not pattern.match(cleaned): raise ValueError(f"{name} must match {pattern.pattern}.") return cleaned def _normalize_submitter(value: str | None) -> str | None: if not value: return None cleaned = value.strip() if not cleaned: return None if not SUBMITTER_RE.match(cleaned): raise ValueError("submitter must be a valid HF username.") return cleaned def _normalize_sort(value: str | None) -> str | None: if not value: return None cleaned = value.strip() if cleaned not in ALLOWED_SORTS: allowed = ", ".join(sorted(ALLOWED_SORTS)) raise ValueError(f"sort must be one of: {allowed}.") return cleaned def _normalize_query(value: str | None) -> str | None: if value is None: return None cleaned = value.strip() if not cleaned: return None return cleaned[:MAX_QUERY_LENGTH] def _build_url(params: dict[str, Any]) -> str: query = urlencode({k: v for k, v in params.items() if v is not None}, doseq=True) return f"{BASE_API_URL}/daily_papers?{query}" if query else f"{BASE_API_URL}/daily_papers" def _request_json(url: str) -> list[dict[str, Any]]: headers = {"Accept": "application/json"} token = _load_token() if token: headers["Authorization"] = f"Bearer {token}" request = Request(url, headers=headers, method="GET") try: with urlopen(request, timeout=_timeout_from_env()) as response: raw = response.read() except HTTPError as exc: error_body = exc.read().decode("utf-8", errors="replace") raise RuntimeError(f"HF API error {exc.code} for {url}: {error_body}") from exc except URLError as exc: raise RuntimeError(f"HF API request failed for {url}: {exc}") from exc payload = json.loads(raw) if not isinstance(payload, list): raise RuntimeError("Unexpected response shape from /api/daily_papers") return payload def _extract_search_blob(item: dict[str, Any]) -> str: paper = item.get("paper") or {} authors = paper.get("authors") or [] author_names = [a.get("name", "") for a in authors if isinstance(a, dict)] ai_keywords = paper.get("ai_keywords") or [] if isinstance(ai_keywords, list): ai_keywords_text = " ".join(str(k) for k in ai_keywords) else: ai_keywords_text = str(ai_keywords) parts = [ item.get("title"), item.get("summary"), paper.get("title"), paper.get("summary"), paper.get("ai_summary"), ai_keywords_text, " ".join(author_names), paper.get("id"), paper.get("projectPage"), paper.get("githubRepo"), ] text = " ".join(str(part) for part in parts if part) return text.lower() def _matches_query(item: dict[str, Any], query: str) -> bool: tokens = [t for t in re.split(r"\s+", query.strip().lower()) if t] if not tokens: return True haystack = _extract_search_blob(item) return all(token in haystack for token in tokens) def _clamp_total_fetch(pages: int, per_page: int) -> tuple[int, int]: if per_page * pages <= MAX_TOTAL_FETCH: return pages, per_page if per_page > MAX_TOTAL_FETCH: return 1, MAX_TOTAL_FETCH max_pages = max(MAX_TOTAL_FETCH // per_page, 1) return min(pages, max_pages), per_page def hf_papers_search( query: str | None = None, *, date: str | None = None, week: str | None = None, month: str | None = None, submitter: str | None = None, sort: str | None = None, limit: int | None = None, page: int | None = None, max_pages: int | None = None, api_limit: int | None = None, ) -> dict[str, Any]: """ Search Hugging Face Daily Papers with optional local filtering. Args: query: Case-insensitive keyword search across title, summary, authors, AI summary/keywords, project page, repo link, and paper id. date: ISO date (YYYY-MM-DD). week: ISO week (YYYY-Www). month: ISO month (YYYY-MM). submitter: HF username of the submitter. sort: "publishedAt" or "trending". limit: Max results to return after filtering (default 20). page: Page index for the API (default 0). max_pages: Number of pages to fetch for local filtering (default 1). api_limit: Page size for the API (default 50, max 100). Returns: dict with query metadata and list of daily paper entries. """ resolved_limit = _coerce_int("limit", limit, default=_max_results_from_env()) if resolved_limit < 1: raise ValueError("limit must be >= 1.") start_page = _coerce_int("page", page, default=0) if start_page < 0: raise ValueError("page must be >= 0.") pages_to_fetch = _coerce_int("max_pages", max_pages, default=1) if pages_to_fetch < 1: raise ValueError("max_pages must be >= 1.") pages_to_fetch = min(pages_to_fetch, MAX_PAGES) per_page = _coerce_int("api_limit", api_limit, default=50) if per_page < 1: raise ValueError("api_limit must be >= 1.") per_page = min(per_page, MAX_API_LIMIT) pages_to_fetch, per_page = _clamp_total_fetch(pages_to_fetch, per_page) normalized_query = _normalize_query(query) params_base: dict[str, Any] = { "date": _normalize_date_param("date", date, DATE_RE), "week": _normalize_date_param("week", week, WEEK_RE), "month": _normalize_date_param("month", month, MONTH_RE), "submitter": _normalize_submitter(submitter), "sort": _normalize_sort(sort), "limit": per_page, } results: list[dict[str, Any]] = [] pages_fetched = 0 for page_index in range(start_page, start_page + pages_to_fetch): params = {**params_base, "p": page_index} url = _build_url(params) payload = _request_json(url) pages_fetched += 1 if normalized_query: filtered = [item for item in payload if _matches_query(item, normalized_query)] else: filtered = payload results.extend(filtered) if len(results) >= resolved_limit: break return { "query": normalized_query, "params": { **{k: v for k, v in params_base.items() if v is not None}, "page": start_page, "max_pages": pages_fetched, "api_limit": per_page, }, "returned": min(len(results), resolved_limit), "data": results[:resolved_limit], }