SuveenE commited on
Commit
d6ea67d
·
1 Parent(s): 723ccd3

Add support for v3

Browse files
Files changed (2) hide show
  1. app.py +4 -4
  2. get_dataset_stats.py +97 -35
app.py CHANGED
@@ -63,13 +63,13 @@ with gr.Blocks(title="LeRobot Dataset Stats Viewer") as demo:
63
  gr.Markdown("**View statistics for Hugging Face datasets (LeRobot format).**")
64
 
65
  # Load initial datasets
66
- _initial_choices = search_datasets_fn("griffinlabs-cortex")
67
 
68
  with gr.Row():
69
  org_input = gr.Textbox(
70
- label="Organization or keyword",
71
- value="griffinlabs-cortex",
72
- placeholder="e.g., lerobot, griffinlabs-cortex"
73
  )
74
  load_btn = gr.Button("Load Datasets")
75
 
 
63
  gr.Markdown("**View statistics for Hugging Face datasets (LeRobot format).**")
64
 
65
  # Load initial datasets
66
+ _initial_choices = search_datasets_fn("")
67
 
68
  with gr.Row():
69
  org_input = gr.Textbox(
70
+ label="Organization",
71
+ value="",
72
+ placeholder="e.g., cortexairobot"
73
  )
74
  load_btn = gr.Button("Load Datasets")
75
 
get_dataset_stats.py CHANGED
@@ -17,6 +17,8 @@ def get_dataset_stats(
17
  ) -> Dict[str, Any]:
18
  """Get statistics for a Hugging Face dataset without downloading all files.
19
 
 
 
20
  Args:
21
  repo_id: The HuggingFace dataset repo ID
22
  hf_token: Optional HuggingFace token for private datasets
@@ -28,7 +30,9 @@ def get_dataset_stats(
28
  - total_parquet_files: Total number of parquet files
29
  - total_video_files: Total number of video files (if present)
30
  - info_metadata: Metadata from info.json (if present)
 
31
  - codebase_version: Dataset version (if present)
 
32
  """
33
  api = HfApi()
34
  token = hf_token or os.environ.get("HF_TOKEN")
@@ -42,34 +46,14 @@ def get_dataset_stats(
42
  "total_parquet_files": 0,
43
  "total_video_files": 0,
44
  "info_metadata": None,
 
45
  "codebase_version": None,
 
46
  "error": None,
47
  }
48
 
49
  try:
50
- # List all files in the repository
51
- files = api.list_repo_files(repo_id=repo_id, repo_type="dataset", token=token)
52
-
53
- # Count parquet files and extract episode numbers
54
- parquet_pattern = re.compile(r"data/chunk-\d+/episode_(\d+)\.parquet")
55
- episode_numbers = set()
56
-
57
- for file_path in files:
58
- match = parquet_pattern.search(file_path)
59
- if match:
60
- episode_num = int(match.group(1))
61
- episode_numbers.add(episode_num)
62
- stats["total_parquet_files"] += 1
63
-
64
- # Count video files
65
- if file_path.endswith(".mp4") and "episode_" in file_path:
66
- stats["total_video_files"] += 1
67
-
68
- # Sort episode numbers and update stats
69
- stats["episode_numbers"] = sorted(list(episode_numbers))
70
- stats["total_episodes"] = len(episode_numbers)
71
-
72
- # Try to fetch metadata from info.json
73
  try:
