File size: 14,684 Bytes
08c19c7
 
 
 
 
 
9c888b7
 
 
 
 
 
 
 
 
 
08c19c7
 
88875f7
45bd962
08c19c7
17ed3a7
 
08c19c7
 
 
0304fd3
 
 
08c19c7
9c888b7
 
 
08c19c7
 
 
 
 
 
 
7203787
08c19c7
 
9c888b7
08c19c7
9c888b7
08c19c7
88875f7
08c19c7
 
9c888b7
 
 
7203787
9c888b7
08c19c7
 
9c888b7
 
 
 
 
 
 
 
08c19c7
 
9c888b7
 
 
08c19c7
 
 
 
 
 
 
 
 
17ed3a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08c19c7
 
9c888b7
08c19c7
9c888b7
08c19c7
88875f7
17ed3a7
 
c27e659
17ed3a7
c27e659
7940abd
c27e659
 
 
 
17ed3a7
88875f7
08c19c7
 
9c888b7
c27e659
08c19c7
 
 
 
9c888b7
08c19c7
 
 
 
 
9c888b7
08c19c7
 
 
 
 
 
9c888b7
 
08c19c7
 
 
 
 
9c888b7
7203787
08c19c7
 
 
 
 
 
a503619
08c19c7
 
 
a503619
 
 
 
 
 
 
 
 
 
08c19c7
a503619
 
cfd3cfa
08c19c7
 
 
 
 
 
 
9c888b7
08c19c7
 
 
 
 
 
5235476
 
 
 
08c19c7
 
 
 
5235476
cfd3cfa
08c19c7
 
 
 
9c888b7
08c19c7
 
 
 
 
 
cfd3cfa
08c19c7
 
 
 
 
cfd3cfa
08c19c7
 
9c888b7
 
 
 
 
 
 
 
cfd3cfa
9c888b7
cfd3cfa
9c888b7
 
 
 
 
 
c27e659
9c888b7
c27e659
08c19c7
cfd3cfa
7203787
cfd3cfa
7203787
 
c27e659
 
 
 
 
 
 
7203787
cfd3cfa
 
08c19c7
 
 
9c888b7
cfd3cfa
08c19c7
 
9c888b7
 
08c19c7
9c888b7
 
 
 
 
 
 
08c19c7
cfd3cfa
08c19c7
 
9c888b7
08c19c7
9c888b7
08c19c7
88875f7
08c19c7
 
88875f7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
"""
app.py
------
FastAPI server exposing the OpenEnv HTTP interface.

Endpoints:
  POST /reset              – start a new episode
  POST /step               – take one action
  GET  /state              – inspect internal state (debugging)
  GET  /tasks              – list available tasks
  GET  /health             – liveness probe
  GET  /action_space       – action space description for a task
  GET  /observation_space  – observation space description

Sessions are keyed by a UUID in the `session_id` query parameter.
If omitted, "default" is used (fine for sequential single-agent runs).
"""

from typing import Dict, Optional, Union
from pathlib import Path

from fastapi import FastAPI, HTTPException, Query, Request
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel

from env.schemas import Action, ActionType, TaskInfo
from server.tasks.task1 import Task1Environment
from server.tasks.task2 import Task2Environment
from server.tasks.task3 import Task3Environment

# ─────────────────────────────────────────────────────────────────────────────
# App
# ─────────────────────────────────────────────────────────────────────────────

app = FastAPI(
    title="Smart Contract Audit RL Environment",
    description=(
        "OpenEnv-compliant reinforcement learning environment for smart contract "
        "security analysis. Train and evaluate agents on real-world Solidity audit tasks."
    ),
    version="1.2.0",
)

# ─────────────────────────────────────────────────────────────────────────────
# Session management
# ─────────────────────────────────────────────────────────────────────────────

_sessions: Dict[str, Union[Task1Environment, Task2Environment, Task3Environment]] = {}
DEFAULT_SESSION = "default"

TASK_ENV_MAP = {
    "task1_vuln_detection":     Task1Environment,
    "task2_property_discovery": Task2Environment,
    "task3_rule_checker":       Task3Environment,
}


def _create_env(task_id: str):
    cls = TASK_ENV_MAP.get(task_id)
    if cls is None:
        raise HTTPException(
            status_code=400,
            detail=f"Unknown task_id '{task_id}'. Available: {list(TASK_ENV_MAP)}",
        )
    return cls()


# ─────────────────────────────────────────────────────────────────────────────
# Request bodies
# ─────────────────────────────────────────────────────────────────────────────

