| from __future__ import annotations |
|
|
| import json |
| import os |
| import re |
| from pathlib import Path |
| from typing import Any |
| from urllib.error import HTTPError, URLError |
| from urllib.parse import urlencode |
| from urllib.request import Request, urlopen |
|
|
| DEFAULT_LIMIT = 20 |
| DEFAULT_TIMEOUT_SEC = 10 |
| MAX_API_LIMIT = 100 |
| MAX_PAGES = 10 |
| MAX_TOTAL_FETCH = 500 |
| MAX_QUERY_LENGTH = 300 |
| BASE_API_URL = "https://huggingface.co/api" |
|
|
| DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$") |
| WEEK_RE = re.compile(r"^\d{4}-W\d{2}$") |
| MONTH_RE = re.compile(r"^\d{4}-\d{2}$") |
| SUBMITTER_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,38}$") |
| ALLOWED_SORTS = {"publishedAt", "trending"} |
|
|
|
|
| def _load_token() -> str | None: |
| |
| try: |
| from fast_agent.mcp.auth.context import request_bearer_token |
|
|
| ctx_token = request_bearer_token.get() |
| if ctx_token: |
| return ctx_token |
| except ImportError: |
| pass |
|
|
| |
| token = os.getenv("HF_TOKEN") |
| if token: |
| return token |
|
|
| |
| token_path = Path.home() / ".cache" / "huggingface" / "token" |
| if token_path.exists(): |
| token_value = token_path.read_text(encoding="utf-8").strip() |
| return token_value or None |
|
|
| return None |
|
|
|
|
| def _max_results_from_env() -> int: |
| raw = os.getenv("HF_MAX_RESULTS") |
| if not raw: |
| return DEFAULT_LIMIT |
| try: |
| value = int(raw) |
| except ValueError: |
| return DEFAULT_LIMIT |
| return value if value > 0 else DEFAULT_LIMIT |
|
|
|
|
| def _timeout_from_env() -> int: |
| raw = os.getenv("HF_TIMEOUT_SEC") |
| if not raw: |
| return DEFAULT_TIMEOUT_SEC |
| try: |
| value = int(raw) |
| except ValueError: |
| return DEFAULT_TIMEOUT_SEC |
| if value <= 0: |
| return DEFAULT_TIMEOUT_SEC |
| return min(value, DEFAULT_TIMEOUT_SEC) |
|
|
|
|
| def _coerce_int(name: str, value: int | None, *, default: int) -> int: |
| if value is None: |
| return default |
| try: |
| resolved = int(value) |
| except (TypeError, ValueError) as exc: |
| raise ValueError(f"{name} must be an integer.") from exc |
| return resolved |
|
|
|
|
| def _normalize_date_param(name: str, value: str | None, pattern: re.Pattern[str]) -> str | None: |
| if not value: |
| return None |
| cleaned = value.strip() |
| if not cleaned: |
| return None |
| if not pattern.match(cleaned): |
| raise ValueError(f"{name} must match {pattern.pattern}.") |
| return cleaned |
|
|
|
|
| def _normalize_submitter(value: str | None) -> str | None: |
| if not value: |
| return None |
| cleaned = value.strip() |
| if not cleaned: |
| return None |
| if not SUBMITTER_RE.match(cleaned): |
| raise ValueError("submitter must be a valid HF username.") |
| return cleaned |
|
|
|
|
| def _normalize_sort(value: str | None) -> str | None: |
| if not value: |
| return None |
| cleaned = value.strip() |
| if cleaned not in ALLOWED_SORTS: |
| allowed = ", ".join(sorted(ALLOWED_SORTS)) |
| raise ValueError(f"sort must be one of: {allowed}.") |
| return cleaned |
|
|
|
|
| def _normalize_query(value: str | None) -> str | None: |
| if value is None: |
| return None |
| cleaned = value.strip() |
| if not cleaned: |
| return None |
| return cleaned[:MAX_QUERY_LENGTH] |
|
|
|
|
| def _build_url(params: dict[str, Any]) -> str: |
| query = urlencode({k: v for k, v in params.items() if v is not None}, doseq=True) |
| return f"{BASE_API_URL}/daily_papers?{query}" if query else f"{BASE_API_URL}/daily_papers" |
|
|
|
|
| def _request_json(url: str) -> list[dict[str, Any]]: |
| headers = {"Accept": "application/json"} |
| token = _load_token() |
| if token: |
| headers["Authorization"] = f"Bearer {token}" |
|
|
| request = Request(url, headers=headers, method="GET") |
| try: |
| with urlopen(request, timeout=_timeout_from_env()) as response: |
| raw = response.read() |
| except HTTPError as exc: |
| error_body = exc.read().decode("utf-8", errors="replace") |
| raise RuntimeError(f"HF API error {exc.code} for {url}: {error_body}") from exc |
| except URLError as exc: |
| raise RuntimeError(f"HF API request failed for {url}: {exc}") from exc |
|
|
| payload = json.loads(raw) |
| if not isinstance(payload, list): |
| raise RuntimeError("Unexpected response shape from /api/daily_papers") |
| return payload |
|
|
|
|
| def _extract_search_blob(item: dict[str, Any]) -> str: |
| paper = item.get("paper") or {} |
| authors = paper.get("authors") or [] |
| author_names = [a.get("name", "") for a in authors if isinstance(a, dict)] |
|
|
| ai_keywords = paper.get("ai_keywords") or [] |
| if isinstance(ai_keywords, list): |
| ai_keywords_text = " ".join(str(k) for k in ai_keywords) |
| else: |
| ai_keywords_text = str(ai_keywords) |
|
|
| parts = [ |
| item.get("title"), |
| item.get("summary"), |
| paper.get("title"), |
| paper.get("summary"), |
| paper.get("ai_summary"), |
| ai_keywords_text, |
| " ".join(author_names), |
| paper.get("id"), |
| paper.get("projectPage"), |
| paper.get("githubRepo"), |
| ] |
|
|
| text = " ".join(str(part) for part in parts if part) |
| return text.lower() |
|
|
|
|
| def _matches_query(item: dict[str, Any], query: str) -> bool: |
| tokens = [t for t in re.split(r"\s+", query.strip().lower()) if t] |
| if not tokens: |
| return True |
| haystack = _extract_search_blob(item) |
| return all(token in haystack for token in tokens) |
|
|
|
|
| def _clamp_total_fetch(pages: int, per_page: int) -> tuple[int, int]: |
| if per_page * pages <= MAX_TOTAL_FETCH: |
| return pages, per_page |
| if per_page > MAX_TOTAL_FETCH: |
| return 1, MAX_TOTAL_FETCH |
| max_pages = max(MAX_TOTAL_FETCH // per_page, 1) |
| return min(pages, max_pages), per_page |
|
|
|
|
| def hf_papers_search( |
| query: str | None = None, |
| *, |
| date: str | None = None, |
| week: str | None = None, |
| month: str | None = None, |
| submitter: str | None = None, |
| sort: str | None = None, |
| limit: int | None = None, |
| page: int | None = None, |
| max_pages: int | None = None, |
| api_limit: int | None = None, |
| ) -> dict[str, Any]: |
| """ |
| Search Hugging Face Daily Papers with optional local filtering. |
| |
| Args: |
| query: Case-insensitive keyword search across title, summary, authors, |
| AI summary/keywords, project page, repo link, and paper id. |
| date: ISO date (YYYY-MM-DD). |
| week: ISO week (YYYY-Www). |
| month: ISO month (YYYY-MM). |
| submitter: HF username of the submitter. |
| sort: "publishedAt" or "trending". |
| limit: Max results to return after filtering (default 20). |
| page: Page index for the API (default 0). |
| max_pages: Number of pages to fetch for local filtering (default 1). |
| api_limit: Page size for the API (default 50, max 100). |
| |
| Returns: |
| dict with query metadata and list of daily paper entries. |
| """ |
| resolved_limit = _coerce_int("limit", limit, default=_max_results_from_env()) |
| if resolved_limit < 1: |
| raise ValueError("limit must be >= 1.") |
|
|
| start_page = _coerce_int("page", page, default=0) |
| if start_page < 0: |
| raise ValueError("page must be >= 0.") |
|
|
| pages_to_fetch = _coerce_int("max_pages", max_pages, default=1) |
| if pages_to_fetch < 1: |
| raise ValueError("max_pages must be >= 1.") |
| pages_to_fetch = min(pages_to_fetch, MAX_PAGES) |
|
|
| per_page = _coerce_int("api_limit", api_limit, default=50) |
| if per_page < 1: |
| raise ValueError("api_limit must be >= 1.") |
| per_page = min(per_page, MAX_API_LIMIT) |
|
|
| pages_to_fetch, per_page = _clamp_total_fetch(pages_to_fetch, per_page) |
|
|
| normalized_query = _normalize_query(query) |
|
|
| params_base: dict[str, Any] = { |
| "date": _normalize_date_param("date", date, DATE_RE), |
| "week": _normalize_date_param("week", week, WEEK_RE), |
| "month": _normalize_date_param("month", month, MONTH_RE), |
| "submitter": _normalize_submitter(submitter), |
| "sort": _normalize_sort(sort), |
| "limit": per_page, |
| } |
|
|
| results: list[dict[str, Any]] = [] |
| pages_fetched = 0 |
| for page_index in range(start_page, start_page + pages_to_fetch): |
| params = {**params_base, "p": page_index} |
| url = _build_url(params) |
| payload = _request_json(url) |
| pages_fetched += 1 |
|
|
| if normalized_query: |
| filtered = [item for item in payload if _matches_query(item, normalized_query)] |
| else: |
| filtered = payload |
|
|
| results.extend(filtered) |
| if len(results) >= resolved_limit: |
| break |
|
|
| return { |
| "query": normalized_query, |
| "params": { |
| **{k: v for k, v in params_base.items() if v is not None}, |
| "page": start_page, |
| "max_pages": pages_fetched, |
| "api_limit": per_page, |
| }, |
| "returned": min(len(results), resolved_limit), |
| "data": results[:resolved_limit], |
| } |
|
|