74
  info_path = hf_hub_download(
75
  repo_id=repo_id,
@@ -87,8 +71,79 @@ def get_dataset_stats(
87
  except Exception as e:
88
  logger.warning(f"Could not fetch info.json: {str(e)}")
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  logger.info(
91
- f"Stats for {repo_id}: {stats['total_episodes']} episodes, "
 
92
  f"{stats['total_parquet_files']} parquet files, "
93
  f"{stats['total_video_files']} video files"
94
  )
@@ -117,25 +172,32 @@ def format_stats_display(stats: Dict[str, Any]) -> str:
117
  lines.append(f"📊 **Dataset Statistics for {stats['repo_id']}**")
118
  lines.append("")
119
 
120
- # Basic stats
121
- lines.append(f"**Total Episodes:** {stats['total_episodes']}")
122
- lines.append(f"**Total Parquet Files:** {stats['total_parquet_files']}")
123
- lines.append(f"**Total Video Files:** {stats['total_video_files']}")
124
 
125
  # Version info
126
  if stats.get("codebase_version"):
127
  lines.append(f"**Codebase Version:** {stats['codebase_version']}")
128
 
129
- # Episode range
130
- if stats["episode_numbers"]:
 
 
 
 
 
 
 
131
  episode_nums = stats["episode_numbers"]
132
  lines.append(f"**Episode Range:** {episode_nums[0]} to {episode_nums[-1]}")
133
 
134
- # Check for gaps in episodes
135
- expected = list(range(episode_nums[0], episode_nums[-1] + 1))
136
- missing = set(expected) - set(episode_nums)
137
- if missing:
138
- lines.append(f"**⚠️ Missing Episodes:** {sorted(list(missing))}")
 
139
 
140
  # Additional metadata from info.json
141
  if stats.get("info_metadata"):
 
17
  ) -> Dict[str, Any]:
18
  """Get statistics for a Hugging Face dataset without downloading all files.
19
 
20
+ Supports both v2.1 and v3.0 LeRobot dataset formats.
21
+
22
  Args:
23
  repo_id: The HuggingFace dataset repo ID
24
  hf_token: Optional HuggingFace token for private datasets
 
30
  - total_parquet_files: Total number of parquet files
31
  - total_video_files: Total number of video files (if present)
32
  - info_metadata: Metadata from info.json (if present)
33
+ - episodes_metadata: Metadata from episodes.json (v3.0 only)
34
  - codebase_version: Dataset version (if present)
35
+ - format_version: Detected format version (v2.1 or v3.0)
36
  """
37
  api = HfApi()
38
  token = hf_token or os.environ.get("HF_TOKEN")
 
46
  "total_parquet_files": 0,
47
  "total_video_files": 0,
48
  "info_metadata": None,
49
+ "episodes_metadata": None,
50
  "codebase_version": None,
51
+ "format_version": None,
52
  "error": None,
53
  }
54
 
55
  try:
56
+ # Try to fetch metadata from info.json first to determine version
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  try:
58
  info_path = hf_hub_download(
59
  repo_id=repo_id,
 
71
  except Exception as e:
72
  logger.warning(f"Could not fetch info.json: {str(e)}")
73
 
74
+ # Try to fetch episodes.json (v3.0 format)
75
+ try:
76
+ episodes_path = hf_hub_download(
77
+ repo_id=repo_id,
78
+ filename="meta/episodes.json",
79
+ repo_type="dataset",
80
+ token=token,
81
+ )
82
+
83
+ with open(episodes_path, "r") as f:
84
+ episodes_data = json.load(f)
85
+ stats["episodes_metadata"] = episodes_data
86
+ stats["format_version"] = "v3.0"
87
+
88
+ # In v3.0, episodes.json contains episode information
89
+ # It's typically a list of episode metadata
90
+ if isinstance(episodes_data, list):
91
+ stats["total_episodes"] = len(episodes_data)
92
+ stats["episode_numbers"] = list(range(len(episodes_data)))
93
+ elif isinstance(episodes_data, dict):
94
+ # Handle different possible structures
95
+ if "episodes" in episodes_data:
96
+ episodes_list = episodes_data["episodes"]
97
+ stats["total_episodes"] = len(episodes_list)
98
+ stats["episode_numbers"] = list(range(len(episodes_list)))
99
+ else:
100
+ # Fallback: count keys
101
+ stats["total_episodes"] = len(episodes_data)
102
+ stats["episode_numbers"] = sorted([int(k) for k in episodes_data.keys() if k.isdigit()])
103
+
104
+ logger.info("Successfully fetched episodes.json - detected v3.0 format")
105
+ except Exception as e:
106
+ logger.info(f"Could not fetch episodes.json (may be v2.1 format): {str(e)}")
107
+ stats["format_version"] = "v2.1"
108
+
109
+ # List all files in the repository
110
+ files = api.list_repo_files(repo_id=repo_id, repo_type="dataset", token=token)
111
+
112
+ # Detect format and count files based on version
113
+ if stats["format_version"] == "v3.0":
114
+ # v3.0 format: file-XXXX.parquet and file-XXXX.mp4
115
+ parquet_pattern = re.compile(r"data/.+/file[-_]\d+\.parquet")
116
+ video_pattern = re.compile(r"videos/.+/file[-_]\d+\.mp4")
117
+
118
+ for file_path in files:
119
+ if parquet_pattern.search(file_path):
120
+ stats["total_parquet_files"] += 1
121
+ elif video_pattern.search(file_path):
122
+ stats["total_video_files"] += 1
123
+ else:
124
+ # v2.1 format: episode_XXXX.parquet and episode_XXXX.mp4
125
+ parquet_pattern = re.compile(r"data/chunk-\d+/episode_(\d+)\.parquet")
126
+ episode_numbers = set()
127
+
128
+ for file_path in files:
129
+ match = parquet_pattern.search(file_path)
130
+ if match:
131
+ episode_num = int(match.group(1))
132
+ episode_numbers.add(episode_num)
133
+ stats["total_parquet_files"] += 1
134
+
135
+ # Count video files (v2.1 format)
136
+ if file_path.endswith(".mp4") and "episode_" in file_path:
137
+ stats["total_video_files"] += 1
138
+
139
+ # Update stats if we didn't get episodes from episodes.json
140
+ if episode_numbers:
141
+ stats["episode_numbers"] = sorted(list(episode_numbers))
142
+ stats["total_episodes"] = len(episode_numbers)
143
+
144
  logger.info(
145
+ f"Stats for {repo_id} ({stats['format_version']}): "
146
+ f"{stats['total_episodes']} episodes, "
147
  f"{stats['total_parquet_files']} parquet files, "
148
  f"{stats['total_video_files']} video files"
149
  )
 
172
  lines.append(f"📊 **Dataset Statistics for {stats['repo_id']}**")
173
  lines.append("")
174
 
175
+ # Format version
176
+ if stats.get("format_version"):
177
+ lines.append(f"**Format Version:** {stats['format_version']}")
 
178
 
179
  # Version info
180
  if stats.get("codebase_version"):
181
  lines.append(f"**Codebase Version:** {stats['codebase_version']}")
182
 
183
+ lines.append("")
184
+
185
+ # Basic stats
186
+ lines.append(f"**Total Episodes:** {stats['total_episodes']}")
187
+ lines.append(f"**Total Parquet Files:** {stats['total_parquet_files']}")
188
+ lines.append(f"**Total Video Files:** {stats['total_video_files']}")
189
+
190
+ # Episode range (mainly for v2.1 or when episode numbers are sequential)
191
+ if stats["episode_numbers"] and len(stats["episode_numbers"]) > 0:
192
  episode_nums = stats["episode_numbers"]
193
  lines.append(f"**Episode Range:** {episode_nums[0]} to {episode_nums[-1]}")
194
 
195
+ # Check for gaps in episodes (only for v2.1)
196
+ if stats.get("format_version") == "v2.1":
197
+ expected = list(range(episode_nums[0], episode_nums[-1] + 1))
198
+ missing = set(expected) - set(episode_nums)
199
+ if missing:
200
+ lines.append(f"**⚠️ Missing Episodes:** {sorted(list(missing))}")
201
 
202
  # Additional metadata from info.json
203
  if stats.get("info_metadata"):