class ResetRequest(BaseModel):
    task_id: str = "task1_vuln_detection"
    seed: Optional[int] = None


class StepRequest(BaseModel):
    action_type: str
    params: dict = {}
    

_ROOT_JSON = {
    "name": "Smart Contract Audit RL Environment",
    "version": "1.2.0",
    "description": (
        "OpenEnv-compliant RL environment for Solidity smart contract security analysis. "
        "Train and evaluate agents on real-world DeFi audit tasks from Certora reports."
    ),
    "tasks": [
        {"id": "task1_vuln_detection",     "name": "Targeted Vulnerability Detection", "difficulty": "medium"},
        {"id": "task2_property_discovery", "name": "Property Discovery",                "difficulty": "hard"},
        {"id": "task3_rule_checker",       "name": "Rule Checker",                      "difficulty": "easy"},
    ],
    "endpoints": {
        "reset":             "POST /reset",
        "step":              "POST /step",
        "state":             "GET  /state",
        "tasks":             "GET  /tasks",
        "health":            "GET  /health",
        "action_space":      "GET  /action_space",
        "observation_space": "GET  /observation_space",
        "docs":              "GET  /docs",
    },
    "data_sources": ["AaveVault", "AaveVaultV2", "Lido Finance"],
}


# ─────────────────────────────────────────────────────────────────────────────
# Routes
# ─────────────────────────────────────────────────────────────────────────────

@app.get("/")
def root(request: Request):
    """
    Landing page with human-readable description and API summary. Also serves as a health check.
    """
    BASE_DIR = Path(__file__).resolve().parent
    return FileResponse(BASE_DIR / "index.html", media_type="text/html", status_code=200)

@app.get("/api")
def api_root():
    """Machine-readable API summary."""
    return JSONResponse(content=_ROOT_JSON, status_code=200)

@app.get("/health")
def health():
    """Liveness probe."""
    return {"status": "ok", "version": "1.2.0"}


@app.get("/tasks")
def list_tasks():
    """List all tasks with their status."""
    tasks = [
        TaskInfo(
            task_id="task1_vuln_detection",
            name="Targeted Vulnerability Detection",
            difficulty="medium",
            description="Given a Solidity contract, identify the vulnerable function and describe the vulnerability type in 2-3 words.",
            status="active",
        ),
        TaskInfo(
            task_id="task2_property_discovery",
            name="Property Discovery",
            difficulty="hard",
            description="Given a Solidity function, write the natural-language property that describes its correct behaviour.",
            status="active",
        ),
        TaskInfo(
            task_id="task3_rule_checker",
            name="Rule Checker",
            difficulty="easy",
            description="Given a property in English and a Solidity contract, identify which function violates that property.",
            status="active",
        ),
    ]
    return {"tasks": [t.model_dump() for t in tasks]}

@app.post("/reset")
def reset(
    body: Optional[ResetRequest] = None,
    session_id: str = Query(default=DEFAULT_SESSION),
):
    """Reset the environment and start a new episode."""
    
    # Handle missing body (OpenEnv validator case)
    if body is None:
        task_id = "task1_vuln_detection"
        seed = None
    else:
        task_id = body.task_id
        seed = body.seed

    env = _create_env(task_id)
    _sessions[session_id] = env

    result = env.reset(seed=seed)
    return JSONResponse(content=result.model_dump(), status_code=200)


@app.post("/step")
def step(
    body: StepRequest,
    session_id: str = Query(default=DEFAULT_SESSION),
):
    """Apply one action and advance the episode."""
    env = _sessions.get(session_id)
    if env is None:
        raise HTTPException(
            status_code=400,
            detail=f"No active session '{session_id}'. Call /reset first.",
        )
    
    # removed error handling here
    action_type = ActionType(body.action_type) if body.action_type in ActionType else ActionType.UNKNOWN
    
    action = Action(action_type=action_type, params=body.params)
    try:
        result = env.step(action)
    except RuntimeError as e:
        return JSONResponse(content=str(e), status_code = 200)
    return JSONResponse(content=result.model_dump(), status_code=200)


@app.get("/state")
def state(session_id: str = Query(default=DEFAULT_SESSION)):
    """Return internal state for debugging (not for agents)."""
    env = _sessions.get(session_id)
    if env is None:
        raise HTTPException(
            status_code=400,
            detail=f"No active session '{session_id}'. Call /reset first.",
        )
    return JSONResponse(content=env.state().model_dump(), status_code=200)

