Spaces:
Running
Running
File size: 9,623 Bytes
3c558f5 d6ea67d 3c558f5 a1410a1 3c558f5 a1410a1 3c558f5 d6ea67d 3c558f5 fc5b955 3c558f5 d6ea67d 3c558f5 d6ea67d 3c558f5 a1410a1 3c558f5 a1410a1 d6ea67d a1410a1 d6ea67d a1410a1 d6ea67d 3c558f5 d6ea67d 3c558f5 d6ea67d 3c558f5 d6ea67d 3c558f5 d6ea67d 3c558f5 a1410a1 3c558f5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 | import os
import re
import json
import logging
from typing import Optional, Dict, Any
from huggingface_hub import HfApi, hf_hub_download
logger = logging.getLogger(__name__)
if not logger.handlers:
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
def get_dataset_stats(
repo_id: str,
hf_token: Optional[str] = None,
) -> Dict[str, Any]:
"""Get statistics for a Hugging Face dataset without downloading all files.
Supports both v2.1 and v3.0 LeRobot dataset formats.
Args:
repo_id: The HuggingFace dataset repo ID
hf_token: Optional HuggingFace token for private datasets
Returns:
Dictionary containing dataset statistics:
- total_episodes: Number of episodes (from info.json for v3.0, or counted for v2.1)
- episode_numbers: List of episode numbers found
- total_parquet_files: Total number of parquet files
- total_video_files: Total number of video files (if present)
- info_metadata: Complete metadata from info.json (if present)
- codebase_version: Dataset version (if present)
- format_version: Detected format version (v2.1 or v3.0)
"""
api = HfApi()
# Public Space: use an explicit token only if the caller passes one.
# Do not implicitly use HF_TOKEN, so visitors cannot proxy private data through this app.
token = hf_token
logger.info(f"Fetching stats for dataset: {repo_id}")
stats = {
"repo_id": repo_id,
"total_episodes": 0,
"episode_numbers": [],
"total_parquet_files": 0,
"total_video_files": 0,
"info_metadata": None,
"codebase_version": None,
"format_version": None,
"error": None,
}
try:
# Try to fetch metadata from info.json first to determine version
try:
info_path = hf_hub_download(
repo_id=repo_id,
filename="meta/info.json",
repo_type="dataset",
token=token,
)
with open(info_path, "r") as f:
info_data = json.load(f)
stats["info_metadata"] = info_data
stats["codebase_version"] = info_data.get("codebase_version")
# Determine format version from codebase_version
if stats["codebase_version"] and stats["codebase_version"].startswith("v3"):
stats["format_version"] = "v3.0"
# In v3.0, total_episodes is in info.json
stats["total_episodes"] = info_data.get("total_episodes", 0)
# Generate episode numbers list
if stats["total_episodes"] > 0:
stats["episode_numbers"] = list(range(stats["total_episodes"]))
logger.info(f"Detected v3.0 format with {stats['total_episodes']} episodes from info.json")
else:
stats["format_version"] = "v2.1"
logger.info("Detected v2.1 format")
logger.info("Successfully fetched metadata from info.json")
except Exception as e:
logger.warning(f"Could not fetch info.json: {str(e)}")
# Assume v2.1 if we can't read info.json
stats["format_version"] = "v2.1"
# List all files in the repository
files = api.list_repo_files(repo_id=repo_id, repo_type="dataset", token=token)
# Detect format and count files based on version
if stats["format_version"] == "v3.0":
# v3.0 format: data/chunk-XXX/file-XXX.parquet and videos/{camera}/chunk-XXX/file-XXX.mp4
parquet_pattern = re.compile(r"data/chunk-\d+/file-\d+\.parquet")
video_pattern = re.compile(r"videos/.+/chunk-\d+/file-\d+\.mp4")
for file_path in files:
if parquet_pattern.search(file_path):
stats["total_parquet_files"] += 1
elif video_pattern.search(file_path):
stats["total_video_files"] += 1
else:
# v2.1 format: episode_XXXX.parquet and episode_XXXX.mp4
parquet_pattern = re.compile(r"data/chunk-\d+/episode_(\d+)\.parquet")
episode_numbers = set()
for file_path in files:
match = parquet_pattern.search(file_path)
if match:
episode_num = int(match.group(1))
episode_numbers.add(episode_num)
stats["total_parquet_files"] += 1
# Count video files (v2.1 format)
if file_path.endswith(".mp4") and "episode_" in file_path:
stats["total_video_files"] += 1
# Update stats if we didn't get episodes from info.json
if episode_numbers:
stats["episode_numbers"] = sorted(list(episode_numbers))
stats["total_episodes"] = len(episode_numbers)
logger.info(
f"Stats for {repo_id} ({stats['format_version']}): "
f"{stats['total_episodes']} episodes, "
f"{stats['total_parquet_files']} parquet files, "
f"{stats['total_video_files']} video files"
)
except Exception as e:
error_msg = f"Error fetching stats: {str(e)}"
logger.error(error_msg)
stats["error"] = error_msg
return stats
def format_stats_display(stats: Dict[str, Any]) -> str:
"""Format stats dictionary into a readable string for display.
Args:
stats: Dictionary of dataset statistics
Returns:
Formatted string for display
"""
if stats.get("error"):
return f"β Error: {stats['error']}"
lines = []
lines.append(f"π **Dataset Statistics for {stats['repo_id']}**")
lines.append("")
# Format version
if stats.get("format_version"):
lines.append(f"**Format Version:** {stats['format_version']}")
# Version info
if stats.get("codebase_version"):
lines.append(f"**Codebase Version:** {stats['codebase_version']}")
lines.append("")
# Basic stats
lines.append(f"**Total Episodes:** {stats['total_episodes']}")
lines.append(f"**Total Parquet Files:** {stats['total_parquet_files']}")
lines.append(f"**Total Video Files:** {stats['total_video_files']}")
# Episode range (mainly for v2.1 or when episode numbers are sequential)
if stats["episode_numbers"] and len(stats["episode_numbers"]) > 0:
episode_nums = stats["episode_numbers"]
lines.append(f"**Episode Range:** {episode_nums[0]} to {episode_nums[-1]}")
# Check for gaps in episodes (only for v2.1)
if stats.get("format_version") == "v2.1":
expected = list(range(episode_nums[0], episode_nums[-1] + 1))
missing = set(expected) - set(episode_nums)
if missing:
lines.append(f"**β οΈ Missing Episodes:** {sorted(list(missing))}")
# Additional metadata from info.json
if stats.get("info_metadata"):
info = stats["info_metadata"]
lines.append("")
lines.append("**Metadata from info.json:**")
# Show key metadata fields (v3.0 has more fields)
if stats.get("format_version") == "v3.0":
metadata_fields = [
("fps", "FPS"),
("robot_type", "Robot Type"),
("total_frames", "Total Frames"),
("total_tasks", "Total Tasks"),
("chunks_size", "Chunks Size"),
("data_files_size_in_mb", "Data Files Size (MB)"),
("video_files_size_in_mb", "Video Files Size (MB)"),
]
else:
# v2.1 fields
metadata_fields = [
("fps", "FPS"),
("robot_type", "Robot Type"),
("total_episodes", "Total Episodes (from metadata)"),
("total_videos", "Total Videos (from metadata)"),
("total_tasks", "Total Tasks"),
("total_frames", "Total Frames"),
]
for key, label in metadata_fields:
if key in info:
lines.append(f" - **{label}:** {info[key]}")
return "\n".join(lines)
def compare_metadata_with_actual(stats: Dict[str, Any]) -> str:
"""Compare metadata from info.json with actual file counts.
Args:
stats: Dictionary of dataset statistics
Returns:
Comparison report string
"""
if not stats.get("info_metadata"):
return "No metadata available for comparison"
info = stats["info_metadata"]
lines = []
lines.append("**π Metadata vs Actual Comparison:**")
lines.append("")
# Compare episodes
metadata_episodes = info.get("total_episodes", "N/A")
actual_episodes = stats["total_episodes"]
match_episodes = "β
" if metadata_episodes == actual_episodes else "β"
lines.append(
f"{match_episodes} **Episodes:** Metadata={metadata_episodes}, Actual={actual_episodes}"
)
# Compare videos
metadata_videos = info.get("total_videos", "N/A")
actual_videos = stats["total_video_files"]
match_videos = "β
" if metadata_videos == actual_videos else "β"
lines.append(
f"{match_videos} **Videos:** Metadata={metadata_videos}, Actual={actual_videos}"
)
return "\n".join(lines)
|