akseljoonas HF Staff commited on
Commit
64bb402
·
1 Parent(s): d90bfa3

file refactor, still works

Browse files
Files changed (2) hide show
  1. agent/tools/docs_tools.py +412 -615
  2. test_dataset_tools.py +0 -89
agent/tools/docs_tools.py CHANGED
@@ -1,9 +1,9 @@
1
  """
2
- Documentation search tools for the HF Agent
3
- Tools for exploring and fetching HuggingFace documentation and API specifications
4
  """
5
 
6
  import asyncio
 
7
  import os
8
  from typing import Any
9
 
@@ -14,23 +14,16 @@ from whoosh.fields import ID, TEXT, Schema
14
  from whoosh.filedb.filestore import RamStorage
15
  from whoosh.qparser import MultifieldParser, OrGroup
16
 
17
- # Cache for OpenAPI spec to avoid repeated fetches
18
- _openapi_spec_cache: dict[str, Any] | None = None
19
-
20
- # Simple in-memory caches for docs and search indexes
21
- _DOCS_CACHE: dict[str, list[dict[str, str]]] = {}
22
- _INDEX_CACHE: dict[str, tuple[Any, MultifieldParser]] = {}
23
- _CACHE_LOCK = asyncio.Lock()
24
 
25
- # Result limiting defaults
26
  DEFAULT_MAX_RESULTS = 20
27
  MAX_RESULTS_CAP = 50
28
 
29
- # Gradio documentation endpoints (hosted separately from HF docs)
30
  GRADIO_LLMS_TXT_URL = "https://gradio.app/llms.txt"
31
- GRADIO_EMBEDDING_SEARCH_URL = "https://playground-worker.pages.dev/api/prompt"
32
 
