adithya9903's picture
Deploy PolyGuard HF training Space
fd0c71a verified
"""LLM provider runtime with Transformers-first fallback order.
The runtime is intentionally conservative: if an LLM backend is unavailable or
errors, selection falls back to deterministic local ranking.
"""
from __future__ import annotations
from dataclasses import dataclass
import json
import os
import shutil
import subprocess
import time
from typing import Any
from app.common.types import CandidateAction
from app.models.policy.active_model import active_model_status, available_artifact_path
from app.models.policy.safety_ranker import rank_candidates
def _extract_candidate_id(text: str, legal_ids: set[str]) -> str | None:
lowered = text.lower()
for candidate_id in legal_ids:
if candidate_id.lower() in lowered:
return candidate_id
return None
def _compact_prompt(candidates: list[CandidateAction], prompt: dict[str, Any]) -> str:
compact_candidates = [
{
"candidate_id": c.candidate_id,
"mode": c.mode.value,
"action_type": c.action_type.value,
"target_drug": c.target_drug,
"replacement_drug": c.replacement_drug,
"dose_bucket": c.dose_bucket.value,
"safety_delta": c.estimated_safety_delta,
"uncertainty": c.uncertainty_score,
"legal": c.legality_precheck,
"tags": c.rationale_tags[:4],
}
for c in candidates
]
payload = {
"instruction": "Select the safest legal medication action candidate_id.",
"context": prompt,
"candidate_ids": [c.candidate_id for c in candidates],
"candidates": compact_candidates,
"answer": "",
"format": "Return candidate_id=<one candidate_id>; rationale=<brief clinical reason>.",
}
return json.dumps(payload, ensure_ascii=True)
@dataclass(slots=True)
class ProviderSelection:
provider: str
candidate_id: str
rationale: str
latency_ms: float
raw_output: str = ""
class OllamaProvider:
name = "ollama"
def __init__(self, model_name: str) -> None:
self.model_name = model_name
def is_available(self) -> bool:
if os.getenv("POLYGUARD_ENABLE_OLLAMA", "false").lower() not in {"1", "true", "yes", "on"}:
return False
return shutil.which("ollama") is not None
def ensure_model(self) -> bool:
if not self.is_available():
return False
if os.getenv("POLYGUARD_OLLAMA_AUTO_PULL", "true").lower() not in {"1", "true", "yes", "on"}:
return True
try:
subprocess.run(
["ollama", "pull", self.model_name],
check=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
timeout=90,
)
return True
except Exception:
return False
def select(self, candidates: list[CandidateAction], prompt: dict[str, Any]) -> ProviderSelection | None:
if not self.is_available() or not candidates:
return None
self.ensure_model()
deadline_seconds = float(os.getenv("POLYGUARD_PROVIDER_TIMEOUT_SECONDS", "7.0"))
compact_candidates = [
{
"candidate_id": c.candidate_id,
"mode": c.mode.value,
"action_type": c.action_type.value,
"estimated_safety_delta": c.estimated_safety_delta,
"uncertainty_score": c.uncertainty_score,
"legality_precheck": c.legality_precheck,
}
for c in candidates
]
request = {
"instruction": "Return only JSON with fields candidate_id and rationale.",
"context": prompt,
"candidates": compact_candidates,
}
start = time.monotonic()
try:
proc = subprocess.run(
["ollama", "run", self.model_name, json.dumps(request, ensure_ascii=True)],
check=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
timeout=deadline_seconds,
)
elapsed_ms = (time.monotonic() - start) * 1000.0
raw = (proc.stdout or "").strip()
if not raw:
return None
data = json.loads(raw)
candidate_id = str(data.get("candidate_id", "")).strip()
if not candidate_id:
return None
if candidate_id not in {c.candidate_id for c in candidates}:
return None
rationale = str(data.get("rationale", "Ollama provider selection.")).strip() or "Ollama provider selection."
return ProviderSelection(
provider=self.name,
candidate_id=candidate_id,
rationale=rationale,
latency_ms=elapsed_ms,
raw_output=raw,
)
except Exception:
return None
class TransformersProvider:
name = "transformers"
def __init__(self, model_name: str) -> None:
self.model_name = model_name
self._model: Any | None = None
self._tokenizer: Any | None = None
self._model_source = ""
self._load_error = ""
def is_available(self) -> bool:
try:
import transformers # noqa: F401
return True
except Exception:
return False
def status(self) -> dict[str, Any]:
status = active_model_status()
status["provider"] = self.name
status["loaded_source"] = self._model_source
status["load_error"] = self._load_error
status["runtime_model_name"] = self.model_name
return status
def _load_artifact(self, artifact_name: str, artifact_path: Any, status: dict[str, Any]) -> bool:
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
artifact_path = os.fspath(artifact_path)
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
if artifact_name == "merged":
tokenizer = AutoTokenizer.from_pretrained(artifact_path)
model = AutoModelForCausalLM.from_pretrained(
artifact_path,
torch_dtype=dtype,
low_cpu_mem_usage=True,
)
source = "active_merged"
else:
from peft import PeftModel
base_model = str(status.get("base_model") or self.model_name)
tokenizer = AutoTokenizer.from_pretrained(base_model)
base = AutoModelForCausalLM.from_pretrained(
base_model,
torch_dtype=dtype,
low_cpu_mem_usage=True,
)
model = PeftModel.from_pretrained(base, artifact_path)
source = f"active_{artifact_name}"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
self._model = model
self._tokenizer = tokenizer
self._model_source = source
self._load_error = ""
return True
except Exception as exc: # noqa: BLE001
self._load_error = str(exc)
self._model = None
self._tokenizer = None
self._model_source = ""
return False
def _load_active_model(self) -> bool:
if self._model is not None and self._tokenizer is not None:
return True
status = active_model_status()
if available_artifact_path(status) is None:
return False
paths = status.get("paths", {})
availability = status.get("availability", {})
errors: list[str] = []
if not isinstance(paths, dict) or not isinstance(availability, dict):
return False
for artifact_name in status.get("load_order", []):
if not availability.get(artifact_name) or not paths.get(artifact_name):
continue
if self._load_artifact(str(artifact_name), paths[artifact_name], status):
return True
errors.append(f"{artifact_name}:{self._load_error}")
if errors:
self._load_error = " | ".join(errors)
return False
def _select_with_active_model(
self,
candidates: list[CandidateAction],
prompt: dict[str, Any],
) -> ProviderSelection | None:
if not self._load_active_model() or self._model is None or self._tokenizer is None:
return None
import torch
legal_ids = {c.candidate_id for c in candidates}
prompt_text = _compact_prompt(candidates, prompt)
max_new_tokens = int(os.getenv("POLYGUARD_PROVIDER_MAX_NEW_TOKENS", "64"))
started = time.monotonic()
try:
device = next(self._model.parameters()).device
encoded = self._tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=768)
encoded = {key: value.to(device) for key, value in encoded.items()}
with torch.no_grad():
generated = self._model.generate(
**encoded,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=0.0,
eos_token_id=self._tokenizer.eos_token_id,
pad_token_id=self._tokenizer.pad_token_id,
)
decoded = self._tokenizer.decode(generated[0], skip_special_tokens=True)
completion = decoded[len(prompt_text) :].strip() if decoded.startswith(prompt_text) else decoded
candidate_id = _extract_candidate_id(completion, legal_ids)
if candidate_id is None:
return None
rationale = completion.strip() or f"Active model selected {candidate_id}."
return ProviderSelection(
provider=self._model_source or self.name,
candidate_id=candidate_id,
rationale=rationale[:500],
latency_ms=(time.monotonic() - started) * 1000.0,
raw_output=completion,
)
except Exception as exc: # noqa: BLE001
self._load_error = str(exc)
return None
def select(self, candidates: list[CandidateAction], prompt: dict[str, Any]) -> ProviderSelection | None:
if not self.is_available() or not candidates:
return None
active_selection = self._select_with_active_model(candidates, prompt)
if active_selection is not None:
return active_selection
# Keep this lightweight and deterministic when no active artifact is
# configured or model loading fails.
start = time.monotonic()
top = rank_candidates(candidates)[0]
status = active_model_status()
load_note = f" active_model_error={self._load_error}" if self._load_error else ""
return ProviderSelection(
provider="transformers_ranker_fallback",
candidate_id=top.candidate_id,
rationale=(
f"Transformers fallback selected {top.candidate_id} via local ranker; "
f"active_model_enabled={status.get('enabled')}; active_model_available={status.get('active')}."
f"{load_note}"
),
latency_ms=(time.monotonic() - start) * 1000.0,
)
class PolicyProviderRouter:
def __init__(self, ollama_model: str = "qwen2.5:1.5b-instruct", hf_model: str = "Qwen/Qwen2.5-0.5B-Instruct") -> None:
self.ollama = OllamaProvider(ollama_model)
self.transformers = TransformersProvider(hf_model)
def select_candidate(
self,
candidates: list[CandidateAction],
prompt: dict[str, Any],
provider_preference: tuple[str, ...] = ("transformers",),
) -> ProviderSelection:
provider_preference = tuple(provider_preference) or ("transformers",)
for provider in provider_preference:
if provider == "ollama":
picked = self.ollama.select(candidates, prompt)
if picked is not None:
return picked
elif provider == "transformers":
picked = self.transformers.select(candidates, prompt)
if picked is not None:
return picked
# Deterministic hard fallback.
fallback = rank_candidates(candidates)[0]
return ProviderSelection(
provider="heuristic_fallback",
candidate_id=fallback.candidate_id,
rationale="Fallback ranker selected top legal/safety candidate.",
latency_ms=0.0,
)
def model_status(self) -> dict[str, Any]:
return self.transformers.status()