"""LLM provider runtime with Transformers-first fallback order. The runtime is intentionally conservative: if an LLM backend is unavailable or errors, selection falls back to deterministic local ranking. """ from __future__ import annotations from dataclasses import dataclass import json import os import re import shutil import subprocess import time from pathlib import Path from typing import Any from app.common.types import CandidateAction from app.models.policy.active_model import active_model_status, available_artifact_path from app.models.policy.safety_ranker import rank_candidates def _transformers_low_cpu_mem() -> bool: """Use lazy/meta init only on CUDA; on CPU it often breaks ``.to(device)`` (meta tensors).""" try: import torch return torch.cuda.is_available() except Exception: return False def _peft_base_model_id(artifact_path: str, status: dict[str, Any], fallback: str) -> str: cfg = Path(artifact_path) / "adapter_config.json" if cfg.is_file(): try: payload = json.loads(cfg.read_text(encoding="utf-8")) raw = payload.get("base_model_name_or_path") if isinstance(raw, str) and raw.strip(): return raw.strip() except Exception: pass return str(status.get("base_model") or fallback) def _env_truthy(name: str, default: bool = False) -> bool: raw = os.getenv(name) if raw is None: return default return raw.strip().lower() in {"1", "true", "yes", "on"} def default_provider_preference() -> tuple[str, ...]: raw = os.getenv("POLYGUARD_PROVIDER_PREFERENCE", "").strip() if raw: order = tuple(p.strip().lower() for p in raw.split(",") if p.strip()) cleaned = tuple(p for p in order if p in {"ollama", "transformers"}) if cleaned: return cleaned if _env_truthy("POLYGUARD_ENABLE_OLLAMA"): return ("ollama", "transformers") return ("transformers",) def _extract_candidate_id(text: str, legal_ids: set[str]) -> str | None: lowered = text.lower() for candidate_id in sorted(legal_ids): if candidate_id.lower() in lowered: return candidate_id return None def _compact_prompt(candidates: list[CandidateAction], prompt: dict[str, Any]) -> str: compact_candidates = [ { "candidate_id": c.candidate_id, "mode": c.mode.value, "action_type": c.action_type.value, "target_drug": c.target_drug, "replacement_drug": c.replacement_drug, "dose_bucket": c.dose_bucket.value, "safety_delta": c.estimated_safety_delta, "uncertainty": c.uncertainty_score, "legal": c.legality_precheck, "tags": c.rationale_tags[:4], } for c in candidates ] payload = { "instruction": "Select the safest legal medication action candidate_id.", "context": prompt, "candidate_ids": [c.candidate_id for c in candidates], "candidates": compact_candidates, "answer": "", "format": "Return candidate_id=; rationale=.", } return json.dumps(payload, ensure_ascii=True) @dataclass(slots=True) class ProviderSelection: provider: str candidate_id: str rationale: str latency_ms: float raw_output: str = "" class OllamaProvider: name = "ollama" def __init__(self, model_name: str) -> None: self.model_name = model_name self._last_error = "" def is_available(self) -> bool: if os.getenv("POLYGUARD_ENABLE_OLLAMA", "false").lower() not in {"1", "true", "yes", "on"}: return False return shutil.which("ollama") is not None def ensure_model(self) -> bool: if not self.is_available(): return False if os.getenv("POLYGUARD_OLLAMA_AUTO_PULL", "true").lower() not in {"1", "true", "yes", "on"}: return True try: subprocess.run( ["ollama", "pull", self.model_name], check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, timeout=90, ) return True except Exception: return False def select(self, candidates: list[CandidateAction], prompt: dict[str, Any]) -> ProviderSelection | None: if not self.is_available() or not candidates: return None self.ensure_model() deadline_seconds = float(os.getenv("POLYGUARD_PROVIDER_TIMEOUT_SECONDS", "25.0")) legal_ids = {c.candidate_id for c in candidates} compact_candidates = [ { "candidate_id": c.candidate_id, "mode": c.mode.value, "action_type": c.action_type.value, "estimated_safety_delta": c.estimated_safety_delta, "uncertainty_score": c.uncertainty_score, "legality_precheck": c.legality_precheck, } for c in candidates ] request = { "instruction": ( "Choose exactly one safest legal medication action. " "Return a single JSON object only: {\"candidate_id\":\"cand_XX\",\"rationale\":\"brief reason\"}. " "Do not return arrays or multiple candidates." ), "context": prompt, "candidates": compact_candidates, } start = time.monotonic() try: prompt_text = json.dumps(request, ensure_ascii=True) proc = subprocess.run( ["ollama", "run", self.model_name], check=False, input=prompt_text, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, timeout=deadline_seconds, env={**os.environ, "TERM": "dumb", "NO_COLOR": "1"}, ) elapsed_ms = (time.monotonic() - start) * 1000.0 if proc.returncode != 0: self._last_error = (proc.stderr or "ollama run failed").strip()[:500] return None raw = re.sub(r"\x1b\[[0-?]*[ -/]*[@-~]", "", proc.stdout or "").strip() if not raw: self._last_error = (proc.stderr or "ollama returned empty output").strip()[:500] return None try: data = json.loads(raw) except json.JSONDecodeError: data = {} parsed_candidate = data.get("candidate_id") if isinstance(data, dict) else None if isinstance(parsed_candidate, list): parsed_candidate = next((str(item) for item in parsed_candidate if str(item) in legal_ids), "") candidate_id = str(parsed_candidate or "").strip() or (_extract_candidate_id(raw, legal_ids) or "") if not candidate_id or candidate_id not in legal_ids: self._last_error = f"ollama returned no legal candidate_id: {raw[:240]}" return None parsed_rationale = data.get("rationale") if isinstance(data, dict) else None if isinstance(parsed_rationale, list): parsed_rationale = " ".join(str(item) for item in parsed_rationale[:2]) rationale = str(parsed_rationale or "Ollama provider selection.").strip() or "Ollama provider selection." self._last_error = "" return ProviderSelection( provider=self.name, candidate_id=candidate_id, rationale=rationale, latency_ms=elapsed_ms, raw_output=raw, ) except Exception as exc: self._last_error = str(exc)[:500] return None def status(self) -> dict[str, Any]: return { "enabled": _env_truthy("POLYGUARD_ENABLE_OLLAMA"), "available": self.is_available(), "model": self.model_name, "provider": self.name, "last_error": self._last_error, } class TransformersProvider: name = "transformers" def __init__(self, model_name: str) -> None: self.model_name = model_name self._model: Any | None = None self._tokenizer: Any | None = None self._model_source = "" self._load_error = "" def is_available(self) -> bool: try: import transformers # noqa: F401 return True except Exception: return False def status(self) -> dict[str, Any]: status = active_model_status() status["provider"] = self.name status["loaded_source"] = self._model_source status["load_error"] = self._load_error status["runtime_model_name"] = self.model_name return status def _load_artifact(self, artifact_name: str, artifact_path: Any, status: dict[str, Any]) -> bool: try: import torch from transformers import AutoModelForCausalLM, AutoTokenizer artifact_path = os.fspath(artifact_path) dtype = torch.float16 if torch.cuda.is_available() else torch.float32 low_mem = _transformers_low_cpu_mem() if artifact_name == "merged": tokenizer = AutoTokenizer.from_pretrained(artifact_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( artifact_path, dtype=dtype, low_cpu_mem_usage=low_mem, trust_remote_code=True, ) source = "active_merged" else: from peft import PeftModel base_model = _peft_base_model_id(artifact_path, status, self.model_name) tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) base = AutoModelForCausalLM.from_pretrained( base_model, dtype=dtype, low_cpu_mem_usage=low_mem, trust_remote_code=True, ) model = PeftModel.from_pretrained(base, artifact_path) source = f"active_{artifact_name}" if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) model.eval() self._model = model self._tokenizer = tokenizer self._model_source = source self._load_error = "" return True except Exception as exc: # noqa: BLE001 self._load_error = str(exc) self._model = None self._tokenizer = None self._model_source = "" return False def _load_active_model(self) -> bool: if self._model is not None and self._tokenizer is not None: return True status = active_model_status() if available_artifact_path(status) is None: return False paths = status.get("paths", {}) availability = status.get("availability", {}) errors: list[str] = [] if not isinstance(paths, dict) or not isinstance(availability, dict): return False for artifact_name in status.get("load_order", []): if not availability.get(artifact_name) or not paths.get(artifact_name): continue if self._load_artifact(str(artifact_name), paths[artifact_name], status): return True errors.append(f"{artifact_name}:{self._load_error}") if errors: self._load_error = " | ".join(errors) return False def _select_with_active_model( self, candidates: list[CandidateAction], prompt: dict[str, Any], ) -> ProviderSelection | None: if not self._load_active_model() or self._model is None or self._tokenizer is None: return None import torch legal_ids = {c.candidate_id for c in candidates} prompt_text = _compact_prompt(candidates, prompt) max_new_tokens = int(os.getenv("POLYGUARD_PROVIDER_MAX_NEW_TOKENS", "64")) started = time.monotonic() try: device = next(self._model.parameters()).device encoded = self._tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=768) encoded = {key: value.to(device) for key, value in encoded.items()} with torch.no_grad(): generated = self._model.generate( **encoded, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.0, eos_token_id=self._tokenizer.eos_token_id, pad_token_id=self._tokenizer.pad_token_id, ) decoded = self._tokenizer.decode(generated[0], skip_special_tokens=True) completion = decoded[len(prompt_text) :].strip() if decoded.startswith(prompt_text) else decoded candidate_id = _extract_candidate_id(completion, legal_ids) if candidate_id is None: return None rationale = completion.strip() or f"Active model selected {candidate_id}." return ProviderSelection( provider=self._model_source or self.name, candidate_id=candidate_id, rationale=rationale[:500], latency_ms=(time.monotonic() - started) * 1000.0, raw_output=completion, ) except Exception as exc: # noqa: BLE001 self._load_error = str(exc) return None def select(self, candidates: list[CandidateAction], prompt: dict[str, Any]) -> ProviderSelection | None: if not self.is_available() or not candidates: return None active_selection = self._select_with_active_model(candidates, prompt) if active_selection is not None: return active_selection # Keep this lightweight and deterministic when no active artifact is # configured or model loading fails. start = time.monotonic() top = rank_candidates(candidates)[0] status = active_model_status() load_note = f" active_model_error={self._load_error}" if self._load_error else "" return ProviderSelection( provider="transformers_ranker_fallback", candidate_id=top.candidate_id, rationale=( f"Transformers fallback selected {top.candidate_id} via local ranker; " f"active_model_enabled={status.get('enabled')}; active_model_available={status.get('active')}." f"{load_note}" ), latency_ms=(time.monotonic() - start) * 1000.0, ) class PolicyProviderRouter: def __init__(self, ollama_model: str = "qwen2.5:1.5b-instruct", hf_model: str = "Qwen/Qwen2.5-0.5B-Instruct") -> None: self.ollama = OllamaProvider(os.getenv("POLYGUARD_OLLAMA_MODEL", ollama_model)) self.transformers = TransformersProvider( os.getenv("POLYGUARD_HF_MODEL") or os.getenv("POLYGUARD_FRONTIER_MODEL") or hf_model ) def select_candidate( self, candidates: list[CandidateAction], prompt: dict[str, Any], provider_preference: tuple[str, ...] | None = None, ) -> ProviderSelection: provider_preference = tuple(provider_preference or default_provider_preference()) for provider in provider_preference: if provider == "ollama": picked = self.ollama.select(candidates, prompt) if picked is not None: return picked elif provider == "transformers": picked = self.transformers.select(candidates, prompt) if picked is not None: return picked # Deterministic hard fallback. fallback = rank_candidates(candidates)[0] return ProviderSelection( provider="heuristic_fallback", candidate_id=fallback.candidate_id, rationale="Fallback ranker selected top legal/safety candidate.", latency_ms=0.0, ) def model_status(self) -> dict[str, Any]: status = self.transformers.status() status["ollama"] = self.ollama.status() status["provider_preference"] = list(default_provider_preference()) return status