polyguard-openenv / app /api /service.py
TheJackBright's picture
Deploy PolyGuard OpenEnv Space
877add7 verified
"""API service layer."""
from __future__ import annotations
from pathlib import Path
from typing import Any
from app.agents.orchestrator import Orchestrator
from app.env.catalog import apply_task_preset, env_catalog
from app.env.env_core import PolyGuardEnv
from app.evaluation.benchmark_report import build_benchmark_report
from app.evaluation.dosing_eval import dosing_eval
from app.knowledge.evidence_retriever import retrieve_evidence
from app.models.retrieval.retriever import retrieve
from app.models.baselines import (
choose_beam_search,
choose_contextual_bandit,
choose_contextual_bandit_topk,
choose_greedy,
choose_no_change,
choose_rules_only,
)
from app.training import train_dosing_grpo, train_planner_grpo, train_supervisor_grpo
class APIService:
def __init__(self) -> None:
self.env = PolyGuardEnv()
self.orchestrator = Orchestrator(self.env)
self.training_metrics: dict[str, Any] = {}
self.root = Path(__file__).resolve().parents[2]
def reset(self, **kwargs: Any) -> dict[str, Any]:
kwargs = apply_task_preset(dict(kwargs))
obs = self.env.reset(**kwargs)
return obs.model_dump(mode="json")
def step(self, action: dict[str, Any]) -> dict[str, Any]:
obs, reward, done, info = self.env.step(action)
reason = str(info.get("termination_reason", "")) if isinstance(info, dict) else ""
truncated = reason in {"wall_clock_timeout", "step_timeout", "step_budget_exhausted"}
return {
"observation": obs.model_dump(mode="json"),
"reward": reward,
"done": done,
"terminated": done,
"truncated": truncated,
"info": info,
}
def catalog(self) -> dict[str, Any]:
return env_catalog()
def step_candidate(self, candidate_id: str, confidence: float, rationale_brief: str) -> dict[str, Any] | None:
for action in self.env.get_legal_actions():
if action.get("candidate_id") != candidate_id:
continue
payload = dict(action)
payload["confidence"] = confidence
payload["rationale_brief"] = rationale_brief
return self.step(payload)
return None
def orchestrate(self, coordination_mode: str | None = None) -> dict[str, Any]:
return self.orchestrator.run_step(coordination_mode=coordination_mode)
def infer_policy(self) -> dict[str, Any]:
legal = self.env.get_legal_actions()
return legal[0] if legal else {}
def batch_infer(self, batch_size: int = 4) -> list[dict[str, Any]]:
legal = self.env.get_legal_actions()
return legal[:batch_size]
def run_baselines(self) -> dict[str, Any]:
candidates = [c for c in self.env.get_candidate_actions() if c.get("legality_precheck")]
if not candidates:
self.env.reset()
candidates = [c for c in self.env.get_candidate_actions() if c.get("legality_precheck")]
baseline_results = {
"no_change": choose_no_change().model_dump(mode="json"),
"rules_only": choose_rules_only([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"),
"greedy": choose_greedy([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"),
"contextual_bandit": choose_contextual_bandit([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"),
"contextual_bandit_topk": [
{
"candidate_id": item.candidate_id,
"score": item.score,
"exploration_bonus": item.exploration_bonus,
"algorithm": item.algorithm,
}
for item in choose_contextual_bandit_topk([self._candidate_obj(c) for c in candidates], top_k=3)
],
"beam_search": choose_beam_search([self._candidate_obj(c) for c in candidates]).model_dump(mode="json"),
}
return baseline_results
def run_policy_eval(self) -> dict[str, Any]:
out = build_benchmark_report(Path("outputs/reports/benchmark_report.txt"))
return out
def run_dosing_eval(self) -> dict[str, Any]:
return dosing_eval()
def run_training(self) -> dict[str, Any]:
out_dir = Path("checkpoints")
out_dir.mkdir(parents=True, exist_ok=True)
self.training_metrics = {
"supervisor": train_supervisor_grpo(episodes=4, checkpoint_dir=out_dir),
"planner": train_planner_grpo(episodes=6, checkpoint_dir=out_dir),
"dosing": train_dosing_grpo(episodes=4, checkpoint_dir=out_dir),
}
return self.training_metrics
def get_metrics(self) -> dict[str, Any]:
if self.training_metrics:
if "planner" in self.training_metrics:
merged = dict(self.training_metrics["planner"])
merged["model_metrics"] = self.training_metrics
return merged
return self.training_metrics
reports_dir = Path("outputs/reports")
metrics: dict[str, Any] = {}
for name in ["supervisor_grpo", "planner_grpo", "dosing_grpo"]:
path = reports_dir / f"{name}.json"
if path.exists():
import json
metrics[name] = json.loads(path.read_text(encoding="utf-8"))
self.training_metrics = metrics
if "planner_grpo" in metrics:
merged = dict(metrics["planner_grpo"])
merged["model_metrics"] = metrics
return merged
return metrics
def sample_case(self) -> dict[str, Any]:
obs = self.env.reset()
return obs.model_dump(mode="json")
def search_cases(self, query: str) -> list[dict[str, Any]]:
index_file = self.root / "data" / "retrieval_index" / "index.json"
hits = retrieve(index_file=index_file, query=query, top_k=5)
if hits:
return [
{
"patient_id": Path(item.get("path", f"case_{idx}")).stem,
"query": query,
"source_path": item.get("path", ""),
"snippet": str(item.get("text", ""))[:280],
}
for idx, item in enumerate(hits)
]
fallback: list[dict[str, Any]] = []
corpus = self.root / "data" / "processed" / "retrieval_corpus.jsonl"
if corpus.exists():
query_tokens = {token for token in query.lower().split() if token}
with corpus.open("r", encoding="utf-8") as handle:
for idx, line in enumerate(handle):
if len(fallback) >= 5:
break
text = line.strip()
if not text:
continue
hay = text.lower()
if query_tokens and not any(token in hay for token in query_tokens):
continue
fallback.append(
{
"patient_id": f"retrieval_corpus_{idx}",
"query": query,
"source_path": str(corpus),
"snippet": text[:280],
}
)
return fallback
def evidence_query(self, query: str, top_k: int = 5) -> list[dict[str, str]]:
return retrieve_evidence(query=query, top_k=top_k)
@staticmethod
def _candidate_obj(payload: dict) -> Any:
from app.common.types import CandidateAction
return CandidateAction.model_validate(payload)