| """FastAPI server exposing the OpenSleuth environment over HTTP.""" |
|
|
| from __future__ import annotations |
|
|
| import logging |
| from typing import Optional |
|
|
| from fastapi import FastAPI, HTTPException, Query |
|
|
| from opensleuth_env import ( |
| BLACK_BOX_FUNCTIONS, |
| OpenSleuthEnv, |
| ProbeAction, |
| ResetRequest, |
| StepRequest, |
| StepResponse, |
| SubmitAction, |
| TaskCatalog, |
| ) |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") |
| log = logging.getLogger("opensleuth.server") |
|
|
| app = FastAPI(title="OpenSleuth Env", version="0.4.0") |
| env = OpenSleuthEnv() |
|
|
|
|
| @app.get("/health") |
| def health(): |
| return { |
| "status": "ok", |
| "episodes_tracked": len(env._states), |
| "hub": env.catalog.hub_status(), |
| } |
|
|
|
|
| @app.get("/functions") |
| def list_functions( |
| difficulty: Optional[str] = Query( |
| None, |
| description="Optional filter: easy / medium / hard. Used by the trainer for curriculum scheduling.", |
| ), |
| ): |
| |
| |
| |
| |
| items = [] |
| for s in BLACK_BOX_FUNCTIONS.values(): |
| if difficulty is not None and getattr(s, "difficulty", None) != difficulty: |
| continue |
| items.append( |
| { |
| "name": s.name, |
| "signature": s.signature, |
| "description": s.description, |
| "difficulty": getattr(s, "difficulty", None), |
| "edge_case_count": len(getattr(s, "edge_cases", []) or []), |
| "source": "builtin", |
| } |
| ) |
| return {"functions": items} |
|
|
|
|
| @app.get("/tasks") |
| def list_tasks( |
| source: str = Query( |
| "all", |
| description="Filter by source: 'builtin', 'hub', or 'all' (default).", |
| ), |
| difficulty: Optional[str] = Query(None, description="Optional curriculum filter."), |
| ): |
| src = source.lower() |
| if src == "builtin": |
| tasks = env.catalog.list_builtin() |
| elif src == "hub": |
| tasks = env.catalog.list_hub() |
| elif src == "all": |
| tasks = env.catalog.list_all() |
| else: |
| raise HTTPException( |
| status_code=400, detail="source must be one of: builtin, hub, all" |
| ) |
| if difficulty is not None: |
| tasks = [t for t in tasks if t.get("difficulty") == difficulty] |
| return { |
| "tasks": tasks, |
| "count": len(tasks), |
| "hub": env.catalog.hub_status(), |
| } |
|
|
|
|
| @app.post("/reset") |
| def reset(req: ResetRequest): |
| |
| |
| |
| if not req.target_name and not req.target_code: |
| raise HTTPException( |
| status_code=400, |
| detail="Either 'target_name' or ('target_code' + 'target_function_name') must be set.", |
| ) |
| if req.target_code and not req.target_function_name: |
| raise HTTPException( |
| status_code=400, |
| detail="'target_function_name' is required when 'target_code' is provided.", |
| ) |
| try: |
| obs = env.reset( |
| target_name=req.target_name, |
| seed=req.seed, |
| max_steps=req.max_steps, |
| target_code=req.target_code, |
| target_function_name=req.target_function_name, |
| edge_cases=req.edge_cases, |
| fuzz_spec=req.fuzz_spec, |
| ) |
| except ValueError as e: |
| raise HTTPException(status_code=400, detail=str(e)) from e |
| return obs |
|
|
|
|
| @app.post("/step", response_model=StepResponse) |
| def step(req: StepRequest): |
| try: |
| return env.step(req.episode_id, req.action) |
| except KeyError as e: |
| raise HTTPException(status_code=404, detail=str(e)) from e |
|
|
|
|
| @app.get("/state/{episode_id}") |
| def get_state(episode_id: str): |
| state = env.get_state(episode_id) |
| if not state: |
| raise HTTPException(status_code=404, detail=f"Unknown episode_id {episode_id!r}") |
| return state |
|
|
|
|
| @app.post("/probe_once") |
| def probe_once(target_name: str, input_repr: str): |
| obs = env.reset(target_name=target_name) |
| resp = env.step(obs.episode_id, ProbeAction(input_repr=input_repr)) |
| return resp |
|
|