33
- # High-level endpoints that bundle related documentation sections
34
  COMPOSITE_ENDPOINTS: dict[str, list[str]] = {
35
  "optimum": [
36
  "optimum",
@@ -56,32 +49,34 @@ COMPOSITE_ENDPOINTS: dict[str, list[str]] = {
56
  ],
57
  }
58
 
59
-
60
- def _expand_endpoint(endpoint: str) -> list[str]:
61
- return COMPOSITE_ENDPOINTS.get(endpoint, [endpoint])
62
-
63
-
64
  # ---------------------------------------------------------------------------
65
- # Gradio documentation helpers (uses gradio.app instead of HF docs)
66
  # ---------------------------------------------------------------------------
67
 
 
 
 
 
68
 
69
- async def _fetch_gradio_full_docs() -> str:
70
- """Fetch Gradio's full documentation from llms.txt"""
71
- async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
72
- response = await client.get(GRADIO_LLMS_TXT_URL)
73
- response.raise_for_status()
74
- return response.text
75
 
76
 
77
- async def _search_gradio_docs(query: str) -> str:
78
  """
79
- Run embedding search on Gradio's documentation via their API.
80
- Returns the most relevant content for the query.
 
81
  """
82
  async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
83
- response = await client.post(
84
- GRADIO_EMBEDDING_SEARCH_URL,
 
 
 
 
 
85
  headers={
86
  "Content-Type": "application/json",
87
  "Origin": "https://gradio-docs-mcp.up.railway.app",
@@ -92,143 +87,96 @@ async def _search_gradio_docs(query: str) -> str:
92
  "FALLBACK_PROMPT": "No results found",
93
  },
94
  )
95
- response.raise_for_status()
96
- result = response.json()
97
- return result.get("SYS_PROMPT", "No results found")
98
 
99
 
100
- def _format_gradio_results(content: str, query: str | None = None) -> str:
101
- """Format Gradio documentation results"""
102
- header = "# Gradio Documentation\n\n"
103
- if query:
104
- header += f"Search query: '{query}'\n\n"
105
- header += "Source: https://gradio.app/docs\n\n---\n\n"
106
- return header + content
107
 
108
 
109
- async def _fetch_html_page(hf_token: str, endpoint: str) -> str:
110
- """Fetch the HTML page for a given endpoint"""
111
- base_url = "https://huggingface.co/docs"
112
- url = f"{base_url}/{endpoint}"
113
  headers = {"Authorization": f"Bearer {hf_token}"}
114
 
115
  async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
116
- response = await client.get(url, headers=headers)
117
- response.raise_for_status()
118
-
119
- return response.text
120
-
121
-
122
- def _parse_sidebar_navigation(html_content: str) -> list[dict[str, str]]:
123
- """Parse the sidebar navigation and extract all links"""
124
- soup = BeautifulSoup(html_content, "html.parser")
125
- sidebar = soup.find("nav", class_=lambda x: x and "flex-auto" in x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- if not sidebar:
128
- raise ValueError("Could not find navigation sidebar")
129
 
130
- links = sidebar.find_all("a", href=True)
131
- nav_data = []
132
-
133
- for link in links:
134
- title = link.get_text(strip=True)
135
- href = link["href"]
136
-
137
- # Make URL absolute
138
- page_url = f"https://huggingface.co{href}" if href.startswith("/") else href
139
- nav_data.append({"title": title, "url": page_url})
140
-
141
- return nav_data
142
-
143
-
144
- async def _fetch_single_glimpse(
145
- client: httpx.AsyncClient, hf_token: str, item: dict[str, str]
146
- ) -> dict[str, str]:
147
- """Fetch a short glimpse for a single page"""
148
- md_url = f"{item['url']}.md"
149
- headers = {"Authorization": f"Bearer {hf_token}"}
150
-
151
- try:
152
- response = await client.get(md_url, headers=headers)
153
- response.raise_for_status()
154
-
155
- content = response.text.strip()
156
- snippet_length = 200
157
- glimpse = content[:snippet_length].strip()
158
- if len(content) > snippet_length:
159
- glimpse += "..."
160
-
161
- return {
162
- "title": item["title"],
163
- "url": item["url"],
164
- "md_url": md_url,
165
- "glimpse": glimpse,
166
- "content": content,
167
- }
168
- except Exception as e:
169
- return {
170
- "title": item["title"],
171
- "url": item["url"],
172
- "md_url": md_url,
173
- "glimpse": f"[Could not fetch glimpse: {str(e)[:50]}]",
174
- "content": "",
175
- }
176
-
177
-
178
- async def _fetch_all_glimpses(
179
- hf_token: str, nav_data: list[dict[str, str]]
180
- ) -> list[dict[str, str]]:
181
- """Fetch glimpses for all pages in parallel"""
182
- async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
183
- result_items = await asyncio.gather(
184
- *[_fetch_single_glimpse(client, hf_token, item) for item in nav_data]
185
- )
186
 
187
- return list(result_items)
 
 
 
188
 
 
 
 
189
 
190
- async def _load_single_endpoint(hf_token: str, endpoint: str) -> list[dict[str, str]]:
191
- """Fetch docs for a single endpoint."""
192
- html_content = await _fetch_html_page(hf_token, endpoint)
193
- nav_data = _parse_sidebar_navigation(html_content)
194
- if not nav_data:
195
- raise ValueError(f"No navigation links found for endpoint '{endpoint}'")
196
 
197
- docs = await _fetch_all_glimpses(hf_token, nav_data)
198
- for doc in docs:
199
- doc["section"] = endpoint
200
- return docs
201
 
202
 
203
- async def _get_docs(hf_token: str, endpoint: str) -> list[dict[str, str]]:
204
- """Return docs for a single endpoint or expanded composite."""
205
- async with _CACHE_LOCK:
206
- cached = _DOCS_CACHE.get(endpoint)
207
- if cached is not None:
208
- return cached
209
-
210
- docs: list[dict[str, str]] = []
211
- for member in _expand_endpoint(endpoint):
212
- async with _CACHE_LOCK:
213
- member_cached = _DOCS_CACHE.get(member)
214
- if member_cached is None:
215
- member_cached = await _load_single_endpoint(hf_token, member)
216
- async with _CACHE_LOCK:
217
- _DOCS_CACHE[member] = member_cached
218
- docs.extend(member_cached)
219
-
220
- async with _CACHE_LOCK:
221
- _DOCS_CACHE[endpoint] = docs
222
- return docs
223
-
224
-
225
- async def _ensure_index(
226
  endpoint: str, docs: list[dict[str, str]]
227
  ) -> tuple[Any, MultifieldParser]:
228
- async with _CACHE_LOCK:
229
- cached = _INDEX_CACHE.get(endpoint)
230
- if cached is not None:
231
- return cached
232
 
233
  analyzer = StemmingAnalyzer()
234
  schema = Schema(
@@ -260,226 +208,167 @@ async def _ensure_index(
260
  group=OrGroup,
261
  )
262
 
263
- async with _CACHE_LOCK:
264
- _INDEX_CACHE[endpoint] = (index, parser)
265
  return index, parser
266
 
267
 
268
  async def _search_docs(
269
- endpoint: str,
270
- docs: list[dict[str, str]],
271
- query: str,
272
- limit: int | None,
273
  ) -> tuple[list[dict[str, Any]], str | None]:
274
- """
275
- Run a Whoosh search over documentation entries.
276
-
277
- Returns (results, fallback_message). If fallback_message is not None, the caller
278
- should surface fallback information to the user.
279
- """
280
- index, parser = await _ensure_index(endpoint, docs)
281
 
282
  try:
283
  query_obj = parser.parse(query)
284
  except Exception:
285
- return (
286
- [],
287
- "Query contained unsupported syntax; showing default ordering instead.",
288
- )
289
 
290
  with index.searcher() as searcher:
291
- whoosh_results = searcher.search(query_obj, limit=limit or None)
292
- matches: list[dict[str, Any]] = []
293
- for hit in whoosh_results:
294
- matches.append(
295
- {
296
- "title": hit["title"],
297
- "url": hit["url"],
298
- "md_url": hit.get("md_url", ""),
299
- "section": hit.get("section", endpoint),
300
- "glimpse": hit["glimpse"],
301
- "score": round(hit.score, 2),
302
- }
303
- )
304
 
305
  if not matches:
306
- return [], "No strong matches found; showing default ordering instead."
307
-
308
  return matches, None
309
 
310
 
311
- def _format_exploration_results(
312
- endpoint: str,
313
- result_items: list[dict[str, str]],
314
- total_items: int,
315
- query: str | None = None,
316
- fallback_message: str | None = None,
317
- ) -> str:
318
- """Format the exploration results as a readable string"""
319
- base_url = "https://huggingface.co/docs"
320
- url = f"{base_url}/{endpoint}"
321
- result = f"Documentation structure for: {url}\n\n"
322
-
323
- if query:
324
- result += (
325
- f"Query: '{query}' → showing {len(result_items)} result(s)"
326
- f" out of {total_items} pages"
327
- )
328
- if fallback_message:
329
- result += f" ({fallback_message})"
330
- result += "\n\n"
331
- else:
332
- result += (
333
- f"Found {len(result_items)} page(s) (total available: {total_items}).\n\n"
334
- )
335
-
336
- for i, item in enumerate(result_items, 1):
337
- result += f"{i}. **{item['title']}**\n"
338
- result += f" URL: {item['url']}\n"
339
- result += f" Section: {item.get('section', endpoint)}\n"
340
- if query and "score" in item:
341
- result += f" Relevance score: {item['score']:.2f}\n"
342
- result += f" Glimpse: {item['glimpse']}\n\n"
343
-
344
- return result
345
 
346
 
347
- async def explore_hf_docs(
348
- hf_token: str,
349
  endpoint: str,
 
 
350
  query: str | None = None,
351
- max_results: int | None = None,
352
  ) -> str:
353
- """Main function to explore documentation structure"""
354
- cached_items = await _get_docs(hf_token, endpoint)
355
-
356
- total_count = len(cached_items)
357
- if max_results is None:
358
- limit = DEFAULT_MAX_RESULTS
359
- limit_note = f"Showing top {DEFAULT_MAX_RESULTS} results (set max_results to adjust)."
360
- else:
361
- limit = max_results if max_results > 0 else None
362
- limit_note = None
363
- if limit is None:
364
- return "Error: max_results must be greater than zero."
365
-
366
- if limit > MAX_RESULTS_CAP:
367
- limit_note = (
368
- f"Requested {limit} results but showing top {MAX_RESULTS_CAP} (maximum allowed)."
369
- )
370
- limit = MAX_RESULTS_CAP
371
-
372
- selected_items: list[dict[str, Any]]
373
- fallback_message: str | None = None
374
 
375
  if query:
376
- search_results, fallback_message = await _search_docs(
377
- endpoint,
378
- cached_items,
379
- query,
380
- limit,
381
- )
382
-
383
- if search_results:
384
- selected_items = search_results
385
- else:
386
- selected_items = cached_items[:limit] if limit else cached_items
387
  else:
388
- selected_items = cached_items[:limit] if limit else cached_items
 
 
 
 
 
 
 
 
 
 
 
389
 
390
- if not selected_items:
391
- return f"No documentation entries available for endpoint '{endpoint}'."
392
 
393
- note = None
394
- if fallback_message or limit_note:
395
- pieces = []
396
- if fallback_message:
397
- pieces.append(fallback_message)
398
- if limit_note:
399
- pieces.append(limit_note)
400
- note = "; ".join(pieces)
401
-
402
- result = _format_exploration_results(
403
- endpoint,
404
- selected_items,
405
- total_items=total_count,
406
- query=query,
407
- fallback_message=note,
408
- )
409
 
410
- return result
 
 
411
 
412
 
413
  async def explore_hf_docs_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
414
- """
415
- Explore the documentation structure for a given endpoint by parsing the sidebar navigation
416
-
417
- Args:
418
- arguments: Dictionary with 'endpoint' parameter (e.g., 'trl', 'transformers', etc.)
419
-
420
- Returns:
421
- Tuple of (structured_navigation_with_glimpses, success)
422
- """
423
- endpoint = arguments.get("endpoint", "")
424
  query = arguments.get("query")
425
  max_results = arguments.get("max_results")
426
 
427
  if not endpoint:
428
  return "Error: No endpoint provided", False
429
 
430
- endpoint = endpoint.lstrip("/")
431
-
432
- # Special handling for Gradio docs (hosted at gradio.app, not HF docs)
433
  if endpoint.lower() == "gradio":
434
  try:
435
  clean_query = (
436
  query.strip() if isinstance(query, str) and query.strip() else None
437
  )
 
 
438
  if clean_query:
439
- # Use embedding search for specific queries
440
- content = await _search_gradio_docs(clean_query)
441
- else:
442
- # Fetch full docs when no query provided
443
- content = await _fetch_gradio_full_docs()
444
- return _format_gradio_results(content, query=clean_query), True
445
  except httpx.HTTPStatusError as e:
446
- return (
447
- f"HTTP error fetching Gradio docs: {e.response.status_code}",
448
- False,
449
- )
450
  except httpx.RequestError as e:
451
  return f"Request error fetching Gradio docs: {str(e)}", False
452
  except Exception as e:
453
  return f"Error fetching Gradio docs: {str(e)}", False
454
 
455
- # Standard HF docs flow for all other endpoints
456
  hf_token = os.environ.get("HF_TOKEN")
457
-
458
  if not hf_token:
459
  return "Error: HF_TOKEN environment variable not set", False
460
 
461
  try:
462
- try:
463
- max_results_int = int(max_results) if max_results is not None else None
464
- except (TypeError, ValueError):
465
- return "Error: max_results must be an integer", False
466
-
467
- if max_results_int is not None and max_results_int <= 0:
468
- return "Error: max_results must be greater than zero", False
469
-
470
- result = await explore_hf_docs(
471
- hf_token,
472
- endpoint,
473
- query=query.strip() if isinstance(query, str) and query.strip() else None,
474
- max_results=max_results_int,
 
 
 
 
 
 
 
 
 
 
 
 
475
  )
476
- return result, True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
 
478
  except httpx.HTTPStatusError as e:
479
- return (
480
- f"HTTP error: {e.response.status_code} - {e.response.text[:200]}",
481
- False,
482
- )
483
  except httpx.RequestError as e:
484
  return f"Request error: {str(e)}", False
485
  except ValueError as e:
@@ -488,121 +377,98 @@ async def explore_hf_docs_handler(arguments: dict[str, Any]) -> tuple[str, bool]
488
  return f"Unexpected error: {str(e)}", False
489
 
490
 
491
- async def _fetch_openapi_spec() -> dict[str, Any]:
492
- """Fetch and cache the HuggingFace OpenAPI specification"""
493
- global _openapi_spec_cache
 
 
 
 
 
 
494
 
495
- if _openapi_spec_cache is not None:
496
- return _openapi_spec_cache
497
 
498
- url = "https://huggingface.co/.well-known/openapi.json"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
 
500
- async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
501
- response = await client.get(url)
502
- response.raise_for_status()
503
 
504
- spec = response.json()
505
- _openapi_spec_cache = spec
 
 
 
506
 
507
- return spec
 
 
 
 
 
508
 
509
 
510
  def _extract_all_tags(spec: dict[str, Any]) -> list[str]:
511
- """Extract all unique tags from the OpenAPI spec"""
512
  tags = set()
513
-
514
- # Get tags from the tags section
515
  for tag_obj in spec.get("tags", []):
516
  if "name" in tag_obj:
517
  tags.add(tag_obj["name"])
518
-
519
- # Also get tags from paths (in case some aren't in the tags section)
520
- for path, path_item in spec.get("paths", {}).items():
521
- for method, operation in path_item.items():
522
  if method in ["get", "post", "put", "delete", "patch", "head", "options"]:
523
- for tag in operation.get("tags", []):
524
  tags.add(tag)
525
-
526
- return sorted(list(tags))
527
-
528
-
529
- def _search_openapi_by_tag(spec: dict[str, Any], tag: str) -> list[dict[str, Any]]:
530
- """Search for API endpoints with a specific tag"""
531
- results = []
532
- paths = spec.get("paths", {})
533
- servers = spec.get("servers", [])
534
- base_url = (
535
- servers[0].get("url", "https://huggingface.co")
536
- if servers
537
- else "https://huggingface.co"
538
- )
539
-
540
- for path, path_item in paths.items():
541
- for method, operation in path_item.items():
542
- if method not in [
543
- "get",
544
- "post",
545
- "put",
546
- "delete",
547
- "patch",
548
- "head",
549
- "options",
550
- ]:
551
- continue
552
-
553
- operation_tags = operation.get("tags", [])
554
- if tag in operation_tags:
555
- # Extract parameters
556
- parameters = operation.get("parameters", [])
557
- request_body = operation.get("requestBody", {})
558
- responses = operation.get("responses", {})
559
-
560
- results.append(
561
- {
562
- "path": path,
563
- "method": method.upper(),
564
- "operationId": operation.get("operationId", ""),
565
- "summary": operation.get("summary", ""),
566
- "description": operation.get("description", ""),
567
- "parameters": parameters,
568
- "request_body": request_body,
569
- "responses": responses,
570
- "base_url": base_url,
571
- }
572
- )
573
-
574
- return results
575
 
576
 
577
  def _generate_curl_example(endpoint: dict[str, Any]) -> str:
578
- """Generate a curl command example for an endpoint"""
579
  method = endpoint["method"]
580
  path = endpoint["path"]
581
  base_url = endpoint["base_url"]
582
 
583
- # Build the full URL with example path parameters
584
  full_path = path
585
  for param in endpoint.get("parameters", []):
586
  if param.get("in") == "path" and param.get("required"):
587
- param_name = param["name"]
588
  example = param.get(
589
- "example", param.get("schema", {}).get("example", f"<{param_name}>")
590
  )
591
- full_path = full_path.replace(f"{{{param_name}}}", str(example))
592
 
593
  curl = f"curl -X {method} \\\n '{base_url}{full_path}'"
594
 
595
- # Add query parameters if any
596
  query_params = [p for p in endpoint.get("parameters", []) if p.get("in") == "query"]
597
  if query_params and query_params[0].get("required"):
598
  param = query_params[0]
599
  example = param.get("example", param.get("schema", {}).get("example", "value"))
600
  curl += f"?{param['name']}={example}"
601
 
602
- # Add headers
603
  curl += " \\\n -H 'Authorization: Bearer $HF_TOKEN'"
604
 
605
- # Add request body if applicable
606
  if method in ["POST", "PUT", "PATCH"] and endpoint.get("request_body"):
607
  content = endpoint["request_body"].get("content", {})
608
  if "application/json" in content:
@@ -610,8 +476,6 @@ def _generate_curl_example(endpoint: dict[str, Any]) -> str:
610
  schema = content["application/json"].get("schema", {})
611
  example = schema.get("example", "{}")
612
  if isinstance(example, dict):
613
- import json
614
-
615
  example = json.dumps(example, indent=2)
616
  curl += f" \\\n -d '{example}'"
617
 
@@ -619,72 +483,50 @@ def _generate_curl_example(endpoint: dict[str, Any]) -> str:
619
 
620
 
621
  def _format_parameters(parameters: list[dict[str, Any]]) -> str:
622
- """Format parameter information from OpenAPI spec"""
623
  if not parameters:
624
  return ""
625
 
626
- # Group parameters by type
627
  path_params = [p for p in parameters if p.get("in") == "path"]
628
  query_params = [p for p in parameters if p.get("in") == "query"]
629
  header_params = [p for p in parameters if p.get("in") == "header"]
630
 
631
  output = []
632
 
633
- if path_params:
634
- output.append("**Path Parameters:**")
635
- for param in path_params:
636
- name = param.get("name", "")
637
- required = " (required)" if param.get("required") else " (optional)"
638
- description = param.get("description", "")
639
- param_type = param.get("schema", {}).get("type", "string")
640
- example = param.get("example") or param.get("schema", {}).get("example", "")
641
-
642
- output.append(f"- `{name}` ({param_type}){required}: {description}")
643
- if example:
644
- output.append(f" Example: `{example}`")
645
-
646
- if query_params:
647
  if output:
648
  output.append("")
649
- output.append("**Query Parameters:**")
650
- for param in query_params:
651
- name = param.get("name", "")
652
- required = " (required)" if param.get("required") else " (optional)"
653
- description = param.get("description", "")
654
- param_type = param.get("schema", {}).get("type", "string")
655
- example = param.get("example") or param.get("schema", {}).get("example", "")
656
-
657
- output.append(f"- `{name}` ({param_type}){required}: {description}")
658
  if example:
659
  output.append(f" Example: `{example}`")
660
 
661
- if header_params:
662
- if output:
663
- output.append("")
664
- output.append("**Header Parameters:**")
665
- for param in header_params:
666
- name = param.get("name", "")
667
- required = " (required)" if param.get("required") else " (optional)"
668
- description = param.get("description", "")
669
-
670
- output.append(f"- `{name}`{required}: {description}")
671
-
672
  return "\n".join(output)
673
 
674
 
675
  def _format_response_info(responses: dict[str, Any]) -> str:
676
- """Format response information from OpenAPI spec"""
677
  if not responses:
678
  return "No response information available"
679
 
680
  output = []
681
- for status_code, response_obj in list(responses.items())[
682
- :3
683
- ]: # Show first 3 status codes
684
- desc = response_obj.get("description", "")
685
- output.append(f"- **{status_code}**: {desc}")
686
-
687
- content = response_obj.get("content", {})
688
  if "application/json" in content:
689
  schema = content["application/json"].get("schema", {})
690
  if "type" in schema:
@@ -694,72 +536,87 @@ def _format_response_info(responses: dict[str, Any]) -> str:
694
 
695
 
696
  def _format_openapi_results(results: list[dict[str, Any]], tag: str) -> str:
697
- """Format OpenAPI search results as markdown with curl examples"""
698
  if not results:
699
  return f"No API endpoints found with tag '{tag}'"
700
 
701
- output = f"# API Endpoints for tag: `{tag}`\n\n"
702
- output += f"Found {len(results)} endpoint(s)\n\n"
703
- output += "---\n\n"
704
 
705
- for i, endpoint in enumerate(results, 1):
706
- output += f"## {i}. {endpoint['method']} {endpoint['path']}\n\n"
707
 
708
- if endpoint["summary"]:
709
- output += f"**Summary:** {endpoint['summary']}\n\n"
710
 
711
- if endpoint["description"]:
712
- desc = endpoint["description"][:300]
713
- if len(endpoint["description"]) > 300:
714
  desc += "..."
715
- output += f"**Description:** {desc}\n\n"
716
 
717
- # Parameters
718
- params_info = _format_parameters(endpoint.get("parameters", []))
719
  if params_info:
720
- output += params_info + "\n\n"
721
 
722
- # Curl example
723
- output += "**Usage:**\n```bash\n"
724
- output += _generate_curl_example(endpoint)
725
- output += "\n```\n\n"
726
 
727
- # Response info
728
- output += "**Returns:**\n"
729
- output += _format_response_info(endpoint["responses"])
730
- output += "\n\n"
731
 
732
- output += "---\n\n"
733
-
734
- return output
735
 
736
 
737
  async def search_openapi_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
738
- """
739
- Search the HuggingFace OpenAPI specification by tag
740
-
741
- Args:
742
- arguments: Dictionary with 'tag' parameter
743
-
744
- Returns:
745
- Tuple of (search_results, success)
746
- """
747
  tag = arguments.get("tag", "")
748
-
749
  if not tag:
750
  return "Error: No tag provided", False
751
 
752
  try:
753
- # Fetch OpenAPI spec (cached after first fetch)
754
  spec = await _fetch_openapi_spec()
 
 
 
 
 
 
 
755
 
756
- # Search for endpoints with this tag
757
- results = _search_openapi_by_tag(spec, tag)
 
 
 
 
 
 
 
 
 
 
 
 
 
758
 
759
- # Format results
760
- formatted = _format_openapi_results(results, tag)
 
 
 
 
 
 
 
 
 
 
 
761
 
762
- return formatted, True
763
 
764
  except httpx.HTTPStatusError as e:
765
  return f"HTTP error fetching OpenAPI spec: {e.response.status_code}", False
@@ -769,61 +626,81 @@ async def search_openapi_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
769
  return f"Error searching OpenAPI spec: {str(e)}", False
770
 
771
 
772
- async def hf_docs_fetch_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
773
- """
774
- Fetch full documentation content from a specific HF docs page
775
-
776
- Args:
777
- arguments: Dictionary with 'url' parameter (full URL to the doc page)
778
-
779
- Returns:
780
- Tuple of (full_markdown_content, success)
781
- """
782
- url = arguments.get("url", "")
783
-
784
- if not url:
785
- return "Error: No URL provided", False
786
-
787
- # Get HF token from environment
788
- hf_token = os.environ.get("HF_TOKEN")
789
-
790
- if not hf_token:
791
- return (
792
- "Error: HF_TOKEN environment variable not set",
793
- False,
794
- )
795
-
796
- # Add .md extension if not already present
797
- if not url.endswith(".md"):
798
- url = f"{url}.md"
799
-
800
- try:
801
- # Make request with auth
802
- headers = {"Authorization": f"Bearer {hf_token}"}
803
-
804
- async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
805
- response = await client.get(url, headers=headers)
806
- response.raise_for_status()
807
-
808
- content = response.text
809
-
810
- # Return the markdown content directly
811
- result = f"Documentation from: {url}\n\n{content}"
812
 
813
- return result, True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
814
 
815
- except httpx.HTTPStatusError as e:
816
- return (
817
- f"HTTP error fetching {url}: {e.response.status_code} - {e.response.text[:200]}",
818
- False,
819
- )
820
- except httpx.RequestError as e:
821
- return f"Request error fetching {url}: {str(e)}", False
822
- except Exception as e:
823
- return f"Error fetching documentation: {str(e)}", False
824
 
 
 
 
825
 
826
- # Tool specifications for documentation search
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
827
 
828
  EXPLORE_HF_DOCS_TOOL_SPEC = {
829
  "name": "explore_hf_docs",
@@ -845,48 +722,7 @@ EXPLORE_HF_DOCS_TOOL_SPEC = {
845
  "properties": {
846
  "endpoint": {
847
  "type": "string",
848
- "enum": [
849
- "hub",
850
- "transformers",
851
- "diffusers",
852
- "datasets",
853
- "gradio",
854
- "trackio",
855
- "smolagents",
856
- "huggingface_hub",
857
- "huggingface.js",
858
- "transformers.js",
859
- "inference-providers",
860
- "inference-endpoints",
861
- "peft",
862
- "accelerate",
863
- "optimum",
864
- "tokenizers",
865
- "courses",
866
- "evaluate",
867
- "tasks",
868
- "dataset-viewer",
869
- "trl",
870
- "simulate",
871
- "sagemaker",
872
- "timm",
873
- "safetensors",
874
- "tgi",
875
- "setfit",
876
- "lerobot",
877
- "autotrain",
878
- "tei",
879
- "bitsandbytes",
880
- "sentence_transformers",
881
- "chat-ui",
882
- "leaderboards",
883
- "lighteval",
884
- "argilla",
885
- "distilabel",
886
- "microsoft-azure",
887
- "kernels",
888
- "google-cloud",
889
- ],
890
  "description": (
891
  "The documentation endpoint to explore. Each endpoint corresponds to a major section of the Hugging Face documentation:\n\n"
892
  "• courses — All Hugging Face courses (LLM, robotics, MCP, smol (llm training), agents, deep RL, computer vision, games, diffusion, 3D, audio) and the cookbook recipes. Probably the best place for examples.\n"
@@ -934,15 +770,13 @@ EXPLORE_HF_DOCS_TOOL_SPEC = {
934
  "query": {
935
  "type": "string",
936
  "description": (
937
- "Optional keyword query to rank and filter documentation pages. Fuzzy matching is used "
938
- "against titles, URLs, and glimpses to surface the most relevant content."
939
  ),
940
  },
941
  "max_results": {
942
  "type": "integer",
943
- "description": (
944
- "Optional cap on number of results to return. Defaults to 20 when omitted and cannot exceed 50."
945
- ),
946
  "minimum": 1,
947
  "maximum": 50,
948
  },
@@ -980,40 +814,3 @@ HF_DOCS_FETCH_TOOL_SPEC = {
980
  "required": ["url"],
981
  },
982
  }
983
-
984
-
985
- async def _get_api_search_tool_spec() -> dict[str, Any]:
986
- """
987
- Dynamically generate the OpenAPI tool spec with tag enum populated at runtime
988
- This must be called async to fetch the OpenAPI spec and extract tags
989
- """
990
- spec = await _fetch_openapi_spec()
991
- tags = _extract_all_tags(spec)
992
-
993
- return {
994
- "name": "search_hf_api_endpoints",
995
- "description": (
996
- "Search HuggingFace OpenAPI specification by tag to find API endpoints with curl examples. "
997
- "**Use when:** (1) Need to interact with HF Hub API directly, (2) Building scripts for repo operations, "
998
- "(3) Need authentication patterns, (4) Understanding API parameters and responses, "
999
- "(5) Need curl examples for HTTP requests. "
1000
- "Returns: Endpoint paths, methods, parameters, curl examples with authentication, and response schemas. "
1001
- "**Pattern:** search_hf_api_endpoints (find endpoint) → use curl pattern in implementation. "
1002
- "Tags group related operations: repos, models, datasets, inference, spaces, etc. "
1003
- "**Note:** Each result includes curl example with $HF_TOKEN placeholder for authentication. "
1004
- "**For tool building:** This provides the API foundation for creating Hub interaction scripts."
1005
- ),
1006
- "parameters": {
1007
- "type": "object",
1008
- "properties": {
1009
- "tag": {
1010
- "type": "string",
1011
- "enum": tags,
1012
- "description": (
1013
- "The API tag to search for. Each tag groups related API endpoints. "
1014
- ),
1015
- },
1016
- },
1017
- "required": ["tag"],
1018
- },
1019
- }
 
1
  """
2
+ Documentation search tools for exploring HuggingFace and Gradio documentation.
 
3
  """
4
 
5
  import asyncio
6
+ import json
7
  import os
8
  from typing import Any
9
 
 
14
  from whoosh.filedb.filestore import RamStorage
15
  from whoosh.qparser import MultifieldParser, OrGroup
16
 
17
+ # ---------------------------------------------------------------------------
18
+ # Configuration
19
+ # ---------------------------------------------------------------------------
 
 
 
 
20
 
 
21
  DEFAULT_MAX_RESULTS = 20
22
  MAX_RESULTS_CAP = 50
23
 
 
24
  GRADIO_LLMS_TXT_URL = "https://gradio.app/llms.txt"
25
+ GRADIO_SEARCH_URL = "https://playground-worker.pages.dev/api/prompt"
26
 
 
27
  COMPOSITE_ENDPOINTS: dict[str, list[str]] = {
28
  "optimum": [
29
  "optimum",
 
49
  ],
50
  }
51
 
 
 
 
 
 
52
  # ---------------------------------------------------------------------------
53
+ # Caches
54
  # ---------------------------------------------------------------------------
55
 
56
+ _docs_cache: dict[str, list[dict[str, str]]] = {}
57
+ _index_cache: dict[str, tuple[Any, MultifieldParser]] = {}
58
+ _cache_lock = asyncio.Lock()
59
+ _openapi_cache: dict[str, Any] | None = None
60
 
61
+ # ---------------------------------------------------------------------------
62
+ # Gradio Documentation
63
+ # ---------------------------------------------------------------------------
 
 
 
64
 
65
 
66
+ async def _fetch_gradio_docs(query: str | None = None) -> str:
67
  """
68
+ Fetch Gradio documentation.
69
+ Without query: Get full documentation from llms.txt
70
+ With query: Run embedding search on guides/demos for relevant content
71
  """
72
  async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
73
+ if not query:
74
+ resp = await client.get(GRADIO_LLMS_TXT_URL)
75
+ resp.raise_for_status()
76
+ return resp.text
77
+
78
+ resp = await client.post(
79
+ GRADIO_SEARCH_URL,
80
  headers={
81
  "Content-Type": "application/json",
82
  "Origin": "https://gradio-docs-mcp.up.railway.app",
 
87
  "FALLBACK_PROMPT": "No results found",
88
  },
89
  )
90
+ resp.raise_for_status()
91
+ return resp.json().get("SYS_PROMPT", "No results found")
 
92
 
93
 
94
+ # ---------------------------------------------------------------------------
95
+ # HF Documentation - Fetching
96
+ # ---------------------------------------------------------------------------
 
 
 
 
97
 
98
 
99
+ async def _fetch_endpoint_docs(hf_token: str, endpoint: str) -> list[dict[str, str]]:
100
+ """Fetch all docs for an endpoint by parsing sidebar and fetching each page."""
101
+ url = f"https://huggingface.co/docs/{endpoint}"
 
102
  headers = {"Authorization": f"Bearer {hf_token}"}
103
 
104
  async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
105
+ resp = await client.get(url, headers=headers)
106
+ resp.raise_for_status()
107
+
108
+ soup = BeautifulSoup(resp.text, "html.parser")
109
+ sidebar = soup.find("nav", class_=lambda x: x and "flex-auto" in x)
110
+ if not sidebar:
111
+ raise ValueError(f"Could not find navigation sidebar for '{endpoint}'")
112
+
113
+ nav_items = []
114
+ for link in sidebar.find_all("a", href=True):
115
+ href = link["href"]
116
+ page_url = f"https://huggingface.co{href}" if href.startswith("/") else href
117
+ nav_items.append({"title": link.get_text(strip=True), "url": page_url})
118
+
119
+ if not nav_items:
120
+ raise ValueError(f"No navigation links found for '{endpoint}'")
121
+
122
+ async def fetch_page(item: dict[str, str]) -> dict[str, str]:
123
+ md_url = f"{item['url']}.md"
124
+ try:
125
+ r = await client.get(md_url, headers=headers)
126
+ r.raise_for_status()
127
+ content = r.text.strip()
128
+ glimpse = content[:200] + "..." if len(content) > 200 else content
129
+ except Exception as e:
130
+ content, glimpse = "", f"[Could not fetch: {str(e)[:50]}]"
131
+ return {
132
+ "title": item["title"],
133
+ "url": item["url"],
134
+ "md_url": md_url,
135
+ "glimpse": glimpse,
136
+ "content": content,
137
+ "section": endpoint,
138
+ }
139
+
140
+ return list(await asyncio.gather(*[fetch_page(item) for item in nav_items]))
141
 
 
 
142
 
143
+ async def _get_docs(hf_token: str, endpoint: str) -> list[dict[str, str]]:
144
+ """Get docs for endpoint with caching. Expands composite endpoints."""
145
+ async with _cache_lock:
146
+ if endpoint in _docs_cache:
147
+ return _docs_cache[endpoint]
148
+
149
+ sub_endpoints = COMPOSITE_ENDPOINTS.get(endpoint, [endpoint])
150
+ all_docs: list[dict[str, str]] = []
151
+
152
+ for sub in sub_endpoints:
153
+ async with _cache_lock:
154
+ if sub in _docs_cache:
155
+ all_docs.extend(_docs_cache[sub])
156
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
+ docs = await _fetch_endpoint_docs(hf_token, sub)
159
+ async with _cache_lock:
160
+ _docs_cache[sub] = docs
161
+ all_docs.extend(docs)
162
 
163
+ async with _cache_lock:
164
+ _docs_cache[endpoint] = all_docs
165
+ return all_docs
166
 
 
 
 
 
 
 
167
 
168
+ # ---------------------------------------------------------------------------
169
+ # HF Documentation - Search
170
+ # ---------------------------------------------------------------------------
 
171
 
172
 
173
+ async def _build_search_index(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  endpoint: str, docs: list[dict[str, str]]
175
  ) -> tuple[Any, MultifieldParser]:
176
+ """Build or retrieve cached Whoosh search index."""
177
+ async with _cache_lock:
178
+ if endpoint in _index_cache:
179
+ return _index_cache[endpoint]
180
 
181
  analyzer = StemmingAnalyzer()
182
  schema = Schema(
 
208
  group=OrGroup,
209
  )
210
 
211
+ async with _cache_lock:
212
+ _index_cache[endpoint] = (index, parser)
213
  return index, parser
214
 
215
 
216
  async def _search_docs(
217
+ endpoint: str, docs: list[dict[str, str]], query: str, limit: int
 
 
 
218
  ) -> tuple[list[dict[str, Any]], str | None]:
219
+ """Search docs using Whoosh. Returns (results, fallback_message)."""
220
+ index, parser = await _build_search_index(endpoint, docs)
 
 
 
 
 
221
 
222
  try:
223
  query_obj = parser.parse(query)
224
  except Exception:
225
+ return [], "Query contained unsupported syntax; showing default ordering."
 
 
 
226
 
227
  with index.searcher() as searcher:
228
+ results = searcher.search(query_obj, limit=limit)
229
+ matches = [
230
+ {
231
+ "title": hit["title"],
232
+ "url": hit["url"],
233
+ "md_url": hit.get("md_url", ""),
234
+ "section": hit.get("section", endpoint),
235
+ "glimpse": hit["glimpse"],
236
+ "score": round(hit.score, 2),
237
+ }
238
+ for hit in results
239
+ ]
 
240
 
241
  if not matches:
242
+ return [], "No strong matches found; showing default ordering."
 
243
  return matches, None
244
 
245
 
246
+ # ---------------------------------------------------------------------------
247
+ # HF Documentation - Formatting
248
+ # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
 
251
+ def _format_results(
 
252
  endpoint: str,
253
+ items: list[dict[str, Any]],
254
+ total: int,
255
  query: str | None = None,
256
+ note: str | None = None,
257
  ) -> str:
258
+ """Format search results as readable text."""
259
+ base_url = f"https://huggingface.co/docs/{endpoint}"
260
+ out = f"Documentation structure for: {base_url}\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  if query:
263
+ out += f"Query: '{query}' → showing {len(items)} result(s) out of {total} pages"
264
+ if note:
265
+ out += f" ({note})"
266
+ out += "\n\n"
 
 
 
 
 
 
 
267
  else:
268
+ out += f"Found {len(items)} page(s) (total available: {total}).\n"
269
+ if note:
270
+ out += f"({note})\n"
271
+ out += "\n"
272
+
273
+ for i, item in enumerate(items, 1):
274
+ out += f"{i}. **{item['title']}**\n"
275
+ out += f" URL: {item['url']}\n"
276
+ out += f" Section: {item.get('section', endpoint)}\n"
277
+ if query and "score" in item:
278
+ out += f" Relevance score: {item['score']:.2f}\n"
279
+ out += f" Glimpse: {item['glimpse']}\n\n"
280
 
281
+ return out
 
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
+ # ---------------------------------------------------------------------------
285
+ # Handlers
286
+ # ---------------------------------------------------------------------------
287
 
288
 
289
  async def explore_hf_docs_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
290
+ """Explore documentation structure with optional search query."""
291
+ endpoint = arguments.get("endpoint", "").lstrip("/")
 
 
 
 
 
 
 
 
292
  query = arguments.get("query")
293
  max_results = arguments.get("max_results")
294
 
295
  if not endpoint:
296
  return "Error: No endpoint provided", False
297
 
298
+ # Gradio uses its own API
 
 
299
  if endpoint.lower() == "gradio":
300
  try:
301
  clean_query = (
302
  query.strip() if isinstance(query, str) and query.strip() else None
303
  )
304
+ content = await _fetch_gradio_docs(clean_query)
305
+ header = "# Gradio Documentation\n\n"
306
  if clean_query:
307
+ header += f"Query: '{clean_query}'\n\n"
308
+ header += "Source: https://gradio.app/docs\n\n---\n\n"
309
+ return header + content, True
 
 
 
310
  except httpx.HTTPStatusError as e:
311
+ return f"HTTP error fetching Gradio docs: {e.response.status_code}", False
 
 
 
312
  except httpx.RequestError as e:
313
  return f"Request error fetching Gradio docs: {str(e)}", False
314
  except Exception as e:
315
  return f"Error fetching Gradio docs: {str(e)}", False
316
 
317
+ # HF docs
318
  hf_token = os.environ.get("HF_TOKEN")
 
319
  if not hf_token:
320
  return "Error: HF_TOKEN environment variable not set", False
321
 
322
  try:
323
+ max_results_int = int(max_results) if max_results is not None else None
324
+ except (TypeError, ValueError):
325
+ return "Error: max_results must be an integer", False
326
+
327
+ if max_results_int is not None and max_results_int <= 0:
328
+ return "Error: max_results must be greater than zero", False
329
+
330
+ try:
331
+ docs = await _get_docs(hf_token, endpoint)
332
+ total = len(docs)
333
+
334
+ # Determine limit
335
+ if max_results_int is None:
336
+ limit = DEFAULT_MAX_RESULTS
337
+ limit_note = f"Showing top {DEFAULT_MAX_RESULTS} results (set max_results to adjust)."
338
+ elif max_results_int > MAX_RESULTS_CAP:
339
+ limit = MAX_RESULTS_CAP
340
+ limit_note = f"Requested {max_results_int} but showing top {MAX_RESULTS_CAP} (maximum)."
341
+ else:
342
+ limit = max_results_int
343
+ limit_note = None
344
+
345
+ # Search or paginate
346
+ clean_query = (
347
+ query.strip() if isinstance(query, str) and query.strip() else None
348
  )
349
+ fallback_msg = None
350
+
351
+ if clean_query:
352
+ results, fallback_msg = await _search_docs(
353
+ endpoint, docs, clean_query, limit
354
+ )
355
+ if not results:
356
+ results = docs[:limit]
357
+ else:
358
+ results = docs[:limit]
359
+
360
+ # Combine notes
361
+ notes = []
362
+ if fallback_msg:
363
+ notes.append(fallback_msg)
364
+ if limit_note:
365
+ notes.append(limit_note)
366
+ note = "; ".join(notes) if notes else None
367
+
368
+ return _format_results(endpoint, results, total, clean_query, note), True
369
 
370
  except httpx.HTTPStatusError as e:
371
+ return f"HTTP error: {e.response.status_code} - {e.response.text[:200]}", False
 
 
 
372
  except httpx.RequestError as e:
373
  return f"Request error: {str(e)}", False
374
  except ValueError as e:
 
377
  return f"Unexpected error: {str(e)}", False
378
 
379
 
380
+ async def hf_docs_fetch_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
381
+ """Fetch full markdown content of a documentation page."""
382
+ url = arguments.get("url", "")
383
+ if not url:
384
+ return "Error: No URL provided", False
385
+
386
+ hf_token = os.environ.get("HF_TOKEN")
387
+ if not hf_token:
388
+ return "Error: HF_TOKEN environment variable not set", False
389
 
390
+ if not url.endswith(".md"):
391
+ url = f"{url}.md"
392
 
393
+ try:
394
+ async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
395
+ resp = await client.get(
396
+ url, headers={"Authorization": f"Bearer {hf_token}"}
397
+ )
398
+ resp.raise_for_status()
399
+ return f"Documentation from: {url}\n\n{resp.text}", True
400
+ except httpx.HTTPStatusError as e:
401
+ return (
402
+ f"HTTP error fetching {url}: {e.response.status_code} - {e.response.text[:200]}",
403
+ False,
404
+ )
405
+ except httpx.RequestError as e:
406
+ return f"Request error fetching {url}: {str(e)}", False
407
+ except Exception as e:
408
+ return f"Error fetching documentation: {str(e)}", False
409
+
410
+
411
+ # ---------------------------------------------------------------------------
412
+ # OpenAPI Search
413
+ # ---------------------------------------------------------------------------
414
 
 
 
 
415
 
416
+ async def _fetch_openapi_spec() -> dict[str, Any]:
417
+ """Fetch and cache HuggingFace OpenAPI specification."""
418
+ global _openapi_cache
419
+ if _openapi_cache is not None:
420
+ return _openapi_cache
421
 
422
+ async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
423
+ resp = await client.get("https://huggingface.co/.well-known/openapi.json")
424
+ resp.raise_for_status()
425
+
426
+ _openapi_cache = resp.json()
427
+ return _openapi_cache
428
 
429
 
430
  def _extract_all_tags(spec: dict[str, Any]) -> list[str]:
431
+ """Extract all unique tags from OpenAPI spec."""
432
  tags = set()
 
 
433
  for tag_obj in spec.get("tags", []):
434
  if "name" in tag_obj:
435
  tags.add(tag_obj["name"])
436
+ for path_item in spec.get("paths", {}).values():
437
+ for method, op in path_item.items():
 
 
438
  if method in ["get", "post", "put", "delete", "patch", "head", "options"]:
439
+ for tag in op.get("tags", []):
440
  tags.add(tag)
441
+ return sorted(tags)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
 
443
 
444
  def _generate_curl_example(endpoint: dict[str, Any]) -> str:
445
+ """Generate curl command example for an endpoint."""
446
  method = endpoint["method"]
447
  path = endpoint["path"]
448
  base_url = endpoint["base_url"]
449
 
450
+ # Build URL with path parameters
451
  full_path = path
452
  for param in endpoint.get("parameters", []):
453
  if param.get("in") == "path" and param.get("required"):
454
+ name = param["name"]
455
  example = param.get(
456
+ "example", param.get("schema", {}).get("example", f"<{name}>")
457
  )
458
+ full_path = full_path.replace(f"{{{name}}}", str(example))
459
 
460
  curl = f"curl -X {method} \\\n '{base_url}{full_path}'"
461
 
462
+ # Add query parameters
463
  query_params = [p for p in endpoint.get("parameters", []) if p.get("in") == "query"]
464
  if query_params and query_params[0].get("required"):
465
  param = query_params[0]
466
  example = param.get("example", param.get("schema", {}).get("example", "value"))
467
  curl += f"?{param['name']}={example}"
468
 
 
469
  curl += " \\\n -H 'Authorization: Bearer $HF_TOKEN'"
470
 
471
+ # Add request body
472
  if method in ["POST", "PUT", "PATCH"] and endpoint.get("request_body"):
473
  content = endpoint["request_body"].get("content", {})
474
  if "application/json" in content:
 
476
  schema = content["application/json"].get("schema", {})
477
  example = schema.get("example", "{}")
478
  if isinstance(example, dict):
 
 
479
  example = json.dumps(example, indent=2)
480
  curl += f" \\\n -d '{example}'"
481
 
 
483
 
484
 
485
  def _format_parameters(parameters: list[dict[str, Any]]) -> str:
486
+ """Format parameter information from OpenAPI spec."""
487
  if not parameters:
488
  return ""
489
 
 
490
  path_params = [p for p in parameters if p.get("in") == "path"]
491
  query_params = [p for p in parameters if p.get("in") == "query"]
492
  header_params = [p for p in parameters if p.get("in") == "header"]
493
 
494
  output = []
495
 
496
+ for label, params in [
497
+ ("Path Parameters", path_params),
498
+ ("Query Parameters", query_params),
499
+ ("Header Parameters", header_params),
500
+ ]:
501
+ if not params:
502
+ continue
 
 
 
 
 
 
 
503
  if output:
504
  output.append("")
505
+ output.append(f"**{label}:**")
506
+ for p in params:
507
+ name = p.get("name", "")
508
+ required = " (required)" if p.get("required") else " (optional)"
509
+ desc = p.get("description", "")
510
+ ptype = p.get("schema", {}).get("type", "string")
511
+ example = p.get("example") or p.get("schema", {}).get("example", "")
512
+
513
+ output.append(f"- `{name}` ({ptype}){required}: {desc}")
514
  if example:
515
  output.append(f" Example: `{example}`")
516
 
 
 
 
 
 
 
 
 
 
 
 
517
  return "\n".join(output)
518
 
519
 
520
  def _format_response_info(responses: dict[str, Any]) -> str:
521
+ """Format response information from OpenAPI spec."""
522
  if not responses:
523
  return "No response information available"
524
 
525
  output = []
526
+ for status, resp_obj in list(responses.items())[:3]:
527
+ desc = resp_obj.get("description", "")
528
+ output.append(f"- **{status}**: {desc}")
529
+ content = resp_obj.get("content", {})
 
 
 
530
  if "application/json" in content:
531
  schema = content["application/json"].get("schema", {})
532
  if "type" in schema:
 
536
 
537
 
538
  def _format_openapi_results(results: list[dict[str, Any]], tag: str) -> str:
539
+ """Format OpenAPI search results with curl examples."""
540
  if not results:
541
  return f"No API endpoints found with tag '{tag}'"
542
 
543
+ out = f"# API Endpoints for tag: `{tag}`\n\n"
544
+ out += f"Found {len(results)} endpoint(s)\n\n---\n\n"
 
545
 
546
+ for i, ep in enumerate(results, 1):
547
+ out += f"## {i}. {ep['method']} {ep['path']}\n\n"
548
 
549
+ if ep["summary"]:
550
+ out += f"**Summary:** {ep['summary']}\n\n"
551
 
552
+ if ep["description"]:
553
+ desc = ep["description"][:300]
554
+ if len(ep["description"]) > 300:
555
  desc += "..."
556
+ out += f"**Description:** {desc}\n\n"
557
 
558
+ params_info = _format_parameters(ep.get("parameters", []))
 
559
  if params_info:
560
+ out += params_info + "\n\n"
561
 
562
+ out += "**Usage:**\n```bash\n"
563
+ out += _generate_curl_example(ep)
564
+ out += "\n```\n\n"
 
565
 
566
+ out += "**Returns:**\n"
567
+ out += _format_response_info(ep["responses"])
568
+ out += "\n\n---\n\n"
 
569
 
570
+ return out
 
 
571
 
572
 
573
  async def search_openapi_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
574
+ """Search HuggingFace OpenAPI specification by tag."""
 
 
 
 
 
 
 
 
575
  tag = arguments.get("tag", "")
 
576
  if not tag:
577
  return "Error: No tag provided", False
578
 
579
  try:
 
580
  spec = await _fetch_openapi_spec()
581
+ paths = spec.get("paths", {})
582
+ servers = spec.get("servers", [])
583
+ base_url = (
584
+ servers[0].get("url", "https://huggingface.co")
585
+ if servers
586
+ else "https://huggingface.co"
587
+ )
588
 
589
+ results = []
590
+ for path, path_item in paths.items():
591
+ for method, op in path_item.items():
592
+ if method not in [
593
+ "get",
594
+ "post",
595
+ "put",
596
+ "delete",
597
+ "patch",
598
+ "head",
599
+ "options",
600
+ ]:
601
+ continue
602
+ if tag not in op.get("tags", []):
603
+ continue
604
 
605
+ results.append(
606
+ {
607
+ "path": path,
608
+ "method": method.upper(),
609
+ "operationId": op.get("operationId", ""),
610
+ "summary": op.get("summary", ""),
611
+ "description": op.get("description", ""),
612
+ "parameters": op.get("parameters", []),
613
+ "request_body": op.get("requestBody", {}),
614
+ "responses": op.get("responses", {}),
615
+ "base_url": base_url,
616
+ }
617
+ )
618
 
619
+ return _format_openapi_results(results, tag), True
620
 
621
  except httpx.HTTPStatusError as e:
622
  return f"HTTP error fetching OpenAPI spec: {e.response.status_code}", False
 
626
  return f"Error searching OpenAPI spec: {str(e)}", False
627
 
628
 
629
+ async def _get_api_search_tool_spec() -> dict[str, Any]:
630
+ """Generate OpenAPI tool spec with tags populated at runtime."""
631
+ spec = await _fetch_openapi_spec()
632
+ tags = _extract_all_tags(spec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
 
634
+ return {
635
+ "name": "search_hf_api_endpoints",
636
+ "description": (
637
+ "Search HuggingFace OpenAPI specification by tag to find API endpoints with curl examples. "
638
+ "**Use when:** (1) Need to interact with HF Hub API directly, (2) Building scripts for repo operations, "
639
+ "(3) Need authentication patterns, (4) Understanding API parameters and responses, "
640
+ "(5) Need curl examples for HTTP requests. "
641
+ "Returns: Endpoint paths, methods, parameters, curl examples with authentication, and response schemas. "
642
+ "Tags group related operations: repos, models, datasets, inference, spaces, etc."
643
+ ),
644
+ "parameters": {
645
+ "type": "object",
646
+ "properties": {
647
+ "tag": {
648
+ "type": "string",
649
+ "enum": tags,
650
+ "description": "The API tag to search for. Each tag groups related API endpoints.",
651
+ },
652
+ },
653
+ "required": ["tag"],
654
+ },
655
+ }
656
 
 
 
 
 
 
 
 
 
 
657
 
658
+ # ---------------------------------------------------------------------------
659
+ # Tool Specifications
660
+ # ---------------------------------------------------------------------------
661
 
662
+ DOC_ENDPOINTS = [
663
+ "hub",
664
+ "transformers",
665
+ "diffusers",
666
+ "datasets",
667
+ "gradio",
668
+ "trackio",
669
+ "smolagents",
670
+ "huggingface_hub",
671
+ "huggingface.js",
672
+ "transformers.js",
673
+ "inference-providers",
674
+ "inference-endpoints",
675
+ "peft",
676
+ "accelerate",
677
+ "optimum",
678
+ "tokenizers",
679
+ "courses",
680
+ "evaluate",
681
+ "tasks",
682
+ "dataset-viewer",
683
+ "trl",
684
+ "simulate",
685
+ "sagemaker",
686
+ "timm",
687
+ "safetensors",
688
+ "tgi",
689
+ "setfit",
690
+ "lerobot",
691
+ "autotrain",
692
+ "tei",
693
+ "bitsandbytes",
694
+ "sentence_transformers",
695
+ "chat-ui",
696
+ "leaderboards",
697
+ "lighteval",
698
+ "argilla",
699
+ "distilabel",
700
+ "microsoft-azure",
701
+ "kernels",
702
+ "google-cloud",
703
+ ]
704
 
705
  EXPLORE_HF_DOCS_TOOL_SPEC = {
706
  "name": "explore_hf_docs",
 
722
  "properties": {
723
  "endpoint": {
724
  "type": "string",
725
+ "enum": DOC_ENDPOINTS,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
726
  "description": (
727
  "The documentation endpoint to explore. Each endpoint corresponds to a major section of the Hugging Face documentation:\n\n"
728
  "• courses — All Hugging Face courses (LLM, robotics, MCP, smol (llm training), agents, deep RL, computer vision, games, diffusion, 3D, audio) and the cookbook recipes. Probably the best place for examples.\n"
 
770
  "query": {
771
  "type": "string",
772
  "description": (
773
+ "Optional keyword query to rank and filter documentation pages. "
774
+ "For Gradio, use concise queries like 'how to use the image component' or 'audio component demo'."
775
  ),
776
  },
777
  "max_results": {
778
  "type": "integer",
779
+ "description": "Max results (default 20, max 50). Ignored for Gradio.",
 
 
780
  "minimum": 1,
781
  "maximum": 50,
782
  },
 
814
  "required": ["url"],
815
  },
816
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_dataset_tools.py DELETED
@@ -1,89 +0,0 @@
1
- """
2
- Test script for hf_repo_files and hf_repo_git tools
3
- """
4
-
5
- import asyncio
6
- import sys
7
- from typing import TypedDict
8
- from unittest.mock import MagicMock
9
-
10
-
11
- # Mock the types module before importing
12
- class ToolResult(TypedDict, total=False):
13
- formatted: str
14
- totalResults: int
15
- resultsShared: int
16
- isError: bool
17
-
18
-
19
- mock_types = MagicMock()
20
- mock_types.ToolResult = ToolResult
21
- sys.modules["agent.tools.types"] = mock_types
22
-
23
- from agent.tools.hf_repo_files_tool import HfRepoFilesTool
24
- from agent.tools.hf_repo_git_tool import HfRepoGitTool
25
-
26
-
27
- async def test_hf_repo_files():
28
- """Test hf_repo_files tool"""
29
- print("=" * 60)
30
- print("Testing hf_repo_files")
31
- print("=" * 60)
32
-
33
- tool = HfRepoFilesTool()
34
-
35
- # Test list
36
- print("\n→ list files in gpt2:")
37
- result = await tool.execute(
38
- {"operation": "list", "repo_id": "openai-community/gpt2"}
39
- )
40
- print(f" isError: {result.get('isError', False)}")
41
- print(f" totalResults: {result['totalResults']}")
42
- # Just show first few lines
43
- lines = result["formatted"].split("\n")
44
- print(" Output (first 5 lines):\n" + "\n".join(f" {line}" for line in lines))
45
-
46
- # Test read
47
- print("\n→ read config.json from gpt2:")
48
- result = await tool.execute(
49
- {"operation": "read", "repo_id": "openai-community/gpt2", "path": "config.json"}
50
- )
51
- print(f" isError: {result.get('isError', False)}")
52
- lines = result["formatted"].split("\n")
53
- print(" Output (first 10 lines):\n" + "\n".join(f" {line}" for line in lines))
54
-
55
-
56
- async def test_hf_repo_git():
57
- """Test hf_repo_git tool"""
58
- print("\n" + "=" * 60)
59
- print("Testing hf_repo_git")
60
- print("=" * 60)
61
-
62
- tool = HfRepoGitTool()
63
-
64
- # Test list_refs
65
- print("\n→ list_refs for gpt2:")
66
- result = await tool.execute(
67
- {"operation": "list_refs", "repo_id": "openai-community/gpt2"}
68
- )
69
- print(f" isError: {result.get('isError', False)}")
70
- print(
71
- " Output:\n"
72
- + "\n".join(f" {line}" for line in result["formatted"].split("\n"))
73
- )
74
-
75
- # Test help (no operation)
76
- print("\n→ help (no operation):")
77
- result = await tool.execute({})
78
- print(f" isError: {result.get('isError', False)}")
79
- lines = result["formatted"].split("\n")[:6]
80
- print(" Output (first 6 lines):\n" + "\n".join(f" {line}" for line in lines))
81
-
82
-
83
- if __name__ == "__main__":
84
- print("\nHF Repo Tools Test\n")
85
- asyncio.run(test_hf_repo_files())
86
- asyncio.run(test_hf_repo_git())
87
- print("\n" + "=" * 60)
88
- print("Tests complete!")
89
- print("=" * 60)