File size: 8,241 Bytes
31715b5 ee14542 d3cd20c 9030acd 536dda7 d3cd20c 536dda7 d3cd20c 77e65fb d3cd20c 9030acd d3cd20c 31715b5 ee14542 d3cd20c 31715b5 d3cd20c 77e65fb d3cd20c 536dda7 77e65fb 536dda7 d3cd20c 536dda7 77e65fb d3cd20c 536dda7 d3cd20c 77e65fb d3cd20c 77e65fb d3cd20c 77e65fb d3cd20c 9030acd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 | """FastAPI server exposing the OpenSleuth environment over HTTP.
Two HTTP surfaces are served from this app:
* The legacy OpenSleuth contract (``/health``, ``/functions``, ``/tasks``,
``/reset``, ``/step``, ``/state/{episode_id}``, ``/probe_once``) used by the
in-flight trainer and eval harness.
* The OpenEnv-conformant sub-app mounted at ``/openenv/*`` (added in v0.5.0
for hackathon conformance) -- exposes ``/openenv/reset``, ``/openenv/step``,
``/openenv/state``, ``/openenv/health``, ``/openenv/metadata``,
``/openenv/schema``, and the canonical ``/openenv/ws`` WebSocket. See
:mod:`opensleuth_env.openenv_adapter` and
https://github.com/meta-pytorch/OpenEnv (v0.2.3).
"""
from __future__ import annotations
import logging
import random
from typing import Optional
from fastapi import FastAPI, HTTPException, Query
from opensleuth_env import (
BLACK_BOX_FUNCTIONS,
OpenSleuthEnv,
ProbeAction,
ResetRequest,
StepRequest,
StepResponse,
SubmitAction,
TaskCatalog,
)
from opensleuth_env.task_catalog import TaskResolutionError
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
log = logging.getLogger("opensleuth.server")
app = FastAPI(title="OpenSleuth Env", version="0.5.0")
env = OpenSleuthEnv()
# ---------------------------------------------------------------------------
# OpenEnv conformance: mount an upstream-spec sub-app at /openenv.
# This is kept additive so the existing trainer (which talks to the bare
# /reset and /step routes above) is completely unaffected.
# ---------------------------------------------------------------------------
try:
from openenv.core.env_server.http_server import HTTPEnvServer
from opensleuth_env.openenv_adapter import (
OPENENV_AVAILABLE,
OpenSleuthAction,
OpenSleuthEnvironment,
OpenSleuthObservation,
)
if OPENENV_AVAILABLE:
openenv_app = FastAPI(
title="OpenSleuth (OpenEnv-conformant)",
version="0.5.0",
description=(
"OpenEnv 0.2.x conformant surface for the OpenSleuth environment.\n\n"
"See https://github.com/meta-pytorch/OpenEnv -- this sub-app implements"
" the canonical reset/step/state/health/metadata/schema HTTP routes plus"
" the /ws WebSocket session protocol."
),
)
_openenv_server = HTTPEnvServer(
env=OpenSleuthEnvironment,
action_cls=OpenSleuthAction,
observation_cls=OpenSleuthObservation,
max_concurrent_envs=8,
)
_openenv_server.register_routes(openenv_app)
app.mount("/openenv", openenv_app)
log.info("Mounted OpenEnv-conformant sub-app at /openenv (openenv-core %s)",
_openenv_server.__class__.__module__)
else: # pragma: no cover
log.warning("openenv-core not importable; /openenv/* will be unavailable.")
except Exception as e: # pragma: no cover - fail open so legacy routes keep working
log.warning("Could not register OpenEnv sub-app: %s: %s", type(e).__name__, e)
@app.get("/health")
def health():
return {
"status": "ok",
"episodes_tracked": len(env._states), # noqa: SLF001
"hub": env.catalog.hub_status(),
}
@app.get("/functions")
def list_functions(
difficulty: Optional[str] = Query(
None,
description="Optional filter: easy / medium / hard. Used by the trainer for curriculum scheduling.",
),
):
# NOTE -- backwards compatibility: this endpoint deliberately keeps the
# exact v0.3 shape (just the 9 builtin functions, with the original
# field set), because the in-flight trainer queries it. The new "source"
# field is additive. Open-ended / Hub tasks are exposed via /tasks.
items = []
for s in BLACK_BOX_FUNCTIONS.values():
if difficulty is not None and getattr(s, "difficulty", None) != difficulty:
continue
items.append(
{
"name": s.name,
"signature": s.signature,
"description": s.description,
"difficulty": getattr(s, "difficulty", None),
"edge_case_count": len(getattr(s, "edge_cases", []) or []),
"source": "builtin",
}
)
return {"functions": items}
@app.get("/tasks")
def list_tasks(
source: str = Query(
"all",
description="Filter by source: 'builtin', 'hub', or 'all' (default).",
),
difficulty: Optional[str] = Query(None, description="Optional curriculum filter."),
):
src = source.lower()
if src == "builtin":
tasks = env.catalog.list_builtin()
elif src == "hub":
tasks = env.catalog.list_hub()
elif src == "all":
tasks = env.catalog.list_all()
else:
raise HTTPException(
status_code=400, detail="source must be one of: builtin, hub, all"
)
if difficulty is not None:
tasks = [t for t in tasks if t.get("difficulty") == difficulty]
return {
"tasks": tasks,
"count": len(tasks),
"hub": env.catalog.hub_status(),
}
@app.post("/reset")
def reset(req: ResetRequest):
# Validation: legacy callers pass only target_name; open-ended callers
# pass target_code + target_function_name. At least one of those paths
# must be populated.
if not req.target_name and not req.target_code:
raise HTTPException(
status_code=400,
detail="Either 'target_name' or ('target_code' + 'target_function_name') must be set.",
)
if req.target_code and not req.target_function_name:
raise HTTPException(
status_code=400,
detail="'target_function_name' is required when 'target_code' is provided.",
)
try:
obs = env.reset(
target_name=req.target_name,
seed=req.seed,
max_steps=req.max_steps,
target_code=req.target_code,
target_function_name=req.target_function_name,
edge_cases=req.edge_cases,
fuzz_spec=req.fuzz_spec,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
return obs
@app.post("/step", response_model=StepResponse)
def step(req: StepRequest):
try:
return env.step(req.episode_id, req.action)
except KeyError as e:
raise HTTPException(status_code=404, detail=str(e)) from e
@app.get("/state/{episode_id}")
def get_state(episode_id: str):
state = env.get_state(episode_id)
if not state:
raise HTTPException(status_code=404, detail=f"Unknown episode_id {episode_id!r}")
return state
@app.post("/probe_once")
def probe_once(target_name: str, input_repr: str):
obs = env.reset(target_name=target_name)
resp = env.step(obs.episode_id, ProbeAction(input_repr=input_repr))
return resp
@app.get("/tasks/{name}/sample_inputs")
def sample_inputs(
name: str,
n: int = Query(8, ge=1, le=64, description="How many inputs to draw."),
seed: int = Query(0, description="Deterministic seed for the fuzzer."),
):
"""Return ``n`` Python-literal `repr` strings drawn from the task's
auto-fuzzer (or hand-written fuzzer for builtins).
Used by the trainer to build in-context probe pools without having to
duplicate the auto-fuzzer logic on the trainer side. Each returned
string is `ast.literal_eval`-safe and can be POSTed straight back to
`/step` as a `ProbeAction.input_repr`.
"""
try:
spec = env.catalog.resolve(target_name=name)
except TaskResolutionError as e:
raise HTTPException(status_code=404, detail=str(e)) from e
rng = random.Random(seed)
try:
raw_inputs = spec.fuzzer(rng, n)
except Exception as e: # noqa: BLE001
raise HTTPException(
status_code=500,
detail=f"fuzzer for {name!r} failed: {type(e).__name__}: {e}",
) from e
return {
"name": name,
"n": n,
"seed": seed,
"unpack_args": bool(getattr(spec, "unpack_args", False)),
"inputs": [repr(x) for x in raw_inputs],
}
|