anugrah55's picture
Level 2 open-ended env: auto-fuzzer + TaskCatalog + Hub-driven catalog + extended /reset
77e65fb verified
raw
history blame
4.43 kB
"""FastAPI server exposing the OpenSleuth environment over HTTP."""
from __future__ import annotations
import logging
from typing import Optional
from fastapi import FastAPI, HTTPException, Query
from opensleuth_env import (
BLACK_BOX_FUNCTIONS,
OpenSleuthEnv,
ProbeAction,
ResetRequest,
StepRequest,
StepResponse,
SubmitAction,
TaskCatalog,
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
log = logging.getLogger("opensleuth.server")
app = FastAPI(title="OpenSleuth Env", version="0.4.0")
env = OpenSleuthEnv()
@app.get("/health")
def health():
return {
"status": "ok",
"episodes_tracked": len(env._states), # noqa: SLF001
"hub": env.catalog.hub_status(),
}
@app.get("/functions")
def list_functions(
difficulty: Optional[str] = Query(
None,
description="Optional filter: easy / medium / hard. Used by the trainer for curriculum scheduling.",
),
):
# NOTE -- backwards compatibility: this endpoint deliberately keeps the
# exact v0.3 shape (just the 9 builtin functions, with the original
# field set), because the in-flight trainer queries it. The new "source"
# field is additive. Open-ended / Hub tasks are exposed via /tasks.
items = []
for s in BLACK_BOX_FUNCTIONS.values():
if difficulty is not None and getattr(s, "difficulty", None) != difficulty:
continue
items.append(
{
"name": s.name,
"signature": s.signature,
"description": s.description,
"difficulty": getattr(s, "difficulty", None),
"edge_case_count": len(getattr(s, "edge_cases", []) or []),
"source": "builtin",
}
)
return {"functions": items}
@app.get("/tasks")
def list_tasks(
source: str = Query(
"all",
description="Filter by source: 'builtin', 'hub', or 'all' (default).",
),
difficulty: Optional[str] = Query(None, description="Optional curriculum filter."),
):
src = source.lower()
if src == "builtin":
tasks = env.catalog.list_builtin()
elif src == "hub":
tasks = env.catalog.list_hub()
elif src == "all":
tasks = env.catalog.list_all()
else:
raise HTTPException(
status_code=400, detail="source must be one of: builtin, hub, all"
)
if difficulty is not None:
tasks = [t for t in tasks if t.get("difficulty") == difficulty]
return {
"tasks": tasks,
"count": len(tasks),
"hub": env.catalog.hub_status(),
}
@app.post("/reset")
def reset(req: ResetRequest):
# Validation: legacy callers pass only target_name; open-ended callers
# pass target_code + target_function_name. At least one of those paths
# must be populated.
if not req.target_name and not req.target_code:
raise HTTPException(
status_code=400,
detail="Either 'target_name' or ('target_code' + 'target_function_name') must be set.",
)
if req.target_code and not req.target_function_name:
raise HTTPException(
status_code=400,
detail="'target_function_name' is required when 'target_code' is provided.",
)
try:
obs = env.reset(
target_name=req.target_name,
seed=req.seed,
max_steps=req.max_steps,
target_code=req.target_code,
target_function_name=req.target_function_name,
edge_cases=req.edge_cases,
fuzz_spec=req.fuzz_spec,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
return obs
@app.post("/step", response_model=StepResponse)
def step(req: StepRequest):
try:
return env.step(req.episode_id, req.action)
except KeyError as e:
raise HTTPException(status_code=404, detail=str(e)) from e
@app.get("/state/{episode_id}")
def get_state(episode_id: str):
state = env.get_state(episode_id)
if not state:
raise HTTPException(status_code=404, detail=f"Unknown episode_id {episode_id!r}")
return state
@app.post("/probe_once")
def probe_once(target_name: str, input_repr: str):
obs = env.reset(target_name=target_name)
resp = env.step(obs.episode_id, ProbeAction(input_repr=input_repr))
return resp