dataset-stats / app.py
hqfang's picture
Scope stats viewer to MolmoAct2 BimanualYAM collection
9208e38 verified
import re
from collections import defaultdict
from datetime import datetime
from functools import lru_cache
from typing import List
import gradio as gr
from huggingface_hub import HfApi
from get_dataset_stats import get_dataset_stats
COLLECTION_SLUG = "allenai/molmoact2-bimanualyam-dataset"
COLLECTION_URL = f"https://huggingface.co/collections/{COLLECTION_SLUG}"
@lru_cache(maxsize=1)
def get_collection_datasets() -> List[str]:
"""Return public dataset repos from the MolmoAct2-BimanualYAM collection."""
api = HfApi()
collection = api.get_collection(COLLECTION_SLUG, token=False)
dataset_ids = [item.item_id for item in collection.items if item.item_type == "dataset"]
seen = set()
unique = []
for repo_id in dataset_ids:
if repo_id not in seen:
unique.append(repo_id)
seen.add(repo_id)
return unique
@lru_cache(maxsize=2048)
def get_cached_dataset_stats(repo_id: str):
"""Cache per-repo stats in the Space process to make repeated UI use cheaper."""
return get_dataset_stats(repo_id, hf_token=None)
def get_allowed_dataset_set() -> set[str]:
return set(get_collection_datasets())
def format_duration(seconds):
"""Format duration as hours, minutes, seconds."""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
if hours > 0:
return f"{hours}h {minutes}m {secs}s"
if minutes > 0:
return f"{minutes}m {secs}s"
return f"{secs}s"
def fetch_stats_for_selected(selected_datasets: List[str], progress=gr.Progress()):
"""Fetch statistics for selected collection datasets."""
if not selected_datasets:
return "Please select at least one dataset."
allowed = get_allowed_dataset_set()
selected_datasets = list(dict.fromkeys(selected_datasets))
outside_collection = [repo_id for repo_id in selected_datasets if repo_id not in allowed]
if outside_collection:
blocked = "\n".join(f"- `{repo_id}`" for repo_id in outside_collection[:20])
extra = "\n- ..." if len(outside_collection) > 20 else ""
return (
"Selection contains repos outside the public "
f"[{COLLECTION_SLUG}]({COLLECTION_URL}) collection.\n\n"
f"{blocked}{extra}"
)
total_episodes = 0
v3_by_date = defaultdict(list)
non_v3_results = []
errors = []
for i, repo_id in enumerate(selected_datasets):
try:
progress((i + 1) / len(selected_datasets), desc=f"Processing {repo_id}...")
stats = get_cached_dataset_stats(repo_id)
if stats.get("error"):
errors.append(f"Error for `{repo_id}`: {stats['error']}")
continue
episodes = stats["total_episodes"]
total_episodes += episodes
is_v3 = stats.get("format_version") == "v3.0"
if is_v3:
date_match = re.search(r"/(\d{8})", repo_id)
if date_match:
date_str = date_match.group(1)
try:
date_obj = datetime.strptime(date_str, "%d%m%Y")
date_key = date_obj.strftime("%Y-%m-%d")
date_display = date_obj.strftime("%B %d, %Y")
except ValueError:
date_key = date_str
date_display = date_str
else:
date_key = "unknown"
date_display = "Unknown Date"
v3_by_date[date_key].append(
{
"repo_id": repo_id,
"episodes": episodes,
"date_display": date_display,
"stats": stats,
}
)
else:
non_v3_results.append(
{
"repo_id": repo_id,
"episodes": episodes,
"stats": stats,
}
)
except Exception as e:
errors.append(f"Error for `{repo_id}`: {e}")
total_duration_seconds = 0
for datasets in v3_by_date.values():
for dataset in datasets:
info_meta = dataset["stats"].get("info_metadata", {}) or {}
if info_meta.get("total_frames"):
fps = info_meta.get("fps", 30)
total_duration_seconds += info_meta["total_frames"] / fps
for dataset in non_v3_results:
info_meta = dataset["stats"].get("info_metadata", {}) or {}
if info_meta.get("total_frames"):
fps = info_meta.get("fps", 30)
total_duration_seconds += info_meta["total_frames"] / fps
duration_display = f" • {format_duration(total_duration_seconds)}" if total_duration_seconds > 0 else ""
output = [
f"## Total Episodes: {total_episodes}{duration_display}",
f"Selected datasets: **{len(selected_datasets)}** / **{len(allowed)}**",
]
if v3_by_date:
sorted_dates = sorted([key for key in v3_by_date if key != "unknown"], reverse=True)
if "unknown" in v3_by_date:
sorted_dates.append("unknown")
for date_key in sorted_dates:
datasets = v3_by_date[date_key]
date_display = datasets[0]["date_display"]
date_total_episodes = sum(dataset["episodes"] for dataset in datasets)
date_total_seconds = 0
for dataset in datasets:
info_meta = dataset["stats"].get("info_metadata", {}) or {}
if info_meta.get("total_frames"):
fps = info_meta.get("fps", 30)
date_total_seconds += info_meta["total_frames"] / fps
output.append(
f"\n**{date_display}** — Total: **{date_total_episodes} episodes**"
f" • {format_duration(date_total_seconds)}"
)
for dataset in sorted(datasets, key=lambda item: item["repo_id"]):
repo_name = dataset["repo_id"].split("/")[-1]
episodes = dataset["episodes"]
info_meta = dataset["stats"].get("info_metadata", {}) or {}
duration_str = ""
if info_meta.get("total_frames"):
fps = info_meta.get("fps", 30)
duration_str = f" • {format_duration(info_meta['total_frames'] / fps)}"
output.append(f"- `{repo_name}`: **{episodes} episodes**{duration_str}")
if non_v3_results:
output.append("\n### Other Formats")
for dataset in non_v3_results:
info_meta = dataset["stats"].get("info_metadata", {}) or {}
duration_str = ""
if info_meta.get("total_frames"):
fps = info_meta.get("fps", 30)
duration_str = f" • {format_duration(info_meta['total_frames'] / fps)}"
output.append(f"- `{dataset['repo_id']}`: **{dataset['episodes']} episodes**{duration_str}")
if errors:
output.append("\n### Errors")
output.extend(f"- {error}" for error in errors)
return "\n".join(output)
def load_collection_datasets():
get_collection_datasets.cache_clear()
datasets = get_collection_datasets()
return [
gr.update(choices=datasets, value=[]),
datasets,
f"Loaded **{len(datasets)}** datasets from [{COLLECTION_SLUG}]({COLLECTION_URL}).",
]
def select_matching(filter_text: str, choices: List[str]):
choices = choices or []
query = (filter_text or "").strip().lower()
if not query:
return gr.update(value=choices)
return gr.update(value=[repo_id for repo_id in choices if query in repo_id.lower()])
_initial_datasets = get_collection_datasets()
with gr.Blocks(title="MolmoAct2-BimanualYAM Dataset Stats") as demo:
gr.Markdown(
"# MolmoAct2-BimanualYAM Dataset Stats\n"
f"Public stats viewer for datasets in [{COLLECTION_SLUG}]({COLLECTION_URL})."
)
current_choices = gr.State(_initial_datasets)
collection_status = gr.Markdown(
f"Loaded **{len(_initial_datasets)}** datasets from [{COLLECTION_SLUG}]({COLLECTION_URL})."
)
with gr.Row():
refresh_btn = gr.Button("Refresh Collection", variant="secondary")
filter_box = gr.Textbox(label="Filter", placeholder="Example: tablebuss, scan, 02012026")
dataset_checkboxes = gr.CheckboxGroup(
label="Select Datasets",
choices=_initial_datasets,
interactive=True,
)
with gr.Row():
select_all_btn = gr.Button("Select All", size="sm")
select_matching_btn = gr.Button("Select Matching", size="sm")
deselect_all_btn = gr.Button("Deselect All", size="sm")
fetch_btn = gr.Button("Fetch Statistics", variant="primary")
stats_output = gr.Markdown(
label="Dataset Statistics",
value="Select datasets and click **Fetch Statistics**.",
)
refresh_btn.click(
load_collection_datasets,
outputs=[dataset_checkboxes, current_choices, collection_status],
)
select_all_btn.click(
lambda choices: gr.update(value=choices),
inputs=current_choices,
outputs=dataset_checkboxes,
)
select_matching_btn.click(
select_matching,
inputs=[filter_box, current_choices],
outputs=dataset_checkboxes,
)
deselect_all_btn.click(
lambda: gr.update(value=[]),
outputs=dataset_checkboxes,
)
fetch_btn.click(
fetch_stats_for_selected,
inputs=dataset_checkboxes,
outputs=stats_output,
)
if __name__ == "__main__":
demo.launch()