#!/usr/bin/env python3 """ Download only the files needed for training (defined by filtered_index.json) from HuggingFaceVLA/community_dataset_v3. Rate-limit-aware greedy scheduler: downloads small files when we have rate limit headroom, swaps to large files (videos) when approaching the limit to keep bandwidth busy while the window recovers. Goal: never actually hit a 429. """ import argparse import json import os import sys import time import threading from collections import defaultdict, deque from concurrent.futures import ThreadPoolExecutor, as_completed, Future from huggingface_hub import hf_hub_download RATE_LIMIT = 100 # 1000 actual / ~10 API calls per hf_hub_download RATE_WINDOW = 300 # 5 minutes class RateLimitTracker: """Sliding window request counter.""" def __init__(self): self.lock = threading.Lock() self.timestamps: deque[float] = deque() def record(self): now = time.time() with self.lock: self.timestamps.append(now) self._prune(now) def count(self) -> int: now = time.time() with self.lock: self._prune(now) return len(self.timestamps) def headroom(self) -> int: """How many more requests we can make in this window.""" return max(0, RATE_LIMIT - self.count()) def wait_if_needed(self): """If we've exhausted the window, sleep until oldest request expires.""" while self.headroom() <= 0: with self.lock: if self.timestamps: wait = RATE_WINDOW - (time.time() - self.timestamps[0]) + 1 if wait > 0: print(f" Rate limit reached, waiting {wait:.0f}s...", flush=True) # Release lock while sleeping else: wait = 0 if wait > 0: time.sleep(wait) def _prune(self, now): cutoff = now - RATE_WINDOW while self.timestamps and self.timestamps[0] < cutoff: self.timestamps.popleft() class FileQueue: """Thread-safe queue that serves small or large files on demand.""" def __init__(self, small_files: list[str], large_files: list[str]): self.lock = threading.Lock() self.small = deque(small_files) self.large = deque(large_files) self.total = len(small_files) + len(large_files) def get(self, prefer_small: bool) -> str | None: with self.lock: if prefer_small and self.small: return self.small.popleft() elif self.large: return self.large.popleft() elif self.small: return self.small.popleft() return None def remaining(self) -> int: with self.lock: return len(self.small) + len(self.large) def small_remaining(self) -> int: with self.lock: return len(self.small) def large_remaining(self) -> int: with self.lock: return len(self.large) def build_file_lists(index_path: str, output_dir: str) -> tuple[list[str], list[str], int]: """Returns (small_files, large_files, skipped_count) from filtered_index.json. Skips files already on disk.""" with open(index_path) as f: index = json.load(f) datasets = defaultdict(list) for ep in index["episodes"]: datasets[ep["dataset"]].append(ep["episode_index"]) small = [] large = [] skipped = 0 def add_if_missing(filepath, target_list): nonlocal skipped if os.path.exists(os.path.join(output_dir, filepath)): skipped += 1 else: target_list.append(filepath) for dataset_name, episode_indices in datasets.items(): prefix = dataset_name add_if_missing(f"{prefix}/meta/info.json", small) add_if_missing(f"{prefix}/meta/tasks.jsonl", small) add_if_missing(f"{prefix}/meta/episodes.jsonl", small) for ep_idx in episode_indices: ep_str = f"episode_{ep_idx:06d}" add_if_missing(f"{prefix}/data/chunk-000/{ep_str}.parquet", small) add_if_missing(f"{prefix}/videos/chunk-000/observation.images.image/{ep_str}.mp4", large) add_if_missing(f"{prefix}/videos/chunk-000/observation.images.image2/{ep_str}.mp4", large) return small, large, skipped # Shared state tracker = RateLimitTracker() queue: FileQueue = None stats_lock = threading.Lock() downloaded = 0 total_bytes = 0 failed = [] start_time = 0 # When headroom drops below this, prefer large files HEADROOM_THRESHOLD = 50 def worker(output_dir, token): """Worker loop: grab a file based on rate limit state, download it, repeat.""" global downloaded, total_bytes while True: headroom = tracker.headroom() prefer_small = headroom > HEADROOM_THRESHOLD filepath = queue.get(prefer_small) if filepath is None: return for attempt in range(10): tracker.wait_if_needed() tracker.record() try: path = hf_hub_download( repo_id="HuggingFaceVLA/community_dataset_v3", repo_type="dataset", filename=filepath, local_dir=output_dir, token=token, ) size = os.path.getsize(path) with stats_lock: downloaded += 1 total_bytes += size _maybe_log() break except Exception as e: if "429" in str(e) and attempt < 9: time.sleep(30 * (attempt + 1)) continue with stats_lock: failed.append((filepath, str(e))) _maybe_log() break def _maybe_log(): """Log progress every 100 files. Must be called with stats_lock held.""" total = downloaded + len(failed) if total % 100 == 0 and total > 0: elapsed = time.time() - start_time rate = total / elapsed if elapsed > 0 else 0 mb_s = (total_bytes / 1024 / 1024) / elapsed if elapsed > 0 else 0 gb_done = total_bytes / 1024 / 1024 / 1024 headroom = tracker.headroom() remaining = queue.remaining() est_min = remaining / rate / 60 if rate > 0 else 0 print(f" [{total}/{queue.total}] {gb_done:.1f}GB, " f"{mb_s:.0f} MB/s, {rate:.1f} files/s, " f"headroom: {headroom}/{RATE_LIMIT}, " f"queued: {queue.small_remaining()}s+{queue.large_remaining()}L, " f"~{est_min:.0f}min left", flush=True) def main(): global queue, start_time parser = argparse.ArgumentParser(description="Download training subset from community_dataset_v3") parser.add_argument("--index", type=str, default="filtered_index.json") parser.add_argument("--output", type=str, default="/data/community_dataset_v3") parser.add_argument("--token", type=str, default=os.environ.get("HF_TOKEN")) parser.add_argument("--workers", type=int, default=8) args = parser.parse_args() if not args.token: print("ERROR: Set HF_TOKEN or pass --token") return small, large, skipped = build_file_lists(args.index, args.output) queue = FileQueue(small, large) print(f"Files to download: {queue.total} ({skipped} already on disk, skipped)") print(f" Small (metadata+parquets): {len(small)}") print(f" Large (videos): {len(large)}") print(f" Workers: {args.workers}") print(f" Rate limit: {RATE_LIMIT}/{RATE_WINDOW}s, " f"swap to large files at <{HEADROOM_THRESHOLD} headroom") print() start_time = time.time() with ThreadPoolExecutor(max_workers=args.workers) as pool: futures = [pool.submit(worker, args.output, args.token) for _ in range(args.workers)] for f in futures: f.result() elapsed = time.time() - start_time gb_total = total_bytes / 1024 / 1024 / 1024 print(f"\nDone in {elapsed/60:.1f} min: {downloaded} files, " f"{gb_total:.1f}GB, {len(failed)} failed") if failed: print("Failed files:") for f, err in failed[:20]: print(f" {f}: {err}") if len(failed) > 20: print(f" ... and {len(failed) - 20} more") sys.exit(1) if __name__ == "__main__": main()