| """API service layer.""" |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import Any |
|
|
| from app.agents.orchestrator import Orchestrator |
| from app.env.catalog import apply_task_preset, env_catalog |
| from app.env.env_core import PolyGuardEnv |
| from app.evaluation.benchmark_report import build_benchmark_report |
| from app.evaluation.dosing_eval import dosing_eval |
| from app.knowledge.evidence_retriever import retrieve_evidence |
| from app.models.retrieval.retriever import retrieve |
| from app.models.policy.provider_runtime import PolicyProviderRouter, default_provider_preference |
| from app.models.baselines import ( |
| choose_beam_search, |
| choose_contextual_bandit, |
| choose_contextual_bandit_topk, |
| choose_greedy, |
| choose_no_change, |
| choose_rules_only, |
| ) |
| from app.training import train_dosing_grpo, train_planner_grpo, train_supervisor_grpo |
|
|
|
|
| class APIService: |
| def __init__(self) -> None: |
| self.env = PolyGuardEnv() |
| self.orchestrator = Orchestrator(self.env) |
| self.policy_router = PolicyProviderRouter() |
| self.training_metrics: dict[str, Any] = {} |
| self.root = Path(__file__).resolve().parents[2] |
|
|
| def reset(self, **kwargs: Any) -> dict[str, Any]: |
| kwargs = apply_task_preset(dict(kwargs)) |
| obs = self.env.reset(**kwargs) |
| return obs.model_dump(mode="json") |
|
|
| def step(self, action: dict[str, Any]) -> dict[str, Any]: |
| obs, reward, done, info = self.env.step(action) |
| reason = str(info.get("termination_reason", "")) if isinstance(info, dict) else "" |
| truncated = reason in {"wall_clock_timeout", "step_timeout", "step_budget_exhausted"} |
| return { |
| "observation": obs.model_dump(mode="json"), |
| "reward": reward, |
| "done": done, |
| "terminated": done, |
| "truncated": truncated, |
| "info": info, |
| } |
|
|
| def catalog(self) -> dict[str, Any]: |
| return env_catalog() |
|
|
| def step_candidate(self, candidate_id: str, confidence: float, rationale_brief: str) -> dict[str, Any] | None: |
| for action in self.env.get_legal_actions(): |
| if action.get("candidate_id") != candidate_id: |
| continue |
| payload = dict(action) |
| payload["confidence"] = confidence |
| payload["rationale_brief"] = rationale_brief |
| return self.step(payload) |
| return None |
|
|
| def orchestrate(self, coordination_mode: str | None = None) -> dict[str, Any]: |
| return self.orchestrator.run_step(coordination_mode=coordination_mode) |
|
|
| def infer_policy(self) -> dict[str, Any]: |
| legal = self.env.get_legal_actions() |
| if not legal: |
| return {} |
| candidate_payloads = [ |
| item for item in self.env.get_candidate_actions() if bool(item.get("legality_precheck", False)) |
| ] |
| if not candidate_payloads: |
| return legal[0] |
| candidates = [self._candidate_obj(item) for item in candidate_payloads] |
| state = self.env.state |
| selection = self.policy_router.select_candidate( |
| candidates=candidates, |
| prompt={ |
| "patient_id": state.patient.patient_id, |
| "difficulty": state.difficulty.value, |
| "sub_environment": state.sub_environment.value, |
| "step_count": state.step_count, |
| }, |
| provider_preference=default_provider_preference(), |
| ) |
| selected = next((item for item in legal if item.get("candidate_id") == selection.candidate_id), legal[0]) |
| payload = dict(selected) |
| payload["policy_selection"] = { |
| "provider": selection.provider, |
| "candidate_id": selection.candidate_id, |
| "rationale": selection.rationale, |
| "latency_ms": round(selection.latency_ms, 3), |
| "raw_output": selection.raw_output, |
| } |
| return payload |
|
|
| def model_status(self) -> dict[str, Any]: |
| return self.policy_router.model_status() |
|
|
| def batch_infer(self, batch_size: int = 4) -> list[dict[str, Any]]: |
| legal = self.env.get_legal_actions() |
| return legal[:batch_size] |
|
|
| def run_baselines(self) -> dict[str, Any]: |
| candidates = [c for c in self.env.get_candidate_actions() if c.get("legality_precheck")] |
| if not candidates: |
| self.env.reset() |
| candidates = [c for c in self.env.get_candidate_actions() if c.get("legality_precheck")] |
| baseline_results = { |
| "no_change": choose_no_change().model_dump(mode="json"), |
| "rules_only": choose_rules_only([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"), |
| "greedy": choose_greedy([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"), |
| "contextual_bandit": choose_contextual_bandit([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"), |
| "contextual_bandit_topk": [ |
| { |
| "candidate_id": item.candidate_id, |
| "score": item.score, |
| "exploration_bonus": item.exploration_bonus, |
| "algorithm": item.algorithm, |
| } |
| for item in choose_contextual_bandit_topk([self._candidate_obj(c) for c in candidates], top_k=3) |
| ], |
| "beam_search": choose_beam_search([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"), |
| } |
| return baseline_results |
|
|
| def run_policy_eval(self) -> dict[str, Any]: |
| out = build_benchmark_report(Path("outputs/reports/benchmark_report.txt")) |
| return out |
|
|
| def run_dosing_eval(self) -> dict[str, Any]: |
| return dosing_eval() |
|
|
| def run_training(self) -> dict[str, Any]: |
| out_dir = Path("checkpoints") |
| out_dir.mkdir(parents=True, exist_ok=True) |
| self.training_metrics = { |
| "supervisor": train_supervisor_grpo(episodes=4, checkpoint_dir=out_dir), |
| "planner": train_planner_grpo(episodes=6, checkpoint_dir=out_dir), |
| "dosing": train_dosing_grpo(episodes=4, checkpoint_dir=out_dir), |
| } |
| return self.training_metrics |
|
|
| def get_metrics(self) -> dict[str, Any]: |
| if self.training_metrics: |
| if "planner" in self.training_metrics: |
| merged = dict(self.training_metrics["planner"]) |
| merged["model_metrics"] = self.training_metrics |
| return merged |
| return self.training_metrics |
| reports_dir = Path("outputs/reports") |
| metrics: dict[str, Any] = {} |
| for name in ["supervisor_grpo", "planner_grpo", "dosing_grpo"]: |
| path = reports_dir / f"{name}.json" |
| if path.exists(): |
| import json |
|
|
| metrics[name] = json.loads(path.read_text(encoding="utf-8")) |
| self.training_metrics = metrics |
| if "planner_grpo" in metrics: |
| merged = dict(metrics["planner_grpo"]) |
| merged["model_metrics"] = metrics |
| return merged |
| return metrics |
|
|
| def sample_case(self) -> dict[str, Any]: |
| obs = self.env.reset() |
| return obs.model_dump(mode="json") |
|
|
| def search_cases(self, query: str) -> list[dict[str, Any]]: |
| index_file = self.root / "data" / "retrieval_index" / "index.json" |
| hits = retrieve(index_file=index_file, query=query, top_k=5) |
| if hits: |
| return [ |
| { |
| "patient_id": Path(item.get("path", f"case_{idx}")).stem, |
| "query": query, |
| "source_path": item.get("path", ""), |
| "snippet": str(item.get("text", ""))[:280], |
| } |
| for idx, item in enumerate(hits) |
| ] |
|
|
| fallback: list[dict[str, Any]] = [] |
| corpus = self.root / "data" / "processed" / "retrieval_corpus.jsonl" |
| if corpus.exists(): |
| query_tokens = {token for token in query.lower().split() if token} |
| with corpus.open("r", encoding="utf-8") as handle: |
| for idx, line in enumerate(handle): |
| if len(fallback) >= 5: |
| break |
| text = line.strip() |
| if not text: |
| continue |
| hay = text.lower() |
| if query_tokens and not any(token in hay for token in query_tokens): |
| continue |
| fallback.append( |
| { |
| "patient_id": f"retrieval_corpus_{idx}", |
| "query": query, |
| "source_path": str(corpus), |
| "snippet": text[:280], |
| } |
| ) |
| return fallback |
|
|
| def evidence_query(self, query: str, top_k: int = 5) -> list[dict[str, str]]: |
| return retrieve_evidence(query=query, top_k=top_k) |
|
|
| @staticmethod |
| def _candidate_obj(payload: dict) -> Any: |
| from app.common.types import CandidateAction |
|
|
| return CandidateAction.model_validate(payload) |
|
|