polyguard-openenv / app /models /policy /provider_runtime.py
TheJackBright's picture
Deploy PolyGuard OpenEnv Space
877add7 verified
"""LLM provider runtime with Transformers-first fallback order.
The runtime is intentionally conservative: if an LLM backend is unavailable or
errors, selection falls back to deterministic local ranking.
"""
from __future__ import annotations
from dataclasses import dataclass
import json
import os
import shutil
import subprocess
import time
from typing import Any
from app.common.types import CandidateAction
from app.models.policy.safety_ranker import rank_candidates
@dataclass(slots=True)
class ProviderSelection:
provider: str
candidate_id: str
rationale: str
latency_ms: float
raw_output: str = ""
class OllamaProvider:
name = "ollama"
def __init__(self, model_name: str) -> None:
self.model_name = model_name
def is_available(self) -> bool:
if os.getenv("POLYGUARD_ENABLE_OLLAMA", "false").lower() not in {"1", "true", "yes", "on"}:
return False
return shutil.which("ollama") is not None
def ensure_model(self) -> bool:
if not self.is_available():
return False
if os.getenv("POLYGUARD_OLLAMA_AUTO_PULL", "true").lower() not in {"1", "true", "yes", "on"}:
return True
try:
subprocess.run(
["ollama", "pull", self.model_name],
check=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
timeout=90,
)
return True
except Exception:
return False
def select(self, candidates: list[CandidateAction], prompt: dict[str, Any]) -> ProviderSelection | None:
if not self.is_available() or not candidates:
return None
self.ensure_model()
deadline_seconds = float(os.getenv("POLYGUARD_PROVIDER_TIMEOUT_SECONDS", "7.0"))
compact_candidates = [
{
"candidate_id": c.candidate_id,
"mode": c.mode.value,
"action_type": c.action_type.value,
"estimated_safety_delta": c.estimated_safety_delta,
"uncertainty_score": c.uncertainty_score,
"legality_precheck": c.legality_precheck,
}
for c in candidates
]
request = {
"instruction": "Return only JSON with fields candidate_id and rationale.",
"context": prompt,
"candidates": compact_candidates,
}
start = time.monotonic()
try:
proc = subprocess.run(
["ollama", "run", self.model_name, json.dumps(request, ensure_ascii=True)],
check=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
timeout=deadline_seconds,
)
elapsed_ms = (time.monotonic() - start) * 1000.0
raw = (proc.stdout or "").strip()
if not raw:
return None
data = json.loads(raw)
candidate_id = str(data.get("candidate_id", "")).strip()
if not candidate_id:
return None
if candidate_id not in {c.candidate_id for c in candidates}:
return None
rationale = str(data.get("rationale", "Ollama provider selection.")).strip() or "Ollama provider selection."
return ProviderSelection(
provider=self.name,
candidate_id=candidate_id,
rationale=rationale,
latency_ms=elapsed_ms,
raw_output=raw,
)
except Exception:
return None
class TransformersProvider:
name = "transformers"
def __init__(self, model_name: str) -> None:
self.model_name = model_name
def is_available(self) -> bool:
try:
import transformers # noqa: F401
return True
except Exception:
return False
def select(self, candidates: list[CandidateAction], prompt: dict[str, Any]) -> ProviderSelection | None:
_ = prompt
if not self.is_available() or not candidates:
return None
# Keep this lightweight and deterministic for local runs: prefer highest
# safety-adjusted candidate as a transformers fallback.
start = time.monotonic()
top = rank_candidates(candidates)[0]
return ProviderSelection(
provider=self.name,
candidate_id=top.candidate_id,
rationale=f"Transformers fallback selected {top.candidate_id} via local ranker.",
latency_ms=(time.monotonic() - start) * 1000.0,
)
class PolicyProviderRouter:
def __init__(self, ollama_model: str = "qwen2.5:1.5b-instruct", hf_model: str = "Qwen/Qwen2.5-0.5B-Instruct") -> None:
self.ollama = OllamaProvider(ollama_model)
self.transformers = TransformersProvider(hf_model)
def select_candidate(
self,
candidates: list[CandidateAction],
prompt: dict[str, Any],
provider_preference: tuple[str, ...] = ("transformers",),
) -> ProviderSelection:
provider_preference = tuple(provider_preference) or ("transformers",)
for provider in provider_preference:
if provider == "ollama":
picked = self.ollama.select(candidates, prompt)
if picked is not None:
return picked
elif provider == "transformers":
picked = self.transformers.select(candidates, prompt)
if picked is not None:
return picked
# Deterministic hard fallback.
fallback = rank_candidates(candidates)[0]
return ProviderSelection(
provider="heuristic_fallback",
candidate_id=fallback.candidate_id,
rationale="Fallback ranker selected top legal/safety candidate.",
latency_ms=0.0,
)