| |
| """Download a PolyGuard usable model bundle from the Hugging Face Hub and activate it. |
| |
| The default bundle matches the public artifact folder: |
| TheJackBright/polyguard-openenv-training-full-artifacts/main/ |
| usable_model_bundles/local-qwen-0-5b-active-smoke/ |
| |
| This copies checkpoints into ``checkpoints/active/`` and installs ``active_model_manifest.json``, |
| which ``app.models.policy.active_model`` reads when ``POLYGUARD_ENABLE_ACTIVE_MODEL`` is true. |
| |
| Usage: |
| python scripts/install_hf_active_bundle.py |
| python scripts/install_hf_active_bundle.py --no-reports |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import shutil |
| import sys |
| from datetime import datetime, timezone |
| from pathlib import Path |
| from typing import Any |
|
|
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| DEFAULT_REPO = "TheJackBright/polyguard-openenv-training-full-artifacts" |
| DEFAULT_BUNDLE = "usable_model_bundles/local-qwen-0-5b-active-smoke" |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser(description=__doc__) |
| p.add_argument("--repo-id", default=DEFAULT_REPO) |
| p.add_argument("--revision", default="main") |
| p.add_argument("--bundle-path", default=DEFAULT_BUNDLE, help="Path inside the repo to the bundle root.") |
| p.add_argument( |
| "--local-snapshot-dir", |
| default="", |
| help="Optional directory for snapshot_download (default: checkpoints/.hf_bundles/<bundle tail>).", |
| ) |
| p.add_argument("--no-reports", action="store_true", help="Skip copying reports into outputs/reports/active_model.") |
| p.add_argument( |
| "--touch-manifest-time", |
| action="store_true", |
| help="Set activated_at_utc in the installed manifest to now (for bookkeeping only).", |
| ) |
| return p.parse_args() |
|
|
|
|
| def _read_json(path: Path) -> dict[str, Any]: |
| return json.loads(path.read_text(encoding="utf-8")) |
|
|
|
|
| def _replace_tree(src: Path, dest: Path) -> None: |
| if dest.is_symlink() or dest.is_file(): |
| dest.unlink() |
| elif dest.exists(): |
| shutil.rmtree(dest) |
| dest.parent.mkdir(parents=True, exist_ok=True) |
| shutil.copytree(src, dest) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| try: |
| from huggingface_hub import snapshot_download |
| except ImportError as exc: |
| raise SystemExit("install huggingface_hub (pip install huggingface-hub)") from exc |
|
|
| bundle_tail = args.bundle_path.strip("/").split("/")[-1] |
| snap_root = ( |
| Path(args.local_snapshot_dir).expanduser().resolve() |
| if args.local_snapshot_dir |
| else (ROOT / "checkpoints" / ".hf_bundles" / bundle_tail) |
| ) |
| allow = f"{args.bundle_path.strip('/')}/**" |
|
|
| print(f"Downloading snapshot of {args.repo_id}@{args.revision} (pattern {allow}) …", flush=True) |
| try: |
| snapshot_download( |
| repo_id=args.repo_id, |
| repo_type="model", |
| revision=args.revision, |
| local_dir=str(snap_root), |
| allow_patterns=[allow], |
| ) |
| except Exception as exc: |
| err = f"{type(exc).__name__}: {exc}" |
| hint = ( |
| "\n[install_hf_active_bundle] Hub returned an error (401/404 often means the artifact repo is private or gated).\n" |
| " • Hugging Face Space: Space Settings → Secrets → add HF_TOKEN (read access to that model repo).\n" |
| " • Or change the repo to public / accept the license on the model card while logged in.\n" |
| " • Without a successful download, POLYGUARD falls back to heuristics / ranker (no local GRPO weights).\n" |
| ) |
| print(f"{hint} • Raw error: {err}\n", flush=True) |
| raise SystemExit(1) from exc |
|
|
| bundle_root = snap_root / args.bundle_path |
| ckpt_src = bundle_root / "checkpoints" |
| manifest_src = bundle_root / "manifests" / "active_model_manifest.json" |
| if not ckpt_src.is_dir(): |
| raise SystemExit(f"missing_bundle_checkpoints:{ckpt_src}") |
| if not manifest_src.is_file(): |
| raise SystemExit(f"missing_bundle_manifest:{manifest_src}") |
|
|
| active_dir = ROOT / "checkpoints" / "active" |
| active_dir.mkdir(parents=True, exist_ok=True) |
|
|
| for name in ("grpo_adapter", "merged", "sft_adapter"): |
| sub = ckpt_src / name |
| if not sub.is_dir(): |
| raise SystemExit(f"missing_artifact_dir:{sub}") |
| print(f"Installing checkpoints/active/{name} …", flush=True) |
| _replace_tree(sub, active_dir / name) |
|
|
| manifest = _read_json(manifest_src) |
| if args.touch_manifest_time: |
| manifest["activated_at_utc"] = datetime.now(timezone.utc).isoformat() |
| active_manifest = active_dir / "active_model_manifest.json" |
| active_manifest.write_text(json.dumps(manifest, ensure_ascii=True, indent=2), encoding="utf-8") |
| print(f"Wrote {active_manifest.relative_to(ROOT)}", flush=True) |
|
|
| if not args.no_reports: |
| rep_src = bundle_root / "reports" |
| rep_dest = ROOT / "outputs" / "reports" / "active_model" |
| if rep_src.is_dir(): |
| print(f"Copying reports → {rep_dest.relative_to(ROOT)} …", flush=True) |
| _replace_tree(rep_src, rep_dest) |
|
|
| docs_mirror = ROOT / "docs" / "results" / "active_model_manifest.json" |
| docs_mirror.parent.mkdir(parents=True, exist_ok=True) |
| shutil.copy2(active_manifest, docs_mirror) |
| print(f"Mirrored manifest to {docs_mirror.relative_to(ROOT)}", flush=True) |
|
|
| print( |
| "\nNext: set in .env (see .env.example):\n" |
| " POLYGUARD_ENABLE_ACTIVE_MODEL=true\n" |
| " POLYGUARD_HF_MODEL=Qwen/Qwen2.5-0.5B-Instruct\n" |
| "Prefer the trained Transformers checkpoint but keep Ollama as fallback:\n" |
| " POLYGUARD_PROVIDER_PREFERENCE=transformers,ollama\n" |
| "Or disable Ollama entirely:\n" |
| " POLYGUARD_ENABLE_OLLAMA=false\n" |
| "Then restart the API / env services.\n", |
| flush=True, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| try: |
| main() |
| except KeyboardInterrupt: |
| sys.exit(130) |
|
|