#!/usr/bin/env python3 """Download a PolyGuard usable model bundle from the Hugging Face Hub and activate it. The default bundle matches the public artifact folder: TheJackBright/polyguard-openenv-training-full-artifacts/main/ usable_model_bundles/local-qwen-0-5b-active-smoke/ This copies checkpoints into ``checkpoints/active/`` and installs ``active_model_manifest.json``, which ``app.models.policy.active_model`` reads when ``POLYGUARD_ENABLE_ACTIVE_MODEL`` is true. Usage: python scripts/install_hf_active_bundle.py python scripts/install_hf_active_bundle.py --no-reports """ from __future__ import annotations import argparse import json import shutil import sys from datetime import datetime, timezone from pathlib import Path from typing import Any ROOT = Path(__file__).resolve().parents[1] DEFAULT_REPO = "TheJackBright/polyguard-openenv-training-full-artifacts" DEFAULT_BUNDLE = "usable_model_bundles/local-qwen-0-5b-active-smoke" def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description=__doc__) p.add_argument("--repo-id", default=DEFAULT_REPO) p.add_argument("--revision", default="main") p.add_argument("--bundle-path", default=DEFAULT_BUNDLE, help="Path inside the repo to the bundle root.") p.add_argument( "--local-snapshot-dir", default="", help="Optional directory for snapshot_download (default: checkpoints/.hf_bundles/).", ) p.add_argument("--no-reports", action="store_true", help="Skip copying reports into outputs/reports/active_model.") p.add_argument( "--touch-manifest-time", action="store_true", help="Set activated_at_utc in the installed manifest to now (for bookkeeping only).", ) return p.parse_args() def _read_json(path: Path) -> dict[str, Any]: return json.loads(path.read_text(encoding="utf-8")) def _replace_tree(src: Path, dest: Path) -> None: if dest.is_symlink() or dest.is_file(): dest.unlink() elif dest.exists(): shutil.rmtree(dest) dest.parent.mkdir(parents=True, exist_ok=True) shutil.copytree(src, dest) def main() -> None: args = parse_args() try: from huggingface_hub import snapshot_download except ImportError as exc: raise SystemExit("install huggingface_hub (pip install huggingface-hub)") from exc bundle_tail = args.bundle_path.strip("/").split("/")[-1] snap_root = ( Path(args.local_snapshot_dir).expanduser().resolve() if args.local_snapshot_dir else (ROOT / "checkpoints" / ".hf_bundles" / bundle_tail) ) allow = f"{args.bundle_path.strip('/')}/**" print(f"Downloading snapshot of {args.repo_id}@{args.revision} (pattern {allow}) …", flush=True) try: snapshot_download( repo_id=args.repo_id, repo_type="model", revision=args.revision, local_dir=str(snap_root), allow_patterns=[allow], ) except Exception as exc: err = f"{type(exc).__name__}: {exc}" hint = ( "\n[install_hf_active_bundle] Hub returned an error (401/404 often means the artifact repo is private or gated).\n" " • Hugging Face Space: Space Settings → Secrets → add HF_TOKEN (read access to that model repo).\n" " • Or change the repo to public / accept the license on the model card while logged in.\n" " • Without a successful download, POLYGUARD falls back to heuristics / ranker (no local GRPO weights).\n" ) print(f"{hint} • Raw error: {err}\n", flush=True) raise SystemExit(1) from exc bundle_root = snap_root / args.bundle_path ckpt_src = bundle_root / "checkpoints" manifest_src = bundle_root / "manifests" / "active_model_manifest.json" if not ckpt_src.is_dir(): raise SystemExit(f"missing_bundle_checkpoints:{ckpt_src}") if not manifest_src.is_file(): raise SystemExit(f"missing_bundle_manifest:{manifest_src}") active_dir = ROOT / "checkpoints" / "active" active_dir.mkdir(parents=True, exist_ok=True) for name in ("grpo_adapter", "merged", "sft_adapter"): sub = ckpt_src / name if not sub.is_dir(): raise SystemExit(f"missing_artifact_dir:{sub}") print(f"Installing checkpoints/active/{name} …", flush=True) _replace_tree(sub, active_dir / name) manifest = _read_json(manifest_src) if args.touch_manifest_time: manifest["activated_at_utc"] = datetime.now(timezone.utc).isoformat() active_manifest = active_dir / "active_model_manifest.json" active_manifest.write_text(json.dumps(manifest, ensure_ascii=True, indent=2), encoding="utf-8") print(f"Wrote {active_manifest.relative_to(ROOT)}", flush=True) if not args.no_reports: rep_src = bundle_root / "reports" rep_dest = ROOT / "outputs" / "reports" / "active_model" if rep_src.is_dir(): print(f"Copying reports → {rep_dest.relative_to(ROOT)} …", flush=True) _replace_tree(rep_src, rep_dest) docs_mirror = ROOT / "docs" / "results" / "active_model_manifest.json" docs_mirror.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(active_manifest, docs_mirror) print(f"Mirrored manifest to {docs_mirror.relative_to(ROOT)}", flush=True) print( "\nNext: set in .env (see .env.example):\n" " POLYGUARD_ENABLE_ACTIVE_MODEL=true\n" " POLYGUARD_HF_MODEL=Qwen/Qwen2.5-0.5B-Instruct\n" "Prefer the trained Transformers checkpoint but keep Ollama as fallback:\n" " POLYGUARD_PROVIDER_PREFERENCE=transformers,ollama\n" "Or disable Ollama entirely:\n" " POLYGUARD_ENABLE_OLLAMA=false\n" "Then restart the API / env services.\n", flush=True, ) if __name__ == "__main__": try: main() except KeyboardInterrupt: sys.exit(130)