Spaces:
Running
Running
| """API service layer.""" | |
| from __future__ import annotations | |
| from pathlib import Path | |
| from typing import Any | |
| from app.agents.orchestrator import Orchestrator | |
| from app.env.catalog import apply_task_preset, env_catalog | |
| from app.env.env_core import PolyGuardEnv | |
| from app.evaluation.benchmark_report import build_benchmark_report | |
| from app.evaluation.dosing_eval import dosing_eval | |
| from app.knowledge.evidence_retriever import retrieve_evidence | |
| from app.models.retrieval.retriever import retrieve | |
| from app.models.baselines import ( | |
| choose_beam_search, | |
| choose_contextual_bandit, | |
| choose_contextual_bandit_topk, | |
| choose_greedy, | |
| choose_no_change, | |
| choose_rules_only, | |
| ) | |
| from app.training import train_dosing_grpo, train_planner_grpo, train_supervisor_grpo | |
| class APIService: | |
| def __init__(self) -> None: | |
| self.env = PolyGuardEnv() | |
| self.orchestrator = Orchestrator(self.env) | |
| self.training_metrics: dict[str, Any] = {} | |
| self.root = Path(__file__).resolve().parents[2] | |
| def reset(self, **kwargs: Any) -> dict[str, Any]: | |
| kwargs = apply_task_preset(dict(kwargs)) | |
| obs = self.env.reset(**kwargs) | |
| return obs.model_dump(mode="json") | |
| def step(self, action: dict[str, Any]) -> dict[str, Any]: | |
| obs, reward, done, info = self.env.step(action) | |
| reason = str(info.get("termination_reason", "")) if isinstance(info, dict) else "" | |
| truncated = reason in {"wall_clock_timeout", "step_timeout", "step_budget_exhausted"} | |
| return { | |
| "observation": obs.model_dump(mode="json"), | |
| "reward": reward, | |
| "done": done, | |
| "terminated": done, | |
| "truncated": truncated, | |
| "info": info, | |
| } | |
| def catalog(self) -> dict[str, Any]: | |
| return env_catalog() | |
| def step_candidate(self, candidate_id: str, confidence: float, rationale_brief: str) -> dict[str, Any] | None: | |
| for action in self.env.get_legal_actions(): | |
| if action.get("candidate_id") != candidate_id: | |
| continue | |
| payload = dict(action) | |
| payload["confidence"] = confidence | |
| payload["rationale_brief"] = rationale_brief | |
| return self.step(payload) | |
| return None | |
| def orchestrate(self, coordination_mode: str | None = None) -> dict[str, Any]: | |
| return self.orchestrator.run_step(coordination_mode=coordination_mode) | |
| def infer_policy(self) -> dict[str, Any]: | |
| legal = self.env.get_legal_actions() | |
| return legal[0] if legal else {} | |
| def batch_infer(self, batch_size: int = 4) -> list[dict[str, Any]]: | |
| legal = self.env.get_legal_actions() | |
| return legal[:batch_size] | |
| def run_baselines(self) -> dict[str, Any]: | |
| candidates = [c for c in self.env.get_candidate_actions() if c.get("legality_precheck")] | |
| if not candidates: | |
| self.env.reset() | |
| candidates = [c for c in self.env.get_candidate_actions() if c.get("legality_precheck")] | |
| baseline_results = { | |
| "no_change": choose_no_change().model_dump(mode="json"), | |
| "rules_only": choose_rules_only([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"), | |
| "greedy": choose_greedy([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"), | |
| "contextual_bandit": choose_contextual_bandit([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"), | |
| "contextual_bandit_topk": [ | |
| { | |
| "candidate_id": item.candidate_id, | |
| "score": item.score, | |
| "exploration_bonus": item.exploration_bonus, | |
| "algorithm": item.algorithm, | |
| } | |
| for item in choose_contextual_bandit_topk([self._candidate_obj(c) for c in candidates], top_k=3) | |
| ], | |
| "beam_search": choose_beam_search([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"), | |
| } | |
| return baseline_results | |
| def run_policy_eval(self) -> dict[str, Any]: | |
| out = build_benchmark_report(Path("outputs/reports/benchmark_report.txt")) | |
| return out | |
| def run_dosing_eval(self) -> dict[str, Any]: | |
| return dosing_eval() | |
| def run_training(self) -> dict[str, Any]: | |
| out_dir = Path("checkpoints") | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| self.training_metrics = { | |
| "supervisor": train_supervisor_grpo(episodes=4, checkpoint_dir=out_dir), | |
| "planner": train_planner_grpo(episodes=6, checkpoint_dir=out_dir), | |
| "dosing": train_dosing_grpo(episodes=4, checkpoint_dir=out_dir), | |
| } | |
| return self.training_metrics | |
| def get_metrics(self) -> dict[str, Any]: | |
| if self.training_metrics: | |
| if "planner" in self.training_metrics: | |
| merged = dict(self.training_metrics["planner"]) | |
| merged["model_metrics"] = self.training_metrics | |
| return merged | |
| return self.training_metrics | |
| reports_dir = Path("outputs/reports") | |
| metrics: dict[str, Any] = {} | |
| for name in ["supervisor_grpo", "planner_grpo", "dosing_grpo"]: | |
| path = reports_dir / f"{name}.json" | |
| if path.exists(): | |
| import json | |
| metrics[name] = json.loads(path.read_text(encoding="utf-8")) | |
| self.training_metrics = metrics | |
| if "planner_grpo" in metrics: | |
| merged = dict(metrics["planner_grpo"]) | |
| merged["model_metrics"] = metrics | |
| return merged | |
| return metrics | |
| def sample_case(self) -> dict[str, Any]: | |
| obs = self.env.reset() | |
| return obs.model_dump(mode="json") | |
| def search_cases(self, query: str) -> list[dict[str, Any]]: | |
| index_file = self.root / "data" / "retrieval_index" / "index.json" | |
| hits = retrieve(index_file=index_file, query=query, top_k=5) | |
| if hits: | |
| return [ | |
| { | |
| "patient_id": Path(item.get("path", f"case_{idx}")).stem, | |
| "query": query, | |
| "source_path": item.get("path", ""), | |
| "snippet": str(item.get("text", ""))[:280], | |
| } | |
| for idx, item in enumerate(hits) | |
| ] | |
| fallback: list[dict[str, Any]] = [] | |
| corpus = self.root / "data" / "processed" / "retrieval_corpus.jsonl" | |
| if corpus.exists(): | |
| query_tokens = {token for token in query.lower().split() if token} | |
| with corpus.open("r", encoding="utf-8") as handle: | |
| for idx, line in enumerate(handle): | |
| if len(fallback) >= 5: | |
| break | |
| text = line.strip() | |
| if not text: | |
| continue | |
| hay = text.lower() | |
| if query_tokens and not any(token in hay for token in query_tokens): | |
| continue | |
| fallback.append( | |
| { | |
| "patient_id": f"retrieval_corpus_{idx}", | |
| "query": query, | |
| "source_path": str(corpus), | |
| "snippet": text[:280], | |
| } | |
| ) | |
| return fallback | |
| def evidence_query(self, query: str, top_k: int = 5) -> list[dict[str, str]]: | |
| return retrieve_evidence(query=query, top_k=top_k) | |
| def _candidate_obj(payload: dict) -> Any: | |
| from app.common.types import CandidateAction | |
| return CandidateAction.model_validate(payload) | |