akseljoonas HF Staff commited on
Commit
e96ab7e
·
2 Parent(s): e2b91f0a8d5460

Merge branch 'dataset_tools'

Browse files
agent/core/tools.py CHANGED
@@ -13,6 +13,10 @@ from lmnr import observe
13
  from mcp.types import EmbeddedResource, ImageContent, TextContent
14
 
15
  from agent.config import MCPServerConfig
 
 
 
 
16
  from agent.tools.docs_tools import (
17
  EXPLORE_HF_DOCS_TOOL_SPEC,
18
  HF_DOCS_FETCH_TOOL_SPEC,
@@ -257,6 +261,13 @@ def create_builtin_tools() -> list[ToolSpec]:
257
  parameters=HF_DOCS_FETCH_TOOL_SPEC["parameters"],
258
  handler=hf_docs_fetch_handler,
259
  ),
 
 
 
 
 
 
 
260
  # Planning and job management tools
261
  ToolSpec(
262
  name=PLAN_TOOL_SPEC["name"],
 
13
  from mcp.types import EmbeddedResource, ImageContent, TextContent
14
 
15
  from agent.config import MCPServerConfig
16
+ from agent.tools.dataset_tools import (
17
+ HF_INSPECT_DATASET_TOOL_SPEC,
18
+ hf_inspect_dataset_handler,
19
+ )
20
  from agent.tools.docs_tools import (
21
  EXPLORE_HF_DOCS_TOOL_SPEC,
22
  HF_DOCS_FETCH_TOOL_SPEC,
 
261
  parameters=HF_DOCS_FETCH_TOOL_SPEC["parameters"],
262
  handler=hf_docs_fetch_handler,
263
  ),
264
+ # Dataset inspection tool (unified)
265
+ ToolSpec(
266
+ name=HF_INSPECT_DATASET_TOOL_SPEC["name"],
267
+ description=HF_INSPECT_DATASET_TOOL_SPEC["description"],
268
+ parameters=HF_INSPECT_DATASET_TOOL_SPEC["parameters"],
269
+ handler=hf_inspect_dataset_handler,
270
+ ),
271
  # Planning and job management tools
272
  ToolSpec(
273
  name=PLAN_TOOL_SPEC["name"],
agent/tools/__init__.py CHANGED
@@ -18,6 +18,10 @@ from agent.tools.github_search_code import (
18
  GITHUB_SEARCH_CODE_TOOL_SPEC,
19
  github_search_code_handler,
20
  )
 
 
 
 
21
  from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, HfJobsTool, hf_jobs_handler
22
  from agent.tools.types import ToolResult
23
 
@@ -34,4 +38,6 @@ __all__ = [
34
  "github_read_file_handler",
35
  "GITHUB_SEARCH_CODE_TOOL_SPEC",
36
  "github_search_code_handler",
 
 
37
  ]
 
18
  GITHUB_SEARCH_CODE_TOOL_SPEC,
19
  github_search_code_handler,
20
  )
21
+ from agent.tools.dataset_tools import (
22
+ HF_INSPECT_DATASET_TOOL_SPEC,
23
+ hf_inspect_dataset_handler,
24
+ )
25
  from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, HfJobsTool, hf_jobs_handler
26
  from agent.tools.types import ToolResult
27
 
 
38
  "github_read_file_handler",
39
  "GITHUB_SEARCH_CODE_TOOL_SPEC",
40
  "github_search_code_handler",
41
+ "HF_INSPECT_DATASET_TOOL_SPEC",
42
+ "hf_inspect_dataset_handler",
43
  ]
