Spaces:
Running
Running
| """ | |
| app.py | |
| ------ | |
| FastAPI server exposing the OpenEnv HTTP interface. | |
| Endpoints: | |
| POST /reset β start a new episode | |
| POST /step β take one action | |
| GET /state β inspect internal state (debugging) | |
| GET /tasks β list available tasks | |
| GET /health β liveness probe | |
| GET /action_space β action space description for a task | |
| GET /observation_space β observation space description | |
| Sessions are keyed by a UUID in the `session_id` query parameter. | |
| If omitted, "default" is used (fine for sequential single-agent runs). | |
| """ | |
| from typing import Dict, Optional, Union | |
| from pathlib import Path | |
| from fastapi import FastAPI, HTTPException, Query, Request | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from pydantic import BaseModel | |
| from env.schemas import Action, ActionType, TaskInfo | |
| from server.tasks.task1 import Task1Environment | |
| from server.tasks.task2 import Task2Environment | |
| from server.tasks.task3 import Task3Environment | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # App | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI( | |
| title="Smart Contract Audit RL Environment", | |
| description=( | |
| "OpenEnv-compliant reinforcement learning environment for smart contract " | |
| "security analysis. Train and evaluate agents on real-world Solidity audit tasks." | |
| ), | |
| version="1.2.0", | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Session management | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _sessions: Dict[str, Union[Task1Environment, Task2Environment, Task3Environment]] = {} | |
| DEFAULT_SESSION = "default" | |
| TASK_ENV_MAP = { | |
| "task1_vuln_detection": Task1Environment, | |
| "task2_property_discovery": Task2Environment, | |
| "task3_rule_checker": Task3Environment, | |
| } | |
| def _create_env(task_id: str): | |
| cls = TASK_ENV_MAP.get(task_id) | |
| if cls is None: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unknown task_id '{task_id}'. Available: {list(TASK_ENV_MAP)}", | |
| ) | |
| return cls() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Request bodies | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ResetRequest(BaseModel): | |
| task_id: str = "task1_vuln_detection" | |
| seed: Optional[int] = None | |
| class StepRequest(BaseModel): | |
| action_type: str | |
| params: dict = {} | |
| _ROOT_JSON = { | |
| "name": "Smart Contract Audit RL Environment", | |
| "version": "1.2.0", | |
| "description": ( | |
| "OpenEnv-compliant RL environment for Solidity smart contract security analysis. " | |
| "Train and evaluate agents on real-world DeFi audit tasks from Certora reports." | |
| ), | |
| "tasks": [ | |
| {"id": "task1_vuln_detection", "name": "Targeted Vulnerability Detection", "difficulty": "medium"}, | |
| {"id": "task2_property_discovery", "name": "Property Discovery", "difficulty": "hard"}, | |
| {"id": "task3_rule_checker", "name": "Rule Checker", "difficulty": "easy"}, | |
| ], | |
| "endpoints": { | |
| "reset": "POST /reset", | |
| "step": "POST /step", | |
| "state": "GET /state", | |
| "tasks": "GET /tasks", | |
| "health": "GET /health", | |
| "action_space": "GET /action_space", | |
| "observation_space": "GET /observation_space", | |
| "docs": "GET /docs", | |
| }, | |
| "data_sources": ["AaveVault", "AaveVaultV2", "Lido Finance"], | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Routes | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def root(request: Request): | |
| """ | |
| Landing page with human-readable description and API summary. Also serves as a health check. | |
| """ | |
| BASE_DIR = Path(__file__).resolve().parent | |
| return FileResponse(BASE_DIR / "index.html", media_type="text/html", status_code=200) | |
| def api_root(): | |
| """Machine-readable API summary.""" | |
| return JSONResponse(content=_ROOT_JSON, status_code=200) | |
| def health(): | |
| """Liveness probe.""" | |
| return {"status": "ok", "version": "1.2.0"} | |
| def list_tasks(): | |
| """List all tasks with their status.""" | |
| tasks = [ | |
| TaskInfo( | |
| task_id="task1_vuln_detection", | |
| name="Targeted Vulnerability Detection", | |
| difficulty="medium", | |
| description="Given a Solidity contract, identify the vulnerable function and describe the vulnerability type in 2-3 words.", | |
| status="active", | |
| ), | |
| TaskInfo( | |
| task_id="task2_property_discovery", | |
| name="Property Discovery", | |
| difficulty="hard", | |
| description="Given a Solidity function, write the natural-language property that describes its correct behaviour.", | |
| status="active", | |
| ), | |
| TaskInfo( | |
| task_id="task3_rule_checker", | |
| name="Rule Checker", | |
| difficulty="easy", | |
| description="Given a property in English and a Solidity contract, identify which function violates that property.", | |
| status="active", | |
| ), | |
| ] | |
| return {"tasks": [t.model_dump() for t in tasks]} | |
| def reset( | |
| body: Optional[ResetRequest] = None, | |
| session_id: str = Query(default=DEFAULT_SESSION), | |
| ): | |
| """Reset the environment and start a new episode.""" | |
| # Handle missing body (OpenEnv validator case) | |
| if body is None: | |
| task_id = "task1_vuln_detection" | |
| seed = None | |
| else: | |
| task_id = body.task_id | |
| seed = body.seed | |
| env = _create_env(task_id) | |
| _sessions[session_id] = env | |
| result = env.reset(seed=seed) | |
| return JSONResponse(content=result.model_dump(), status_code=200) | |
| def step( | |
| body: StepRequest, | |
| session_id: str = Query(default=DEFAULT_SESSION), | |
| ): | |
| """Apply one action and advance the episode.""" | |
| env = _sessions.get(session_id) | |
| if env is None: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"No active session '{session_id}'. Call /reset first.", | |
| ) | |
| # removed error handling here | |
| action_type = ActionType(body.action_type) if body.action_type in ActionType else ActionType.UNKNOWN | |
| action = Action(action_type=action_type, params=body.params) | |
| try: | |
| result = env.step(action) | |
| except RuntimeError as e: | |
| return JSONResponse(content=str(e), status_code = 200) | |
| return JSONResponse(content=result.model_dump(), status_code=200) | |
| def state(session_id: str = Query(default=DEFAULT_SESSION)): | |
| """Return internal state for debugging (not for agents).""" | |
| env = _sessions.get(session_id) | |
| if env is None: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"No active session '{session_id}'. Call /reset first.", | |
| ) | |
| return JSONResponse(content=env.state().model_dump(), status_code=200) | |
| def action_space(task_id: str = "task1_vuln_detection"): | |
| """Describe the action space for a task.""" | |
| if task_id == "task1_vuln_detection": | |
| return JSONResponse(content={ | |
| "task_id": task_id, | |
| "actions": [ | |
| {"type": "list_functions", "params": {}, "reward": -0.05, "description": "List all function names"}, | |
| {"type": "get_function_code", "params": {"function_name": "string"}, "reward": "+0.05 (target) / -0.10 (other)", "description": "Get full Solidity source of a function"}, | |
| {"type": "get_function_summary", "params": {"function_name": "string"}, "reward": "+0.03 (target) / -0.05 (other)", "description": "Get NatSpec comment of a function"}, | |
| {"type": "get_file_metadata", "params": {}, "reward": -0.04, "description": "Get contract-level metadata"}, | |
| {"type": "get_state_variable", "params": {"variable_name": "string (optional)"}, "reward": -0.05, "description": "Get a state variable or list all"}, | |
| {"type": "get_call_graph", "params": {}, "reward": -0.08, "description": "Get function call graph"}, | |
| {"type": "submit", "params": {"function_name": "str", "vulnerability_type": "str"},"reward": "+5.0 / +1.0 / -1.5", "description": "Submit answer. Ends episode."}, | |
| ], | |
| }, status_code=200) | |
| if task_id == "task2_property_discovery": | |
| return JSONResponse(content={ | |
| "task_id": task_id, | |
| "actions": [ | |
| {"type": "get_function_code", "params": {}, "reward": -0.06, "description": "Read full source of the target function"}, | |
| {"type": "get_function_natspec", "params": {}, "reward": -0.08, "description": "Read NatSpec + expected behaviour"}, | |
| {"type": "get_file_natspec", "params": {}, "reward": -0.03, "description": "Read contract-level NatSpec"}, | |
| {"type": "get_related_functions", "params": {}, "reward": -0.06, "description": "List caller/callee functions with summaries"}, | |
| {"type": "get_signature", "params": {}, "reward": -0.04, "description": "Get structured I/O + expected behaviour"}, | |
| {"type": "get_similar_rule", "params": {}, "reward": -0.20, "description": "Get a similar property from another contract"}, | |
| {"type": "submit_property", "params": {"property": "string"}, "reward": "0.0β5.0", "description": "Submit property. ONE attempt. Ends episode."}, | |
| ], | |
| }, status_code=200) | |
| if task_id == "task3_rule_checker": | |
| return JSONResponse(content={ | |
| "task_id": task_id, | |
| "actions": [ | |
| {"type": "list_functions", "params": {}, "reward": -0.05, "description": "List all function names"}, | |
| {"type": "get_function_metadata", "params": {"function_name": "string"}, "reward": -0.05, "description": "Get signature, visibility, params of a function"}, | |
| {"type": "get_function_code", "params": {"function_name": "string"}, "reward": -0.10, "description": "Read full Solidity source of a function"}, | |
| {"type": "get_state_variable", "params": {"variable_name": "string (opt)"},"reward": -0.05, "description": "Get a state variable or list all"}, | |
| {"type": "get_call_graph", "params": {}, "reward": -0.08, "description": "Get function call graph"}, | |
| {"type": "get_property_specification", "params": {}, "reward": -0.03, "description": "Get formal pre/post-condition for the property"}, | |
| {"type": "submit_function", "params": {"function_name": "string"}, "reward": "+5.0 / +1.5 / -1.5", "description": "Submit answer. ONE attempt. Ends episode."}, | |
| ], | |
| }, status_code=200) | |
| raise HTTPException(status_code=400, detail=f"Unknown task_id '{task_id}'.") | |
| def observation_space(): | |
| """Describe the observation space (same for all tasks).""" | |
| return JSONResponse(content={ | |
| "type": "object", | |
| "fields": { | |
| "task_id": "string β active task identifier", | |
| "contract_name": "string β Solidity contract name", | |
| "contract_description": "string β what the contract does", | |
| "available_actions": "list[string] β valid action types for this task", | |
| "last_action": "string|null β previous action type", | |
| "last_action_result": "string|null β human-readable result of last action", | |
| "step_count": "int β steps taken in this episode", | |
| "cumulative_reward": "float β running reward total", | |
| "done": "bool β True when episode is over", | |
| "extra": "object β task-specific hints (target_function, hint, etc.)", | |
| }, | |
| }, status_code=200) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Entry point | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False) | |
| if __name__ == "__main__": | |
| main() |