Spaces:
Running
Running
| """LLM provider runtime with Transformers-first fallback order. | |
| The runtime is intentionally conservative: if an LLM backend is unavailable or | |
| errors, selection falls back to deterministic local ranking. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| import json | |
| import os | |
| import shutil | |
| import subprocess | |
| import time | |
| from typing import Any | |
| from app.common.types import CandidateAction | |
| from app.models.policy.safety_ranker import rank_candidates | |
| class ProviderSelection: | |
| provider: str | |
| candidate_id: str | |
| rationale: str | |
| latency_ms: float | |
| raw_output: str = "" | |
| class OllamaProvider: | |
| name = "ollama" | |
| def __init__(self, model_name: str) -> None: | |
| self.model_name = model_name | |
| def is_available(self) -> bool: | |
| if os.getenv("POLYGUARD_ENABLE_OLLAMA", "false").lower() not in {"1", "true", "yes", "on"}: | |
| return False | |
| return shutil.which("ollama") is not None | |
| def ensure_model(self) -> bool: | |
| if not self.is_available(): | |
| return False | |
| if os.getenv("POLYGUARD_OLLAMA_AUTO_PULL", "true").lower() not in {"1", "true", "yes", "on"}: | |
| return True | |
| try: | |
| subprocess.run( | |
| ["ollama", "pull", self.model_name], | |
| check=False, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| text=True, | |
| timeout=90, | |
| ) | |
| return True | |
| except Exception: | |
| return False | |
| def select(self, candidates: list[CandidateAction], prompt: dict[str, Any]) -> ProviderSelection | None: | |
| if not self.is_available() or not candidates: | |
| return None | |
| self.ensure_model() | |
| deadline_seconds = float(os.getenv("POLYGUARD_PROVIDER_TIMEOUT_SECONDS", "7.0")) | |
| compact_candidates = [ | |
| { | |
| "candidate_id": c.candidate_id, | |
| "mode": c.mode.value, | |
| "action_type": c.action_type.value, | |
| "estimated_safety_delta": c.estimated_safety_delta, | |
| "uncertainty_score": c.uncertainty_score, | |
| "legality_precheck": c.legality_precheck, | |
| } | |
| for c in candidates | |
| ] | |
| request = { | |
| "instruction": "Return only JSON with fields candidate_id and rationale.", | |
| "context": prompt, | |
| "candidates": compact_candidates, | |
| } | |
| start = time.monotonic() | |
| try: | |
| proc = subprocess.run( | |
| ["ollama", "run", self.model_name, json.dumps(request, ensure_ascii=True)], | |
| check=False, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| text=True, | |
| timeout=deadline_seconds, | |
| ) | |
| elapsed_ms = (time.monotonic() - start) * 1000.0 | |
| raw = (proc.stdout or "").strip() | |
| if not raw: | |
| return None | |
| data = json.loads(raw) | |
| candidate_id = str(data.get("candidate_id", "")).strip() | |
| if not candidate_id: | |
| return None | |
| if candidate_id not in {c.candidate_id for c in candidates}: | |
| return None | |
| rationale = str(data.get("rationale", "Ollama provider selection.")).strip() or "Ollama provider selection." | |
| return ProviderSelection( | |
| provider=self.name, | |
| candidate_id=candidate_id, | |
| rationale=rationale, | |
| latency_ms=elapsed_ms, | |
| raw_output=raw, | |
| ) | |
| except Exception: | |
| return None | |
| class TransformersProvider: | |
| name = "transformers" | |
| def __init__(self, model_name: str) -> None: | |
| self.model_name = model_name | |
| def is_available(self) -> bool: | |
| try: | |
| import transformers # noqa: F401 | |
| return True | |
| except Exception: | |
| return False | |
| def select(self, candidates: list[CandidateAction], prompt: dict[str, Any]) -> ProviderSelection | None: | |
| _ = prompt | |
| if not self.is_available() or not candidates: | |
| return None | |
| # Keep this lightweight and deterministic for local runs: prefer highest | |
| # safety-adjusted candidate as a transformers fallback. | |
| start = time.monotonic() | |
| top = rank_candidates(candidates)[0] | |
| return ProviderSelection( | |
| provider=self.name, | |
| candidate_id=top.candidate_id, | |
| rationale=f"Transformers fallback selected {top.candidate_id} via local ranker.", | |
| latency_ms=(time.monotonic() - start) * 1000.0, | |
| ) | |
| class PolicyProviderRouter: | |
| def __init__(self, ollama_model: str = "qwen2.5:1.5b-instruct", hf_model: str = "Qwen/Qwen2.5-0.5B-Instruct") -> None: | |
| self.ollama = OllamaProvider(ollama_model) | |
| self.transformers = TransformersProvider(hf_model) | |
| def select_candidate( | |
| self, | |
| candidates: list[CandidateAction], | |
| prompt: dict[str, Any], | |
| provider_preference: tuple[str, ...] = ("transformers",), | |
| ) -> ProviderSelection: | |
| provider_preference = tuple(provider_preference) or ("transformers",) | |
| for provider in provider_preference: | |
| if provider == "ollama": | |
| picked = self.ollama.select(candidates, prompt) | |
| if picked is not None: | |
| return picked | |
| elif provider == "transformers": | |
| picked = self.transformers.select(candidates, prompt) | |
| if picked is not None: | |
| return picked | |
| # Deterministic hard fallback. | |
| fallback = rank_candidates(candidates)[0] | |
| return ProviderSelection( | |
| provider="heuristic_fallback", | |
| candidate_id=fallback.candidate_id, | |
| rationale="Fallback ranker selected top legal/safety candidate.", | |
| latency_ms=0.0, | |
| ) | |