ml-intern / agent /tools /docs_tools.py
akseljoonas's picture
akseljoonas HF Staff
removing extra prints
7473699
raw
history blame
26.7 kB
"""
Documentation search tools for the HF Agent
Tools for exploring and fetching HuggingFace documentation and API specifications
"""
import asyncio
import os
from typing import Any
import httpx
from bs4 import BeautifulSoup
# Cache for OpenAPI spec to avoid repeated fetches
_openapi_spec_cache: dict[str, Any] | None = None
async def _fetch_html_page(hf_token: str, endpoint: str) -> str:
"""Fetch the HTML page for a given endpoint"""
base_url = "https://huggingface.co/docs"
url = f"{base_url}/{endpoint}"
headers = {"Authorization": f"Bearer {hf_token}"}
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
response = await client.get(url, headers=headers)
response.raise_for_status()
return response.text
def _parse_sidebar_navigation(html_content: str) -> list[dict[str, str]]:
"""Parse the sidebar navigation and extract all links"""
soup = BeautifulSoup(html_content, "html.parser")
sidebar = soup.find("nav", class_=lambda x: x and "flex-auto" in x)
if not sidebar:
raise ValueError("Could not find navigation sidebar")
links = sidebar.find_all("a", href=True)
nav_data = []
for link in links:
title = link.get_text(strip=True)
href = link["href"]
# Make URL absolute
page_url = f"https://huggingface.co{href}" if href.startswith("/") else href
nav_data.append({"title": title, "url": page_url})
return nav_data
async def _fetch_single_glimpse(
client: httpx.AsyncClient, hf_token: str, item: dict[str, str]
) -> dict[str, str]:
"""Fetch a glimpse (first 300 chars) for a single page"""
md_url = f"{item['url']}.md"
headers = {"Authorization": f"Bearer {hf_token}"}
try:
response = await client.get(md_url, headers=headers)
response.raise_for_status()
content = response.text
glimpse = content[:300].strip()
if len(content) > 300:
glimpse += "..."
return {
"title": item["title"],
"url": item["url"],
"md_url": md_url,
"glimpse": glimpse,
}
except Exception as e:
return {
"title": item["title"],
"url": item["url"],
"md_url": md_url,
"glimpse": f"[Could not fetch glimpse: {str(e)[:50]}]",
}
async def _fetch_all_glimpses(
hf_token: str, nav_data: list[dict[str, str]]
) -> list[dict[str, str]]:
"""Fetch glimpses for all pages in parallel"""
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
result_items = await asyncio.gather(
*[_fetch_single_glimpse(client, hf_token, item) for item in nav_data]
)
return list(result_items)
def _format_exploration_results(
endpoint: str, result_items: list[dict[str, str]]
) -> str:
"""Format the exploration results as a readable string"""
base_url = "https://huggingface.co/docs"
url = f"{base_url}/{endpoint}"
result = f"Documentation structure for: {url}\n\n"
result += f"Found {len(result_items)} pages:\n\n"
for i, item in enumerate(result_items, 1):
result += f"{i}. **{item['title']}**\n"
result += f" URL: {item['url']}\n"
result += f" Glimpse: {item['glimpse']}\n\n"
return result
async def explore_hf_docs(hf_token: str, endpoint: str) -> str:
"""Main function to explore documentation structure"""
# Fetch HTML page
html_content = await _fetch_html_page(hf_token, endpoint)
# Parse navigation
nav_data = _parse_sidebar_navigation(html_content)
if not nav_data:
raise ValueError(f"No navigation links found for endpoint '{endpoint}'")
# Fetch all glimpses in parallel
result_items = await _fetch_all_glimpses(hf_token, nav_data)
# Format results
result = _format_exploration_results(endpoint, result_items)
return result
async def explore_hf_docs_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
"""
Explore the documentation structure for a given endpoint by parsing the sidebar navigation
Args:
arguments: Dictionary with 'endpoint' parameter (e.g., 'trl', 'transformers', etc.)
Returns:
Tuple of (structured_navigation_with_glimpses, success)
"""
endpoint = arguments.get("endpoint", "")
if not endpoint:
return "Error: No endpoint provided", False
# Get HF token from environment
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
return "Error: HF_TOKEN environment variable not set", False
endpoint = endpoint.lstrip("/")
try:
result = await explore_hf_docs(hf_token, endpoint)
return result, True
except httpx.HTTPStatusError as e:
return (
f"HTTP error: {e.response.status_code} - {e.response.text[:200]}",
False,
)
except httpx.RequestError as e:
return f"Request error: {str(e)}", False
except ValueError as e:
return f"Error: {str(e)}", False
except Exception as e:
return f"Unexpected error: {str(e)}", False
async def _fetch_openapi_spec() -> dict[str, Any]:
"""Fetch and cache the HuggingFace OpenAPI specification"""
global _openapi_spec_cache
if _openapi_spec_cache is not None:
return _openapi_spec_cache
url = "https://huggingface.co/.well-known/openapi.json"
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
response = await client.get(url)
response.raise_for_status()
spec = response.json()
_openapi_spec_cache = spec
return spec
def _extract_all_tags(spec: dict[str, Any]) -> list[str]:
"""Extract all unique tags from the OpenAPI spec"""
tags = set()
# Get tags from the tags section
for tag_obj in spec.get("tags", []):
if "name" in tag_obj:
tags.add(tag_obj["name"])
# Also get tags from paths (in case some aren't in the tags section)
for path, path_item in spec.get("paths", {}).items():
for method, operation in path_item.items():
if method in ["get", "post", "put", "delete", "patch", "head", "options"]:
for tag in operation.get("tags", []):
tags.add(tag)
return sorted(list(tags))
def _search_openapi_by_tag(spec: dict[str, Any], tag: str) -> list[dict[str, Any]]:
"""Search for API endpoints with a specific tag"""
results = []
paths = spec.get("paths", {})
servers = spec.get("servers", [])
base_url = (
servers[0].get("url", "https://huggingface.co")
if servers
else "https://huggingface.co"
)
for path, path_item in paths.items():
for method, operation in path_item.items():
if method not in [
"get",
"post",
"put",
"delete",
"patch",
"head",
"options",
]:
continue
operation_tags = operation.get("tags", [])
if tag in operation_tags:
# Extract parameters
parameters = operation.get("parameters", [])
request_body = operation.get("requestBody", {})
responses = operation.get("responses", {})
results.append(
{
"path": path,
"method": method.upper(),
"operationId": operation.get("operationId", ""),
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
"parameters": parameters,
"request_body": request_body,
"responses": responses,
"base_url": base_url,
}
)
return results
def _generate_curl_example(endpoint: dict[str, Any]) -> str:
"""Generate a curl command example for an endpoint"""
method = endpoint["method"]
path = endpoint["path"]
base_url = endpoint["base_url"]
# Build the full URL with example path parameters
full_path = path
for param in endpoint.get("parameters", []):
if param.get("in") == "path" and param.get("required"):
param_name = param["name"]
example = param.get(
"example", param.get("schema", {}).get("example", f"<{param_name}>")
)
full_path = full_path.replace(f"{{{param_name}}}", str(example))
curl = f"curl -X {method} \\\n '{base_url}{full_path}'"
# Add query parameters if any
query_params = [p for p in endpoint.get("parameters", []) if p.get("in") == "query"]
if query_params and query_params[0].get("required"):
param = query_params[0]
example = param.get("example", param.get("schema", {}).get("example", "value"))
curl += f"?{param['name']}={example}"
# Add headers
curl += " \\\n -H 'Authorization: Bearer $HF_TOKEN'"
# Add request body if applicable
if method in ["POST", "PUT", "PATCH"] and endpoint.get("request_body"):
content = endpoint["request_body"].get("content", {})
if "application/json" in content:
curl += " \\\n -H 'Content-Type: application/json'"
schema = content["application/json"].get("schema", {})
example = schema.get("example", "{}")
if isinstance(example, dict):
import json
example = json.dumps(example, indent=2)
curl += f" \\\n -d '{example}'"
return curl
def _format_parameters(parameters: list[dict[str, Any]]) -> str:
"""Format parameter information from OpenAPI spec"""
if not parameters:
return ""
# Group parameters by type
path_params = [p for p in parameters if p.get("in") == "path"]
query_params = [p for p in parameters if p.get("in") == "query"]
header_params = [p for p in parameters if p.get("in") == "header"]
output = []
if path_params:
output.append("**Path Parameters:**")
for param in path_params:
name = param.get("name", "")
required = " (required)" if param.get("required") else " (optional)"
description = param.get("description", "")
param_type = param.get("schema", {}).get("type", "string")
example = param.get("example") or param.get("schema", {}).get("example", "")
output.append(f"- `{name}` ({param_type}){required}: {description}")
if example:
output.append(f" Example: `{example}`")
if query_params:
if output:
output.append("")
output.append("**Query Parameters:**")
for param in query_params:
name = param.get("name", "")
required = " (required)" if param.get("required") else " (optional)"
description = param.get("description", "")
param_type = param.get("schema", {}).get("type", "string")
example = param.get("example") or param.get("schema", {}).get("example", "")
output.append(f"- `{name}` ({param_type}){required}: {description}")
if example:
output.append(f" Example: `{example}`")
if header_params:
if output:
output.append("")
output.append("**Header Parameters:**")
for param in header_params:
name = param.get("name", "")
required = " (required)" if param.get("required") else " (optional)"
description = param.get("description", "")
output.append(f"- `{name}`{required}: {description}")
return "\n".join(output)
def _format_response_info(responses: dict[str, Any]) -> str:
"""Format response information from OpenAPI spec"""
if not responses:
return "No response information available"
output = []
for status_code, response_obj in list(responses.items())[
:3
]: # Show first 3 status codes
desc = response_obj.get("description", "")
output.append(f"- **{status_code}**: {desc}")
content = response_obj.get("content", {})
if "application/json" in content:
schema = content["application/json"].get("schema", {})
if "type" in schema:
output.append(f" Returns: {schema.get('type', 'object')}")
return "\n".join(output)
def _format_openapi_results(results: list[dict[str, Any]], tag: str) -> str:
"""Format OpenAPI search results as markdown with curl examples"""
if not results:
return f"No API endpoints found with tag '{tag}'"
output = f"# API Endpoints for tag: `{tag}`\n\n"
output += f"Found {len(results)} endpoint(s)\n\n"
output += "---\n\n"
for i, endpoint in enumerate(results, 1):
output += f"## {i}. {endpoint['method']} {endpoint['path']}\n\n"
if endpoint["summary"]:
output += f"**Summary:** {endpoint['summary']}\n\n"
if endpoint["description"]:
desc = endpoint["description"][:300]
if len(endpoint["description"]) > 300:
desc += "..."
output += f"**Description:** {desc}\n\n"
# Parameters
params_info = _format_parameters(endpoint.get("parameters", []))
if params_info:
output += params_info + "\n\n"
# Curl example
output += "**Usage:**\n```bash\n"
output += _generate_curl_example(endpoint)
output += "\n```\n\n"
# Response info
output += "**Returns:**\n"
output += _format_response_info(endpoint["responses"])
output += "\n\n"
output += "---\n\n"
return output
async def search_openapi_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
"""
Search the HuggingFace OpenAPI specification by tag
Args:
arguments: Dictionary with 'tag' parameter
Returns:
Tuple of (search_results, success)
"""
tag = arguments.get("tag", "")
if not tag:
return "Error: No tag provided", False
try:
# Fetch OpenAPI spec (cached after first fetch)
spec = await _fetch_openapi_spec()
# Search for endpoints with this tag
results = _search_openapi_by_tag(spec, tag)
# Format results
formatted = _format_openapi_results(results, tag)
return formatted, True
except httpx.HTTPStatusError as e:
return f"HTTP error fetching OpenAPI spec: {e.response.status_code}", False
except httpx.RequestError as e:
return f"Request error: {str(e)}", False
except Exception as e:
return f"Error searching OpenAPI spec: {str(e)}", False
async def hf_docs_fetch_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
"""
Fetch full documentation content from a specific HF docs page
Args:
arguments: Dictionary with 'url' parameter (full URL to the doc page)
Returns:
Tuple of (full_markdown_content, success)
"""
url = arguments.get("url", "")
if not url:
return "Error: No URL provided", False
# Get HF token from environment
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
return (
"Error: HF_TOKEN environment variable not set",
False,
)
# Add .md extension if not already present
if not url.endswith(".md"):
url = f"{url}.md"
try:
# Make request with auth
headers = {"Authorization": f"Bearer {hf_token}"}
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
response = await client.get(url, headers=headers)
response.raise_for_status()
content = response.text
# Return the markdown content directly
result = f"Documentation from: {url}\n\n{content}"
return result, True
except httpx.HTTPStatusError as e:
return (
f"HTTP error fetching {url}: {e.response.status_code} - {e.response.text[:200]}",
False,
)
except httpx.RequestError as e:
return f"Request error fetching {url}: {str(e)}", False
except Exception as e:
return f"Error fetching documentation: {str(e)}", False
# Tool specifications for documentation search
EXPLORE_HF_DOCS_TOOL_SPEC = {
"name": "explore_hf_docs",
"description": (
"Explore the Hugging Face documentation at a glance. "
"Select an endpoint from the available options and get a list of all documentation pages "
"with their titles, URLs, and a 300-character glimpse of each page. "
"Use this to discover what documentation is available before fetching specific pages."
),
"parameters": {
"type": "object",
"properties": {
"endpoint": {
"type": "string",
"enum": [
"hub",
"transformers",
"diffusers",
"datasets",
"gradio",
"trackio",
"smolagents",
"huggingface_hub",
"huggingface.js",
"transformers.js",
"inference-providers",
"inference-endpoints",
"peft",
"accelerate",
"optimum",
"optimum-habana",
"optimum-neuron",
"optimum-intel",
"optimum-executorch",
"optimum-tpu",
"tokenizers",
"llm-course",
"robotics-course",
"mcp-course",
"smol-course",
"agents-course",
"deep-rl-course",
"computer-vision-course",
"evaluate",
"tasks",
"dataset-viewer",
"trl",
"simulate",
"sagemaker",
"timm",
"safetensors",
"tgi",
"setfit",
"audio-course",
"lerobot",
"autotrain",
"tei",
"bitsandbytes",
"cookbook",
"sentence_transformers",
"ml-games-course",
"diffusion-course",
"ml-for-3d-course",
"chat-ui",
"leaderboards",
"lighteval",
"argilla",
"distilabel",
"microsoft-azure",
"kernels",
"google-cloud",
],
"description": (
"The documentation endpoint to explore. Each endpoint corresponds to a major section of the Hugging Face documentation:\n\n"
"β€’ hub β€” Find answers to questions about models/datasets/spaces, auth, versioning, metadata.\n"
"β€’ transformers β€” Core model library: architectures, configs, tokenizers, training & inference APIs.\n"
"β€’ diffusers β€” Diffusion pipelines, schedulers, fine-tuning, training, and deployment patterns.\n"
"β€’ datasets β€” Dataset loading, streaming, processing, Arrow format, Hub integration.\n"
"β€’ gradio β€” UI components and demos for interacting with ML models.\n"
"β€’ trackio β€” Experiment tracking, metrics logging, and run comparison.\n"
"β€’ smolagents β€” Lightweight agent abstractions and tool-using patterns.\n"
"β€’ huggingface_hub β€” Python client for Hub operations (auth, upload/download, repo management).\n"
"β€’ huggingface.js β€” JS/TS client for Hub APIs in browser and Node.\n"
"β€’ transformers.js β€” Run Transformer models in browser/Node via WebGPU/WASM.\n"
"β€’ inference-providers β€” Unified interface for third-party inference backends.\n"
"β€’ inference-endpoints β€” Managed, scalable model deployments on HF infrastructure.\n"
"β€’ peft β€” Parameter-efficient fine-tuning methods (LoRA, adapters, etc.).\n"
"β€’ accelerate β€” Hardware-agnostic, distributed and mixed-precision training orchestration.\n"
"β€’ optimum β€” Hardware-aware optimization and model export tooling.\n"
"β€’ optimum-habana β€” Training and inference on Habana Gaudi accelerators.\n"
"β€’ optimum-neuron β€” Optimization workflows for AWS Inferentia/Trainium.\n"
"β€’ optimum-intel β€” Intel CPU/GPU optimizations (OpenVINO, IPEX).\n"
"β€’ optimum-executorch β€” Exporting models to ExecuTorch for edge/mobile.\n"
"β€’ optimum-tpu β€” TPU-specific training and optimization paths.\n"
"β€’ tokenizers β€” Fast tokenizer internals, training, and low-level APIs.\n"
"β€’ llm-course β€” End-to-end LLM concepts, training, and deployment.\n"
"β€’ robotics-course β€” Learning-based robotics foundations.\n"
"β€’ mcp-course β€” Model Context Protocol concepts and usage.\n"
"β€’ smol-course β€” Small-model and efficiency-focused workflows.\n"
"β€’ agents-course β€” Tool-using, planning, and multi-step agent design.\n"
"β€’ deep-rl-course β€” Deep reinforcement learning foundations.\n"
"β€’ computer-vision-course β€” Vision models, datasets, and pipelines.\n"
"β€’ evaluate β€” Metrics, evaluation workflows, and training-loop integration.\n"
"β€’ tasks β€” Canonical task definitions and model categorization.\n"
"β€’ dataset-viewer β€” Dataset preview, streaming views, and viewer internals.\n"
"β€’ trl β€” RLHF, DPO, PPO, and SFT utilities for LLMs.\n"
"β€’ simulate β€” Experimental simulation tools and workflows.\n"
"β€’ sagemaker β€” Deploying Hugging Face models on AWS SageMaker.\n"
"β€’ timm β€” Image model zoo and utilities via HF integrations.\n"
"β€’ safetensors β€” Safe, fast tensor serialization format.\n"
"β€’ tgi β€” High-throughput text generation server for LLMs.\n"
"β€’ setfit β€” Few-shot text classification via sentence embeddings.\n"
"β€’ audio-course β€” Speech and audio models, datasets, and tasks.\n"
"β€’ lerobot β€” Robotics datasets, policies, and learning workflows.\n"
"β€’ autotrain β€” No/low-code model training on Hugging Face.\n"
"β€’ tei β€” Optimized inference server for embedding workloads.\n"
"β€’ bitsandbytes β€” Quantization and memory-efficient optimizers.\n"
"β€’ cookbook β€” Practical, task-oriented recipes across the ecosystem.\n"
"β€’ sentence_transformers β€” Embedding models, training recipes, similarity/search workflows.\n"
"β€’ ml-games-course β€” Game-based ML and reinforcement learning experiments.\n"
"β€’ diffusion-course β€” Diffusion model theory and hands-on practice.\n"
"β€’ ml-for-3d-course β€” 3D representations, models, and learning techniques.\n"
"β€’ chat-ui β€” Reference chat interfaces for LLM deployment.\n"
"β€’ leaderboards β€” Evaluation leaderboards and submission mechanics.\n"
"β€’ lighteval β€” Lightweight, reproducible LLM evaluation framework.\n"
"β€’ argilla β€” Data annotation, feedback, and human-in-the-loop workflows.\n"
"β€’ distilabel β€” Synthetic data generation and distillation pipelines.\n"
"β€’ microsoft-azure β€” Azure deployment and integration guides.\n"
"β€’ kernels β€” Lightweight execution environments and notebook-style workflows.\n"
"β€’ google-cloud β€” GCP deployment and serving workflows.\n"
),
},
},
"required": ["endpoint"],
},
}
HF_DOCS_FETCH_TOOL_SPEC = {
"name": "fetch_hf_docs",
"description": (
"Fetch the full content of a specific HF documentation page. "
"Provide the full URL to the doc page (e.g., from explore_hf_docs results). "
"Returns the complete markdown content of that page. "
"Use explore_hf_docs first to discover available pages."
),
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": (
"The full URL to the documentation page. "
"Example: 'https://huggingface.co/docs/trl/dpo_trainer' "
"The .md extension will be added automatically if not present."
),
},
},
"required": ["url"],
},
}
async def _get_api_search_tool_spec() -> dict[str, Any]:
"""
Dynamically generate the OpenAPI tool spec with tag enum populated at runtime
This must be called async to fetch the OpenAPI spec and extract tags
"""
spec = await _fetch_openapi_spec()
tags = _extract_all_tags(spec)
return {
"name": "search_hf_api_endpoints",
"description": (
"Search the HuggingFace OpenAPI specification by tag to find related API endpoints. "
"Returns all endpoints with the specified tag including curl examples showing how to use them. "
"Each result includes the endpoint path, summary, usage example with curl, and response information."
),
"parameters": {
"type": "object",
"properties": {
"tag": {
"type": "string",
"enum": tags,
"description": (
"The API tag to search for. Each tag groups related API endpoints. "
),
},
},
"required": ["tag"],
},
}