File size: 5,959 Bytes
fd0c71a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | #!/usr/bin/env python3
"""Pull remote HF training artifacts back into the local PolyGuard repo."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import shutil
from huggingface_hub import snapshot_download
ROOT = Path(__file__).resolve().parents[1]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Download PolyGuard remote training artifacts.")
parser.add_argument("--artifact-repo-id", default="TheJackBright/polyguard-openenv-training-full-artifacts")
parser.add_argument("--cache-dir", default="/tmp/polyguard-training-artifacts")
parser.add_argument(
"--training-mode",
choices=["auto", "full", "sft-baseline"],
default="auto",
help="Artifact validation mode. Auto reads outputs/reports/hf_sweep_summary.json.",
)
return parser.parse_args()
def _copy_tree(src: Path, dst: Path) -> None:
if src.exists():
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copytree(src, dst, dirs_exist_ok=True)
def _mirror_docs_results() -> None:
docs = ROOT / "docs" / "results"
docs.mkdir(parents=True, exist_ok=True)
for directory in [ROOT / "outputs" / "reports", ROOT / "outputs" / "plots"]:
if not directory.exists():
continue
for path in directory.rglob("*"):
if path.is_file() and path.suffix.lower() in {".json", ".txt", ".png"}:
target = docs / path.relative_to(directory)
target.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(path, target)
def _assert_remote_training_ready(training_mode: str = "auto") -> None:
sweep_summary_path = ROOT / "outputs" / "reports" / "hf_sweep_summary.json"
anti_hacking_path = ROOT / "outputs" / "reports" / "anti_hacking_overfit_report.json"
sft_path = ROOT / "outputs" / "reports" / "sft_trl_run.json"
grpo_path = ROOT / "outputs" / "reports" / "grpo_trl_run.json"
postsave_path = ROOT / "outputs" / "reports" / "postsave_inference.json"
failures: list[str] = []
def read_json(path: Path) -> dict:
if not path.exists():
return {}
return json.loads(path.read_text(encoding="utf-8"))
sweep_summary = read_json(sweep_summary_path)
anti_hacking = read_json(anti_hacking_path)
if sweep_summary:
summary_mode = str(sweep_summary.get("training_mode") or "full")
effective_mode = summary_mode if training_mode == "auto" else training_mode
sft_only = effective_mode == "sft-baseline"
if int(sweep_summary.get("completed_models", 0) or 0) <= 0:
failures.append("HF sweep has no completed models")
for row in sweep_summary.get("models", []):
if not isinstance(row, dict) or row.get("status") != "completed":
continue
label = str(row.get("label") or row.get("model_id") or "model")
if row.get("fallback_detected"):
failures.append(f"{label} used fallback backend")
if not row.get("reward_range_ok"):
failures.append(f"{label} has reward range/precision failures")
artifact_paths = row.get("artifact_paths", {})
if not isinstance(artifact_paths, dict):
artifact_paths = {}
if not artifact_paths.get("sft"):
failures.append(f"{label} missing SFT artifact path")
if not sft_only and not artifact_paths.get("grpo"):
failures.append(f"{label} missing GRPO artifact path")
charts = sweep_summary.get("charts", {})
for chart_name, rel_path in charts.items():
if not (ROOT / str(rel_path)).exists():
failures.append(f"missing chart {chart_name}")
if anti_hacking.get("passed") is not True:
failures.append("anti-hacking/overfit report did not pass")
if failures:
raise SystemExit("artifact_checks_failed:" + "; ".join(failures))
return
sft = read_json(sft_path)
if sft.get("status") != "ok":
failures.append("SFT status is not ok")
if sft.get("backend") not in {"trl_unsloth", "trl_transformers"}:
failures.append("SFT backend is not TRL")
if not sft.get("artifact_path"):
failures.append("SFT artifact path is empty")
if int(sft.get("examples_used", 0) or 0) <= 0:
failures.append("SFT examples_used is zero")
grpo = read_json(grpo_path)
if grpo.get("status") != "ok":
failures.append("GRPO status is not ok")
if not grpo.get("artifact_path"):
failures.append("GRPO artifact path is empty")
postsave = read_json(postsave_path)
if postsave.get("model_source") == "fallback_policy":
failures.append("post-save inference still uses fallback policy")
if failures:
raise SystemExit("artifact_checks_failed:" + "; ".join(failures))
def main() -> None:
args = parse_args()
snapshot = Path(
snapshot_download(
repo_id=args.artifact_repo_id,
repo_type="model",
cache_dir=args.cache_dir,
allow_patterns=[
"outputs/reports/**",
"outputs/plots/**",
"docs/results/**",
"checkpoints/sft_adapter/**",
"checkpoints/grpo_adapter/**",
"checkpoints/merged/**",
"checkpoints/sweeps/**",
],
)
)
for rel in [
"outputs/reports",
"outputs/plots",
"docs/results",
"checkpoints/sft_adapter",
"checkpoints/grpo_adapter",
"checkpoints/merged",
"checkpoints/sweeps",
]:
_copy_tree(snapshot / rel, ROOT / rel)
_mirror_docs_results()
_assert_remote_training_ready(training_mode=args.training_mode)
print(f"artifacts_pulled_from={args.artifact_repo_id}")
if __name__ == "__main__":
main()
|