File size: 4,431 Bytes
d3cd20c
ee14542
d3cd20c
 
 
536dda7
d3cd20c
536dda7
d3cd20c
 
 
 
 
 
 
 
 
77e65fb
d3cd20c
 
 
 
 
77e65fb
ee14542
 
d3cd20c
 
 
77e65fb
 
 
 
 
d3cd20c
 
 
536dda7
 
 
 
 
 
77e65fb
 
 
 
536dda7
 
 
 
 
d3cd20c
 
 
 
536dda7
 
77e65fb
d3cd20c
536dda7
 
d3cd20c
 
77e65fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3cd20c
 
77e65fb
 
 
 
 
 
 
 
 
 
 
 
 
d3cd20c
77e65fb
 
 
 
 
 
 
 
 
d3cd20c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""FastAPI server exposing the OpenSleuth environment over HTTP."""

from __future__ import annotations

import logging
from typing import Optional

from fastapi import FastAPI, HTTPException, Query

from opensleuth_env import (
    BLACK_BOX_FUNCTIONS,
    OpenSleuthEnv,
    ProbeAction,
    ResetRequest,
    StepRequest,
    StepResponse,
    SubmitAction,
    TaskCatalog,
)

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
log = logging.getLogger("opensleuth.server")

app = FastAPI(title="OpenSleuth Env", version="0.4.0")
env = OpenSleuthEnv()


@app.get("/health")
def health():
    return {
        "status": "ok",
        "episodes_tracked": len(env._states),  # noqa: SLF001
        "hub": env.catalog.hub_status(),
    }


@app.get("/functions")
def list_functions(
    difficulty: Optional[str] = Query(
        None,
        description="Optional filter: easy / medium / hard. Used by the trainer for curriculum scheduling.",
    ),
):
    # NOTE -- backwards compatibility: this endpoint deliberately keeps the
    # exact v0.3 shape (just the 9 builtin functions, with the original
    # field set), because the in-flight trainer queries it. The new "source"
    # field is additive. Open-ended / Hub tasks are exposed via /tasks.
    items = []
    for s in BLACK_BOX_FUNCTIONS.values():
        if difficulty is not None and getattr(s, "difficulty", None) != difficulty:
            continue
        items.append(
            {
                "name": s.name,
                "signature": s.signature,
                "description": s.description,
                "difficulty": getattr(s, "difficulty", None),
                "edge_case_count": len(getattr(s, "edge_cases", []) or []),
                "source": "builtin",
            }
        )
    return {"functions": items}


@app.get("/tasks")
def list_tasks(
    source: str = Query(
        "all",
        description="Filter by source: 'builtin', 'hub', or 'all' (default).",
    ),
    difficulty: Optional[str] = Query(None, description="Optional curriculum filter."),
):
    src = source.lower()
    if src == "builtin":
        tasks = env.catalog.list_builtin()
    elif src == "hub":
        tasks = env.catalog.list_hub()
    elif src == "all":
        tasks = env.catalog.list_all()
    else:
        raise HTTPException(
            status_code=400, detail="source must be one of: builtin, hub, all"
        )
    if difficulty is not None:
        tasks = [t for t in tasks if t.get("difficulty") == difficulty]
    return {
        "tasks": tasks,
        "count": len(tasks),
        "hub": env.catalog.hub_status(),
    }


@app.post("/reset")
def reset(req: ResetRequest):
    # Validation: legacy callers pass only target_name; open-ended callers
    # pass target_code + target_function_name. At least one of those paths
    # must be populated.
    if not req.target_name and not req.target_code:
        raise HTTPException(
            status_code=400,
            detail="Either 'target_name' or ('target_code' + 'target_function_name') must be set.",
        )
    if req.target_code and not req.target_function_name:
        raise HTTPException(
            status_code=400,
            detail="'target_function_name' is required when 'target_code' is provided.",
        )
    try:
        obs = env.reset(
            target_name=req.target_name,
            seed=req.seed,
            max_steps=req.max_steps,
            target_code=req.target_code,
            target_function_name=req.target_function_name,
            edge_cases=req.edge_cases,
            fuzz_spec=req.fuzz_spec,
        )
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e)) from e
    return obs


@app.post("/step", response_model=StepResponse)
def step(req: StepRequest):
    try:
        return env.step(req.episode_id, req.action)
    except KeyError as e:
        raise HTTPException(status_code=404, detail=str(e)) from e


@app.get("/state/{episode_id}")
def get_state(episode_id: str):
    state = env.get_state(episode_id)
    if not state:
        raise HTTPException(status_code=404, detail=f"Unknown episode_id {episode_id!r}")
    return state


@app.post("/probe_once")
def probe_once(target_name: str, input_repr: str):
    obs = env.reset(target_name=target_name)
    resp = env.step(obs.episode_id, ProbeAction(input_repr=input_repr))
    return resp