File size: 9,129 Bytes
21c7db9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 | """API service layer."""
from __future__ import annotations
from pathlib import Path
from typing import Any
from app.agents.orchestrator import Orchestrator
from app.env.catalog import apply_task_preset, env_catalog
from app.env.env_core import PolyGuardEnv
from app.evaluation.benchmark_report import build_benchmark_report
from app.evaluation.dosing_eval import dosing_eval
from app.knowledge.evidence_retriever import retrieve_evidence
from app.models.retrieval.retriever import retrieve
from app.models.policy.provider_runtime import PolicyProviderRouter, default_provider_preference
from app.models.baselines import (
choose_beam_search,
choose_contextual_bandit,
choose_contextual_bandit_topk,
choose_greedy,
choose_no_change,
choose_rules_only,
)
from app.training import train_dosing_grpo, train_planner_grpo, train_supervisor_grpo
class APIService:
def __init__(self) -> None:
self.env = PolyGuardEnv()
self.orchestrator = Orchestrator(self.env)
self.policy_router = PolicyProviderRouter()
self.training_metrics: dict[str, Any] = {}
self.root = Path(__file__).resolve().parents[2]
def reset(self, **kwargs: Any) -> dict[str, Any]:
kwargs = apply_task_preset(dict(kwargs))
obs = self.env.reset(**kwargs)
return obs.model_dump(mode="json")
def step(self, action: dict[str, Any]) -> dict[str, Any]:
obs, reward, done, info = self.env.step(action)
reason = str(info.get("termination_reason", "")) if isinstance(info, dict) else ""
truncated = reason in {"wall_clock_timeout", "step_timeout", "step_budget_exhausted"}
return {
"observation": obs.model_dump(mode="json"),
"reward": reward,
"done": done,
"terminated": done,
"truncated": truncated,
"info": info,
}
def catalog(self) -> dict[str, Any]:
return env_catalog()
def step_candidate(self, candidate_id: str, confidence: float, rationale_brief: str) -> dict[str, Any] | None:
for action in self.env.get_legal_actions():
if action.get("candidate_id") != candidate_id:
continue
payload = dict(action)
payload["confidence"] = confidence
payload["rationale_brief"] = rationale_brief
return self.step(payload)
return None
def orchestrate(self, coordination_mode: str | None = None) -> dict[str, Any]:
return self.orchestrator.run_step(coordination_mode=coordination_mode)
def infer_policy(self) -> dict[str, Any]:
legal = self.env.get_legal_actions()
if not legal:
return {}
candidate_payloads = [
item for item in self.env.get_candidate_actions() if bool(item.get("legality_precheck", False))
]
if not candidate_payloads:
return legal[0]
candidates = [self._candidate_obj(item) for item in candidate_payloads]
state = self.env.state
selection = self.policy_router.select_candidate(
candidates=candidates,
prompt={
"patient_id": state.patient.patient_id,
"difficulty": state.difficulty.value,
"sub_environment": state.sub_environment.value,
"step_count": state.step_count,
},
provider_preference=default_provider_preference(),
)
selected = next((item for item in legal if item.get("candidate_id") == selection.candidate_id), legal[0])
payload = dict(selected)
payload["policy_selection"] = {
"provider": selection.provider,
"candidate_id": selection.candidate_id,
"rationale": selection.rationale,
"latency_ms": round(selection.latency_ms, 3),
"raw_output": selection.raw_output,
}
return payload
def model_status(self) -> dict[str, Any]:
return self.policy_router.model_status()
def batch_infer(self, batch_size: int = 4) -> list[dict[str, Any]]:
legal = self.env.get_legal_actions()
return legal[:batch_size]
def run_baselines(self) -> dict[str, Any]:
candidates = [c for c in self.env.get_candidate_actions() if c.get("legality_precheck")]
if not candidates:
self.env.reset()
candidates = [c for c in self.env.get_candidate_actions() if c.get("legality_precheck")]
baseline_results = {
"no_change": choose_no_change().model_dump(mode="json"),
"rules_only": choose_rules_only([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"),
"greedy": choose_greedy([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"),
"contextual_bandit": choose_contextual_bandit([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"),
"contextual_bandit_topk": [
{
"candidate_id": item.candidate_id,
"score": item.score,
"exploration_bonus": item.exploration_bonus,
"algorithm": item.algorithm,
}
for item in choose_contextual_bandit_topk([self._candidate_obj(c) for c in candidates], top_k=3)
],
"beam_search": choose_beam_search([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"),
}
return baseline_results
def run_policy_eval(self) -> dict[str, Any]:
out = build_benchmark_report(Path("outputs/reports/benchmark_report.txt"))
return out
def run_dosing_eval(self) -> dict[str, Any]:
return dosing_eval()
def run_training(self) -> dict[str, Any]:
out_dir = Path("checkpoints")
out_dir.mkdir(parents=True, exist_ok=True)
self.training_metrics = {
"supervisor": train_supervisor_grpo(episodes=4, checkpoint_dir=out_dir),
"planner": train_planner_grpo(episodes=6, checkpoint_dir=out_dir),
"dosing": train_dosing_grpo(episodes=4, checkpoint_dir=out_dir),
}
return self.training_metrics
def get_metrics(self) -> dict[str, Any]:
if self.training_metrics:
if "planner" in self.training_metrics:
merged = dict(self.training_metrics["planner"])
merged["model_metrics"] = self.training_metrics
return merged
return self.training_metrics
reports_dir = Path("outputs/reports")
metrics: dict[str, Any] = {}
for name in ["supervisor_grpo", "planner_grpo", "dosing_grpo"]:
path = reports_dir / f"{name}.json"
if path.exists():
import json
metrics[name] = json.loads(path.read_text(encoding="utf-8"))
self.training_metrics = metrics
if "planner_grpo" in metrics:
merged = dict(metrics["planner_grpo"])
merged["model_metrics"] = metrics
return merged
return metrics
def sample_case(self) -> dict[str, Any]:
obs = self.env.reset()
return obs.model_dump(mode="json")
def search_cases(self, query: str) -> list[dict[str, Any]]:
index_file = self.root / "data" / "retrieval_index" / "index.json"
hits = retrieve(index_file=index_file, query=query, top_k=5)
if hits:
return [
{
"patient_id": Path(item.get("path", f"case_{idx}")).stem,
"query": query,
"source_path": item.get("path", ""),
"snippet": str(item.get("text", ""))[:280],
}
for idx, item in enumerate(hits)
]
fallback: list[dict[str, Any]] = []
corpus = self.root / "data" / "processed" / "retrieval_corpus.jsonl"
if corpus.exists():
query_tokens = {token for token in query.lower().split() if token}
with corpus.open("r", encoding="utf-8") as handle:
for idx, line in enumerate(handle):
if len(fallback) >= 5:
break
text = line.strip()
if not text:
continue
hay = text.lower()
if query_tokens and not any(token in hay for token in query_tokens):
continue
fallback.append(
{
"patient_id": f"retrieval_corpus_{idx}",
"query": query,
"source_path": str(corpus),
"snippet": text[:280],
}
)
return fallback
def evidence_query(self, query: str, top_k: int = 5) -> list[dict[str, str]]:
return retrieve_evidence(query=query, top_k=top_k)
@staticmethod
def _candidate_obj(payload: dict) -> Any:
from app.common.types import CandidateAction
return CandidateAction.model_validate(payload)
|