Spaces:
Running
Running
ajaxwin commited on
Commit Β·
cfd3cfa
1
Parent(s): df6af9d
fix: Update API responses to return JSON format and remove deprecated file references
Browse files- README.md +2 -3
- inference.py +1 -1
- server/app.py +12 -13
README.md
CHANGED
|
@@ -88,7 +88,6 @@ SmartContractEnv/
|
|
| 88 |
βββ .gitignore
|
| 89 |
βββ demo.py
|
| 90 |
βββ Dockerfile
|
| 91 |
-
βββ Docs.md
|
| 92 |
βββ eval.py
|
| 93 |
βββ inference.py
|
| 94 |
βββ LICENSE.txt
|
|
@@ -308,8 +307,8 @@ The `inference.py` script runs an OpenAI-compatible model against all three task
|
|
| 308 |
|
| 309 |
```bash
|
| 310 |
export OPENAI_API_KEY=your_key
|
| 311 |
-
export API_BASE_URL=
|
| 312 |
-
export MODEL_NAME=
|
| 313 |
|
| 314 |
python inference.py
|
| 315 |
```
|
|
|
|
| 88 |
βββ .gitignore
|
| 89 |
βββ demo.py
|
| 90 |
βββ Dockerfile
|
|
|
|
| 91 |
βββ eval.py
|
| 92 |
βββ inference.py
|
| 93 |
βββ LICENSE.txt
|
|
|
|
| 307 |
|
| 308 |
```bash
|
| 309 |
export OPENAI_API_KEY=your_key
|
| 310 |
+
export API_BASE_URL=your custom endpoint
|
| 311 |
+
export MODEL_NAME=your custom model
|
| 312 |
|
| 313 |
python inference.py
|
| 314 |
```
|
inference.py
CHANGED
|
@@ -44,7 +44,7 @@ from dotenv import dotenv_values
|
|
| 44 |
config = dotenv_values(".env")
|
| 45 |
API_BASE_URL = config.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 46 |
MODEL_NAME = config.get("MODEL_NAME", "gpt-4o")
|
| 47 |
-
HF_TOKEN = config.get("HF_TOKEN"
|
| 48 |
|
| 49 |
if not HF_TOKEN:
|
| 50 |
print("[WARN] HF_TOKEN not set β API calls may fail.", file=sys.stderr)
|
|
|
|
| 44 |
config = dotenv_values(".env")
|
| 45 |
API_BASE_URL = config.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 46 |
MODEL_NAME = config.get("MODEL_NAME", "gpt-4o")
|
| 47 |
+
HF_TOKEN = config.get("HF_TOKEN")
|
| 48 |
|
| 49 |
if not HF_TOKEN:
|
| 50 |
print("[WARN] HF_TOKEN not set β API calls may fail.", file=sys.stderr)
|
server/app.py
CHANGED
|
@@ -166,7 +166,7 @@ def reset(
|
|
| 166 |
env = _create_env(body.task_id)
|
| 167 |
_sessions[session_id] = env
|
| 168 |
result = env.reset(seed=body.seed)
|
| 169 |
-
return result.model_dump()
|
| 170 |
|
| 171 |
|
| 172 |
@app.post("/step")
|
|
@@ -193,7 +193,7 @@ def step(
|
|
| 193 |
result = env.step(action)
|
| 194 |
except RuntimeError as e:
|
| 195 |
raise HTTPException(status_code=409, detail=str(e))
|
| 196 |
-
return result.model_dump()
|
| 197 |
|
| 198 |
|
| 199 |
@app.get("/state")
|
|
@@ -205,14 +205,14 @@ def state(session_id: str = Query(default=DEFAULT_SESSION)):
|
|
| 205 |
status_code=400,
|
| 206 |
detail=f"No active session '{session_id}'. Call /reset first.",
|
| 207 |
)
|
| 208 |
-
return env.state().model_dump()
|
| 209 |
|
| 210 |
|
| 211 |
@app.get("/action_space")
|
| 212 |
def action_space(task_id: str = "task1_vuln_detection"):
|
| 213 |
"""Describe the action space for a task."""
|
| 214 |
if task_id == "task1_vuln_detection":
|
| 215 |
-
return {
|
| 216 |
"task_id": task_id,
|
| 217 |
"actions": [
|
| 218 |
{"type": "list_functions", "params": {}, "reward": -0.05, "description": "List all function names"},
|
|
@@ -223,9 +223,9 @@ def action_space(task_id: str = "task1_vuln_detection"):
|
|
| 223 |
{"type": "get_call_graph", "params": {}, "reward": -0.08, "description": "Get function call graph"},
|
| 224 |
{"type": "submit", "params": {"function_name": "str", "vulnerability_type": "str"},"reward": "+5.0 / +1.0 / -1.5", "description": "Submit answer. Ends episode."},
|
| 225 |
],
|
| 226 |
-
}
|
| 227 |
if task_id == "task2_property_discovery":
|
| 228 |
-
return {
|
| 229 |
"task_id": task_id,
|
| 230 |
"actions": [
|
| 231 |
{"type": "get_function_code", "params": {}, "reward": -0.06, "description": "Read full source of the target function"},
|
|
@@ -236,9 +236,9 @@ def action_space(task_id: str = "task1_vuln_detection"):
|
|
| 236 |
{"type": "get_similar_rule", "params": {}, "reward": -0.20, "description": "Get a similar property from another contract"},
|
| 237 |
{"type": "submit_property", "params": {"property": "string"}, "reward": "0.0β5.0", "description": "Submit property. ONE attempt. Ends episode."},
|
| 238 |
],
|
| 239 |
-
}
|
| 240 |
if task_id == "task3_rule_checker":
|
| 241 |
-
return {
|
| 242 |
"task_id": task_id,
|
| 243 |
"actions": [
|
| 244 |
{"type": "list_functions", "params": {}, "reward": -0.05, "description": "List all function names"},
|
|
@@ -249,14 +249,13 @@ def action_space(task_id: str = "task1_vuln_detection"):
|
|
| 249 |
{"type": "get_property_specification", "params": {}, "reward": -0.03, "description": "Get formal pre/post-condition for the property"},
|
| 250 |
{"type": "submit_function", "params": {"function_name": "string"}, "reward": "+5.0 / +1.5 / -1.5", "description": "Submit answer. ONE attempt. Ends episode."},
|
| 251 |
],
|
| 252 |
-
}
|
| 253 |
-
|
| 254 |
-
|
| 255 |
|
| 256 |
@app.get("/observation_space")
|
| 257 |
def observation_space():
|
| 258 |
"""Describe the observation space (same for all tasks)."""
|
| 259 |
-
return {
|
| 260 |
"type": "object",
|
| 261 |
"fields": {
|
| 262 |
"task_id": "string β active task identifier",
|
|
@@ -270,7 +269,7 @@ def observation_space():
|
|
| 270 |
"done": "bool β True when episode is over",
|
| 271 |
"extra": "object β task-specific hints (target_function, hint, etc.)",
|
| 272 |
},
|
| 273 |
-
}
|
| 274 |
|
| 275 |
|
| 276 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 166 |
env = _create_env(body.task_id)
|
| 167 |
_sessions[session_id] = env
|
| 168 |
result = env.reset(seed=body.seed)
|
| 169 |
+
return JSONResponse(content=result.model_dump(), status_code=200)
|
| 170 |
|
| 171 |
|
| 172 |
@app.post("/step")
|
|
|
|
| 193 |
result = env.step(action)
|
| 194 |
except RuntimeError as e:
|
| 195 |
raise HTTPException(status_code=409, detail=str(e))
|
| 196 |
+
return JSONResponse(content=result.model_dump(), status_code=200)
|
| 197 |
|
| 198 |
|
| 199 |
@app.get("/state")
|
|
|
|
| 205 |
status_code=400,
|
| 206 |
detail=f"No active session '{session_id}'. Call /reset first.",
|
| 207 |
)
|
| 208 |
+
return JSONResponse(content=env.state().model_dump(), status_code=200)
|
| 209 |
|
| 210 |
|
| 211 |
@app.get("/action_space")
|
| 212 |
def action_space(task_id: str = "task1_vuln_detection"):
|
| 213 |
"""Describe the action space for a task."""
|
| 214 |
if task_id == "task1_vuln_detection":
|
| 215 |
+
return JSONResponse(content={
|
| 216 |
"task_id": task_id,
|
| 217 |
"actions": [
|
| 218 |
{"type": "list_functions", "params": {}, "reward": -0.05, "description": "List all function names"},
|
|
|
|
| 223 |
{"type": "get_call_graph", "params": {}, "reward": -0.08, "description": "Get function call graph"},
|
| 224 |
{"type": "submit", "params": {"function_name": "str", "vulnerability_type": "str"},"reward": "+5.0 / +1.0 / -1.5", "description": "Submit answer. Ends episode."},
|
| 225 |
],
|
| 226 |
+
}, status_code=200)
|
| 227 |
if task_id == "task2_property_discovery":
|
| 228 |
+
return JSONResponse(content={
|
| 229 |
"task_id": task_id,
|
| 230 |
"actions": [
|
| 231 |
{"type": "get_function_code", "params": {}, "reward": -0.06, "description": "Read full source of the target function"},
|
|
|
|
| 236 |
{"type": "get_similar_rule", "params": {}, "reward": -0.20, "description": "Get a similar property from another contract"},
|
| 237 |
{"type": "submit_property", "params": {"property": "string"}, "reward": "0.0β5.0", "description": "Submit property. ONE attempt. Ends episode."},
|
| 238 |
],
|
| 239 |
+
}, status_code=200)
|
| 240 |
if task_id == "task3_rule_checker":
|
| 241 |
+
return JSONResponse(content={
|
| 242 |
"task_id": task_id,
|
| 243 |
"actions": [
|
| 244 |
{"type": "list_functions", "params": {}, "reward": -0.05, "description": "List all function names"},
|
|
|
|
| 249 |
{"type": "get_property_specification", "params": {}, "reward": -0.03, "description": "Get formal pre/post-condition for the property"},
|
| 250 |
{"type": "submit_function", "params": {"function_name": "string"}, "reward": "+5.0 / +1.5 / -1.5", "description": "Submit answer. ONE attempt. Ends episode."},
|
| 251 |
],
|
| 252 |
+
}, status_code=200)
|
| 253 |
+
raise HTTPException(status_code=400, detail=f"Unknown task_id '{task_id}'.")
|
|
|
|
| 254 |
|
| 255 |
@app.get("/observation_space")
|
| 256 |
def observation_space():
|
| 257 |
"""Describe the observation space (same for all tasks)."""
|
| 258 |
+
return JSONResponse(content={
|
| 259 |
"type": "object",
|
| 260 |
"fields": {
|
| 261 |
"task_id": "string β active task identifier",
|
|
|
|
| 269 |
"done": "bool β True when episode is over",
|
| 270 |
"extra": "object β task-specific hints (target_function, hint, etc.)",
|
| 271 |
},
|
| 272 |
+
}, status_code=200)
|
| 273 |
|
| 274 |
|
| 275 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|