#!/usr/bin/env python3 """Pull one model-sweep run from the HF artifact repo when it is available.""" from __future__ import annotations import argparse import json from pathlib import Path import shutil from huggingface_hub import HfApi, snapshot_download ROOT = Path(__file__).resolve().parents[1] DEFAULT_RUN_ID = "qwen-qwen2-5-0-5b-instruct" def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Download one PolyGuard HF sweep run.") parser.add_argument("--artifact-repo-id", default="TheJackBright/polyguard-openenv-training-full-artifacts") parser.add_argument("--run-id", default=DEFAULT_RUN_ID) parser.add_argument("--cache-dir", default="/tmp/polyguard-training-artifacts") parser.add_argument("--allow-missing", action="store_true") return parser.parse_args() def _copy_tree(source: Path, target: Path) -> bool: if not source.exists(): return False target.parent.mkdir(parents=True, exist_ok=True) shutil.copytree(source, target, dirs_exist_ok=True) return True def main() -> None: args = parse_args() api = HfApi() files = api.list_repo_files(args.artifact_repo_id, repo_type="model") run_prefixes = [ f"outputs/reports/sweeps/{args.run_id}/", f"docs/results/sweeps/{args.run_id}/", f"checkpoints/sweeps/{args.run_id}/", ] matched = [path for path in files if any(path.startswith(prefix) for prefix in run_prefixes)] if not matched and not args.allow_missing: raise SystemExit( "sweep_artifacts_not_uploaded_yet:" + json.dumps( { "artifact_repo_id": args.artifact_repo_id, "run_id": args.run_id, "repo_file_count": len(files), "available_files": files[:20], }, ensure_ascii=True, ) ) allow_patterns = [ f"outputs/reports/sweeps/{args.run_id}/**", f"docs/results/sweeps/{args.run_id}/**", f"checkpoints/sweeps/{args.run_id}/**", "outputs/plots/**", "outputs/reports/hf_sweep_summary.json", "outputs/reports/anti_hacking_overfit_report.json", "docs/results/*.png", "docs/results/hf_sweep_summary.json", "docs/results/anti_hacking_overfit_report.json", ] snapshot = Path( snapshot_download( repo_id=args.artifact_repo_id, repo_type="model", cache_dir=args.cache_dir, allow_patterns=allow_patterns, ) ) copied = [] for rel in [ f"outputs/reports/sweeps/{args.run_id}", f"docs/results/sweeps/{args.run_id}", f"checkpoints/sweeps/{args.run_id}", "outputs/plots", ]: if _copy_tree(snapshot / rel, ROOT / rel): copied.append(rel) for rel in [ "outputs/reports/hf_sweep_summary.json", "outputs/reports/anti_hacking_overfit_report.json", "docs/results/hf_sweep_summary.json", "docs/results/anti_hacking_overfit_report.json", ]: source = snapshot / rel if source.exists(): target = ROOT / rel target.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(source, target) copied.append(rel) print(json.dumps({"status": "ok", "run_id": args.run_id, "copied": copied}, ensure_ascii=True, indent=2)) if __name__ == "__main__": main()