@app.get("/action_space")
def action_space(task_id: str = "task1_vuln_detection"):
    """Describe the action space for a task."""
    if task_id == "task1_vuln_detection":
        return JSONResponse(content={
            "task_id": task_id,
            "actions": [
                {"type": "list_functions",       "params": {},                                                    "reward": -0.05,                              "description": "List all function names"},
                {"type": "get_function_code",    "params": {"function_name": "string"},                          "reward": "+0.05 (target) / -0.10 (other)",   "description": "Get full Solidity source of a function"},
                {"type": "get_function_summary", "params": {"function_name": "string"},                          "reward": "+0.03 (target) / -0.05 (other)",   "description": "Get NatSpec comment of a function"},
                {"type": "get_file_metadata",    "params": {},                                                    "reward": -0.04,                              "description": "Get contract-level metadata"},
                {"type": "get_state_variable",   "params": {"variable_name": "string (optional)"},               "reward": -0.05,                              "description": "Get a state variable or list all"},
                {"type": "get_call_graph",        "params": {},                                                   "reward": -0.08,                              "description": "Get function call graph"},
                {"type": "submit",               "params": {"function_name": "str", "vulnerability_type": "str"},"reward": "+5.0 / +1.0 / -1.5",              "description": "Submit answer. Ends episode."},
            ],
        }, status_code=200)
    if task_id == "task2_property_discovery":
        return JSONResponse(content={
            "task_id": task_id,
            "actions": [
                {"type": "get_function_code",     "params": {},                         "reward": -0.06, "description": "Read full source of the target function"},
                {"type": "get_function_natspec",  "params": {},                         "reward": -0.08, "description": "Read NatSpec + expected behaviour"},
                {"type": "get_file_natspec",      "params": {},                         "reward": -0.03, "description": "Read contract-level NatSpec"},
                {"type": "get_related_functions", "params": {},                         "reward": -0.06, "description": "List caller/callee functions with summaries"},
                {"type": "get_signature",         "params": {},                         "reward": -0.04, "description": "Get structured I/O + expected behaviour"},
                {"type": "get_similar_rule",      "params": {},                         "reward": -0.20, "description": "Get a similar property from another contract"},
                {"type": "submit_property",       "params": {"property": "string"},     "reward": "0.0–5.0", "description": "Submit property. ONE attempt. Ends episode."},
            ],
        }, status_code=200)
    if task_id == "task3_rule_checker":
        return JSONResponse(content={
            "task_id": task_id,
            "actions": [
                {"type": "list_functions",          "params": {},                               "reward": -0.05, "description": "List all function names"},
                {"type": "get_function_metadata",   "params": {"function_name": "string"},      "reward": -0.05, "description": "Get signature, visibility, params of a function"},
                {"type": "get_function_code",       "params": {"function_name": "string"},      "reward": -0.10, "description": "Read full Solidity source of a function"},
                {"type": "get_state_variable",      "params": {"variable_name": "string (opt)"},"reward": -0.05, "description": "Get a state variable or list all"},
                {"type": "get_call_graph",          "params": {},                               "reward": -0.08, "description": "Get function call graph"},
                {"type": "get_property_specification", "params": {},                            "reward": -0.03, "description": "Get formal pre/post-condition for the property"},
                {"type": "submit_function",         "params": {"function_name": "string"},      "reward": "+5.0 / +1.5 / -1.5", "description": "Submit answer. ONE attempt. Ends episode."},
            ],
        }, status_code=200)
    raise HTTPException(status_code=400, detail=f"Unknown task_id '{task_id}'.")

@app.get("/observation_space")
def observation_space():
    """Describe the observation space (same for all tasks)."""
    return JSONResponse(content={
        "type": "object",
        "fields": {
            "task_id":              "string – active task identifier",
            "contract_name":        "string – Solidity contract name",
            "contract_description": "string – what the contract does",
            "available_actions":    "list[string] – valid action types for this task",
            "last_action":          "string|null – previous action type",
            "last_action_result":   "string|null – human-readable result of last action",
            "step_count":           "int – steps taken in this episode",
            "cumulative_reward":    "float – running reward total",
            "done":                 "bool – True when episode is over",
            "extra":                "object – task-specific hints (target_function, hint, etc.)",
        },
    }, status_code=200)


# ─────────────────────────────────────────────────────────────────────────────
# Entry point
# ─────────────────────────────────────────────────────────────────────────────

def main():
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)

if __name__ == "__main__":
    main()