Aksel Joonas Reedi commited on
Commit
00ea4db
·
2 Parent(s): 64bb402197e841

Merge pull request #20 from huggingface/dataset_tool_improved

Browse files
Files changed (1) hide show
  1. agent/tools/dataset_tools.py +53 -13
agent/tools/dataset_tools.py CHANGED
@@ -7,7 +7,7 @@ to provide everything needed for ML tasks in a single tool call.
7
 
8
  import asyncio
9
  import os
10
- from typing import Any
11
 
12
  import httpx
13
 
@@ -15,6 +15,16 @@ from agent.tools.types import ToolResult
15
 
16
  BASE_URL = "https://datasets-server.huggingface.co"
17
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def _get_headers() -> dict:
20
  """Get auth headers for private/gated datasets"""
@@ -148,9 +158,9 @@ def _format_status(data: dict) -> str:
148
  return "## Status\n✗ Dataset may have issues"
149
 
150
 
151
- def _extract_configs(splits_data: dict) -> list[dict]:
152
  """Group splits by config"""
153
- configs: dict[str, dict] = {}
154
  for s in splits_data.get("splits", []):
155
  cfg = s.get("config", "default")
156
  if cfg not in configs:
@@ -159,12 +169,29 @@ def _extract_configs(splits_data: dict) -> list[dict]:
159
  return list(configs.values())
160
 
161
 
162
- def _format_structure(configs: list) -> str:
163
- """Format splits as markdown table"""
164
- lines = ["## Structure", "| Config | Split |", "|--------|-------|"]
 
 
 
 
 
 
165
  for cfg in configs:
166
  for split_name in cfg["splits"]:
 
 
167
  lines.append(f"| {cfg['name']} | {split_name} |")
 
 
 
 
 
 
 
 
 
168
  return "\n".join(lines)
169
 
170
 
@@ -205,8 +232,8 @@ def _format_samples(rows_data: dict, config: str, split: str, limit: int) -> str
205
  messages_col_data = val
206
 
207
  val_str = str(val)
208
- if len(val_str) > 150:
209
- val_str = val_str[:150] + "..."
210
  lines.append(f"- {key}: {val_str}")
211
 
212
  # If we found a messages column, add format analysis
@@ -322,8 +349,8 @@ def _format_messages_structure(messages_data: Any) -> str | None:
322
  return "\n".join(lines)
323
 
324
 
325
- def _format_parquet_files(data: dict) -> str | None:
326
- """Format parquet file info, return None if no files"""
327
  files = data.get("parquet_files", [])
328
  if not files:
329
  return None
@@ -334,13 +361,26 @@ def _format_parquet_files(data: dict) -> str | None:
334
  key = f"{f.get('config', 'default')}/{f.get('split', 'train')}"
335
  if key not in groups:
336
  groups[key] = {"count": 0, "size": 0}
 
 
 
337
  groups[key]["count"] += 1
338
- groups[key]["size"] += f.get("size", 0)
339
 
340
  lines = ["## Files (Parquet)"]
341
- for key, info in groups.items():
 
 
 
 
342
  size_mb = info["size"] / (1024 * 1024)
343
  lines.append(f"- {key}: {info['count']} file(s) ({size_mb:.1f} MB)")
 
 
 
 
 
 
344
  return "\n".join(lines)
345
 
346
 
