""" app.py ------ FastAPI server exposing the OpenEnv HTTP interface. Endpoints: POST /reset – start a new episode POST /step – take one action GET /state – inspect internal state (debugging) GET /tasks – list available tasks GET /health – liveness probe GET /action_space – action space description for a task GET /observation_space – observation space description Sessions are keyed by a UUID in the `session_id` query parameter. If omitted, "default" is used (fine for sequential single-agent runs). """ from typing import Dict, Optional, Union from pathlib import Path from fastapi import FastAPI, HTTPException, Query, Request from fastapi.responses import FileResponse, JSONResponse from pydantic import BaseModel from env.schemas import Action, ActionType, TaskInfo from server.tasks.task1 import Task1Environment from server.tasks.task2 import Task2Environment from server.tasks.task3 import Task3Environment # ───────────────────────────────────────────────────────────────────────────── # App # ───────────────────────────────────────────────────────────────────────────── app = FastAPI( title="Smart Contract Audit RL Environment", description=( "OpenEnv-compliant reinforcement learning environment for smart contract " "security analysis. Train and evaluate agents on real-world Solidity audit tasks." ), version="1.2.0", ) # ───────────────────────────────────────────────────────────────────────────── # Session management # ───────────────────────────────────────────────────────────────────────────── _sessions: Dict[str, Union[Task1Environment, Task2Environment, Task3Environment]] = {} DEFAULT_SESSION = "default" TASK_ENV_MAP = { "task1_vuln_detection": Task1Environment, "task2_property_discovery": Task2Environment, "task3_rule_checker": Task3Environment, } def _create_env(task_id: str): cls = TASK_ENV_MAP.get(task_id) if cls is None: raise HTTPException( status_code=400, detail=f"Unknown task_id '{task_id}'. Available: {list(TASK_ENV_MAP)}", ) return cls() # ───────────────────────────────────────────────────────────────────────────── # Request bodies # ───────────────────────────────────────────────────────────────────────────── class ResetRequest(BaseModel): task_id: str = "task1_vuln_detection" seed: Optional[int] = None class StepRequest(BaseModel): action_type: str params: dict = {} _ROOT_JSON = { "name": "Smart Contract Audit RL Environment", "version": "1.2.0", "description": ( "OpenEnv-compliant RL environment for Solidity smart contract security analysis. " "Train and evaluate agents on real-world DeFi audit tasks from Certora reports." ), "tasks": [ {"id": "task1_vuln_detection", "name": "Targeted Vulnerability Detection", "difficulty": "medium"}, {"id": "task2_property_discovery", "name": "Property Discovery", "difficulty": "hard"}, {"id": "task3_rule_checker", "name": "Rule Checker", "difficulty": "easy"}, ], "endpoints": { "reset": "POST /reset", "step": "POST /step", "state": "GET /state", "tasks": "GET /tasks", "health": "GET /health", "action_space": "GET /action_space", "observation_space": "GET /observation_space", "docs": "GET /docs", }, "data_sources": ["AaveVault", "AaveVaultV2", "Lido Finance"], } # ───────────────────────────────────────────────────────────────────────────── # Routes # ───────────────────────────────────────────────────────────────────────────── @app.get("/") def root(request: Request): """ Landing page with human-readable description and API summary. Also serves as a health check. """ BASE_DIR = Path(__file__).resolve().parent return FileResponse(BASE_DIR / "index.html", media_type="text/html", status_code=200) @app.get("/api") def api_root(): """Machine-readable API summary.""" return JSONResponse(content=_ROOT_JSON, status_code=200) @app.get("/health") def health(): """Liveness probe.""" return {"status": "ok", "version": "1.2.0"} @app.get("/tasks") def list_tasks(): """List all tasks with their status.""" tasks = [ TaskInfo( task_id="task1_vuln_detection", name="Targeted Vulnerability Detection", difficulty="medium", description="Given a Solidity contract, identify the vulnerable function and describe the vulnerability type in 2-3 words.", status="active", ), TaskInfo( task_id="task2_property_discovery", name="Property Discovery", difficulty="hard", description="Given a Solidity function, write the natural-language property that describes its correct behaviour.", status="active", ), TaskInfo( task_id="task3_rule_checker", name="Rule Checker", difficulty="easy", description="Given a property in English and a Solidity contract, identify which function violates that property.", status="active", ), ] return {"tasks": [t.model_dump() for t in tasks]} @app.post("/reset") def reset( body: Optional[ResetRequest] = None, session_id: str = Query(default=DEFAULT_SESSION), ): """Reset the environment and start a new episode.""" # Handle missing body (OpenEnv validator case) if body is None: task_id = "task1_vuln_detection" seed = None else: task_id = body.task_id seed = body.seed env = _create_env(task_id) _sessions[session_id] = env result = env.reset(seed=seed) return JSONResponse(content=result.model_dump(), status_code=200) @app.post("/step") def step( body: StepRequest, session_id: str = Query(default=DEFAULT_SESSION), ): """Apply one action and advance the episode.""" env = _sessions.get(session_id) if env is None: raise HTTPException( status_code=400, detail=f"No active session '{session_id}'. Call /reset first.", ) # removed error handling here action_type = ActionType(body.action_type) if body.action_type in ActionType else ActionType.UNKNOWN action = Action(action_type=action_type, params=body.params) try: result = env.step(action) except RuntimeError as e: return JSONResponse(content=str(e), status_code = 200) return JSONResponse(content=result.model_dump(), status_code=200) @app.get("/state") def state(session_id: str = Query(default=DEFAULT_SESSION)): """Return internal state for debugging (not for agents).""" env = _sessions.get(session_id) if env is None: raise HTTPException( status_code=400, detail=f"No active session '{session_id}'. Call /reset first.", ) return JSONResponse(content=env.state().model_dump(), status_code=200) @app.get("/action_space") def action_space(task_id: str = "task1_vuln_detection"): """Describe the action space for a task.""" if task_id == "task1_vuln_detection": return JSONResponse(content={ "task_id": task_id, "actions": [ {"type": "list_functions", "params": {}, "reward": -0.05, "description": "List all function names"}, {"type": "get_function_code", "params": {"function_name": "string"}, "reward": "+0.05 (target) / -0.10 (other)", "description": "Get full Solidity source of a function"}, {"type": "get_function_summary", "params": {"function_name": "string"}, "reward": "+0.03 (target) / -0.05 (other)", "description": "Get NatSpec comment of a function"}, {"type": "get_file_metadata", "params": {}, "reward": -0.04, "description": "Get contract-level metadata"}, {"type": "get_state_variable", "params": {"variable_name": "string (optional)"}, "reward": -0.05, "description": "Get a state variable or list all"}, {"type": "get_call_graph", "params": {}, "reward": -0.08, "description": "Get function call graph"}, {"type": "submit", "params": {"function_name": "str", "vulnerability_type": "str"},"reward": "+5.0 / +1.0 / -1.5", "description": "Submit answer. Ends episode."}, ], }, status_code=200) if task_id == "task2_property_discovery": return JSONResponse(content={ "task_id": task_id, "actions": [ {"type": "get_function_code", "params": {}, "reward": -0.06, "description": "Read full source of the target function"}, {"type": "get_function_natspec", "params": {}, "reward": -0.08, "description": "Read NatSpec + expected behaviour"}, {"type": "get_file_natspec", "params": {}, "reward": -0.03, "description": "Read contract-level NatSpec"}, {"type": "get_related_functions", "params": {}, "reward": -0.06, "description": "List caller/callee functions with summaries"}, {"type": "get_signature", "params": {}, "reward": -0.04, "description": "Get structured I/O + expected behaviour"}, {"type": "get_similar_rule", "params": {}, "reward": -0.20, "description": "Get a similar property from another contract"}, {"type": "submit_property", "params": {"property": "string"}, "reward": "0.0–5.0", "description": "Submit property. ONE attempt. Ends episode."}, ], }, status_code=200) if task_id == "task3_rule_checker": return JSONResponse(content={ "task_id": task_id, "actions": [ {"type": "list_functions", "params": {}, "reward": -0.05, "description": "List all function names"}, {"type": "get_function_metadata", "params": {"function_name": "string"}, "reward": -0.05, "description": "Get signature, visibility, params of a function"}, {"type": "get_function_code", "params": {"function_name": "string"}, "reward": -0.10, "description": "Read full Solidity source of a function"}, {"type": "get_state_variable", "params": {"variable_name": "string (opt)"},"reward": -0.05, "description": "Get a state variable or list all"}, {"type": "get_call_graph", "params": {}, "reward": -0.08, "description": "Get function call graph"}, {"type": "get_property_specification", "params": {}, "reward": -0.03, "description": "Get formal pre/post-condition for the property"}, {"type": "submit_function", "params": {"function_name": "string"}, "reward": "+5.0 / +1.5 / -1.5", "description": "Submit answer. ONE attempt. Ends episode."}, ], }, status_code=200) raise HTTPException(status_code=400, detail=f"Unknown task_id '{task_id}'.") @app.get("/observation_space") def observation_space(): """Describe the observation space (same for all tasks).""" return JSONResponse(content={ "type": "object", "fields": { "task_id": "string – active task identifier", "contract_name": "string – Solidity contract name", "contract_description": "string – what the contract does", "available_actions": "list[string] – valid action types for this task", "last_action": "string|null – previous action type", "last_action_result": "string|null – human-readable result of last action", "step_count": "int – steps taken in this episode", "cumulative_reward": "float – running reward total", "done": "bool – True when episode is over", "extra": "object – task-specific hints (target_function, hint, etc.)", }, }, status_code=200) # ───────────────────────────────────────────────────────────────────────────── # Entry point # ───────────────────────────────────────────────────────────────────────────── def main(): import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False) if __name__ == "__main__": main()