bruAristimunha commited on
Commit
bb09849
·
1 Parent(s): 403c90e

Add --workers parallel push (ThreadPoolExecutor)

Browse files
Files changed (1) hide show
  1. scripts/push_metadata_stubs.py +47 -9
scripts/push_metadata_stubs.py CHANGED
@@ -33,11 +33,13 @@ from __future__ import annotations
33
 
34
  import argparse
35
  import ast
 
36
  import json
37
  import logging
38
  import os
39
  import sys
40
  import tempfile
 
41
  import time
42
  import urllib.error
43
  import urllib.request
@@ -754,6 +756,12 @@ def main(argv: list[str] | None = None) -> int:
754
  parser.add_argument("--dry-run-out", type=Path, default=Path("/tmp/stub_preview"))
755
  parser.add_argument("--private", action="store_true")
756
  parser.add_argument("--token", default=os.environ.get("HF_TOKEN"))
 
 
 
 
 
 
757
  parser.add_argument("-v", "--verbose", action="count", default=0)
758
  args = parser.parse_args(argv)
759
 
@@ -790,28 +798,58 @@ def main(argv: list[str] | None = None) -> int:
790
  logger.info("Dry-run output: %s", args.dry_run_out)
791
  return 0
792
 
793
- failed: list[tuple[str, str]] = []
794
  for r in rows:
795
  slug = str(r["dataset"]).lower()
796
  if slug in existing:
797
  logger.info("skipping %s (exists)", slug)
798
- continue
 
 
 
 
 
 
799
  try:
800
  ctx = _build_context(r)
801
- repo_id = _push_one(ctx, args)
802
- logger.info("pushed %s", repo_id)
803
  except Exception as exc: # noqa: BLE001
804
- logger.exception("failed %s", slug)
805
- failed.append((slug, str(exc)))
806
- # Be polite to the API and HF.
807
- time.sleep(0.25)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
808
 
809
  if failed:
810
  logger.error("%d failures:", len(failed))
811
  for slug, err in failed:
812
  logger.error(" %s — %s", slug, err)
813
  return 1
814
- logger.info("done — %d stubs processed", len(rows) - len(existing))
815
  return 0
816
 
817
 
 
33
 
34
  import argparse
35
  import ast
36
+ import concurrent.futures
37
  import json
38
  import logging
39
  import os
40
  import sys
41
  import tempfile
42
+ import threading
43
  import time
44
  import urllib.error
45
  import urllib.request
 
756
  parser.add_argument("--dry-run-out", type=Path, default=Path("/tmp/stub_preview"))
757
  parser.add_argument("--private", action="store_true")
758
  parser.add_argument("--token", default=os.environ.get("HF_TOKEN"))
759
+ parser.add_argument(
760
+ "--workers",
761
+ type=int,
762
+ default=1,
763
+ help="Parallel pushes (IO-bound — 8-16 is safe; higher risks rate-limits).",
764
+ )
765
  parser.add_argument("-v", "--verbose", action="count", default=0)
766
  args = parser.parse_args(argv)
767
 
 
798
  logger.info("Dry-run output: %s", args.dry_run_out)
799
  return 0
800
 
801
+ pending = [r for r in rows if str(r["dataset"]).lower() not in existing]
802
  for r in rows:
803
  slug = str(r["dataset"]).lower()
804
  if slug in existing:
805
  logger.info("skipping %s (exists)", slug)
806
+
807
+ failed: list[tuple[str, str]] = []
808
+ done = 0
809
+ done_lock = threading.Lock()
810
+
811
+ def _one(r: pd.Series) -> tuple[str, Exception | None]:
812
+ slug = str(r["dataset"]).lower()
813
  try:
814
  ctx = _build_context(r)
815
+ _push_one(ctx, args)
816
+ return slug, None
817
  except Exception as exc: # noqa: BLE001
818
+ return slug, exc
819
+
820
+ if args.workers and args.workers > 1:
821
+ logger.info(
822
+ "parallel push: %d workers, %d pending", args.workers, len(pending)
823
+ )
824
+ with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as pool:
825
+ futures = {pool.submit(_one, r): r for r in pending}
826
+ for fut in concurrent.futures.as_completed(futures):
827
+ slug, err = fut.result()
828
+ if err is None:
829
+ with done_lock:
830
+ done += 1
831
+ logger.info("pushed EEGDash/%s (%d/%d)", slug, done, len(pending))
832
+ else:
833
+ logger.exception("failed %s", slug, exc_info=err)
834
+ failed.append((slug, str(err)))
835
+ else:
836
+ for r in pending:
837
+ slug, err = _one(r)
838
+ if err is None:
839
+ done += 1
840
+ logger.info("pushed EEGDash/%s (%d/%d)", slug, done, len(pending))
841
+ else:
842
+ logger.exception("failed %s", slug, exc_info=err)
843
+ failed.append((slug, str(err)))
844
+ # Serial mode only — parallel mode doesn't need the spacer.
845
+ time.sleep(0.15)
846
 
847
  if failed:
848
  logger.error("%d failures:", len(failed))
849
  for slug, err in failed:
850
  logger.error(" %s — %s", slug, err)
851
  return 1
852
+ logger.info("done — %d stubs processed (%d skipped)", done, len(existing))
853
  return 0
854
 
855