@@ -351,7 +391,7 @@ HF_INSPECT_DATASET_TOOL_SPEC = {
351
  "Inspect a Hugging Face dataset comprehensively in one call.\n\n"
352
  "## What you get\n"
353
  "- Status check (validates dataset works without errors)\n"
354
- "- All configs and splits\n"
355
  "- Column names and types (schema)\n"
356
  "- Sample rows to understand data format\n"
357
  "- Parquet file structure and sizes\n\n"
 
7
 
8
  import asyncio
9
  import os
10
+ from typing import Any, TypedDict
11
 
12
  import httpx
13
 
 
15
 
16
  BASE_URL = "https://datasets-server.huggingface.co"
17
 
18
+ # Truncation limit for long sample values in the output
19
+ MAX_SAMPLE_VALUE_LEN = 150
20
+
21
+
22
+ class SplitConfig(TypedDict):
23
+ """Typed representation of a dataset config and its splits."""
24
+
25
+ name: str
26
+ splits: list[str]
27
+
28
 
29
  def _get_headers() -> dict:
30
  """Get auth headers for private/gated datasets"""
 
158
  return "## Status\n✗ Dataset may have issues"
159
 
160
 
161
+ def _extract_configs(splits_data: dict) -> list[SplitConfig]:
162
  """Group splits by config"""
163
+ configs: dict[str, SplitConfig] = {}
164
  for s in splits_data.get("splits", []):
165
  cfg = s.get("config", "default")
166
  if cfg not in configs:
 
169
  return list(configs.values())
170
 
171
 
172
+ def _format_structure(
173
+ configs: list[SplitConfig], max_rows: int = 10
174
+ ) -> str:
175
+ """Format configs and splits as a markdown table."""
176
+ lines = ["## Structure (configs & splits)", "| Config | Split |", "|--------|-------|"]
177
+
178
+ total_splits = sum(len(cfg["splits"]) for cfg in configs)
179
+ added_rows = 0
180
+
181
  for cfg in configs:
182
  for split_name in cfg["splits"]:
183
+ if added_rows >= max_rows:
184
+ break
185
  lines.append(f"| {cfg['name']} | {split_name} |")
186
+ added_rows += 1
187
+ if added_rows >= max_rows:
188
+ break
189
+
190
+ if total_splits > added_rows:
191
+ lines.append(
192
+ f"| ... | ... | (_showing {added_rows} of {total_splits} config/split rows_) |"
193
+ )
194
+
195
  return "\n".join(lines)
196
 
197
 
 
232
  messages_col_data = val
233
 
234
  val_str = str(val)
235
+ if len(val_str) > MAX_SAMPLE_VALUE_LEN:
236
+ val_str = val_str[:MAX_SAMPLE_VALUE_LEN] + "..."
237
  lines.append(f"- {key}: {val_str}")
238
 
239
  # If we found a messages column, add format analysis
 
349
  return "\n".join(lines)
350
 
351
 
352
+ def _format_parquet_files(data: dict, max_rows: int = 10) -> str | None:
353
+ """Format parquet file info, return None if no files."""
354
  files = data.get("parquet_files", [])
355
  if not files:
356
  return None
 
361
  key = f"{f.get('config', 'default')}/{f.get('split', 'train')}"
362
  if key not in groups:
363
  groups[key] = {"count": 0, "size": 0}
364
+ size = f.get("size") or 0
365
+ if not isinstance(size, (int, float)):
366
+ size = 0
367
  groups[key]["count"] += 1
368
+ groups[key]["size"] += int(size)
369
 
370
  lines = ["## Files (Parquet)"]
371
+ items = list(groups.items())
372
+ total_groups = len(items)
373
+
374
+ shown = 0
375
+ for key, info in items[:max_rows]:
376
  size_mb = info["size"] / (1024 * 1024)
377
  lines.append(f"- {key}: {info['count']} file(s) ({size_mb:.1f} MB)")
378
+ shown += 1
379
+
380
+ if total_groups > shown:
381
+ lines.append(
382
+ f"- ... (_showing {shown} of {total_groups} parquet groups_)"
383
+ )
384
  return "\n".join(lines)
385
 
386
 
 
391
  "Inspect a Hugging Face dataset comprehensively in one call.\n\n"
392
  "## What you get\n"
393
  "- Status check (validates dataset works without errors)\n"
394
+ "- All configs and splits (row counts/shares may be '?' when metadata is missing)\n"
395
  "- Column names and types (schema)\n"
396
  "- Sample rows to understand data format\n"
397
  "- Parquet file structure and sizes\n\n"