"""LLM provider runtime with Transformers-first fallback order. The runtime is intentionally conservative: if an LLM backend is unavailable or errors, selection falls back to deterministic local ranking. """ from __future__ import annotations from dataclasses import dataclass import json import os import shutil import subprocess import time from typing import Any from app.common.types import CandidateAction from app.models.policy.active_model import active_model_status, available_artifact_path from app.models.policy.safety_ranker import rank_candidates def _extract_candidate_id(text: str, legal_ids: set[str]) -> str | None: lowered = text.lower() for candidate_id in legal_ids: if candidate_id.lower() in lowered: return candidate_id return None def _compact_prompt(candidates: list[CandidateAction], prompt: dict[str, Any]) -> str: compact_candidates = [ { "candidate_id": c.candidate_id, "mode": c.mode.value, "action_type": c.action_type.value, "target_drug": c.target_drug, "replacement_drug": c.replacement_drug, "dose_bucket": c.dose_bucket.value, "safety_delta": c.estimated_safety_delta, "uncertainty": c.uncertainty_score, "legal": c.legality_precheck, "tags": c.rationale_tags[:4], } for c in candidates ] payload = { "instruction": "Select the safest legal medication action candidate_id.", "context": prompt, "candidate_ids": [c.candidate_id for c in candidates], "candidates": compact_candidates, "answer": "", "format": "Return candidate_id=; rationale=.", } return json.dumps(payload, ensure_ascii=True) @dataclass(slots=True) class ProviderSelection: provider: str candidate_id: str rationale: str latency_ms: float raw_output: str = "" class OllamaProvider: name = "ollama" def __init__(self, model_name: str) -> None: self.model_name = model_name def is_available(self) -> bool: if os.getenv("POLYGUARD_ENABLE_OLLAMA", "false").lower() not in {"1", "true", "yes", "on"}: return False return shutil.which("ollama") is not None def ensure_model(self) -> bool: if not self.is_available(): return False if os.getenv("POLYGUARD_OLLAMA_AUTO_PULL", "true").lower() not in {"1", "true", "yes", "on"}: return True try: subprocess.run( ["ollama", "pull", self.model_name], check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, timeout=90, ) return True except Exception: return False def select(self, candidates: list[CandidateAction], prompt: dict[str, Any]) -> ProviderSelection | None: if not self.is_available() or not candidates: return None self.ensure_model() deadline_seconds = float(os.getenv("POLYGUARD_PROVIDER_TIMEOUT_SECONDS", "7.0")) compact_candidates = [ { "candidate_id": c.candidate_id, "mode": c.mode.value, "action_type": c.action_type.value, "estimated_safety_delta": c.estimated_safety_delta, "uncertainty_score": c.uncertainty_score, "legality_precheck": c.legality_precheck, } for c in candidates ] request = { "instruction": "Return only JSON with fields candidate_id and rationale.", "context": prompt, "candidates": compact_candidates, } start = time.monotonic() try: proc = subprocess.run( ["ollama", "run", self.model_name, json.dumps(request, ensure_ascii=True)], check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, timeout=deadline_seconds, ) elapsed_ms = (time.monotonic() - start) * 1000.0 raw = (proc.stdout or "").strip() if not raw: return None data = json.loads(raw) candidate_id = str(data.get("candidate_id", "")).strip() if not candidate_id: return None if candidate_id not in {c.candidate_id for c in candidates}: return None rationale = str(data.get("rationale", "Ollama provider selection.")).strip() or "Ollama provider selection." return ProviderSelection( provider=self.name, candidate_id=candidate_id, rationale=rationale, latency_ms=elapsed_ms, raw_output=raw, ) except Exception: return None class TransformersProvider: name = "transformers" def __init__(self, model_name: str) -> None: self.model_name = model_name self._model: Any | None = None self._tokenizer: Any | None = None self._model_source = "" self._load_error = "" def is_available(self) -> bool: try: import transformers # noqa: F401 return True except Exception: return False def status(self) -> dict[str, Any]: status = active_model_status() status["provider"] = self.name status["loaded_source"] = self._model_source status["load_error"] = self._load_error status["runtime_model_name"] = self.model_name return status def _load_artifact(self, artifact_name: str, artifact_path: Any, status: dict[str, Any]) -> bool: try: import torch from transformers import AutoModelForCausalLM, AutoTokenizer artifact_path = os.fspath(artifact_path) dtype = torch.float16 if torch.cuda.is_available() else torch.float32 if artifact_name == "merged": tokenizer = AutoTokenizer.from_pretrained(artifact_path) model = AutoModelForCausalLM.from_pretrained( artifact_path, torch_dtype=dtype, low_cpu_mem_usage=True, ) source = "active_merged" else: from peft import PeftModel base_model = str(status.get("base_model") or self.model_name) tokenizer = AutoTokenizer.from_pretrained(base_model) base = AutoModelForCausalLM.from_pretrained( base_model, torch_dtype=dtype, low_cpu_mem_usage=True, ) model = PeftModel.from_pretrained(base, artifact_path) source = f"active_{artifact_name}" if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) model.eval() self._model = model self._tokenizer = tokenizer self._model_source = source self._load_error = "" return True except Exception as exc: # noqa: BLE001 self._load_error = str(exc) self._model = None self._tokenizer = None self._model_source = "" return False def _load_active_model(self) -> bool: if self._model is not None and self._tokenizer is not None: return True status = active_model_status() if available_artifact_path(status) is None: return False paths = status.get("paths", {}) availability = status.get("availability", {}) errors: list[str] = [] if not isinstance(paths, dict) or not isinstance(availability, dict): return False for artifact_name in status.get("load_order", []): if not availability.get(artifact_name) or not paths.get(artifact_name): continue if self._load_artifact(str(artifact_name), paths[artifact_name], status): return True errors.append(f"{artifact_name}:{self._load_error}") if errors: self._load_error = " | ".join(errors) return False def _select_with_active_model( self, candidates: list[CandidateAction], prompt: dict[str, Any], ) -> ProviderSelection | None: if not self._load_active_model() or self._model is None or self._tokenizer is None: return None import torch legal_ids = {c.candidate_id for c in candidates} prompt_text = _compact_prompt(candidates, prompt) max_new_tokens = int(os.getenv("POLYGUARD_PROVIDER_MAX_NEW_TOKENS", "64")) started = time.monotonic() try: device = next(self._model.parameters()).device encoded = self._tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=768) encoded = {key: value.to(device) for key, value in encoded.items()} with torch.no_grad(): generated = self._model.generate( **encoded, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.0, eos_token_id=self._tokenizer.eos_token_id, pad_token_id=self._tokenizer.pad_token_id, ) decoded = self._tokenizer.decode(generated[0], skip_special_tokens=True) completion = decoded[len(prompt_text) :].strip() if decoded.startswith(prompt_text) else decoded candidate_id = _extract_candidate_id(completion, legal_ids) if candidate_id is None: return None rationale = completion.strip() or f"Active model selected {candidate_id}." return ProviderSelection( provider=self._model_source or self.name, candidate_id=candidate_id, rationale=rationale[:500], latency_ms=(time.monotonic() - started) * 1000.0, raw_output=completion, ) except Exception as exc: # noqa: BLE001 self._load_error = str(exc) return None def select(self, candidates: list[CandidateAction], prompt: dict[str, Any]) -> ProviderSelection | None: if not self.is_available() or not candidates: return None active_selection = self._select_with_active_model(candidates, prompt) if active_selection is not None: return active_selection # Keep this lightweight and deterministic when no active artifact is # configured or model loading fails. start = time.monotonic() top = rank_candidates(candidates)[0] status = active_model_status() load_note = f" active_model_error={self._load_error}" if self._load_error else "" return ProviderSelection( provider="transformers_ranker_fallback", candidate_id=top.candidate_id, rationale=( f"Transformers fallback selected {top.candidate_id} via local ranker; " f"active_model_enabled={status.get('enabled')}; active_model_available={status.get('active')}." f"{load_note}" ), latency_ms=(time.monotonic() - start) * 1000.0, ) class PolicyProviderRouter: def __init__(self, ollama_model: str = "qwen2.5:1.5b-instruct", hf_model: str = "Qwen/Qwen2.5-0.5B-Instruct") -> None: self.ollama = OllamaProvider(ollama_model) self.transformers = TransformersProvider(hf_model) def select_candidate( self, candidates: list[CandidateAction], prompt: dict[str, Any], provider_preference: tuple[str, ...] = ("transformers",), ) -> ProviderSelection: provider_preference = tuple(provider_preference) or ("transformers",) for provider in provider_preference: if provider == "ollama": picked = self.ollama.select(candidates, prompt) if picked is not None: return picked elif provider == "transformers": picked = self.transformers.select(candidates, prompt) if picked is not None: return picked # Deterministic hard fallback. fallback = rank_candidates(candidates)[0] return ProviderSelection( provider="heuristic_fallback", candidate_id=fallback.candidate_id, rationale="Fallback ranker selected top legal/safety candidate.", latency_ms=0.0, ) def model_status(self) -> dict[str, Any]: return self.transformers.status()