ajaxwin commited on
Commit
5235476
Β·
1 Parent(s): 409c8b7

refactor: Update ActionType to include costs and modified grader for task 1

Browse files
env/schemas.py CHANGED
@@ -24,28 +24,43 @@ from pydantic import BaseModel, Field
24
  # ---------------------------------------------------------------------------
25
 
26
  class ActionType(str, Enum):
 
 
 
 
 
27
  # ── Task 1 – Vulnerability Detection ───────────────────────────────────
28
- LIST_FUNCTIONS = "list_functions"
29
- GET_FUNCTION_CODE = "get_function_code"
30
- GET_FUNCTION_SUMMARY = "get_function_summary"
31
- GET_FILE_METADATA = "get_file_metadata"
32
- GET_STATE_VARIABLE = "get_state_variable"
33
- GET_CALL_GRAPH = "get_call_graph"
34
- SUBMIT = "submit"
35
 
36
  # ── Task 2 – Property Discovery ─────────────────────────────────────────
37
- GET_SIMILAR_RULE = "get_similar_rule" # -0.20
38
- GET_FILE_NATSPEC = "get_file_natspec" # -0.03
39
- GET_FUNCTION_NATSPEC = "get_function_natspec" # -0.08
40
- GET_RELATED_FUNCTIONS = "get_related_functions" # -0.06
41
- GET_SIGNATURE = "get_signature" # -0.04
42
- SUBMIT_PROPERTY = "submit_property" # scored 0–5, one attempt
43
 
44
  # ── Task 3 – Rule Checker ────────────────────────────────────────────────
45
- GET_PROPERTY_SPECIFICATION = "get_property_specification" # -0.03
46
- GET_FUNCTION_METADATA = "get_function_metadata" # -0.05
47
- SUBMIT_FUNCTION = "submit_function" # +5.0 / +1.5 / -1.5, one attempt
48
-
 
 
 
 
 
 
 
 
 
 
49
 
50
  class Action(BaseModel):
51
  """
@@ -73,27 +88,20 @@ class Observation(BaseModel):
73
 
74
  task_id : which task is active
75
  contract_name : name of the Solidity contract
76
- contract_description : high-level description of what the contract does
77
  available_actions : list of valid ActionType strings
78
  last_action : the action that produced this observation (None on reset)
79
  last_action_result : human-readable result of the last action
80
- step_count : number of steps taken so far
81
- cumulative_reward : running reward total
82
  done : whether the episode has ended
83
  extra : any additional task-specific context
84
  """
85
  task_id: str
86
  contract_name: str
87
- contract_description: str
88
- available_actions: List[str]
89
  last_action: Optional[str] = None
90
  last_action_result: Optional[str] = None
91
- step_count: int = 0
92
- cumulative_reward: float = 0.0
93
  done: bool = False
94
  extra: Dict[str, Any] = Field(default_factory=dict)
95
 
96
-
97
  # ---------------------------------------------------------------------------
98
  # Reward
99
  # ---------------------------------------------------------------------------
 
24
  # ---------------------------------------------------------------------------
25
 
26
  class ActionType(str, Enum):
27
+ """(Action type, cost)"""
28
+
29
+ # Attribute to store the cost of each action
30
+ cost: float
31
+
32
  # ── Task 1 – Vulnerability Detection ───────────────────────────────────
33
+ LIST_FUNCTIONS = ("list_functions", -0.04)
34
+ GET_FUNCTION_CODE = ("get_function_code", -0.14)
35
+ GET_FUNCTION_SUMMARY = ("get_function_summary", -0.07)
36
+ GET_FILE_METADATA = ("get_file_metadata", -0.02)
37
+ GET_STATE_VARIABLE = ("get_state_variable", -0.06)
38
+ GET_CALL_GRAPH = ("get_call_graph", -0.08)
39
+ SUBMIT = ("submit", 0.0)
40
 
41
  # ── Task 2 – Property Discovery ─────────────────────────────────────────
42
+ GET_SIMILAR_RULE = ("get_similar_rule", 0.0)
43
+ GET_FILE_NATSPEC = ("get_file_natspec", 0.0)
44
+ GET_FUNCTION_NATSPEC = ("get_function_natspec", 0.0)
45
+ GET_RELATED_FUNCTIONS = ("get_related_functions", 0.0)
46
+ GET_SIGNATURE = ("get_signature", 0.0)
47
+ SUBMIT_PROPERTY = ("submit_property", 0.0)
48
 
49
  # ── Task 3 – Rule Checker ────────────────────────────────────────────────
