Spaces:
Running
Running
File size: 5,975 Bytes
877add7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | """LLM provider runtime with Transformers-first fallback order.
The runtime is intentionally conservative: if an LLM backend is unavailable or
errors, selection falls back to deterministic local ranking.
"""
from __future__ import annotations
from dataclasses import dataclass
import json
import os
import shutil
import subprocess
import time
from typing import Any
from app.common.types import CandidateAction
from app.models.policy.safety_ranker import rank_candidates
@dataclass(slots=True)
class ProviderSelection:
provider: str
candidate_id: str
rationale: str
latency_ms: float
raw_output: str = ""
class OllamaProvider:
name = "ollama"
def __init__(self, model_name: str) -> None:
self.model_name = model_name
def is_available(self) -> bool:
if os.getenv("POLYGUARD_ENABLE_OLLAMA", "false").lower() not in {"1", "true", "yes", "on"}:
return False
return shutil.which("ollama") is not None
def ensure_model(self) -> bool:
if not self.is_available():
return False
if os.getenv("POLYGUARD_OLLAMA_AUTO_PULL", "true").lower() not in {"1", "true", "yes", "on"}:
return True
try:
subprocess.run(
["ollama", "pull", self.model_name],
check=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
timeout=90,
)
return True
except Exception:
return False
def select(self, candidates: list[CandidateAction], prompt: dict[str, Any]) -> ProviderSelection | None:
if not self.is_available() or not candidates:
return None
self.ensure_model()
deadline_seconds = float(os.getenv("POLYGUARD_PROVIDER_TIMEOUT_SECONDS", "7.0"))
compact_candidates = [
{
"candidate_id": c.candidate_id,
"mode": c.mode.value,
"action_type": c.action_type.value,
"estimated_safety_delta": c.estimated_safety_delta,
"uncertainty_score": c.uncertainty_score,
"legality_precheck": c.legality_precheck,
}
for c in candidates
]
request = {
"instruction": "Return only JSON with fields candidate_id and rationale.",
"context": prompt,
"candidates": compact_candidates,
}
start = time.monotonic()
try:
proc = subprocess.run(
["ollama", "run", self.model_name, json.dumps(request, ensure_ascii=True)],
check=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
timeout=deadline_seconds,
)
elapsed_ms = (time.monotonic() - start) * 1000.0
raw = (proc.stdout or "").strip()
if not raw:
return None
data = json.loads(raw)
candidate_id = str(data.get("candidate_id", "")).strip()
if not candidate_id:
return None
if candidate_id not in {c.candidate_id for c in candidates}:
return None
rationale = str(data.get("rationale", "Ollama provider selection.")).strip() or "Ollama provider selection."
return ProviderSelection(
provider=self.name,
candidate_id=candidate_id,
rationale=rationale,
latency_ms=elapsed_ms,
raw_output=raw,
)
except Exception:
return None
class TransformersProvider:
name = "transformers"
def __init__(self, model_name: str) -> None:
self.model_name = model_name
def is_available(self) -> bool:
try:
import transformers # noqa: F401
return True
except Exception:
return False
def select(self, candidates: list[CandidateAction], prompt: dict[str, Any]) -> ProviderSelection | None:
_ = prompt
if not self.is_available() or not candidates:
return None
# Keep this lightweight and deterministic for local runs: prefer highest
# safety-adjusted candidate as a transformers fallback.
start = time.monotonic()
top = rank_candidates(candidates)[0]
return ProviderSelection(
provider=self.name,
candidate_id=top.candidate_id,
rationale=f"Transformers fallback selected {top.candidate_id} via local ranker.",
latency_ms=(time.monotonic() - start) * 1000.0,
)
class PolicyProviderRouter:
def __init__(self, ollama_model: str = "qwen2.5:1.5b-instruct", hf_model: str = "Qwen/Qwen2.5-0.5B-Instruct") -> None:
self.ollama = OllamaProvider(ollama_model)
self.transformers = TransformersProvider(hf_model)
def select_candidate(
self,
candidates: list[CandidateAction],
prompt: dict[str, Any],
provider_preference: tuple[str, ...] = ("transformers",),
) -> ProviderSelection:
provider_preference = tuple(provider_preference) or ("transformers",)
for provider in provider_preference:
if provider == "ollama":
picked = self.ollama.select(candidates, prompt)
if picked is not None:
return picked
elif provider == "transformers":
picked = self.transformers.select(candidates, prompt)
if picked is not None:
return picked
# Deterministic hard fallback.
fallback = rank_candidates(candidates)[0]
return ProviderSelection(
provider="heuristic_fallback",
candidate_id=fallback.candidate_id,
rationale="Fallback ranker selected top legal/safety candidate.",
latency_ms=0.0,
)
|