#!/usr/bin/env python3 """Pull remote HF training artifacts back into the local PolyGuard repo.""" from __future__ import annotations import argparse import json from pathlib import Path import shutil from huggingface_hub import snapshot_download ROOT = Path(__file__).resolve().parents[1] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Download PolyGuard remote training artifacts.") parser.add_argument("--artifact-repo-id", default="TheJackBright/polyguard-openenv-training-artifacts") parser.add_argument("--cache-dir", default="/tmp/polyguard-training-artifacts") return parser.parse_args() def _copy_tree(src: Path, dst: Path) -> None: if src.exists(): dst.parent.mkdir(parents=True, exist_ok=True) shutil.copytree(src, dst, dirs_exist_ok=True) def _mirror_docs_results() -> None: docs = ROOT / "docs" / "results" docs.mkdir(parents=True, exist_ok=True) for directory in [ROOT / "outputs" / "reports", ROOT / "outputs" / "plots"]: if not directory.exists(): continue for path in directory.iterdir(): if path.is_file() and path.suffix.lower() in {".json", ".txt", ".png"}: shutil.copy2(path, docs / path.name) def _assert_remote_training_ready() -> None: sft_path = ROOT / "outputs" / "reports" / "sft_trl_run.json" grpo_path = ROOT / "outputs" / "reports" / "grpo_trl_run.json" postsave_path = ROOT / "outputs" / "reports" / "postsave_inference.json" failures: list[str] = [] def read_json(path: Path) -> dict: if not path.exists(): return {} return json.loads(path.read_text(encoding="utf-8")) sft = read_json(sft_path) if sft.get("status") != "ok": failures.append("SFT status is not ok") if sft.get("backend") not in {"trl_unsloth", "trl_transformers"}: failures.append("SFT backend is not TRL") if not sft.get("artifact_path"): failures.append("SFT artifact path is empty") if int(sft.get("examples_used", 0) or 0) <= 0: failures.append("SFT examples_used is zero") grpo = read_json(grpo_path) if grpo.get("status") != "ok": failures.append("GRPO status is not ok") if not grpo.get("artifact_path"): failures.append("GRPO artifact path is empty") postsave = read_json(postsave_path) if postsave.get("model_source") == "fallback_policy": failures.append("post-save inference still uses fallback policy") if failures: raise SystemExit("artifact_checks_failed:" + "; ".join(failures)) def main() -> None: args = parse_args() snapshot = Path( snapshot_download( repo_id=args.artifact_repo_id, repo_type="model", cache_dir=args.cache_dir, allow_patterns=[ "outputs/reports/*", "outputs/plots/*", "docs/results/*", "checkpoints/sft_adapter/*", "checkpoints/grpo_adapter/*", "checkpoints/merged/*", ], ) ) for rel in [ "outputs/reports", "outputs/plots", "docs/results", "checkpoints/sft_adapter", "checkpoints/grpo_adapter", "checkpoints/merged", ]: _copy_tree(snapshot / rel, ROOT / rel) _mirror_docs_results() _assert_remote_training_ready() print(f"artifacts_pulled_from={args.artifact_repo_id}") if __name__ == "__main__": main()