polyguard-openenv-workbench / polyguard-rl /scripts /install_hf_active_bundle.py
TheJackBright's picture
Deploy GitHub root master to Space
c296d62
#!/usr/bin/env python3
"""Download a PolyGuard usable model bundle from the Hugging Face Hub and activate it.
The default bundle matches the public artifact folder:
TheJackBright/polyguard-openenv-training-full-artifacts/main/
usable_model_bundles/local-qwen-0-5b-active-smoke/
This copies checkpoints into ``checkpoints/active/`` and installs ``active_model_manifest.json``,
which ``app.models.policy.active_model`` reads when ``POLYGUARD_ENABLE_ACTIVE_MODEL`` is true.
Usage:
python scripts/install_hf_active_bundle.py
python scripts/install_hf_active_bundle.py --no-reports
"""
from __future__ import annotations
import argparse
import json
import shutil
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
ROOT = Path(__file__).resolve().parents[1]
DEFAULT_REPO = "TheJackBright/polyguard-openenv-training-full-artifacts"
DEFAULT_BUNDLE = "usable_model_bundles/local-qwen-0-5b-active-smoke"
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--repo-id", default=DEFAULT_REPO)
p.add_argument("--revision", default="main")
p.add_argument("--bundle-path", default=DEFAULT_BUNDLE, help="Path inside the repo to the bundle root.")
p.add_argument(
"--local-snapshot-dir",
default="",
help="Optional directory for snapshot_download (default: checkpoints/.hf_bundles/<bundle tail>).",
)
p.add_argument("--no-reports", action="store_true", help="Skip copying reports into outputs/reports/active_model.")
p.add_argument(
"--touch-manifest-time",
action="store_true",
help="Set activated_at_utc in the installed manifest to now (for bookkeeping only).",
)
return p.parse_args()
def _read_json(path: Path) -> dict[str, Any]:
return json.loads(path.read_text(encoding="utf-8"))
def _replace_tree(src: Path, dest: Path) -> None:
if dest.is_symlink() or dest.is_file():
dest.unlink()
elif dest.exists():
shutil.rmtree(dest)
dest.parent.mkdir(parents=True, exist_ok=True)
shutil.copytree(src, dest)
def main() -> None:
args = parse_args()
try:
from huggingface_hub import snapshot_download
except ImportError as exc:
raise SystemExit("install huggingface_hub (pip install huggingface-hub)") from exc
bundle_tail = args.bundle_path.strip("/").split("/")[-1]
snap_root = (
Path(args.local_snapshot_dir).expanduser().resolve()
if args.local_snapshot_dir
else (ROOT / "checkpoints" / ".hf_bundles" / bundle_tail)
)
allow = f"{args.bundle_path.strip('/')}/**"
print(f"Downloading snapshot of {args.repo_id}@{args.revision} (pattern {allow}) …", flush=True)
try:
snapshot_download(
repo_id=args.repo_id,
repo_type="model",
revision=args.revision,
local_dir=str(snap_root),
allow_patterns=[allow],
)
except Exception as exc:
err = f"{type(exc).__name__}: {exc}"
hint = (
"\n[install_hf_active_bundle] Hub returned an error (401/404 often means the artifact repo is private or gated).\n"
" • Hugging Face Space: Space Settings → Secrets → add HF_TOKEN (read access to that model repo).\n"
" • Or change the repo to public / accept the license on the model card while logged in.\n"
" • Without a successful download, POLYGUARD falls back to heuristics / ranker (no local GRPO weights).\n"
)
print(f"{hint} • Raw error: {err}\n", flush=True)
raise SystemExit(1) from exc
bundle_root = snap_root / args.bundle_path
ckpt_src = bundle_root / "checkpoints"
manifest_src = bundle_root / "manifests" / "active_model_manifest.json"
if not ckpt_src.is_dir():
raise SystemExit(f"missing_bundle_checkpoints:{ckpt_src}")
if not manifest_src.is_file():
raise SystemExit(f"missing_bundle_manifest:{manifest_src}")
active_dir = ROOT / "checkpoints" / "active"
active_dir.mkdir(parents=True, exist_ok=True)
for name in ("grpo_adapter", "merged", "sft_adapter"):
sub = ckpt_src / name
if not sub.is_dir():
raise SystemExit(f"missing_artifact_dir:{sub}")
print(f"Installing checkpoints/active/{name} …", flush=True)
_replace_tree(sub, active_dir / name)
manifest = _read_json(manifest_src)
if args.touch_manifest_time:
manifest["activated_at_utc"] = datetime.now(timezone.utc).isoformat()
active_manifest = active_dir / "active_model_manifest.json"
active_manifest.write_text(json.dumps(manifest, ensure_ascii=True, indent=2), encoding="utf-8")
print(f"Wrote {active_manifest.relative_to(ROOT)}", flush=True)
if not args.no_reports:
rep_src = bundle_root / "reports"
rep_dest = ROOT / "outputs" / "reports" / "active_model"
if rep_src.is_dir():
print(f"Copying reports → {rep_dest.relative_to(ROOT)} …", flush=True)
_replace_tree(rep_src, rep_dest)
docs_mirror = ROOT / "docs" / "results" / "active_model_manifest.json"
docs_mirror.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(active_manifest, docs_mirror)
print(f"Mirrored manifest to {docs_mirror.relative_to(ROOT)}", flush=True)
print(
"\nNext: set in .env (see .env.example):\n"
" POLYGUARD_ENABLE_ACTIVE_MODEL=true\n"
" POLYGUARD_HF_MODEL=Qwen/Qwen2.5-0.5B-Instruct\n"
"Prefer the trained Transformers checkpoint but keep Ollama as fallback:\n"
" POLYGUARD_PROVIDER_PREFERENCE=transformers,ollama\n"
"Or disable Ollama entirely:\n"
" POLYGUARD_ENABLE_OLLAMA=false\n"
"Then restart the API / env services.\n",
flush=True,
)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
sys.exit(130)