polyguard-openenv-workbench / polyguard-rl /scripts /pull_training_artifacts.py
TheJackBright's picture
Deploy GitHub root master to Space
c296d62
#!/usr/bin/env python3
"""Pull remote HF training artifacts back into the local PolyGuard repo."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import shutil
from huggingface_hub import snapshot_download
ROOT = Path(__file__).resolve().parents[1]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Download PolyGuard remote training artifacts.")
parser.add_argument("--artifact-repo-id", default="TheJackBright/polyguard-openenv-training-full-artifacts")
parser.add_argument("--cache-dir", default="/tmp/polyguard-training-artifacts")
parser.add_argument(
"--training-mode",
choices=["auto", "full", "sft-baseline"],
default="auto",
help="Artifact validation mode. Auto reads outputs/reports/hf_sweep_summary.json.",
)
return parser.parse_args()
def _copy_tree(src: Path, dst: Path) -> None:
if src.exists():
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copytree(src, dst, dirs_exist_ok=True)
def _mirror_docs_results() -> None:
docs = ROOT / "docs" / "results"
docs.mkdir(parents=True, exist_ok=True)
for directory in [ROOT / "outputs" / "reports", ROOT / "outputs" / "plots"]:
if not directory.exists():
continue
for path in directory.rglob("*"):
if path.is_file() and path.suffix.lower() in {".json", ".txt", ".png"}:
target = docs / path.relative_to(directory)
target.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(path, target)
def _assert_remote_training_ready(training_mode: str = "auto") -> None:
sweep_summary_path = ROOT / "outputs" / "reports" / "hf_sweep_summary.json"
anti_hacking_path = ROOT / "outputs" / "reports" / "anti_hacking_overfit_report.json"
sft_path = ROOT / "outputs" / "reports" / "sft_trl_run.json"
grpo_path = ROOT / "outputs" / "reports" / "grpo_trl_run.json"
postsave_path = ROOT / "outputs" / "reports" / "postsave_inference.json"
failures: list[str] = []
def read_json(path: Path) -> dict:
if not path.exists():
return {}
return json.loads(path.read_text(encoding="utf-8"))
sweep_summary = read_json(sweep_summary_path)
anti_hacking = read_json(anti_hacking_path)
if sweep_summary:
summary_mode = str(sweep_summary.get("training_mode") or "full")
effective_mode = summary_mode if training_mode == "auto" else training_mode
sft_only = effective_mode == "sft-baseline"
if int(sweep_summary.get("completed_models", 0) or 0) <= 0:
failures.append("HF sweep has no completed models")
for row in sweep_summary.get("models", []):
if not isinstance(row, dict) or row.get("status") != "completed":
continue
label = str(row.get("label") or row.get("model_id") or "model")
if row.get("fallback_detected"):
failures.append(f"{label} used fallback backend")
if not row.get("reward_range_ok"):
failures.append(f"{label} has reward range/precision failures")
artifact_paths = row.get("artifact_paths", {})
if not isinstance(artifact_paths, dict):
artifact_paths = {}
if not artifact_paths.get("sft"):
failures.append(f"{label} missing SFT artifact path")
if not sft_only and not artifact_paths.get("grpo"):
failures.append(f"{label} missing GRPO artifact path")
charts = sweep_summary.get("charts", {})
for chart_name, rel_path in charts.items():
if not (ROOT / str(rel_path)).exists():
failures.append(f"missing chart {chart_name}")
if anti_hacking.get("passed") is not True:
failures.append("anti-hacking/overfit report did not pass")
if failures:
raise SystemExit("artifact_checks_failed:" + "; ".join(failures))
return
sft = read_json(sft_path)
if sft.get("status") != "ok":
failures.append("SFT status is not ok")
if sft.get("backend") not in {"trl_unsloth", "trl_transformers"}:
failures.append("SFT backend is not TRL")
if not sft.get("artifact_path"):
failures.append("SFT artifact path is empty")
if int(sft.get("examples_used", 0) or 0) <= 0:
failures.append("SFT examples_used is zero")
grpo = read_json(grpo_path)
if grpo.get("status") != "ok":
failures.append("GRPO status is not ok")
if not grpo.get("artifact_path"):
failures.append("GRPO artifact path is empty")
postsave = read_json(postsave_path)
if postsave.get("model_source") == "fallback_policy":
failures.append("post-save inference still uses fallback policy")
if failures:
raise SystemExit("artifact_checks_failed:" + "; ".join(failures))
def main() -> None:
args = parse_args()
snapshot = Path(
snapshot_download(
repo_id=args.artifact_repo_id,
repo_type="model",
cache_dir=args.cache_dir,
allow_patterns=[
"outputs/reports/**",
"outputs/plots/**",
"docs/results/**",
"checkpoints/sft_adapter/**",
"checkpoints/grpo_adapter/**",
"checkpoints/merged/**",
"checkpoints/sweeps/**",
],
)
)
for rel in [
"outputs/reports",
"outputs/plots",
"docs/results",
"checkpoints/sft_adapter",
"checkpoints/grpo_adapter",
"checkpoints/merged",
"checkpoints/sweeps",
]:
_copy_tree(snapshot / rel, ROOT / rel)
_mirror_docs_results()
_assert_remote_training_ready(training_mode=args.training_mode)
print(f"artifacts_pulled_from={args.artifact_repo_id}")
if __name__ == "__main__":
main()