polyguard-openenv / scripts /pull_training_artifacts.py
TheJackBright's picture
Deploy PolyGuard OpenEnv Space
877add7 verified
#!/usr/bin/env python3
"""Pull remote HF training artifacts back into the local PolyGuard repo."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import shutil
from huggingface_hub import snapshot_download
ROOT = Path(__file__).resolve().parents[1]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Download PolyGuard remote training artifacts.")
parser.add_argument("--artifact-repo-id", default="TheJackBright/polyguard-openenv-training-artifacts")
parser.add_argument("--cache-dir", default="/tmp/polyguard-training-artifacts")
return parser.parse_args()
def _copy_tree(src: Path, dst: Path) -> None:
if src.exists():
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copytree(src, dst, dirs_exist_ok=True)
def _mirror_docs_results() -> None:
docs = ROOT / "docs" / "results"
docs.mkdir(parents=True, exist_ok=True)
for directory in [ROOT / "outputs" / "reports", ROOT / "outputs" / "plots"]:
if not directory.exists():
continue
for path in directory.iterdir():
if path.is_file() and path.suffix.lower() in {".json", ".txt", ".png"}:
shutil.copy2(path, docs / path.name)
def _assert_remote_training_ready() -> None:
sft_path = ROOT / "outputs" / "reports" / "sft_trl_run.json"
grpo_path = ROOT / "outputs" / "reports" / "grpo_trl_run.json"
postsave_path = ROOT / "outputs" / "reports" / "postsave_inference.json"
failures: list[str] = []
def read_json(path: Path) -> dict:
if not path.exists():
return {}
return json.loads(path.read_text(encoding="utf-8"))
sft = read_json(sft_path)
if sft.get("status") != "ok":
failures.append("SFT status is not ok")
if sft.get("backend") not in {"trl_unsloth", "trl_transformers"}:
failures.append("SFT backend is not TRL")
if not sft.get("artifact_path"):
failures.append("SFT artifact path is empty")
if int(sft.get("examples_used", 0) or 0) <= 0:
failures.append("SFT examples_used is zero")
grpo = read_json(grpo_path)
if grpo.get("status") != "ok":
failures.append("GRPO status is not ok")
if not grpo.get("artifact_path"):
failures.append("GRPO artifact path is empty")
postsave = read_json(postsave_path)
if postsave.get("model_source") == "fallback_policy":
failures.append("post-save inference still uses fallback policy")
if failures:
raise SystemExit("artifact_checks_failed:" + "; ".join(failures))
def main() -> None:
args = parse_args()
snapshot = Path(
snapshot_download(
repo_id=args.artifact_repo_id,
repo_type="model",
cache_dir=args.cache_dir,
allow_patterns=[
"outputs/reports/*",
"outputs/plots/*",
"docs/results/*",
"checkpoints/sft_adapter/*",
"checkpoints/grpo_adapter/*",
"checkpoints/merged/*",
],
)
)
for rel in [
"outputs/reports",
"outputs/plots",
"docs/results",
"checkpoints/sft_adapter",
"checkpoints/grpo_adapter",
"checkpoints/merged",
]:
_copy_tree(snapshot / rel, ROOT / rel)
_mirror_docs_results()
_assert_remote_training_ready()
print(f"artifacts_pulled_from={args.artifact_repo_id}")
if __name__ == "__main__":
main()