File size: 5,951 Bytes
21c7db9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e21fe7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21c7db9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python3
"""Download a PolyGuard usable model bundle from the Hugging Face Hub and activate it.

The default bundle matches the public artifact folder:
  TheJackBright/polyguard-openenv-training-full-artifacts/main/
  usable_model_bundles/local-qwen-0-5b-active-smoke/

This copies checkpoints into ``checkpoints/active/`` and installs ``active_model_manifest.json``,
which ``app.models.policy.active_model`` reads when ``POLYGUARD_ENABLE_ACTIVE_MODEL`` is true.

Usage:
  python scripts/install_hf_active_bundle.py
  python scripts/install_hf_active_bundle.py --no-reports
"""

from __future__ import annotations

import argparse
import json
import shutil
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any


ROOT = Path(__file__).resolve().parents[1]
DEFAULT_REPO = "TheJackBright/polyguard-openenv-training-full-artifacts"
DEFAULT_BUNDLE = "usable_model_bundles/local-qwen-0-5b-active-smoke"


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description=__doc__)
    p.add_argument("--repo-id", default=DEFAULT_REPO)
    p.add_argument("--revision", default="main")
    p.add_argument("--bundle-path", default=DEFAULT_BUNDLE, help="Path inside the repo to the bundle root.")
    p.add_argument(
        "--local-snapshot-dir",
        default="",
        help="Optional directory for snapshot_download (default: checkpoints/.hf_bundles/<bundle tail>).",
    )
    p.add_argument("--no-reports", action="store_true", help="Skip copying reports into outputs/reports/active_model.")
    p.add_argument(
        "--touch-manifest-time",
        action="store_true",
        help="Set activated_at_utc in the installed manifest to now (for bookkeeping only).",
    )
    return p.parse_args()


def _read_json(path: Path) -> dict[str, Any]:
    return json.loads(path.read_text(encoding="utf-8"))


def _replace_tree(src: Path, dest: Path) -> None:
    if dest.is_symlink() or dest.is_file():
        dest.unlink()
    elif dest.exists():
        shutil.rmtree(dest)
    dest.parent.mkdir(parents=True, exist_ok=True)
    shutil.copytree(src, dest)


def main() -> None:
    args = parse_args()
    try:
        from huggingface_hub import snapshot_download
    except ImportError as exc:
        raise SystemExit("install huggingface_hub (pip install huggingface-hub)") from exc

    bundle_tail = args.bundle_path.strip("/").split("/")[-1]
    snap_root = (
        Path(args.local_snapshot_dir).expanduser().resolve()
        if args.local_snapshot_dir
        else (ROOT / "checkpoints" / ".hf_bundles" / bundle_tail)
    )
    allow = f"{args.bundle_path.strip('/')}/**"

    print(f"Downloading snapshot of {args.repo_id}@{args.revision} (pattern {allow}) …", flush=True)
    try:
        snapshot_download(
            repo_id=args.repo_id,
            repo_type="model",
            revision=args.revision,
            local_dir=str(snap_root),
            allow_patterns=[allow],
        )
    except Exception as exc:
        err = f"{type(exc).__name__}: {exc}"
        hint = (
            "\n[install_hf_active_bundle] Hub returned an error (401/404 often means the artifact repo is private or gated).\n"
            "  • Hugging Face Space: Space Settings → Secrets → add HF_TOKEN (read access to that model repo).\n"
            "  • Or change the repo to public / accept the license on the model card while logged in.\n"
            "  • Without a successful download, POLYGUARD falls back to heuristics / ranker (no local GRPO weights).\n"
        )
        print(f"{hint}  • Raw error: {err}\n", flush=True)
        raise SystemExit(1) from exc

    bundle_root = snap_root / args.bundle_path
    ckpt_src = bundle_root / "checkpoints"
    manifest_src = bundle_root / "manifests" / "active_model_manifest.json"
    if not ckpt_src.is_dir():
        raise SystemExit(f"missing_bundle_checkpoints:{ckpt_src}")
    if not manifest_src.is_file():
        raise SystemExit(f"missing_bundle_manifest:{manifest_src}")

    active_dir = ROOT / "checkpoints" / "active"
    active_dir.mkdir(parents=True, exist_ok=True)

    for name in ("grpo_adapter", "merged", "sft_adapter"):
        sub = ckpt_src / name
        if not sub.is_dir():
            raise SystemExit(f"missing_artifact_dir:{sub}")
        print(f"Installing checkpoints/active/{name} …", flush=True)
        _replace_tree(sub, active_dir / name)

    manifest = _read_json(manifest_src)
    if args.touch_manifest_time:
        manifest["activated_at_utc"] = datetime.now(timezone.utc).isoformat()
    active_manifest = active_dir / "active_model_manifest.json"
    active_manifest.write_text(json.dumps(manifest, ensure_ascii=True, indent=2), encoding="utf-8")
    print(f"Wrote {active_manifest.relative_to(ROOT)}", flush=True)

    if not args.no_reports:
        rep_src = bundle_root / "reports"
        rep_dest = ROOT / "outputs" / "reports" / "active_model"
        if rep_src.is_dir():
            print(f"Copying reports → {rep_dest.relative_to(ROOT)} …", flush=True)
            _replace_tree(rep_src, rep_dest)

    docs_mirror = ROOT / "docs" / "results" / "active_model_manifest.json"
    docs_mirror.parent.mkdir(parents=True, exist_ok=True)
    shutil.copy2(active_manifest, docs_mirror)
    print(f"Mirrored manifest to {docs_mirror.relative_to(ROOT)}", flush=True)

    print(
        "\nNext: set in .env (see .env.example):\n"
        "  POLYGUARD_ENABLE_ACTIVE_MODEL=true\n"
        "  POLYGUARD_HF_MODEL=Qwen/Qwen2.5-0.5B-Instruct\n"
        "Prefer the trained Transformers checkpoint but keep Ollama as fallback:\n"
        "  POLYGUARD_PROVIDER_PREFERENCE=transformers,ollama\n"
        "Or disable Ollama entirely:\n"
        "  POLYGUARD_ENABLE_OLLAMA=false\n"
        "Then restart the API / env services.\n",
        flush=True,
    )


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        sys.exit(130)