adithya9903's picture
Deploy PolyGuard HF training Space
fd0c71a verified
#!/usr/bin/env python3
"""Pull one model-sweep run from the HF artifact repo when it is available."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import shutil
from huggingface_hub import HfApi, snapshot_download
ROOT = Path(__file__).resolve().parents[1]
DEFAULT_RUN_ID = "qwen-qwen2-5-0-5b-instruct"
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Download one PolyGuard HF sweep run.")
parser.add_argument("--artifact-repo-id", default="TheJackBright/polyguard-openenv-training-full-artifacts")
parser.add_argument("--run-id", default=DEFAULT_RUN_ID)
parser.add_argument("--cache-dir", default="/tmp/polyguard-training-artifacts")
parser.add_argument("--allow-missing", action="store_true")
return parser.parse_args()
def _copy_tree(source: Path, target: Path) -> bool:
if not source.exists():
return False
target.parent.mkdir(parents=True, exist_ok=True)
shutil.copytree(source, target, dirs_exist_ok=True)
return True
def main() -> None:
args = parse_args()
api = HfApi()
files = api.list_repo_files(args.artifact_repo_id, repo_type="model")
run_prefixes = [
f"outputs/reports/sweeps/{args.run_id}/",
f"docs/results/sweeps/{args.run_id}/",
f"checkpoints/sweeps/{args.run_id}/",
]
matched = [path for path in files if any(path.startswith(prefix) for prefix in run_prefixes)]
if not matched and not args.allow_missing:
raise SystemExit(
"sweep_artifacts_not_uploaded_yet:"
+ json.dumps(
{
"artifact_repo_id": args.artifact_repo_id,
"run_id": args.run_id,
"repo_file_count": len(files),
"available_files": files[:20],
},
ensure_ascii=True,
)
)
allow_patterns = [
f"outputs/reports/sweeps/{args.run_id}/**",
f"docs/results/sweeps/{args.run_id}/**",
f"checkpoints/sweeps/{args.run_id}/**",
"outputs/plots/**",
"outputs/reports/hf_sweep_summary.json",
"outputs/reports/anti_hacking_overfit_report.json",
"docs/results/*.png",
"docs/results/hf_sweep_summary.json",
"docs/results/anti_hacking_overfit_report.json",
]
snapshot = Path(
snapshot_download(
repo_id=args.artifact_repo_id,
repo_type="model",
cache_dir=args.cache_dir,
allow_patterns=allow_patterns,
)
)
copied = []
for rel in [
f"outputs/reports/sweeps/{args.run_id}",
f"docs/results/sweeps/{args.run_id}",
f"checkpoints/sweeps/{args.run_id}",
"outputs/plots",
]:
if _copy_tree(snapshot / rel, ROOT / rel):
copied.append(rel)
for rel in [
"outputs/reports/hf_sweep_summary.json",
"outputs/reports/anti_hacking_overfit_report.json",
"docs/results/hf_sweep_summary.json",
"docs/results/anti_hacking_overfit_report.json",
]:
source = snapshot / rel
if source.exists():
target = ROOT / rel
target.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(source, target)
copied.append(rel)
print(json.dumps({"status": "ok", "run_id": args.run_id, "copied": copied}, ensure_ascii=True, indent=2))
if __name__ == "__main__":
main()