Spaces:
Running
Running
ajaxwin commited on
Commit Β·
5235476
1
Parent(s): 409c8b7
refactor: Update ActionType to include costs and modified grader for task 1
Browse files- env/schemas.py +33 -25
- server/app.py +5 -9
- server/tasks/task1/actions.py +35 -62
- server/tasks/task1/environment.py +18 -28
- server/tasks/task1/grader.py +15 -43
- utils/semanticmatcher.py +2 -14
env/schemas.py
CHANGED
|
@@ -24,28 +24,43 @@ from pydantic import BaseModel, Field
|
|
| 24 |
# ---------------------------------------------------------------------------
|
| 25 |
|
| 26 |
class ActionType(str, Enum):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
# ββ Task 1 β Vulnerability Detection βββββββββββββββββββββββββββββββββββ
|
| 28 |
-
LIST_FUNCTIONS = "list_functions"
|
| 29 |
-
GET_FUNCTION_CODE = "get_function_code"
|
| 30 |
-
GET_FUNCTION_SUMMARY = "get_function_summary"
|
| 31 |
-
GET_FILE_METADATA = "get_file_metadata"
|
| 32 |
-
GET_STATE_VARIABLE = "get_state_variable"
|
| 33 |
-
GET_CALL_GRAPH = "get_call_graph"
|
| 34 |
-
SUBMIT = "submit"
|
| 35 |
|
| 36 |
# ββ Task 2 β Property Discovery βββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
-
GET_SIMILAR_RULE = "get_similar_rule"
|
| 38 |
-
GET_FILE_NATSPEC = "get_file_natspec"
|
| 39 |
-
GET_FUNCTION_NATSPEC = "get_function_natspec"
|
| 40 |
-
GET_RELATED_FUNCTIONS = "get_related_functions"
|
| 41 |
-
GET_SIGNATURE = "get_signature"
|
| 42 |
-
SUBMIT_PROPERTY = "submit_property"
|
| 43 |
|
| 44 |
# ββ Task 3 β Rule Checker ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 45 |
-
GET_PROPERTY_SPECIFICATION = "get_property_specification"
|
| 46 |
-
GET_FUNCTION_METADATA = "get_function_metadata"
|
| 47 |
-
SUBMIT_FUNCTION = "submit_function"
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
class Action(BaseModel):
|
| 51 |
"""
|
|
@@ -73,27 +88,20 @@ class Observation(BaseModel):
|
|
| 73 |
|
| 74 |
task_id : which task is active
|
| 75 |
contract_name : name of the Solidity contract
|
| 76 |
-
contract_description : high-level description of what the contract does
|
| 77 |
available_actions : list of valid ActionType strings
|
| 78 |
last_action : the action that produced this observation (None on reset)
|
| 79 |
last_action_result : human-readable result of the last action
|
| 80 |
-
step_count : number of steps taken so far
|
| 81 |
-
cumulative_reward : running reward total
|
| 82 |
done : whether the episode has ended
|
| 83 |
extra : any additional task-specific context
|
| 84 |
"""
|
| 85 |
task_id: str
|
| 86 |
contract_name: str
|
| 87 |
-
|
| 88 |
-
available_actions: List[str]
|
| 89 |
last_action: Optional[str] = None
|
| 90 |
last_action_result: Optional[str] = None
|
| 91 |
-
step_count: int = 0
|
| 92 |
-
cumulative_reward: float = 0.0
|
| 93 |
done: bool = False
|
| 94 |
extra: Dict[str, Any] = Field(default_factory=dict)
|
| 95 |
|
| 96 |
-
|
| 97 |
# ---------------------------------------------------------------------------
|
| 98 |
# Reward
|
| 99 |
# ---------------------------------------------------------------------------
|
|
|
|
| 24 |
# ---------------------------------------------------------------------------
|
| 25 |
|
| 26 |
class ActionType(str, Enum):
|
| 27 |
+
"""(Action type, cost)"""
|
| 28 |
+
|
| 29 |
+
# Attribute to store the cost of each action
|
| 30 |
+
cost: float
|
| 31 |
+
|
| 32 |
# ββ Task 1 β Vulnerability Detection βββββββββββββββββββββββββββββββββββ
|
| 33 |
+
LIST_FUNCTIONS = ("list_functions", -0.04)
|
| 34 |
+
GET_FUNCTION_CODE = ("get_function_code", -0.14)
|
| 35 |
+
GET_FUNCTION_SUMMARY = ("get_function_summary", -0.07)
|
| 36 |
+
GET_FILE_METADATA = ("get_file_metadata", -0.02)
|
| 37 |
+
GET_STATE_VARIABLE = ("get_state_variable", -0.06)
|
| 38 |
+
GET_CALL_GRAPH = ("get_call_graph", -0.08)
|
| 39 |
+
SUBMIT = ("submit", 0.0)
|
| 40 |
|
| 41 |
# ββ Task 2 β Property Discovery βββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
+
GET_SIMILAR_RULE = ("get_similar_rule", 0.0)
|
| 43 |
+
GET_FILE_NATSPEC = ("get_file_natspec", 0.0)
|
| 44 |
+
GET_FUNCTION_NATSPEC = ("get_function_natspec", 0.0)
|
| 45 |
+
GET_RELATED_FUNCTIONS = ("get_related_functions", 0.0)
|
| 46 |
+
GET_SIGNATURE = ("get_signature", 0.0)
|
| 47 |
+
SUBMIT_PROPERTY = ("submit_property", 0.0)
|
| 48 |
|
| 49 |
# ββ Task 3 β Rule Checker ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 50 |
+
GET_PROPERTY_SPECIFICATION = ("get_property_specification", 0.0)
|
| 51 |
+
GET_FUNCTION_METADATA = ("get_function_metadata", 0.0)
|
| 52 |
+
SUBMIT_FUNCTION = ("submit_function", 0.0)
|
| 53 |
+
|
| 54 |
+
# βββββββ General Actions βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 55 |
+
UNKNOWN = ("unknown", 0.0)
|
| 56 |
+
REPEATED = ("repeated", -0.22)
|
| 57 |
+
RESUBMIT = ("resubmit", 0.0)
|
| 58 |
+
|
| 59 |
+
def __new__(cls, str_value: str, cost: float):
|
| 60 |
+
obj = str.__new__(cls, str_value)
|
| 61 |
+
obj._value_ = str_value
|
| 62 |
+
obj.cost = cost
|
| 63 |
+
return obj
|
| 64 |
|
| 65 |
class Action(BaseModel):
|
| 66 |
"""
|
|
|
|
| 88 |
|
| 89 |
task_id : which task is active
|
| 90 |
contract_name : name of the Solidity contract
|
|
|
|
| 91 |
available_actions : list of valid ActionType strings
|
| 92 |
last_action : the action that produced this observation (None on reset)
|
| 93 |
last_action_result : human-readable result of the last action
|
|
|
|
|
|
|
| 94 |
done : whether the episode has ended
|
| 95 |
extra : any additional task-specific context
|
| 96 |
"""
|
| 97 |
task_id: str
|
| 98 |
contract_name: str
|
| 99 |
+
# available_actions: List[str] # May need it, may not depends on the agent
|
|
|
|
| 100 |
last_action: Optional[str] = None
|
| 101 |
last_action_result: Optional[str] = None
|
|
|
|
|
|
|
| 102 |
done: bool = False
|
| 103 |
extra: Dict[str, Any] = Field(default_factory=dict)
|
| 104 |
|
|
|
|
| 105 |
# ---------------------------------------------------------------------------
|
| 106 |
# Reward
|
| 107 |
# ---------------------------------------------------------------------------
|
server/app.py
CHANGED
|
@@ -190,18 +190,15 @@ def step(
|
|
| 190 |
status_code=400,
|
| 191 |
detail=f"No active session '{session_id}'. Call /reset first.",
|
| 192 |
)
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
status_code=400,
|
| 198 |
-
detail=f"Unknown action_type '{body.action_type}'. Valid: {[a.value for a in ActionType]}",
|
| 199 |
-
)
|
| 200 |
action = Action(action_type=action_type, params=body.params)
|
| 201 |
try:
|
| 202 |
result = env.step(action)
|
| 203 |
except RuntimeError as e:
|
| 204 |
-
|
| 205 |
return JSONResponse(content=result.model_dump(), status_code=200)
|
| 206 |
|
| 207 |
|
|
@@ -216,7 +213,6 @@ def state(session_id: str = Query(default=DEFAULT_SESSION)):
|
|
| 216 |
)
|
| 217 |
return JSONResponse(content=env.state().model_dump(), status_code=200)
|
| 218 |
|
| 219 |
-
|
| 220 |
@app.get("/action_space")
|
| 221 |
def action_space(task_id: str = "task1_vuln_detection"):
|
| 222 |
"""Describe the action space for a task."""
|
|
|
|
| 190 |
status_code=400,
|
| 191 |
detail=f"No active session '{session_id}'. Call /reset first.",
|
| 192 |
)
|
| 193 |
+
|
| 194 |
+
# removed error handling here
|
| 195 |
+
action_type = ActionType(body.action_type) if body.action_type in ActionType else ActionType.UNKNOWN
|
| 196 |
+
|
|
|
|
|
|
|
|
|
|
| 197 |
action = Action(action_type=action_type, params=body.params)
|
| 198 |
try:
|
| 199 |
result = env.step(action)
|
| 200 |
except RuntimeError as e:
|
| 201 |
+
return JSONResponse(content=str(e), status_code = 200)
|
| 202 |
return JSONResponse(content=result.model_dump(), status_code=200)
|
| 203 |
|
| 204 |
|
|
|
|
| 213 |
)
|
| 214 |
return JSONResponse(content=env.state().model_dump(), status_code=200)
|
| 215 |
|
|
|
|
| 216 |
@app.get("/action_space")
|
| 217 |
def action_space(task_id: str = "task1_vuln_detection"):
|
| 218 |
"""Describe the action space for a task."""
|
server/tasks/task1/actions.py
CHANGED
|
@@ -1,13 +1,8 @@
|
|
| 1 |
"""Actions for Task 1: Targeted Vulnerability Detection.
|
| 2 |
-
Actions & rewards:
|
| 3 |
-
list_functions -0.05 (broad overview of contract)
|
| 4 |
-
get_function_code -0.10 (wrong function) / +0.05 (correct function)
|
| 5 |
-
get_function_summary -0.05 (wrong function) / +0.03 (correct function)
|
| 6 |
-
get_file_metadata -0.04 (general contract info)
|
| 7 |
"""
|
| 8 |
|
| 9 |
from typing import Any, Dict, Tuple
|
| 10 |
-
from env.schemas import Reward
|
| 11 |
from data.data_loader import (
|
| 12 |
list_function_names,
|
| 13 |
get_function_by_name,
|
|
@@ -19,11 +14,11 @@ from data.data_loader import (
|
|
| 19 |
def list_functions(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
| 20 |
"""Handle LIST_FUNCTIONS action."""
|
| 21 |
if ctx._is_repeated(qkey):
|
| 22 |
-
return "Repeated query.", Reward(value=
|
| 23 |
names = list_function_names(ctx._contract)
|
| 24 |
return (
|
| 25 |
f"Functions in {ctx._contract['contract_name']}: {', '.join(names)}",
|
| 26 |
-
Reward(value=
|
| 27 |
)
|
| 28 |
|
| 29 |
|
|
@@ -31,20 +26,19 @@ def get_function_code(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
|
| 31 |
"""Handle GET_FUNCTION_CODE action."""
|
| 32 |
fn_name = params.get("function_name", "")
|
| 33 |
if ctx._is_repeated(qkey):
|
| 34 |
-
return "Repeated query.", Reward(value=
|
|
|
|
| 35 |
fn = get_function_by_name(ctx._contract, fn_name)
|
| 36 |
if fn is None:
|
| 37 |
return (
|
| 38 |
f"Function '{fn_name}' not found. Available: {list_function_names(ctx._contract)}",
|
| 39 |
-
Reward(value=
|
| 40 |
)
|
| 41 |
-
|
| 42 |
code = fn.get("code", "// no code available")
|
| 43 |
-
reward_val = 0.05 if is_target else -0.10
|
| 44 |
-
reason = "Fetched target function code (+)" if is_target else "Fetched non-target function (-)"
|
| 45 |
return (
|
| 46 |
f"// {fn['name']}\n{code}",
|
| 47 |
-
Reward(value=
|
| 48 |
)
|
| 49 |
|
| 50 |
|
|
@@ -52,71 +46,74 @@ def get_function_summary(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward
|
|
| 52 |
"""Handle GET_FUNCTION_SUMMARY action."""
|
| 53 |
fn_name = params.get("function_name", "")
|
| 54 |
if ctx._is_repeated(qkey):
|
| 55 |
-
return "Repeated query.", Reward(value=
|
|
|
|
| 56 |
fn = get_function_by_name(ctx._contract, fn_name)
|
| 57 |
if fn is None:
|
| 58 |
return (
|
| 59 |
f"Function '{fn_name}' not found.",
|
| 60 |
-
Reward(value=
|
| 61 |
)
|
| 62 |
-
|
| 63 |
comment = fn.get("comment", "No summary available.")
|
| 64 |
-
reward_val = 0.03 if is_target else -0.05
|
| 65 |
-
reason = "Fetched target function summary (+)" if is_target else "Fetched non-target summary (-)"
|
| 66 |
return (
|
| 67 |
f"Summary of '{fn['name']}': {comment}",
|
| 68 |
-
Reward(value=
|
| 69 |
)
|
| 70 |
|
| 71 |
|
| 72 |
def get_file_metadata(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
| 73 |
"""Handle GET_FILE_METADATA action."""
|
| 74 |
if ctx._is_repeated(qkey):
|
| 75 |
-
return "Repeated query.", Reward(value=
|
|
|
|
| 76 |
meta = ctx._contract.get("metadata", {})
|
| 77 |
result = (
|
| 78 |
f"Contract: {ctx._contract['contract_name']} | "
|
| 79 |
f"Solidity: {meta.get('solidity_version', 'N/A')} | "
|
| 80 |
f"Description: {meta.get('description', 'N/A')}"
|
| 81 |
)
|
| 82 |
-
return result, Reward(value=
|
| 83 |
|
| 84 |
|
| 85 |
def get_state_variable(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
| 86 |
"""Handle GET_STATE_VARIABLE action."""
|
| 87 |
var_name = params.get("variable_name", "")
|
| 88 |
if ctx._is_repeated(qkey):
|
| 89 |
-
return "Repeated query.", Reward(value=
|
|
|
|
| 90 |
if not var_name:
|
| 91 |
names = list_state_variable_names(ctx._contract)
|
| 92 |
return (
|
| 93 |
f"State variables: {', '.join(names)}",
|
| 94 |
-
Reward(value=
|
| 95 |
)
|
|
|
|
| 96 |
sv = get_state_variable_by_name(ctx._contract, var_name)
|
| 97 |
if sv is None:
|
| 98 |
return (
|
| 99 |
f"Variable '{var_name}' not found.",
|
| 100 |
-
Reward(value=
|
| 101 |
)
|
|
|
|
| 102 |
return (
|
| 103 |
f"{sv['type']} {sv['visibility']} {sv['name']}: {sv.get('description', '')}",
|
| 104 |
-
Reward(value=
|
| 105 |
)
|
| 106 |
|
| 107 |
|
| 108 |
def get_call_graph(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
| 109 |
"""Handle GET_CALL_GRAPH action."""
|
| 110 |
if ctx._is_repeated(qkey):
|
| 111 |
-
return "Repeated query.", Reward(value=
|
|
|
|
| 112 |
cg = ctx._contract.get("call_graph", {})
|
| 113 |
cg_str = "; ".join(f"{fn} β [{', '.join(callees)}]" for fn, callees in cg.items())
|
| 114 |
return (
|
| 115 |
f"Call graph: {cg_str}",
|
| 116 |
-
Reward(value=
|
| 117 |
)
|
| 118 |
|
| 119 |
-
|
| 120 |
def submit(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
| 121 |
"""Handle SUBMIT action for Task 1.
|
| 122 |
|
|
@@ -127,14 +124,15 @@ def submit(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
|
| 127 |
"""
|
| 128 |
if ctx._done:
|
| 129 |
return (
|
| 130 |
-
"β You have already submitted for this episode. "
|
| 131 |
"Only ONE submission is allowed.",
|
| 132 |
-
Reward(value=
|
|
|
|
|
|
|
| 133 |
)
|
| 134 |
|
| 135 |
fn_name = params.get("function_name", "").strip()
|
| 136 |
vuln_type = params.get("vulnerability_type", "").strip()
|
| 137 |
-
|
| 138 |
if not fn_name or not vuln_type:
|
| 139 |
return (
|
| 140 |
"submit_function requires both 'function_name' and "
|
|
@@ -142,35 +140,10 @@ def submit(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
|
| 142 |
Reward(value=0.0, reason="Malformed submission", partial=False),
|
| 143 |
)
|
| 144 |
|
| 145 |
-
ctx._done
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
correct = ctx._grader.get_canonical_answer()
|
| 150 |
-
|
| 151 |
-
if score == 1.0:
|
| 152 |
-
msg = (
|
| 153 |
-
f"β
CORRECT! '{fn_name}' is the vulnerable function "
|
| 154 |
-
f"and the vulnerability type matches. "
|
| 155 |
-
f"Score: 1.0 β Reward: {reward_val:.3f}"
|
| 156 |
-
)
|
| 157 |
-
elif score == 0.5:
|
| 158 |
-
msg = (
|
| 159 |
-
f"π‘ PARTIAL. '{fn_name}' is the correct function but the "
|
| 160 |
-
f"vulnerability type was not recognised. "
|
| 161 |
-
f"Score: 0.5 β Reward: {reward_val:.3f}. "
|
| 162 |
-
f"Expected vulnerability: '{correct['vulnerability']}'."
|
| 163 |
-
)
|
| 164 |
-
else:
|
| 165 |
-
msg = (
|
| 166 |
-
f"β INCORRECT. '{fn_name}' is not the target function. "
|
| 167 |
-
f"Score: 0.0 β Reward: {reward_val:.3f}. "
|
| 168 |
-
f"Correct answer: function='{correct['function']}', "
|
| 169 |
-
f"vulnerability='{correct['vulnerability']}'."
|
| 170 |
-
)
|
| 171 |
-
|
| 172 |
-
return msg, Reward(
|
| 173 |
-
value=reward_val,
|
| 174 |
reason=f"submit_function score={score:.1f}",
|
| 175 |
partial=False,
|
| 176 |
)
|
|
@@ -180,5 +153,5 @@ def unknown_action(ctx: Any, qkey: str, params: Dict, action_type: str) -> Tuple
|
|
| 180 |
"""Fallback for unknown actions."""
|
| 181 |
return (
|
| 182 |
f"Unknown action type: {action_type}",
|
| 183 |
-
Reward(value=
|
| 184 |
)
|
|
|
|
| 1 |
"""Actions for Task 1: Targeted Vulnerability Detection.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
"""
|
| 3 |
|
| 4 |
from typing import Any, Dict, Tuple
|
| 5 |
+
from env.schemas import ActionType, Reward
|
| 6 |
from data.data_loader import (
|
| 7 |
list_function_names,
|
| 8 |
get_function_by_name,
|
|
|
|
| 14 |
def list_functions(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
| 15 |
"""Handle LIST_FUNCTIONS action."""
|
| 16 |
if ctx._is_repeated(qkey):
|
| 17 |
+
return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True)
|
| 18 |
names = list_function_names(ctx._contract)
|
| 19 |
return (
|
| 20 |
f"Functions in {ctx._contract['contract_name']}: {', '.join(names)}",
|
| 21 |
+
Reward(value=ActionType.LIST_FUNCTIONS.cost, reason="list_functions cost", partial=True),
|
| 22 |
)
|
| 23 |
|
| 24 |
|
|
|
|
| 26 |
"""Handle GET_FUNCTION_CODE action."""
|
| 27 |
fn_name = params.get("function_name", "")
|
| 28 |
if ctx._is_repeated(qkey):
|
| 29 |
+
return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True)
|
| 30 |
+
|
| 31 |
fn = get_function_by_name(ctx._contract, fn_name)
|
| 32 |
if fn is None:
|
| 33 |
return (
|
| 34 |
f"Function '{fn_name}' not found. Available: {list_function_names(ctx._contract)}",
|
| 35 |
+
Reward(value=ActionType.GET_FUNCTION_CODE.cost, reason="Wrong/unknown function name", partial=True),
|
| 36 |
)
|
| 37 |
+
|
| 38 |
code = fn.get("code", "// no code available")
|
|
|
|
|
|
|
| 39 |
return (
|
| 40 |
f"// {fn['name']}\n{code}",
|
| 41 |
+
Reward(value=ActionType.GET_FUNCTION_CODE.cost, reason="Fetched code", partial=True),
|
| 42 |
)
|
| 43 |
|
| 44 |
|
|
|
|
| 46 |
"""Handle GET_FUNCTION_SUMMARY action."""
|
| 47 |
fn_name = params.get("function_name", "")
|
| 48 |
if ctx._is_repeated(qkey):
|
| 49 |
+
return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True)
|
| 50 |
+
|
| 51 |
fn = get_function_by_name(ctx._contract, fn_name)
|
| 52 |
if fn is None:
|
| 53 |
return (
|
| 54 |
f"Function '{fn_name}' not found.",
|
| 55 |
+
Reward(value=ActionType.GET_FUNCTION_SUMMARY.cost, reason="Wrong function name", partial=True),
|
| 56 |
)
|
| 57 |
+
|
| 58 |
comment = fn.get("comment", "No summary available.")
|
|
|
|
|
|
|
| 59 |
return (
|
| 60 |
f"Summary of '{fn['name']}': {comment}",
|
| 61 |
+
Reward(value=ActionType.GET_FUNCTION_SUMMARY.cost, reason="Fetched summary", partial=True),
|
| 62 |
)
|
| 63 |
|
| 64 |
|
| 65 |
def get_file_metadata(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
| 66 |
"""Handle GET_FILE_METADATA action."""
|
| 67 |
if ctx._is_repeated(qkey):
|
| 68 |
+
return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True)
|
| 69 |
+
|
| 70 |
meta = ctx._contract.get("metadata", {})
|
| 71 |
result = (
|
| 72 |
f"Contract: {ctx._contract['contract_name']} | "
|
| 73 |
f"Solidity: {meta.get('solidity_version', 'N/A')} | "
|
| 74 |
f"Description: {meta.get('description', 'N/A')}"
|
| 75 |
)
|
| 76 |
+
return result, Reward(value=ActionType.GET_FILE_METADATA.cost, reason="get_file_metadata cost", partial=True)
|
| 77 |
|
| 78 |
|
| 79 |
def get_state_variable(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
| 80 |
"""Handle GET_STATE_VARIABLE action."""
|
| 81 |
var_name = params.get("variable_name", "")
|
| 82 |
if ctx._is_repeated(qkey):
|
| 83 |
+
return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True)
|
| 84 |
+
|
| 85 |
if not var_name:
|
| 86 |
names = list_state_variable_names(ctx._contract)
|
| 87 |
return (
|
| 88 |
f"State variables: {', '.join(names)}",
|
| 89 |
+
Reward(value=ActionType.GET_STATE_VARIABLE.cost, reason="Listed state variables", partial=True),
|
| 90 |
)
|
| 91 |
+
|
| 92 |
sv = get_state_variable_by_name(ctx._contract, var_name)
|
| 93 |
if sv is None:
|
| 94 |
return (
|
| 95 |
f"Variable '{var_name}' not found.",
|
| 96 |
+
Reward(value=ActionType.GET_STATE_VARIABLE.cost, reason="Unknown state variable", partial=True),
|
| 97 |
)
|
| 98 |
+
|
| 99 |
return (
|
| 100 |
f"{sv['type']} {sv['visibility']} {sv['name']}: {sv.get('description', '')}",
|
| 101 |
+
Reward(value=ActionType.GET_STATE_VARIABLE.cost, reason="get_state_variable cost", partial=True),
|
| 102 |
)
|
| 103 |
|
| 104 |
|
| 105 |
def get_call_graph(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
| 106 |
"""Handle GET_CALL_GRAPH action."""
|
| 107 |
if ctx._is_repeated(qkey):
|
| 108 |
+
return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True)
|
| 109 |
+
|
| 110 |
cg = ctx._contract.get("call_graph", {})
|
| 111 |
cg_str = "; ".join(f"{fn} β [{', '.join(callees)}]" for fn, callees in cg.items())
|
| 112 |
return (
|
| 113 |
f"Call graph: {cg_str}",
|
| 114 |
+
Reward(value=ActionType.GET_CALL_GRAPH.cost, reason="get_call_graph cost", partial=True),
|
| 115 |
)
|
| 116 |
|
|
|
|
| 117 |
def submit(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
| 118 |
"""Handle SUBMIT action for Task 1.
|
| 119 |
|
|
|
|
| 124 |
"""
|
| 125 |
if ctx._done:
|
| 126 |
return (
|
|
|
|
| 127 |
"Only ONE submission is allowed.",
|
| 128 |
+
Reward(value=ActionType.RESUBMIT.cost,
|
| 129 |
+
reason="Second submit_function attempt",
|
| 130 |
+
partial=False),
|
| 131 |
)
|
| 132 |
|
| 133 |
fn_name = params.get("function_name", "").strip()
|
| 134 |
vuln_type = params.get("vulnerability_type", "").strip()
|
| 135 |
+
|
| 136 |
if not fn_name or not vuln_type:
|
| 137 |
return (
|
| 138 |
"submit_function requires both 'function_name' and "
|
|
|
|
| 140 |
Reward(value=0.0, reason="Malformed submission", partial=False),
|
| 141 |
)
|
| 142 |
|
| 143 |
+
ctx._done = True
|
| 144 |
+
score = ctx._grader.grade(fn_name, vuln_type, ctx._step_count, ctx._cummulative_cost)
|
| 145 |
+
return (f"Correct Answer: {ctx._grader.get_canonical_answer}"), Reward(
|
| 146 |
+
value=score,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
reason=f"submit_function score={score:.1f}",
|
| 148 |
partial=False,
|
| 149 |
)
|
|
|
|
| 153 |
"""Fallback for unknown actions."""
|
| 154 |
return (
|
| 155 |
f"Unknown action type: {action_type}",
|
| 156 |
+
Reward(value=ActionType.UNKNOWN.cost, reason="Unknown action", partial=True),
|
| 157 |
)
|
server/tasks/task1/environment.py
CHANGED
|
@@ -9,21 +9,11 @@ Episode flow:
|
|
| 9 |
3. The agent uses actions to explore the contract (each costs a small penalty).
|
| 10 |
4. When the agent submits, the Grader scores the answer and the episode ends.
|
| 11 |
|
| 12 |
-
Reward shaping:
|
| 13 |
-
list_functions : -0.05
|
| 14 |
-
get_function_code : -0.10 (wrong function) / +0.05 (correct function)
|
| 15 |
-
get_function_summary : -0.05 (wrong function) / +0.03 (correct function)
|
| 16 |
-
get_file_metadata : -0.04
|
| 17 |
-
get_state_variable : -0.05
|
| 18 |
-
get_call_graph : -0.08
|
| 19 |
-
correct submit (score=1.0) : +5.0
|
| 20 |
-
partially correct submit (score=0.5) : +1.0
|
| 21 |
-
wrong submit (score=0.0) : -1.5
|
| 22 |
-
repeated query : -0.40
|
| 23 |
"""
|
| 24 |
|
| 25 |
from __future__ import annotations
|
| 26 |
|
|
|
|
| 27 |
import random
|
| 28 |
from typing import Any, Dict, List, Optional, Set
|
| 29 |
|
|
@@ -60,16 +50,19 @@ class Task1Environment(BaseEnv):
|
|
| 60 |
def __init__(self, contracts_path: Optional[str] = None) -> None:
|
| 61 |
self._contracts = load_contracts(contracts_path) if contracts_path else load_contracts()
|
| 62 |
self._rng = random.Random()
|
|
|
|
| 63 |
|
| 64 |
# Episode state (initialised by reset)
|
| 65 |
self._contract: Dict[str, Any] = {}
|
| 66 |
self._target_fn: Dict[str, Any] = {}
|
| 67 |
self._grader: Optional[Task1Grader] = None
|
| 68 |
self._step_count: int = 0
|
| 69 |
-
self.
|
| 70 |
self._done: bool = False
|
| 71 |
self._query_history: List[str] = []
|
| 72 |
self._seen_queries: Set[str] = set()
|
|
|
|
|
|
|
| 73 |
|
| 74 |
# ------------------------------------------------------------------
|
| 75 |
# OpenEnv interface
|
|
@@ -84,9 +77,10 @@ class Task1Environment(BaseEnv):
|
|
| 84 |
self._grader = Task1Grader(
|
| 85 |
target_function=self._target_fn["name"],
|
| 86 |
vulnerability_issue=self._target_fn["vulnerability_details"]["issue"],
|
|
|
|
| 87 |
)
|
| 88 |
self._step_count = 0
|
| 89 |
-
self.
|
| 90 |
self._done = False
|
| 91 |
self._query_history = []
|
| 92 |
self._seen_queries = set()
|
|
@@ -102,28 +96,28 @@ class Task1Environment(BaseEnv):
|
|
| 102 |
|
| 103 |
def step(self, action: Action) -> StepResult:
|
| 104 |
"""Execute one agent action."""
|
|
|
|
| 105 |
if self._done:
|
| 106 |
raise RuntimeError("Episode is done. Call reset() to start a new episode.")
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
result_text, reward = self._dispatch(action)
|
| 112 |
-
|
| 113 |
-
self.
|
| 114 |
-
self._query_history.append(f"[{action.action_type}] β {result_text[:120]}")
|
| 115 |
-
|
| 116 |
obs = self._build_observation(
|
| 117 |
last_action=action.action_type,
|
| 118 |
last_result=result_text,
|
| 119 |
)
|
|
|
|
| 120 |
return StepResult(
|
| 121 |
observation=obs,
|
| 122 |
reward=reward,
|
| 123 |
done=self._done,
|
| 124 |
info={
|
| 125 |
"step": self._step_count,
|
| 126 |
-
"cumulative_reward": self.
|
| 127 |
},
|
| 128 |
)
|
| 129 |
|
|
@@ -133,7 +127,7 @@ class Task1Environment(BaseEnv):
|
|
| 133 |
contract_name=self._contract.get("contract_name", ""),
|
| 134 |
target_function=self._target_fn.get("name", ""),
|
| 135 |
step_count=self._step_count,
|
| 136 |
-
cumulative_reward=self.
|
| 137 |
done=self._done,
|
| 138 |
query_history=list(self._query_history),
|
| 139 |
)
|
|
@@ -150,12 +144,8 @@ class Task1Environment(BaseEnv):
|
|
| 150 |
return Observation(
|
| 151 |
task_id=TASK_ID,
|
| 152 |
contract_name=self._contract.get("contract_name", ""),
|
| 153 |
-
contract_description=self._contract.get("metadata", {}).get("description", ""),
|
| 154 |
-
available_actions=[a.value for a in AVAILABLE_ACTIONS],
|
| 155 |
last_action=last_action,
|
| 156 |
last_action_result=last_result,
|
| 157 |
-
step_count=self._step_count,
|
| 158 |
-
cumulative_reward=self._cumulative_reward,
|
| 159 |
done=self._done,
|
| 160 |
extra={
|
| 161 |
"solidity_version": self._contract.get("metadata", {}).get("solidity_version", ""),
|
|
@@ -181,7 +171,7 @@ class Task1Environment(BaseEnv):
|
|
| 181 |
at = action.action_type
|
| 182 |
params = action.params
|
| 183 |
qkey = self._query_key(at, params)
|
| 184 |
-
|
| 185 |
# Mapping from ActionType to handler function
|
| 186 |
handlers = {
|
| 187 |
ActionType.LIST_FUNCTIONS: actions.list_functions,
|
|
|
|
| 9 |
3. The agent uses actions to explore the contract (each costs a small penalty).
|
| 10 |
4. When the agent submits, the Grader scores the answer and the episode ends.
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
from __future__ import annotations
|
| 15 |
|
| 16 |
+
from math import floor, log2
|
| 17 |
import random
|
| 18 |
from typing import Any, Dict, List, Optional, Set
|
| 19 |
|
|
|
|
| 50 |
def __init__(self, contracts_path: Optional[str] = None) -> None:
|
| 51 |
self._contracts = load_contracts(contracts_path) if contracts_path else load_contracts()
|
| 52 |
self._rng = random.Random()
|
| 53 |
+
self._max_steps: int = 0
|
| 54 |
|
| 55 |
# Episode state (initialised by reset)
|
| 56 |
self._contract: Dict[str, Any] = {}
|
| 57 |
self._target_fn: Dict[str, Any] = {}
|
| 58 |
self._grader: Optional[Task1Grader] = None
|
| 59 |
self._step_count: int = 0
|
| 60 |
+
self._cummulative_cost: float = 0.0
|
| 61 |
self._done: bool = False
|
| 62 |
self._query_history: List[str] = []
|
| 63 |
self._seen_queries: Set[str] = set()
|
| 64 |
+
self._cost_free_steps: int = 0
|
| 65 |
+
self._decay: float = 0.0
|
| 66 |
|
| 67 |
# ------------------------------------------------------------------
|
| 68 |
# OpenEnv interface
|
|
|
|
| 77 |
self._grader = Task1Grader(
|
| 78 |
target_function=self._target_fn["name"],
|
| 79 |
vulnerability_issue=self._target_fn["vulnerability_details"]["issue"],
|
| 80 |
+
n = floor(log2(len(self._contract["functions"])))
|
| 81 |
)
|
| 82 |
self._step_count = 0
|
| 83 |
+
self._cummulative_cost = 0.0
|
| 84 |
self._done = False
|
| 85 |
self._query_history = []
|
| 86 |
self._seen_queries = set()
|
|
|
|
| 96 |
|
| 97 |
def step(self, action: Action) -> StepResult:
|
| 98 |
"""Execute one agent action."""
|
| 99 |
+
|
| 100 |
if self._done:
|
| 101 |
raise RuntimeError("Episode is done. Call reset() to start a new episode.")
|
| 102 |
+
if self._step_count > self._max_steps:
|
| 103 |
+
raise RuntimeError("Exceeded maximum number of steps allowed. Call reset() to start a new episode.")
|
| 104 |
+
|
| 105 |
+
self._step_count += 1
|
| 106 |
result_text, reward = self._dispatch(action)
|
| 107 |
+
self._cummulative_cost += reward.value
|
| 108 |
+
self._query_history.append(f"[{action.action_type}] β {result_text[:200]}")
|
|
|
|
|
|
|
| 109 |
obs = self._build_observation(
|
| 110 |
last_action=action.action_type,
|
| 111 |
last_result=result_text,
|
| 112 |
)
|
| 113 |
+
|
| 114 |
return StepResult(
|
| 115 |
observation=obs,
|
| 116 |
reward=reward,
|
| 117 |
done=self._done,
|
| 118 |
info={
|
| 119 |
"step": self._step_count,
|
| 120 |
+
"cumulative_reward": self._cummulative_cost,
|
| 121 |
},
|
| 122 |
)
|
| 123 |
|
|
|
|
| 127 |
contract_name=self._contract.get("contract_name", ""),
|
| 128 |
target_function=self._target_fn.get("name", ""),
|
| 129 |
step_count=self._step_count,
|
| 130 |
+
cumulative_reward=self._cummulative_cost,
|
| 131 |
done=self._done,
|
| 132 |
query_history=list(self._query_history),
|
| 133 |
)
|
|
|
|
| 144 |
return Observation(
|
| 145 |
task_id=TASK_ID,
|
| 146 |
contract_name=self._contract.get("contract_name", ""),
|
|
|
|
|
|
|
| 147 |
last_action=last_action,
|
| 148 |
last_action_result=last_result,
|
|
|
|
|
|
|
| 149 |
done=self._done,
|
| 150 |
extra={
|
| 151 |
"solidity_version": self._contract.get("metadata", {}).get("solidity_version", ""),
|
|
|
|
| 171 |
at = action.action_type
|
| 172 |
params = action.params
|
| 173 |
qkey = self._query_key(at, params)
|
| 174 |
+
|
| 175 |
# Mapping from ActionType to handler function
|
| 176 |
handlers = {
|
| 177 |
ActionType.LIST_FUNCTIONS: actions.list_functions,
|
server/tasks/task1/grader.py
CHANGED
|
@@ -1,58 +1,30 @@
|
|
| 1 |
"""
|
| 2 |
grader.py (Task 1 β Targeted Vulnerability Detection)
|
| 3 |
-------------------------------------------------------
|
| 4 |
-
Deterministic grader. Grade range: 0
|
| 5 |
-
|
| 6 |
-
1.0 β correct function + correct vulnerability keyword
|
| 7 |
-
0.5 β correct function + wrong/unrecognised vulnerability keyword
|
| 8 |
-
0.0 β wrong function name
|
| 9 |
-
|
| 10 |
-
reward_for_score() normalises the raw RL reward to [0.0, 1.0]
|
| 11 |
-
using the fixed reward bounds [MIN_REWARD=-1.5, MAX_REWARD=5.0]:
|
| 12 |
-
normalised = (raw + 1.5) / 6.5
|
| 13 |
"""
|
|
|
|
| 14 |
from __future__ import annotations
|
| 15 |
from typing import Dict
|
| 16 |
from utils import SemanticMatcher
|
| 17 |
|
| 18 |
-
# Raw reward bounds β used only for normalisation
|
| 19 |
-
_MIN_REWARD = -1.5
|
| 20 |
-
_MAX_REWARD = 5.0
|
| 21 |
-
_REWARD_RANGE = _MAX_REWARD - _MIN_REWARD # 6.5
|
| 22 |
-
|
| 23 |
-
_SCORE_MIN = 0.001 # grades are strictly (0, 1)
|
| 24 |
-
_SCORE_MAX = 0.999
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def _clamp(v: float) -> float:
|
| 28 |
-
return max(_SCORE_MIN, min(_SCORE_MAX, v))
|
| 29 |
-
|
| 30 |
-
|
| 31 |
class Task1Grader:
|
| 32 |
-
def __init__(self, target_function: str, vulnerability_issue: str) -> None:
|
| 33 |
self.target_function = target_function.lower()
|
| 34 |
self.vulnerability_issue = vulnerability_issue
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
def
|
| 37 |
"""Returns grade strictly in (0, 1)."""
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
Raw rewards: correct=+5.0, partial=+1.0, wrong=-1.5
|
| 47 |
-
Normalised: (raw + 1.5) / 6.5 then clamped to (0.001, 0.999)
|
| 48 |
-
"""
|
| 49 |
-
if score >= _SCORE_MAX:
|
| 50 |
-
raw = 5.0
|
| 51 |
-
elif score >= 0.5:
|
| 52 |
-
raw = 1.0
|
| 53 |
-
else:
|
| 54 |
-
raw = -1.5
|
| 55 |
-
return _clamp((raw - _MIN_REWARD) / _REWARD_RANGE)
|
| 56 |
-
|
| 57 |
def get_canonical_answer(self) -> Dict[str, str]:
|
| 58 |
return {"function": self.target_function, "vulnerability": self.vulnerability_issue}
|
|
|
|
| 1 |
"""
|
| 2 |
grader.py (Task 1 β Targeted Vulnerability Detection)
|
| 3 |
-------------------------------------------------------
|
| 4 |
+
Deterministic grader. Grade range: (0, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
+
|
| 7 |
from __future__ import annotations
|
| 8 |
from typing import Dict
|
| 9 |
from utils import SemanticMatcher
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
class Task1Grader:
|
| 12 |
+
def __init__(self, target_function: str, vulnerability_issue: str, n: int) -> None:
|
| 13 |
self.target_function = target_function.lower()
|
| 14 |
self.vulnerability_issue = vulnerability_issue
|
| 15 |
+
|
| 16 |
+
# Log of No. of functions (n) is a heurisitic used to decided the size of contract code
|
| 17 |
+
self.n = n
|
| 18 |
+
self._decay = 0.75
|
| 19 |
|
| 20 |
+
def grade(self, submitted_function: str, submitted_vuln_type: str, steps: int, cummulative_cost: int) -> float:
|
| 21 |
"""Returns grade strictly in (0, 1)."""
|
| 22 |
+
func_match = submitted_function.strip().lower() != self.target_function
|
| 23 |
+
issue_match = SemanticMatcher().match(self.vulnerability_issue, submitted_vuln_type)
|
| 24 |
+
|
| 25 |
+
# Score formula
|
| 26 |
+
free_budget = (cummulative_cost / steps) * (self.n + 2)
|
| 27 |
+
return func_match * issue_match * (self._decay ** max(0, cummulative_cost - free_budget))
|
| 28 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
def get_canonical_answer(self) -> Dict[str, str]:
|
| 30 |
return {"function": self.target_function, "vulnerability": self.vulnerability_issue}
|
utils/semanticmatcher.py
CHANGED
|
@@ -142,18 +142,6 @@ def cosine_similarity(vec_a: np.ndarray, vec_b: np.ndarray) -> float:
|
|
| 142 |
return 0.0
|
| 143 |
return float(np.dot(vec_a, vec_b) / (norm_a * norm_b))
|
| 144 |
|
| 145 |
-
|
| 146 |
-
# ββ Score clamping βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 147 |
-
|
| 148 |
-
_SCORE_MIN = 0.001 # scores are strictly (0, 1) β never touch 0 or 1
|
| 149 |
-
_SCORE_MAX = 0.999
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
def _clamp(score: float) -> float:
|
| 153 |
-
"""Clamp score to the open interval (0, 1): [_SCORE_MIN, _SCORE_MAX]."""
|
| 154 |
-
return max(_SCORE_MIN, min(_SCORE_MAX, score))
|
| 155 |
-
|
| 156 |
-
|
| 157 |
# ββ Core matcher ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 158 |
|
| 159 |
class SemanticMatcher:
|
|
@@ -212,7 +200,7 @@ class SemanticMatcher:
|
|
| 212 |
# Fast-path: normalized exact match
|
| 213 |
if normalize(text_a) == normalize(text_b):
|
| 214 |
self.confidence_level = "strong"
|
| 215 |
-
return
|
| 216 |
|
| 217 |
tokens_a = tokenize_and_lemmatize(text_a)
|
| 218 |
tokens_b = tokenize_and_lemmatize(text_b)
|
|
@@ -230,7 +218,7 @@ class SemanticMatcher:
|
|
| 230 |
self.confidence_level = "moderate"
|
| 231 |
else:
|
| 232 |
self.confidence_level = "no_match"
|
| 233 |
-
return
|
| 234 |
|
| 235 |
def match(self, text_a: str, text_b: str) -> bool:
|
| 236 |
"""Return True if the two texts are considered a match based on the score."""
|
|
|
|
| 142 |
return 0.0
|
| 143 |
return float(np.dot(vec_a, vec_b) / (norm_a * norm_b))
|
| 144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
# ββ Core matcher ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 146 |
|
| 147 |
class SemanticMatcher:
|
|
|
|
| 200 |
# Fast-path: normalized exact match
|
| 201 |
if normalize(text_a) == normalize(text_b):
|
| 202 |
self.confidence_level = "strong"
|
| 203 |
+
return 1.0
|
| 204 |
|
| 205 |
tokens_a = tokenize_and_lemmatize(text_a)
|
| 206 |
tokens_b = tokenize_and_lemmatize(text_b)
|
|
|
|
| 218 |
self.confidence_level = "moderate"
|
| 219 |
else:
|
| 220 |
self.confidence_level = "no_match"
|
| 221 |
+
return score
|
| 222 |
|
| 223 |
def match(self, text_a: str, text_b: str) -> bool:
|
| 224 |
"""Return True if the two texts are considered a match based on the score."""
|