Spaces:
Running
Running
| import os | |
| import re | |
| import json | |
| import logging | |
| from typing import Optional, Dict, Any | |
| from huggingface_hub import HfApi, hf_hub_download | |
| logger = logging.getLogger(__name__) | |
| if not logger.handlers: | |
| logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s") | |
| def get_dataset_stats( | |
| repo_id: str, | |
| hf_token: Optional[str] = None, | |
| ) -> Dict[str, Any]: | |
| """Get statistics for a Hugging Face dataset without downloading all files. | |
| Supports both v2.1 and v3.0 LeRobot dataset formats. | |
| Args: | |
| repo_id: The HuggingFace dataset repo ID | |
| hf_token: Optional HuggingFace token for private datasets | |
| Returns: | |
| Dictionary containing dataset statistics: | |
| - total_episodes: Number of episodes (from info.json for v3.0, or counted for v2.1) | |
| - episode_numbers: List of episode numbers found | |
| - total_parquet_files: Total number of parquet files | |
| - total_video_files: Total number of video files (if present) | |
| - info_metadata: Complete metadata from info.json (if present) | |
| - codebase_version: Dataset version (if present) | |
| - format_version: Detected format version (v2.1 or v3.0) | |
| """ | |
| api = HfApi() | |
| # Public Space: use an explicit token only if the caller passes one. | |
| # Do not implicitly use HF_TOKEN, so visitors cannot proxy private data through this app. | |
| token = hf_token | |
| logger.info(f"Fetching stats for dataset: {repo_id}") | |
| stats = { | |
| "repo_id": repo_id, | |
| "total_episodes": 0, | |
| "episode_numbers": [], | |
| "total_parquet_files": 0, | |
| "total_video_files": 0, | |
| "info_metadata": None, | |
| "codebase_version": None, | |
| "format_version": None, | |
| "error": None, | |
| } | |
| try: | |
| # Try to fetch metadata from info.json first to determine version | |
| try: | |
| info_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="meta/info.json", | |
| repo_type="dataset", | |
| token=token, | |
| ) | |
| with open(info_path, "r") as f: | |
| info_data = json.load(f) | |
| stats["info_metadata"] = info_data | |
| stats["codebase_version"] = info_data.get("codebase_version") | |
| # Determine format version from codebase_version | |
| if stats["codebase_version"] and stats["codebase_version"].startswith("v3"): | |
| stats["format_version"] = "v3.0" | |
| # In v3.0, total_episodes is in info.json | |
| stats["total_episodes"] = info_data.get("total_episodes", 0) | |
| # Generate episode numbers list | |
| if stats["total_episodes"] > 0: | |
| stats["episode_numbers"] = list(range(stats["total_episodes"])) | |
| logger.info(f"Detected v3.0 format with {stats['total_episodes']} episodes from info.json") | |
| else: | |
| stats["format_version"] = "v2.1" | |
| logger.info("Detected v2.1 format") | |
| logger.info("Successfully fetched metadata from info.json") | |
| except Exception as e: | |
| logger.warning(f"Could not fetch info.json: {str(e)}") | |
| # Assume v2.1 if we can't read info.json | |
| stats["format_version"] = "v2.1" | |
| # List all files in the repository | |
| files = api.list_repo_files(repo_id=repo_id, repo_type="dataset", token=token) | |
| # Detect format and count files based on version | |
| if stats["format_version"] == "v3.0": | |
| # v3.0 format: data/chunk-XXX/file-XXX.parquet and videos/{camera}/chunk-XXX/file-XXX.mp4 | |
| parquet_pattern = re.compile(r"data/chunk-\d+/file-\d+\.parquet") | |
| video_pattern = re.compile(r"videos/.+/chunk-\d+/file-\d+\.mp4") | |
| for file_path in files: | |
| if parquet_pattern.search(file_path): | |
| stats["total_parquet_files"] += 1 | |
| elif video_pattern.search(file_path): | |
| stats["total_video_files"] += 1 | |
| else: | |
| # v2.1 format: episode_XXXX.parquet and episode_XXXX.mp4 | |
| parquet_pattern = re.compile(r"data/chunk-\d+/episode_(\d+)\.parquet") | |
| episode_numbers = set() | |
| for file_path in files: | |
| match = parquet_pattern.search(file_path) | |
| if match: | |
| episode_num = int(match.group(1)) | |
| episode_numbers.add(episode_num) | |
| stats["total_parquet_files"] += 1 | |
| # Count video files (v2.1 format) | |
| if file_path.endswith(".mp4") and "episode_" in file_path: | |
| stats["total_video_files"] += 1 | |
| # Update stats if we didn't get episodes from info.json | |
| if episode_numbers: | |
| stats["episode_numbers"] = sorted(list(episode_numbers)) | |
| stats["total_episodes"] = len(episode_numbers) | |
| logger.info( | |
| f"Stats for {repo_id} ({stats['format_version']}): " | |
| f"{stats['total_episodes']} episodes, " | |
| f"{stats['total_parquet_files']} parquet files, " | |
| f"{stats['total_video_files']} video files" | |
| ) | |
| except Exception as e: | |
| error_msg = f"Error fetching stats: {str(e)}" | |
| logger.error(error_msg) | |
| stats["error"] = error_msg | |
| return stats | |
| def format_stats_display(stats: Dict[str, Any]) -> str: | |
| """Format stats dictionary into a readable string for display. | |
| Args: | |
| stats: Dictionary of dataset statistics | |
| Returns: | |
| Formatted string for display | |
| """ | |
| if stats.get("error"): | |
| return f"β Error: {stats['error']}" | |
| lines = [] | |
| lines.append(f"π **Dataset Statistics for {stats['repo_id']}**") | |
| lines.append("") | |
| # Format version | |
| if stats.get("format_version"): | |
| lines.append(f"**Format Version:** {stats['format_version']}") | |
| # Version info | |
| if stats.get("codebase_version"): | |
| lines.append(f"**Codebase Version:** {stats['codebase_version']}") | |
| lines.append("") | |
| # Basic stats | |
| lines.append(f"**Total Episodes:** {stats['total_episodes']}") | |
| lines.append(f"**Total Parquet Files:** {stats['total_parquet_files']}") | |
| lines.append(f"**Total Video Files:** {stats['total_video_files']}") | |
| # Episode range (mainly for v2.1 or when episode numbers are sequential) | |
| if stats["episode_numbers"] and len(stats["episode_numbers"]) > 0: | |
| episode_nums = stats["episode_numbers"] | |
| lines.append(f"**Episode Range:** {episode_nums[0]} to {episode_nums[-1]}") | |
| # Check for gaps in episodes (only for v2.1) | |
| if stats.get("format_version") == "v2.1": | |
| expected = list(range(episode_nums[0], episode_nums[-1] + 1)) | |
| missing = set(expected) - set(episode_nums) | |
| if missing: | |
| lines.append(f"**β οΈ Missing Episodes:** {sorted(list(missing))}") | |
| # Additional metadata from info.json | |
| if stats.get("info_metadata"): | |
| info = stats["info_metadata"] | |
| lines.append("") | |
| lines.append("**Metadata from info.json:**") | |
| # Show key metadata fields (v3.0 has more fields) | |
| if stats.get("format_version") == "v3.0": | |
| metadata_fields = [ | |
| ("fps", "FPS"), | |
| ("robot_type", "Robot Type"), | |
| ("total_frames", "Total Frames"), | |
| ("total_tasks", "Total Tasks"), | |
| ("chunks_size", "Chunks Size"), | |
| ("data_files_size_in_mb", "Data Files Size (MB)"), | |
| ("video_files_size_in_mb", "Video Files Size (MB)"), | |
| ] | |
| else: | |
| # v2.1 fields | |
| metadata_fields = [ | |
| ("fps", "FPS"), | |
| ("robot_type", "Robot Type"), | |
| ("total_episodes", "Total Episodes (from metadata)"), | |
| ("total_videos", "Total Videos (from metadata)"), | |
| ("total_tasks", "Total Tasks"), | |
| ("total_frames", "Total Frames"), | |
| ] | |
| for key, label in metadata_fields: | |
| if key in info: | |
| lines.append(f" - **{label}:** {info[key]}") | |
| return "\n".join(lines) | |
| def compare_metadata_with_actual(stats: Dict[str, Any]) -> str: | |
| """Compare metadata from info.json with actual file counts. | |
| Args: | |
| stats: Dictionary of dataset statistics | |
| Returns: | |
| Comparison report string | |
| """ | |
| if not stats.get("info_metadata"): | |
| return "No metadata available for comparison" | |
| info = stats["info_metadata"] | |
| lines = [] | |
| lines.append("**π Metadata vs Actual Comparison:**") | |
| lines.append("") | |
| # Compare episodes | |
| metadata_episodes = info.get("total_episodes", "N/A") | |
| actual_episodes = stats["total_episodes"] | |
| match_episodes = "β " if metadata_episodes == actual_episodes else "β" | |
| lines.append( | |
| f"{match_episodes} **Episodes:** Metadata={metadata_episodes}, Actual={actual_episodes}" | |
| ) | |
| # Compare videos | |
| metadata_videos = info.get("total_videos", "N/A") | |
| actual_videos = stats["total_video_files"] | |
| match_videos = "β " if metadata_videos == actual_videos else "β" | |
| lines.append( | |
| f"{match_videos} **Videos:** Metadata={metadata_videos}, Actual={actual_videos}" | |
| ) | |
| return "\n".join(lines) | |