Spaces:
Running
Running
| import re | |
| from collections import defaultdict | |
| from datetime import datetime | |
| from functools import lru_cache | |
| from typing import List | |
| import gradio as gr | |
| from huggingface_hub import HfApi | |
| from get_dataset_stats import get_dataset_stats | |
| COLLECTION_SLUG = "allenai/molmoact2-bimanualyam-dataset" | |
| COLLECTION_URL = f"https://huggingface.co/collections/{COLLECTION_SLUG}" | |
| def get_collection_datasets() -> List[str]: | |
| """Return public dataset repos from the MolmoAct2-BimanualYAM collection.""" | |
| api = HfApi() | |
| collection = api.get_collection(COLLECTION_SLUG, token=False) | |
| dataset_ids = [item.item_id for item in collection.items if item.item_type == "dataset"] | |
| seen = set() | |
| unique = [] | |
| for repo_id in dataset_ids: | |
| if repo_id not in seen: | |
| unique.append(repo_id) | |
| seen.add(repo_id) | |
| return unique | |
| def get_cached_dataset_stats(repo_id: str): | |
| """Cache per-repo stats in the Space process to make repeated UI use cheaper.""" | |
| return get_dataset_stats(repo_id, hf_token=None) | |
| def get_allowed_dataset_set() -> set[str]: | |
| return set(get_collection_datasets()) | |
| def format_duration(seconds): | |
| """Format duration as hours, minutes, seconds.""" | |
| hours = int(seconds // 3600) | |
| minutes = int((seconds % 3600) // 60) | |
| secs = int(seconds % 60) | |
| if hours > 0: | |
| return f"{hours}h {minutes}m {secs}s" | |
| if minutes > 0: | |
| return f"{minutes}m {secs}s" | |
| return f"{secs}s" | |
| def fetch_stats_for_selected(selected_datasets: List[str], progress=gr.Progress()): | |
| """Fetch statistics for selected collection datasets.""" | |
| if not selected_datasets: | |
| return "Please select at least one dataset." | |
| allowed = get_allowed_dataset_set() | |
| selected_datasets = list(dict.fromkeys(selected_datasets)) | |
| outside_collection = [repo_id for repo_id in selected_datasets if repo_id not in allowed] | |
| if outside_collection: | |
| blocked = "\n".join(f"- `{repo_id}`" for repo_id in outside_collection[:20]) | |
| extra = "\n- ..." if len(outside_collection) > 20 else "" | |
| return ( | |
| "Selection contains repos outside the public " | |
| f"[{COLLECTION_SLUG}]({COLLECTION_URL}) collection.\n\n" | |
| f"{blocked}{extra}" | |
| ) | |
| total_episodes = 0 | |
| v3_by_date = defaultdict(list) | |
| non_v3_results = [] | |
| errors = [] | |
| for i, repo_id in enumerate(selected_datasets): | |
| try: | |
| progress((i + 1) / len(selected_datasets), desc=f"Processing {repo_id}...") | |
| stats = get_cached_dataset_stats(repo_id) | |
| if stats.get("error"): | |
| errors.append(f"Error for `{repo_id}`: {stats['error']}") | |
| continue | |
| episodes = stats["total_episodes"] | |
| total_episodes += episodes | |
| is_v3 = stats.get("format_version") == "v3.0" | |
| if is_v3: | |
| date_match = re.search(r"/(\d{8})", repo_id) | |
| if date_match: | |
| date_str = date_match.group(1) | |
| try: | |
| date_obj = datetime.strptime(date_str, "%d%m%Y") | |
| date_key = date_obj.strftime("%Y-%m-%d") | |
| date_display = date_obj.strftime("%B %d, %Y") | |
| except ValueError: | |
| date_key = date_str | |
| date_display = date_str | |
| else: | |
| date_key = "unknown" | |
| date_display = "Unknown Date" | |
| v3_by_date[date_key].append( | |
| { | |
| "repo_id": repo_id, | |
| "episodes": episodes, | |
| "date_display": date_display, | |
| "stats": stats, | |
| } | |
| ) | |
| else: | |
| non_v3_results.append( | |
| { | |
| "repo_id": repo_id, | |
| "episodes": episodes, | |
| "stats": stats, | |
| } | |
| ) | |
| except Exception as e: | |
| errors.append(f"Error for `{repo_id}`: {e}") | |
| total_duration_seconds = 0 | |
| for datasets in v3_by_date.values(): | |
| for dataset in datasets: | |
| info_meta = dataset["stats"].get("info_metadata", {}) or {} | |
| if info_meta.get("total_frames"): | |
| fps = info_meta.get("fps", 30) | |
| total_duration_seconds += info_meta["total_frames"] / fps | |
| for dataset in non_v3_results: | |
| info_meta = dataset["stats"].get("info_metadata", {}) or {} | |
| if info_meta.get("total_frames"): | |
| fps = info_meta.get("fps", 30) | |
| total_duration_seconds += info_meta["total_frames"] / fps | |
| duration_display = f" • {format_duration(total_duration_seconds)}" if total_duration_seconds > 0 else "" | |
| output = [ | |
| f"## Total Episodes: {total_episodes}{duration_display}", | |
| f"Selected datasets: **{len(selected_datasets)}** / **{len(allowed)}**", | |
| ] | |
| if v3_by_date: | |
| sorted_dates = sorted([key for key in v3_by_date if key != "unknown"], reverse=True) | |
| if "unknown" in v3_by_date: | |
| sorted_dates.append("unknown") | |
| for date_key in sorted_dates: | |
| datasets = v3_by_date[date_key] | |
| date_display = datasets[0]["date_display"] | |
| date_total_episodes = sum(dataset["episodes"] for dataset in datasets) | |
| date_total_seconds = 0 | |
| for dataset in datasets: | |
| info_meta = dataset["stats"].get("info_metadata", {}) or {} | |
| if info_meta.get("total_frames"): | |
| fps = info_meta.get("fps", 30) | |
| date_total_seconds += info_meta["total_frames"] / fps | |
| output.append( | |
| f"\n**{date_display}** — Total: **{date_total_episodes} episodes**" | |
| f" • {format_duration(date_total_seconds)}" | |
| ) | |
| for dataset in sorted(datasets, key=lambda item: item["repo_id"]): | |
| repo_name = dataset["repo_id"].split("/")[-1] | |
| episodes = dataset["episodes"] | |
| info_meta = dataset["stats"].get("info_metadata", {}) or {} | |
| duration_str = "" | |
| if info_meta.get("total_frames"): | |
| fps = info_meta.get("fps", 30) | |
| duration_str = f" • {format_duration(info_meta['total_frames'] / fps)}" | |
| output.append(f"- `{repo_name}`: **{episodes} episodes**{duration_str}") | |
| if non_v3_results: | |
| output.append("\n### Other Formats") | |
| for dataset in non_v3_results: | |
| info_meta = dataset["stats"].get("info_metadata", {}) or {} | |
| duration_str = "" | |
| if info_meta.get("total_frames"): | |
| fps = info_meta.get("fps", 30) | |
| duration_str = f" • {format_duration(info_meta['total_frames'] / fps)}" | |
| output.append(f"- `{dataset['repo_id']}`: **{dataset['episodes']} episodes**{duration_str}") | |
| if errors: | |
| output.append("\n### Errors") | |
| output.extend(f"- {error}" for error in errors) | |
| return "\n".join(output) | |
| def load_collection_datasets(): | |
| get_collection_datasets.cache_clear() | |
| datasets = get_collection_datasets() | |
| return [ | |
| gr.update(choices=datasets, value=[]), | |
| datasets, | |
| f"Loaded **{len(datasets)}** datasets from [{COLLECTION_SLUG}]({COLLECTION_URL}).", | |
| ] | |
| def select_matching(filter_text: str, choices: List[str]): | |
| choices = choices or [] | |
| query = (filter_text or "").strip().lower() | |
| if not query: | |
| return gr.update(value=choices) | |
| return gr.update(value=[repo_id for repo_id in choices if query in repo_id.lower()]) | |
| _initial_datasets = get_collection_datasets() | |
| with gr.Blocks(title="MolmoAct2-BimanualYAM Dataset Stats") as demo: | |
| gr.Markdown( | |
| "# MolmoAct2-BimanualYAM Dataset Stats\n" | |
| f"Public stats viewer for datasets in [{COLLECTION_SLUG}]({COLLECTION_URL})." | |
| ) | |
| current_choices = gr.State(_initial_datasets) | |
| collection_status = gr.Markdown( | |
| f"Loaded **{len(_initial_datasets)}** datasets from [{COLLECTION_SLUG}]({COLLECTION_URL})." | |
| ) | |
| with gr.Row(): | |
| refresh_btn = gr.Button("Refresh Collection", variant="secondary") | |
| filter_box = gr.Textbox(label="Filter", placeholder="Example: tablebuss, scan, 02012026") | |
| dataset_checkboxes = gr.CheckboxGroup( | |
| label="Select Datasets", | |
| choices=_initial_datasets, | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| select_all_btn = gr.Button("Select All", size="sm") | |
| select_matching_btn = gr.Button("Select Matching", size="sm") | |
| deselect_all_btn = gr.Button("Deselect All", size="sm") | |
| fetch_btn = gr.Button("Fetch Statistics", variant="primary") | |
| stats_output = gr.Markdown( | |
| label="Dataset Statistics", | |
| value="Select datasets and click **Fetch Statistics**.", | |
| ) | |
| refresh_btn.click( | |
| load_collection_datasets, | |
| outputs=[dataset_checkboxes, current_choices, collection_status], | |
| ) | |
| select_all_btn.click( | |
| lambda choices: gr.update(value=choices), | |
| inputs=current_choices, | |
| outputs=dataset_checkboxes, | |
| ) | |
| select_matching_btn.click( | |
| select_matching, | |
| inputs=[filter_box, current_choices], | |
| outputs=dataset_checkboxes, | |
| ) | |
| deselect_all_btn.click( | |
| lambda: gr.update(value=[]), | |
| outputs=dataset_checkboxes, | |
| ) | |
| fetch_btn.click( | |
| fetch_stats_for_selected, | |
| inputs=dataset_checkboxes, | |
| outputs=stats_output, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |