| """LLM provider runtime with Transformers-first fallback order. |
| |
| The runtime is intentionally conservative: if an LLM backend is unavailable or |
| errors, selection falls back to deterministic local ranking. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| import json |
| import os |
| import shutil |
| import subprocess |
| import time |
| from typing import Any |
|
|
| from app.common.types import CandidateAction |
| from app.models.policy.active_model import active_model_status, available_artifact_path |
| from app.models.policy.safety_ranker import rank_candidates |
|
|
|
|
| def _extract_candidate_id(text: str, legal_ids: set[str]) -> str | None: |
| lowered = text.lower() |
| for candidate_id in legal_ids: |
| if candidate_id.lower() in lowered: |
| return candidate_id |
| return None |
|
|
|
|
| def _compact_prompt(candidates: list[CandidateAction], prompt: dict[str, Any]) -> str: |
| compact_candidates = [ |
| { |
| "candidate_id": c.candidate_id, |
| "mode": c.mode.value, |
| "action_type": c.action_type.value, |
| "target_drug": c.target_drug, |
| "replacement_drug": c.replacement_drug, |
| "dose_bucket": c.dose_bucket.value, |
| "safety_delta": c.estimated_safety_delta, |
| "uncertainty": c.uncertainty_score, |
| "legal": c.legality_precheck, |
| "tags": c.rationale_tags[:4], |
| } |
| for c in candidates |
| ] |
| payload = { |
| "instruction": "Select the safest legal medication action candidate_id.", |
| "context": prompt, |
| "candidate_ids": [c.candidate_id for c in candidates], |
| "candidates": compact_candidates, |
| "answer": "", |
| "format": "Return candidate_id=<one candidate_id>; rationale=<brief clinical reason>.", |
| } |
| return json.dumps(payload, ensure_ascii=True) |
|
|
|
|
| @dataclass(slots=True) |
| class ProviderSelection: |
| provider: str |
| candidate_id: str |
| rationale: str |
| latency_ms: float |
| raw_output: str = "" |
|
|
|
|
| class OllamaProvider: |
| name = "ollama" |
|
|
| def __init__(self, model_name: str) -> None: |
| self.model_name = model_name |
|
|
| def is_available(self) -> bool: |
| if os.getenv("POLYGUARD_ENABLE_OLLAMA", "false").lower() not in {"1", "true", "yes", "on"}: |
| return False |
| return shutil.which("ollama") is not None |
|
|
| def ensure_model(self) -> bool: |
| if not self.is_available(): |
| return False |
| if os.getenv("POLYGUARD_OLLAMA_AUTO_PULL", "true").lower() not in {"1", "true", "yes", "on"}: |
| return True |
| try: |
| subprocess.run( |
| ["ollama", "pull", self.model_name], |
| check=False, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.PIPE, |
| text=True, |
| timeout=90, |
| ) |
| return True |
| except Exception: |
| return False |
|
|
| def select(self, candidates: list[CandidateAction], prompt: dict[str, Any]) -> ProviderSelection | None: |
| if not self.is_available() or not candidates: |
| return None |
| self.ensure_model() |
| deadline_seconds = float(os.getenv("POLYGUARD_PROVIDER_TIMEOUT_SECONDS", "7.0")) |
| compact_candidates = [ |
| { |
| "candidate_id": c.candidate_id, |
| "mode": c.mode.value, |
| "action_type": c.action_type.value, |
| "estimated_safety_delta": c.estimated_safety_delta, |
| "uncertainty_score": c.uncertainty_score, |
| "legality_precheck": c.legality_precheck, |
| } |
| for c in candidates |
| ] |
| request = { |
| "instruction": "Return only JSON with fields candidate_id and rationale.", |
| "context": prompt, |
| "candidates": compact_candidates, |
| } |
| start = time.monotonic() |
| try: |
| proc = subprocess.run( |
| ["ollama", "run", self.model_name, json.dumps(request, ensure_ascii=True)], |
| check=False, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.PIPE, |
| text=True, |
| timeout=deadline_seconds, |
| ) |
| elapsed_ms = (time.monotonic() - start) * 1000.0 |
| raw = (proc.stdout or "").strip() |
| if not raw: |
| return None |
| data = json.loads(raw) |
| candidate_id = str(data.get("candidate_id", "")).strip() |
| if not candidate_id: |
| return None |
| if candidate_id not in {c.candidate_id for c in candidates}: |
| return None |
| rationale = str(data.get("rationale", "Ollama provider selection.")).strip() or "Ollama provider selection." |
| return ProviderSelection( |
| provider=self.name, |
| candidate_id=candidate_id, |
| rationale=rationale, |
| latency_ms=elapsed_ms, |
| raw_output=raw, |
| ) |
| except Exception: |
| return None |
|
|
|
|
| class TransformersProvider: |
| name = "transformers" |
|
|
| def __init__(self, model_name: str) -> None: |
| self.model_name = model_name |
| self._model: Any | None = None |
| self._tokenizer: Any | None = None |
| self._model_source = "" |
| self._load_error = "" |
|
|
| def is_available(self) -> bool: |
| try: |
| import transformers |
|
|
| return True |
| except Exception: |
| return False |
|
|
| def status(self) -> dict[str, Any]: |
| status = active_model_status() |
| status["provider"] = self.name |
| status["loaded_source"] = self._model_source |
| status["load_error"] = self._load_error |
| status["runtime_model_name"] = self.model_name |
| return status |
|
|
| def _load_artifact(self, artifact_name: str, artifact_path: Any, status: dict[str, Any]) -> bool: |
| try: |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| artifact_path = os.fspath(artifact_path) |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| if artifact_name == "merged": |
| tokenizer = AutoTokenizer.from_pretrained(artifact_path) |
| model = AutoModelForCausalLM.from_pretrained( |
| artifact_path, |
| torch_dtype=dtype, |
| low_cpu_mem_usage=True, |
| ) |
| source = "active_merged" |
| else: |
| from peft import PeftModel |
|
|
| base_model = str(status.get("base_model") or self.model_name) |
| tokenizer = AutoTokenizer.from_pretrained(base_model) |
| base = AutoModelForCausalLM.from_pretrained( |
| base_model, |
| torch_dtype=dtype, |
| low_cpu_mem_usage=True, |
| ) |
| model = PeftModel.from_pretrained(base, artifact_path) |
| source = f"active_{artifact_name}" |
|
|
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = model.to(device) |
| model.eval() |
| self._model = model |
| self._tokenizer = tokenizer |
| self._model_source = source |
| self._load_error = "" |
| return True |
| except Exception as exc: |
| self._load_error = str(exc) |
| self._model = None |
| self._tokenizer = None |
| self._model_source = "" |
| return False |
|
|
| def _load_active_model(self) -> bool: |
| if self._model is not None and self._tokenizer is not None: |
| return True |
|
|
| status = active_model_status() |
| if available_artifact_path(status) is None: |
| return False |
|
|
| paths = status.get("paths", {}) |
| availability = status.get("availability", {}) |
| errors: list[str] = [] |
| if not isinstance(paths, dict) or not isinstance(availability, dict): |
| return False |
| for artifact_name in status.get("load_order", []): |
| if not availability.get(artifact_name) or not paths.get(artifact_name): |
| continue |
| if self._load_artifact(str(artifact_name), paths[artifact_name], status): |
| return True |
| errors.append(f"{artifact_name}:{self._load_error}") |
| if errors: |
| self._load_error = " | ".join(errors) |
| return False |
|
|
| def _select_with_active_model( |
| self, |
| candidates: list[CandidateAction], |
| prompt: dict[str, Any], |
| ) -> ProviderSelection | None: |
| if not self._load_active_model() or self._model is None or self._tokenizer is None: |
| return None |
|
|
| import torch |
|
|
| legal_ids = {c.candidate_id for c in candidates} |
| prompt_text = _compact_prompt(candidates, prompt) |
| max_new_tokens = int(os.getenv("POLYGUARD_PROVIDER_MAX_NEW_TOKENS", "64")) |
| started = time.monotonic() |
| try: |
| device = next(self._model.parameters()).device |
| encoded = self._tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=768) |
| encoded = {key: value.to(device) for key, value in encoded.items()} |
| with torch.no_grad(): |
| generated = self._model.generate( |
| **encoded, |
| max_new_tokens=max_new_tokens, |
| do_sample=False, |
| temperature=0.0, |
| eos_token_id=self._tokenizer.eos_token_id, |
| pad_token_id=self._tokenizer.pad_token_id, |
| ) |
| decoded = self._tokenizer.decode(generated[0], skip_special_tokens=True) |
| completion = decoded[len(prompt_text) :].strip() if decoded.startswith(prompt_text) else decoded |
| candidate_id = _extract_candidate_id(completion, legal_ids) |
| if candidate_id is None: |
| return None |
| rationale = completion.strip() or f"Active model selected {candidate_id}." |
| return ProviderSelection( |
| provider=self._model_source or self.name, |
| candidate_id=candidate_id, |
| rationale=rationale[:500], |
| latency_ms=(time.monotonic() - started) * 1000.0, |
| raw_output=completion, |
| ) |
| except Exception as exc: |
| self._load_error = str(exc) |
| return None |
|
|
| def select(self, candidates: list[CandidateAction], prompt: dict[str, Any]) -> ProviderSelection | None: |
| if not self.is_available() or not candidates: |
| return None |
|
|
| active_selection = self._select_with_active_model(candidates, prompt) |
| if active_selection is not None: |
| return active_selection |
|
|
| |
| |
| start = time.monotonic() |
| top = rank_candidates(candidates)[0] |
| status = active_model_status() |
| load_note = f" active_model_error={self._load_error}" if self._load_error else "" |
| return ProviderSelection( |
| provider="transformers_ranker_fallback", |
| candidate_id=top.candidate_id, |
| rationale=( |
| f"Transformers fallback selected {top.candidate_id} via local ranker; " |
| f"active_model_enabled={status.get('enabled')}; active_model_available={status.get('active')}." |
| f"{load_note}" |
| ), |
| latency_ms=(time.monotonic() - start) * 1000.0, |
| ) |
|
|
|
|
| class PolicyProviderRouter: |
| def __init__(self, ollama_model: str = "qwen2.5:1.5b-instruct", hf_model: str = "Qwen/Qwen2.5-0.5B-Instruct") -> None: |
| self.ollama = OllamaProvider(ollama_model) |
| self.transformers = TransformersProvider(hf_model) |
|
|
| def select_candidate( |
| self, |
| candidates: list[CandidateAction], |
| prompt: dict[str, Any], |
| provider_preference: tuple[str, ...] = ("transformers",), |
| ) -> ProviderSelection: |
| provider_preference = tuple(provider_preference) or ("transformers",) |
|
|
| for provider in provider_preference: |
| if provider == "ollama": |
| picked = self.ollama.select(candidates, prompt) |
| if picked is not None: |
| return picked |
| elif provider == "transformers": |
| picked = self.transformers.select(candidates, prompt) |
| if picked is not None: |
| return picked |
|
|
| |
| fallback = rank_candidates(candidates)[0] |
| return ProviderSelection( |
| provider="heuristic_fallback", |
| candidate_id=fallback.candidate_id, |
| rationale="Fallback ranker selected top legal/safety candidate.", |
| latency_ms=0.0, |
| ) |
|
|
| def model_status(self) -> dict[str, Any]: |
| return self.transformers.status() |
|
|