"""API service layer.""" from __future__ import annotations from pathlib import Path from typing import Any from app.agents.orchestrator import Orchestrator from app.env.catalog import apply_task_preset, env_catalog from app.env.env_core import PolyGuardEnv from app.evaluation.benchmark_report import build_benchmark_report from app.evaluation.dosing_eval import dosing_eval from app.knowledge.evidence_retriever import retrieve_evidence from app.models.retrieval.retriever import retrieve from app.models.baselines import ( choose_beam_search, choose_contextual_bandit, choose_contextual_bandit_topk, choose_greedy, choose_no_change, choose_rules_only, ) from app.training import train_dosing_grpo, train_planner_grpo, train_supervisor_grpo class APIService: def __init__(self) -> None: self.env = PolyGuardEnv() self.orchestrator = Orchestrator(self.env) self.training_metrics: dict[str, Any] = {} self.root = Path(__file__).resolve().parents[2] def reset(self, **kwargs: Any) -> dict[str, Any]: kwargs = apply_task_preset(dict(kwargs)) obs = self.env.reset(**kwargs) return obs.model_dump(mode="json") def step(self, action: dict[str, Any]) -> dict[str, Any]: obs, reward, done, info = self.env.step(action) reason = str(info.get("termination_reason", "")) if isinstance(info, dict) else "" truncated = reason in {"wall_clock_timeout", "step_timeout", "step_budget_exhausted"} return { "observation": obs.model_dump(mode="json"), "reward": reward, "done": done, "terminated": done, "truncated": truncated, "info": info, } def catalog(self) -> dict[str, Any]: return env_catalog() def step_candidate(self, candidate_id: str, confidence: float, rationale_brief: str) -> dict[str, Any] | None: for action in self.env.get_legal_actions(): if action.get("candidate_id") != candidate_id: continue payload = dict(action) payload["confidence"] = confidence payload["rationale_brief"] = rationale_brief return self.step(payload) return None def orchestrate(self, coordination_mode: str | None = None) -> dict[str, Any]: return self.orchestrator.run_step(coordination_mode=coordination_mode) def infer_policy(self) -> dict[str, Any]: legal = self.env.get_legal_actions() return legal[0] if legal else {} def batch_infer(self, batch_size: int = 4) -> list[dict[str, Any]]: legal = self.env.get_legal_actions() return legal[:batch_size] def run_baselines(self) -> dict[str, Any]: candidates = [c for c in self.env.get_candidate_actions() if c.get("legality_precheck")] if not candidates: self.env.reset() candidates = [c for c in self.env.get_candidate_actions() if c.get("legality_precheck")] baseline_results = { "no_change": choose_no_change().model_dump(mode="json"), "rules_only": choose_rules_only([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"), "greedy": choose_greedy([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"), "contextual_bandit": choose_contextual_bandit([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"), "contextual_bandit_topk": [ { "candidate_id": item.candidate_id, "score": item.score, "exploration_bonus": item.exploration_bonus, "algorithm": item.algorithm, } for item in choose_contextual_bandit_topk([self._candidate_obj(c) for c in candidates], top_k=3) ], "beam_search": choose_beam_search([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"), } return baseline_results def run_policy_eval(self) -> dict[str, Any]: out = build_benchmark_report(Path("outputs/reports/benchmark_report.txt")) return out def run_dosing_eval(self) -> dict[str, Any]: return dosing_eval() def run_training(self) -> dict[str, Any]: out_dir = Path("checkpoints") out_dir.mkdir(parents=True, exist_ok=True) self.training_metrics = { "supervisor": train_supervisor_grpo(episodes=4, checkpoint_dir=out_dir), "planner": train_planner_grpo(episodes=6, checkpoint_dir=out_dir), "dosing": train_dosing_grpo(episodes=4, checkpoint_dir=out_dir), } return self.training_metrics def get_metrics(self) -> dict[str, Any]: if self.training_metrics: if "planner" in self.training_metrics: merged = dict(self.training_metrics["planner"]) merged["model_metrics"] = self.training_metrics return merged return self.training_metrics reports_dir = Path("outputs/reports") metrics: dict[str, Any] = {} for name in ["supervisor_grpo", "planner_grpo", "dosing_grpo"]: path = reports_dir / f"{name}.json" if path.exists(): import json metrics[name] = json.loads(path.read_text(encoding="utf-8")) self.training_metrics = metrics if "planner_grpo" in metrics: merged = dict(metrics["planner_grpo"]) merged["model_metrics"] = metrics return merged return metrics def sample_case(self) -> dict[str, Any]: obs = self.env.reset() return obs.model_dump(mode="json") def search_cases(self, query: str) -> list[dict[str, Any]]: index_file = self.root / "data" / "retrieval_index" / "index.json" hits = retrieve(index_file=index_file, query=query, top_k=5) if hits: return [ { "patient_id": Path(item.get("path", f"case_{idx}")).stem, "query": query, "source_path": item.get("path", ""), "snippet": str(item.get("text", ""))[:280], } for idx, item in enumerate(hits) ] fallback: list[dict[str, Any]] = [] corpus = self.root / "data" / "processed" / "retrieval_corpus.jsonl" if corpus.exists(): query_tokens = {token for token in query.lower().split() if token} with corpus.open("r", encoding="utf-8") as handle: for idx, line in enumerate(handle): if len(fallback) >= 5: break text = line.strip() if not text: continue hay = text.lower() if query_tokens and not any(token in hay for token in query_tokens): continue fallback.append( { "patient_id": f"retrieval_corpus_{idx}", "query": query, "source_path": str(corpus), "snippet": text[:280], } ) return fallback def evidence_query(self, query: str, top_k: int = 5) -> list[dict[str, str]]: return retrieve_evidence(query=query, top_k=top_k) @staticmethod def _candidate_obj(payload: dict) -> Any: from app.common.types import CandidateAction return CandidateAction.model_validate(payload)