#!/usr/bin/env python3 """Pull remote HF training artifacts back into the local PolyGuard repo.""" from __future__ import annotations import argparse import json from pathlib import Path import shutil from huggingface_hub import snapshot_download ROOT = Path(__file__).resolve().parents[1] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Download PolyGuard remote training artifacts.") parser.add_argument("--artifact-repo-id", default="TheJackBright/polyguard-openenv-training-full-artifacts") parser.add_argument("--cache-dir", default="/tmp/polyguard-training-artifacts") parser.add_argument( "--training-mode", choices=["auto", "full", "sft-baseline"], default="auto", help="Artifact validation mode. Auto reads outputs/reports/hf_sweep_summary.json.", ) return parser.parse_args() def _copy_tree(src: Path, dst: Path) -> None: if src.exists(): dst.parent.mkdir(parents=True, exist_ok=True) shutil.copytree(src, dst, dirs_exist_ok=True) def _mirror_docs_results() -> None: docs = ROOT / "docs" / "results" docs.mkdir(parents=True, exist_ok=True) for directory in [ROOT / "outputs" / "reports", ROOT / "outputs" / "plots"]: if not directory.exists(): continue for path in directory.rglob("*"): if path.is_file() and path.suffix.lower() in {".json", ".txt", ".png"}: target = docs / path.relative_to(directory) target.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(path, target) def _assert_remote_training_ready(training_mode: str = "auto") -> None: sweep_summary_path = ROOT / "outputs" / "reports" / "hf_sweep_summary.json" anti_hacking_path = ROOT / "outputs" / "reports" / "anti_hacking_overfit_report.json" sft_path = ROOT / "outputs" / "reports" / "sft_trl_run.json" grpo_path = ROOT / "outputs" / "reports" / "grpo_trl_run.json" postsave_path = ROOT / "outputs" / "reports" / "postsave_inference.json" failures: list[str] = [] def read_json(path: Path) -> dict: if not path.exists(): return {} return json.loads(path.read_text(encoding="utf-8")) sweep_summary = read_json(sweep_summary_path) anti_hacking = read_json(anti_hacking_path) if sweep_summary: summary_mode = str(sweep_summary.get("training_mode") or "full") effective_mode = summary_mode if training_mode == "auto" else training_mode sft_only = effective_mode == "sft-baseline" if int(sweep_summary.get("completed_models", 0) or 0) <= 0: failures.append("HF sweep has no completed models") for row in sweep_summary.get("models", []): if not isinstance(row, dict) or row.get("status") != "completed": continue label = str(row.get("label") or row.get("model_id") or "model") if row.get("fallback_detected"): failures.append(f"{label} used fallback backend") if not row.get("reward_range_ok"): failures.append(f"{label} has reward range/precision failures") artifact_paths = row.get("artifact_paths", {}) if not isinstance(artifact_paths, dict): artifact_paths = {} if not artifact_paths.get("sft"): failures.append(f"{label} missing SFT artifact path") if not sft_only and not artifact_paths.get("grpo"): failures.append(f"{label} missing GRPO artifact path") charts = sweep_summary.get("charts", {}) for chart_name, rel_path in charts.items(): if not (ROOT / str(rel_path)).exists(): failures.append(f"missing chart {chart_name}") if anti_hacking.get("passed") is not True: failures.append("anti-hacking/overfit report did not pass") if failures: raise SystemExit("artifact_checks_failed:" + "; ".join(failures)) return sft = read_json(sft_path) if sft.get("status") != "ok": failures.append("SFT status is not ok") if sft.get("backend") not in {"trl_unsloth", "trl_transformers"}: failures.append("SFT backend is not TRL") if not sft.get("artifact_path"): failures.append("SFT artifact path is empty") if int(sft.get("examples_used", 0) or 0) <= 0: failures.append("SFT examples_used is zero") grpo = read_json(grpo_path) if grpo.get("status") != "ok": failures.append("GRPO status is not ok") if not grpo.get("artifact_path"): failures.append("GRPO artifact path is empty") postsave = read_json(postsave_path) if postsave.get("model_source") == "fallback_policy": failures.append("post-save inference still uses fallback policy") if failures: raise SystemExit("artifact_checks_failed:" + "; ".join(failures)) def main() -> None: args = parse_args() snapshot = Path( snapshot_download( repo_id=args.artifact_repo_id, repo_type="model", cache_dir=args.cache_dir, allow_patterns=[ "outputs/reports/**", "outputs/plots/**", "docs/results/**", "checkpoints/sft_adapter/**", "checkpoints/grpo_adapter/**", "checkpoints/merged/**", "checkpoints/sweeps/**", ], ) ) for rel in [ "outputs/reports", "outputs/plots", "docs/results", "checkpoints/sft_adapter", "checkpoints/grpo_adapter", "checkpoints/merged", "checkpoints/sweeps", ]: _copy_tree(snapshot / rel, ROOT / rel) _mirror_docs_results() _assert_remote_training_ready(training_mode=args.training_mode) print(f"artifacts_pulled_from={args.artifact_repo_id}") if __name__ == "__main__": main()