Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Commit ·
061ebdc
1
Parent(s): 838cd50
fix: pass user OAuth token to dataset inspection tool
Browse files- agent/core/tools.py +7 -1
- agent/tools/dataset_tools.py +9 -4
agent/core/tools.py
CHANGED
|
@@ -62,7 +62,13 @@ warnings.filterwarnings(
|
|
| 62 |
"ignore", category=DeprecationWarning, module="aiohttp.connector"
|
| 63 |
)
|
| 64 |
|
| 65 |
-
NOT_ALLOWED_TOOL_NAMES = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
def convert_mcp_content_to_string(content: list) -> str:
|
|
|
|
| 62 |
"ignore", category=DeprecationWarning, module="aiohttp.connector"
|
| 63 |
)
|
| 64 |
|
| 65 |
+
NOT_ALLOWED_TOOL_NAMES = [
|
| 66 |
+
"hf_jobs",
|
| 67 |
+
"hf_doc_search",
|
| 68 |
+
"hf_doc_fetch",
|
| 69 |
+
"hf_whoami",
|
| 70 |
+
"paper_search",
|
| 71 |
+
]
|
| 72 |
|
| 73 |
|
| 74 |
def convert_mcp_content_to_string(content: list) -> str:
|
agent/tools/dataset_tools.py
CHANGED
|
@@ -11,6 +11,7 @@ from typing import Any, TypedDict
|
|
| 11 |
|
| 12 |
import httpx
|
| 13 |
|
|
|
|
| 14 |
from agent.tools.types import ToolResult
|
| 15 |
|
| 16 |
BASE_URL = "https://datasets-server.huggingface.co"
|
|
@@ -26,9 +27,9 @@ class SplitConfig(TypedDict):
|
|
| 26 |
splits: list[str]
|
| 27 |
|
| 28 |
|
| 29 |
-
def _get_headers() -> dict:
|
| 30 |
"""Get auth headers for private/gated datasets"""
|
| 31 |
-
token = os.environ.get("HF_TOKEN")
|
| 32 |
if token:
|
| 33 |
return {"Authorization": f"Bearer {token}"}
|
| 34 |
return {}
|
|
@@ -39,12 +40,13 @@ async def inspect_dataset(
|
|
| 39 |
config: str | None = None,
|
| 40 |
split: str | None = None,
|
| 41 |
sample_rows: int = 3,
|
|
|
|
| 42 |
) -> ToolResult:
|
| 43 |
"""
|
| 44 |
Get comprehensive dataset info in one call.
|
| 45 |
All API calls made in parallel for speed.
|
| 46 |
"""
|
| 47 |
-
headers = _get_headers()
|
| 48 |
output_parts = []
|
| 49 |
errors = []
|
| 50 |
|
|
@@ -424,7 +426,9 @@ HF_INSPECT_DATASET_TOOL_SPEC = {
|
|
| 424 |
}
|
| 425 |
|
| 426 |
|
| 427 |
-
async def hf_inspect_dataset_handler(
|
|
|
|
|
|
|
| 428 |
"""Handler for agent tool router"""
|
| 429 |
try:
|
| 430 |
result = await inspect_dataset(
|
|
@@ -432,6 +436,7 @@ async def hf_inspect_dataset_handler(arguments: dict[str, Any]) -> tuple[str, bo
|
|
| 432 |
config=arguments.get("config"),
|
| 433 |
split=arguments.get("split"),
|
| 434 |
sample_rows=min(arguments.get("sample_rows", 3), 10),
|
|
|
|
| 435 |
)
|
| 436 |
return result["formatted"], not result.get("isError", False)
|
| 437 |
except Exception as e:
|
|
|
|
| 11 |
|
| 12 |
import httpx
|
| 13 |
|
| 14 |
+
from agent.core.session import Session
|
| 15 |
from agent.tools.types import ToolResult
|
| 16 |
|
| 17 |
BASE_URL = "https://datasets-server.huggingface.co"
|
|
|
|
| 27 |
splits: list[str]
|
| 28 |
|
| 29 |
|
| 30 |
+
def _get_headers(token: str | None = None) -> dict:
|
| 31 |
"""Get auth headers for private/gated datasets"""
|
| 32 |
+
token = token or os.environ.get("HF_TOKEN")
|
| 33 |
if token:
|
| 34 |
return {"Authorization": f"Bearer {token}"}
|
| 35 |
return {}
|
|
|
|
| 40 |
config: str | None = None,
|
| 41 |
split: str | None = None,
|
| 42 |
sample_rows: int = 3,
|
| 43 |
+
hf_token: str | None = None,
|
| 44 |
) -> ToolResult:
|
| 45 |
"""
|
| 46 |
Get comprehensive dataset info in one call.
|
| 47 |
All API calls made in parallel for speed.
|
| 48 |
"""
|
| 49 |
+
headers = _get_headers(hf_token)
|
| 50 |
output_parts = []
|
| 51 |
errors = []
|
| 52 |
|
|
|
|
| 426 |
}
|
| 427 |
|
| 428 |
|
| 429 |
+
async def hf_inspect_dataset_handler(
|
| 430 |
+
arguments: dict[str, Any], session: Session = None
|
| 431 |
+
) -> tuple[str, bool]:
|
| 432 |
"""Handler for agent tool router"""
|
| 433 |
try:
|
| 434 |
result = await inspect_dataset(
|
|
|
|
| 436 |
config=arguments.get("config"),
|
| 437 |
split=arguments.get("split"),
|
| 438 |
sample_rows=min(arguments.get("sample_rows", 3), 10),
|
| 439 |
+
hf_token=session.hf_token,
|
| 440 |
)
|
| 441 |
return result["formatted"], not result.get("isError", False)
|
| 442 |
except Exception as e:
|