anugrah55's picture
OpenEnv 0.2.3 conformance: mount /openenv sub-app, add adapter + tests + example client
31715b5 verified
"""FastAPI server exposing the OpenSleuth environment over HTTP.
Two HTTP surfaces are served from this app:
* The legacy OpenSleuth contract (``/health``, ``/functions``, ``/tasks``,
``/reset``, ``/step``, ``/state/{episode_id}``, ``/probe_once``) used by the
in-flight trainer and eval harness.
* The OpenEnv-conformant sub-app mounted at ``/openenv/*`` (added in v0.5.0
for hackathon conformance) -- exposes ``/openenv/reset``, ``/openenv/step``,
``/openenv/state``, ``/openenv/health``, ``/openenv/metadata``,
``/openenv/schema``, and the canonical ``/openenv/ws`` WebSocket. See
:mod:`opensleuth_env.openenv_adapter` and
https://github.com/meta-pytorch/OpenEnv (v0.2.3).
"""
from __future__ import annotations
import logging
import random
from typing import Optional
from fastapi import FastAPI, HTTPException, Query
from opensleuth_env import (
BLACK_BOX_FUNCTIONS,
OpenSleuthEnv,
ProbeAction,
ResetRequest,
StepRequest,
StepResponse,
SubmitAction,
TaskCatalog,
)
from opensleuth_env.task_catalog import TaskResolutionError
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
log = logging.getLogger("opensleuth.server")
app = FastAPI(title="OpenSleuth Env", version="0.5.0")
env = OpenSleuthEnv()
# ---------------------------------------------------------------------------
# OpenEnv conformance: mount an upstream-spec sub-app at /openenv.
# This is kept additive so the existing trainer (which talks to the bare
# /reset and /step routes above) is completely unaffected.
# ---------------------------------------------------------------------------
try:
from openenv.core.env_server.http_server import HTTPEnvServer
from opensleuth_env.openenv_adapter import (
OPENENV_AVAILABLE,
OpenSleuthAction,
OpenSleuthEnvironment,
OpenSleuthObservation,
)
if OPENENV_AVAILABLE:
openenv_app = FastAPI(
title="OpenSleuth (OpenEnv-conformant)",
version="0.5.0",
description=(
"OpenEnv 0.2.x conformant surface for the OpenSleuth environment.\n\n"
"See https://github.com/meta-pytorch/OpenEnv -- this sub-app implements"
" the canonical reset/step/state/health/metadata/schema HTTP routes plus"
" the /ws WebSocket session protocol."
),
)
_openenv_server = HTTPEnvServer(
env=OpenSleuthEnvironment,
action_cls=OpenSleuthAction,
observation_cls=OpenSleuthObservation,
max_concurrent_envs=8,
)
_openenv_server.register_routes(openenv_app)
app.mount("/openenv", openenv_app)
log.info("Mounted OpenEnv-conformant sub-app at /openenv (openenv-core %s)",
_openenv_server.__class__.__module__)
else: # pragma: no cover
log.warning("openenv-core not importable; /openenv/* will be unavailable.")
except Exception as e: # pragma: no cover - fail open so legacy routes keep working
log.warning("Could not register OpenEnv sub-app: %s: %s", type(e).__name__, e)
@app.get("/health")
def health():
return {
"status": "ok",
"episodes_tracked": len(env._states), # noqa: SLF001
"hub": env.catalog.hub_status(),
}
@app.get("/functions")
def list_functions(
difficulty: Optional[str] = Query(
None,
description="Optional filter: easy / medium / hard. Used by the trainer for curriculum scheduling.",
),
):
# NOTE -- backwards compatibility: this endpoint deliberately keeps the
# exact v0.3 shape (just the 9 builtin functions, with the original
# field set), because the in-flight trainer queries it. The new "source"
# field is additive. Open-ended / Hub tasks are exposed via /tasks.
items = []
for s in BLACK_BOX_FUNCTIONS.values():
if difficulty is not None and getattr(s, "difficulty", None) != difficulty:
continue
items.append(
{
"name": s.name,
"signature": s.signature,
"description": s.description,
"difficulty": getattr(s, "difficulty", None),
"edge_case_count": len(getattr(s, "edge_cases", []) or []),
"source": "builtin",
}
)
return {"functions": items}
@app.get("/tasks")
def list_tasks(
source: str = Query(
"all",
description="Filter by source: 'builtin', 'hub', or 'all' (default).",
),
difficulty: Optional[str] = Query(None, description="Optional curriculum filter."),
):
src = source.lower()
if src == "builtin":
tasks = env.catalog.list_builtin()
elif src == "hub":
tasks = env.catalog.list_hub()
elif src == "all":
tasks = env.catalog.list_all()
else:
raise HTTPException(
status_code=400, detail="source must be one of: builtin, hub, all"
)
if difficulty is not None:
tasks = [t for t in tasks if t.get("difficulty") == difficulty]
return {
"tasks": tasks,
"count": len(tasks),
"hub": env.catalog.hub_status(),
}
@app.post("/reset")
def reset(req: ResetRequest):
# Validation: legacy callers pass only target_name; open-ended callers
# pass target_code + target_function_name. At least one of those paths
# must be populated.
if not req.target_name and not req.target_code:
raise HTTPException(
status_code=400,
detail="Either 'target_name' or ('target_code' + 'target_function_name') must be set.",
)
if req.target_code and not req.target_function_name:
raise HTTPException(
status_code=400,
detail="'target_function_name' is required when 'target_code' is provided.",
)
try:
obs = env.reset(
target_name=req.target_name,
seed=req.seed,
max_steps=req.max_steps,
target_code=req.target_code,
target_function_name=req.target_function_name,
edge_cases=req.edge_cases,
fuzz_spec=req.fuzz_spec,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
return obs
@app.post("/step", response_model=StepResponse)
def step(req: StepRequest):
try:
return env.step(req.episode_id, req.action)
except KeyError as e:
raise HTTPException(status_code=404, detail=str(e)) from e
@app.get("/state/{episode_id}")
def get_state(episode_id: str):
state = env.get_state(episode_id)
if not state:
raise HTTPException(status_code=404, detail=f"Unknown episode_id {episode_id!r}")
return state
@app.post("/probe_once")
def probe_once(target_name: str, input_repr: str):
obs = env.reset(target_name=target_name)
resp = env.step(obs.episode_id, ProbeAction(input_repr=input_repr))
return resp
@app.get("/tasks/{name}/sample_inputs")
def sample_inputs(
name: str,
n: int = Query(8, ge=1, le=64, description="How many inputs to draw."),
seed: int = Query(0, description="Deterministic seed for the fuzzer."),
):
"""Return ``n`` Python-literal `repr` strings drawn from the task's
auto-fuzzer (or hand-written fuzzer for builtins).
Used by the trainer to build in-context probe pools without having to
duplicate the auto-fuzzer logic on the trainer side. Each returned
string is `ast.literal_eval`-safe and can be POSTed straight back to
`/step` as a `ProbeAction.input_repr`.
"""
try:
spec = env.catalog.resolve(target_name=name)
except TaskResolutionError as e:
raise HTTPException(status_code=404, detail=str(e)) from e
rng = random.Random(seed)
try:
raw_inputs = spec.fuzzer(rng, n)
except Exception as e: # noqa: BLE001
raise HTTPException(
status_code=500,
detail=f"fuzzer for {name!r} failed: {type(e).__name__}: {e}",
) from e
return {
"name": name,
"n": n,
"seed": seed,
"unpack_args": bool(getattr(spec, "unpack_args", False)),
"inputs": [repr(x) for x in raw_inputs],
}