pi05-so100-diverse / download_subset.py
bot
Update lerobot to latest with SO100 rename_map fix
a8eb6e5
#!/usr/bin/env python3
"""
Download only the files needed for training (defined by filtered_index.json)
from HuggingFaceVLA/community_dataset_v3.
Rate-limit-aware greedy scheduler: downloads small files when we have rate limit
headroom, swaps to large files (videos) when approaching the limit to keep
bandwidth busy while the window recovers. Goal: never actually hit a 429.
"""
import argparse
import json
import os
import sys
import time
import threading
from collections import defaultdict, deque
from concurrent.futures import ThreadPoolExecutor, as_completed, Future
from huggingface_hub import hf_hub_download
RATE_LIMIT = 100 # 1000 actual / ~10 API calls per hf_hub_download
RATE_WINDOW = 300 # 5 minutes
class RateLimitTracker:
"""Sliding window request counter."""
def __init__(self):
self.lock = threading.Lock()
self.timestamps: deque[float] = deque()
def record(self):
now = time.time()
with self.lock:
self.timestamps.append(now)
self._prune(now)
def count(self) -> int:
now = time.time()
with self.lock:
self._prune(now)
return len(self.timestamps)
def headroom(self) -> int:
"""How many more requests we can make in this window."""
return max(0, RATE_LIMIT - self.count())
def wait_if_needed(self):
"""If we've exhausted the window, sleep until oldest request expires."""
while self.headroom() <= 0:
with self.lock:
if self.timestamps:
wait = RATE_WINDOW - (time.time() - self.timestamps[0]) + 1
if wait > 0:
print(f" Rate limit reached, waiting {wait:.0f}s...", flush=True)
# Release lock while sleeping
else:
wait = 0
if wait > 0:
time.sleep(wait)
def _prune(self, now):
cutoff = now - RATE_WINDOW
while self.timestamps and self.timestamps[0] < cutoff:
self.timestamps.popleft()
class FileQueue:
"""Thread-safe queue that serves small or large files on demand."""
def __init__(self, small_files: list[str], large_files: list[str]):
self.lock = threading.Lock()
self.small = deque(small_files)
self.large = deque(large_files)
self.total = len(small_files) + len(large_files)
def get(self, prefer_small: bool) -> str | None:
with self.lock:
if prefer_small and self.small:
return self.small.popleft()
elif self.large:
return self.large.popleft()
elif self.small:
return self.small.popleft()
return None
def remaining(self) -> int:
with self.lock:
return len(self.small) + len(self.large)
def small_remaining(self) -> int:
with self.lock:
return len(self.small)
def large_remaining(self) -> int:
with self.lock:
return len(self.large)
def build_file_lists(index_path: str, output_dir: str) -> tuple[list[str], list[str], int]:
"""Returns (small_files, large_files, skipped_count) from filtered_index.json.
Skips files already on disk."""
with open(index_path) as f:
index = json.load(f)
datasets = defaultdict(list)
for ep in index["episodes"]:
datasets[ep["dataset"]].append(ep["episode_index"])
small = []
large = []
skipped = 0
def add_if_missing(filepath, target_list):
nonlocal skipped
if os.path.exists(os.path.join(output_dir, filepath)):
skipped += 1
else:
target_list.append(filepath)
for dataset_name, episode_indices in datasets.items():
prefix = dataset_name
add_if_missing(f"{prefix}/meta/info.json", small)
add_if_missing(f"{prefix}/meta/tasks.jsonl", small)
add_if_missing(f"{prefix}/meta/episodes.jsonl", small)
for ep_idx in episode_indices:
ep_str = f"episode_{ep_idx:06d}"
add_if_missing(f"{prefix}/data/chunk-000/{ep_str}.parquet", small)
add_if_missing(f"{prefix}/videos/chunk-000/observation.images.image/{ep_str}.mp4", large)
add_if_missing(f"{prefix}/videos/chunk-000/observation.images.image2/{ep_str}.mp4", large)
return small, large, skipped
# Shared state
tracker = RateLimitTracker()
queue: FileQueue = None
stats_lock = threading.Lock()
downloaded = 0
total_bytes = 0
failed = []
start_time = 0
# When headroom drops below this, prefer large files
HEADROOM_THRESHOLD = 50
def worker(output_dir, token):
"""Worker loop: grab a file based on rate limit state, download it, repeat."""
global downloaded, total_bytes
while True:
headroom = tracker.headroom()
prefer_small = headroom > HEADROOM_THRESHOLD
filepath = queue.get(prefer_small)
if filepath is None:
return
for attempt in range(10):
tracker.wait_if_needed()
tracker.record()
try:
path = hf_hub_download(
repo_id="HuggingFaceVLA/community_dataset_v3",
repo_type="dataset",
filename=filepath,
local_dir=output_dir,
token=token,
)
size = os.path.getsize(path)
with stats_lock:
downloaded += 1
total_bytes += size
_maybe_log()
break
except Exception as e:
if "429" in str(e) and attempt < 9:
time.sleep(30 * (attempt + 1))
continue
with stats_lock:
failed.append((filepath, str(e)))
_maybe_log()
break
def _maybe_log():
"""Log progress every 100 files. Must be called with stats_lock held."""
total = downloaded + len(failed)
if total % 100 == 0 and total > 0:
elapsed = time.time() - start_time
rate = total / elapsed if elapsed > 0 else 0
mb_s = (total_bytes / 1024 / 1024) / elapsed if elapsed > 0 else 0
gb_done = total_bytes / 1024 / 1024 / 1024
headroom = tracker.headroom()
remaining = queue.remaining()
est_min = remaining / rate / 60 if rate > 0 else 0
print(f" [{total}/{queue.total}] {gb_done:.1f}GB, "
f"{mb_s:.0f} MB/s, {rate:.1f} files/s, "
f"headroom: {headroom}/{RATE_LIMIT}, "
f"queued: {queue.small_remaining()}s+{queue.large_remaining()}L, "
f"~{est_min:.0f}min left", flush=True)
def main():
global queue, start_time
parser = argparse.ArgumentParser(description="Download training subset from community_dataset_v3")
parser.add_argument("--index", type=str, default="filtered_index.json")
parser.add_argument("--output", type=str, default="/data/community_dataset_v3")
parser.add_argument("--token", type=str, default=os.environ.get("HF_TOKEN"))
parser.add_argument("--workers", type=int, default=8)
args = parser.parse_args()
if not args.token:
print("ERROR: Set HF_TOKEN or pass --token")
return
small, large, skipped = build_file_lists(args.index, args.output)
queue = FileQueue(small, large)
print(f"Files to download: {queue.total} ({skipped} already on disk, skipped)")
print(f" Small (metadata+parquets): {len(small)}")
print(f" Large (videos): {len(large)}")
print(f" Workers: {args.workers}")
print(f" Rate limit: {RATE_LIMIT}/{RATE_WINDOW}s, "
f"swap to large files at <{HEADROOM_THRESHOLD} headroom")
print()
start_time = time.time()
with ThreadPoolExecutor(max_workers=args.workers) as pool:
futures = [pool.submit(worker, args.output, args.token)
for _ in range(args.workers)]
for f in futures:
f.result()
elapsed = time.time() - start_time
gb_total = total_bytes / 1024 / 1024 / 1024
print(f"\nDone in {elapsed/60:.1f} min: {downloaded} files, "
f"{gb_total:.1f}GB, {len(failed)} failed")
if failed:
print("Failed files:")
for f, err in failed[:20]:
print(f" {f}: {err}")
if len(failed) > 20:
print(f" ... and {len(failed) - 20} more")
sys.exit(1)
if __name__ == "__main__":
main()