Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Merge pull request #14 from huggingface/dataset_tools
Browse files- agent/core/tools.py +11 -0
- agent/tools/__init__.py +6 -0
- agent/tools/dataset_tools.py +405 -0
agent/core/tools.py
CHANGED
|
@@ -13,6 +13,10 @@ from lmnr import observe
|
|
| 13 |
from mcp.types import EmbeddedResource, ImageContent, TextContent
|
| 14 |
|
| 15 |
from agent.config import MCPServerConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
from agent.tools.docs_tools import (
|
| 17 |
EXPLORE_HF_DOCS_TOOL_SPEC,
|
| 18 |
HF_DOCS_FETCH_TOOL_SPEC,
|
|
@@ -257,6 +261,13 @@ def create_builtin_tools() -> list[ToolSpec]:
|
|
| 257 |
parameters=HF_DOCS_FETCH_TOOL_SPEC["parameters"],
|
| 258 |
handler=hf_docs_fetch_handler,
|
| 259 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
# Planning and job management tools
|
| 261 |
ToolSpec(
|
| 262 |
name=PLAN_TOOL_SPEC["name"],
|
|
|
|
| 13 |
from mcp.types import EmbeddedResource, ImageContent, TextContent
|
| 14 |
|
| 15 |
from agent.config import MCPServerConfig
|
| 16 |
+
from agent.tools.dataset_tools import (
|
| 17 |
+
HF_INSPECT_DATASET_TOOL_SPEC,
|
| 18 |
+
hf_inspect_dataset_handler,
|
| 19 |
+
)
|
| 20 |
from agent.tools.docs_tools import (
|
| 21 |
EXPLORE_HF_DOCS_TOOL_SPEC,
|
| 22 |
HF_DOCS_FETCH_TOOL_SPEC,
|
|
|
|
| 261 |
parameters=HF_DOCS_FETCH_TOOL_SPEC["parameters"],
|
| 262 |
handler=hf_docs_fetch_handler,
|
| 263 |
),
|
| 264 |
+
# Dataset inspection tool (unified)
|
| 265 |
+
ToolSpec(
|
| 266 |
+
name=HF_INSPECT_DATASET_TOOL_SPEC["name"],
|
| 267 |
+
description=HF_INSPECT_DATASET_TOOL_SPEC["description"],
|
| 268 |
+
parameters=HF_INSPECT_DATASET_TOOL_SPEC["parameters"],
|
| 269 |
+
handler=hf_inspect_dataset_handler,
|
| 270 |
+
),
|
| 271 |
# Planning and job management tools
|
| 272 |
ToolSpec(
|
| 273 |
name=PLAN_TOOL_SPEC["name"],
|
agent/tools/__init__.py
CHANGED
|
@@ -18,6 +18,10 @@ from agent.tools.github_search_code import (
|
|
| 18 |
GITHUB_SEARCH_CODE_TOOL_SPEC,
|
| 19 |
github_search_code_handler,
|
| 20 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, HfJobsTool, hf_jobs_handler
|
| 22 |
from agent.tools.types import ToolResult
|
| 23 |
|
|
@@ -34,4 +38,6 @@ __all__ = [
|
|
| 34 |
"github_read_file_handler",
|
| 35 |
"GITHUB_SEARCH_CODE_TOOL_SPEC",
|
| 36 |
"github_search_code_handler",
|
|
|
|
|
|
|
| 37 |
]
|
|
|
|
| 18 |
GITHUB_SEARCH_CODE_TOOL_SPEC,
|
| 19 |
github_search_code_handler,
|
| 20 |
)
|
| 21 |
+
from agent.tools.dataset_tools import (
|
| 22 |
+
HF_INSPECT_DATASET_TOOL_SPEC,
|
| 23 |
+
hf_inspect_dataset_handler,
|
| 24 |
+
)
|
| 25 |
from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, HfJobsTool, hf_jobs_handler
|
| 26 |
from agent.tools.types import ToolResult
|
| 27 |
|
|
|
|
| 38 |
"github_read_file_handler",
|
| 39 |
"GITHUB_SEARCH_CODE_TOOL_SPEC",
|
| 40 |
"github_search_code_handler",
|
| 41 |
+
"HF_INSPECT_DATASET_TOOL_SPEC",
|
| 42 |
+
"hf_inspect_dataset_handler",
|
| 43 |
]
|
agent/tools/dataset_tools.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset Inspection Tool - Comprehensive dataset analysis in one call
|
| 3 |
+
|
| 4 |
+
Combines /is-valid, /splits, /info, /first-rows, and /parquet endpoints
|
| 5 |
+
to provide everything needed for ML tasks in a single tool call.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
import os
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
import httpx
|
| 13 |
+
|
| 14 |
+
from agent.tools.types import ToolResult
|
| 15 |
+
|
| 16 |
+
BASE_URL = "https://datasets-server.huggingface.co"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _get_headers() -> dict:
|
| 20 |
+
"""Get auth headers for private/gated datasets"""
|
| 21 |
+
token = os.environ.get("HF_TOKEN")
|
| 22 |
+
if token:
|
| 23 |
+
return {"Authorization": f"Bearer {token}"}
|
| 24 |
+
return {}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
async def inspect_dataset(
|
| 28 |
+
dataset: str,
|
| 29 |
+
config: str | None = None,
|
| 30 |
+
split: str | None = None,
|
| 31 |
+
sample_rows: int = 3,
|
| 32 |
+
) -> ToolResult:
|
| 33 |
+
"""
|
| 34 |
+
Get comprehensive dataset info in one call.
|
| 35 |
+
All API calls made in parallel for speed.
|
| 36 |
+
"""
|
| 37 |
+
headers = _get_headers()
|
| 38 |
+
output_parts = []
|
| 39 |
+
errors = []
|
| 40 |
+
|
| 41 |
+
async with httpx.AsyncClient(timeout=15, headers=headers) as client:
|
| 42 |
+
# Phase 1: Parallel calls for structure info (no dependencies)
|
| 43 |
+
is_valid_task = client.get(f"{BASE_URL}/is-valid", params={"dataset": dataset})
|
| 44 |
+
splits_task = client.get(f"{BASE_URL}/splits", params={"dataset": dataset})
|
| 45 |
+
parquet_task = client.get(f"{BASE_URL}/parquet", params={"dataset": dataset})
|
| 46 |
+
|
| 47 |
+
results = await asyncio.gather(
|
| 48 |
+
is_valid_task,
|
| 49 |
+
splits_task,
|
| 50 |
+
parquet_task,
|
| 51 |
+
return_exceptions=True,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# Process is-valid
|
| 55 |
+
if not isinstance(results[0], Exception):
|
| 56 |
+
try:
|
| 57 |
+
output_parts.append(_format_status(results[0].json()))
|
| 58 |
+
except Exception as e:
|
| 59 |
+
errors.append(f"is-valid: {e}")
|
| 60 |
+
|
| 61 |
+
# Process splits and auto-detect config/split
|
| 62 |
+
configs = []
|
| 63 |
+
if not isinstance(results[1], Exception):
|
| 64 |
+
try:
|
| 65 |
+
splits_data = results[1].json()
|
| 66 |
+
configs = _extract_configs(splits_data)
|
| 67 |
+
if not config:
|
| 68 |
+
config = configs[0]["name"] if configs else "default"
|
| 69 |
+
if not split:
|
| 70 |
+
split = configs[0]["splits"][0] if configs else "train"
|
| 71 |
+
output_parts.append(_format_structure(configs))
|
| 72 |
+
except Exception as e:
|
| 73 |
+
errors.append(f"splits: {e}")
|
| 74 |
+
|
| 75 |
+
if not config:
|
| 76 |
+
config = "default"
|
| 77 |
+
if not split:
|
| 78 |
+
split = "train"
|
| 79 |
+
|
| 80 |
+
# Process parquet (will be added at the end)
|
| 81 |
+
parquet_section = None
|
| 82 |
+
if not isinstance(results[2], Exception):
|
| 83 |
+
try:
|
| 84 |
+
parquet_section = _format_parquet_files(results[2].json())
|
| 85 |
+
except Exception:
|
| 86 |
+
pass # Silently skip if no parquet
|
| 87 |
+
|
| 88 |
+
# Phase 2: Parallel calls for content (depend on config/split)
|
| 89 |
+
info_task = client.get(
|
| 90 |
+
f"{BASE_URL}/info", params={"dataset": dataset, "config": config}
|
| 91 |
+
)
|
| 92 |
+
rows_task = client.get(
|
| 93 |
+
f"{BASE_URL}/first-rows",
|
| 94 |
+
params={"dataset": dataset, "config": config, "split": split},
|
| 95 |
+
timeout=30,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
content_results = await asyncio.gather(
|
| 99 |
+
info_task,
|
| 100 |
+
rows_task,
|
| 101 |
+
return_exceptions=True,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Process info (schema)
|
| 105 |
+
if not isinstance(content_results[0], Exception):
|
| 106 |
+
try:
|
| 107 |
+
output_parts.append(_format_schema(content_results[0].json(), config))
|
| 108 |
+
except Exception as e:
|
| 109 |
+
errors.append(f"info: {e}")
|
| 110 |
+
|
| 111 |
+
# Process sample rows
|
| 112 |
+
if not isinstance(content_results[1], Exception):
|
| 113 |
+
try:
|
| 114 |
+
output_parts.append(
|
| 115 |
+
_format_samples(
|
| 116 |
+
content_results[1].json(), config, split, sample_rows
|
| 117 |
+
)
|
| 118 |
+
)
|
| 119 |
+
except Exception as e:
|
| 120 |
+
errors.append(f"rows: {e}")
|
| 121 |
+
|
| 122 |
+
# Add parquet section at the end if available
|
| 123 |
+
if parquet_section:
|
| 124 |
+
output_parts.append(parquet_section)
|
| 125 |
+
|
| 126 |
+
# Combine output
|
| 127 |
+
formatted = f"# {dataset}\n\n" + "\n\n".join(output_parts)
|
| 128 |
+
if errors:
|
| 129 |
+
formatted += f"\n\n**Warnings:** {'; '.join(errors)}"
|
| 130 |
+
|
| 131 |
+
return {
|
| 132 |
+
"formatted": formatted,
|
| 133 |
+
"totalResults": 1,
|
| 134 |
+
"resultsShared": 1,
|
| 135 |
+
"isError": len(output_parts) == 0,
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _format_status(data: dict) -> str:
|
| 140 |
+
"""Format /is-valid response as status line"""
|
| 141 |
+
available = [
|
| 142 |
+
k
|
| 143 |
+
for k in ["viewer", "preview", "search", "filter", "statistics"]
|
| 144 |
+
if data.get(k)
|
| 145 |
+
]
|
| 146 |
+
if available:
|
| 147 |
+
return f"## Status\n✓ Valid ({', '.join(available)})"
|
| 148 |
+
return "## Status\n✗ Dataset may have issues"
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _extract_configs(splits_data: dict) -> list[dict]:
|
| 152 |
+
"""Group splits by config"""
|
| 153 |
+
configs: dict[str, dict] = {}
|
| 154 |
+
for s in splits_data.get("splits", []):
|
| 155 |
+
cfg = s.get("config", "default")
|
| 156 |
+
if cfg not in configs:
|
| 157 |
+
configs[cfg] = {"name": cfg, "splits": []}
|
| 158 |
+
configs[cfg]["splits"].append(s.get("split"))
|
| 159 |
+
return list(configs.values())
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _format_structure(configs: list) -> str:
|
| 163 |
+
"""Format splits as markdown table"""
|
| 164 |
+
lines = ["## Structure", "| Config | Split |", "|--------|-------|"]
|
| 165 |
+
for cfg in configs:
|
| 166 |
+
for split_name in cfg["splits"]:
|
| 167 |
+
lines.append(f"| {cfg['name']} | {split_name} |")
|
| 168 |
+
return "\n".join(lines)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _format_schema(info: dict, config: str) -> str:
|
| 172 |
+
"""Extract features and format as table"""
|
| 173 |
+
features = info.get("dataset_info", {}).get("features", {})
|
| 174 |
+
lines = [f"## Schema ({config})", "| Column | Type |", "|--------|------|"]
|
| 175 |
+
for col_name, col_info in features.items():
|
| 176 |
+
col_type = _get_type_str(col_info)
|
| 177 |
+
lines.append(f"| {col_name} | {col_type} |")
|
| 178 |
+
return "\n".join(lines)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _get_type_str(col_info: dict) -> str:
|
| 182 |
+
"""Convert feature info to readable type string"""
|
| 183 |
+
dtype = col_info.get("dtype") or col_info.get("_type", "unknown")
|
| 184 |
+
if col_info.get("_type") == "ClassLabel":
|
| 185 |
+
names = col_info.get("names", [])
|
| 186 |
+
if names and len(names) <= 5:
|
| 187 |
+
return f"ClassLabel ({', '.join(f'{n}={i}' for i, n in enumerate(names))})"
|
| 188 |
+
return f"ClassLabel ({len(names)} classes)"
|
| 189 |
+
return str(dtype)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _format_samples(rows_data: dict, config: str, split: str, limit: int) -> str:
|
| 193 |
+
"""Format sample rows, truncate long values"""
|
| 194 |
+
rows = rows_data.get("rows", [])[:limit]
|
| 195 |
+
lines = [f"## Sample Rows ({config}/{split})"]
|
| 196 |
+
|
| 197 |
+
messages_col_data = None
|
| 198 |
+
|
| 199 |
+
for i, row_wrapper in enumerate(rows, 1):
|
| 200 |
+
row = row_wrapper.get("row", {})
|
| 201 |
+
lines.append(f"**Row {i}:**")
|
| 202 |
+
for key, val in row.items():
|
| 203 |
+
# Check for messages column and capture first one for format analysis
|
| 204 |
+
if key.lower() == "messages" and messages_col_data is None:
|
| 205 |
+
messages_col_data = val
|
| 206 |
+
|
| 207 |
+
val_str = str(val)
|
| 208 |
+
if len(val_str) > 150:
|
| 209 |
+
val_str = val_str[:150] + "..."
|
| 210 |
+
lines.append(f"- {key}: {val_str}")
|
| 211 |
+
|
| 212 |
+
# If we found a messages column, add format analysis
|
| 213 |
+
if messages_col_data is not None:
|
| 214 |
+
messages_format = _format_messages_structure(messages_col_data)
|
| 215 |
+
if messages_format:
|
| 216 |
+
lines.append("")
|
| 217 |
+
lines.append(messages_format)
|
| 218 |
+
|
| 219 |
+
return "\n".join(lines)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def _format_messages_structure(messages_data: Any) -> str | None:
|
| 223 |
+
"""
|
| 224 |
+
Analyze and format the structure of a messages column.
|
| 225 |
+
Common in chat/instruction datasets.
|
| 226 |
+
"""
|
| 227 |
+
import json
|
| 228 |
+
|
| 229 |
+
# Parse if string
|
| 230 |
+
if isinstance(messages_data, str):
|
| 231 |
+
try:
|
| 232 |
+
messages_data = json.loads(messages_data)
|
| 233 |
+
except json.JSONDecodeError:
|
| 234 |
+
return None
|
| 235 |
+
|
| 236 |
+
if not isinstance(messages_data, list) or not messages_data:
|
| 237 |
+
return None
|
| 238 |
+
|
| 239 |
+
lines = ["## Messages Column Format"]
|
| 240 |
+
|
| 241 |
+
# Analyze message structure
|
| 242 |
+
roles_seen = set()
|
| 243 |
+
has_tool_calls = False
|
| 244 |
+
has_tool_results = False
|
| 245 |
+
message_keys = set()
|
| 246 |
+
|
| 247 |
+
for msg in messages_data:
|
| 248 |
+
if not isinstance(msg, dict):
|
| 249 |
+
continue
|
| 250 |
+
|
| 251 |
+
message_keys.update(msg.keys())
|
| 252 |
+
|
| 253 |
+
role = msg.get("role", "")
|
| 254 |
+
if role:
|
| 255 |
+
roles_seen.add(role)
|
| 256 |
+
|
| 257 |
+
if "tool_calls" in msg or "function_call" in msg:
|
| 258 |
+
has_tool_calls = True
|
| 259 |
+
if role in ("tool", "function") or msg.get("tool_call_id"):
|
| 260 |
+
has_tool_results = True
|
| 261 |
+
|
| 262 |
+
# Format the analysis
|
| 263 |
+
lines.append(
|
| 264 |
+
f"**Roles:** {', '.join(sorted(roles_seen)) if roles_seen else 'unknown'}"
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Show common message keys with presence indicators
|
| 268 |
+
common_keys = [
|
| 269 |
+
"role",
|
| 270 |
+
"content",
|
| 271 |
+
"tool_calls",
|
| 272 |
+
"tool_call_id",
|
| 273 |
+
"name",
|
| 274 |
+
"function_call",
|
| 275 |
+
]
|
| 276 |
+
key_status = []
|
| 277 |
+
for key in common_keys:
|
| 278 |
+
if key in message_keys:
|
| 279 |
+
key_status.append(f"{key} ✓")
|
| 280 |
+
else:
|
| 281 |
+
key_status.append(f"{key} ✗")
|
| 282 |
+
lines.append(f"**Message keys:** {', '.join(key_status)}")
|
| 283 |
+
|
| 284 |
+
if has_tool_calls:
|
| 285 |
+
lines.append("**Tool calls:** ✓ Present")
|
| 286 |
+
if has_tool_results:
|
| 287 |
+
lines.append("**Tool results:** ✓ Present")
|
| 288 |
+
|
| 289 |
+
# Show example message structure
|
| 290 |
+
# Priority: 1) message with tool_calls, 2) first assistant message, 3) first non-system message
|
| 291 |
+
example = None
|
| 292 |
+
fallback = None
|
| 293 |
+
for msg in messages_data:
|
| 294 |
+
if not isinstance(msg, dict):
|
| 295 |
+
continue
|
| 296 |
+
role = msg.get("role", "")
|
| 297 |
+
# Check for actual tool_calls/function_call values (not None)
|
| 298 |
+
if msg.get("tool_calls") or msg.get("function_call"):
|
| 299 |
+
example = msg
|
| 300 |
+
break
|
| 301 |
+
if role == "assistant" and example is None:
|
| 302 |
+
example = msg
|
| 303 |
+
elif role != "system" and fallback is None:
|
| 304 |
+
fallback = msg
|
| 305 |
+
if example is None:
|
| 306 |
+
example = fallback
|
| 307 |
+
|
| 308 |
+
if example:
|
| 309 |
+
lines.append("")
|
| 310 |
+
lines.append("**Example message structure:**")
|
| 311 |
+
# Build a copy with truncated content but keep all keys
|
| 312 |
+
example_clean = {}
|
| 313 |
+
for key, val in example.items():
|
| 314 |
+
if key == "content" and isinstance(val, str) and len(val) > 100:
|
| 315 |
+
example_clean[key] = val[:100] + "..."
|
| 316 |
+
else:
|
| 317 |
+
example_clean[key] = val
|
| 318 |
+
lines.append("```json")
|
| 319 |
+
lines.append(json.dumps(example_clean, indent=2, ensure_ascii=False))
|
| 320 |
+
lines.append("```")
|
| 321 |
+
|
| 322 |
+
return "\n".join(lines)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def _format_parquet_files(data: dict) -> str | None:
|
| 326 |
+
"""Format parquet file info, return None if no files"""
|
| 327 |
+
files = data.get("parquet_files", [])
|
| 328 |
+
if not files:
|
| 329 |
+
return None
|
| 330 |
+
|
| 331 |
+
# Group by config/split
|
| 332 |
+
groups: dict[str, dict] = {}
|
| 333 |
+
for f in files:
|
| 334 |
+
key = f"{f.get('config', 'default')}/{f.get('split', 'train')}"
|
| 335 |
+
if key not in groups:
|
| 336 |
+
groups[key] = {"count": 0, "size": 0}
|
| 337 |
+
groups[key]["count"] += 1
|
| 338 |
+
groups[key]["size"] += f.get("size", 0)
|
| 339 |
+
|
| 340 |
+
lines = ["## Files (Parquet)"]
|
| 341 |
+
for key, info in groups.items():
|
| 342 |
+
size_mb = info["size"] / (1024 * 1024)
|
| 343 |
+
lines.append(f"- {key}: {info['count']} file(s) ({size_mb:.1f} MB)")
|
| 344 |
+
return "\n".join(lines)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
# Tool specification
|
| 348 |
+
HF_INSPECT_DATASET_TOOL_SPEC = {
|
| 349 |
+
"name": "hf_inspect_dataset",
|
| 350 |
+
"description": (
|
| 351 |
+
"Inspect a Hugging Face dataset comprehensively in one call.\n\n"
|
| 352 |
+
"## What you get\n"
|
| 353 |
+
"- Status check (validates dataset works without errors)\n"
|
| 354 |
+
"- All configs and splits\n"
|
| 355 |
+
"- Column names and types (schema)\n"
|
| 356 |
+
"- Sample rows to understand data format\n"
|
| 357 |
+
"- Parquet file structure and sizes\n\n"
|
| 358 |
+
"## CRITICAL\n"
|
| 359 |
+
"**Always inspect datasets before writing training code** to understand:\n"
|
| 360 |
+
"- Column names for your dataloader\n"
|
| 361 |
+
"- Data types and format\n"
|
| 362 |
+
"- Available splits (train/test/validation)\n\n"
|
| 363 |
+
"Supports private/gated datasets when HF_TOKEN is set.\n\n"
|
| 364 |
+
"## Examples\n"
|
| 365 |
+
'{"dataset": "stanfordnlp/imdb"}\n'
|
| 366 |
+
'{"dataset": "nyu-mll/glue", "config": "mrpc", "sample_rows": 5}\n'
|
| 367 |
+
),
|
| 368 |
+
"parameters": {
|
| 369 |
+
"type": "object",
|
| 370 |
+
"properties": {
|
| 371 |
+
"dataset": {
|
| 372 |
+
"type": "string",
|
| 373 |
+
"description": "Dataset ID in 'org/name' format (e.g., 'stanfordnlp/imdb')",
|
| 374 |
+
},
|
| 375 |
+
"config": {
|
| 376 |
+
"type": "string",
|
| 377 |
+
"description": "Config/subset name. Auto-detected if not specified.",
|
| 378 |
+
},
|
| 379 |
+
"split": {
|
| 380 |
+
"type": "string",
|
| 381 |
+
"description": "Split for sample rows. Auto-detected if not specified.",
|
| 382 |
+
},
|
| 383 |
+
"sample_rows": {
|
| 384 |
+
"type": "integer",
|
| 385 |
+
"description": "Number of sample rows to show (default: 3, max: 10)",
|
| 386 |
+
"default": 3,
|
| 387 |
+
},
|
| 388 |
+
},
|
| 389 |
+
"required": ["dataset"],
|
| 390 |
+
},
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
async def hf_inspect_dataset_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
|
| 395 |
+
"""Handler for agent tool router"""
|
| 396 |
+
try:
|
| 397 |
+
result = await inspect_dataset(
|
| 398 |
+
dataset=arguments["dataset"],
|
| 399 |
+
config=arguments.get("config"),
|
| 400 |
+
split=arguments.get("split"),
|
| 401 |
+
sample_rows=min(arguments.get("sample_rows", 3), 10),
|
| 402 |
+
)
|
| 403 |
+
return result["formatted"], not result.get("isError", False)
|
| 404 |
+
except Exception as e:
|
| 405 |
+
return f"Error inspecting dataset: {str(e)}", False
|