| """FastAPI server exposing the OpenSleuth environment over HTTP.""" |
|
|
| from __future__ import annotations |
|
|
| import logging |
|
|
| from fastapi import FastAPI, HTTPException |
|
|
| from opensleuth_env import ( |
| BLACK_BOX_FUNCTIONS, |
| OpenSleuthEnv, |
| ProbeAction, |
| ResetRequest, |
| StepRequest, |
| StepResponse, |
| SubmitAction, |
| ) |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") |
| log = logging.getLogger("opensleuth.server") |
|
|
| app = FastAPI(title="OpenSleuth Env", version="0.2.0") |
| env = OpenSleuthEnv() |
|
|
|
|
| @app.get("/health") |
| def health(): |
| return {"status": "ok", "episodes_tracked": len(env._states)} |
|
|
|
|
| @app.get("/functions") |
| def list_functions(): |
| return { |
| "functions": [ |
| { |
| "name": s.name, |
| "signature": s.signature, |
| "description": s.description, |
| } |
| for s in BLACK_BOX_FUNCTIONS.values() |
| ] |
| } |
|
|
|
|
| @app.post("/reset") |
| def reset(req: ResetRequest): |
| try: |
| obs = env.reset(target_name=req.target_name, seed=req.seed, max_steps=req.max_steps) |
| except ValueError as e: |
| raise HTTPException(status_code=400, detail=str(e)) from e |
| return obs |
|
|
|
|
| @app.post("/step", response_model=StepResponse) |
| def step(req: StepRequest): |
| try: |
| return env.step(req.episode_id, req.action) |
| except KeyError as e: |
| raise HTTPException(status_code=404, detail=str(e)) from e |
|
|
|
|
| @app.get("/state/{episode_id}") |
| def get_state(episode_id: str): |
| state = env.get_state(episode_id) |
| if not state: |
| raise HTTPException(status_code=404, detail=f"Unknown episode_id {episode_id!r}") |
| return state |
|
|
|
|
| |
| |
| @app.post("/probe_once") |
| def probe_once(target_name: str, input_repr: str): |
| obs = env.reset(target_name=target_name) |
| resp = env.step(obs.episode_id, ProbeAction(input_repr=input_repr)) |
| return resp |
|
|