File size: 4,431 Bytes
d3cd20c ee14542 d3cd20c 536dda7 d3cd20c 536dda7 d3cd20c 77e65fb d3cd20c 77e65fb ee14542 d3cd20c 77e65fb d3cd20c 536dda7 77e65fb 536dda7 d3cd20c 536dda7 77e65fb d3cd20c 536dda7 d3cd20c 77e65fb d3cd20c 77e65fb d3cd20c 77e65fb d3cd20c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | """FastAPI server exposing the OpenSleuth environment over HTTP."""
from __future__ import annotations
import logging
from typing import Optional
from fastapi import FastAPI, HTTPException, Query
from opensleuth_env import (
BLACK_BOX_FUNCTIONS,
OpenSleuthEnv,
ProbeAction,
ResetRequest,
StepRequest,
StepResponse,
SubmitAction,
TaskCatalog,
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
log = logging.getLogger("opensleuth.server")
app = FastAPI(title="OpenSleuth Env", version="0.4.0")
env = OpenSleuthEnv()
@app.get("/health")
def health():
return {
"status": "ok",
"episodes_tracked": len(env._states), # noqa: SLF001
"hub": env.catalog.hub_status(),
}
@app.get("/functions")
def list_functions(
difficulty: Optional[str] = Query(
None,
description="Optional filter: easy / medium / hard. Used by the trainer for curriculum scheduling.",
),
):
# NOTE -- backwards compatibility: this endpoint deliberately keeps the
# exact v0.3 shape (just the 9 builtin functions, with the original
# field set), because the in-flight trainer queries it. The new "source"
# field is additive. Open-ended / Hub tasks are exposed via /tasks.
items = []
for s in BLACK_BOX_FUNCTIONS.values():
if difficulty is not None and getattr(s, "difficulty", None) != difficulty:
continue
items.append(
{
"name": s.name,
"signature": s.signature,
"description": s.description,
"difficulty": getattr(s, "difficulty", None),
"edge_case_count": len(getattr(s, "edge_cases", []) or []),
"source": "builtin",
}
)
return {"functions": items}
@app.get("/tasks")
def list_tasks(
source: str = Query(
"all",
description="Filter by source: 'builtin', 'hub', or 'all' (default).",
),
difficulty: Optional[str] = Query(None, description="Optional curriculum filter."),
):
src = source.lower()
if src == "builtin":
tasks = env.catalog.list_builtin()
elif src == "hub":
tasks = env.catalog.list_hub()
elif src == "all":
tasks = env.catalog.list_all()
else:
raise HTTPException(
status_code=400, detail="source must be one of: builtin, hub, all"
)
if difficulty is not None:
tasks = [t for t in tasks if t.get("difficulty") == difficulty]
return {
"tasks": tasks,
"count": len(tasks),
"hub": env.catalog.hub_status(),
}
@app.post("/reset")
def reset(req: ResetRequest):
# Validation: legacy callers pass only target_name; open-ended callers
# pass target_code + target_function_name. At least one of those paths
# must be populated.
if not req.target_name and not req.target_code:
raise HTTPException(
status_code=400,
detail="Either 'target_name' or ('target_code' + 'target_function_name') must be set.",
)
if req.target_code and not req.target_function_name:
raise HTTPException(
status_code=400,
detail="'target_function_name' is required when 'target_code' is provided.",
)
try:
obs = env.reset(
target_name=req.target_name,
seed=req.seed,
max_steps=req.max_steps,
target_code=req.target_code,
target_function_name=req.target_function_name,
edge_cases=req.edge_cases,
fuzz_spec=req.fuzz_spec,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
return obs
@app.post("/step", response_model=StepResponse)
def step(req: StepRequest):
try:
return env.step(req.episode_id, req.action)
except KeyError as e:
raise HTTPException(status_code=404, detail=str(e)) from e
@app.get("/state/{episode_id}")
def get_state(episode_id: str):
state = env.get_state(episode_id)
if not state:
raise HTTPException(status_code=404, detail=f"Unknown episode_id {episode_id!r}")
return state
@app.post("/probe_once")
def probe_once(target_name: str, input_repr: str):
obs = env.reset(target_name=target_name)
resp = env.step(obs.episode_id, ProbeAction(input_repr=input_repr))
return resp
|