agent/tools/dataset_tools.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset Inspection Tool - Comprehensive dataset analysis in one call
3
+
4
+ Combines /is-valid, /splits, /info, /first-rows, and /parquet endpoints
5
+ to provide everything needed for ML tasks in a single tool call.
6
+ """
7
+
8
+ import asyncio
9
+ import os
10
+ from typing import Any
11
+
12
+ import httpx
13
+
14
+ from agent.tools.types import ToolResult
15
+
16
+ BASE_URL = "https://datasets-server.huggingface.co"
17
+
18
+
19
+ def _get_headers() -> dict:
20
+ """Get auth headers for private/gated datasets"""
21
+ token = os.environ.get("HF_TOKEN")
22
+ if token:
23
+ return {"Authorization": f"Bearer {token}"}
24
+ return {}
25
+
26
+
27
+ async def inspect_dataset(
28
+ dataset: str,
29
+ config: str | None = None,
30
+ split: str | None = None,
31
+ sample_rows: int = 3,
32
+ ) -> ToolResult:
33
+ """
34
+ Get comprehensive dataset info in one call.
35
+ All API calls made in parallel for speed.
36
+ """
37
+ headers = _get_headers()
38
+ output_parts = []
39
+ errors = []
40
+
41
+ async with httpx.AsyncClient(timeout=15, headers=headers) as client:
42
+ # Phase 1: Parallel calls for structure info (no dependencies)
43
+ is_valid_task = client.get(f"{BASE_URL}/is-valid", params={"dataset": dataset})
44
+ splits_task = client.get(f"{BASE_URL}/splits", params={"dataset": dataset})
45
+ parquet_task = client.get(f"{BASE_URL}/parquet", params={"dataset": dataset})
46
+
47
+ results = await asyncio.gather(
48
+ is_valid_task,
49
+ splits_task,
50
+ parquet_task,
51
+ return_exceptions=True,
52
+ )
53
+
54
+ # Process is-valid
55
+ if not isinstance(results[0], Exception):
56
+ try:
57
+ output_parts.append(_format_status(results[0].json()))
58
+ except Exception as e:
59
+ errors.append(f"is-valid: {e}")
60
+
61
+ # Process splits and auto-detect config/split
62
+ configs = []
63
+ if not isinstance(results[1], Exception):
64
+ try:
65
+ splits_data = results[1].json()
66
+ configs = _extract_configs(splits_data)
67
+ if not config:
68
+ config = configs[0]["name"] if configs else "default"
69
+ if not split:
70
+ split = configs[0]["splits"][0] if configs else "train"
71
+ output_parts.append(_format_structure(configs))
72
+ except Exception as e:
73
+ errors.append(f"splits: {e}")
74
+
75
+ if not config:
76
+ config = "default"
77
+ if not split:
78
+ split = "train"
79
+
80
+ # Process parquet (will be added at the end)
81
+ parquet_section = None
82
+ if not isinstance(results[2], Exception):
83
+ try:
84
+ parquet_section = _format_parquet_files(results[2].json())
85
+ except Exception:
86
+ pass # Silently skip if no parquet
87
+
88
+ # Phase 2: Parallel calls for content (depend on config/split)
89
+ info_task = client.get(
90
+ f"{BASE_URL}/info", params={"dataset": dataset, "config": config}
91
+ )
92
+ rows_task = client.get(
93
+ f"{BASE_URL}/first-rows",
94
+ params={"dataset": dataset, "config": config, "split": split},
95
+ timeout=30,
96
+ )
97
+
98
+ content_results = await asyncio.gather(
99
+ info_task,
100
+ rows_task,
101
+ return_exceptions=True,
102
+ )
103
+
104
+ # Process info (schema)
105
+ if not isinstance(content_results[0], Exception):
106
+ try:
107
+ output_parts.append(_format_schema(content_results[0].json(), config))
108
+ except Exception as e:
109
+ errors.append(f"info: {e}")
110
+
111
+ # Process sample rows
112
+ if not isinstance(content_results[1], Exception):
113
+ try:
114
+ output_parts.append(
115
+ _format_samples(
116
+ content_results[1].json(), config, split, sample_rows
117
+ )
118
+ )
119
+ except Exception as e:
120
+ errors.append(f"rows: {e}")
121
+
122
+ # Add parquet section at the end if available
123
+ if parquet_section:
124
+ output_parts.append(parquet_section)
125
+
126
+ # Combine output
127
+ formatted = f"# {dataset}\n\n" + "\n\n".join(output_parts)
128
+ if errors:
129
+ formatted += f"\n\n**Warnings:** {'; '.join(errors)}"
130
+
131
+ return {
132
+ "formatted": formatted,
133
+ "totalResults": 1,
134
+ "resultsShared": 1,
135
+ "isError": len(output_parts) == 0,
136
+ }
137
+
138
+
139
+ def _format_status(data: dict) -> str:
140
+ """Format /is-valid response as status line"""
141
+ available = [
142
+ k
143
+ for k in ["viewer", "preview", "search", "filter", "statistics"]
144
+ if data.get(k)
145
+ ]
146
+ if available:
147
+ return f"## Status\n✓ Valid ({', '.join(available)})"
148
+ return "## Status\n✗ Dataset may have issues"
149
+
150
+
151
+ def _extract_configs(splits_data: dict) -> list[dict]:
152
+ """Group splits by config"""
153
+ configs: dict[str, dict] = {}
154
+ for s in splits_data.get("splits", []):
155
+ cfg = s.get("config", "default")
156
+ if cfg not in configs:
157
+ configs[cfg] = {"name": cfg, "splits": []}
158
+ configs[cfg]["splits"].append(s.get("split"))
159
+ return list(configs.values())
160
+
161
+
162
+ def _format_structure(configs: list) -> str:
163
+ """Format splits as markdown table"""
164
+ lines = ["## Structure", "| Config | Split |", "|--------|-------|"]
165
+ for cfg in configs:
166
+ for split_name in cfg["splits"]:
167
+ lines.append(f"| {cfg['name']} | {split_name} |")
168
+ return "\n".join(lines)
169
+
170
+
171
+ def _format_schema(info: dict, config: str) -> str:
172
+ """Extract features and format as table"""
173
+ features = info.get("dataset_info", {}).get("features", {})
174
+ lines = [f"## Schema ({config})", "| Column | Type |", "|--------|------|"]
175
+ for col_name, col_info in features.items():
176
+ col_type = _get_type_str(col_info)
177
+ lines.append(f"| {col_name} | {col_type} |")
178
+ return "\n".join(lines)
179
+
180
+
181
+ def _get_type_str(col_info: dict) -> str:
182
+ """Convert feature info to readable type string"""
183
+ dtype = col_info.get("dtype") or col_info.get("_type", "unknown")
184
+ if col_info.get("_type") == "ClassLabel":
185
+ names = col_info.get("names", [])
186
+ if names and len(names) <= 5:
187
+ return f"ClassLabel ({', '.join(f'{n}={i}' for i, n in enumerate(names))})"
188
+ return f"ClassLabel ({len(names)} classes)"
189
+ return str(dtype)
190
+
191
+
192
+ def _format_samples(rows_data: dict, config: str, split: str, limit: int) -> str:
193
+ """Format sample rows, truncate long values"""
194
+ rows = rows_data.get("rows", [])[:limit]
195
+ lines = [f"## Sample Rows ({config}/{split})"]
196
+
197
+ messages_col_data = None
198
+
199
+ for i, row_wrapper in enumerate(rows, 1):
200
+ row = row_wrapper.get("row", {})
201
+ lines.append(f"**Row {i}:**")
202
+ for key, val in row.items():
203
+ # Check for messages column and capture first one for format analysis
204
+ if key.lower() == "messages" and messages_col_data is None:
205
+ messages_col_data = val
206
+
207
+ val_str = str(val)
208
+ if len(val_str) > 150:
209
+ val_str = val_str[:150] + "..."
210
+ lines.append(f"- {key}: {val_str}")
211
+
212
+ # If we found a messages column, add format analysis
213
+ if messages_col_data is not None:
214
+ messages_format = _format_messages_structure(messages_col_data)
215
+ if messages_format:
216
+ lines.append("")
217
+ lines.append(messages_format)
218
+
219
+ return "\n".join(lines)
220
+
221
+
222
+ def _format_messages_structure(messages_data: Any) -> str | None:
223
+ """
224
+ Analyze and format the structure of a messages column.
225
+ Common in chat/instruction datasets.
226
+ """
227
+ import json
228
+
229
+ # Parse if string
230
+ if isinstance(messages_data, str):
231
+ try:
232
+ messages_data = json.loads(messages_data)
233
+ except json.JSONDecodeError:
234
+ return None
235
+
236
+ if not isinstance(messages_data, list) or not messages_data:
237
+ return None
238
+
239
+ lines = ["## Messages Column Format"]
240
+
241
+ # Analyze message structure
242
+ roles_seen = set()
243
+ has_tool_calls = False
244
+ has_tool_results = False
245
+ message_keys = set()
246
+
247
+ for msg in messages_data:
248
+ if not isinstance(msg, dict):
249
+ continue
250
+
251
+ message_keys.update(msg.keys())
252
+
253
+ role = msg.get("role", "")
254
+ if role:
255
+ roles_seen.add(role)
256
+
257
+ if "tool_calls" in msg or "function_call" in msg:
258
+ has_tool_calls = True
259
+ if role in ("tool", "function") or msg.get("tool_call_id"):
260
+ has_tool_results = True
261
+
262
+ # Format the analysis
263
+ lines.append(
264
+ f"**Roles:** {', '.join(sorted(roles_seen)) if roles_seen else 'unknown'}"
265
+ )
266
+
267
+ # Show common message keys with presence indicators
268
+ common_keys = [
269
+ "role",
270
+ "content",
271
+ "tool_calls",
272
+ "tool_call_id",
273
+ "name",
274
+ "function_call",
275
+ ]
276
+ key_status = []
277
+ for key in common_keys:
278
+ if key in message_keys:
279
+ key_status.append(f"{key} ✓")
280
+ else:
281
+ key_status.append(f"{key} ✗")
282
+ lines.append(f"**Message keys:** {', '.join(key_status)}")
283
+
284
+ if has_tool_calls:
285
+ lines.append("**Tool calls:** ✓ Present")
286
+ if has_tool_results:
287
+ lines.append("**Tool results:** ✓ Present")
288
+
289
+ # Show example message structure
290
+ # Priority: 1) message with tool_calls, 2) first assistant message, 3) first non-system message
291
+ example = None
292
+ fallback = None
293
+ for msg in messages_data:
294
+ if not isinstance(msg, dict):
295
+ continue
296
+ role = msg.get("role", "")
297
+ # Check for actual tool_calls/function_call values (not None)
298
+ if msg.get("tool_calls") or msg.get("function_call"):
299
+ example = msg
300
+ break
301
+ if role == "assistant" and example is None:
302
+ example = msg
303
+ elif role != "system" and fallback is None:
304
+ fallback = msg
305
+ if example is None:
306
+ example = fallback
307
+
308
+ if example:
309
+ lines.append("")
310
+ lines.append("**Example message structure:**")
311
+ # Build a copy with truncated content but keep all keys
312
+ example_clean = {}
313
+ for key, val in example.items():
314
+ if key == "content" and isinstance(val, str) and len(val) > 100:
315
+ example_clean[key] = val[:100] + "..."
316
+ else:
317
+ example_clean[key] = val
318
+ lines.append("```json")
319
+ lines.append(json.dumps(example_clean, indent=2, ensure_ascii=False))
320
+ lines.append("```")
321
+
322
+ return "\n".join(lines)
323
+
324
+
325
+ def _format_parquet_files(data: dict) -> str | None:
326
+ """Format parquet file info, return None if no files"""
327
+ files = data.get("parquet_files", [])
328
+ if not files:
329
+ return None
330
+
331
+ # Group by config/split
332
+ groups: dict[str, dict] = {}
333
+ for f in files:
334
+ key = f"{f.get('config', 'default')}/{f.get('split', 'train')}"
335
+ if key not in groups:
336
+ groups[key] = {"count": 0, "size": 0}
337
+ groups[key]["count"] += 1
338
+ groups[key]["size"] += f.get("size", 0)
339
+
340
+ lines = ["## Files (Parquet)"]
341
+ for key, info in groups.items():
342
+ size_mb = info["size"] / (1024 * 1024)
343
+ lines.append(f"- {key}: {info['count']} file(s) ({size_mb:.1f} MB)")
344
+ return "\n".join(lines)
345
+
346
+
347
+ # Tool specification
348
+ HF_INSPECT_DATASET_TOOL_SPEC = {
349
+ "name": "hf_inspect_dataset",
350
+ "description": (
351
+ "Inspect a Hugging Face dataset comprehensively in one call.\n\n"
352
+ "## What you get\n"
353
+ "- Status check (validates dataset works without errors)\n"
354
+ "- All configs and splits\n"
355
+ "- Column names and types (schema)\n"
356
+ "- Sample rows to understand data format\n"
357
+ "- Parquet file structure and sizes\n\n"
358
+ "## CRITICAL\n"
359
+ "**Always inspect datasets before writing training code** to understand:\n"
360
+ "- Column names for your dataloader\n"
361
+ "- Data types and format\n"
362
+ "- Available splits (train/test/validation)\n\n"
363
+ "Supports private/gated datasets when HF_TOKEN is set.\n\n"
364
+ "## Examples\n"
365
+ '{"dataset": "stanfordnlp/imdb"}\n'
366
+ '{"dataset": "nyu-mll/glue", "config": "mrpc", "sample_rows": 5}\n'
367
+ ),
368
+ "parameters": {
369
+ "type": "object",
370
+ "properties": {
371
+ "dataset": {
372
+ "type": "string",
373
+ "description": "Dataset ID in 'org/name' format (e.g., 'stanfordnlp/imdb')",
374
+ },
375
+ "config": {
376
+ "type": "string",
377
+ "description": "Config/subset name. Auto-detected if not specified.",
378
+ },
379
+ "split": {
380
+ "type": "string",
381
+ "description": "Split for sample rows. Auto-detected if not specified.",
382
+ },
383
+ "sample_rows": {
384
+ "type": "integer",
385
+ "description": "Number of sample rows to show (default: 3, max: 10)",
386
+ "default": 3,
387
+ },
388
+ },
389
+ "required": ["dataset"],
390
+ },
391
+ }
392
+
393
+
394
+ async def hf_inspect_dataset_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
395
+ """Handler for agent tool router"""
396
+ try:
397
+ result = await inspect_dataset(
398
+ dataset=arguments["dataset"],
399
+ config=arguments.get("config"),
400
+ split=arguments.get("split"),
401
+ sample_rows=min(arguments.get("sample_rows", 3), 10),
402
+ )
403
+ return result["formatted"], not result.get("isError", False)
404
+ except Exception as e:
405
+ return f"Error inspecting dataset: {str(e)}", False