| """LLM provider runtime with Transformers-first fallback order. |
| |
| The runtime is intentionally conservative: if an LLM backend is unavailable or |
| errors, selection falls back to deterministic local ranking. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| import json |
| import os |
| import re |
| import shutil |
| import subprocess |
| import time |
| from pathlib import Path |
| from typing import Any |
|
|
| from app.common.types import CandidateAction |
| from app.models.policy.active_model import active_model_status, available_artifact_path |
| from app.models.policy.safety_ranker import rank_candidates |
|
|
|
|
| def _transformers_low_cpu_mem() -> bool: |
| """Use lazy/meta init only on CUDA; on CPU it often breaks ``.to(device)`` (meta tensors).""" |
| try: |
| import torch |
|
|
| return torch.cuda.is_available() |
| except Exception: |
| return False |
|
|
|
|
| def _peft_base_model_id(artifact_path: str, status: dict[str, Any], fallback: str) -> str: |
| cfg = Path(artifact_path) / "adapter_config.json" |
| if cfg.is_file(): |
| try: |
| payload = json.loads(cfg.read_text(encoding="utf-8")) |
| raw = payload.get("base_model_name_or_path") |
| if isinstance(raw, str) and raw.strip(): |
| return raw.strip() |
| except Exception: |
| pass |
| return str(status.get("base_model") or fallback) |
|
|
|
|
| def _env_truthy(name: str, default: bool = False) -> bool: |
| raw = os.getenv(name) |
| if raw is None: |
| return default |
| return raw.strip().lower() in {"1", "true", "yes", "on"} |
|
|
|
|
| def default_provider_preference() -> tuple[str, ...]: |
| raw = os.getenv("POLYGUARD_PROVIDER_PREFERENCE", "").strip() |
| if raw: |
| order = tuple(p.strip().lower() for p in raw.split(",") if p.strip()) |
| cleaned = tuple(p for p in order if p in {"ollama", "transformers"}) |
| if cleaned: |
| return cleaned |
| if _env_truthy("POLYGUARD_ENABLE_OLLAMA"): |
| return ("ollama", "transformers") |
| return ("transformers",) |
|
|
|
|
| def _extract_candidate_id(text: str, legal_ids: set[str]) -> str | None: |
| lowered = text.lower() |
| for candidate_id in sorted(legal_ids): |
| if candidate_id.lower() in lowered: |
| return candidate_id |
| return None |
|
|
|
|
| def _compact_prompt(candidates: list[CandidateAction], prompt: dict[str, Any]) -> str: |
| compact_candidates = [ |
| { |
| "candidate_id": c.candidate_id, |
| "mode": c.mode.value, |
| "action_type": c.action_type.value, |
| "target_drug": c.target_drug, |
| "replacement_drug": c.replacement_drug, |
| "dose_bucket": c.dose_bucket.value, |
| "safety_delta": c.estimated_safety_delta, |
| "uncertainty": c.uncertainty_score, |
| "legal": c.legality_precheck, |
| "tags": c.rationale_tags[:4], |
| } |
| for c in candidates |
| ] |
| payload = { |
| "instruction": "Select the safest legal medication action candidate_id.", |
| "context": prompt, |
| "candidate_ids": [c.candidate_id for c in candidates], |
| "candidates": compact_candidates, |
| "answer": "", |
| "format": "Return candidate_id=<one candidate_id>; rationale=<brief clinical reason>.", |
| } |
| return json.dumps(payload, ensure_ascii=True) |
|
|
|
|
| @dataclass(slots=True) |
| class ProviderSelection: |
| provider: str |
| candidate_id: str |
| rationale: str |
| latency_ms: float |
| raw_output: str = "" |
|
|
|
|
| class OllamaProvider: |
| name = "ollama" |
|
|
| def __init__(self, model_name: str) -> None: |
| self.model_name = model_name |
| self._last_error = "" |
|
|
| def is_available(self) -> bool: |
| if os.getenv("POLYGUARD_ENABLE_OLLAMA", "false").lower() not in {"1", "true", "yes", "on"}: |
| return False |
| return shutil.which("ollama") is not None |
|
|
| def ensure_model(self) -> bool: |
| if not self.is_available(): |
| return False |
| if os.getenv("POLYGUARD_OLLAMA_AUTO_PULL", "true").lower() not in {"1", "true", "yes", "on"}: |
| return True |
| try: |
| subprocess.run( |
| ["ollama", "pull", self.model_name], |
| check=False, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.PIPE, |
| text=True, |
| timeout=90, |
| ) |
| return True |
| except Exception: |
| return False |
|
|
| def select(self, candidates: list[CandidateAction], prompt: dict[str, Any]) -> ProviderSelection | None: |
| if not self.is_available() or not candidates: |
| return None |
| self.ensure_model() |
| deadline_seconds = float(os.getenv("POLYGUARD_PROVIDER_TIMEOUT_SECONDS", "25.0")) |
| legal_ids = {c.candidate_id for c in candidates} |
| compact_candidates = [ |
| { |
| "candidate_id": c.candidate_id, |
| "mode": c.mode.value, |
| "action_type": c.action_type.value, |
| "estimated_safety_delta": c.estimated_safety_delta, |
| "uncertainty_score": c.uncertainty_score, |
| "legality_precheck": c.legality_precheck, |
| } |
| for c in candidates |
| ] |
| request = { |
| "instruction": ( |
| "Choose exactly one safest legal medication action. " |
| "Return a single JSON object only: {\"candidate_id\":\"cand_XX\",\"rationale\":\"brief reason\"}. " |
| "Do not return arrays or multiple candidates." |
| ), |
| "context": prompt, |
| "candidates": compact_candidates, |
| } |
| start = time.monotonic() |
| try: |
| prompt_text = json.dumps(request, ensure_ascii=True) |
| proc = subprocess.run( |
| ["ollama", "run", self.model_name], |
| check=False, |
| input=prompt_text, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.PIPE, |
| text=True, |
| timeout=deadline_seconds, |
| env={**os.environ, "TERM": "dumb", "NO_COLOR": "1"}, |
| ) |
| elapsed_ms = (time.monotonic() - start) * 1000.0 |
| if proc.returncode != 0: |
| self._last_error = (proc.stderr or "ollama run failed").strip()[:500] |
| return None |
| raw = re.sub(r"\x1b\[[0-?]*[ -/]*[@-~]", "", proc.stdout or "").strip() |
| if not raw: |
| self._last_error = (proc.stderr or "ollama returned empty output").strip()[:500] |
| return None |
| try: |
| data = json.loads(raw) |
| except json.JSONDecodeError: |
| data = {} |
| parsed_candidate = data.get("candidate_id") if isinstance(data, dict) else None |
| if isinstance(parsed_candidate, list): |
| parsed_candidate = next((str(item) for item in parsed_candidate if str(item) in legal_ids), "") |
| candidate_id = str(parsed_candidate or "").strip() or (_extract_candidate_id(raw, legal_ids) or "") |
| if not candidate_id or candidate_id not in legal_ids: |
| self._last_error = f"ollama returned no legal candidate_id: {raw[:240]}" |
| return None |
| parsed_rationale = data.get("rationale") if isinstance(data, dict) else None |
| if isinstance(parsed_rationale, list): |
| parsed_rationale = " ".join(str(item) for item in parsed_rationale[:2]) |
| rationale = str(parsed_rationale or "Ollama provider selection.").strip() or "Ollama provider selection." |
| self._last_error = "" |
| return ProviderSelection( |
| provider=self.name, |
| candidate_id=candidate_id, |
| rationale=rationale, |
| latency_ms=elapsed_ms, |
| raw_output=raw, |
| ) |
| except Exception as exc: |
| self._last_error = str(exc)[:500] |
| return None |
|
|
| def status(self) -> dict[str, Any]: |
| return { |
| "enabled": _env_truthy("POLYGUARD_ENABLE_OLLAMA"), |
| "available": self.is_available(), |
| "model": self.model_name, |
| "provider": self.name, |
| "last_error": self._last_error, |
| } |
|
|
|
|
| class TransformersProvider: |
| name = "transformers" |
|
|
| def __init__(self, model_name: str) -> None: |
| self.model_name = model_name |
| self._model: Any | None = None |
| self._tokenizer: Any | None = None |
| self._model_source = "" |
| self._load_error = "" |
|
|
| def is_available(self) -> bool: |
| try: |
| import transformers |
|
|
| return True |
| except Exception: |
| return False |
|
|
| def status(self) -> dict[str, Any]: |
| status = active_model_status() |
| status["provider"] = self.name |
| status["loaded_source"] = self._model_source |
| status["load_error"] = self._load_error |
| status["runtime_model_name"] = self.model_name |
| return status |
|
|
| def _load_artifact(self, artifact_name: str, artifact_path: Any, status: dict[str, Any]) -> bool: |
| try: |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| artifact_path = os.fspath(artifact_path) |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| low_mem = _transformers_low_cpu_mem() |
| if artifact_name == "merged": |
| tokenizer = AutoTokenizer.from_pretrained(artifact_path, trust_remote_code=True) |
| model = AutoModelForCausalLM.from_pretrained( |
| artifact_path, |
| dtype=dtype, |
| low_cpu_mem_usage=low_mem, |
| trust_remote_code=True, |
| ) |
| source = "active_merged" |
| else: |
| from peft import PeftModel |
|
|
| base_model = _peft_base_model_id(artifact_path, status, self.model_name) |
| tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) |
| base = AutoModelForCausalLM.from_pretrained( |
| base_model, |
| dtype=dtype, |
| low_cpu_mem_usage=low_mem, |
| trust_remote_code=True, |
| ) |
| model = PeftModel.from_pretrained(base, artifact_path) |
| source = f"active_{artifact_name}" |
|
|
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = model.to(device) |
| model.eval() |
| self._model = model |
| self._tokenizer = tokenizer |
| self._model_source = source |
| self._load_error = "" |
| return True |
| except Exception as exc: |
| self._load_error = str(exc) |
| self._model = None |
| self._tokenizer = None |
| self._model_source = "" |
| return False |
|
|
| def _load_active_model(self) -> bool: |
| if self._model is not None and self._tokenizer is not None: |
| return True |
|
|
| status = active_model_status() |
| if available_artifact_path(status) is None: |
| return False |
|
|
| paths = status.get("paths", {}) |
| availability = status.get("availability", {}) |
| errors: list[str] = [] |
| if not isinstance(paths, dict) or not isinstance(availability, dict): |
| return False |
| for artifact_name in status.get("load_order", []): |
| if not availability.get(artifact_name) or not paths.get(artifact_name): |
| continue |
| if self._load_artifact(str(artifact_name), paths[artifact_name], status): |
| return True |
| errors.append(f"{artifact_name}:{self._load_error}") |
| if errors: |
| self._load_error = " | ".join(errors) |
| return False |
|
|
| def _select_with_active_model( |
| self, |
| candidates: list[CandidateAction], |
| prompt: dict[str, Any], |
| ) -> ProviderSelection | None: |
| if not self._load_active_model() or self._model is None or self._tokenizer is None: |
| return None |
|
|
| import torch |
|
|
| legal_ids = {c.candidate_id for c in candidates} |
| prompt_text = _compact_prompt(candidates, prompt) |
| max_new_tokens = int(os.getenv("POLYGUARD_PROVIDER_MAX_NEW_TOKENS", "64")) |
| started = time.monotonic() |
| try: |
| device = next(self._model.parameters()).device |
| encoded = self._tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=768) |
| encoded = {key: value.to(device) for key, value in encoded.items()} |
| with torch.no_grad(): |
| generated = self._model.generate( |
| **encoded, |
| max_new_tokens=max_new_tokens, |
| do_sample=False, |
| temperature=0.0, |
| eos_token_id=self._tokenizer.eos_token_id, |
| pad_token_id=self._tokenizer.pad_token_id, |
| ) |
| decoded = self._tokenizer.decode(generated[0], skip_special_tokens=True) |
| completion = decoded[len(prompt_text) :].strip() if decoded.startswith(prompt_text) else decoded |
| candidate_id = _extract_candidate_id(completion, legal_ids) |
| if candidate_id is None: |
| return None |
| rationale = completion.strip() or f"Active model selected {candidate_id}." |
| return ProviderSelection( |
| provider=self._model_source or self.name, |
| candidate_id=candidate_id, |
| rationale=rationale[:500], |
| latency_ms=(time.monotonic() - started) * 1000.0, |
| raw_output=completion, |
| ) |
| except Exception as exc: |
| self._load_error = str(exc) |
| return None |
|
|
| def select(self, candidates: list[CandidateAction], prompt: dict[str, Any]) -> ProviderSelection | None: |
| if not self.is_available() or not candidates: |
| return None |
|
|
| active_selection = self._select_with_active_model(candidates, prompt) |
| if active_selection is not None: |
| return active_selection |
|
|
| |
| |
| start = time.monotonic() |
| top = rank_candidates(candidates)[0] |
| status = active_model_status() |
| load_note = f" active_model_error={self._load_error}" if self._load_error else "" |
| return ProviderSelection( |
| provider="transformers_ranker_fallback", |
| candidate_id=top.candidate_id, |
| rationale=( |
| f"Transformers fallback selected {top.candidate_id} via local ranker; " |
| f"active_model_enabled={status.get('enabled')}; active_model_available={status.get('active')}." |
| f"{load_note}" |
| ), |
| latency_ms=(time.monotonic() - start) * 1000.0, |
| ) |
|
|
|
|
| class PolicyProviderRouter: |
| def __init__(self, ollama_model: str = "qwen2.5:1.5b-instruct", hf_model: str = "Qwen/Qwen2.5-0.5B-Instruct") -> None: |
| self.ollama = OllamaProvider(os.getenv("POLYGUARD_OLLAMA_MODEL", ollama_model)) |
| self.transformers = TransformersProvider( |
| os.getenv("POLYGUARD_HF_MODEL") or os.getenv("POLYGUARD_FRONTIER_MODEL") or hf_model |
| ) |
|
|
| def select_candidate( |
| self, |
| candidates: list[CandidateAction], |
| prompt: dict[str, Any], |
| provider_preference: tuple[str, ...] | None = None, |
| ) -> ProviderSelection: |
| provider_preference = tuple(provider_preference or default_provider_preference()) |
|
|
| for provider in provider_preference: |
| if provider == "ollama": |
| picked = self.ollama.select(candidates, prompt) |
| if picked is not None: |
| return picked |
| elif provider == "transformers": |
| picked = self.transformers.select(candidates, prompt) |
| if picked is not None: |
| return picked |
|
|
| |
| fallback = rank_candidates(candidates)[0] |
| return ProviderSelection( |
| provider="heuristic_fallback", |
| candidate_id=fallback.candidate_id, |
| rationale="Fallback ranker selected top legal/safety candidate.", |
| latency_ms=0.0, |
| ) |
|
|
| def model_status(self) -> dict[str, Any]: |
| status = self.transformers.status() |
| status["ollama"] = self.ollama.status() |
| status["provider_preference"] = list(default_provider_preference()) |
| return status |
|
|