50
+ GET_PROPERTY_SPECIFICATION = ("get_property_specification", 0.0)
51
+ GET_FUNCTION_METADATA = ("get_function_metadata", 0.0)
52
+ SUBMIT_FUNCTION = ("submit_function", 0.0)
53
+
54
+ # ─────── General Actions ─────────────────────────────────────────────────
55
+ UNKNOWN = ("unknown", 0.0)
56
+ REPEATED = ("repeated", -0.22)
57
+ RESUBMIT = ("resubmit", 0.0)
58
+
59
+ def __new__(cls, str_value: str, cost: float):
60
+ obj = str.__new__(cls, str_value)
61
+ obj._value_ = str_value
62
+ obj.cost = cost
63
+ return obj
64
 
65
  class Action(BaseModel):
66
  """
 
88
 
89
  task_id : which task is active
90
  contract_name : name of the Solidity contract
 
91
  available_actions : list of valid ActionType strings
92
  last_action : the action that produced this observation (None on reset)
93
  last_action_result : human-readable result of the last action
 
 
94
  done : whether the episode has ended
95
  extra : any additional task-specific context
96
  """
97
  task_id: str
98
  contract_name: str
99
+ # available_actions: List[str] # May need it, may not depends on the agent
 
100
  last_action: Optional[str] = None
101
  last_action_result: Optional[str] = None
 
 
102
  done: bool = False
103
  extra: Dict[str, Any] = Field(default_factory=dict)
104
 
 
105
  # ---------------------------------------------------------------------------
106
  # Reward
107
  # ---------------------------------------------------------------------------
server/app.py CHANGED
@@ -190,18 +190,15 @@ def step(
190
  status_code=400,
191
  detail=f"No active session '{session_id}'. Call /reset first.",
192
  )
193
- try:
194
- action_type = ActionType(body.action_type)
195
- except ValueError:
196
- raise HTTPException(
197
- status_code=400,
198
- detail=f"Unknown action_type '{body.action_type}'. Valid: {[a.value for a in ActionType]}",
199
- )
200
  action = Action(action_type=action_type, params=body.params)
201
  try:
202
  result = env.step(action)
203
  except RuntimeError as e:
204
- raise HTTPException(status_code=409, detail=str(e))
205
  return JSONResponse(content=result.model_dump(), status_code=200)
206
 
207
 
@@ -216,7 +213,6 @@ def state(session_id: str = Query(default=DEFAULT_SESSION)):
216
  )
217
  return JSONResponse(content=env.state().model_dump(), status_code=200)
218
 
219
-
220
  @app.get("/action_space")
221
  def action_space(task_id: str = "task1_vuln_detection"):
222
  """Describe the action space for a task."""
 
190
  status_code=400,
191
  detail=f"No active session '{session_id}'. Call /reset first.",
192
  )
193
+
194
+ # removed error handling here
195
+ action_type = ActionType(body.action_type) if body.action_type in ActionType else ActionType.UNKNOWN
196
+
 
 
 
197
  action = Action(action_type=action_type, params=body.params)
198
  try:
199
  result = env.step(action)
200
  except RuntimeError as e:
201
+ return JSONResponse(content=str(e), status_code = 200)
202
  return JSONResponse(content=result.model_dump(), status_code=200)
203
 
204
 
 
213
  )
214
  return JSONResponse(content=env.state().model_dump(), status_code=200)
215
 
 
216
  @app.get("/action_space")
217
  def action_space(task_id: str = "task1_vuln_detection"):
218
  """Describe the action space for a task."""
server/tasks/task1/actions.py CHANGED
@@ -1,13 +1,8 @@
1
  """Actions for Task 1: Targeted Vulnerability Detection.
2
- Actions & rewards:
3
- list_functions -0.05 (broad overview of contract)
4
- get_function_code -0.10 (wrong function) / +0.05 (correct function)
5
- get_function_summary -0.05 (wrong function) / +0.03 (correct function)
6
- get_file_metadata -0.04 (general contract info)
7
  """
8
 
9
  from typing import Any, Dict, Tuple
