akseljoonas HF Staff commited on
Commit
061ebdc
·
1 Parent(s): 838cd50

fix: pass user OAuth token to dataset inspection tool

Browse files
agent/core/tools.py CHANGED
@@ -62,7 +62,13 @@ warnings.filterwarnings(
62
  "ignore", category=DeprecationWarning, module="aiohttp.connector"
63
  )
64
 
65
- NOT_ALLOWED_TOOL_NAMES = ["hf_jobs", "hf_doc_search", "hf_doc_fetch", "hf_whoami", "paper_search"]
 
 
 
 
 
 
66
 
67
 
68
  def convert_mcp_content_to_string(content: list) -> str:
 
62
  "ignore", category=DeprecationWarning, module="aiohttp.connector"
63
  )
64
 
65
+ NOT_ALLOWED_TOOL_NAMES = [
66
+ "hf_jobs",
67
+ "hf_doc_search",
68
+ "hf_doc_fetch",
69
+ "hf_whoami",
70
+ "paper_search",
71
+ ]
72
 
73
 
74
  def convert_mcp_content_to_string(content: list) -> str:
agent/tools/dataset_tools.py CHANGED
@@ -11,6 +11,7 @@ from typing import Any, TypedDict
11
 
12
  import httpx
13
 
 
14
  from agent.tools.types import ToolResult
15
 
16
  BASE_URL = "https://datasets-server.huggingface.co"
@@ -26,9 +27,9 @@ class SplitConfig(TypedDict):
26
  splits: list[str]
27
 
28
 
29
- def _get_headers() -> dict:
30
  """Get auth headers for private/gated datasets"""
31
- token = os.environ.get("HF_TOKEN")
32
  if token:
33
  return {"Authorization": f"Bearer {token}"}
34
  return {}
@@ -39,12 +40,13 @@ async def inspect_dataset(
39
  config: str | None = None,
40
  split: str | None = None,
41
  sample_rows: int = 3,
 
42
  ) -> ToolResult:
43
  """
44
  Get comprehensive dataset info in one call.
45
  All API calls made in parallel for speed.
46
  """
47
- headers = _get_headers()
48
  output_parts = []
49
  errors = []
50
 
@@ -424,7 +426,9 @@ HF_INSPECT_DATASET_TOOL_SPEC = {
424
  }
425
 
426
 
427
- async def hf_inspect_dataset_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
 
 
428
  """Handler for agent tool router"""
429
  try:
430
  result = await inspect_dataset(
@@ -432,6 +436,7 @@ async def hf_inspect_dataset_handler(arguments: dict[str, Any]) -> tuple[str, bo
432
  config=arguments.get("config"),
433
  split=arguments.get("split"),
434
  sample_rows=min(arguments.get("sample_rows", 3), 10),
 
435
  )
436
  return result["formatted"], not result.get("isError", False)
437
  except Exception as e:
 
11
 
12
  import httpx
13
 
14
+ from agent.core.session import Session
15
  from agent.tools.types import ToolResult
16
 
17
  BASE_URL = "https://datasets-server.huggingface.co"
 
27
  splits: list[str]
28
 
29
 
30
+ def _get_headers(token: str | None = None) -> dict:
31
  """Get auth headers for private/gated datasets"""
32
+ token = token or os.environ.get("HF_TOKEN")
33
  if token:
34
  return {"Authorization": f"Bearer {token}"}
35
  return {}
 
40
  config: str | None = None,
41
  split: str | None = None,
42
  sample_rows: int = 3,
43
+ hf_token: str | None = None,
44
  ) -> ToolResult:
45
  """
46
  Get comprehensive dataset info in one call.
47
  All API calls made in parallel for speed.
48
  """
49
+ headers = _get_headers(hf_token)
50
  output_parts = []
51
  errors = []
52
 
 
426
  }
427
 
428
 
429
+ async def hf_inspect_dataset_handler(
430
+ arguments: dict[str, Any], session: Session = None
431
+ ) -> tuple[str, bool]:
432
  """Handler for agent tool router"""
433
  try:
434
  result = await inspect_dataset(
 
436
  config=arguments.get("config"),
437
  split=arguments.get("split"),
438
  sample_rows=min(arguments.get("sample_rows", 3), 10),
439
+ hf_token=session.hf_token,
440
  )
441
  return result["formatted"], not result.get("isError", False)
442
  except Exception as e: