File size: 3,525 Bytes
877add7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python3
"""Pull remote HF training artifacts back into the local PolyGuard repo."""

from __future__ import annotations

import argparse
import json
from pathlib import Path
import shutil

from huggingface_hub import snapshot_download


ROOT = Path(__file__).resolve().parents[1]


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Download PolyGuard remote training artifacts.")
    parser.add_argument("--artifact-repo-id", default="TheJackBright/polyguard-openenv-training-artifacts")
    parser.add_argument("--cache-dir", default="/tmp/polyguard-training-artifacts")
    return parser.parse_args()


def _copy_tree(src: Path, dst: Path) -> None:
    if src.exists():
        dst.parent.mkdir(parents=True, exist_ok=True)
        shutil.copytree(src, dst, dirs_exist_ok=True)


def _mirror_docs_results() -> None:
    docs = ROOT / "docs" / "results"
    docs.mkdir(parents=True, exist_ok=True)
    for directory in [ROOT / "outputs" / "reports", ROOT / "outputs" / "plots"]:
        if not directory.exists():
            continue
        for path in directory.iterdir():
            if path.is_file() and path.suffix.lower() in {".json", ".txt", ".png"}:
                shutil.copy2(path, docs / path.name)


def _assert_remote_training_ready() -> None:
    sft_path = ROOT / "outputs" / "reports" / "sft_trl_run.json"
    grpo_path = ROOT / "outputs" / "reports" / "grpo_trl_run.json"
    postsave_path = ROOT / "outputs" / "reports" / "postsave_inference.json"
    failures: list[str] = []

    def read_json(path: Path) -> dict:
        if not path.exists():
            return {}
        return json.loads(path.read_text(encoding="utf-8"))

    sft = read_json(sft_path)
    if sft.get("status") != "ok":
        failures.append("SFT status is not ok")
    if sft.get("backend") not in {"trl_unsloth", "trl_transformers"}:
        failures.append("SFT backend is not TRL")
    if not sft.get("artifact_path"):
        failures.append("SFT artifact path is empty")
    if int(sft.get("examples_used", 0) or 0) <= 0:
        failures.append("SFT examples_used is zero")

    grpo = read_json(grpo_path)
    if grpo.get("status") != "ok":
        failures.append("GRPO status is not ok")
    if not grpo.get("artifact_path"):
        failures.append("GRPO artifact path is empty")

    postsave = read_json(postsave_path)
    if postsave.get("model_source") == "fallback_policy":
        failures.append("post-save inference still uses fallback policy")

    if failures:
        raise SystemExit("artifact_checks_failed:" + "; ".join(failures))


def main() -> None:
    args = parse_args()
    snapshot = Path(
        snapshot_download(
            repo_id=args.artifact_repo_id,
            repo_type="model",
            cache_dir=args.cache_dir,
            allow_patterns=[
                "outputs/reports/*",
                "outputs/plots/*",
                "docs/results/*",
                "checkpoints/sft_adapter/*",
                "checkpoints/grpo_adapter/*",
                "checkpoints/merged/*",
            ],
        )
    )

    for rel in [
        "outputs/reports",
        "outputs/plots",
        "docs/results",
        "checkpoints/sft_adapter",
        "checkpoints/grpo_adapter",
        "checkpoints/merged",
    ]:
        _copy_tree(snapshot / rel, ROOT / rel)

    _mirror_docs_results()
    _assert_remote_training_ready()
    print(f"artifacts_pulled_from={args.artifact_repo_id}")


if __name__ == "__main__":
    main()