File size: 3,487 Bytes
21c7db9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/env python3
"""Pull one model-sweep run from the HF artifact repo when it is available."""

from __future__ import annotations

import argparse
import json
from pathlib import Path
import shutil

from huggingface_hub import HfApi, snapshot_download


ROOT = Path(__file__).resolve().parents[1]
DEFAULT_RUN_ID = "qwen-qwen2-5-0-5b-instruct"


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Download one PolyGuard HF sweep run.")
    parser.add_argument("--artifact-repo-id", default="TheJackBright/polyguard-openenv-training-full-artifacts")
    parser.add_argument("--run-id", default=DEFAULT_RUN_ID)
    parser.add_argument("--cache-dir", default="/tmp/polyguard-training-artifacts")
    parser.add_argument("--allow-missing", action="store_true")
    return parser.parse_args()


def _copy_tree(source: Path, target: Path) -> bool:
    if not source.exists():
        return False
    target.parent.mkdir(parents=True, exist_ok=True)
    shutil.copytree(source, target, dirs_exist_ok=True)
    return True


def main() -> None:
    args = parse_args()
    api = HfApi()
    files = api.list_repo_files(args.artifact_repo_id, repo_type="model")
    run_prefixes = [
        f"outputs/reports/sweeps/{args.run_id}/",
        f"docs/results/sweeps/{args.run_id}/",
        f"checkpoints/sweeps/{args.run_id}/",
    ]
    matched = [path for path in files if any(path.startswith(prefix) for prefix in run_prefixes)]
    if not matched and not args.allow_missing:
        raise SystemExit(
            "sweep_artifacts_not_uploaded_yet:"
            + json.dumps(
                {
                    "artifact_repo_id": args.artifact_repo_id,
                    "run_id": args.run_id,
                    "repo_file_count": len(files),
                    "available_files": files[:20],
                },
                ensure_ascii=True,
            )
        )

    allow_patterns = [
        f"outputs/reports/sweeps/{args.run_id}/**",
        f"docs/results/sweeps/{args.run_id}/**",
        f"checkpoints/sweeps/{args.run_id}/**",
        "outputs/plots/**",
        "outputs/reports/hf_sweep_summary.json",
        "outputs/reports/anti_hacking_overfit_report.json",
        "docs/results/*.png",
        "docs/results/hf_sweep_summary.json",
        "docs/results/anti_hacking_overfit_report.json",
    ]
    snapshot = Path(
        snapshot_download(
            repo_id=args.artifact_repo_id,
            repo_type="model",
            cache_dir=args.cache_dir,
            allow_patterns=allow_patterns,
        )
    )

    copied = []
    for rel in [
        f"outputs/reports/sweeps/{args.run_id}",
        f"docs/results/sweeps/{args.run_id}",
        f"checkpoints/sweeps/{args.run_id}",
        "outputs/plots",
    ]:
        if _copy_tree(snapshot / rel, ROOT / rel):
            copied.append(rel)
    for rel in [
        "outputs/reports/hf_sweep_summary.json",
        "outputs/reports/anti_hacking_overfit_report.json",
        "docs/results/hf_sweep_summary.json",
        "docs/results/anti_hacking_overfit_report.json",
    ]:
        source = snapshot / rel
        if source.exists():
            target = ROOT / rel
            target.parent.mkdir(parents=True, exist_ok=True)
            shutil.copy2(source, target)
            copied.append(rel)

    print(json.dumps({"status": "ok", "run_id": args.run_id, "copied": copied}, ensure_ascii=True, indent=2))


if __name__ == "__main__":
    main()