Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Merge pull request #20 from huggingface/dataset_tool_improved
Browse files- agent/tools/dataset_tools.py +53 -13
agent/tools/dataset_tools.py
CHANGED
|
@@ -7,7 +7,7 @@ to provide everything needed for ML tasks in a single tool call.
|
|
| 7 |
|
| 8 |
import asyncio
|
| 9 |
import os
|
| 10 |
-
from typing import Any
|
| 11 |
|
| 12 |
import httpx
|
| 13 |
|
|
@@ -15,6 +15,16 @@ from agent.tools.types import ToolResult
|
|
| 15 |
|
| 16 |
BASE_URL = "https://datasets-server.huggingface.co"
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
def _get_headers() -> dict:
|
| 20 |
"""Get auth headers for private/gated datasets"""
|
|
@@ -148,9 +158,9 @@ def _format_status(data: dict) -> str:
|
|
| 148 |
return "## Status\n✗ Dataset may have issues"
|
| 149 |
|
| 150 |
|
| 151 |
-
def _extract_configs(splits_data: dict) -> list[
|
| 152 |
"""Group splits by config"""
|
| 153 |
-
configs: dict[str,
|
| 154 |
for s in splits_data.get("splits", []):
|
| 155 |
cfg = s.get("config", "default")
|
| 156 |
if cfg not in configs:
|
|
@@ -159,12 +169,29 @@ def _extract_configs(splits_data: dict) -> list[dict]:
|
|
| 159 |
return list(configs.values())
|
| 160 |
|
| 161 |
|
| 162 |
-
def _format_structure(
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
for cfg in configs:
|
| 166 |
for split_name in cfg["splits"]:
|
|
|
|
|
|
|
| 167 |
lines.append(f"| {cfg['name']} | {split_name} |")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
return "\n".join(lines)
|
| 169 |
|
| 170 |
|
|
@@ -205,8 +232,8 @@ def _format_samples(rows_data: dict, config: str, split: str, limit: int) -> str
|
|
| 205 |
messages_col_data = val
|
| 206 |
|
| 207 |
val_str = str(val)
|
| 208 |
-
if len(val_str) >
|
| 209 |
-
val_str = val_str[:
|
| 210 |
lines.append(f"- {key}: {val_str}")
|
| 211 |
|
| 212 |
# If we found a messages column, add format analysis
|
|
@@ -322,8 +349,8 @@ def _format_messages_structure(messages_data: Any) -> str | None:
|
|
| 322 |
return "\n".join(lines)
|
| 323 |
|
| 324 |
|
| 325 |
-
def _format_parquet_files(data: dict) -> str | None:
|
| 326 |
-
"""Format parquet file info, return None if no files"""
|
| 327 |
files = data.get("parquet_files", [])
|
| 328 |
if not files:
|
| 329 |
return None
|
|
@@ -334,13 +361,26 @@ def _format_parquet_files(data: dict) -> str | None:
|
|
| 334 |
key = f"{f.get('config', 'default')}/{f.get('split', 'train')}"
|
| 335 |
if key not in groups:
|
| 336 |
groups[key] = {"count": 0, "size": 0}
|
|
|
|
|
|
|
|
|
|
| 337 |
groups[key]["count"] += 1
|
| 338 |
-
groups[key]["size"] +=
|
| 339 |
|
| 340 |
lines = ["## Files (Parquet)"]
|
| 341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
size_mb = info["size"] / (1024 * 1024)
|
| 343 |
lines.append(f"- {key}: {info['count']} file(s) ({size_mb:.1f} MB)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
return "\n".join(lines)
|
| 345 |
|
| 346 |
|
|
@@ -351,7 +391,7 @@ HF_INSPECT_DATASET_TOOL_SPEC = {
|
|
| 351 |
"Inspect a Hugging Face dataset comprehensively in one call.\n\n"
|
| 352 |
"## What you get\n"
|
| 353 |
"- Status check (validates dataset works without errors)\n"
|
| 354 |
-
"- All configs and splits\n"
|
| 355 |
"- Column names and types (schema)\n"
|
| 356 |
"- Sample rows to understand data format\n"
|
| 357 |
"- Parquet file structure and sizes\n\n"
|
|
|
|
| 7 |
|
| 8 |
import asyncio
|
| 9 |
import os
|
| 10 |
+
from typing import Any, TypedDict
|
| 11 |
|
| 12 |
import httpx
|
| 13 |
|
|
|
|
| 15 |
|
| 16 |
BASE_URL = "https://datasets-server.huggingface.co"
|
| 17 |
|
| 18 |
+
# Truncation limit for long sample values in the output
|
| 19 |
+
MAX_SAMPLE_VALUE_LEN = 150
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SplitConfig(TypedDict):
|
| 23 |
+
"""Typed representation of a dataset config and its splits."""
|
| 24 |
+
|
| 25 |
+
name: str
|
| 26 |
+
splits: list[str]
|
| 27 |
+
|
| 28 |
|
| 29 |
def _get_headers() -> dict:
|
| 30 |
"""Get auth headers for private/gated datasets"""
|
|
|
|
| 158 |
return "## Status\n✗ Dataset may have issues"
|
| 159 |
|
| 160 |
|
| 161 |
+
def _extract_configs(splits_data: dict) -> list[SplitConfig]:
|
| 162 |
"""Group splits by config"""
|
| 163 |
+
configs: dict[str, SplitConfig] = {}
|
| 164 |
for s in splits_data.get("splits", []):
|
| 165 |
cfg = s.get("config", "default")
|
| 166 |
if cfg not in configs:
|
|
|
|
| 169 |
return list(configs.values())
|
| 170 |
|
| 171 |
|
| 172 |
+
def _format_structure(
|
| 173 |
+
configs: list[SplitConfig], max_rows: int = 10
|
| 174 |
+
) -> str:
|
| 175 |
+
"""Format configs and splits as a markdown table."""
|
| 176 |
+
lines = ["## Structure (configs & splits)", "| Config | Split |", "|--------|-------|"]
|
| 177 |
+
|
| 178 |
+
total_splits = sum(len(cfg["splits"]) for cfg in configs)
|
| 179 |
+
added_rows = 0
|
| 180 |
+
|
| 181 |
for cfg in configs:
|
| 182 |
for split_name in cfg["splits"]:
|
| 183 |
+
if added_rows >= max_rows:
|
| 184 |
+
break
|
| 185 |
lines.append(f"| {cfg['name']} | {split_name} |")
|
| 186 |
+
added_rows += 1
|
| 187 |
+
if added_rows >= max_rows:
|
| 188 |
+
break
|
| 189 |
+
|
| 190 |
+
if total_splits > added_rows:
|
| 191 |
+
lines.append(
|
| 192 |
+
f"| ... | ... | (_showing {added_rows} of {total_splits} config/split rows_) |"
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
return "\n".join(lines)
|
| 196 |
|
| 197 |
|
|
|
|
| 232 |
messages_col_data = val
|
| 233 |
|
| 234 |
val_str = str(val)
|
| 235 |
+
if len(val_str) > MAX_SAMPLE_VALUE_LEN:
|
| 236 |
+
val_str = val_str[:MAX_SAMPLE_VALUE_LEN] + "..."
|
| 237 |
lines.append(f"- {key}: {val_str}")
|
| 238 |
|
| 239 |
# If we found a messages column, add format analysis
|
|
|
|
| 349 |
return "\n".join(lines)
|
| 350 |
|
| 351 |
|
| 352 |
+
def _format_parquet_files(data: dict, max_rows: int = 10) -> str | None:
|
| 353 |
+
"""Format parquet file info, return None if no files."""
|
| 354 |
files = data.get("parquet_files", [])
|
| 355 |
if not files:
|
| 356 |
return None
|
|
|
|
| 361 |
key = f"{f.get('config', 'default')}/{f.get('split', 'train')}"
|
| 362 |
if key not in groups:
|
| 363 |
groups[key] = {"count": 0, "size": 0}
|
| 364 |
+
size = f.get("size") or 0
|
| 365 |
+
if not isinstance(size, (int, float)):
|
| 366 |
+
size = 0
|
| 367 |
groups[key]["count"] += 1
|
| 368 |
+
groups[key]["size"] += int(size)
|
| 369 |
|
| 370 |
lines = ["## Files (Parquet)"]
|
| 371 |
+
items = list(groups.items())
|
| 372 |
+
total_groups = len(items)
|
| 373 |
+
|
| 374 |
+
shown = 0
|
| 375 |
+
for key, info in items[:max_rows]:
|
| 376 |
size_mb = info["size"] / (1024 * 1024)
|
| 377 |
lines.append(f"- {key}: {info['count']} file(s) ({size_mb:.1f} MB)")
|
| 378 |
+
shown += 1
|
| 379 |
+
|
| 380 |
+
if total_groups > shown:
|
| 381 |
+
lines.append(
|
| 382 |
+
f"- ... (_showing {shown} of {total_groups} parquet groups_)"
|
| 383 |
+
)
|
| 384 |
return "\n".join(lines)
|
| 385 |
|
| 386 |
|
|
|
|
| 391 |
"Inspect a Hugging Face dataset comprehensively in one call.\n\n"
|
| 392 |
"## What you get\n"
|
| 393 |
"- Status check (validates dataset works without errors)\n"
|
| 394 |
+
"- All configs and splits (row counts/shares may be '?' when metadata is missing)\n"
|
| 395 |
"- Column names and types (schema)\n"
|
| 396 |
"- Sample rows to understand data format\n"
|
| 397 |
"- Parquet file structure and sizes\n\n"
|