| |
| """Pull remote HF training artifacts back into the local PolyGuard repo.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
| import shutil |
|
|
| from huggingface_hub import snapshot_download |
|
|
|
|
| ROOT = Path(__file__).resolve().parents[1] |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Download PolyGuard remote training artifacts.") |
| parser.add_argument("--artifact-repo-id", default="TheJackBright/polyguard-openenv-training-full-artifacts") |
| parser.add_argument("--cache-dir", default="/tmp/polyguard-training-artifacts") |
| parser.add_argument( |
| "--training-mode", |
| choices=["auto", "full", "sft-baseline"], |
| default="auto", |
| help="Artifact validation mode. Auto reads outputs/reports/hf_sweep_summary.json.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def _copy_tree(src: Path, dst: Path) -> None: |
| if src.exists(): |
| dst.parent.mkdir(parents=True, exist_ok=True) |
| shutil.copytree(src, dst, dirs_exist_ok=True) |
|
|
|
|
| def _mirror_docs_results() -> None: |
| docs = ROOT / "docs" / "results" |
| docs.mkdir(parents=True, exist_ok=True) |
| for directory in [ROOT / "outputs" / "reports", ROOT / "outputs" / "plots"]: |
| if not directory.exists(): |
| continue |
| for path in directory.rglob("*"): |
| if path.is_file() and path.suffix.lower() in {".json", ".txt", ".png"}: |
| target = docs / path.relative_to(directory) |
| target.parent.mkdir(parents=True, exist_ok=True) |
| shutil.copy2(path, target) |
|
|
|
|
| def _assert_remote_training_ready(training_mode: str = "auto") -> None: |
| sweep_summary_path = ROOT / "outputs" / "reports" / "hf_sweep_summary.json" |
| anti_hacking_path = ROOT / "outputs" / "reports" / "anti_hacking_overfit_report.json" |
| sft_path = ROOT / "outputs" / "reports" / "sft_trl_run.json" |
| grpo_path = ROOT / "outputs" / "reports" / "grpo_trl_run.json" |
| postsave_path = ROOT / "outputs" / "reports" / "postsave_inference.json" |
| failures: list[str] = [] |
|
|
| def read_json(path: Path) -> dict: |
| if not path.exists(): |
| return {} |
| return json.loads(path.read_text(encoding="utf-8")) |
|
|
| sweep_summary = read_json(sweep_summary_path) |
| anti_hacking = read_json(anti_hacking_path) |
| if sweep_summary: |
| summary_mode = str(sweep_summary.get("training_mode") or "full") |
| effective_mode = summary_mode if training_mode == "auto" else training_mode |
| sft_only = effective_mode == "sft-baseline" |
| if int(sweep_summary.get("completed_models", 0) or 0) <= 0: |
| failures.append("HF sweep has no completed models") |
| for row in sweep_summary.get("models", []): |
| if not isinstance(row, dict) or row.get("status") != "completed": |
| continue |
| label = str(row.get("label") or row.get("model_id") or "model") |
| if row.get("fallback_detected"): |
| failures.append(f"{label} used fallback backend") |
| if not row.get("reward_range_ok"): |
| failures.append(f"{label} has reward range/precision failures") |
| artifact_paths = row.get("artifact_paths", {}) |
| if not isinstance(artifact_paths, dict): |
| artifact_paths = {} |
| if not artifact_paths.get("sft"): |
| failures.append(f"{label} missing SFT artifact path") |
| if not sft_only and not artifact_paths.get("grpo"): |
| failures.append(f"{label} missing GRPO artifact path") |
| charts = sweep_summary.get("charts", {}) |
| for chart_name, rel_path in charts.items(): |
| if not (ROOT / str(rel_path)).exists(): |
| failures.append(f"missing chart {chart_name}") |
| if anti_hacking.get("passed") is not True: |
| failures.append("anti-hacking/overfit report did not pass") |
| if failures: |
| raise SystemExit("artifact_checks_failed:" + "; ".join(failures)) |
| return |
|
|
| sft = read_json(sft_path) |
| if sft.get("status") != "ok": |
| failures.append("SFT status is not ok") |
| if sft.get("backend") not in {"trl_unsloth", "trl_transformers"}: |
| failures.append("SFT backend is not TRL") |
| if not sft.get("artifact_path"): |
| failures.append("SFT artifact path is empty") |
| if int(sft.get("examples_used", 0) or 0) <= 0: |
| failures.append("SFT examples_used is zero") |
|
|
| grpo = read_json(grpo_path) |
| if grpo.get("status") != "ok": |
| failures.append("GRPO status is not ok") |
| if not grpo.get("artifact_path"): |
| failures.append("GRPO artifact path is empty") |
|
|
| postsave = read_json(postsave_path) |
| if postsave.get("model_source") == "fallback_policy": |
| failures.append("post-save inference still uses fallback policy") |
|
|
| if failures: |
| raise SystemExit("artifact_checks_failed:" + "; ".join(failures)) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| snapshot = Path( |
| snapshot_download( |
| repo_id=args.artifact_repo_id, |
| repo_type="model", |
| cache_dir=args.cache_dir, |
| allow_patterns=[ |
| "outputs/reports/**", |
| "outputs/plots/**", |
| "docs/results/**", |
| "checkpoints/sft_adapter/**", |
| "checkpoints/grpo_adapter/**", |
| "checkpoints/merged/**", |
| "checkpoints/sweeps/**", |
| ], |
| ) |
| ) |
|
|
| for rel in [ |
| "outputs/reports", |
| "outputs/plots", |
| "docs/results", |
| "checkpoints/sft_adapter", |
| "checkpoints/grpo_adapter", |
| "checkpoints/merged", |
| "checkpoints/sweeps", |
| ]: |
| _copy_tree(snapshot / rel, ROOT / rel) |
|
|
| _mirror_docs_results() |
| _assert_remote_training_ready(training_mode=args.training_mode) |
| print(f"artifacts_pulled_from={args.artifact_repo_id}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|