"""API service layer.""" from __future__ import annotations from pathlib import Path from typing import Any from app.agents.orchestrator import Orchestrator from app.env.catalog import apply_task_preset, env_catalog from app.env.env_core import PolyGuardEnv from app.evaluation.benchmark_report import build_benchmark_report from app.evaluation.dosing_eval import dosing_eval from app.knowledge.evidence_retriever import retrieve_evidence from app.models.retrieval.retriever import retrieve from app.models.policy.provider_runtime import PolicyProviderRouter, default_provider_preference from app.models.baselines import ( choose_beam_search, choose_contextual_bandit, choose_contextual_bandit_topk, choose_greedy, choose_no_change, choose_rules_only, ) from app.training import train_dosing_grpo, train_planner_grpo, train_supervisor_grpo class APIService: def __init__(self) -> None: self.env = PolyGuardEnv() self.orchestrator = Orchestrator(self.env) self.policy_router = PolicyProviderRouter() self.training_metrics: dict[str, Any] = {} self.root = Path(__file__).resolve().parents[2] def reset(self, **kwargs: Any) -> dict[str, Any]: kwargs = apply_task_preset(dict(kwargs)) obs = self.env.reset(**kwargs) return obs.model_dump(mode="json") def step(self, action: dict[str, Any]) -> dict[str, Any]: obs, reward, done, info = self.env.step(action) reason = str(info.get("termination_reason", "")) if isinstance(info, dict) else "" truncated = reason in {"wall_clock_timeout", "step_timeout", "step_budget_exhausted"} return { "observation": obs.model_dump(mode="json"), "reward": reward, "done": done, "terminated": done, "truncated": truncated, "info": info, } def catalog(self) -> dict[str, Any]: return env_catalog() def step_candidate(self, candidate_id: str, confidence: float, rationale_brief: str) -> dict[str, Any] | None: for action in self.env.get_legal_actions(): if action.get("candidate_id") != candidate_id: continue payload = dict(action) payload["confidence"] = confidence payload["rationale_brief"] = rationale_brief return self.step(payload) return None def orchestrate(self, coordination_mode: str | None = None) -> dict[str, Any]: return self.orchestrator.run_step(coordination_mode=coordination_mode) def infer_policy(self) -> dict[str, Any]: legal = self.env.get_legal_actions() if not legal: return {} candidate_payloads = [ item for item in self.env.get_candidate_actions() if bool(item.get("legality_precheck", False)) ] if not candidate_payloads: return legal[0] candidates = [self._candidate_obj(item) for item in candidate_payloads] state = self.env.state selection = self.policy_router.select_candidate( candidates=candidates, prompt={ "patient_id": state.patient.patient_id, "difficulty": state.difficulty.value, "sub_environment": state.sub_environment.value, "step_count": state.step_count, }, provider_preference=default_provider_preference(), ) selected = next((item for item in legal if item.get("candidate_id") == selection.candidate_id), legal[0]) payload = dict(selected) payload["policy_selection"] = { "provider": selection.provider, "candidate_id": selection.candidate_id, "rationale": selection.rationale, "latency_ms": round(selection.latency_ms, 3), "raw_output": selection.raw_output, } return payload def model_status(self) -> dict[str, Any]: return self.policy_router.model_status() def batch_infer(self, batch_size: int = 4) -> list[dict[str, Any]]: legal = self.env.get_legal_actions() return legal[:batch_size] def run_baselines(self) -> dict[str, Any]: candidates = [c for c in self.env.get_candidate_actions() if c.get("legality_precheck")] if not candidates: self.env.reset() candidates = [c for c in self.env.get_candidate_actions() if c.get("legality_precheck")] baseline_results = { "no_change": choose_no_change().model_dump(mode="json"), "rules_only": choose_rules_only([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"), "greedy": choose_greedy([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"), "contextual_bandit": choose_contextual_bandit([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"), "contextual_bandit_topk": [ { "candidate_id": item.candidate_id, "score": item.score, "exploration_bonus": item.exploration_bonus, "algorithm": item.algorithm, } for item in choose_contextual_bandit_topk([self._candidate_obj(c) for c in candidates], top_k=3) ], "beam_search": choose_beam_search([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"), } return baseline_results def run_policy_eval(self) -> dict[str, Any]: out = build_benchmark_report(Path("outputs/reports/benchmark_report.txt")) return out def run_dosing_eval(self) -> dict[str, Any]: return dosing_eval() def run_training(self) -> dict[str, Any]: out_dir = Path("checkpoints") out_dir.mkdir(parents=True, exist_ok=True) self.training_metrics = { "supervisor": train_supervisor_grpo(episodes=4, checkpoint_dir=out_dir), "planner": train_planner_grpo(episodes=6, checkpoint_dir=out_dir), "dosing": train_dosing_grpo(episodes=4, checkpoint_dir=out_dir), } return self.training_metrics def get_metrics(self) -> dict[str, Any]: if self.training_metrics: if "planner" in self.training_metrics: merged = dict(self.training_metrics["planner"]) merged["model_metrics"] = self.training_metrics return merged return self.training_metrics reports_dir = Path("outputs/reports") metrics: dict[str, Any] = {} for name in ["supervisor_grpo", "planner_grpo", "dosing_grpo"]: path = reports_dir / f"{name}.json" if path.exists(): import json metrics[name] = json.loads(path.read_text(encoding="utf-8")) self.training_metrics = metrics if "planner_grpo" in metrics: merged = dict(metrics["planner_grpo"]) merged["model_metrics"] = metrics return merged return metrics def sample_case(self) -> dict[str, Any]: obs = self.env.reset() return obs.model_dump(mode="json") def search_cases(self, query: str) -> list[dict[str, Any]]: index_file = self.root / "data" / "retrieval_index" / "index.json" hits = retrieve(index_file=index_file, query=query, top_k=5) if hits: return [ { "patient_id": Path(item.get("path", f"case_{idx}")).stem, "query": query, "source_path": item.get("path", ""), "snippet": str(item.get("text", ""))[:280], } for idx, item in enumerate(hits) ] fallback: list[dict[str, Any]] = [] corpus = self.root / "data" / "processed" / "retrieval_corpus.jsonl" if corpus.exists(): query_tokens = {token for token in query.lower().split() if token} with corpus.open("r", encoding="utf-8") as handle: for idx, line in enumerate(handle): if len(fallback) >= 5: break text = line.strip() if not text: continue hay = text.lower() if query_tokens and not any(token in hay for token in query_tokens): continue fallback.append( { "patient_id": f"retrieval_corpus_{idx}", "query": query, "source_path": str(corpus), "snippet": text[:280], } ) return fallback def evidence_query(self, query: str, top_k: int = 5) -> list[dict[str, str]]: return retrieve_evidence(query=query, top_k=top_k) @staticmethod def _candidate_obj(payload: dict) -> Any: from app.common.types import CandidateAction return CandidateAction.model_validate(payload)