Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Pull remote HF training artifacts back into the local PolyGuard repo.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import shutil | |
| from huggingface_hub import snapshot_download | |
| ROOT = Path(__file__).resolve().parents[1] | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Download PolyGuard remote training artifacts.") | |
| parser.add_argument("--artifact-repo-id", default="TheJackBright/polyguard-openenv-training-artifacts") | |
| parser.add_argument("--cache-dir", default="/tmp/polyguard-training-artifacts") | |
| return parser.parse_args() | |
| def _copy_tree(src: Path, dst: Path) -> None: | |
| if src.exists(): | |
| dst.parent.mkdir(parents=True, exist_ok=True) | |
| shutil.copytree(src, dst, dirs_exist_ok=True) | |
| def _mirror_docs_results() -> None: | |
| docs = ROOT / "docs" / "results" | |
| docs.mkdir(parents=True, exist_ok=True) | |
| for directory in [ROOT / "outputs" / "reports", ROOT / "outputs" / "plots"]: | |
| if not directory.exists(): | |
| continue | |
| for path in directory.iterdir(): | |
| if path.is_file() and path.suffix.lower() in {".json", ".txt", ".png"}: | |
| shutil.copy2(path, docs / path.name) | |
| def _assert_remote_training_ready() -> None: | |
| sft_path = ROOT / "outputs" / "reports" / "sft_trl_run.json" | |
| grpo_path = ROOT / "outputs" / "reports" / "grpo_trl_run.json" | |
| postsave_path = ROOT / "outputs" / "reports" / "postsave_inference.json" | |
| failures: list[str] = [] | |
| def read_json(path: Path) -> dict: | |
| if not path.exists(): | |
| return {} | |
| return json.loads(path.read_text(encoding="utf-8")) | |
| sft = read_json(sft_path) | |
| if sft.get("status") != "ok": | |
| failures.append("SFT status is not ok") | |
| if sft.get("backend") not in {"trl_unsloth", "trl_transformers"}: | |
| failures.append("SFT backend is not TRL") | |
| if not sft.get("artifact_path"): | |
| failures.append("SFT artifact path is empty") | |
| if int(sft.get("examples_used", 0) or 0) <= 0: | |
| failures.append("SFT examples_used is zero") | |
| grpo = read_json(grpo_path) | |
| if grpo.get("status") != "ok": | |
| failures.append("GRPO status is not ok") | |
| if not grpo.get("artifact_path"): | |
| failures.append("GRPO artifact path is empty") | |
| postsave = read_json(postsave_path) | |
| if postsave.get("model_source") == "fallback_policy": | |
| failures.append("post-save inference still uses fallback policy") | |
| if failures: | |
| raise SystemExit("artifact_checks_failed:" + "; ".join(failures)) | |
| def main() -> None: | |
| args = parse_args() | |
| snapshot = Path( | |
| snapshot_download( | |
| repo_id=args.artifact_repo_id, | |
| repo_type="model", | |
| cache_dir=args.cache_dir, | |
| allow_patterns=[ | |
| "outputs/reports/*", | |
| "outputs/plots/*", | |
| "docs/results/*", | |
| "checkpoints/sft_adapter/*", | |
| "checkpoints/grpo_adapter/*", | |
| "checkpoints/merged/*", | |
| ], | |
| ) | |
| ) | |
| for rel in [ | |
| "outputs/reports", | |
| "outputs/plots", | |
| "docs/results", | |
| "checkpoints/sft_adapter", | |
| "checkpoints/grpo_adapter", | |
| "checkpoints/merged", | |
| ]: | |
| _copy_tree(snapshot / rel, ROOT / rel) | |
| _mirror_docs_results() | |
| _assert_remote_training_ready() | |
| print(f"artifacts_pulled_from={args.artifact_repo_id}") | |
| if __name__ == "__main__": | |
| main() | |