File size: 9,129 Bytes
21c7db9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""API service layer."""

from __future__ import annotations

from pathlib import Path
from typing import Any

from app.agents.orchestrator import Orchestrator
from app.env.catalog import apply_task_preset, env_catalog
from app.env.env_core import PolyGuardEnv
from app.evaluation.benchmark_report import build_benchmark_report
from app.evaluation.dosing_eval import dosing_eval
from app.knowledge.evidence_retriever import retrieve_evidence
from app.models.retrieval.retriever import retrieve
from app.models.policy.provider_runtime import PolicyProviderRouter, default_provider_preference
from app.models.baselines import (
    choose_beam_search,
    choose_contextual_bandit,
    choose_contextual_bandit_topk,
    choose_greedy,
    choose_no_change,
    choose_rules_only,
)
from app.training import train_dosing_grpo, train_planner_grpo, train_supervisor_grpo


class APIService:
    def __init__(self) -> None:
        self.env = PolyGuardEnv()
        self.orchestrator = Orchestrator(self.env)
        self.policy_router = PolicyProviderRouter()
        self.training_metrics: dict[str, Any] = {}
        self.root = Path(__file__).resolve().parents[2]

    def reset(self, **kwargs: Any) -> dict[str, Any]:
        kwargs = apply_task_preset(dict(kwargs))
        obs = self.env.reset(**kwargs)
        return obs.model_dump(mode="json")

    def step(self, action: dict[str, Any]) -> dict[str, Any]:
        obs, reward, done, info = self.env.step(action)
        reason = str(info.get("termination_reason", "")) if isinstance(info, dict) else ""
        truncated = reason in {"wall_clock_timeout", "step_timeout", "step_budget_exhausted"}
        return {
            "observation": obs.model_dump(mode="json"),
            "reward": reward,
            "done": done,
            "terminated": done,
            "truncated": truncated,
            "info": info,
        }

    def catalog(self) -> dict[str, Any]:
        return env_catalog()

    def step_candidate(self, candidate_id: str, confidence: float, rationale_brief: str) -> dict[str, Any] | None:
        for action in self.env.get_legal_actions():
            if action.get("candidate_id") != candidate_id:
                continue
            payload = dict(action)
            payload["confidence"] = confidence
            payload["rationale_brief"] = rationale_brief
            return self.step(payload)
        return None

    def orchestrate(self, coordination_mode: str | None = None) -> dict[str, Any]:
        return self.orchestrator.run_step(coordination_mode=coordination_mode)

    def infer_policy(self) -> dict[str, Any]:
        legal = self.env.get_legal_actions()
        if not legal:
            return {}
        candidate_payloads = [
            item for item in self.env.get_candidate_actions() if bool(item.get("legality_precheck", False))
        ]
        if not candidate_payloads:
            return legal[0]
        candidates = [self._candidate_obj(item) for item in candidate_payloads]
        state = self.env.state
        selection = self.policy_router.select_candidate(
            candidates=candidates,
            prompt={
                "patient_id": state.patient.patient_id,
                "difficulty": state.difficulty.value,
                "sub_environment": state.sub_environment.value,
                "step_count": state.step_count,
            },
            provider_preference=default_provider_preference(),
        )
        selected = next((item for item in legal if item.get("candidate_id") == selection.candidate_id), legal[0])
        payload = dict(selected)
        payload["policy_selection"] = {
            "provider": selection.provider,
            "candidate_id": selection.candidate_id,
            "rationale": selection.rationale,
            "latency_ms": round(selection.latency_ms, 3),
            "raw_output": selection.raw_output,
        }
        return payload

    def model_status(self) -> dict[str, Any]:
        return self.policy_router.model_status()

    def batch_infer(self, batch_size: int = 4) -> list[dict[str, Any]]:
        legal = self.env.get_legal_actions()
        return legal[:batch_size]

    def run_baselines(self) -> dict[str, Any]:
        candidates = [c for c in self.env.get_candidate_actions() if c.get("legality_precheck")]
        if not candidates:
            self.env.reset()
            candidates = [c for c in self.env.get_candidate_actions() if c.get("legality_precheck")]
        baseline_results = {
            "no_change": choose_no_change().model_dump(mode="json"),
            "rules_only": choose_rules_only([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"),
            "greedy": choose_greedy([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"),
            "contextual_bandit": choose_contextual_bandit([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"),
            "contextual_bandit_topk": [
                {
                    "candidate_id": item.candidate_id,
                    "score": item.score,
                    "exploration_bonus": item.exploration_bonus,
                    "algorithm": item.algorithm,
                }
                for item in choose_contextual_bandit_topk([self._candidate_obj(c) for c in candidates], top_k=3)
            ],
            "beam_search": choose_beam_search([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"),
        }
        return baseline_results

    def run_policy_eval(self) -> dict[str, Any]:
        out = build_benchmark_report(Path("outputs/reports/benchmark_report.txt"))
        return out

    def run_dosing_eval(self) -> dict[str, Any]:
        return dosing_eval()

    def run_training(self) -> dict[str, Any]:
        out_dir = Path("checkpoints")
        out_dir.mkdir(parents=True, exist_ok=True)
        self.training_metrics = {
            "supervisor": train_supervisor_grpo(episodes=4, checkpoint_dir=out_dir),
            "planner": train_planner_grpo(episodes=6, checkpoint_dir=out_dir),
            "dosing": train_dosing_grpo(episodes=4, checkpoint_dir=out_dir),
        }
        return self.training_metrics

    def get_metrics(self) -> dict[str, Any]:
        if self.training_metrics:
            if "planner" in self.training_metrics:
                merged = dict(self.training_metrics["planner"])
                merged["model_metrics"] = self.training_metrics
                return merged
            return self.training_metrics
        reports_dir = Path("outputs/reports")
        metrics: dict[str, Any] = {}
        for name in ["supervisor_grpo", "planner_grpo", "dosing_grpo"]:
            path = reports_dir / f"{name}.json"
            if path.exists():
                import json

                metrics[name] = json.loads(path.read_text(encoding="utf-8"))
        self.training_metrics = metrics
        if "planner_grpo" in metrics:
            merged = dict(metrics["planner_grpo"])
            merged["model_metrics"] = metrics
            return merged
        return metrics

    def sample_case(self) -> dict[str, Any]:
        obs = self.env.reset()
        return obs.model_dump(mode="json")

    def search_cases(self, query: str) -> list[dict[str, Any]]:
        index_file = self.root / "data" / "retrieval_index" / "index.json"
        hits = retrieve(index_file=index_file, query=query, top_k=5)
        if hits:
            return [
                {
                    "patient_id": Path(item.get("path", f"case_{idx}")).stem,
                    "query": query,
                    "source_path": item.get("path", ""),
                    "snippet": str(item.get("text", ""))[:280],
                }
                for idx, item in enumerate(hits)
            ]

        fallback: list[dict[str, Any]] = []
        corpus = self.root / "data" / "processed" / "retrieval_corpus.jsonl"
        if corpus.exists():
            query_tokens = {token for token in query.lower().split() if token}
            with corpus.open("r", encoding="utf-8") as handle:
                for idx, line in enumerate(handle):
                    if len(fallback) >= 5:
                        break
                    text = line.strip()
                    if not text:
                        continue
                    hay = text.lower()
                    if query_tokens and not any(token in hay for token in query_tokens):
                        continue
                    fallback.append(
                        {
                            "patient_id": f"retrieval_corpus_{idx}",
                            "query": query,
                            "source_path": str(corpus),
                            "snippet": text[:280],
                        }
                    )
        return fallback

    def evidence_query(self, query: str, top_k: int = 5) -> list[dict[str, str]]:
        return retrieve_evidence(query=query, top_k=top_k)

    @staticmethod
    def _candidate_obj(payload: dict) -> Any:
        from app.common.types import CandidateAction

        return CandidateAction.model_validate(payload)