File size: 5,975 Bytes
877add7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""LLM provider runtime with Transformers-first fallback order.

The runtime is intentionally conservative: if an LLM backend is unavailable or
errors, selection falls back to deterministic local ranking.
"""

from __future__ import annotations

from dataclasses import dataclass
import json
import os
import shutil
import subprocess
import time
from typing import Any

from app.common.types import CandidateAction
from app.models.policy.safety_ranker import rank_candidates


@dataclass(slots=True)
class ProviderSelection:
    provider: str
    candidate_id: str
    rationale: str
    latency_ms: float
    raw_output: str = ""


class OllamaProvider:
    name = "ollama"

    def __init__(self, model_name: str) -> None:
        self.model_name = model_name

    def is_available(self) -> bool:
        if os.getenv("POLYGUARD_ENABLE_OLLAMA", "false").lower() not in {"1", "true", "yes", "on"}:
            return False
        return shutil.which("ollama") is not None

    def ensure_model(self) -> bool:
        if not self.is_available():
            return False
        if os.getenv("POLYGUARD_OLLAMA_AUTO_PULL", "true").lower() not in {"1", "true", "yes", "on"}:
            return True
        try:
            subprocess.run(
                ["ollama", "pull", self.model_name],
                check=False,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True,
                timeout=90,
            )
            return True
        except Exception:
            return False

    def select(self, candidates: list[CandidateAction], prompt: dict[str, Any]) -> ProviderSelection | None:
        if not self.is_available() or not candidates:
            return None
        self.ensure_model()
        deadline_seconds = float(os.getenv("POLYGUARD_PROVIDER_TIMEOUT_SECONDS", "7.0"))
        compact_candidates = [
            {
                "candidate_id": c.candidate_id,
                "mode": c.mode.value,
                "action_type": c.action_type.value,
                "estimated_safety_delta": c.estimated_safety_delta,
                "uncertainty_score": c.uncertainty_score,
                "legality_precheck": c.legality_precheck,
            }
            for c in candidates
        ]
        request = {
            "instruction": "Return only JSON with fields candidate_id and rationale.",
            "context": prompt,
            "candidates": compact_candidates,
        }
        start = time.monotonic()
        try:
            proc = subprocess.run(
                ["ollama", "run", self.model_name, json.dumps(request, ensure_ascii=True)],
                check=False,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True,
                timeout=deadline_seconds,
            )
            elapsed_ms = (time.monotonic() - start) * 1000.0
            raw = (proc.stdout or "").strip()
            if not raw:
                return None
            data = json.loads(raw)
            candidate_id = str(data.get("candidate_id", "")).strip()
            if not candidate_id:
                return None
            if candidate_id not in {c.candidate_id for c in candidates}:
                return None
            rationale = str(data.get("rationale", "Ollama provider selection.")).strip() or "Ollama provider selection."
            return ProviderSelection(
                provider=self.name,
                candidate_id=candidate_id,
                rationale=rationale,
                latency_ms=elapsed_ms,
                raw_output=raw,
            )
        except Exception:
            return None


class TransformersProvider:
    name = "transformers"

    def __init__(self, model_name: str) -> None:
        self.model_name = model_name

    def is_available(self) -> bool:
        try:
            import transformers  # noqa: F401

            return True
        except Exception:
            return False

    def select(self, candidates: list[CandidateAction], prompt: dict[str, Any]) -> ProviderSelection | None:
        _ = prompt
        if not self.is_available() or not candidates:
            return None
        # Keep this lightweight and deterministic for local runs: prefer highest
        # safety-adjusted candidate as a transformers fallback.
        start = time.monotonic()
        top = rank_candidates(candidates)[0]
        return ProviderSelection(
            provider=self.name,
            candidate_id=top.candidate_id,
            rationale=f"Transformers fallback selected {top.candidate_id} via local ranker.",
            latency_ms=(time.monotonic() - start) * 1000.0,
        )


class PolicyProviderRouter:
    def __init__(self, ollama_model: str = "qwen2.5:1.5b-instruct", hf_model: str = "Qwen/Qwen2.5-0.5B-Instruct") -> None:
        self.ollama = OllamaProvider(ollama_model)
        self.transformers = TransformersProvider(hf_model)

    def select_candidate(
        self,
        candidates: list[CandidateAction],
        prompt: dict[str, Any],
        provider_preference: tuple[str, ...] = ("transformers",),
    ) -> ProviderSelection:
        provider_preference = tuple(provider_preference) or ("transformers",)

        for provider in provider_preference:
            if provider == "ollama":
                picked = self.ollama.select(candidates, prompt)
                if picked is not None:
                    return picked
            elif provider == "transformers":
                picked = self.transformers.select(candidates, prompt)
                if picked is not None:
                    return picked

        # Deterministic hard fallback.
        fallback = rank_candidates(candidates)[0]
        return ProviderSelection(
            provider="heuristic_fallback",
            candidate_id=fallback.candidate_id,
            rationale="Fallback ranker selected top legal/safety candidate.",
            latency_ms=0.0,
        )