10
- from env.schemas import Reward
11
  from data.data_loader import (
12
  list_function_names,
13
  get_function_by_name,
@@ -19,11 +14,11 @@ from data.data_loader import (
19
  def list_functions(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
20
  """Handle LIST_FUNCTIONS action."""
21
  if ctx._is_repeated(qkey):
22
- return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
23
  names = list_function_names(ctx._contract)
24
  return (
25
  f"Functions in {ctx._contract['contract_name']}: {', '.join(names)}",
26
- Reward(value=-0.05, reason="list_functions cost", partial=True),
27
  )
28
 
29
 
@@ -31,20 +26,19 @@ def get_function_code(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
31
  """Handle GET_FUNCTION_CODE action."""
32
  fn_name = params.get("function_name", "")
33
  if ctx._is_repeated(qkey):
34
- return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
 
35
  fn = get_function_by_name(ctx._contract, fn_name)
36
  if fn is None:
37
  return (
38
  f"Function '{fn_name}' not found. Available: {list_function_names(ctx._contract)}",
39
- Reward(value=-0.10, reason="Wrong/unknown function name", partial=True),
40
  )
41
- is_target = fn["name"].lower() == ctx._target_fn["name"].lower()
42
  code = fn.get("code", "// no code available")
43
- reward_val = 0.05 if is_target else -0.10
44
- reason = "Fetched target function code (+)" if is_target else "Fetched non-target function (-)"
45
  return (
46
  f"// {fn['name']}\n{code}",
47
- Reward(value=reward_val, reason=reason, partial=True),
48
  )
49
 
50
 
@@ -52,71 +46,74 @@ def get_function_summary(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward
52
  """Handle GET_FUNCTION_SUMMARY action."""
53
  fn_name = params.get("function_name", "")
54
  if ctx._is_repeated(qkey):
55
- return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
 
56
  fn = get_function_by_name(ctx._contract, fn_name)
57
  if fn is None:
58
  return (
59
  f"Function '{fn_name}' not found.",
60
- Reward(value=-0.05, reason="Wrong function name", partial=True),
61
  )
62
- is_target = fn["name"].lower() == ctx._target_fn["name"].lower()
63
  comment = fn.get("comment", "No summary available.")
64
- reward_val = 0.03 if is_target else -0.05
65
- reason = "Fetched target function summary (+)" if is_target else "Fetched non-target summary (-)"
66
  return (
67
  f"Summary of '{fn['name']}': {comment}",
68
- Reward(value=reward_val, reason=reason, partial=True),
69
  )
70
 
71
 
72
  def get_file_metadata(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
73
  """Handle GET_FILE_METADATA action."""
74
  if ctx._is_repeated(qkey):
75
- return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
 
76
  meta = ctx._contract.get("metadata", {})
77
  result = (
78
  f"Contract: {ctx._contract['contract_name']} | "
79
  f"Solidity: {meta.get('solidity_version', 'N/A')} | "
80
  f"Description: {meta.get('description', 'N/A')}"
81
  )
82
- return result, Reward(value=-0.04, reason="get_file_metadata cost", partial=True)
83
 
84
 
85
  def get_state_variable(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
86
  """Handle GET_STATE_VARIABLE action."""
87
  var_name = params.get("variable_name", "")
88
  if ctx._is_repeated(qkey):
89
- return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
 
90
  if not var_name:
91
  names = list_state_variable_names(ctx._contract)
92
  return (
93
  f"State variables: {', '.join(names)}",
94
- Reward(value=-0.05, reason="Listed state variables", partial=True),
95
  )
 
96
  sv = get_state_variable_by_name(ctx._contract, var_name)
97
  if sv is None:
98
  return (
99
  f"Variable '{var_name}' not found.",
100
- Reward(value=-0.05, reason="Unknown state variable", partial=True),
101
  )
 
102
  return (
103
  f"{sv['type']} {sv['visibility']} {sv['name']}: {sv.get('description', '')}",
104
- Reward(value=-0.05, reason="get_state_variable cost", partial=True),
105
  )
106
 
107
 
108
  def get_call_graph(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
109
  """Handle GET_CALL_GRAPH action."""
110
  if ctx._is_repeated(qkey):
111
- return "Repeated query.", Reward(value=-0.40, reason="Repeated query", partial=True)
 
112
  cg = ctx._contract.get("call_graph", {})
113
  cg_str = "; ".join(f"{fn} β†’ [{', '.join(callees)}]" for fn, callees in cg.items())
114
  return (
115
  f"Call graph: {cg_str}",
116
- Reward(value=-0.08, reason="get_call_graph cost", partial=True),
117
  )
118
 
119
-
120
  def submit(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
121
  """Handle SUBMIT action for Task 1.
122
 
@@ -127,14 +124,15 @@ def submit(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
127
  """
128
  if ctx._done:
129
  return (
130
- "❌ You have already submitted for this episode. "
131
  "Only ONE submission is allowed.",
132
- Reward(value=0.0, reason="Second submit_function attempt", partial=False),
 
 
133
  )
134
 
135
  fn_name = params.get("function_name", "").strip()
136
  vuln_type = params.get("vulnerability_type", "").strip()
137
-
138
  if not fn_name or not vuln_type:
139
  return (
140
  "submit_function requires both 'function_name' and "
@@ -142,35 +140,10 @@ def submit(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
142
  Reward(value=0.0, reason="Malformed submission", partial=False),
143
  )
144
 
145
- ctx._done = True
146
-
147
- score = ctx._grader.grade_submission(fn_name, vuln_type) # {0.0, 0.5, 1.0}
148
- reward_val = ctx._grader.reward_for_score(score) # [0.0, 1.0]
149
- correct = ctx._grader.get_canonical_answer()
150
-
151
- if score == 1.0:
152
- msg = (
153
- f"βœ… CORRECT! '{fn_name}' is the vulnerable function "
154
- f"and the vulnerability type matches. "
155
- f"Score: 1.0 β†’ Reward: {reward_val:.3f}"
156
- )
157
- elif score == 0.5:
158
- msg = (
159
- f"🟑 PARTIAL. '{fn_name}' is the correct function but the "
160
- f"vulnerability type was not recognised. "
161
- f"Score: 0.5 β†’ Reward: {reward_val:.3f}. "
162
- f"Expected vulnerability: '{correct['vulnerability']}'."
163
- )
164
- else:
165
- msg = (
166
- f"❌ INCORRECT. '{fn_name}' is not the target function. "
167
- f"Score: 0.0 β†’ Reward: {reward_val:.3f}. "
168
- f"Correct answer: function='{correct['function']}', "
169
- f"vulnerability='{correct['vulnerability']}'."
170
- )
171
-
172
- return msg, Reward(
173
- value=reward_val,
174
  reason=f"submit_function score={score:.1f}",
175
  partial=False,
176
  )
@@ -180,5 +153,5 @@ def unknown_action(ctx: Any, qkey: str, params: Dict, action_type: str) -> Tuple
180
  """Fallback for unknown actions."""
181
  return (
182
  f"Unknown action type: {action_type}",
183
- Reward(value=-0.10, reason="Unknown action", partial=True),
184
  )
 
1
  """Actions for Task 1: Targeted Vulnerability Detection.
 
 
 
 
 
2
  """
3
 
4
  from typing import Any, Dict, Tuple
5
+ from env.schemas import ActionType, Reward
6
  from data.data_loader import (
7
  list_function_names,
8
  get_function_by_name,
 
14
  def list_functions(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
15
  """Handle LIST_FUNCTIONS action."""
16
  if ctx._is_repeated(qkey):
17
+ return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True)
18
  names = list_function_names(ctx._contract)
19
  return (
20
  f"Functions in {ctx._contract['contract_name']}: {', '.join(names)}",
21
+ Reward(value=ActionType.LIST_FUNCTIONS.cost, reason="list_functions cost", partial=True),
22
  )
23
 
24
 
 
26
  """Handle GET_FUNCTION_CODE action."""
27
  fn_name = params.get("function_name", "")
28
  if ctx._is_repeated(qkey):
29
+ return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True)
30
+
31
  fn = get_function_by_name(ctx._contract, fn_name)
32
  if fn is None:
33
  return (
34
  f"Function '{fn_name}' not found. Available: {list_function_names(ctx._contract)}",
35
+ Reward(value=ActionType.GET_FUNCTION_CODE.cost, reason="Wrong/unknown function name", partial=True),
36
  )
37
+
38
  code = fn.get("code", "// no code available")
 
 
39
  return (
40
  f"// {fn['name']}\n{code}",
41
+ Reward(value=ActionType.GET_FUNCTION_CODE.cost, reason="Fetched code", partial=True),
42
  )
43
 
44
 
 
46
  """Handle GET_FUNCTION_SUMMARY action."""
47
  fn_name = params.get("function_name", "")
48
  if ctx._is_repeated(qkey):
49
+ return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True)
50
+
51
  fn = get_function_by_name(ctx._contract, fn_name)
52
  if fn is None:
53
  return (
54
  f"Function '{fn_name}' not found.",
55
+ Reward(value=ActionType.GET_FUNCTION_SUMMARY.cost, reason="Wrong function name", partial=True),
56
  )
57
+
58
  comment = fn.get("comment", "No summary available.")
 
 
59
  return (
60
  f"Summary of '{fn['name']}': {comment}",
61
+ Reward(value=ActionType.GET_FUNCTION_SUMMARY.cost, reason="Fetched summary", partial=True),
62
  )
63
 
64
 
65
  def get_file_metadata(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
66
  """Handle GET_FILE_METADATA action."""
67
  if ctx._is_repeated(qkey):
68
+ return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True)
69
+
70
  meta = ctx._contract.get("metadata", {})
71
  result = (
72
  f"Contract: {ctx._contract['contract_name']} | "
73
  f"Solidity: {meta.get('solidity_version', 'N/A')} | "
74
  f"Description: {meta.get('description', 'N/A')}"
75
  )
76
+ return result, Reward(value=ActionType.GET_FILE_METADATA.cost, reason="get_file_metadata cost", partial=True)
77
 
78
 
79
  def get_state_variable(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
80
  """Handle GET_STATE_VARIABLE action."""
81
  var_name = params.get("variable_name", "")
82
  if ctx._is_repeated(qkey):
83
+ return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True)
84
+
85
  if not var_name:
86
  names = list_state_variable_names(ctx._contract)
87
  return (
88
  f"State variables: {', '.join(names)}",
89
+ Reward(value=ActionType.GET_STATE_VARIABLE.cost, reason="Listed state variables", partial=True),
90
  )
91
+
92
  sv = get_state_variable_by_name(ctx._contract, var_name)
93
  if sv is None:
94
  return (
95
  f"Variable '{var_name}' not found.",
96
+ Reward(value=ActionType.GET_STATE_VARIABLE.cost, reason="Unknown state variable", partial=True),
97
  )
98
+
99
  return (
100
  f"{sv['type']} {sv['visibility']} {sv['name']}: {sv.get('description', '')}",
101
+ Reward(value=ActionType.GET_STATE_VARIABLE.cost, reason="get_state_variable cost", partial=True),
102
  )
103
 
104
 
105
  def get_call_graph(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
106
  """Handle GET_CALL_GRAPH action."""
107
  if ctx._is_repeated(qkey):
108
+ return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True)
109
+
110
  cg = ctx._contract.get("call_graph", {})
111
  cg_str = "; ".join(f"{fn} β†’ [{', '.join(callees)}]" for fn, callees in cg.items())
112
  return (
113
  f"Call graph: {cg_str}",
114
+ Reward(value=ActionType.GET_CALL_GRAPH.cost, reason="get_call_graph cost", partial=True),
115
  )
116
 
 
117
  def submit(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
118
  """Handle SUBMIT action for Task 1.
119
 
 
124
  """
125
  if ctx._done:
126
  return (
 
127
  "Only ONE submission is allowed.",
128
+ Reward(value=ActionType.RESUBMIT.cost,
129
+ reason="Second submit_function attempt",
130
+ partial=False),
131
  )
132
 
133
  fn_name = params.get("function_name", "").strip()
134
  vuln_type = params.get("vulnerability_type", "").strip()
135
+
136
  if not fn_name or not vuln_type:
137
  return (
138
  "submit_function requires both 'function_name' and "
 
140
  Reward(value=0.0, reason="Malformed submission", partial=False),
141
  )
142
 
143
+ ctx._done = True
144
+ score = ctx._grader.grade(fn_name, vuln_type, ctx._step_count, ctx._cummulative_cost)
145
+ return (f"Correct Answer: {ctx._grader.get_canonical_answer}"), Reward(
146
+ value=score,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  reason=f"submit_function score={score:.1f}",
148
  partial=False,
149
  )
 
153
  """Fallback for unknown actions."""
154
  return (
155
  f"Unknown action type: {action_type}",
156
+ Reward(value=ActionType.UNKNOWN.cost, reason="Unknown action", partial=True),
157
  )
server/tasks/task1/environment.py CHANGED
@@ -9,21 +9,11 @@ Episode flow:
9
  3. The agent uses actions to explore the contract (each costs a small penalty).
10
  4. When the agent submits, the Grader scores the answer and the episode ends.
11
 
12
- Reward shaping:
13
- list_functions : -0.05
14
- get_function_code : -0.10 (wrong function) / +0.05 (correct function)
15
- get_function_summary : -0.05 (wrong function) / +0.03 (correct function)
16
- get_file_metadata : -0.04
17
- get_state_variable : -0.05
18
- get_call_graph : -0.08
19
- correct submit (score=1.0) : +5.0
20
- partially correct submit (score=0.5) : +1.0
21
- wrong submit (score=0.0) : -1.5
22
- repeated query : -0.40
23
  """
24
 
25
  from __future__ import annotations
26
 
 
27
  import random
28
  from typing import Any, Dict, List, Optional, Set
29
 
@@ -60,16 +50,19 @@ class Task1Environment(BaseEnv):
60
  def __init__(self, contracts_path: Optional[str] = None) -> None:
61
  self._contracts = load_contracts(contracts_path) if contracts_path else load_contracts()
62
  self._rng = random.Random()
 
63
 
64
  # Episode state (initialised by reset)
65
  self._contract: Dict[str, Any] = {}
66
  self._target_fn: Dict[str, Any] = {}
67
  self._grader: Optional[Task1Grader] = None
68
  self._step_count: int = 0
69
- self._cumulative_reward: float = 0.0
70
  self._done: bool = False
71
  self._query_history: List[str] = []
72
  self._seen_queries: Set[str] = set()
 
 
73
 
74
  # ------------------------------------------------------------------
75
  # OpenEnv interface
@@ -84,9 +77,10 @@ class Task1Environment(BaseEnv):
84
  self._grader = Task1Grader(
85
  target_function=self._target_fn["name"],
86
  vulnerability_issue=self._target_fn["vulnerability_details"]["issue"],
 
87
  )
88
  self._step_count = 0
89
- self._cumulative_reward = 0.0
90
  self._done = False
91
  self._query_history = []
92
  self._seen_queries = set()
@@ -102,28 +96,28 @@ class Task1Environment(BaseEnv):
102
 
103
  def step(self, action: Action) -> StepResult:
104
  """Execute one agent action."""
 
105
  if self._done:
106
  raise RuntimeError("Episode is done. Call reset() to start a new episode.")
107
-
108
- self._step_count += 1
109
-
110
- # Dispatch
111
  result_text, reward = self._dispatch(action)
112
-
113
- self._cumulative_reward += reward.value
114
- self._query_history.append(f"[{action.action_type}] β†’ {result_text[:120]}")
115
-
116
  obs = self._build_observation(
117
  last_action=action.action_type,
118
  last_result=result_text,
119
  )
 
120
  return StepResult(
121
  observation=obs,
122
  reward=reward,
123
  done=self._done,
124
  info={
125
  "step": self._step_count,
126
- "cumulative_reward": self._cumulative_reward,
127
  },
128
  )
129
 
@@ -133,7 +127,7 @@ class Task1Environment(BaseEnv):
133
  contract_name=self._contract.get("contract_name", ""),
134
  target_function=self._target_fn.get("name", ""),
135
  step_count=self._step_count,
136
- cumulative_reward=self._cumulative_reward,
137
  done=self._done,
138
  query_history=list(self._query_history),
139
  )
@@ -150,12 +144,8 @@ class Task1Environment(BaseEnv):
150
  return Observation(
151
  task_id=TASK_ID,
152
  contract_name=self._contract.get("contract_name", ""),
153
- contract_description=self._contract.get("metadata", {}).get("description", ""),
154
- available_actions=[a.value for a in AVAILABLE_ACTIONS],
155
  last_action=last_action,
156
  last_action_result=last_result,
157
- step_count=self._step_count,
158
- cumulative_reward=self._cumulative_reward,
159
  done=self._done,
160
  extra={
161
  "solidity_version": self._contract.get("metadata", {}).get("solidity_version", ""),
@@ -181,7 +171,7 @@ class Task1Environment(BaseEnv):
181
  at = action.action_type
182
  params = action.params
183
  qkey = self._query_key(at, params)
184
-
185
  # Mapping from ActionType to handler function
186
  handlers = {
187
  ActionType.LIST_FUNCTIONS: actions.list_functions,
 
9
  3. The agent uses actions to explore the contract (each costs a small penalty).
10
  4. When the agent submits, the Grader scores the answer and the episode ends.
11
 
 
 
 
 
 
 
 
 
 
 
 
12
  """
13
 
14
  from __future__ import annotations
15
 
16
+ from math import floor, log2
17
  import random
18
  from typing import Any, Dict, List, Optional, Set
19
 
 
50
  def __init__(self, contracts_path: Optional[str] = None) -> None:
51
  self._contracts = load_contracts(contracts_path) if contracts_path else load_contracts()
52
  self._rng = random.Random()
53
+ self._max_steps: int = 0
54
 
55
  # Episode state (initialised by reset)
56
  self._contract: Dict[str, Any] = {}
57
  self._target_fn: Dict[str, Any] = {}
58
  self._grader: Optional[Task1Grader] = None
59
  self._step_count: int = 0
60
+ self._cummulative_cost: float = 0.0
61
  self._done: bool = False
62
  self._query_history: List[str] = []
63
  self._seen_queries: Set[str] = set()
64
+ self._cost_free_steps: int = 0
65
+ self._decay: float = 0.0
66
 
67
  # ------------------------------------------------------------------
68
  # OpenEnv interface
 
77
  self._grader = Task1Grader(
78
  target_function=self._target_fn["name"],
79
  vulnerability_issue=self._target_fn["vulnerability_details"]["issue"],
80
+ n = floor(log2(len(self._contract["functions"])))
81
  )
82
  self._step_count = 0
83
+ self._cummulative_cost = 0.0
84
  self._done = False
85
  self._query_history = []
86
  self._seen_queries = set()
 
96
 
97
  def step(self, action: Action) -> StepResult:
98
  """Execute one agent action."""
99
+
100
  if self._done:
101
  raise RuntimeError("Episode is done. Call reset() to start a new episode.")
102
+ if self._step_count > self._max_steps:
103
+ raise RuntimeError("Exceeded maximum number of steps allowed. Call reset() to start a new episode.")
104
+
105
+ self._step_count += 1
106
  result_text, reward = self._dispatch(action)
107
+ self._cummulative_cost += reward.value
108
+ self._query_history.append(f"[{action.action_type}] β†’ {result_text[:200]}")
 
 
109
  obs = self._build_observation(
110
  last_action=action.action_type,
111
  last_result=result_text,
112
  )
113
+
114
  return StepResult(
115
  observation=obs,
116
  reward=reward,
117
  done=self._done,
118
  info={
119
  "step": self._step_count,
120
+ "cumulative_reward": self._cummulative_cost,
121
  },
122
  )
123
 
 
127
  contract_name=self._contract.get("contract_name", ""),
128
  target_function=self._target_fn.get("name", ""),
129
  step_count=self._step_count,
130
+ cumulative_reward=self._cummulative_cost,
131
  done=self._done,
132
  query_history=list(self._query_history),
133
  )
 
144
  return Observation(
145
  task_id=TASK_ID,
146
  contract_name=self._contract.get("contract_name", ""),
 
 
147
  last_action=last_action,
148
  last_action_result=last_result,
 
 
149
  done=self._done,
150
  extra={
151
  "solidity_version": self._contract.get("metadata", {}).get("solidity_version", ""),
 
171
  at = action.action_type
172
  params = action.params
173
  qkey = self._query_key(at, params)
174
+
175
  # Mapping from ActionType to handler function
176
  handlers = {
177
  ActionType.LIST_FUNCTIONS: actions.list_functions,
server/tasks/task1/grader.py CHANGED
@@ -1,58 +1,30 @@
1
  """
2
  grader.py (Task 1 – Targeted Vulnerability Detection)
3
  -------------------------------------------------------
4
- Deterministic grader. Grade range: 0.0 – 1.0
5
-
6
- 1.0 – correct function + correct vulnerability keyword
7
- 0.5 – correct function + wrong/unrecognised vulnerability keyword
8
- 0.0 – wrong function name
9
-
10
- reward_for_score() normalises the raw RL reward to [0.0, 1.0]
11
- using the fixed reward bounds [MIN_REWARD=-1.5, MAX_REWARD=5.0]:
12
- normalised = (raw + 1.5) / 6.5
13
  """
 
14
  from __future__ import annotations
15
  from typing import Dict
16
  from utils import SemanticMatcher
17
 
18
- # Raw reward bounds β€” used only for normalisation
19
- _MIN_REWARD = -1.5
20
- _MAX_REWARD = 5.0
21
- _REWARD_RANGE = _MAX_REWARD - _MIN_REWARD # 6.5
22
-
23
- _SCORE_MIN = 0.001 # grades are strictly (0, 1)
24
- _SCORE_MAX = 0.999
25
-
26
-
27
- def _clamp(v: float) -> float:
28
- return max(_SCORE_MIN, min(_SCORE_MAX, v))
29
-
30
-
31
  class Task1Grader:
32
- def __init__(self, target_function: str, vulnerability_issue: str) -> None:
33
  self.target_function = target_function.lower()
34
  self.vulnerability_issue = vulnerability_issue
 
 
 
 
35
 
36
- def grade_submission(self, submitted_function: str, submitted_vuln_type: str) -> float:
37
  """Returns grade strictly in (0, 1)."""
38
- if submitted_function.strip().lower() != self.target_function:
39
- return _clamp(0.0) # β†’ 0.001
40
- return _clamp(1.0) if SemanticMatcher().match(self.vulnerability_issue, submitted_vuln_type) else _clamp(0.5)
41
-
42
- def reward_for_score(self, score: float) -> float:
43
- """
44
- Maps grade score β†’ normalised reward strictly in (0, 1).
45
-
46
- Raw rewards: correct=+5.0, partial=+1.0, wrong=-1.5
47
- Normalised: (raw + 1.5) / 6.5 then clamped to (0.001, 0.999)
48
- """
49
- if score >= _SCORE_MAX:
50
- raw = 5.0
51
- elif score >= 0.5:
52
- raw = 1.0
53
- else:
54
- raw = -1.5
55
- return _clamp((raw - _MIN_REWARD) / _REWARD_RANGE)
56
-
57
  def get_canonical_answer(self) -> Dict[str, str]:
58
  return {"function": self.target_function, "vulnerability": self.vulnerability_issue}
 
1
  """
2
  grader.py (Task 1 – Targeted Vulnerability Detection)
3
  -------------------------------------------------------
4
+ Deterministic grader. Grade range: (0, 1)
 
 
 
 
 
 
 
 
5
  """
6
+
7
  from __future__ import annotations
8
  from typing import Dict
9
  from utils import SemanticMatcher
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class Task1Grader:
12
+ def __init__(self, target_function: str, vulnerability_issue: str, n: int) -> None:
13
  self.target_function = target_function.lower()
14
  self.vulnerability_issue = vulnerability_issue
15
+
16
+ # Log of No. of functions (n) is a heurisitic used to decided the size of contract code
17
+ self.n = n
18
+ self._decay = 0.75
19
 
20
+ def grade(self, submitted_function: str, submitted_vuln_type: str, steps: int, cummulative_cost: int) -> float:
21
  """Returns grade strictly in (0, 1)."""
22
+ func_match = submitted_function.strip().lower() != self.target_function
23
+ issue_match = SemanticMatcher().match(self.vulnerability_issue, submitted_vuln_type)
24
+
25
+ # Score formula
26
+ free_budget = (cummulative_cost / steps) * (self.n + 2)
27
+ return func_match * issue_match * (self._decay ** max(0, cummulative_cost - free_budget))
28
+
 
 
 
 
 
 
 
 
 
 
 
 
29
  def get_canonical_answer(self) -> Dict[str, str]:
30
  return {"function": self.target_function, "vulnerability": self.vulnerability_issue}
utils/semanticmatcher.py CHANGED
@@ -142,18 +142,6 @@ def cosine_similarity(vec_a: np.ndarray, vec_b: np.ndarray) -> float:
142
  return 0.0
143
  return float(np.dot(vec_a, vec_b) / (norm_a * norm_b))
144
 
145
-
146
- # ── Score clamping ───────────────────────────────────────────────────────────
147
-
148
- _SCORE_MIN = 0.001 # scores are strictly (0, 1) β€” never touch 0 or 1
149
- _SCORE_MAX = 0.999
150
-
151
-
152
- def _clamp(score: float) -> float:
153
- """Clamp score to the open interval (0, 1): [_SCORE_MIN, _SCORE_MAX]."""
154
- return max(_SCORE_MIN, min(_SCORE_MAX, score))
155
-
156
-
157
  # ── Core matcher ──────────────────────────────────────────────────────────────
158
 
159
  class SemanticMatcher:
@@ -212,7 +200,7 @@ class SemanticMatcher:
212
  # Fast-path: normalized exact match
213
  if normalize(text_a) == normalize(text_b):
214
  self.confidence_level = "strong"
215
- return _clamp(1.0) # β†’ 0.999 (strictly less than 1)
216
 
217
  tokens_a = tokenize_and_lemmatize(text_a)
218
  tokens_b = tokenize_and_lemmatize(text_b)
@@ -230,7 +218,7 @@ class SemanticMatcher:
230
  self.confidence_level = "moderate"
231
  else:
232
  self.confidence_level = "no_match"
233
- return _clamp(score) # strictly in (0, 1)
234
 
235
  def match(self, text_a: str, text_b: str) -> bool:
236
  """Return True if the two texts are considered a match based on the score."""
 
142
  return 0.0
143
  return float(np.dot(vec_a, vec_b) / (norm_a * norm_b))
144
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  # ── Core matcher ──────────────────────────────────────────────────────────────
146
 
147
  class SemanticMatcher:
 
200
  # Fast-path: normalized exact match
201
  if normalize(text_a) == normalize(text_b):
202
  self.confidence_level = "strong"
203
+ return 1.0
204
 
205
  tokens_a = tokenize_and_lemmatize(text_a)
206
  tokens_b = tokenize_and_lemmatize(text_b)
 
218
  self.confidence_level = "moderate"
219
  else:
220
  self.confidence_level = "no_match"
221
+ return score
222
 
223
  def match(self, text_a: str, text_b: str) -> bool:
224
  """Return True if the two texts are considered a match based on the score."""